import logging
import os
import sys
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
from lineapy.data.types import LineaID
from lineapy.db.db import RelationalLineaDB
from lineapy.graph_reader.api_utils import de_lineate_code
from lineapy.graph_reader.program_slice import (
CodeSlice,
get_program_slice_by_artifact_name,
)
from lineapy.plugins.task import TaskGraph, TaskGraphEdge
from lineapy.plugins.utils import (
get_lib_version_text,
load_plugin_template,
safe_var_name,
)
from lineapy.utils.logging_config import configure_logging
from lineapy.utils.utils import prettify
logger = logging.getLogger(__name__)
configure_logging()
[docs]@dataclass
class BasePlugin:
db: RelationalLineaDB
session_id: LineaID
[docs] def prepare_output_dir(self, copy_dst: str):
"""
This helper creates directories if missing
"""
if not os.path.exists(copy_dst):
os.makedirs(copy_dst)
[docs] def generate_python_module(
self,
module_name: str,
artifacts_code: Dict[str, CodeSlice],
output_dir_path: Path,
):
"""
Generate python module code and save to a file.
"""
full_import_block = ""
full_code_block = ""
for artifact_name, sliced_code in artifacts_code.items():
_import_block = "\n".join(sliced_code.import_lines)
_code_block = f"def {artifact_name}():\n\t" + "\n\t".join(
sliced_code.body_lines
)
full_import_block += "\n" + _import_block
full_code_block += "\n" + _code_block
full_code = prettify(
de_lineate_code(full_import_block + full_code_block, self.db)
)
(output_dir_path / f"{module_name}.py").write_text(full_code)
logger.info(f"Generated python module {module_name}.py")
def get_working_dir_as_str(self):
working_directory = Path(
self.db.get_session_context(self.session_id).working_directory
)
return str(working_directory.resolve())
[docs] def generate_infra(
self,
module_name: str,
output_dir_path: Path,
):
"""
Generates templates to test the airflow module. Currently, we
produce a <module_name>_Dockerfile and a <module_name>_requirements.txt file.
These can be used to test the dag that gets generated by linea. For more
details, :ref:`Testing locally <testingairflow>`
"""
DOCKERFILE_TEMPLATE = load_plugin_template("dockerfile.jinja")
dockerfile = DOCKERFILE_TEMPLATE.render(module_name=module_name)
(output_dir_path / (module_name + "_Dockerfile")).write_text(
dockerfile
)
logger.info(f"Generated Dockerfile {module_name}_Dockerfile")
all_libs = self.db.get_libraries_for_session(self.session_id)
lib_names_text = ""
for lib in all_libs:
if lib.name in sys.modules:
text = get_lib_version_text(str(lib.package_name))
lib_names_text += f"{text}\n"
# lib_names_text = "\n".join([str(lib.name) for lib in all_libs])
(output_dir_path / (module_name + "_requirements.txt")).write_text(
lib_names_text
)
logger.info(
f"Generated requirements file {module_name}_requirements.txt"
)
[docs] def slice_dag_helper(
self,
slice_names: List[str],
module_name: Optional[str] = None,
task_dependencies: TaskGraphEdge = {},
output_dir: Optional[str] = None,
) -> Tuple[str, List[str], Path, TaskGraph]:
"""
A generic function shared by Script and Airflow
To create DAG from the sliced code. This includes a python
file with one function per slice, task dependencies file in Airflow
format and an example Dockerfile and requirements.txt that can be used
to run this.
:param slice_names: list of slice names to be used as tasks.
:param module_name: name of the Python module the generated code will
be saved to.
:param task_dependencies: tasks dependencies in graphlib format
{'B':{'A','C'}}"; this means task A and C are prerequisites for
task B.
:param output_dir: directory to save the generated code to.
:param airflow_dag_config: Configs of Airflow DAG model.
"""
artifacts_code = {}
artifact_safe_names = []
for slice_name in slice_names:
artifact_var = safe_var_name(slice_name)
slice_code: CodeSlice = get_program_slice_by_artifact_name(
self.db, slice_name, keep_lineapy_save=True
)
artifacts_code[artifact_var] = slice_code
artifact_safe_names.append(artifact_var)
task_graph = TaskGraph(
slice_names,
{
slice: task
for slice, task in zip(slice_names, artifact_safe_names)
},
task_dependencies,
)
module_name = module_name or "_".join(artifact_safe_names)
output_dir_path = Path.cwd()
if output_dir:
output_dir_path = Path(os.path.expanduser(output_dir))
self.prepare_output_dir(
copy_dst=str(output_dir_path.resolve()),
)
logger.info(
"Pipeline source generated in the directory: %s", output_dir_path
)
self.generate_python_module(
module_name, artifacts_code, output_dir_path
)
self.generate_infra(
module_name=module_name, output_dir_path=output_dir_path
)
return module_name, artifact_safe_names, output_dir_path, task_graph