import os, sys, shutil, logging
from typing import List, Dict
from pathlib import Path
from copy import deepcopy
from dflow.python import (
OP,
OPIO,
OPIOSign,
Artifact
)
from rid.utils import load_json
from rid.constants import model_tag_fmt, init_conf_name, init_input_name, walker_tag_fmt
logging.basicConfig(
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
level=os.environ.get("LOGLEVEL", "INFO").upper(),
stream=sys.stdout,
)
logger = logging.getLogger(__name__)
[docs]def prep_confs(confs, numb_walkers):
numb_confs = len(confs)
conf_list = []
if numb_confs < numb_walkers:
logger.info("Number of confs is smaller than number of walkers. Copy replicas up to number of walkers.")
for idx in range(numb_walkers):
shutil.copyfile(confs[idx%numb_confs], init_conf_name.format(idx=idx))
elif numb_confs > numb_walkers:
logger.info("Number of confs is greater than number of walkers. Only use the fist `numb_walkers` confs.")
for idx in range(numb_walkers):
shutil.copyfile(confs[idx], init_conf_name.format(idx=idx))
else:
for idx in range(numb_walkers):
shutil.copyfile(confs[idx], init_conf_name.format(idx=idx))
for idx in range(numb_walkers):
conf_list.append(Path(init_conf_name.format(idx=idx)))
return conf_list
[docs]class PrepRiD(OP):
"""Pre-processing of RiD.
1. Parse RiD configuration JSON file, get default value if parameters are not provided.
2. Rearrange conformation files.
3. Make task names and formats.
"""
[docs] @classmethod
def get_output_sign(cls):
return OPIOSign(
{
"numb_iters": int,
"numb_walkers": int,
"numb_models": int,
"confs": Artifact(List[Path]),
"walker_tags": List,
"model_tags": List,
"exploration_config": Dict,
"cv_config": Dict,
"trust_lvl_1": List[float],
"trust_lvl_2": List[float],
"cluster_threshold": List[float],
"angular_mask": List,
"weights": List,
"numb_cluster_upper": int,
"numb_cluster_lower": int,
"max_selection": int,
"numb_cluster_threshold": int,
"dt": float,
"output_freq": float,
"slice_mode": str,
"label_config": Dict,
"kappas": List,
"train_config": Dict
}
)
[docs] @OP.exec_sign_check
def execute(
self,
op_in: OPIO,
) -> OPIO:
r"""Execute the OP.
Parameters
----------
op_in : dict
Input dict with components:
- `confs`: (`Artifact(Path)`) User-provided initial conformation files (.gro) for reinfoced dynamics.
- `rid_config`: (`Artifact(Path)`) Configuration file (.json) of RiD.
Parameters in this file will be parsed.
Returns
-------
Output dict with components:
- `numb_iters`: (`int`) Max number of iterations for Ri.
- `numb_walkers`: (`int`) Number of parallel walkers for exploration.
- `numb_models`: (`int`) Number of neural network models of RiD.
- `confs`: (`Artifact(List[Path])`) Rearranged initial conformation files (.gro) for reinfoced dynamics.
- `walker_tags`: (`List`) Tag formmat for parallel walkers.
- `model_tags`: (`List`) Tag formmat for neural network models.
- `exploration_config`: (`Dict`) Configuration of simulations in exploration steps.
- `cv_config`: (`Dict`) Configuration to create CV in PLUMED2 style.
- `trust_lvl_1`: (`List[float]`) Trust level 1, or e0.
- `trust_lvl_2`: (`List[float]`) Trust level 2, or e1.
- `cluster_threshold`: (`List[float]`) Initial guess of cluster threshold.
- `angular_mask`: (`List`) Angular mask for periodic collective variables.
1 represents periodic, 0 represents non-periodic.
- `weights`: (`List`) Weights for clustering collective variables. see details in cluster algorithms.
- `numb_cluster_upper`: (`int`) Upper limit of cluster number to make cluster threshold.
- `numb_cluster_lower`: (`int`) Lower limit of cluster number to make cluster threshold.
- `max_selection`: (`int`) Max selection number of clusters in Selection steps for each parallel walker.
- `numb_cluster_threshold`: (`int`) Used to adjust trust level. When cluster number is grater than this threshold,
trust levels will be increased adaptively.
- `dt`: (`float`) Time interval of exploration MD simulations. Gromacs `trjconv` commands will need this parameters
to slice trajectories by `-dump` tag, see `selection` steps for detail.
- `slice_mode`: (`str`) Mode to slice trajectories. Either `gmx` or `mdtraj`.
- `label_config`: (`Dict`) Configuration of simulations in labeling steps.
- `kappas`: (`List`) Force constants of harmonic restraints to perform restrained MD simulations.
- `train_config`: (`Dict`) Configuration to train neural networks, including training strategy and network structures.
"""
jdata = deepcopy(load_json(op_in["rid_config"]))
numb_walkers = jdata.pop("numb_walkers")
train_config = jdata.pop("Train")
numb_models = train_config.pop("numb_models")
numb_iters = jdata.pop("numb_iters")
conf_list = prep_confs(op_in["confs"], numb_walkers)
walker_tags = []
model_tags = []
for idx in range(numb_walkers):
walker_tags.append(walker_tag_fmt.format(idx=idx))
for idx in range(numb_models):
model_tags.append(model_tag_fmt.format(idx=idx))
exploration_config = jdata.pop("ExploreMDConfig")
dt = exploration_config["dt"]
output_freq = exploration_config["output_freq"]
cv_config = jdata.pop("CV")
angular_mask = cv_config.pop("angular_mask")
weights = cv_config.pop("weights")
selection_config = jdata.pop("SelectorConfig")
label_config = jdata.pop("LabelMDConfig")
kappas = label_config.pop("kappas")
trust_lvl_1 = jdata.pop("trust_lvl_1")
trust_lvl_2 = jdata.pop("trust_lvl_2")
trust_lvl_1_list = [trust_lvl_1 for _ in range(numb_walkers)]
trust_lvl_2_list = [trust_lvl_2 for _ in range(numb_walkers)]
cluster_threshold = selection_config.pop("cluster_threshold")
cluster_threshold_list = [cluster_threshold for _ in range(numb_walkers)]
op_out = OPIO(
{
"numb_iters": numb_iters,
"numb_walkers": numb_walkers,
"numb_models": numb_models,
"confs": conf_list,
"walker_tags": walker_tags,
"model_tags": model_tags,
"exploration_config": exploration_config,
"cv_config": cv_config,
"trust_lvl_1": trust_lvl_1_list,
"trust_lvl_2": trust_lvl_2_list,
"cluster_threshold": cluster_threshold_list,
"angular_mask": angular_mask,
"weights": weights,
"numb_cluster_upper": selection_config.pop("numb_cluster_upper"),
"numb_cluster_lower": selection_config.pop("numb_cluster_lower"),
"max_selection": selection_config.pop("max_selection"),
"numb_cluster_threshold": selection_config.pop("numb_cluster_threshold"),
"dt": dt,
"output_freq": output_freq,
"slice_mode": selection_config.pop("slice_mode"),
"label_config": label_config,
"kappas": kappas,
"train_config": train_config
}
)
return op_out