Reg. XGBoost

Train an XGBoost regression model.

Reg. XGBoost

Processing

This brick trains an XGBoost regression model to predict continuous numerical values. XGBoost (Extreme Gradient Boosting) is a powerful machine learning algorithm that builds an ensemble of decision trees, where each new tree corrects the errors of previous ones.

The brick handles the complete training workflow: it splits your data into training and testing sets, trains the model, evaluates its performance using multiple metrics, and optionally performs advanced techniques like hyperparameter optimization to find the best settings automatically. It returns the trained model along with detailed performance metrics, feature importance rankings, and optional explainability tools.

Important Notice for Apple Silicon Users

If you're using a Mac with Apple Silicon (M1, M2, M3, M4, or later), you must install the OpenMP library before using this brick, or XGBoost will crash or fail to load.

Required Setup:

Open your Terminal and run:

brew install libomp

If you already have Homebrew installed but still encounter errors, try:

brew reinstall libomp

Common Error Messages (without libomp):

  • OMP: Error #15: ...
  • dlopen(...libomp.dylib...): image not found
  • Symbol not found: _omp_init_lock

Inputs

X
The dataset containing your input features (the variables used to make predictions). This should be a table where each column represents a different feature and each row represents a different data point.
y
The target values you want to predict (the actual numerical outcomes). This should be a single column or list of numbers corresponding to each row in your feature dataset.

Inputs Types

Input Types
X DataFrame
y DataSeries, NDArray, List

You can check the list of supported types here: Available Type Hints.

Outputs

Model
The trained XGBoost regression model, ready to make predictions on new data. You can use this model in subsequent bricks to predict values for unseen examples.
SHAP
A SHAP explainer object that helps you understand why the model made specific predictions. This is only returned if the "SHAP Explainer" option is enabled, otherwise it returns None.
Label Encoder
An encoder object used internally to preprocess the target variable. This is primarily for internal consistency and typically not needed for regression workflows.
Metrics
Performance measurements showing how well the model predicts on the test set. Contains metrics like Mean Absolute Error, R2 Score, and Forecast Accuracy. The format (table or dictionary) depends on your "Metrics as" setting.
CV Metrics
Cross-validation performance metrics that provide a more robust estimate of model quality by testing on multiple data splits. Only populated if "Enable Cross-Validation" is turned on, otherwise returns an empty table.
Features Importance
A ranked table showing which input features have the most influence on the model's predictions. Features at the top contribute more to the model's decisions.
Prediction Set
A table combining the test data features, actual values, and predicted values side-by-side. Useful for analyzing individual predictions and identifying where the model performs well or poorly.
HPO Trials
A detailed log of all hyperparameter combinations tested during optimization, showing how each configuration performed. Only populated if "Hyperparameter Optim." is enabled, otherwise returns an empty table.
HPO Best
The best hyperparameter configuration found during optimization, stored as a dictionary. Only populated if "Hyperparameter Optim." is enabled, otherwise returns None.

The Metrics output contains the following performance measurements:

  • Forecast Accuracy: The proportion of total error eliminated by the model (1 - WAPE). Values closer to 1.0 indicate better predictions.
  • Weighted Absolute Percentage Error: Sum of absolute errors divided by sum of actual values. Measures overall prediction accuracy as a percentage.
  • Mean Absolute Error: Average absolute difference between predictions and actual values, in the same units as your target variable.
  • Mean Squared Error: Average squared difference between predictions and actual values. Emphasizes larger errors.
  • Root Mean Squared Error: Square root of MSE, returning to original units while still penalizing large errors more than MAE.
  • R2 Score: Proportion of variance explained by the model. 1.0 = perfect predictions, 0.0 = no better than predicting the mean.
  • Mean Absolute Percentage Error: Average percentage error across all predictions.

The Features Importance output contains:

  • feature: The name of each input feature from your dataset.
  • importance: A score indicating how much this feature contributes to the model's predictions. Higher values mean the feature is more important for making accurate predictions.

The Prediction Set output contains:

  • feature_0, feature_1, ... (or your actual feature names): All the input features from the test set.
  • y_true: The actual target values from your test data.
  • y_pred: The model's predicted values for the test data.

The HPO Trials output contains (when hyperparameter optimization is enabled):

  • number: Trial number in the optimization sequence.
  • value: The optimization metric value achieved by this trial's configuration.
  • datetime_start / datetime_complete: When the trial started and finished.
  • duration: How long the trial took to complete.
  • params_*: The specific hyperparameter values tested in this trial (e.g., params_n_estimators, params_learning_rate).
  • state: Whether the trial completed successfully or failed.
  • best_value: The best metric value found up to this point in the optimization.

The HPO Best output contains (when hyperparameter optimization is enabled):

  • n_estimators: Optimal number of trees.
  • max_depth: Optimal tree depth.
  • learning_rate: Optimal learning rate.
  • reg_lambda: Optimal L2 regularization.
  • reg_alpha: Optimal L1 regularization.
  • gamma: Optimal minimum split loss.
  • min_child_weight: Optimal minimum child weight.
  • subsample: Optimal subsample ratio.
  • colsample_bytree: Optimal column sampling ratio.

Outputs Types

Output Types
Model Any
SHAP Any
Label_Encoder Any
Metrics DataFrame, Dict
CV_Metrics DataFrame
Features_Importance DataFrame
Prediction_Set DataFrame
HPO_Trials DataFrame
HPO_Best Dict

You can check the list of supported types here: Available Type Hints.

Options

The Reg. XGBoost brick contains some changeable options:

Max Number of Trees
The maximum number of decision trees the model will build. More trees can improve accuracy but take longer to train and may overfit. Start with 100 and increase if performance improves on validation data.
Enable Early Stopping
When enabled, training stops automatically if the model's performance stops improving on a validation set, preventing wasted computation and overfitting. The model will use the best iteration found rather than all trees.
Early Stopping Rounds
How many consecutive rounds without improvement before training stops. For example, with a value of 10, if the validation score doesn't improve for 10 consecutive trees, training halts.
Use Dropout
Enables DART (Dropouts meet Multiple Additive Regression Trees), a variant that randomly drops trees during training to prevent overfitting. This can improve generalization but may slow down training.
Max Depth
The maximum number of levels each decision tree can have. Deeper trees can capture more complex patterns but are more prone to overfitting. Values between 3-10 work well for most problems.
Learning Rate (Eta)
Controls how much each tree contributes to the final prediction. Lower values (e.g., 0.01) make learning more gradual and robust but require more trees. Higher values (e.g., 0.3) speed up training but may overshoot optimal solutions.
L2 Regularization (Lambda)
Adds a penalty for overly complex models based on the squared magnitude of tree weights. Higher values create simpler, more generalizable models. Default of 1.0 provides moderate regularization.
L1 Regularization (Alpha)
Adds a penalty for overly complex models based on the absolute magnitude of tree weights. Can push some feature weights to exactly zero, effectively performing feature selection. Default of 0.0 means no L1 regularization.
Gamma (Min Split Loss)
The minimum reduction in error required to make a further split in a tree. Higher values make the model more conservative, creating simpler trees. A value of 0 means no minimum requirement.
Min Child Weight
The minimum sum of instance weights needed in a child node. Higher values prevent the model from learning overly specific patterns that only apply to a few examples.
Subsample Ratio
The fraction of training samples used to build each tree. For example, 0.8 means each tree sees 80% of the data, randomly selected. Values below 1.0 can prevent overfitting by introducing randomness.
Colsample by Tree
The fraction of features (columns) randomly selected when building each tree. For example, 0.8 means each tree uses 80% of available features. This adds diversity to the ensemble and can improve generalization.
Auto Split Data
When enabled, the brick automatically determines optimal train/test split ratios based on your dataset size. Larger datasets use smaller test percentages. When disabled, you control the exact percentage via "Test/Validation Set %".
Shuffle Split
Randomly shuffles your data before splitting into train/test sets. Recommended to ensure both sets are representative of the overall data distribution, unless your data has a meaningful time-based order.
Test/Validation Set %
The percentage of data reserved for testing (and validation if early stopping is enabled). Only used when "Auto Split Data" is disabled. For example, 15 means 15% for testing and potentially 15% for validation.
Retrain On Full Data
When enabled, after evaluating performance on the test set, the model is retrained using all available data (train + test) to maximize learning from every example. The reported metrics still reflect the original test set performance.
Enable Cross-Validation
Enables k-fold cross-validation, which splits the training data into multiple folds, trains on some folds while testing on others, and repeats this process to get a more reliable performance estimate. This is more robust than a single train/test split.
Number of CV Folds
How many folds to use in cross-validation. For example, 5-fold CV splits data into 5 parts, trains 5 times (each time holding out a different part for testing), and averages the results.
Hyperparameter Optim.
Enables automatic hyperparameter optimization, which systematically tests different combinations of settings (tree depth, learning rate, etc.) to find the configuration that produces the best model. This can significantly improve performance but increases training time.
Optimization Metric
The performance measure used to evaluate different hyperparameter configurations during optimization. The algorithm will try to optimize this specific metric.
  • Forecast Accuracy: Measures how much of the total error is eliminated. Higher is better (1.0 = perfect, 0.0 = predictions equal actual values on average).
  • Weighted Absolute Percentage Error (WAPE): Sum of absolute errors divided by sum of actual values. Lower is better. Robust to scale and doesn't break on zero values.
  • Mean Absolute Error (MAE): Average absolute difference between predictions and actual values. Lower is better. Easy to interpret in the same units as your target.
  • Mean Squared Error (MSE): Average squared difference between predictions and actual values. Lower is better. Penalizes large errors more heavily than MAE.
  • Root Mean Squared Error (RMSE): Square root of MSE, returning error to original units. Lower is better. More interpretable than MSE while still penalizing large errors.
  • R2 Score: Proportion of variance in the target explained by the model. Higher is better (1.0 = perfect fit, 0.0 = model no better than predicting the mean).
  • Mean Absolute Percentage Error (MAPE): Average absolute percentage error. Lower is better. Can be problematic with values near zero.
Optimization Method
The search strategy used to explore different hyperparameter combinations.
  • Tree-structured Parzen: An intelligent Bayesian approach that learns from previous trials to suggest promising configurations. Efficient and generally performs well.
  • Gaussian Process: Uses probabilistic modeling to predict which configurations might perform best. Good for smooth optimization landscapes.
  • CMA-ES: Evolution strategy that adapts its search based on the distribution of good solutions. Effective for complex optimization problems.
  • Random Sobol Search: Quasi-random sampling that covers the search space more evenly than pure random search. Good baseline approach.
  • Random Search: Completely random sampling of configurations. Simple but can be effective, especially with sufficient iterations.
Optimization Iterations
How many different hyperparameter configurations to test. More iterations increase the chance of finding better settings but take longer. 50-100 iterations typically provide good results.
Metrics as
Controls the output format for performance metrics.
SHAP Explainer
When enabled, generates a SHAP (SHapley Additive exPlanations) explainer object that can explain individual predictions by showing each feature's contribution. Useful for understanding model decisions and building trust.
SHAP Sampler
When enabled, uses a representative subset of training data for SHAP computations instead of the entire dataset. This speeds up explanation generation for large datasets while maintaining accuracy.
SHAP Feature Perturbation
Controls how SHAP calculates feature contributions.
  • Interventional: Replaces feature values with samples from the background dataset. More accurate but slower. Recommended for most cases.
  • Tree Path Dependent: Uses the tree structure itself to calculate contributions. Faster but may give different results than interventional.
Number of Jobs
Controls parallel processing for training and optimization. Higher values use more CPU cores for faster computation.
Random State
A seed value for random number generation, ensuring reproducible results. Using the same seed with the same data and settings will produce identical models. Change this value to try different random initializations.
Brick Caching
When enabled, saves the trained model and results to disk. If you run the brick again with identical data and settings, it loads the cached results instantly instead of retraining. Useful for iterative workflow development.
Verbose Logging
When enabled, displays detailed progress messages during training, including metric values, split information, and optimization progress. Helpful for understanding what the brick is doing and debugging issues.
import logging
import warnings
import shap
import json
import xxhash
import hashlib
import tempfile
import sklearn
import scipy
import joblib
import numpy as np
import pandas as pd
import polars as pl
from xgboost import XGBRegressor
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight
from pathlib import Path
from scipy import sparse
from collections import Counter
from optuna.samplers import (
    TPESampler,
    RandomSampler,
    GPSampler,
    CmaEsSampler,
    QMCSampler,
)
import optuna
from optuna import Study
from optuna.trial import FrozenTrial
from optuna.pruners import HyperbandPruner
from optuna import create_study
from sklearn.model_selection import train_test_split, cross_validate, KFold
from sklearn.metrics import (
    mean_absolute_error,
    mean_squared_error,
    r2_score,
    root_mean_squared_error,
    mean_absolute_percentage_error,
    make_scorer,
)
from dataclasses import dataclass
from datetime import datetime
from coded_flows.types import (
    Union,
    Dict,
    List,
    Tuple,
    NDArray,
    DataFrame,
    DataSeries,
    Any,
    Tuple,
)
from coded_flows.utils import CodedFlowsLogger

logger = CodedFlowsLogger(name="Reg. XGBoost", level=logging.INFO)
optuna.logging.set_verbosity(optuna.logging.ERROR)
warnings.filterwarnings("ignore", category=optuna.exceptions.ExperimentalWarning)
METRICS_DICT = {
    "Forecast Accuracy": "fa",
    "Weighted Absolute Percentage Error (WAPE)": "wape",
    "Mean Absolute Error (MAE)": "mae",
    "Mean Squared Error (MSE)": "mse",
    "Root Mean Squared Error (RMSE)": "rmse",
    "R2 Score": "r2",
    "Mean Absolute Percentage Error (MAPE)": "mape",
}
METRICS_OPT = {
    "fa": "maximize",
    "wape": "minimize",
    "mae": "minimize",
    "mse": "minimize",
    "rmse": "minimize",
    "r2": "maximize",
    "mape": "minimize",
}
DataType = Union[
    pd.DataFrame, pl.DataFrame, np.ndarray, sparse.spmatrix, pd.Series, pl.Series
]


@dataclass
class _DatasetFingerprint:
    """Lightweight fingerprint of a dataset."""

    hash: str
    shape: tuple
    computed_at: str
    data_type: str
    method: str


class _UniversalDatasetHasher:
    """
    High-performance dataset hasher optimizing for zero-copy operations
    and native backend execution (C/Rust).
    """

    def __init__(
        self,
        data_size: int,
        method: str = "auto",
        sample_size: int = 100000,
        verbose: bool = False,
    ):
        self.method = method
        self.sample_size = sample_size
        self.data_size = data_size
        self.verbose = verbose

    def hash_data(self, data: DataType) -> _DatasetFingerprint:
        """
        Main entry point: hash any supported data format.
        Auto-detects format and applies optimal strategy.
        """
        if isinstance(data, pd.DataFrame):
            return self._hash_pandas(data)
        elif isinstance(data, pl.DataFrame):
            return self._hash_polars(data)
        elif isinstance(data, pd.Series):
            return self._hash_pandas_series(data)
        elif isinstance(data, pl.Series):
            return self._hash_polars_series(data)
        elif isinstance(data, np.ndarray):
            return self._hash_numpy(data)
        elif sparse.issparse(data):
            return self._hash_sparse(data)
        else:
            raise TypeError(f"Unsupported data type: {type(data)}")

    def _hash_pandas(self, df: pd.DataFrame) -> _DatasetFingerprint:
        """
        Optimized Pandas hashing using pd.util.hash_pandas_object.
        Avoids object-to-string conversion overhead.
        """
        method = self._determine_method(self.data_size, self.method)
        self.verbose and logger.info(
            f"Hashing Pandas: {self.data_size:,} rows - {method}"
        )
        target_df = df
        if method == "sampled":
            target_df = self._get_pandas_sample(df)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns.tolist(),
                "dtypes": {k: str(v) for (k, v) in df.dtypes.items()},
                "shape": df.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(target_df, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception as e:
            self.verbose and logger.warning(
                f"Fast hash failed, falling back to slow hash: {e}"
            )
            self._hash_pandas_fallback(hasher, target_df)
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas",
            method=method,
        )

    def _get_pandas_sample(self, df: pd.DataFrame) -> pd.DataFrame:
        """Deterministic slicing for sampling (Zero randomness)."""
        if self.data_size <= self.sample_size:
            return df
        chunk = self.sample_size // 3
        head = df.iloc[:chunk]
        mid_idx = self.data_size // 2
        mid = df.iloc[mid_idx : mid_idx + chunk]
        tail = df.iloc[-chunk:]
        return pd.concat([head, mid, tail])

    def _hash_pandas_fallback(self, hasher, df: pd.DataFrame):
        """Legacy fallback for complex object types."""
        for col in df.columns:
            val = df[col].astype(str).values
            hasher.update(val.astype(np.bytes_).tobytes())

    def _hash_polars(self, df: pl.DataFrame) -> _DatasetFingerprint:
        """
        Optimized Polars hashing using native Rust execution.
        """
        method = self._determine_method(self.data_size, self.method)
        self.verbose and logger.info(
            f"Hashing Polars: {self.data_size:,} rows - {method}"
        )
        target_df = df
        if method == "sampled" and self.data_size > self.sample_size:
            indices = self._get_sample_indices(self.data_size, self.sample_size)
            target_df = df.gather(indices)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns,
                "dtypes": [str(t) for t in df.dtypes],
                "shape": df.shape,
            },
        )
        row_hashes = target_df.hash_rows()
        hasher.update(memoryview(row_hashes.to_numpy()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars",
            method=method,
        )

    def _hash_pandas_series(self, series: pd.Series) -> _DatasetFingerprint:
        """Hash Pandas Series using the fastest vectorized method."""
        self.verbose and logger.info(f"Hashing Pandas Series: {self.data_size:,} rows")
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "name": series.name if series.name else "None",
                "dtype": str(series.dtype),
                "shape": series.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(series, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception as e:
            self.verbose and logger.warning(f"Series hash failed, falling back: {e}")
            hasher.update(memoryview(series.astype(str).values.tobytes()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas_series",
            method="full",
        )

    def _hash_polars_series(self, series: pl.Series) -> _DatasetFingerprint:
        """Hash Polars Series using native Polars expressions."""
        self.verbose and logger.info(f"Hashing Polars Series: {self.data_size:,} rows")
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"name": series.name, "dtype": str(series.dtype), "shape": series.shape},
        )
        try:
            row_hashes = series.hash()
            hasher.update(memoryview(row_hashes.to_numpy()))
        except Exception as e:
            self.verbose and logger.warning(
                f"Polars series native hash failed. Falling back."
            )
            hasher.update(str(series.to_list()).encode())
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars_series",
            method="full",
        )

    def _hash_numpy(self, arr: np.ndarray) -> _DatasetFingerprint:
        """
        Optimized NumPy hashing using Buffer Protocol (Zero-Copy).
        """
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"shape": arr.shape, "dtype": str(arr.dtype), "strides": arr.strides},
        )
        if arr.flags["C_CONTIGUOUS"] or arr.flags["F_CONTIGUOUS"]:
            hasher.update(memoryview(arr))
        else:
            hasher.update(memoryview(np.ascontiguousarray(arr)))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=arr.shape,
            computed_at=datetime.now().isoformat(),
            data_type="numpy",
            method="full",
        )

    def _hash_sparse(self, matrix: sparse.spmatrix) -> _DatasetFingerprint:
        """
        Optimized sparse hashing. Hashes underlying data arrays directly.
        """
        if not (sparse.isspmatrix_csr(matrix) or sparse.isspmatrix_csc(matrix)):
            matrix = matrix.tocsr()
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher, {"shape": matrix.shape, "format": matrix.format, "nnz": matrix.nnz}
        )
        hasher.update(memoryview(matrix.data))
        hasher.update(memoryview(matrix.indices))
        hasher.update(memoryview(matrix.indptr))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=matrix.shape,
            computed_at=datetime.now().isoformat(),
            data_type=f"sparse_{matrix.format}",
            method="sparse",
        )

    def _determine_method(self, rows: int, requested: str) -> str:
        if requested != "auto":
            return requested
        if rows < 5000000:
            return "full"
        return "sampled"

    def _hash_schema(self, hasher, schema: Dict[str, Any]):
        """Compact schema hashing."""
        hasher.update(
            json.dumps(schema, sort_keys=True, separators=(",", ":")).encode()
        )

    def _get_sample_indices(self, total_rows: int, sample_size: int) -> list:
        """Calculate indices for sampling without generating full range lists."""
        chunk = sample_size // 3
        indices = list(range(min(chunk, total_rows)))
        mid_start = max(0, total_rows // 2 - chunk // 2)
        mid_end = min(mid_start + chunk, total_rows)
        indices.extend(range(mid_start, mid_end))
        last_start = max(0, total_rows - chunk)
        indices.extend(range(last_start, total_rows))
        return sorted(list(set(indices)))


def wape_score(y_true, y_pred):
    """
    Calculates Weighted Absolute Percentage Error (WAPE).

    WAPE = sum(|Error|) / sum(|Groundtruth|)
    """
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    eps = np.finfo(np.float64).eps
    sum_abs_error = np.sum(np.abs(y_true - y_pred))
    sum_abs_truth = np.maximum(np.sum(np.abs(y_true)), eps)
    return sum_abs_error / sum_abs_truth


def forecast_accuracy(y_true, y_pred):
    """
    Calculates Forecast Accuracy.

    FA = 1 - (sum(|Error|) / sum(|Groundtruth|))
    """
    y_true = np.asarray(y_true, dtype=np.float64)
    y_pred = np.asarray(y_pred, dtype=np.float64)
    eps = np.finfo(np.float64).eps
    sum_abs_error = np.sum(np.abs(y_true - y_pred))
    sum_abs_truth = np.maximum(np.sum(np.abs(y_true)), eps)
    return 1 - sum_abs_error / sum_abs_truth


def _normalize_hpo_df(df):
    df = df.copy()
    param_cols = [c for c in df.columns if c.startswith("params_")]
    df[param_cols] = df[param_cols].astype("string[pyarrow]")
    return df


def _get_shape_and_sparsity(X: Any) -> Tuple[int, int, float, bool]:
    """
    Efficiently extracts shape and estimates sparsity without converting
    the entire dataset to numpy.
    """
    (n_samples, n_features) = (0, 0)
    is_sparse = False
    sparsity = 0.0
    if hasattr(X, "nnz") and hasattr(X, "shape"):
        (n_samples, n_features) = X.shape
        is_sparse = True
        sparsity = 1.0 - X.nnz / (n_samples * n_features)
        return (n_samples, n_features, sparsity, is_sparse)
    if hasattr(X, "height") and hasattr(X, "width"):
        (n_samples, n_features) = (X.height, X.width)
        return (n_samples, n_features, 0.0, False)
    if hasattr(X, "shape") and hasattr(X, "iloc"):
        (n_samples, n_features) = X.shape
        return (n_samples, n_features, 0.0, False)
    if isinstance(X, list):
        X = np.array(X)
    if hasattr(X, "shape"):
        (n_samples, n_features) = X.shape
        return (n_samples, n_features, 0.0, False)
    raise ValueError("Unsupported data type")


def _smart_split(
    n_samples,
    X,
    y,
    *,
    random_state=42,
    shuffle=True,
    stratify=None,
    fixed_test_split=None,
    verbose=True,
):
    """
    Parameters
    ----------
    n_samples : int
        Number of samples in the dataset (len(X) or len(y))
    X : array-like
        Features
    y : array-like
        Target
    random_state : int
    shuffle : bool
    stratify : array-like or None
        For stratified splitting (recommended for classification)

    Returns
    -------
    If return_val=True  → X_train, X_val, X_test, y_train, y_val, y_test
    If return_val=False → X_train, X_test, y_train, y_test
    """
    if fixed_test_split:
        test_ratio = fixed_test_split
        val_ratio = fixed_test_split
    elif n_samples <= 1000:
        test_ratio = 0.2
        val_ratio = 0.1
    elif n_samples < 10000:
        test_ratio = 0.15
        val_ratio = 0.15
    elif n_samples < 100000:
        test_ratio = 0.1
        val_ratio = 0.1
    elif n_samples < 1000000:
        test_ratio = 0.05
        val_ratio = 0.05
    else:
        test_ratio = 0.01
        val_ratio = 0.01
    (X_train, X_test, y_train, y_test) = train_test_split(
        X,
        y,
        test_size=test_ratio,
        random_state=random_state,
        shuffle=shuffle,
        stratify=stratify,
    )
    val_size_in_train = val_ratio / (1 - test_ratio)
    verbose and logger.info(
        f"Split → Train: {1 - test_ratio:.2%} | Test: {test_ratio:.2%} (no validation set)"
    )
    return (X_train, X_test, y_train, y_test, val_size_in_train)


def _ensure_feature_names(X, feature_names=None):
    if isinstance(X, pd.DataFrame):
        return list(X.columns)
    if isinstance(X, np.ndarray):
        if feature_names is None:
            feature_names = [f"feature_{i}" for i in range(X.shape[1])]
        return feature_names
    raise TypeError("X must be a pandas DataFrame or numpy ndarray")


def _perform_cross_validation(
    model, X, y, cv_folds, shuffle, random_state, n_jobs, verbose
) -> dict[str, Any]:
    """Perform cross-validation on the regression model."""
    verbose and logger.info(f"Performing {cv_folds}-fold cross-validation...")
    cv = KFold(n_splits=cv_folds, shuffle=shuffle, random_state=random_state)
    scoring = {
        "MAE": "neg_mean_absolute_error",
        "MSE": "neg_mean_squared_error",
        "RMSE": "neg_root_mean_squared_error",
        "MAPE": "neg_mean_absolute_percentage_error",
        "R2": "r2",
        "WAPE": make_scorer(wape_score, greater_is_better=False),
        "Forecast_Accuracy": make_scorer(forecast_accuracy, greater_is_better=True),
    }
    cv_results = cross_validate(
        model, X, y, cv=cv, scoring=scoring, return_train_score=False, n_jobs=n_jobs
    )

    def get_score_stats(metric_key, invert_sign=False):
        key = f"test_{metric_key}"
        if key in cv_results:
            scores = cv_results[key]
            if invert_sign:
                scores = -scores
            return (scores.mean(), scores.std())
        return (0.0, 0.0)

    (mae_mean, mae_std) = get_score_stats("MAE", invert_sign=True)
    (mse_mean, mse_std) = get_score_stats("MSE", invert_sign=True)
    (rmse_mean, rmse_std) = get_score_stats("RMSE", invert_sign=True)
    (mape_mean, mape_std) = get_score_stats("MAPE", invert_sign=True)
    (wape_mean, wape_std) = get_score_stats("WAPE", invert_sign=True)
    (r2_mean, r2_std) = get_score_stats("R2", invert_sign=False)
    (fa_mean, fa_std) = get_score_stats("Forecast_Accuracy", invert_sign=False)
    verbose and logger.info(f"CV MAE          : {mae_mean:.4f} (+/- {mae_std:.4f})")
    verbose and logger.info(f"CV MSE          : {mse_mean:.4f} (+/- {mse_std:.4f})")
    verbose and logger.info(f"CV RMSE         : {rmse_mean:.4f} (+/- {rmse_std:.4f})")
    verbose and logger.info(f"CV MAPE         : {mape_mean:.4f} (+/- {mape_std:.4f})")
    verbose and logger.info(f"CV WAPE         : {wape_mean:.4f} (+/- {wape_std:.4f})")
    verbose and logger.info(f"CV R2 Score     : {r2_mean:.4f} (+/- {r2_std:.4f})")
    verbose and logger.info(f"CV Forecast Acc : {fa_mean:.4f} (+/- {fa_std:.4f})")
    CV_metrics = pd.DataFrame(
        {
            "Metric": [
                "Mean Absolute Error (MAE)",
                "Mean Squared Error (MSE)",
                "Root Mean Squared Error (RMSE)",
                "Mean Absolute Percentage Error (MAPE)",
                "Weighted Absolute Percentage Error (WAPE)",
                "R2 Score",
                "Forecast Accuracy",
            ],
            "Mean": [
                mae_mean,
                mse_mean,
                rmse_mean,
                mape_mean,
                wape_mean,
                r2_mean,
                fa_mean,
            ],
            "Std": [mae_std, mse_std, rmse_std, mape_std, wape_std, r2_std, fa_std],
        }
    )
    return CV_metrics


def _compute_score(model, X, y, metric):
    """
    Computes the score for the model on the given data based on the selected metric.
    Assumes 'metric' is passed as the short code (e.g., "MAE", "R2", "FA").
    """
    y_pred = model.predict(X)
    if metric == "mae":
        score = mean_absolute_error(y, y_pred)
    elif metric == "mse":
        score = mean_squared_error(y, y_pred)
    elif metric == "rmse":
        score = root_mean_squared_error(y, y_pred)
    elif metric == "mape":
        score = mean_absolute_percentage_error(y, y_pred)
    elif metric == "r2":
        score = r2_score(y, y_pred)
    elif metric == "wape" or metric == "fa":
        y_true_np = np.array(y, dtype=float).flatten()
        y_pred_np = np.array(y_pred, dtype=float).flatten()
        eps = np.finfo(np.float64).eps
        sum_abs_error = np.sum(np.abs(y_true_np - y_pred_np))
        sum_abs_truth = np.maximum(np.sum(np.abs(y_true_np)), eps)
        wape_val = sum_abs_error / sum_abs_truth
        if metric == "fa":
            score = 1.0 - wape_val
        else:
            score = wape_val
    else:
        raise ValueError(f"Unknown regression metric: {metric}")
    return score


def _get_cv_scoring_object(metric: str) -> Any:
    """
    Returns a scoring object (string or callable) suitable for cross_validate or GridSearchCV.
    Used during HPO for Regression.
    """
    if metric == "mae":
        return "neg_mean_absolute_error"
    elif metric == "mse":
        return "neg_mean_squared_error"
    elif metric == "rmse":
        return "neg_root_mean_squared_error"
    elif metric == "r2":
        return "r2"
    elif metric == "mape":
        return "neg_mean_absolute_percentage_error"
    elif metric == "wape":
        return make_scorer(wape_score, greater_is_better=False)
    elif metric == "fa":
        return make_scorer(forecast_accuracy, greater_is_better=True)
    else:
        return "neg_root_mean_squared_error"


def _hyperparameters_optimization(
    X,
    y,
    constant_hyperparameters,
    optimization_metric,
    val_ratio,
    shuffle_split,
    use_cross_val,
    cv_folds,
    n_trials=50,
    strategy="maximize",
    sampler="Tree-structured Parzen",
    seed=None,
    n_jobs=-1,
    verbose=False,
):
    direction = "maximize" if strategy.lower() == "maximize" else "minimize"
    sampler_map = {
        "Tree-structured Parzen": TPESampler(seed=seed),
        "Gaussian Process": GPSampler(seed=seed),
        "CMA-ES": CmaEsSampler(seed=seed),
        "Random Search": RandomSampler(seed=seed),
        "Random Sobol Search": QMCSampler(seed=seed),
    }
    if sampler in sampler_map:
        chosen_sampler = sampler_map[sampler]
    else:
        logger.warning(f"Sampler '{sampler}' not recognized → falling back to TPE")
        chosen_sampler = TPESampler(seed=seed)
    chosen_pruner = HyperbandPruner()
    if use_cross_val:
        cv = KFold(n_splits=cv_folds, shuffle=shuffle_split, random_state=seed)
        cv_score_obj = _get_cv_scoring_object(optimization_metric)
    else:
        (X_train, X_val, y_train, y_val) = train_test_split(
            X, y, test_size=val_ratio, random_state=seed, shuffle=shuffle_split
        )

    def logging_callback(study: Study, trial: FrozenTrial):
        """Callback function to log trial progress"""
        verbose and logger.info(
            f"Trial {trial.number} finished with value: {trial.value} and parameters: {trial.params}"
        )
        try:
            verbose and logger.info(f"Best value so far: {study.best_value}")
            verbose and logger.info(f"Best parameters so far: {study.best_params}")
        except ValueError:
            verbose and logger.info(f"No successful trials completed yet")
        verbose and logger.info(f"" + "-" * 50)

    def objective(trial):
        try:
            booster = constant_hyperparameters.get("booster")
            feature_types = constant_hyperparameters.get("feature_types")
            params = {}
            params["n_estimators"] = trial.suggest_int("n_estimators", 50, 1000)
            params["max_depth"] = trial.suggest_int("max_depth", 1, 15)
            params["learning_rate"] = trial.suggest_float(
                "learning_rate", 0.0001, 1.0, log=True
            )
            params["reg_lambda"] = trial.suggest_float(
                "reg_lambda", 1e-08, 100.0, log=True
            )
            params["reg_alpha"] = trial.suggest_float(
                "reg_alpha", 1e-08, 100.0, log=True
            )
            params["gamma"] = trial.suggest_float("gamma", 0.0, 5.0)
            params["min_child_weight"] = trial.suggest_float(
                "min_child_weight", 0.0, 10.0
            )
            params["subsample"] = trial.suggest_float("subsample", 0.1, 1.0)
            params["colsample_bytree"] = trial.suggest_float(
                "colsample_bytree", 0.1, 1.0
            )
            model = XGBRegressor(
                **params,
                feature_types=feature_types,
                tree_method="hist",
                enable_categorical=True,
                booster=booster,
                random_state=seed,
                n_jobs=n_jobs,
            )
            if use_cross_val:
                scores = cross_validate(
                    model, X, y, cv=cv, n_jobs=n_jobs, scoring=cv_score_obj
                )
                return scores["test_score"].mean()
            else:
                model.fit(X_train, y_train)
                score = _compute_score(model, X_val, y_val, optimization_metric)
                return score
        except Exception as e:
            verbose and logger.error(
                f"Trial {trial.number} failed with error: {str(e)}"
            )
            raise

    study = create_study(
        direction=direction, sampler=chosen_sampler, pruner=chosen_pruner
    )
    study.optimize(
        objective,
        n_trials=n_trials,
        catch=(Exception,),
        n_jobs=n_jobs,
        callbacks=[logging_callback],
    )
    verbose and logger.info(f"Optimization completed!")
    verbose and logger.info(
        f"   Best Number of Trees       : {study.best_params['n_estimators']}"
    )
    verbose and logger.info(
        f"   Best Max Depth             : {study.best_params['max_depth']}"
    )
    verbose and logger.info(
        f"   Best Learning Rate         : {study.best_params['learning_rate']}"
    )
    verbose and logger.info(
        f"   Best L1 Regularization     : {study.best_params['reg_alpha']}"
    )
    verbose and logger.info(
        f"   Best L2 Regularization     : {study.best_params['reg_lambda']}"
    )
    verbose and logger.info(
        f"   Best Gamma                 : {study.best_params['gamma']}"
    )
    verbose and logger.info(
        f"   Best Min Child Weight      : {study.best_params['min_child_weight']}"
    )
    verbose and logger.info(
        f"   Best Subsample Ratio       : {study.best_params['subsample']}"
    )
    verbose and logger.info(
        f"   Best Colsample by Tree     : {study.best_params['colsample_bytree']}"
    )
    verbose and logger.info(
        f"   Best {optimization_metric:<22}: {study.best_value:.4f}"
    )
    verbose and logger.info(f"   Sampler used               : {sampler}")
    verbose and logger.info(f"   Direction                  : {direction}")
    if use_cross_val:
        verbose and logger.info(f"   Cross-validation           : {cv_folds}-fold")
    else:
        verbose and logger.info(
            f"   Validation                 : single train/val split"
        )
    trials = study.trials_dataframe()
    trials["best_value"] = trials["value"].cummax()
    cols = list(trials.columns)
    value_idx = cols.index("value")
    cols = [c for c in cols if c != "best_value"]
    new_order = cols[: value_idx + 1] + ["best_value"] + cols[value_idx + 1 :]
    trials = trials[new_order]
    return (study.best_params, trials)


def _combine_test_data(X_test, y_true, y_pred, features_names=None):
    """
    Combine X_test, y_true, y_pred into a single DataFrame.

    Parameters:
    -----------
    X_test : pandas/polars DataFrame, numpy array, or scipy sparse matrix
        Test features
    y_true : pandas/polars Series, numpy array, or list
        True labels
    y_pred : pandas/polars Series, numpy array, or list
        Predicted labels

    Returns:
    --------
    pandas.DataFrame
        Combined DataFrame with features, y_true, and y_pred
    """
    if sparse.issparse(X_test):
        X_df = pd.DataFrame(X_test.toarray())
    elif isinstance(X_test, np.ndarray):
        X_df = pd.DataFrame(X_test)
    elif hasattr(X_test, "to_pandas"):
        X_df = X_test.to_pandas()
    elif isinstance(X_test, pd.DataFrame):
        X_df = X_test.copy()
    else:
        raise TypeError(f"Unsupported type for X_test: {type(X_test)}")
    if X_df.columns.tolist() == list(range(len(X_df.columns))):
        X_df.columns = (
            [f"feature_{i}" for i in range(len(X_df.columns))]
            if features_names is None
            else features_names
        )
    if isinstance(y_true, list):
        y_true_series = pd.Series(y_true, name="y_true")
    elif isinstance(y_true, np.ndarray):
        y_true_series = pd.Series(y_true, name="y_true")
    elif hasattr(y_true, "to_pandas"):
        y_true_series = y_true.to_pandas()
        y_true_series.name = "y_true"
    elif isinstance(y_true, pd.Series):
        y_true_series = y_true.copy()
        y_true_series.name = "y_true"
    else:
        raise TypeError(f"Unsupported type for y_true: {type(y_true)}")
    if isinstance(y_pred, list):
        y_pred_series = pd.Series(y_pred, name="y_pred")
    elif isinstance(y_pred, np.ndarray):
        y_pred_series = pd.Series(y_pred, name="y_pred")
    elif hasattr(y_pred, "to_pandas"):
        y_pred_series = y_pred.to_pandas()
        y_pred_series.name = "y_pred"
    elif isinstance(y_pred, pd.Series):
        y_pred_series = y_pred.copy()
        y_pred_series.name = "y_pred"
    else:
        raise TypeError(f"Unsupported type for y_pred: {type(y_pred)}")
    X_df = X_df.reset_index(drop=True)
    y_true_series = y_true_series.reset_index(drop=True)
    y_pred_series = y_pred_series.reset_index(drop=True)
    result_df = pd.concat([X_df, y_true_series, y_pred_series], axis=1)
    return result_df


def _get_feature_importance(model, feature_names=None, sort=True, top_n=None):
    """
    Extract feature importance from a Random Forest model.

    Parameters:
    -----------
    model : Fitted model
    feature_names : list or array-like, optional
        Names of features. If None, uses generic names like 'feature_0', 'feature_1', etc.
    sort : bool, default=True
        Whether to sort features by importance (descending)
    top_n : int, optional
        If specified, returns only the top N most important features

    Returns:
    --------
    pd.DataFrame
        DataFrame with columns: 'feature', 'importance'
        Importance values represent the mean decrease in impurity (Gini importance)
    """
    importances = model.feature_importances_
    if feature_names is None:
        feature_names = [f"feature_{i}" for i in range(len(importances))]
    importance_df = pd.DataFrame({"feature": feature_names, "importance": importances})
    if sort:
        importance_df = importance_df.sort_values("importance", ascending=False)
    importance_df = importance_df.reset_index(drop=True)
    if top_n is not None:
        importance_df = importance_df.head(top_n)
    return importance_df


def _smart_shap_background(
    X: Union[np.ndarray, pd.DataFrame],
    model_type: str = "tree",
    seed: int = 42,
    verbose: bool = True,
) -> Union[np.ndarray, pd.DataFrame, object]:
    """
    Intelligently prepares a background dataset for SHAP based on model type.

    Strategies:
    - Tree: Higher sample cap (1000), uses Random Sampling (preserves data structure).
    - Other: Lower sample cap (100), uses K-Means (maximizes info density).
    """
    (n_rows, n_features) = X.shape
    if model_type == "tree":
        max_samples = 1000
        use_kmeans = False
    else:
        max_samples = 100
        use_kmeans = True
    if n_rows <= max_samples:
        verbose and logger.info(
            f"✓ Dataset small ({n_rows} <= {max_samples}). Using full data."
        )
        return X
    verbose and logger.info(
        f"⚡ Large dataset detected ({n_rows} rows). Optimization Strategy: {('K-Means' if use_kmeans else 'Random Sampling')}"
    )
    if use_kmeans:
        try:
            verbose and logger.info(
                f"   Summarizing to {max_samples} weighted centroids..."
            )
            return shap.kmeans(X, max_samples)
        except Exception as e:
            logger.warning(
                f"   K-Means failed ({str(e)}). Falling back to random sampling."
            )
            return shap.sample(X, max_samples, random_state=seed)
    else:
        verbose and logger.info(f"   Sampling {max_samples} random rows...")
        return shap.sample(X, max_samples, random_state=seed)


def _get_xgb_feature_types(X):
    """
    Generate feature_types array for XGBClassifier based on data types in X.

    Parameters:
    -----------
    X : pandas.DataFrame or numpy.ndarray
        Feature matrix for XGBClassifier

    Returns:
    --------
    list : List of feature types ('q' for quantitative, 'c' for categorical)

    Notes:
    ------
    - Numeric types (int, float) and boolean are treated as quantitative ('q')
    - Object and category types are treated as categorical ('c')
    - If X is a numpy array, all features are assumed quantitative
    """
    if isinstance(X, pd.DataFrame):
        feature_types = []
        for col in X.columns:
            dtype = X[col].dtype
            if dtype == "bool":
                feature_types.append("q")
            elif dtype == "object" or isinstance(dtype, pd.CategoricalDtype):
                feature_types.append("c")
            elif pd.api.types.is_numeric_dtype(dtype):
                feature_types.append("q")
            else:
                feature_types.append("q")
        return feature_types
    elif isinstance(X, np.ndarray):
        if X.dtype == bool:
            return ["c"] * X.shape[1]
        else:
            return ["q"] * X.shape[1]
    else:
        raise TypeError("X must be a pandas DataFrame or numpy ndarray")


def train_reg_xgboost(
    X: DataFrame, y: Union[DataSeries, NDArray, List], options=None
) -> Tuple[
    Any,
    Any,
    Any,
    Union[DataFrame, Dict],
    DataFrame,
    DataFrame,
    DataFrame,
    DataFrame,
    Dict,
]:
    options = options or {}
    n_estimators = options.get("n_estimators", 100)
    early_stopping = options.get("early_stopping", True)
    early_stopping_rounds = options.get("early_stopping_rounds", 10)
    use_dart = options.get("use_dart", False)
    max_depth = options.get("max_depth", 6)
    learning_rate = options.get("learning_rate", 0.3)
    reg_lambda = options.get("reg_lambda", 1.0)
    reg_alpha = options.get("reg_alpha", 0.0)
    gamma = options.get("gamma", 0.0)
    min_child_weight = options.get("min_child_weight", 1.0)
    subsample = options.get("subsample", 1.0)
    colsample_bytree = options.get("colsample_bytree", 1.0)
    auto_split = options.get("auto_split", True)
    test_val_size = options.get("test_val_size", 15) / 100
    shuffle_split = options.get("shuffle_split", True)
    retrain_on_full = options.get("retrain_on_full", False)
    use_cross_validation = options.get("use_cross_validation", False)
    cv_folds = options.get("cv_folds", 5)
    use_hpo = options.get("use_hyperparameter_optimization", False)
    optimization_metric = options.get(
        "optimization_metric", "Root Mean Squared Error (RMSE)"
    )
    optimization_metric = METRICS_DICT[optimization_metric]
    optimization_method = options.get("optimization_method", "Tree-structured Parzen")
    optimization_iterations = options.get("optimization_iterations", 50)
    return_shap_explainer = options.get("return_shap_explainer", False)
    use_shap_sampler = options.get("use_shap_sampler", False)
    shap_feature_perturbation = options.get(
        "shap_feature_perturbation", "Interventional"
    )
    metrics_as = options.get("metrics_as", "Dataframe")
    n_jobs_str = options.get("n_jobs", "1")
    random_state = options.get("random_state", 42)
    activate_caching = options.get("activate_caching", False)
    verbose = options.get("verbose", True)
    n_jobs_int = -1 if n_jobs_str == "All" else int(n_jobs_str)
    skip_computation = False
    Model = None
    Metrics = pd.DataFrame()
    CV_Metrics = pd.DataFrame()
    Features_Importance = pd.DataFrame()
    Label_Encoder = None
    SHAP = None
    HPO_Trials = pd.DataFrame()
    HPO_Best = None
    fa = None
    wape = None
    mae = None
    mse = None
    rmse = None
    r2 = None
    mape = None
    (n_samples, _) = X.shape
    if activate_caching:
        verbose and logger.info(f"Caching is activate")
        data_hasher = _UniversalDatasetHasher(n_samples, verbose=verbose)
        X_hash = data_hasher.hash_data(X).hash
        y_hash = data_hasher.hash_data(y).hash
        all_hash_base_text = f"HASH BASE TEXTPandas Version {pd.__version__}POLARS Version {pl.__version__}Numpy Version {np.__version__}Scikit Learn Version {sklearn.__version__}Scipy Version {scipy.__version__}{('SHAP Version ' + shap.__version__ if return_shap_explainer else 'NO SHAP Version')}{X_hash}{y_hash}{n_estimators}{early_stopping}{early_stopping_rounds}{max_depth}{learning_rate}{reg_lambda}{reg_alpha}{gamma}{min_child_weight}{subsample}{colsample_bytree}{use_dart}{('Use HPO' if use_hpo else 'No HPO')}{(optimization_metric if use_hpo else 'No HPO Metric')}{(optimization_method if use_hpo else 'No HPO Method')}{(optimization_iterations if use_hpo else 'No HPO Iter')}{(cv_folds if use_cross_validation else 'No CV')}{('Auto Split' if auto_split else test_val_size)}{shuffle_split}{return_shap_explainer}{shap_feature_perturbation}{use_shap_sampler}{random_state}"
        all_hash = hashlib.sha256(all_hash_base_text.encode("utf-8")).hexdigest()
        verbose and logger.info(f"Hash was computed: {all_hash}")
        temp_folder = Path(tempfile.gettempdir())
        cache_folder = temp_folder / "coded-flows-cache"
        cache_folder.mkdir(parents=True, exist_ok=True)
        model_path = cache_folder / f"{all_hash}.json"
        metrics_dict_path = cache_folder / f"metrics_{all_hash}.json"
        metrics_df_path = cache_folder / f"metrics_{all_hash}.parquet"
        cv_metrics_path = cache_folder / f"cv_metrics_{all_hash}.parquet"
        hpo_trials_path = cache_folder / f"hpo_trials_{all_hash}.parquet"
        hpo_best_params_path = cache_folder / f"hpo_best_params_{all_hash}.json"
        features_importance_path = (
            cache_folder / f"features_importance_{all_hash}.parquet"
        )
        prediction_set_path = cache_folder / f"prediction_set_{all_hash}.parquet"
        shap_path = cache_folder / f"{all_hash}.shap"
        label_encoder_path = cache_folder / f"{all_hash}.encoder"
        skip_computation = model_path.is_file()
    if not skip_computation:
        features_names = X.columns if hasattr(X, "columns") else None
        shap_feature_names = _ensure_feature_names(X)
        Label_Encoder = LabelEncoder()
        y = Label_Encoder.fit_transform(y)
        booster = "dart" if use_dart else "gbtree"
        eval_metric = "rmse"
        es_objective = "reg:squarederror"
        feature_types = _get_xgb_feature_types(X)
        fixed_test_split = None if auto_split else test_val_size
        (X_train, X_test, y_train, y_test, val_ratio) = _smart_split(
            n_samples,
            X,
            y,
            random_state=random_state,
            shuffle=shuffle_split,
            fixed_test_split=fixed_test_split,
            verbose=verbose,
        )
        if use_hpo:
            verbose and logger.info(f"Performing Hyperparameters Optimization")
            constant_hyperparameters = {
                "booster": booster,
                "feature_types": feature_types,
            }
            (HPO_Best, HPO_Trials) = _hyperparameters_optimization(
                X_train,
                y_train,
                constant_hyperparameters,
                optimization_metric,
                val_ratio,
                shuffle_split,
                use_cross_validation,
                cv_folds,
                optimization_iterations,
                METRICS_OPT[optimization_metric],
                optimization_method,
                random_state,
                n_jobs_int,
                verbose=verbose,
            )
            HPO_Trials = _normalize_hpo_df(HPO_Trials)
            n_estimators = HPO_Best["n_estimators"]
            max_depth = HPO_Best["max_depth"]
            learning_rate = HPO_Best["learning_rate"]
            reg_lambda = HPO_Best["reg_lambda"]
            reg_alpha = HPO_Best["reg_alpha"]
            gamma = HPO_Best["gamma"]
            min_child_weight = HPO_Best["min_child_weight"]
            subsample = HPO_Best["subsample"]
            colsample_bytree = HPO_Best["colsample_bytree"]
        model_params = {}
        model_params["n_estimators"] = n_estimators
        model_params["max_depth"] = max_depth
        model_params["learning_rate"] = learning_rate
        model_params["reg_lambda"] = reg_lambda
        model_params["reg_alpha"] = reg_alpha
        model_params["gamma"] = gamma
        model_params["min_child_weight"] = min_child_weight
        model_params["subsample"] = subsample
        model_params["colsample_bytree"] = colsample_bytree
        if early_stopping and (not use_hpo):
            model_params["early_stopping_rounds"] = early_stopping_rounds
            model_params["objective"] = es_objective
            model_params["eval_metric"] = eval_metric
        Model = XGBRegressor(
            **model_params,
            feature_types=feature_types,
            tree_method="hist",
            enable_categorical=True,
            booster=booster,
            random_state=random_state,
            n_jobs=n_jobs_int,
        )
        if early_stopping and (not use_hpo):
            (X_train, X_val, y_train, y_val) = train_test_split(
                X_train,
                y_train,
                test_size=val_ratio,
                random_state=random_state,
                shuffle=shuffle_split,
            )
            Model.fit(X_train, y_train, eval_set=[(X_val, y_val)], verbose=False)
            model_params["n_estimators"] = Model.best_iteration + 1
        else:
            Model.fit(X_train, y_train, verbose=False)
        if use_cross_validation and (not use_hpo):
            verbose and logger.info(
                f"Using Cross-Validation to measure performance metrics"
            )
            cv_params = model_params.copy()
            cv_params.pop("early_stopping_rounds", None)
            cv_params.pop("objective", None)
            cv_params.pop("eval_metric", None)
            CV_Model = XGBRegressor(
                **cv_params,
                feature_types=feature_types,
                tree_method="hist",
                enable_categorical=True,
                booster=booster,
                random_state=random_state,
                n_jobs=n_jobs_int,
            )
            CV_Metrics = _perform_cross_validation(
                CV_Model,
                X_train,
                y_train,
                cv_folds,
                shuffle_split,
                random_state,
                n_jobs_int,
                verbose,
            )
        y_pred = Model.predict(X_test)
        fa = forecast_accuracy(y_test, y_pred)
        wape = wape_score(y_test, y_pred)
        mae = mean_absolute_error(y_test, y_pred)
        mse = mean_squared_error(y_test, y_pred)
        rmse = root_mean_squared_error(y_test, y_pred)
        r2 = r2_score(y_test, y_pred)
        mape = mean_absolute_percentage_error(y_test, y_pred)
        if metrics_as == "Dataframe":
            Metrics = pd.DataFrame(
                {
                    "Metric": [
                        "Forecast Accuracy",
                        "Weighted Absolute Percentage Error",
                        "Mean Absolute Error",
                        "Mean Squared Error",
                        "Root Mean Squared Error",
                        "R2 Score",
                        "Mean Absolute Percentage Error",
                    ],
                    "Value": [fa, wape, mae, mse, rmse, r2, mape],
                }
            )
        else:
            Metrics = {
                "forecast_accuracy": fa,
                "weighted_absolute_percentage_error ": wape,
                "mean_absolute_error": mae,
                "mean_squared_error": mse,
                "root_mean_squared_error": rmse,
                "r2_score": r2,
                "mean_absolute_percentage_error": mape,
            }
        verbose and logger.info(f"Forecast Accuracy                  : {fa:.2%}")
        verbose and logger.info(f"Weighted Absolute Percentage Error : {wape:.2%}")
        verbose and logger.info(f"Mean Absolute Error                : {mae:.4f}")
        verbose and logger.info(f"Mean Squared Error                 : {mse:.4f}")
        verbose and logger.info(f"Root Mean Squared Error            : {rmse:.4f}")
        verbose and logger.info(f"R2 Score                           : {r2:.4f}")
        verbose and logger.info(f"Mean Absolute Percentage Error     : {mape:.2%}")
        Prediction_Set = _combine_test_data(X_test, y_test, y_pred, features_names)
        verbose and logger.info(f"Prediction Set created")
        if retrain_on_full:
            verbose and logger.info(
                "Retraining model on full dataset for production deployment"
            )
            if early_stopping:
                model_params.pop("early_stopping_rounds", None)
                model_params.pop("objective", None)
                model_params.pop("eval_metric", None)
                Model = XGBRegressor(
                    **model_params,
                    feature_types=feature_types,
                    tree_method="hist",
                    enable_categorical=True,
                    booster=booster,
                    random_state=random_state,
                    n_jobs=n_jobs_int,
                )
            Model.fit(X, y, verbose=False)
            verbose and logger.info(
                "Model successfully retrained on full dataset. Reported metrics remain from original held-out test set."
            )
        Features_Importance = _get_feature_importance(Model, features_names)
        verbose and logger.info(f"Features Importance computed")
        if return_shap_explainer:
            if shap_feature_perturbation == "Interventional":
                SHAP = shap.TreeExplainer(
                    Model,
                    (
                        _smart_shap_background(
                            X if retrain_on_full else X_train,
                            model_type="tree",
                            seed=random_state,
                            verbose=verbose,
                        )
                        if use_shap_sampler
                        else X if retrain_on_full else X_train
                    ),
                    feature_names=shap_feature_names,
                )
            else:
                SHAP = shap.TreeExplainer(
                    Model,
                    feature_names=shap_feature_names,
                    feature_perturbation="tree_path_dependent",
                )
            verbose and logger.info(f"SHAP explainer generated")
        if activate_caching:
            verbose and logger.info(f"Caching output elements")
            Model.save_model(model_path)
            if isinstance(Metrics, dict):
                with metrics_dict_path.open("w", encoding="utf-8") as f:
                    json.dump(Metrics, f, ensure_ascii=False, indent=4)
            else:
                Metrics.to_parquet(metrics_df_path)
            if use_cross_validation and (not use_hpo):
                CV_Metrics.to_parquet(cv_metrics_path)
            if use_hpo:
                HPO_Trials.to_parquet(hpo_trials_path)
                with hpo_best_params_path.open("w", encoding="utf-8") as f:
                    json.dump(HPO_Best, f, ensure_ascii=False, indent=4)
            Features_Importance.to_parquet(features_importance_path)
            Prediction_Set.to_parquet(prediction_set_path)
            if return_shap_explainer:
                with shap_path.open("wb") as f:
                    joblib.dump(SHAP, f)
            joblib.dump(Label_Encoder, label_encoder_path)
            verbose and logger.info(f"Caching done")
    else:
        verbose and logger.info(f"Skipping computations and loading cached elements")
        Model = XGBRegressor()
        Model.load_model(model_path)
        verbose and logger.info(f"Model loaded")
        if metrics_dict_path.is_file():
            with metrics_dict_path.open("r", encoding="utf-8") as f:
                Metrics = json.load(f)
        else:
            Metrics = pd.read_parquet(metrics_df_path)
        verbose and logger.info(f"Metrics loaded")
        if use_cross_validation and (not use_hpo):
            CV_Metrics = pd.read_parquet(cv_metrics_path)
            verbose and logger.info(f"Cross Validation metrics loaded")
        if use_hpo:
            HPO_Trials = pd.read_parquet(hpo_trials_path)
            with hpo_best_params_path.open("r", encoding="utf-8") as f:
                HPO_Best = json.load(f)
            verbose and logger.info(
                f"Hyperparameters Optimization trials and best params loaded"
            )
        Features_Importance = pd.read_parquet(features_importance_path)
        verbose and logger.info(f"Features Importance loaded")
        Prediction_Set = pd.read_parquet(prediction_set_path)
        verbose and logger.info(f"Prediction Set loaded")
        if return_shap_explainer:
            with shap_path.open("rb") as f:
                SHAP = joblib.load(f)
            verbose and logger.info(f"SHAP Explainer loaded")
        Label_Encoder = joblib.load(label_encoder_path)
        verbose and logger.info(f"Label Encoder loaded")
    return (
        Model,
        SHAP,
        Label_Encoder,
        Metrics,
        CV_Metrics,
        Features_Importance,
        Prediction_Set,
        HPO_Trials,
        HPO_Best,
    )

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • scikit-learn
  • pandas
  • numpy
  • torch
  • numba>=0.56.0
  • shap
  • xgboost
  • cmaes
  • optuna
  • scipy
  • polars
  • xxhash