Source code for rid.superop.label

from typing import Dict, List
from copy import deepcopy
from dflow import (
    InputParameter,
    Inputs,
    InputArtifact,
    Outputs,
    OutputArtifact,
    Step,
    Steps,
    argo_range,
    argo_len,
)
from dflow.python import(
    PythonOPTemplate,
    OP,
    Slices,
)
from rid.utils import init_executor


[docs]class Label(Steps): def __init__( self, name: str, check_input_op: OP, prep_op: OP, run_op: OP, post_op: OP, prep_config: Dict, run_config: Dict, upload_python_package = None ): self._input_parameters = { "label_config": InputParameter(type=Dict), "cv_config": InputParameter(type=Dict), "kappas": InputParameter(type=List[float]), "angular_mask": InputParameter(type=List), "tail": InputParameter(type=float, value=0.9), "conf_tags": InputParameter(type=List), "block_tag" : InputParameter(type=str, value="") } self._input_artifacts = { "topology" : InputArtifact(optional=True), "models" : InputArtifact(optional=True), "forcefield" : InputArtifact(optional=True), "inputfile": InputArtifact(optional=True), "confs": InputArtifact(), "at": InputArtifact(), "index_file": InputArtifact(optional=True), "dp_files": InputArtifact(optional=True), "cv_file": InputArtifact(optional=True) } self._output_parameters = { } self._output_artifacts = { "md_log": OutputArtifact(), "forces": OutputArtifact() } super().__init__( name=name, inputs=Inputs( parameters=self._input_parameters, artifacts=self._input_artifacts ), outputs=Outputs( parameters=self._output_parameters, artifacts=self._output_artifacts ), ) step_keys = { "check_label_inputs": "{}-check-label-inputs".format(self.inputs.parameters["block_tag"]), "prep_label": "{}-prep-label".format(self.inputs.parameters["block_tag"]), "run_label": "{}-run-label".format(self.inputs.parameters["block_tag"]), "post_label": "{}-post-label".format(self.inputs.parameters["block_tag"]) } self = _label( self, step_keys, check_input_op, prep_op, run_op, post_op, prep_config = prep_config, run_config = run_config, post_config = prep_config, upload_python_package = upload_python_package, ) @property def input_parameters(self): return self._input_parameters @property def input_artifacts(self): return self._input_artifacts @property def output_parameters(self): return self._output_parameters @property def output_artifacts(self): return self._output_artifacts @property def keys(self): return self._keys
def _label( label_steps, step_keys, check_label_input_op : OP, prep_label_op : OP, run_label_op : OP, post_label_op : OP, prep_config : Dict, run_config : Dict, post_config : Dict, upload_python_package : str = None, ): prep_config = deepcopy(prep_config) run_config = deepcopy(run_config) post_config = deepcopy(post_config) prep_template_config = prep_config.pop('template_config') run_template_config = run_config.pop('template_config') post_template_config = post_config.pop('template_config') prep_executor = init_executor(prep_config.pop('executor')) run_executor = init_executor(run_config.pop('executor')) post_executor = init_executor(post_config.pop('executor')) check_label_inputs = Step( 'check-label-inputs', template=PythonOPTemplate( check_label_input_op, python_packages = upload_python_package, **prep_template_config, ), parameters={ "conf_tags": label_steps.inputs.parameters['conf_tags'], }, artifacts={ "confs": label_steps.inputs.artifacts['confs'], }, key = step_keys['check_label_inputs'], executor = prep_executor, **prep_config, ) label_steps.add(check_label_inputs) prep_label = Step( 'prep-label', template=PythonOPTemplate( prep_label_op, python_packages = upload_python_package, slices=Slices("{{item}}", input_parameter=["task_name"], input_artifact=["conf", "at"], output_artifact=["task_path"]), **prep_template_config, ), parameters={ "label_config": label_steps.inputs.parameters['label_config'], "cv_config": label_steps.inputs.parameters['cv_config'], "task_name": check_label_inputs.outputs.parameters['conf_tags'], "kappas": label_steps.inputs.parameters['kappas'] }, artifacts={ "topology": label_steps.inputs.artifacts['topology'], "conf": label_steps.inputs.artifacts['confs'], "at": label_steps.inputs.artifacts['at'], "cv_file": label_steps.inputs.artifacts['cv_file'] }, key = step_keys['prep_label']+"-{{item}}", executor = prep_executor, with_param=argo_range(argo_len(check_label_inputs.outputs.parameters['conf_tags'])), when = "%s > 0" % (check_label_inputs.outputs.parameters["if_continue"]), **prep_config, ) label_steps.add(prep_label) run_label = Step( 'run-label', template=PythonOPTemplate( run_label_op, python_packages = upload_python_package, slices=Slices("{{item}}", input_artifact=["task_path"], output_artifact=["plm_out", "md_log"]), **run_template_config, ), parameters={ "label_config": label_steps.inputs.parameters["label_config"] }, artifacts={ "forcefield": label_steps.inputs.artifacts['forcefield'], "task_path": prep_label.outputs.artifacts["task_path"], "index_file": label_steps.inputs.artifacts['index_file'], "dp_files": label_steps.inputs.artifacts['dp_files'], "cv_file": label_steps.inputs.artifacts['cv_file'], "inputfile": label_steps.inputs.artifacts['inputfile'] }, key = step_keys['run_label']+"-{{item}}", executor = run_executor, with_param=argo_range(argo_len(check_label_inputs.outputs.parameters['conf_tags'])), **run_config, ) label_steps.add(run_label) post_label = Step( 'post-label', template=PythonOPTemplate( post_label_op, python_packages = upload_python_package, slices=Slices("{{item}}", input_parameter=["task_name"], input_artifact=["plm_out", "at"], output_artifact=["forces"]), **post_template_config, ), parameters={ "task_name": check_label_inputs.outputs.parameters['conf_tags'], "kappas": label_steps.inputs.parameters['kappas'], "tail": label_steps.inputs.parameters['tail'], "angular_mask": label_steps.inputs.parameters['angular_mask'] }, artifacts={ "plm_out": run_label.outputs.artifacts["plm_out"], "at": label_steps.inputs.artifacts['at'] }, key = step_keys['post_label']+"-{{item}}", executor = post_executor, with_param=argo_range(argo_len(check_label_inputs.outputs.parameters['conf_tags'])), **post_config, ) label_steps.add(post_label) label_steps.outputs.artifacts["forces"]._from = post_label.outputs.artifacts["forces"] label_steps.outputs.artifacts["md_log"]._from = run_label.outputs.artifacts["md_log"] return label_steps