Skip to content

mlflow_io

Serializer and deserializer for MLflow supported model flavors. Only include annotated modules at this time.

  • sklearn
  • prophet
  • tensorflow (need update current annotation)
  • keras (need update current annotation)
  • h2o
  • gluon
  • xgboost
  • lightgbm
  • catboost
  • spacy
  • fastai
  • statsmodels

mlflow_io = {} module-attribute

This dictionary holds information about how to serialize and deserialize for each MLflow supported model flavors. Each individual key is a module name; i.e., MLflow supported flavor, sklearn for instance. Each value is a dictionary with following three keys:

  1. class: list of object class with the module that can be log/save in MLflow
  2. serializer: the method in MLflow to log the model flavor
  3. deserializer is the method in MLflow to retrieve the model

read_mlflow(mlflow_metadata)

Read model from MLflow artifact store

In case the saved tracking_uri or the registry_uri is different from lineapy configs(i.e., multiple MLflow backends), we recored current lineapy related mlflow configs for tracking/registry and set tracking/registry based on metadata to load the model and reset back to the original configs.

Source code in lineapy/plugins/serializers/mlflow_io.py
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
def read_mlflow(mlflow_metadata: MLflowArtifactInfo) -> Any:
    """
    Read model from MLflow artifact store

    In case the saved tracking_uri or the registry_uri is different from
    lineapy configs(i.e., multiple MLflow backends), we recored current
    lineapy related mlflow configs for tracking/registry and set
    tracking/registry based on metadata to load the model and reset back
    to the original configs.
    """

    current_mlflow_tracking_uri = options.get("mlflow_tracking_uri")
    current_mlflow_registry_uri = options.get("mlflow_registry_uri")
    mlflow.set_tracking_uri(mlflow_metadata.tracking_uri)
    mlflow.set_registry_uri(mlflow_metadata.registry_uri)

    assert isinstance(mlflow_metadata.model_flavor, str)
    value = mlflow_io[mlflow_metadata.model_flavor]["deserializer"](
        mlflow_metadata.model_uri
    )

    mlflow.set_tracking_uri(current_mlflow_tracking_uri)
    mlflow.set_registry_uri(current_mlflow_registry_uri)
    return value

try_write_to_mlflow(value, name, **kwargs)

Try to save artifact with MLflow. If success return mlflow ModelInfo, return None if fail.

Parameters:

Name Type Description Default
value Any

value(ML model) to save with mlflow

required
name str

artifact_path and registered_model_name used in mlflow.sklearn.log_model or equivalent flavors

required
**kwargs

args to pass into mlflow.sklearn.log_model or equivalent flavors

{}

Returns:

Type Description
Optional[Any]

return a ModelInfo(MLflow model metadata) if successfully save with mlflow; otherwise None. Note that, using Any not ModelInfo here in case mlflow is not installed and cause error when loading lineapy

Source code in lineapy/plugins/serializers/mlflow_io.py
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
def try_write_to_mlflow(value: Any, name: str, **kwargs) -> Optional[Any]:
    """
    Try to save artifact with MLflow. If success return mlflow ModelInfo,
    return None if fail.

    Parameters
    ----------
    value: Any
        value(ML model) to save with mlflow
    name: str
        artifact_path and registered_model_name used in
        `mlflow.sklearn.log_model` or equivalent flavors
    **kwargs:
        args to pass into `mlflow.sklearn.log_model` or equivalent flavors

    Returns
    -------
    Optional[Any]
        return a ModelInfo(MLflow model metadata) if successfully save with
        mlflow; otherwise None. Note that, using Any not ModelInfo here in
        case mlflow is not installed and cause error when loading lineapy

    """

    logger.info("Trying to save the object to MLflow.")

    # Check mlflow is installed, if not raise error
    if "mlflow" not in sys.modules:
        msg = (
            "module 'mlflow' is not installed;"
            + " please install it with 'pip install lineapy[mlflow]'"
        )
        raise ModuleNotFoundError(msg)
    mlflow.set_tracking_uri(options.get("mlflow_tracking_uri"))

    # Check value is from a module supported by mlflow
    full_module_name = getmodule(value)
    if full_module_name is not None:
        root_module_name = full_module_name.__name__.split(".")[0]
        if root_module_name in mlflow_io.keys():
            flavor_io = mlflow_io[root_module_name]
            # Check value is the right class type for the module supported by mlflow
            if any(
                [
                    isinstance(value, target_class)
                    for target_class in flavor_io["class"]
                ]
            ):
                kwargs["registered_model_name"] = kwargs.get(
                    "registered_model_name", name
                )
                kwargs["artifact_path"] = kwargs.get("artifact_path", name)
                # This is where save to MLflow happen
                model_info = flavor_io["serializer"](value, **kwargs)
                return model_info

    logger.info(
        f"LineaPy is currently not supporting saving {type(value)} to MLflow."
    )
    return None

Was this helpful?

Help us improve docs with your feedback!