Model Load

Restores a machine learning model and its metadata from a ZIP archive or directory created by the Save Model brick.

Model Load

Processing

This brick restores a machine learning model and its associated settings from a stored file or directory. It is designed to work in tandem with the Save Model brick.

When executed, the brick reads a ZIP archive or a directory, extracts the necessary artifacts (the model file) and the configuration data (model_metadata.json). It then reconstructs the model object in memory, supporting libraries like Scikit-Learn, XGBoost, and LightGBM. Once loaded, the model is ready to be passed to other bricks for making predictions or further analysis.

Inputs

path
The location of the saved model. This can be a file path pointing to a .zip archive or a path to a directory containing the model files. The location must contain the valid model_metadata.json file generated during the saving process.

Inputs Types

Input Types
path Str, FilePath, DirectoryPath

You can check the list of supported types here: Available Type Hints.

Outputs

Model
The fully reconstructed machine learning model object. This object contains all the learned patterns and rules from training and is ready to be used by a prediction brick.

Outputs Types

Output Types
Model Any

You can check the list of supported types here: Available Type Hints.

Options

The Model Load brick contains some changeable options:

Verbose
Controls the amount of information logged during the loading process.
import json
import joblib
import zipfile
import logging
import pathlib
import tempfile
import importlib
import xgboost as xgb
from coded_flows.types import Union, Str, Any, FilePath, DirectoryPath
from coded_flows.utils import CodedFlowsLogger

logger = CodedFlowsLogger(name="Model Load", level=logging.INFO)


def _get_class_by_name(module_path: str, class_name: str) -> Any:
    """Dynamically imports a class based on its module and name."""
    try:
        module = importlib.import_module(module_path)
        return getattr(module, class_name)
    except (ImportError, AttributeError):
        return None


def _load_xgboost_native(model_path: str, class_name: str, module_path: str) -> Any:
    """Helper to load XGBoost models as scikit-learn wrappers."""
    sklearn_wrappers = {
        "XGBClassifier": xgb.XGBClassifier,
        "XGBRegressor": xgb.XGBRegressor,
        "XGBRanker": xgb.XGBRanker,
        "XGBRFClassifier": xgb.XGBRFClassifier,
        "XGBRFRegressor": xgb.XGBRFRegressor,
        "_AutoBalanceXGBClassifier": xgb.XGBRFClassifier,
    }
    if class_name in sklearn_wrappers:
        instance = sklearn_wrappers[class_name]()
        instance.load_model(model_path)
        return instance
    if "sklearn" in module_path:
        cls = _get_class_by_name(module_path, class_name)
        if cls:
            instance = cls()
            instance.load_model(model_path)
            return instance
    instance = xgb.XGBRegressor()
    instance.load_model(model_path)
    return instance


def _load_lightgbm_native(model_path: str) -> Any:
    """Helper to load LightGBM Boosters (Sklearn API LGBMs are saved as joblib)."""
    import lightgbm as lgb

    return lgb.Booster(model_file=model_path)


def load_model(path: Union[Str, FilePath, DirectoryPath], options: dict = None) -> Any:
    """
    Loads a model from a ZIP file or directory containing artifacts and metadata.
    """
    options = options or {}
    verbose = options.get("verbose", True)
    if not path:
        raise ValueError("Path to model (ZIP or Directory) is required.")
    input_path = pathlib.Path(path)
    Model = None
    if not input_path.exists():
        verbose and logger.error(f"Path does not exist: {input_path}")
        raise FileNotFoundError(f"Model path not found: {input_path}")
    verbose and logger.info(f"Attempting to load model from: {input_path}")
    with tempfile.TemporaryDirectory() as temp_dir_str:
        work_dir = pathlib.Path(temp_dir_str)
        try:
            if input_path.is_file() and input_path.suffix.lower() == ".zip":
                verbose and logger.info("Detected ZIP archive. Extracting...")
                with zipfile.ZipFile(input_path, "r") as zf:
                    zf.extractall(work_dir)
            elif input_path.is_dir():
                work_dir = input_path
            else:
                raise ValueError(
                    f"Unsupported file type: {input_path}. Must be a directory or .zip file."
                )
            metadata_path = work_dir / "model_metadata.json"
            if not metadata_path.exists():
                raise FileNotFoundError(
                    "model_metadata.json not found. Invalid model archive."
                )
            with open(metadata_path, "r", encoding="utf-8") as f:
                metadata = json.load(f)
            class_name = metadata.get("class_name")
            library = metadata.get("library")
            fmt = metadata.get("serialization_format")
            model_filename = metadata.get("artifacts", {}).get("model")
            module_path = metadata.get("module_path", "")
            if not module_path:
                if library == "xgboost":
                    module_path = "xgboost.sklearn"
                elif library == "sklearn":
                    module_path = "sklearn"
            model_file_path = work_dir / model_filename
            if not model_file_path.exists():
                raise FileNotFoundError(
                    f"Model artifact '{model_filename}' missing from archive."
                )
            verbose and logger.info(
                f"Metadata loaded. Library: {library}, Class: {class_name}, Format: {fmt}"
            )
            if fmt == "joblib":
                Model = joblib.load(model_file_path)
            elif fmt == "xgboost_json":
                Model = _load_xgboost_native(
                    str(model_file_path), class_name, module_path
                )
            elif fmt == "lightgbm_text":
                Model = _load_lightgbm_native(str(model_file_path))
            else:
                raise ValueError(f"Unknown serialization format: {fmt}")
            verbose and logger.info("Model loaded successfully.")
        except Exception as e:
            verbose and logger.error(f"Failed to load model: {e}")
            raise e
    return Model

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • joblib
  • shap>=0.47.0
  • xgboost
  • lightgbm
  • scikit-learn
  • numba>=0.56.0