Model Save

Saves the model and metadata into a portable archive (ZIP) for seamless loading.

Model Save

Processing

This brick packages your trained machine learning model into a portable format (a ZIP archive or a directory) for safe storage and easy deployment. It automatically detects the library used to create the model (such as Scikit-Learn, XGBoost, CatBoost, or LightGBM) and handles the specific saving requirements for that library.

In addition to the model file, this brick generates a model_metadata.json file. This metadata includes critical information like the creation timestamp, library version, and Python version, ensuring the model can be reliably loaded and used in other workflows or environments.

Inputs

model
The trained machine learning model object you want to save. This is typically the output from a training brick (e.g., an XGBoost or Scikit-Learn classifier).
directory
The destination folder path where the model archive or directory will be created.

Inputs Types

Input Types
model Any
directory Str, DirectoryPath

You can check the list of supported types here: Available Type Hints.

Outputs

Model path
The full file path to the created artifact. Depending on your chosen options, this will be the path to a .zip file or a folder.

Outputs Types

Output Types
Model path Str, FilePath

You can check the list of supported types here: Available Type Hints.

Options

The Model Save brick contains some changeable options:

Custom Prefix
A text label added to the beginning of the filename to help identify the model (e.g., "customer_churn" or "sales_forecast").
Include Date (YYYYMMDD)
If enabled, appends the current date to the filename. Useful for organizing models chronologically.
Include Time (HHMMSS)
If enabled, appends the current time to the filename. Useful if you generate multiple models on the same day.
Auto-Increment Version
If enabled, checks the destination folder for existing files with the same name and automatically adds a version suffix (e.g., _v1, _v2) to prevent overwriting previous work.
Archive Format
Determines how the model and its metadata are packaged. You can either choose to compress the model and metadata into a single .zip file, best for downloading or moving between systems, or save the model and metadata as loose files inside a standard folder, best if you need immediate access to the internal files without unzipping.
Return as Path Object
If enabled, returns the output as a Python pathlib.Path object instead of a standard string. Keep this disabled unless your workflow specifically requires Path objects.
Verbose
If enabled, logs detailed information about the saving process, such as the detected model type and the final save location.
import re
import json
import shutil
import joblib
import zipfile
import logging
import pathlib
import tempfile
import datetime
import importlib.metadata
from coded_flows.types import Union, Str, Any, FilePath, DirectoryPath
from coded_flows.utils import CodedFlowsLogger

logger = CodedFlowsLogger(name="Model Save", level=logging.INFO)


def _to_snake_case(name: str) -> str:
    """Converts CamelCase to snake_case for standardized naming."""
    name = re.sub("(.)([A-Z][a-z]+)", "\\1_\\2", name)
    return re.sub("([a-z0-9])([A-Z])", "\\1_\\2", name).lower()


def _get_library_version(library_name: str) -> str:
    """Retrieves the installed version of the library safely."""
    try:
        return importlib.metadata.version(library_name)
    except importlib.metadata.PackageNotFoundError:
        return "unknown"


def _detect_model_context(model: Any) -> dict:
    """
    Analyzes the model object to extract context for metadata.
    Returns: dict with library, class, preferred extension, and serialization format.
    """
    model_type = type(model)
    module_path = model_type.__module__
    class_name = model_type.__name__
    snake_name = _to_snake_case(class_name)
    all_modules = [base.__module__ for base in model_type.__mro__]
    context = {
        "class_name": class_name,
        "module_path": module_path,
        "base_name": snake_name,
    }
    if any(("xgboost" in m for m in all_modules)):
        context.update(
            {
                "library": "xgboost",
                "extension": ".json",
                "format": "xgboost_json",
                "version": _get_library_version("xgboost"),
            }
        )
        return context
    if "catboost" in module_path:
        context.update(
            {
                "library": "catboost",
                "extension": ".cbm",
                "format": "catboost_binary",
                "version": _get_library_version("catboost"),
            }
        )
        return context
    if "lightgbm" in module_path:
        is_sklearn = "sklearn" in module_path or hasattr(model, "fit")
        context.update(
            {
                "library": "lightgbm",
                "extension": ".joblib" if is_sklearn else ".txt",
                "format": "joblib" if is_sklearn else "lightgbm_text",
                "version": _get_library_version("lightgbm"),
            }
        )
        return context
    if "sklearn" in module_path or (
        hasattr(model, "fit") and hasattr(model, "predict")
    ):
        context.update(
            {
                "library": "sklearn",
                "extension": ".joblib",
                "format": "joblib",
                "version": _get_library_version("scikit-learn"),
            }
        )
        return context
    raise ValueError(
        f"Unsupported or unrecognized model type: {module_path}.{class_name}"
    )


def _get_next_version_index(
    directory: pathlib.Path, base_name: str, is_archive: bool
) -> int:
    """Finds the next integer version based on existing files."""
    ext_pattern = "\\.zip" if is_archive else ""
    pattern = re.compile(f"^{re.escape(base_name)}_v(\\d+){ext_pattern}$")
    max_v = 0
    if directory.exists():
        for item in directory.iterdir():
            match = pattern.match(item.name)
            if match:
                current_v = int(match.group(1))
                if current_v > max_v:
                    max_v = current_v
    return max_v + 1


def save_model(
    model: Any, directory: Union[Str, DirectoryPath], options: dict = None
) -> Union[Str, FilePath]:
    options = options or {}
    verbose = options.get("verbose", True)
    custom_prefix = options.get("custom_prefix", "model")
    include_date = options.get("include_date", False)
    include_time = options.get("include_time", False)
    use_versioning = options.get("use_versioning", True)
    archive_format = options.get("archive_format", "zip")
    return_as_pathlib = options.get("return_as_pathlib", False)
    save_dir = pathlib.Path(directory)
    result_path = None
    with tempfile.TemporaryDirectory() as temp_dir_str:
        temp_dir = pathlib.Path(temp_dir_str)
        try:
            if not save_dir.exists():
                verbose and logger.info(f"Creating output directory: {save_dir}")
                save_dir.mkdir(parents=True, exist_ok=True)
            context = _detect_model_context(model)
            verbose and logger.info(
                f"Detected {context['library']} model ({context['class_name']})."
            )
            base_name_parts = [custom_prefix, context["base_name"]]
            now = datetime.datetime.now()
            if include_date:
                base_name_parts.append(now.strftime("%Y%m%d"))
            if include_time:
                base_name_parts.append(now.strftime("%H%M%S"))
            clean_base_name = "_".join(filter(None, base_name_parts))
            version_str = ""
            if use_versioning:
                next_v = _get_next_version_index(
                    save_dir, clean_base_name, archive_format == "zip"
                )
                version_str = f"_v{next_v}"
            final_name = f"{clean_base_name}{version_str}"
            model_filename = f"model{context['extension']}"
            model_temp_path = temp_dir / model_filename
            lib_fmt = context["format"]
            if lib_fmt == "joblib":
                joblib.dump(model, model_temp_path)
            elif lib_fmt == "xgboost_json":
                model.save_model(str(model_temp_path))
            elif lib_fmt == "catboost_binary":
                model.save_model(str(model_temp_path))
            elif lib_fmt == "lightgbm_text":
                model.save_model(str(model_temp_path))
            metadata = {
                "id": final_name,
                "timestamp": now.isoformat(),
                "library": context["library"],
                "library_version": context["version"],
                "class_name": context["class_name"],
                "serialization_format": context["format"],
                "python_version": importlib.metadata.sys.version.split()[0],
                "artifacts": {"model": model_filename},
            }
            metadata_path = temp_dir / "model_metadata.json"
            with open(metadata_path, "w", encoding="utf-8") as f:
                json.dump(metadata, f, indent=4)
            verbose and logger.info("Model artifacts and metadata prepared.")
            if archive_format == "zip":
                output_zip_path = save_dir / f"{final_name}.zip"
                with zipfile.ZipFile(output_zip_path, "w", zipfile.ZIP_DEFLATED) as zf:
                    zf.write(model_temp_path, arcname=model_filename)
                    zf.write(metadata_path, arcname="model_metadata.json")
                result_path = output_zip_path
                verbose and logger.info(f"Saved portable archive to: {result_path}")
            else:
                output_folder_path = save_dir / final_name
                if output_folder_path.exists():
                    shutil.rmtree(output_folder_path)
                shutil.copytree(temp_dir, output_folder_path)
                result_path = output_folder_path
                verbose and logger.info(f"Saved model directory to: {result_path}")
        except Exception as e:
            verbose and logger.error(f"Failed to save model: {e}")
            raise e
    Model_path = result_path if return_as_pathlib else str(result_path)
    return Model_path

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • joblib
  • shap>=0.47.0
  • lightgbm
  • scikit-learn
  • xgboost
  • catboost
  • numba>=0.56.0