Bases: BasePipelineWriter
Class for pipeline file writer. Corresponds to "DVC" framework.
Source code in lineapy/plugins/dvc_pipeline_writer.py
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165 | class DVCPipelineWriter(BasePipelineWriter):
"""
Class for pipeline file writer. Corresponds to "DVC" framework.
"""
@property
def docker_template_name(self) -> str:
return "dvc/dvc_dockerfile.jinja"
def _write_dag(self) -> None:
dag_flavor = self.dag_config.get("dag_flavor", "StagePerArtifact")
# Check if the given DAG flavor is a supported/valid one
if dag_flavor not in DVCDagFlavor.__members__:
raise ValueError(f'"{dag_flavor}" is an invalid dvc dag flavor.')
# Construct DAG text for the given flavor
if DVCDagFlavor[dag_flavor] == DVCDagFlavor.SingleStageAllSessions:
dvc_yaml_code = self._write_operator_run_all_sessions()
if DVCDagFlavor[dag_flavor] == DVCDagFlavor.StagePerArtifact:
dvc_yaml_code = self._write_operator_run_per_artifact()
# Write out file
dvc_dag_file = self.output_dir / "dvc.yaml"
dvc_dag_file.write_text(dvc_yaml_code)
logger.info(f"Generated DAG file: {dvc_dag_file}")
def _write_operator_run_all_sessions(self) -> str:
"""
This hidden method implements DVC DAG code generation corresponding
to the `SingleStageAllSessions` flavor. This DAG only has one stage and
calls `run_all_sessions` generated by the module file.
"""
DAG_TEMPLATE = load_plugin_template(
"dvc/dvc_dag_SingleStageAllSessions.jinja"
)
full_code = DAG_TEMPLATE.render(
MODULE_COMMAND=f"python {self.pipeline_name}_module.py",
)
return full_code
def _write_operator_run_per_artifact(self) -> str:
"""
This hidden method implements DVC DAG code generation corresponding
to the `StagePerArtifact` flavor.
"""
DAG_TEMPLATE = load_plugin_template(
"dvc/dvc_dag_StagePerArtifact.jinja"
)
task_defs, _ = get_task_graph(
self.artifact_collection,
pipeline_name=self.pipeline_name,
task_breakdown=DagTaskBreakdown.TaskPerArtifact,
)
full_code = DAG_TEMPLATE.render(
MODULE_NAME=f"{self.pipeline_name}_module", TASK_DEFS=task_defs
)
self._write_params()
self._write_python_operator_per_run_artifact(task_defs)
return full_code
def _write_params(self):
# Get DAG parameters for an DVC pipeline
input_parameters_dict: Dict[str, Any] = {}
for parameter_name, input_spec in super().get_pipeline_args().items():
input_parameters_dict[parameter_name] = input_spec.value
PARAMS_TEMPLATE = load_plugin_template("dvc/dvc_dag_params.jinja")
params_code = PARAMS_TEMPLATE.render(
input_parameters_dict=input_parameters_dict
)
filename = "params.yaml"
params_file = self.output_dir / filename
params_file.write_text(params_code)
logger.info(f"Generated DAG file: {params_file}")
def _write_python_operator_per_run_artifact(
self, task_defs: Dict[str, TaskDefinition]
):
"""
This hidden method generates the python cmd files for each DVC stage.
"""
STAGE_TEMPLATE = load_plugin_template(
"dvc/dvc_dag_PythonOperator.jinja"
)
rendered_task_defs = render_task_definitions(
task_defs,
self.pipeline_name,
task_serialization=TaskSerializer.CWDPickle,
)
# use index to keep track of which rendered task should be written
# since they are returned in the same order as the keys in task_defs
for index, (task_name, task_def) in enumerate(task_defs.items()):
stage_code = STAGE_TEMPLATE.render(
MODULE_NAME=f"{self.pipeline_name}_module",
TASK_CODE=rendered_task_defs[index],
task_name=task_name,
# DVC tasks read each input variable and cannot rely on DAG to provide them
# provide a list here for the main function body
task_parameters=task_def.user_input_variables,
)
filename = f"task_{task_name}.py"
python_operator_file = self.output_dir / filename
python_operator_file.write_text(prettify(stage_code))
logger.info(f"Generated DAG file: {python_operator_file}")
|