Skip to content

dvc_pipeline_writer

DVCPipelineWriter

Bases: BasePipelineWriter

Class for pipeline file writer. Corresponds to "DVC" framework.

Source code in lineapy/plugins/dvc_pipeline_writer.py
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
class DVCPipelineWriter(BasePipelineWriter):
    """
    Class for pipeline file writer. Corresponds to "DVC" framework.
    """

    @property
    def docker_template_name(self) -> str:
        return "dvc/dvc_dockerfile.jinja"

    def _write_dag(self) -> None:
        dag_flavor = self.dag_config.get("dag_flavor", "StagePerArtifact")

        # Check if the given DAG flavor is a supported/valid one
        if dag_flavor not in DVCDagFlavor.__members__:
            raise ValueError(f'"{dag_flavor}" is an invalid dvc dag flavor.')

        # Construct DAG text for the given flavor
        if DVCDagFlavor[dag_flavor] == DVCDagFlavor.SingleStageAllSessions:
            dvc_yaml_code = self._write_operator_run_all_sessions()

        if DVCDagFlavor[dag_flavor] == DVCDagFlavor.StagePerArtifact:
            dvc_yaml_code = self._write_operator_run_per_artifact()

        # Write out file
        dvc_dag_file = self.output_dir / "dvc.yaml"
        dvc_dag_file.write_text(dvc_yaml_code)
        logger.info(f"Generated DAG file: {dvc_dag_file}")

    def _write_operator_run_all_sessions(self) -> str:
        """
        This hidden method implements DVC DAG code generation corresponding
        to the `SingleStageAllSessions` flavor. This DAG only has one stage and
        calls `run_all_sessions` generated by the module file.
        """

        DAG_TEMPLATE = load_plugin_template(
            "dvc/dvc_dag_SingleStageAllSessions.jinja"
        )

        full_code = DAG_TEMPLATE.render(
            MODULE_COMMAND=f"python {self.pipeline_name}_module.py",
        )

        return full_code

    def _write_operator_run_per_artifact(self) -> str:
        """
        This hidden method implements DVC DAG code generation corresponding
        to the `StagePerArtifact` flavor.
        """

        DAG_TEMPLATE = load_plugin_template(
            "dvc/dvc_dag_StagePerArtifact.jinja"
        )

        task_defs, _ = get_task_graph(
            self.artifact_collection,
            pipeline_name=self.pipeline_name,
            task_breakdown=DagTaskBreakdown.TaskPerArtifact,
        )

        full_code = DAG_TEMPLATE.render(
            MODULE_NAME=f"{self.pipeline_name}_module", TASK_DEFS=task_defs
        )

        self._write_params()

        self._write_python_operator_per_run_artifact(task_defs)

        return full_code

    def _write_params(self):
        # Get DAG parameters for an DVC pipeline
        input_parameters_dict: Dict[str, Any] = {}
        for parameter_name, input_spec in super().get_pipeline_args().items():
            input_parameters_dict[parameter_name] = input_spec.value

        PARAMS_TEMPLATE = load_plugin_template("dvc/dvc_dag_params.jinja")

        params_code = PARAMS_TEMPLATE.render(
            input_parameters_dict=input_parameters_dict
        )
        filename = "params.yaml"
        params_file = self.output_dir / filename
        params_file.write_text(params_code)
        logger.info(f"Generated DAG file: {params_file}")

    def _write_python_operator_per_run_artifact(
        self, task_defs: Dict[str, TaskDefinition]
    ):
        """
        This hidden method generates the python cmd files for each DVC stage.
        """
        STAGE_TEMPLATE = load_plugin_template(
            "dvc/dvc_dag_PythonOperator.jinja"
        )

        rendered_task_defs = render_task_definitions(
            task_defs,
            self.pipeline_name,
            task_serialization=TaskSerializer.CWDPickle,
        )

        # use index to keep track of which rendered task should be written
        # since they are returned in the same order as the keys in task_defs
        for index, (task_name, task_def) in enumerate(task_defs.items()):
            stage_code = STAGE_TEMPLATE.render(
                MODULE_NAME=f"{self.pipeline_name}_module",
                TASK_CODE=rendered_task_defs[index],
                task_name=task_name,
                # DVC tasks read each input variable and cannot rely on DAG to provide them
                # provide a list here for the main function body
                task_parameters=task_def.user_input_variables,
            )

            filename = f"task_{task_name}.py"
            python_operator_file = self.output_dir / filename
            python_operator_file.write_text(prettify(stage_code))
            logger.info(f"Generated DAG file: {python_operator_file}")

Was this helpful?

Help us improve docs with your feedback!