Source code for rid.superop.data

from typing import Dict, List
from copy import deepcopy
from dflow import (
    InputParameter,
    Inputs,
    InputArtifact,
    Outputs,
    OutputArtifact,
    Step,
    Steps
)
from dflow.python import(
    PythonOPTemplate,
    OP
)
from rid.utils import init_executor


[docs]class DataGenerator(Steps): def __init__( self, name: str, collect_op: OP, merge_op: OP, run_config: Dict, upload_python_package = None ): self._input_parameters = { "block_tag" : InputParameter(type=str, value="") } self._input_artifacts = { "forces": InputArtifact(), "centers": InputArtifact(), "data_old": InputArtifact(optional=True) } self._output_parameters = {} self._output_artifacts = { "data": OutputArtifact() } super().__init__( name=name, inputs=Inputs( parameters=self._input_parameters, artifacts=self._input_artifacts ), outputs=Outputs( parameters=self._output_parameters, artifacts=self._output_artifacts ), ) self = _gen_data( self, collect_op, merge_op, run_config = run_config, upload_python_package = upload_python_package, ) @property def input_parameters(self): return self._input_parameters @property def input_artifacts(self): return self._input_artifacts @property def output_parameters(self): return self._output_parameters @property def output_artifacts(self): return self._output_artifacts @property def keys(self): return self._keys
def _gen_data( data_steps, collect_op : OP, merge_op : OP, run_config : Dict, upload_python_package : str = None, ): run_config = deepcopy(run_config) run_template_config = run_config.pop('template_config') run_executor = init_executor(run_config.pop('executor')) collect_data = Step( 'collect-data', template=PythonOPTemplate( collect_op, python_packages = upload_python_package, **run_template_config, ), parameters={}, artifacts={ "forces": data_steps.inputs.artifacts['forces'], "centers": data_steps.inputs.artifacts['centers'] }, key = '{}-collect-data'.format(data_steps.inputs.parameters["block_tag"]), executor = run_executor, **run_config, ) data_steps.add(collect_data) merge_data = Step( 'merge-data', template=PythonOPTemplate( merge_op, python_packages = upload_python_package, **run_template_config, ), parameters={}, artifacts={ "data_old": data_steps.inputs.artifacts["data_old"], "data_new": collect_data.outputs.artifacts["data_new"] }, key = '{}-merge-data'.format(data_steps.inputs.parameters["block_tag"]), executor = run_executor, **run_config, ) data_steps.add(merge_data) data_steps.outputs.artifacts["data"]._from = merge_data.outputs.artifacts["data_raw"] return data_steps