Reg. ElasticNet

Train an ElasticNet regression model.

Reg. ElasticNet

Processing

This brick trains an ElasticNet regression model to predict numerical values based on your input features. ElasticNet combines two regularization techniques (L1 and L2) to prevent overfitting and handle datasets with many features, including cases where features might be correlated.

The brick automatically splits your data into training and test sets, measures performance using multiple metrics, and can optionally optimize hyperparameters to find the best model configuration. It returns a trained model ready for making predictions, along with detailed performance metrics and optional explainability tools.

Inputs

X
The feature data used to train the model. This should be a table (DataFrame) or array containing the independent variables that will be used to predict the target. Each row represents one observation, and each column represents a feature.
y
The target values you want to predict. This is the dependent variable—the actual numerical outcomes your model will learn to predict. Must be numerical values that correspond to each row in your feature data.

Inputs Types

Input Types
X DataFrame
y DataSeries, NDArray, List

You can check the list of supported types here: Available Type Hints.

Outputs

Model
The trained ElasticNet regression model object. Use this to make predictions on new data by connecting it to prediction bricks in your workflow.
SHAP
The SHAP explainer object for interpreting model predictions (only returned if "SHAP Explainer" option is enabled). This helps you understand which features contribute most to each prediction.
Scaler
The StandardScaler object used to normalize the data (only returned if "Standard Scaling" option is enabled). You'll need this to transform new data the same way before making predictions.
Metrics
Performance measurements showing how well the model predicts on the test set. Includes metrics like R2 Score, Mean Absolute Error, Root Mean Squared Error, and others. Format depends on the "Metrics as" option.
CV Metrics
Cross-validation performance metrics showing model stability across different data splits (only returned if "Enable Cross-Validation" is enabled). This provides mean and standard deviation for each metric across all folds.
Prediction Set
A DataFrame combining the test features, actual values (y_true), and predicted values (y_pred). Useful for analyzing individual predictions and identifying patterns in model errors.
HPO Trials
A detailed history of all hyperparameter optimization trials (only returned if "Hyperparameter Optim." is enabled). Shows which parameter combinations were tested and their performance scores.
HPO Best
The best hyperparameter values found during optimization (only returned if "Hyperparameter Optim." is enabled). Contains the optimal Alpha, L1 Ratio, and Tolerance values.

The Metrics output contains the following performance measurements:

  • Forecast Accuracy: The percentage of total magnitude predicted correctly (1 - WAPE). Values closer to 1.0 (100%) indicate better predictions.
  • Weighted Absolute Percentage Error: The sum of absolute errors divided by the sum of actual values. Lower values indicate better accuracy.
  • Mean Absolute Error: The average absolute difference between predicted and actual values, in the same units as your target variable.
  • Mean Squared Error: The average squared difference between predicted and actual values. More sensitive to large errors than MAE.
  • Root Mean Squared Error: The square root of MSE, returning to the original units of the target variable.
  • R2 Score: The proportion of variance explained by the model. Values range from negative (worse than predicting the mean) to 1.0 (perfect predictions).
  • Mean Absolute Percentage Error: The average percentage error across all predictions.

The Prediction_Set output contains:

  • feature columns: All original feature columns from your test data, preserving their original names.
  • y_true: The actual target values from the test set.
  • y_pred: The model's predicted values for the test set.

The HPO_Best output contains the optimized hyperparameters:

  • alpha: The optimized regularization strength.
  • l1_ratio: The optimized L1/L2 ratio.
  • tol: The optimized convergence tolerance.

Outputs Types

Output Types
Model Any
SHAP Any
Scaler Any
Metrics DataFrame, Dict
CV_Metrics DataFrame
Prediction_Set DataFrame
HPO_Trials DataFrame
HPO_Best Dict

You can check the list of supported types here: Available Type Hints.

Options

The Reg. ElasticNet brick contains some changeable options:

Alpha
Controls the overall strength of regularization (the penalty applied to model complexity). Higher values create simpler models that may underfit, while lower values create more complex models that may overfit. Default is 1.0.
L1 Ratio
Determines the balance between L1 (Lasso) and L2 (Ridge) regularization. A value of 1.0 means pure L1 (feature selection), 0.0 means pure L2 (coefficient shrinkage), and 0.5 is an equal mix. Default is 0.5.
Include Intercept
When enabled, the model learns a baseline value (intercept/bias term) in addition to feature weights. Disable this only if your data is already centered around zero. Default is enabled.
Tolerance
The precision threshold for stopping model training. Training stops when improvements become smaller than this value. Lower values mean more precise models but longer training time. Default is 0.0001.
Maximum Iterations
The maximum number of training cycles allowed before stopping. Increase this if you see warnings about the model not converging. Default is 1000.
Standard Scaling
When enabled, normalizes all features to have zero mean and unit variance before training. This is recommended when features have different scales (e.g., mixing ages with salaries). Default is disabled.
Auto Split Data
When enabled, automatically determines the optimal train/test split ratio based on your dataset size. Larger datasets use smaller test percentages. When disabled, uses the percentage specified in "Test/Validation Set %". Default is enabled.
Shuffle Split
When enabled, randomly shuffles the data before splitting into train and test sets. Disable this for time-series data or when the order of observations matters. Default is enabled.
Test/Validation Set %
The percentage of data reserved for testing (not used for training). Only applies when "Auto Split Data" is disabled. Higher values give more reliable performance estimates but less training data. Default is 15%.
Retrain On Full Data
When enabled, retrains the final model on the entire dataset after performance evaluation. Use this for production deployment when you want to use all available data. Performance metrics still reflect the original train/test split. Default is disabled.
Enable Cross-Validation
When enabled, evaluates the model using k-fold cross-validation instead of a single train/test split. This provides more robust performance estimates by testing on multiple different data splits. Default is disabled.
Number of CV Folds
The number of cross-validation splits to create. Higher values give more reliable estimates but take longer to compute. Only applies when "Enable Cross-Validation" is enabled. Default is 5.
Hyperparameter Optim.
When enabled, automatically searches for the best Alpha, L1 Ratio, and Tolerance values instead of using the manually specified values. This can significantly improve model performance but takes longer to run. Default is disabled.
Optimization Metric
The performance metric to optimize when searching for the best hyperparameters. The algorithm tries to find parameters that maximize or minimize this metric.
  • Forecast Accuracy: Measures what percentage of the total magnitude is predicted correctly (higher is better). Good for business forecasting.
  • Weighted Absolute Percentage Error (WAPE): The sum of errors divided by the sum of actual values, expressed as a percentage (lower is better). Less sensitive to outliers than MAPE.
  • Mean Absolute Error (MAE): The average absolute difference between predictions and actual values (lower is better). Easy to interpret in the same units as your target.
  • Mean Squared Error (MSE): The average squared difference between predictions and actual values (lower is better). Penalizes large errors more heavily than MAE.
  • Root Mean Squared Error (RMSE): The square root of MSE, in the same units as your target (lower is better). Standard metric for regression.
  • R2 Score: The proportion of variance in the target explained by the model, ranging from 0 to 1 (higher is better). Useful for comparing different models.
  • Mean Absolute Percentage Error (MAPE): The average percentage error across all predictions (lower is better). Can be unstable when actual values are close to zero.
Optimization Method
The algorithm used to search for the best hyperparameters. Different methods have different strengths in exploring the parameter space.
  • Tree-structured Parzen: An intelligent Bayesian method that learns from previous trials. Generally the best all-around choice for most problems.
  • Gaussian Process: A sophisticated Bayesian method that models the entire parameter space. Works well for smooth optimization landscapes but slower for many trials.
  • CMA-ES: Evolution-based strategy good for complex, non-smooth optimization problems. Particularly effective when parameters interact in complex ways.
  • Random Sobol Search: Quasi-random sampling that ensures even coverage of the parameter space. More systematic than pure random search.
  • Random Search: Randomly tries parameter combinations. Simple and can work well as a baseline, but less efficient than smarter methods.
Optimization Iterations
The number of different hyperparameter combinations to try. More iterations increase the chance of finding better parameters but take longer. Default is 50.
Metrics as
Determines the format of the performance metrics output.
SHAP Explainer
When enabled, generates a SHAP (SHapley Additive exPlanations) explainer object that can show which features contribute most to each prediction. Essential for model interpretability and debugging. Default is disabled.
SHAP Sampler
When enabled and using SHAP, creates a condensed background dataset for faster SHAP calculations on large datasets. Disable for complete accuracy on smaller datasets. Default is disabled.
Number of Jobs
The number of parallel processing threads to use for training and optimization. Higher values speed up computation but use more system resources.
Random State
A seed number that controls randomness in data shuffling and model initialization. Using the same value ensures reproducible results across different runs. Default is 42.
Brick Caching
When enabled, saves trained models and results to disk. If you run the brick again with identical inputs and settings, it loads the cached results instantly instead of retraining. Useful for iterative workflow development. Default is disabled.
Verbose Logging
When enabled, prints detailed progress information during training, including metric values and optimization progress. Helpful for monitoring long-running processes. Default is enabled.
import logging
import warnings
import shap
import json
import xxhash
import hashlib
import tempfile
import sklearn
import scipy
import joblib
import numpy as np
import pandas as pd
import polars as pl
from pathlib import Path
from scipy import sparse
from optuna.samplers import (
    TPESampler,
    RandomSampler,
    GPSampler,
    CmaEsSampler,
    QMCSampler,
)
import optuna
from optuna import Study
from optuna.trial import FrozenTrial
from optuna.pruners import HyperbandPruner
from optuna import create_study
from sklearn.linear_model import ElasticNet
from sklearn.model_selection import train_test_split, cross_validate, KFold
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import (
    mean_absolute_error,
    mean_squared_error,
    r2_score,
    root_mean_squared_error,
    mean_absolute_percentage_error,
    make_scorer,
)
from dataclasses import dataclass
from datetime import datetime
from coded_flows.types import (
    Union,
    Dict,
    List,
    Tuple,
    NDArray,
    DataFrame,
    DataSeries,
    Any,
    Tuple,
)
from coded_flows.utils import CodedFlowsLogger

logger = CodedFlowsLogger(name="Reg. ElasticNet", level=logging.INFO)
optuna.logging.set_verbosity(optuna.logging.ERROR)
warnings.filterwarnings("ignore", category=optuna.exceptions.ExperimentalWarning)
METRICS_DICT = {
    "Forecast Accuracy": "fa",
    "Weighted Absolute Percentage Error (WAPE)": "wape",
    "Mean Absolute Error (MAE)": "mae",
    "Mean Squared Error (MSE)": "mse",
    "Root Mean Squared Error (RMSE)": "rmse",
    "R2 Score": "r2",
    "Mean Absolute Percentage Error (MAPE)": "mape",
}
METRICS_OPT = {
    "fa": "maximize",
    "wape": "minimize",
    "mae": "minimize",
    "mse": "minimize",
    "rmse": "minimize",
    "r2": "maximize",
    "mape": "minimize",
}
DataType = Union[
    pd.DataFrame, pl.DataFrame, np.ndarray, sparse.spmatrix, pd.Series, pl.Series
]


@dataclass
class _DatasetFingerprint:
    """Lightweight fingerprint of a dataset."""

    hash: str
    shape: tuple
    computed_at: str
    data_type: str
    method: str


class _UniversalDatasetHasher:
    """
    High-performance dataset hasher optimizing for zero-copy operations
    and native backend execution (C/Rust).
    """

    def __init__(
        self,
        data_size: int,
        method: str = "auto",
        sample_size: int = 100000,
        verbose: bool = False,
    ):
        self.method = method
        self.sample_size = sample_size
        self.data_size = data_size
        self.verbose = verbose

    def hash_data(self, data: DataType) -> _DatasetFingerprint:
        """
        Main entry point: hash any supported data format.
        Auto-detects format and applies optimal strategy.
        """
        if isinstance(data, pd.DataFrame):
            return self._hash_pandas(data)
        elif isinstance(data, pl.DataFrame):
            return self._hash_polars(data)
        elif isinstance(data, pd.Series):
            return self._hash_pandas_series(data)
        elif isinstance(data, pl.Series):
            return self._hash_polars_series(data)
        elif isinstance(data, np.ndarray):
            return self._hash_numpy(data)
        elif sparse.issparse(data):
            return self._hash_sparse(data)
        else:
            raise TypeError(f"Unsupported data type: {type(data)}")

    def _hash_pandas(self, df: pd.DataFrame) -> _DatasetFingerprint:
        """
        Optimized Pandas hashing using pd.util.hash_pandas_object.
        Avoids object-to-string conversion overhead.
        """
        method = self._determine_method(self.data_size, self.method)
        self.verbose and logger.info(
            f"Hashing Pandas: {self.data_size:,} rows - {method}"
        )
        target_df = df
        if method == "sampled":
            target_df = self._get_pandas_sample(df)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns.tolist(),
                "dtypes": {k: str(v) for (k, v) in df.dtypes.items()},
                "shape": df.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(target_df, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception as e:
            self.verbose and logger.warning(
                f"Fast hash failed, falling back to slow hash: {e}"
            )
            self._hash_pandas_fallback(hasher, target_df)
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas",
            method=method,
        )

    def _get_pandas_sample(self, df: pd.DataFrame) -> pd.DataFrame:
        """Deterministic slicing for sampling (Zero randomness)."""
        if self.data_size <= self.sample_size:
            return df
        chunk = self.sample_size // 3
        head = df.iloc[:chunk]
        mid_idx = self.data_size // 2
        mid = df.iloc[mid_idx : mid_idx + chunk]
        tail = df.iloc[-chunk:]
        return pd.concat([head, mid, tail])

    def _hash_pandas_fallback(self, hasher, df: pd.DataFrame):
        """Legacy fallback for complex object types."""
        for col in df.columns:
            val = df[col].astype(str).values
            hasher.update(val.astype(np.bytes_).tobytes())

    def _hash_polars(self, df: pl.DataFrame) -> _DatasetFingerprint:
        """
        Optimized Polars hashing using native Rust execution.
        """
        method = self._determine_method(self.data_size, self.method)
        self.verbose and logger.info(
            f"Hashing Polars: {self.data_size:,} rows - {method}"
        )
        target_df = df
        if method == "sampled" and self.data_size > self.sample_size:
            indices = self._get_sample_indices(self.data_size, self.sample_size)
            target_df = df.gather(indices)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns,
                "dtypes": [str(t) for t in df.dtypes],
                "shape": df.shape,
            },
        )
        row_hashes = target_df.hash_rows()
        hasher.update(memoryview(row_hashes.to_numpy()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars",
            method=method,
        )

    def _hash_pandas_series(self, series: pd.Series) -> _DatasetFingerprint:
        """Hash Pandas Series using the fastest vectorized method."""
        self.verbose and logger.info(f"Hashing Pandas Series: {self.data_size:,} rows")
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "name": series.name if series.name else "None",
                "dtype": str(series.dtype),
                "shape": series.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(series, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception as e:
            self.verbose and logger.warning(f"Series hash failed, falling back: {e}")
            hasher.update(memoryview(series.astype(str).values.tobytes()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas_series",
            method="full",
        )

    def _hash_polars_series(self, series: pl.Series) -> _DatasetFingerprint:
        """Hash Polars Series using native Polars expressions."""
        self.verbose and logger.info(f"Hashing Polars Series: {self.data_size:,} rows")
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"name": series.name, "dtype": str(series.dtype), "shape": series.shape},
        )
        try:
            row_hashes = series.hash()
            hasher.update(memoryview(row_hashes.to_numpy()))
        except Exception as e:
            self.verbose and logger.warning(
                f"Polars series native hash failed. Falling back."
            )
            hasher.update(str(series.to_list()).encode())
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars_series",
            method="full",
        )

    def _hash_numpy(self, arr: np.ndarray) -> _DatasetFingerprint:
        """
        Optimized NumPy hashing using Buffer Protocol (Zero-Copy).
        """
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"shape": arr.shape, "dtype": str(arr.dtype), "strides": arr.strides},
        )
        if arr.flags["C_CONTIGUOUS"] or arr.flags["F_CONTIGUOUS"]:
            hasher.update(memoryview(arr))
        else:
            hasher.update(memoryview(np.ascontiguousarray(arr)))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=arr.shape,
            computed_at=datetime.now().isoformat(),
            data_type="numpy",
            method="full",
        )

    def _hash_sparse(self, matrix: sparse.spmatrix) -> _DatasetFingerprint:
        """
        Optimized sparse hashing. Hashes underlying data arrays directly.
        """
        if not (sparse.isspmatrix_csr(matrix) or sparse.isspmatrix_csc(matrix)):
            matrix = matrix.tocsr()
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher, {"shape": matrix.shape, "format": matrix.format, "nnz": matrix.nnz}
        )
        hasher.update(memoryview(matrix.data))
        hasher.update(memoryview(matrix.indices))
        hasher.update(memoryview(matrix.indptr))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=matrix.shape,
            computed_at=datetime.now().isoformat(),
            data_type=f"sparse_{matrix.format}",
            method="sparse",
        )

    def _determine_method(self, rows: int, requested: str) -> str:
        if requested != "auto":
            return requested
        if rows < 5000000:
            return "full"
        return "sampled"

    def _hash_schema(self, hasher, schema: Dict[str, Any]):
        """Compact schema hashing."""
        hasher.update(
            json.dumps(schema, sort_keys=True, separators=(",", ":")).encode()
        )

    def _get_sample_indices(self, total_rows: int, sample_size: int) -> list:
        """Calculate indices for sampling without generating full range lists."""
        chunk = sample_size // 3
        indices = list(range(min(chunk, total_rows)))
        mid_start = max(0, total_rows // 2 - chunk // 2)
        mid_end = min(mid_start + chunk, total_rows)
        indices.extend(range(mid_start, mid_end))
        last_start = max(0, total_rows - chunk)
        indices.extend(range(last_start, total_rows))
        return sorted(list(set(indices)))


def wape_score(y_true, y_pred):
    """
    Calculates Weighted Absolute Percentage Error (WAPE).

    WAPE = sum(|Error|) / sum(|Groundtruth|)
    """
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    eps = np.finfo(np.float64).eps
    sum_abs_error = np.sum(np.abs(y_true - y_pred))
    sum_abs_truth = np.maximum(np.sum(np.abs(y_true)), eps)
    return sum_abs_error / sum_abs_truth


def forecast_accuracy(y_true, y_pred):
    """
    Calculates Forecast Accuracy.

    FA = 1 - (sum(|Error|) / sum(|Groundtruth|))
    """
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    eps = np.finfo(np.float64).eps
    sum_abs_error = np.sum(np.abs(y_true - y_pred))
    sum_abs_truth = np.maximum(np.sum(np.abs(y_true)), eps)
    return 1 - sum_abs_error / sum_abs_truth


def _normalize_hpo_df(df):
    df = df.copy()
    param_cols = [c for c in df.columns if c.startswith("params_")]
    df[param_cols] = df[param_cols].astype("string[pyarrow]")
    return df


def _validate_numerical_data(data):
    """
    Validates if the input data (NumPy array, Pandas DataFrame/Series,
    Polars DataFrame/Series, or SciPy sparse matrix) contains only
    numerical (integer, float) or boolean values.

    Args:
        data: The input data structure to check.

    Raises:
        TypeError: If the input data contains non-numerical and non-boolean types.
        ValueError: If the input data is of an unsupported type.
    """
    if sparse.issparse(data):
        if not (
            np.issubdtype(data.dtype, np.number) or np.issubdtype(data.dtype, np.bool_)
        ):
            raise TypeError(
                f"Sparse matrix contains unsupported data type: {data.dtype}. Only numerical or boolean types are allowed."
            )
        return
    elif isinstance(data, np.ndarray):
        if not (
            np.issubdtype(data.dtype, np.number) or np.issubdtype(data.dtype, np.bool_)
        ):
            raise TypeError(
                f"NumPy array contains unsupported data type: {data.dtype}. Only numerical or boolean types are allowed."
            )
        return
    elif isinstance(data, (pd.DataFrame, pd.Series)):
        d_types = data.dtypes.apply(lambda x: x.kind)
        non_numerical_mask = ~d_types.isin(["i", "f", "b"])
        if non_numerical_mask.any():
            non_numerical_columns = (
                data.columns[non_numerical_mask].tolist()
                if isinstance(data, pd.DataFrame)
                else [data.name]
            )
            raise TypeError(
                f"Pandas {('DataFrame' if isinstance(data, pd.DataFrame) else 'Series')} contains non-numerical/boolean data. Offending column(s) and types: {data.dtypes[non_numerical_mask].to_dict()}"
            )
        return
    elif isinstance(data, (pl.DataFrame, pl.Series)):
        pl_numerical_types = [
            pl.Int8,
            pl.Int16,
            pl.Int32,
            pl.Int64,
            pl.UInt8,
            pl.UInt16,
            pl.UInt32,
            pl.UInt64,
            pl.Float32,
            pl.Float64,
            pl.Boolean,
        ]
        if isinstance(data, pl.DataFrame):
            for col, dtype in data.schema.items():
                if dtype not in pl_numerical_types:
                    raise TypeError(
                        f"Polars DataFrame column '{col}' has unsupported data type: {dtype}. Only numerical or boolean types are allowed."
                    )
        elif isinstance(data, pl.Series):
            if data.dtype not in pl_numerical_types:
                raise TypeError(
                    f"Polars Series has unsupported data type: {data.dtype}. Only numerical or boolean types are allowed."
                )
        return
    else:
        raise ValueError(
            f"Unsupported data type provided: {type(data)}. Function supports NumPy, Pandas, Polars, and SciPy sparse matrices."
        )


def _smart_split(
    n_samples,
    X,
    y,
    *,
    random_state=42,
    shuffle=True,
    stratify=None,
    fixed_test_split=None,
    verbose=True,
):
    """
    Parameters
    ----------
    n_samples : int
        Number of samples in the dataset (len(X) or len(y))
    X : array-like
        Features
    y : array-like
        Target
    random_state : int
    shuffle : bool
    stratify : array-like or None
        For stratified splitting (recommended for classification)

    Returns
    -------
    If return_val=True  → X_train, X_val, X_test, y_train, y_val, y_test
    If return_val=False → X_train, X_test, y_train, y_test
    """
    if fixed_test_split:
        test_ratio = fixed_test_split
        val_ratio = fixed_test_split
    elif n_samples <= 1000:
        test_ratio = 0.2
        val_ratio = 0.1
    elif n_samples < 10000:
        test_ratio = 0.15
        val_ratio = 0.15
    elif n_samples < 100000:
        test_ratio = 0.1
        val_ratio = 0.1
    elif n_samples < 1000000:
        test_ratio = 0.05
        val_ratio = 0.05
    else:
        test_ratio = 0.01
        val_ratio = 0.01
    (X_train, X_test, y_train, y_test) = train_test_split(
        X,
        y,
        test_size=test_ratio,
        random_state=random_state,
        shuffle=shuffle,
        stratify=stratify,
    )
    val_size_in_train = val_ratio / (1 - test_ratio)
    verbose and logger.info(
        f"Split → Train: {1 - test_ratio:.2%} | Test: {test_ratio:.2%} (no validation set)"
    )
    return (X_train, X_test, y_train, y_test, val_size_in_train)


def _ensure_feature_names(X, feature_names=None):
    if isinstance(X, pd.DataFrame):
        return list(X.columns)
    if isinstance(X, np.ndarray):
        if feature_names is None:
            feature_names = [f"feature_{i}" for i in range(X.shape[1])]
        return feature_names
    raise TypeError("X must be a pandas DataFrame or numpy ndarray")


def _perform_cross_validation(
    model, X, y, cv_folds, shuffle, random_state, n_jobs, verbose
) -> dict[str, Any]:
    """Perform cross-validation on the regression model."""
    verbose and logger.info(f"Performing {cv_folds}-fold cross-validation...")
    cv = KFold(n_splits=cv_folds, shuffle=shuffle, random_state=random_state)
    scoring = {
        "MAE": "neg_mean_absolute_error",
        "MSE": "neg_mean_squared_error",
        "RMSE": "neg_root_mean_squared_error",
        "MAPE": "neg_mean_absolute_percentage_error",
        "R2": "r2",
        "WAPE": make_scorer(wape_score, greater_is_better=False),
        "Forecast_Accuracy": make_scorer(forecast_accuracy, greater_is_better=True),
    }
    cv_results = cross_validate(
        model, X, y, cv=cv, scoring=scoring, return_train_score=False, n_jobs=n_jobs
    )

    def get_score_stats(metric_key, invert_sign=False):
        key = f"test_{metric_key}"
        if key in cv_results:
            scores = cv_results[key]
            if invert_sign:
                scores = -scores
            return (scores.mean(), scores.std())
        return (0.0, 0.0)

    (mae_mean, mae_std) = get_score_stats("MAE", invert_sign=True)
    (mse_mean, mse_std) = get_score_stats("MSE", invert_sign=True)
    (rmse_mean, rmse_std) = get_score_stats("RMSE", invert_sign=True)
    (mape_mean, mape_std) = get_score_stats("MAPE", invert_sign=True)
    (wape_mean, wape_std) = get_score_stats("WAPE", invert_sign=True)
    (r2_mean, r2_std) = get_score_stats("R2", invert_sign=False)
    (fa_mean, fa_std) = get_score_stats("Forecast_Accuracy", invert_sign=False)
    verbose and logger.info(f"CV MAE          : {mae_mean:.4f} (+/- {mae_std:.4f})")
    verbose and logger.info(f"CV MSE          : {mse_mean:.4f} (+/- {mse_std:.4f})")
    verbose and logger.info(f"CV RMSE         : {rmse_mean:.4f} (+/- {rmse_std:.4f})")
    verbose and logger.info(f"CV MAPE         : {mape_mean:.4f} (+/- {mape_std:.4f})")
    verbose and logger.info(f"CV WAPE         : {wape_mean:.4f} (+/- {wape_std:.4f})")
    verbose and logger.info(f"CV R2 Score     : {r2_mean:.4f} (+/- {r2_std:.4f})")
    verbose and logger.info(f"CV Forecast Acc : {fa_mean:.4f} (+/- {fa_std:.4f})")
    CV_metrics = pd.DataFrame(
        {
            "Metric": [
                "Mean Absolute Error (MAE)",
                "Mean Squared Error (MSE)",
                "Root Mean Squared Error (RMSE)",
                "Mean Absolute Percentage Error (MAPE)",
                "Weighted Absolute Percentage Error (WAPE)",
                "R2 Score",
                "Forecast Accuracy",
            ],
            "Mean": [
                mae_mean,
                mse_mean,
                rmse_mean,
                mape_mean,
                wape_mean,
                r2_mean,
                fa_mean,
            ],
            "Std": [mae_std, mse_std, rmse_std, mape_std, wape_std, r2_std, fa_std],
        }
    )
    return CV_metrics


def _compute_score(model, X, y, metric):
    """
    Computes the score for the model on the given data based on the selected metric.
    Assumes 'metric' is passed as the short code (e.g., "MAE", "R2", "FA").
    """
    y_pred = model.predict(X)
    if metric == "mae":
        score = mean_absolute_error(y, y_pred)
    elif metric == "mse":
        score = mean_squared_error(y, y_pred)
    elif metric == "rmse":
        score = root_mean_squared_error(y, y_pred)
    elif metric == "mape":
        score = mean_absolute_percentage_error(y, y_pred)
    elif metric == "r2":
        score = r2_score(y, y_pred)
    elif metric == "wape" or metric == "fa":
        y_true_np = np.array(y, dtype=float).flatten()
        y_pred_np = np.array(y_pred, dtype=float).flatten()
        eps = np.finfo(np.float64).eps
        sum_abs_error = np.sum(np.abs(y_true_np - y_pred_np))
        sum_abs_truth = np.maximum(np.sum(np.abs(y_true_np)), eps)
        wape_val = sum_abs_error / sum_abs_truth
        if metric == "fa":
            score = 1.0 - wape_val
        else:
            score = wape_val
    else:
        raise ValueError(f"Unknown regression metric: {metric}")
    return score


def _get_cv_scoring_object(metric: str) -> Any:
    """
    Returns a scoring object (string or callable) suitable for cross_validate or GridSearchCV.
    Used during HPO for Regression.
    """
    if metric == "mae":
        return "neg_mean_absolute_error"
    elif metric == "mse":
        return "neg_mean_squared_error"
    elif metric == "rmse":
        return "neg_root_mean_squared_error"
    elif metric == "r2":
        return "r2"
    elif metric == "mape":
        return "neg_mean_absolute_percentage_error"
    elif metric == "wape":
        return make_scorer(wape_score, greater_is_better=False)
    elif metric == "fa":
        return make_scorer(forecast_accuracy, greater_is_better=True)
    else:
        return "neg_root_mean_squared_error"


def _hyperparameters_optimization(
    X,
    y,
    constant_hyperparameters,
    optimization_metric,
    val_ratio,
    shuffle_split,
    use_cross_val,
    cv_folds,
    n_trials=50,
    strategy="maximize",
    sampler="Tree-structured Parzen",
    standard_scaling=False,
    seed=None,
    n_jobs=-1,
    verbose=False,
):
    direction = "maximize" if strategy.lower() == "maximize" else "minimize"
    sampler_map = {
        "Tree-structured Parzen": TPESampler(seed=seed),
        "Gaussian Process": GPSampler(seed=seed),
        "CMA-ES": CmaEsSampler(seed=seed),
        "Random Search": RandomSampler(seed=seed),
        "Random Sobol Search": QMCSampler(seed=seed),
    }
    if sampler in sampler_map:
        chosen_sampler = sampler_map[sampler]
    else:
        logger.warning(f"Sampler '{sampler}' not recognized → falling back to TPE")
        chosen_sampler = TPESampler(seed=seed)
    chosen_pruner = HyperbandPruner()
    if use_cross_val:
        cv = KFold(n_splits=cv_folds, shuffle=shuffle_split, random_state=seed)
        cv_score_obj = _get_cv_scoring_object(optimization_metric)
    else:
        (X_train, X_val, y_train, y_val) = train_test_split(
            X, y, test_size=val_ratio, random_state=seed, shuffle=shuffle_split
        )

    def logging_callback(study: Study, trial: FrozenTrial):
        """Callback function to log trial progress"""
        verbose and logger.info(
            f"Trial {trial.number} finished with value: {trial.value} and parameters: {trial.params}"
        )
        try:
            verbose and logger.info(f"Best value so far: {study.best_value}")
            verbose and logger.info(f"Best parameters so far: {study.best_params}")
        except ValueError:
            verbose and logger.info(f"No successful trials completed yet")
        verbose and logger.info(f"" + "-" * 50)

    def objective(trial):
        try:
            alpha = trial.suggest_float("alpha", 1e-05, 100.0, log=True)
            l1_ratio = trial.suggest_float("l1_ratio", 0.0, 1.0)
            tol = trial.suggest_float("tol", 0.0001, 1, log=True)
            max_iter = constant_hyperparameters.get("max_iter")
            intercept = constant_hyperparameters.get("intercept")
            model = ElasticNet(
                alpha=alpha,
                l1_ratio=l1_ratio,
                tol=tol,
                max_iter=max_iter,
                fit_intercept=intercept,
            )
            if standard_scaling:
                model = Pipeline([("scaler", StandardScaler()), ("lr", model)])
            if use_cross_val:
                scores = cross_validate(
                    model, X, y, cv=cv, n_jobs=n_jobs, scoring=cv_score_obj
                )
                return scores["test_score"].mean()
            else:
                model.fit(X_train, y_train)
                score = _compute_score(model, X_val, y_val, optimization_metric)
                return score
        except Exception as e:
            verbose and logger.error(
                f"Trial {trial.number} failed with error: {str(e)}"
            )
            raise

    study = create_study(
        direction=direction, sampler=chosen_sampler, pruner=chosen_pruner
    )
    study.optimize(
        objective,
        n_trials=n_trials,
        catch=(Exception,),
        n_jobs=n_jobs,
        callbacks=[logging_callback],
    )
    verbose and logger.info(f"Optimization completed!")
    verbose and logger.info(f"   Best Alpha        : {study.best_params['alpha']:.6e}")
    verbose and logger.info(
        f"   Best L1 Ratio     : {study.best_params['l1_ratio']:.4f}"
    )
    verbose and logger.info(f"   Best Tolerance    : {study.best_params['tol']:.6e}")
    verbose and logger.info(
        f"   Best {optimization_metric:<13}: {study.best_value:.4f}"
    )
    verbose and logger.info(f"   Sampler used      : {sampler}")
    verbose and logger.info(f"   Direction         : {direction}")
    if use_cross_val:
        verbose and logger.info(f"   Cross-validation  : {cv_folds}-fold")
    else:
        verbose and logger.info(f"   Validation        : single train/val split")
    trials = study.trials_dataframe()
    trials["best_value"] = trials["value"].cummax()
    cols = list(trials.columns)
    value_idx = cols.index("value")
    cols = [c for c in cols if c != "best_value"]
    new_order = cols[: value_idx + 1] + ["best_value"] + cols[value_idx + 1 :]
    trials = trials[new_order]
    return (study.best_params, trials)


def _combine_test_data(X_test, y_true, y_pred, features_names=None):
    """
    Combine X_test, y_true, y_pred into a single DataFrame.

    Parameters:
    -----------
    X_test : pandas/polars DataFrame, numpy array, or scipy sparse matrix
        Test features
    y_true : pandas/polars Series, numpy array, or list
        True labels
    y_pred : pandas/polars Series, numpy array, or list
        Predicted labels

    Returns:
    --------
    pandas.DataFrame
        Combined DataFrame with features, y_true, and y_pred
    """
    if sparse.issparse(X_test):
        X_df = pd.DataFrame(X_test.toarray())
    elif isinstance(X_test, np.ndarray):
        X_df = pd.DataFrame(X_test)
    elif hasattr(X_test, "to_pandas"):
        X_df = X_test.to_pandas()
    elif isinstance(X_test, pd.DataFrame):
        X_df = X_test.copy()
    else:
        raise TypeError(f"Unsupported type for X_test: {type(X_test)}")
    if X_df.columns.tolist() == list(range(len(X_df.columns))):
        X_df.columns = (
            [f"feature_{i}" for i in range(len(X_df.columns))]
            if features_names is None
            else features_names
        )
    if isinstance(y_true, list):
        y_true_series = pd.Series(y_true, name="y_true")
    elif isinstance(y_true, np.ndarray):
        y_true_series = pd.Series(y_true, name="y_true")
    elif hasattr(y_true, "to_pandas"):
        y_true_series = y_true.to_pandas()
        y_true_series.name = "y_true"
    elif isinstance(y_true, pd.Series):
        y_true_series = y_true.copy()
        y_true_series.name = "y_true"
    else:
        raise TypeError(f"Unsupported type for y_true: {type(y_true)}")
    if isinstance(y_pred, list):
        y_pred_series = pd.Series(y_pred, name="y_pred")
    elif isinstance(y_pred, np.ndarray):
        y_pred_series = pd.Series(y_pred, name="y_pred")
    elif hasattr(y_pred, "to_pandas"):
        y_pred_series = y_pred.to_pandas()
        y_pred_series.name = "y_pred"
    elif isinstance(y_pred, pd.Series):
        y_pred_series = y_pred.copy()
        y_pred_series.name = "y_pred"
    else:
        raise TypeError(f"Unsupported type for y_pred: {type(y_pred)}")
    X_df = X_df.reset_index(drop=True)
    y_true_series = y_true_series.reset_index(drop=True)
    y_pred_series = y_pred_series.reset_index(drop=True)
    result_df = pd.concat([X_df, y_true_series, y_pred_series], axis=1)
    return result_df


def _smart_shap_background(
    X: Union[np.ndarray, pd.DataFrame],
    model_type: str = "tree",
    seed: int = 42,
    verbose: bool = True,
) -> Union[np.ndarray, pd.DataFrame, object]:
    """
    Intelligently prepares a background dataset for SHAP based on model type.

    Strategies:
    - Tree: Higher sample cap (1000), uses Random Sampling (preserves data structure).
    - Other: Lower sample cap (100), uses K-Means (maximizes info density).
    """
    (n_rows, n_features) = X.shape
    if model_type == "tree":
        max_samples = 1000
        use_kmeans = False
    else:
        max_samples = 100
        use_kmeans = True
    if n_rows <= max_samples:
        verbose and logger.info(
            f"✓ Dataset small ({n_rows} <= {max_samples}). Using full data."
        )
        return X
    verbose and logger.info(
        f"⚡ Large dataset detected ({n_rows} rows). Optimization Strategy: {('K-Means' if use_kmeans else 'Random Sampling')}"
    )
    if use_kmeans:
        try:
            verbose and logger.info(
                f"   Summarizing to {max_samples} weighted centroids..."
            )
            return shap.kmeans(X, max_samples)
        except Exception as e:
            logger.warning(
                f"   K-Means failed ({str(e)}). Falling back to random sampling."
            )
            return shap.sample(X, max_samples, random_state=seed)
    else:
        verbose and logger.info(f"   Sampling {max_samples} random rows...")
        return shap.sample(X, max_samples, random_state=seed)


def _class_index_df(model):
    columns = {"index": pd.Series(dtype="int64"), "class": pd.Series(dtype="object")}
    if model is None:
        return pd.DataFrame(columns)
    classes = getattr(model, "classes_", None)
    if classes is None:
        return pd.DataFrame(columns)
    return pd.DataFrame({"index": range(len(classes)), "class": classes})


def train_reg_elasticnet(
    X: DataFrame, y: Union[DataSeries, NDArray, List], options=None
) -> Tuple[
    Any, Any, Any, Union[DataFrame, Dict], DataFrame, DataFrame, DataFrame, Dict
]:
    options = options or {}
    intercept = options.get("intercept", True)
    alpha = options.get("alpha", 1.0)
    l1_ratio = options.get("l1_ratio", 0.5)
    tol = options.get("tol", 0.0001)
    max_iter = options.get("max_iterations", 1000)
    standard_scaling = options.get("standard_scaling", True)
    auto_split = options.get("auto_split", True)
    test_val_size = options.get("test_val_size", 15) / 100
    shuffle_split = options.get("shuffle_split", True)
    retrain_on_full = options.get("retrain_on_full", False)
    use_cross_validation = options.get("use_cross_validation", False)
    cv_folds = options.get("cv_folds", 5)
    use_hpo = options.get("use_hyperparameter_optimization", False)
    optimization_metric = options.get(
        "optimization_metric", "Root Mean Squared Error (RMSE)"
    )
    optimization_metric = METRICS_DICT[optimization_metric]
    optimization_method = options.get("optimization_method", "Tree-structured Parzen")
    optimization_iterations = options.get("optimization_iterations", 50)
    return_shap_explainer = options.get("return_shap_explainer", False)
    use_shap_sampler = options.get("use_shap_sampler", False)
    metrics_as = options.get("metrics_as", "Dataframe")
    n_jobs_str = options.get("n_jobs", "1")
    random_state = options.get("random_state", 42)
    activate_caching = options.get("activate_caching", False)
    verbose = options.get("verbose", True)
    n_jobs_int = -1 if n_jobs_str == "All" else int(n_jobs_str)
    skip_computation = False
    Scaler = None
    Model = None
    Metrics = pd.DataFrame()
    CV_Metrics = pd.DataFrame()
    SHAP = None
    HPO_Trials = pd.DataFrame()
    HPO_Best = None
    fa = None
    wape = None
    mae = None
    mse = None
    rmse = None
    r2 = None
    mape = None
    (n_samples, _) = X.shape
    shap_feature_names = _ensure_feature_names(X)
    if standard_scaling:
        verbose and logger.info("Standard scaling is activated")
    if activate_caching:
        verbose and logger.info(f"Caching is activate")
        data_hasher = _UniversalDatasetHasher(n_samples, verbose=verbose)
        X_hash = data_hasher.hash_data(X).hash
        y_hash = data_hasher.hash_data(y).hash
        all_hash_base_text = f"HASH BASE TEXTPandas Version {pd.__version__}POLARS Version {pl.__version__}Numpy Version {np.__version__}Scikit Learn Version {sklearn.__version__}Scipy Version {scipy.__version__}{('SHAP Version ' + shap.__version__ if return_shap_explainer else 'NO SHAP Version')}{X_hash}{y_hash}{alpha}{l1_ratio}{tol}{max_iter}{('Use HPO' if use_hpo else 'No HPO')}{(optimization_metric if use_hpo else 'No HPO Metric')}{(optimization_method if use_hpo else 'No HPO Method')}{(optimization_iterations if use_hpo else 'No HPO Iter')}{(cv_folds if use_cross_validation else 'No CV')}{standard_scaling}{('Auto Split' if auto_split else test_val_size)}{shuffle_split}{return_shap_explainer}{use_shap_sampler}{random_state}"
        all_hash = hashlib.sha256(all_hash_base_text.encode("utf-8")).hexdigest()
        verbose and logger.info(f"Hash was computed: {all_hash}")
        temp_folder = Path(tempfile.gettempdir())
        cache_folder = temp_folder / "coded-flows-cache"
        cache_folder.mkdir(parents=True, exist_ok=True)
        model_path = cache_folder / f"{all_hash}.model"
        metrics_dict_path = cache_folder / f"metrics_{all_hash}.json"
        metrics_df_path = cache_folder / f"metrics_{all_hash}.parquet"
        cv_metrics_path = cache_folder / f"cv_metrics_{all_hash}.parquet"
        hpo_trials_path = cache_folder / f"hpo_trials_{all_hash}.parquet"
        hpo_best_params_path = cache_folder / f"hpo_best_params_{all_hash}.json"
        prediction_set_path = cache_folder / f"prediction_set_{all_hash}.parquet"
        shap_path = cache_folder / f"{all_hash}.shap"
        scaler_path = cache_folder / f"{all_hash}.scaler"
        skip_computation = model_path.is_file()
    if not skip_computation:
        try:
            _validate_numerical_data(X)
        except Exception as e:
            verbose and logger.error(
                f"Only numerical or boolean types are allowed for 'X' input!"
            )
            raise
        features_names = X.columns if hasattr(X, "columns") else None
        fixed_test_split = None if auto_split else test_val_size
        (X_train, X_test, y_train, y_test, val_ratio) = _smart_split(
            n_samples,
            X,
            y,
            random_state=random_state,
            shuffle=shuffle_split,
            fixed_test_split=fixed_test_split,
            verbose=verbose,
        )
        if use_hpo:
            verbose and logger.info(f"Performing Hyperparameters Optimization")
            constant_hyperparameters = {"max_iter": max_iter, "intercept": intercept}
            (HPO_Best, HPO_Trials) = _hyperparameters_optimization(
                X_train,
                y_train,
                constant_hyperparameters,
                optimization_metric,
                val_ratio,
                shuffle_split,
                use_cross_validation,
                cv_folds,
                optimization_iterations,
                METRICS_OPT[optimization_metric],
                optimization_method,
                random_state,
                n_jobs_int,
                verbose=verbose,
            )
            HPO_Trials = _normalize_hpo_df(HPO_Trials)
            alpha = HPO_Best["alpha"]
            l1_ratio = HPO_Best["l1_ratio"]
            tol = HPO_Best["tol"]
        if standard_scaling:
            Scaler = StandardScaler().set_output(transform="pandas")
            X_train = Scaler.fit_transform(X_train)
        Model = ElasticNet(
            alpha=alpha,
            l1_ratio=l1_ratio,
            tol=tol,
            max_iter=max_iter,
            fit_intercept=intercept,
        )
        if use_cross_validation and (not use_hpo):
            verbose and logger.info(
                f"Using Cross-Validation to measure performance metrics"
            )
            CV_Metrics = _perform_cross_validation(
                Model,
                X_train,
                y_train,
                cv_folds,
                shuffle_split,
                random_state,
                n_jobs_int,
                verbose,
            )
        Model.fit(X_train, y_train)
        y_pred = Model.predict(Scaler.transform(X_test) if standard_scaling else X_test)
        fa = forecast_accuracy(y_test, y_pred)
        wape = wape_score(y_test, y_pred)
        mae = mean_absolute_error(y_test, y_pred)
        mse = mean_squared_error(y_test, y_pred)
        rmse = root_mean_squared_error(y_test, y_pred)
        r2 = r2_score(y_test, y_pred)
        mape = mean_absolute_percentage_error(y_test, y_pred)
        if metrics_as == "Dataframe":
            Metrics = pd.DataFrame(
                {
                    "Metric": [
                        "Forecast Accuracy",
                        "Weighted Absolute Percentage Error",
                        "Mean Absolute Error",
                        "Mean Squared Error",
                        "Root Mean Squared Error",
                        "R2 Score",
                        "Mean Absolute Percentage Error",
                    ],
                    "Value": [fa, wape, mae, mse, rmse, r2, mape],
                }
            )
        else:
            Metrics = {
                "forecast_accuracy": fa,
                "weighted_absolute_percentage_error ": wape,
                "mean_absolute_error": mae,
                "mean_squared_error": mse,
                "root_mean_squared_error": rmse,
                "r2_score": r2,
                "mean_absolute_percentage_error": mape,
            }
        verbose and logger.info(f"Forecast Accuracy                  : {fa:.2%}")
        verbose and logger.info(f"Weighted Absolute Percentage Error : {wape:.2%}")
        verbose and logger.info(f"Mean Absolute Error                : {mae:.4f}")
        verbose and logger.info(f"Mean Squared Error                 : {mse:.4f}")
        verbose and logger.info(f"Root Mean Squared Error            : {rmse:.4f}")
        verbose and logger.info(f"R2 Score                           : {r2:.4f}")
        verbose and logger.info(f"Mean Absolute Percentage Error     : {mape:.2%}")
        Prediction_Set = _combine_test_data(X_test, y_test, y_pred, features_names)
        verbose and logger.info(f"Prediction Set created")
        if retrain_on_full:
            verbose and logger.info(
                "Retraining model on full dataset for production deployment"
            )
            if standard_scaling:
                Scaler = StandardScaler().set_output(transform="pandas")
                X = Scaler.fit_transform(X)
            Model.fit(X, y)
            verbose and logger.info(
                "Model successfully retrained on full dataset. Reported metrics remain from original held-out test set."
            )
        if return_shap_explainer:
            SHAP = shap.LinearExplainer(
                Model,
                (
                    _smart_shap_background(
                        X if retrain_on_full else X_train,
                        model_type="other",
                        seed=random_state,
                        verbose=verbose,
                    )
                    if use_shap_sampler
                    else X if retrain_on_full else X_train
                ),
                feature_names=shap_feature_names,
            )
            verbose and logger.info(f"SHAP explainer generated")
        if activate_caching:
            verbose and logger.info(f"Caching output elements")
            joblib.dump(Model, model_path)
            if isinstance(Metrics, dict):
                with metrics_dict_path.open("w", encoding="utf-8") as f:
                    json.dump(Metrics, f, ensure_ascii=False, indent=4)
            else:
                Metrics.to_parquet(metrics_df_path)
            if use_cross_validation and (not use_hpo):
                CV_Metrics.to_parquet(cv_metrics_path)
            if use_hpo:
                HPO_Trials.to_parquet(hpo_trials_path)
                with hpo_best_params_path.open("w", encoding="utf-8") as f:
                    json.dump(HPO_Best, f, ensure_ascii=False, indent=4)
            Prediction_Set.to_parquet(prediction_set_path)
            if return_shap_explainer:
                with shap_path.open("wb") as f:
                    joblib.dump(SHAP, f)
            joblib.dump(Scaler, scaler_path)
            verbose and logger.info(f"Caching done")
    else:
        verbose and logger.info(f"Skipping computations and loading cached elements")
        Model = joblib.load(model_path)
        verbose and logger.info(f"Model loaded")
        if metrics_dict_path.is_file():
            with metrics_dict_path.open("r", encoding="utf-8") as f:
                Metrics = json.load(f)
        else:
            Metrics = pd.read_parquet(metrics_df_path)
        verbose and logger.info(f"Metrics loaded")
        if use_cross_validation and (not use_hpo):
            CV_Metrics = pd.read_parquet(cv_metrics_path)
            verbose and logger.info(f"Cross Validation metrics loaded")
        if use_hpo:
            HPO_Trials = pd.read_parquet(hpo_trials_path)
            with hpo_best_params_path.open("r", encoding="utf-8") as f:
                HPO_Best = json.load(f)
            verbose and logger.info(
                f"Hyperparameters Optimization trials and best params loaded"
            )
        Prediction_Set = pd.read_parquet(prediction_set_path)
        verbose and logger.info(f"Prediction Set loaded")
        if return_shap_explainer:
            with shap_path.open("rb") as f:
                SHAP = joblib.load(f)
            verbose and logger.info(f"SHAP Explainer loaded")
        Scaler = joblib.load(scaler_path)
        verbose and logger.info(f"Standard Scaler loaded")
    return (
        Model,
        SHAP,
        Scaler,
        Metrics,
        CV_Metrics,
        Prediction_Set,
        HPO_Trials,
        HPO_Best,
    )

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • scikit-learn
  • pandas
  • numpy
  • torch
  • numba>=0.56.0
  • shap
  • cmaes
  • optuna
  • scipy
  • polars
  • xxhash