Cla. Support Vector Machine

Train a Support Vector Machine (SVM) classification model.

Cla. Support Vector Machine

Processing

This brick trains a Support Vector Machine (SVM) classification model. In simple terms, the SVM algorithm looks at your data points and attempts to find the best boundary (or "hyperplane") to separate them into different categories.

It is particularly effective for complex data where categories cannot be easily separated by a straight line. The brick handles essential preprocessing steps automatically, such as scaling your data (crucial for SVMs) and splitting it into training and testing sets. It also includes advanced features like Hyperparameter Optimization to automatically find the best settings for your specific dataset.

Inputs

X
The features (predictors) used to train the model. This must be a collection of numerical data (e.g., age, income, test scores). If your data contains text categories, they must be encoded into numbers before entering this brick.
y
The target variable (labels) you want to predict. This contains the known categories for the training data (e.g., "Yes/No", "Spam/Ham", or specific product categories).

Inputs Types

Input Types
X DataFrame
y DataSeries, NDArray, List

You can check the list of supported types here: Available Type Hints.

Outputs

Model
The trained SVM model object (Scikit-learn SVC). This object can be passed to other bricks to make predictions on new, unseen data.
Model Classes
A reference table mapping the model's internal numerical indices to the actual class names (labels).
SHAP
A SHAP explainer object. This is used to generate explanations for why the model made specific predictions (requires the SHAP Explainer option to be enabled).
Scaler
The scaling logic used during training. This is required to process new data in the exact same way as the training data.
Metrics
A summary of the model's performance on the test set, including Accuracy, Precision, Recall, F1-Score, and ROC-AUC.
CV Metrics
Performance metrics calculated using Cross-Validation. This provides a more robust estimate of how the model will perform on different subsets of data.
Prediction Set
A detailed table showing the model's predictions on the test dataset. It includes the original features, the actual values, the predicted values, and probability scores.
HPO Trials
A log of all attempts made during Hyperparameter Optimization, showing which settings were tried and how they performed.
HPO Best
A dictionary containing the single best combination of settings found during optimization.

The Prediction Set output contains the following specific data fields:

  • feature_{name}: The original input features used for the prediction.
  • proba_{class_name}: The probability score calculated by the model for a specific class (e.g., proba_Churn).
  • y_true: The actual known label for that row.
  • y_pred: The label predicted by the model.
  • is_false_prediction: A boolean (True/False) indicating if the model made a mistake on this row.

Outputs Types

Output Types
Model Any
Model Classes DataFrame
SHAP Any
Scaler Any
Metrics DataFrame, Dict
CV Metrics DataFrame
Prediction Set DataFrame
HPO Trials DataFrame
HPO Best Dict

You can check the list of supported types here: Available Type Hints.

Options

The Cla. Support Vector Machine brick contains some changeable options:

Kernel
Determines the shape of the decision boundary used to separate classes.
  • RBF (Radial Basis Function): The default and most versatile choice. Good for curved boundaries and complex patterns.
  • Linear: Creates a straight line (or flat plane) boundary. Best for simple, high-dimensional data.
  • Polynomial: Creates curved boundaries based on polynomial equations.
  • Sigmoid: Similar to a neural network activation function.
C
Controls the trade-off between a smooth decision boundary and classifying training points correctly.
  • Low value (e.g., 0.1): Makes the decision surface smoother. It allows more misclassifications in the training data but may generalize better to new data.
  • High value (e.g., 100): Attempts to classify all training examples correctly. This creates a complex boundary but risks "overfitting" (memorizing) the data.
Polynomoial Degree
Only used when Kernel is set to "Polynomial". It controls the complexity of the polynomial curve (e.g., 2 is quadratic, 3 is cubic). Higher degrees create more complex shapes.
Gamma
Defines how far the influence of a single training example reaches.
  • Scale: (Default) Uses 1 / (n_features _ X.var()). Adapts to the variance of your data.
  • Auto: Uses 1 / n_features.
  • Numerical: Allows you to set a specific custom value using the Gamma Num. Value option.
Gamma Num. Value
A specific float value for Gamma. Only used if Gamma is set to "Numerical". A low value means "far" (points far apart are considered similar), while a high value means "close" (only very close points are similar).
Tolerance
The stopping criterion. It tells the algorithm when to stop searching for a better solution because the improvements have become insignificant.
Maximum Iterations
The limit on the number of iterations to run. Set to -1 for no limit (runs until the Tolerance is met).
Standard Scaling
If enabled, the brick automatically standardizes features (removes mean and scales to unit variance) before training. Highly Recommended for SVMs, as they are very sensitive to the scale of the data.
Auto Split Data
If enabled, the system automatically determines the best ratio for splitting data into training and testing sets based on the dataset size.
Shuffle Split
If enabled, shuffles the data randomly before splitting. This ensures the order of data doesn't bias the training.
Stratify Split
If enabled, ensures that the training and testing sets have the same proportion of class labels as the original data. Crucial for imbalanced datasets (e.g., rare fraud detection).
Test/Validation Set %
The percentage of data to hold back for testing. Only used if Auto Split Data is disabled.
Retrain On Full Data
If enabled, the model performs a final training pass using 100% of the available data after metrics have been calculated on the test split. Useful for production models.
Average Strategy
Determines how metrics (Precision, Recall, F1) are averaged for multiclass problems.
  • auto: Automatically selects the best strategy based on the data balance.
  • binary: For two classes only.
  • micro: Calculates metrics globally by counting total true positives, false negatives, and false positives.
  • macro: Calculates metrics for each label, and finds their unweighted mean. Does not take label imbalance into account.
  • weighted: Calculates metrics for each label, and finds their average weighted by the number of true instances for each label.
Enable Cross-Validation
If enabled, performs k-fold cross-validation instead of a single train/test split to evaluate performance. This takes longer but produces more reliable metrics.
Number of CV Folds
The number of groups (folds) to split the data into during cross-validation (e.g., 5 folds means the model is trained 5 times).
Hyperparameter Optim.
If enabled, the brick attempts to find the optimal values for C, Kernel, Gamma, etc., automatically using the Optuna library.
Optimization Metric
The performance score the optimizer tries to maximize (e.g., "F1 Score").
Optimization Method
The algorithm used to search for the best hyperparameters.
  • Tree-structured Parzen: A Bayesian optimization method that models good vs bad parameter regions using probability distributions and prioritizes sampling where success is statistically more likely.
  • Gaussian Process: ses a probabilistic regression model (Gaussian Process) to estimate performance uncertainty and selects new trials using acquisition functions.
  • CMA-ES: An evolutionary strategy that adapts the covariance matrix of a multivariate normal distribution to efficiently search complex, non-linear, non-convex spaces.
  • Random Sobol Search: Uses low-discrepancy quasi-random sequences to ensure uniform coverage of the parameter space, avoiding clustering and gaps.
  • Random Search: Uniform random sampling of parameter configurations without learning or feedback between iterations.
Optimization Iterations
The number of different setting combinations to try. Higher numbers take longer but may find better models.
Positive Label (Binary Only)
Explicitly defines which class label should be considered "Positive" (e.g., "1", "Yes", "Churned"). Important for calculating Precision/Recall correctly in binary classification.
Metrics as
Format of the Metrics output.
SHAP Explainer
If enabled, generates the SHAP output object for model explainability.
SHAP Sampler
If enabled, uses a subset of data to generate the SHAP background. This significantly speeds up calculation for large datasets.
Number of Jobs
The number of CPU cores to use for calculation. "All" uses all available cores.
Random State
A seed number to ensure reproducibility. Using the same seed and input data will always result in the same model.
Brick Caching
If enabled, saves the results to a temporary cache. If the brick is run again with the exact same inputs and settings, it loads the result from cache instead of recalculating.
Verbose Logging
If enabled, prints detailed progress logs to the console.
import logging
import warnings
import shap
import json
import xxhash
import hashlib
import tempfile
import sklearn
import scipy
import joblib
import numpy as np
import pandas as pd
import polars as pl
from pathlib import Path
from scipy import sparse
from optuna.samplers import (
    TPESampler,
    RandomSampler,
    GPSampler,
    CmaEsSampler,
    QMCSampler,
)
import optuna
from optuna import Study
from optuna.trial import FrozenTrial
from optuna.pruners import HyperbandPruner
from optuna import create_study
from sklearn.svm import SVC
from sklearn.model_selection import train_test_split, cross_validate, StratifiedKFold
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import (
    accuracy_score,
    precision_score,
    recall_score,
    f1_score,
    roc_auc_score,
    make_scorer,
)
from dataclasses import dataclass
from datetime import datetime
from coded_flows.types import (
    Union,
    Dict,
    List,
    Tuple,
    NDArray,
    DataFrame,
    DataSeries,
    Any,
    Tuple,
)
from coded_flows.utils import CodedFlowsLogger

logger = CodedFlowsLogger(name="Cla. Support Vector Machine", level=logging.INFO)
optuna.logging.set_verbosity(optuna.logging.ERROR)
warnings.filterwarnings("ignore", category=optuna.exceptions.ExperimentalWarning)
DataType = Union[
    pd.DataFrame, pl.DataFrame, np.ndarray, sparse.spmatrix, pd.Series, pl.Series
]


@dataclass
class _DatasetFingerprint:
    """Lightweight fingerprint of a dataset."""

    hash: str
    shape: tuple
    computed_at: str
    data_type: str
    method: str


class _UniversalDatasetHasher:
    """
    High-performance dataset hasher optimizing for zero-copy operations
    and native backend execution (C/Rust).
    """

    def __init__(
        self,
        data_size: int,
        method: str = "auto",
        sample_size: int = 100000,
        verbose: bool = False,
    ):
        self.method = method
        self.sample_size = sample_size
        self.data_size = data_size
        self.verbose = verbose

    def hash_data(self, data: DataType) -> _DatasetFingerprint:
        """
        Main entry point: hash any supported data format.
        Auto-detects format and applies optimal strategy.
        """
        if isinstance(data, pd.DataFrame):
            return self._hash_pandas(data)
        elif isinstance(data, pl.DataFrame):
            return self._hash_polars(data)
        elif isinstance(data, pd.Series):
            return self._hash_pandas_series(data)
        elif isinstance(data, pl.Series):
            return self._hash_polars_series(data)
        elif isinstance(data, np.ndarray):
            return self._hash_numpy(data)
        elif sparse.issparse(data):
            return self._hash_sparse(data)
        else:
            raise TypeError(f"Unsupported data type: {type(data)}")

    def _hash_pandas(self, df: pd.DataFrame) -> _DatasetFingerprint:
        """
        Optimized Pandas hashing using pd.util.hash_pandas_object.
        Avoids object-to-string conversion overhead.
        """
        method = self._determine_method(self.data_size, self.method)
        self.verbose and logger.info(
            f"Hashing Pandas: {self.data_size:,} rows - {method}"
        )
        target_df = df
        if method == "sampled":
            target_df = self._get_pandas_sample(df)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns.tolist(),
                "dtypes": {k: str(v) for (k, v) in df.dtypes.items()},
                "shape": df.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(target_df, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception as e:
            self.verbose and logger.warning(
                f"Fast hash failed, falling back to slow hash: {e}"
            )
            self._hash_pandas_fallback(hasher, target_df)
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas",
            method=method,
        )

    def _get_pandas_sample(self, df: pd.DataFrame) -> pd.DataFrame:
        """Deterministic slicing for sampling (Zero randomness)."""
        if self.data_size <= self.sample_size:
            return df
        chunk = self.sample_size // 3
        head = df.iloc[:chunk]
        mid_idx = self.data_size // 2
        mid = df.iloc[mid_idx : mid_idx + chunk]
        tail = df.iloc[-chunk:]
        return pd.concat([head, mid, tail])

    def _hash_pandas_fallback(self, hasher, df: pd.DataFrame):
        """Legacy fallback for complex object types."""
        for col in df.columns:
            val = df[col].astype(str).values
            hasher.update(val.astype(np.bytes_).tobytes())

    def _hash_polars(self, df: pl.DataFrame) -> _DatasetFingerprint:
        """
        Optimized Polars hashing using native Rust execution.
        """
        method = self._determine_method(self.data_size, self.method)
        self.verbose and logger.info(
            f"Hashing Polars: {self.data_size:,} rows - {method}"
        )
        target_df = df
        if method == "sampled" and self.data_size > self.sample_size:
            indices = self._get_sample_indices(self.data_size, self.sample_size)
            target_df = df.gather(indices)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns,
                "dtypes": [str(t) for t in df.dtypes],
                "shape": df.shape,
            },
        )
        row_hashes = target_df.hash_rows()
        hasher.update(memoryview(row_hashes.to_numpy()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars",
            method=method,
        )

    def _hash_pandas_series(self, series: pd.Series) -> _DatasetFingerprint:
        """Hash Pandas Series using the fastest vectorized method."""
        self.verbose and logger.info(f"Hashing Pandas Series: {self.data_size:,} rows")
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "name": series.name if series.name else "None",
                "dtype": str(series.dtype),
                "shape": series.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(series, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception as e:
            self.verbose and logger.warning(f"Series hash failed, falling back: {e}")
            hasher.update(memoryview(series.astype(str).values.tobytes()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas_series",
            method="full",
        )

    def _hash_polars_series(self, series: pl.Series) -> _DatasetFingerprint:
        """Hash Polars Series using native Polars expressions."""
        self.verbose and logger.info(f"Hashing Polars Series: {self.data_size:,} rows")
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"name": series.name, "dtype": str(series.dtype), "shape": series.shape},
        )
        try:
            row_hashes = series.hash()
            hasher.update(memoryview(row_hashes.to_numpy()))
        except Exception as e:
            self.verbose and logger.warning(
                f"Polars series native hash failed. Falling back."
            )
            hasher.update(str(series.to_list()).encode())
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars_series",
            method="full",
        )

    def _hash_numpy(self, arr: np.ndarray) -> _DatasetFingerprint:
        """
        Optimized NumPy hashing using Buffer Protocol (Zero-Copy).
        """
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"shape": arr.shape, "dtype": str(arr.dtype), "strides": arr.strides},
        )
        if arr.flags["C_CONTIGUOUS"] or arr.flags["F_CONTIGUOUS"]:
            hasher.update(memoryview(arr))
        else:
            hasher.update(memoryview(np.ascontiguousarray(arr)))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=arr.shape,
            computed_at=datetime.now().isoformat(),
            data_type="numpy",
            method="full",
        )

    def _hash_sparse(self, matrix: sparse.spmatrix) -> _DatasetFingerprint:
        """
        Optimized sparse hashing. Hashes underlying data arrays directly.
        """
        if not (sparse.isspmatrix_csr(matrix) or sparse.isspmatrix_csc(matrix)):
            matrix = matrix.tocsr()
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher, {"shape": matrix.shape, "format": matrix.format, "nnz": matrix.nnz}
        )
        hasher.update(memoryview(matrix.data))
        hasher.update(memoryview(matrix.indices))
        hasher.update(memoryview(matrix.indptr))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=matrix.shape,
            computed_at=datetime.now().isoformat(),
            data_type=f"sparse_{matrix.format}",
            method="sparse",
        )

    def _determine_method(self, rows: int, requested: str) -> str:
        if requested != "auto":
            return requested
        if rows < 5000000:
            return "full"
        return "sampled"

    def _hash_schema(self, hasher, schema: Dict[str, Any]):
        """Compact schema hashing."""
        hasher.update(
            json.dumps(schema, sort_keys=True, separators=(",", ":")).encode()
        )

    def _get_sample_indices(self, total_rows: int, sample_size: int) -> list:
        """Calculate indices for sampling without generating full range lists."""
        chunk = sample_size // 3
        indices = list(range(min(chunk, total_rows)))
        mid_start = max(0, total_rows // 2 - chunk // 2)
        mid_end = min(mid_start + chunk, total_rows)
        indices.extend(range(mid_start, mid_end))
        last_start = max(0, total_rows - chunk)
        indices.extend(range(last_start, total_rows))
        return sorted(list(set(indices)))


def _normalize_hpo_df(df):
    df = df.copy()
    param_cols = [c for c in df.columns if c.startswith("params_")]
    df[param_cols] = df[param_cols].astype("string[pyarrow]")
    return df


def _validate_numerical_data(data):
    """
    Validates if the input data (NumPy array, Pandas DataFrame/Series,
    Polars DataFrame/Series, or SciPy sparse matrix) contains only
    numerical (integer, float) or boolean values.

    Args:
        data: The input data structure to check.

    Raises:
        TypeError: If the input data contains non-numerical and non-boolean types.
        ValueError: If the input data is of an unsupported type.
    """
    if sparse.issparse(data):
        if not (
            np.issubdtype(data.dtype, np.number) or np.issubdtype(data.dtype, np.bool_)
        ):
            raise TypeError(
                f"Sparse matrix contains unsupported data type: {data.dtype}. Only numerical or boolean types are allowed."
            )
        return
    elif isinstance(data, np.ndarray):
        if not (
            np.issubdtype(data.dtype, np.number) or np.issubdtype(data.dtype, np.bool_)
        ):
            raise TypeError(
                f"NumPy array contains unsupported data type: {data.dtype}. Only numerical or boolean types are allowed."
            )
        return
    elif isinstance(data, (pd.DataFrame, pd.Series)):
        d_types = data.dtypes.apply(lambda x: x.kind)
        non_numerical_mask = ~d_types.isin(["i", "f", "b"])
        if non_numerical_mask.any():
            non_numerical_columns = (
                data.columns[non_numerical_mask].tolist()
                if isinstance(data, pd.DataFrame)
                else [data.name]
            )
            raise TypeError(
                f"Pandas {('DataFrame' if isinstance(data, pd.DataFrame) else 'Series')} contains non-numerical/boolean data. Offending column(s) and types: {data.dtypes[non_numerical_mask].to_dict()}"
            )
        return
    elif isinstance(data, (pl.DataFrame, pl.Series)):
        pl_numerical_types = [
            pl.Int8,
            pl.Int16,
            pl.Int32,
            pl.Int64,
            pl.UInt8,
            pl.UInt16,
            pl.UInt32,
            pl.UInt64,
            pl.Float32,
            pl.Float64,
            pl.Boolean,
        ]
        if isinstance(data, pl.DataFrame):
            for col, dtype in data.schema.items():
                if dtype not in pl_numerical_types:
                    raise TypeError(
                        f"Polars DataFrame column '{col}' has unsupported data type: {dtype}. Only numerical or boolean types are allowed."
                    )
        elif isinstance(data, pl.Series):
            if data.dtype not in pl_numerical_types:
                raise TypeError(
                    f"Polars Series has unsupported data type: {data.dtype}. Only numerical or boolean types are allowed."
                )
        return
    else:
        raise ValueError(
            f"Unsupported data type provided: {type(data)}. Function supports NumPy, Pandas, Polars, and SciPy sparse matrices."
        )


def _get_shape_and_sparsity(X: Any) -> Tuple[int, int, float, bool]:
    """
    Efficiently extracts shape and estimates sparsity without converting
    the entire dataset to numpy.
    """
    (n_samples, n_features) = (0, 0)
    is_sparse = False
    sparsity = 0.0
    if hasattr(X, "nnz") and hasattr(X, "shape"):
        (n_samples, n_features) = X.shape
        is_sparse = True
        sparsity = 1.0 - X.nnz / (n_samples * n_features)
        return (n_samples, n_features, sparsity, is_sparse)
    if hasattr(X, "height") and hasattr(X, "width"):
        (n_samples, n_features) = (X.height, X.width)
        return (n_samples, n_features, 0.0, False)
    if hasattr(X, "shape") and hasattr(X, "iloc"):
        (n_samples, n_features) = X.shape
        return (n_samples, n_features, 0.0, False)
    if isinstance(X, list):
        X = np.array(X)
    if hasattr(X, "shape"):
        (n_samples, n_features) = X.shape
        return (n_samples, n_features, 0.0, False)
    raise ValueError("Unsupported data type")


def _choose_class_weight_logreg(y):
    from collections import Counter

    counts = np.array(list(Counter(y).values()))
    r = counts.max() / counts.min()
    if r <= 2:
        return None
    return "balanced"


def _smart_split(
    n_samples,
    X,
    y,
    *,
    random_state=42,
    shuffle=True,
    stratify=None,
    fixed_test_split=None,
    verbose=True,
):
    """
    Parameters
    ----------
    n_samples : int
        Number of samples in the dataset (len(X) or len(y))
    X : array-like
        Features
    y : array-like
        Target
    random_state : int
    shuffle : bool
    stratify : array-like or None
        For stratified splitting (recommended for classification)

    Returns
    -------
    If return_val=True  → X_train, X_val, X_test, y_train, y_val, y_test
    If return_val=False → X_train, X_test, y_train, y_test
    """
    if fixed_test_split:
        test_ratio = fixed_test_split
        val_ratio = fixed_test_split
    elif n_samples <= 1000:
        test_ratio = 0.2
        val_ratio = 0.1
    elif n_samples < 10000:
        test_ratio = 0.15
        val_ratio = 0.15
    elif n_samples < 100000:
        test_ratio = 0.1
        val_ratio = 0.1
    elif n_samples < 1000000:
        test_ratio = 0.05
        val_ratio = 0.05
    else:
        test_ratio = 0.01
        val_ratio = 0.01
    (X_train, X_test, y_train, y_test) = train_test_split(
        X,
        y,
        test_size=test_ratio,
        random_state=random_state,
        shuffle=shuffle,
        stratify=stratify,
    )
    val_size_in_train = val_ratio / (1 - test_ratio)
    verbose and logger.info(
        f"Split → Train: {1 - test_ratio:.2%} | Test: {test_ratio:.2%} (no validation set)"
    )
    return (X_train, X_test, y_train, y_test, val_size_in_train)


def _get_best_metric_average_strategy(y_true, balance_threshold: float = 0.5) -> str:
    """
    Analyzes y_true to determine the best averaging strategy.

    Args:
        y_true: Input array (Numpy array, Pandas Series, or Polars Series).
        balance_threshold: Float (0 to 1). If min_class_count / max_class_count
                           is below this, the data is considered imbalanced.

    Returns:
        str: 'binary', 'weighted', or 'macro'
    """
    counts = None
    if hasattr(y_true, "value_counts") and hasattr(y_true, "values"):
        counts = y_true.value_counts().values
    elif hasattr(y_true, "value_counts") and hasattr(y_true, "to_numpy"):
        vc = y_true.value_counts()
        if "count" in vc.columns:
            counts = vc["count"].to_numpy()
        else:
            counts = vc[:, 1].to_numpy()
    elif isinstance(y_true, np.ndarray):
        (_, counts) = np.unique(y_true, return_counts=True)
    else:
        (_, counts) = np.unique(np.array(y_true), return_counts=True)
    if counts is None or len(counts) == 0:
        raise ValueError("Input y_true appears to be empty.")
    n_classes = len(counts)
    if n_classes <= 2:
        return "binary"
    min_c = np.min(counts)
    max_c = np.max(counts)
    ratio = min_c / max_c
    if ratio < balance_threshold:
        return "weighted"
    else:
        return "macro"


def _ensure_feature_names(X, feature_names=None):
    if isinstance(X, pd.DataFrame):
        return list(X.columns)
    if isinstance(X, np.ndarray):
        if feature_names is None:
            feature_names = [f"feature_{i}" for i in range(X.shape[1])]
        return feature_names
    raise TypeError("X must be a pandas DataFrame or numpy ndarray")


def _perform_cross_validation(
    model,
    X,
    y,
    cv_folds,
    average_strategy,
    shuffle,
    random_state,
    n_jobs,
    verbose,
    pos_label=None,
) -> dict[str, Any]:
    """Perform cross-validation on the model."""
    verbose and logger.info(f"Performing {cv_folds}-fold cross-validation...")
    cv = StratifiedKFold(n_splits=cv_folds, shuffle=shuffle, random_state=random_state)
    if average_strategy == "binary":
        scoring = {
            "accuracy": "accuracy",
            "precision": make_scorer(
                precision_score, average="binary", pos_label=pos_label
            ),
            "recall": make_scorer(recall_score, average="binary", pos_label=pos_label),
            "f1": make_scorer(f1_score, average="binary", pos_label=pos_label),
            "roc_auc": "roc_auc",
        }
    else:
        average_strategy_suffix = f"_{average_strategy}"
        roc_average_strategy_suffix = (
            f"_{average_strategy}" if average_strategy == "weighted" else ""
        )
        roc_auc_ovr_suffix = "_ovr"
        scoring = (
            f"f1{average_strategy_suffix}",
            "accuracy",
            f"precision{average_strategy_suffix}",
            f"recall{average_strategy_suffix}",
            f"roc_auc{roc_auc_ovr_suffix}{roc_average_strategy_suffix}",
        )
    cv_results = cross_validate(
        model, X, y, cv=cv, scoring=scoring, return_train_score=True, n_jobs=n_jobs
    )

    def get_score_mean_std(metric_key):
        if metric_key in cv_results:
            return (cv_results[metric_key].mean(), cv_results[metric_key].std())
        return (0.0, 0.0)

    if average_strategy == "binary":
        (accuracy_mean, accuracy_std) = get_score_mean_std("test_accuracy")
        (precision_mean, precision_std) = get_score_mean_std("test_precision")
        (recall_mean, recall_std) = get_score_mean_std("test_recall")
        (f1_mean, f1_std) = get_score_mean_std("test_f1")
        (roc_auc_mean, roc_auc_std) = get_score_mean_std("test_roc_auc")
    else:
        (accuracy_mean, accuracy_std) = get_score_mean_std("test_accuracy")
        (precision_mean, precision_std) = get_score_mean_std(
            f"test_precision{average_strategy_suffix}"
        )
        (recall_mean, recall_std) = get_score_mean_std(
            f"test_recall{average_strategy_suffix}"
        )
        (f1_mean, f1_std) = get_score_mean_std(f"test_f1{average_strategy_suffix}")
        roc_key = f"test_roc_auc{roc_auc_ovr_suffix}{roc_average_strategy_suffix}"
        (roc_auc_mean, roc_auc_std) = get_score_mean_std(roc_key)
    verbose and logger.info(
        f"CV Accuracy  : {accuracy_mean:.4f} (+/- {accuracy_std:.4f})"
    )
    verbose and logger.info(
        f"CV Precision : {precision_mean:.4f} (+/- {precision_std:.4f})"
    )
    verbose and logger.info(f"CV Recall    : {recall_mean:.4f} (+/- {recall_std:.4f})")
    verbose and logger.info(f"CV F1 Score  : {f1_mean:.4f} (+/- {f1_std:.4f})")
    verbose and logger.info(
        f"CV ROC-AUC   : {roc_auc_mean:.4f} (+/- {roc_auc_std:.4f})"
    )
    CV_metrics = pd.DataFrame(
        {
            "Metric": ["Accuracy", "Precision", "Recall", "F1-Score", "ROC AUC"],
            "Mean": [accuracy_mean, precision_mean, recall_mean, f1_mean, roc_auc_mean],
            "Std": [accuracy_std, precision_std, recall_std, f1_std, roc_auc_std],
        }
    )
    return CV_metrics


def _compute_score(model, X, y, metric, average_strategy, pos_label=None):
    score_params = {"average": average_strategy, "zero_division": 0}
    y_pred = model.predict(X)
    if average_strategy != "binary":
        y_score = model.predict_proba(X)
    else:
        score_params["pos_label"] = pos_label
        if pos_label is not None:
            classes = list(model.classes_)
            try:
                pos_idx = classes.index(pos_label)
            except ValueError:
                pos_idx = 1 if len(classes) > 1 else 0
            y_score = model.predict_proba(X)[:, pos_idx]
        else:
            y_score = model.predict_proba(X)[:, 1]
    if metric == "Accuracy":
        score = accuracy_score(y, y_pred)
    elif metric == "Precision":
        score = precision_score(y, y_pred, **score_params)
    elif metric == "Recall":
        score = recall_score(y, y_pred, **score_params)
    elif metric == "F1 Score":
        score = f1_score(y, y_pred, **score_params)
    elif metric == "ROC-AUC":
        if average_strategy != "binary":
            score = roc_auc_score(
                y, y_score, multi_class="ovr", average=average_strategy
            )
        else:
            score = roc_auc_score(y, y_score)
    return score


def _get_cv_scoring_object(metric, average_strategy, pos_label=None):
    """
    Returns a scoring object (string or callable) suitable for cross_validate.
    Used during HPO.
    """
    if average_strategy == "binary":
        if metric == "F1 Score":
            return make_scorer(f1_score, average="binary", pos_label=pos_label)
        elif metric == "Accuracy":
            return "accuracy"
        elif metric == "Precision":
            return make_scorer(precision_score, average="binary", pos_label=pos_label)
        elif metric == "Recall":
            return make_scorer(recall_score, average="binary", pos_label=pos_label)
        elif metric == "ROC-AUC":
            return "roc_auc"
    else:
        average_strategy_suffix = f"_{average_strategy}"
        roc_auc_ovr_suffix = "_ovr"
        roc_average_strategy_suffix = (
            f"_{average_strategy}" if average_strategy == "weighted" else ""
        )
        if metric == "F1 Score":
            return f"f1{average_strategy_suffix}"
        elif metric == "Accuracy":
            return "accuracy"
        elif metric == "Precision":
            return f"precision{average_strategy_suffix}"
        elif metric == "Recall":
            return f"recall{average_strategy_suffix}"
        elif metric == "ROC-AUC":
            return f"roc_auc{roc_auc_ovr_suffix}{roc_average_strategy_suffix}"


def _hyperparameters_optimization(
    X,
    y,
    constant_hyperparameters,
    optimization_metric,
    metric_average_strategy,
    val_ratio,
    shuffle_split,
    stratify_split,
    use_cross_val,
    cv_folds,
    n_trials=50,
    strategy="maximize",
    sampler="Tree-structured Parzen",
    standard_scaling=False,
    seed=None,
    n_jobs=-1,
    verbose=False,
    pos_label=None,
):
    direction = "maximize" if strategy.lower() == "maximize" else "minimize"
    sampler_map = {
        "Tree-structured Parzen": TPESampler(seed=seed),
        "Gaussian Process": GPSampler(seed=seed),
        "CMA-ES": CmaEsSampler(seed=seed),
        "Random Search": RandomSampler(seed=seed),
        "Random Sobol Search": QMCSampler(seed=seed),
    }
    if sampler in sampler_map:
        chosen_sampler = sampler_map[sampler]
    else:
        logger.warning(f"Sampler '{sampler}' not recognized → falling back to TPE")
        chosen_sampler = TPESampler(seed=seed)
    chosen_pruner = HyperbandPruner()
    if use_cross_val:
        cv = StratifiedKFold(
            n_splits=cv_folds, shuffle=shuffle_split, random_state=seed
        )
        cv_score_obj = _get_cv_scoring_object(
            optimization_metric, metric_average_strategy, pos_label
        )
    else:
        (X_train, X_val, y_train, y_val) = train_test_split(
            X,
            y,
            test_size=val_ratio,
            random_state=seed,
            stratify=y if stratify_split else None,
            shuffle=shuffle_split,
        )

    def logging_callback(study: Study, trial: FrozenTrial):
        """Callback function to log trial progress"""
        verbose and logger.info(
            f"Trial {trial.number} finished with value: {trial.value} and parameters: {trial.params}"
        )
        try:
            verbose and logger.info(f"Best value so far: {study.best_value}")
            verbose and logger.info(f"Best parameters so far: {study.best_params}")
        except ValueError:
            verbose and logger.info(f"No successful trials completed yet")
        verbose and logger.info(f"" + "-" * 50)

    def objective(trial):
        try:
            C_value = trial.suggest_float("C", 0.0001, 1000.0, log=True)
            kernel = trial.suggest_categorical(
                "kernel", ["linear", "poly", "rbf", "sigmoid"]
            )
            degree = trial.suggest_int("degree", 2, 5) if kernel == "poly" else 3
            gamma_type = trial.suggest_categorical(
                "gamma_type", ["scale", "auto", "numerical"]
            )
            gamma = (
                trial.suggest_float("gamma", 0.0001, 10.0, log=True)
                if gamma_type == "numerical"
                else gamma_type
            )
            tol = trial.suggest_float("tol", 0.0001, 1, log=True)
            class_weight = constant_hyperparameters.get("class_weight")
            max_iter = constant_hyperparameters.get("max_iter")
            model = SVC(
                C=C_value,
                kernel=kernel,
                degree=degree,
                gamma=gamma,
                tol=tol,
                max_iter=max_iter,
                class_weight=class_weight,
                probability=True,
                random_state=seed,
            )
            if standard_scaling:
                model = Pipeline([("scaler", StandardScaler()), ("lr", model)])
            if use_cross_val:
                scores = cross_validate(
                    model, X, y, cv=cv, n_jobs=n_jobs, scoring=cv_score_obj
                )
                return scores["test_score"].mean()
            else:
                model.fit(X_train, y_train)
                score = _compute_score(
                    model,
                    X_val,
                    y_val,
                    optimization_metric,
                    metric_average_strategy,
                    pos_label,
                )
                return score
        except Exception as e:
            verbose and logger.error(
                f"Trial {trial.number} failed with error: {str(e)}"
            )
            raise

    study = create_study(
        direction=direction, sampler=chosen_sampler, pruner=chosen_pruner
    )
    study.optimize(
        objective,
        n_trials=n_trials,
        catch=(Exception,),
        n_jobs=n_jobs,
        callbacks=[logging_callback],
    )
    verbose and logger.info(f"Optimization completed!")
    verbose and logger.info(f"   Best C            : {study.best_params['C']:.6e}")
    verbose and logger.info(f"   Best Kernel       : {study.best_params['kernel']}")
    verbose and logger.info(
        f"   Best Poly Degree  : {study.best_params.get('degree', 3)}"
    )
    verbose and logger.info(f"   Best Gamma        : {study.best_params['gamma_type']}")
    if study.best_params["gamma_type"] == "numerical":
        verbose and logger.info(
            f"   Best Gamma Value  : {study.best_params['gamma']:.6e}"
        )
    verbose and logger.info(f"   Best Tolerance    : {study.best_params['tol']:.6e}")
    verbose and logger.info(
        f"   Best {optimization_metric:<13}: {study.best_value:.4f}"
    )
    verbose and logger.info(f"   Sampler used      : {sampler}")
    verbose and logger.info(f"   Direction         : {direction}")
    if use_cross_val:
        verbose and logger.info(f"   Cross-validation  : {cv_folds}-fold")
    else:
        verbose and logger.info(f"   Validation        : single train/val split")
    trials = study.trials_dataframe()
    trials["best_value"] = trials["value"].cummax()
    cols = list(trials.columns)
    value_idx = cols.index("value")
    cols = [c for c in cols if c != "best_value"]
    new_order = cols[: value_idx + 1] + ["best_value"] + cols[value_idx + 1 :]
    trials = trials[new_order]
    return (study.best_params, trials)


def _combine_test_data(
    X_test, y_true, y_pred, y_proba, class_names, features_names=None
):
    """
    Combine X_test, y_true, y_pred, and y_proba into a single DataFrame.

    Parameters:
    -----------
    X_test : pandas/polars DataFrame, numpy array, or scipy sparse matrix
        Test features
    y_true : pandas/polars Series, numpy array, or list
        True labels
    y_pred : pandas/polars Series, numpy array, or list
        Predicted labels
    y_proba : pandas/polars Series/DataFrame, numpy array (1D or 2D), or list
        Prediction probabilities - can be:
        - 1D array for binary classification (probability of positive class)
        - 2D array for multiclass (probabilities for each class)
    class_names : list or array-like
        Names of the classes in order.
        For binary classification with 1D y_proba, only the positive class name is needed.

    Returns:
    --------
    pandas.DataFrame
        Combined DataFrame with features, probabilities, y_true, and y_pred
    """
    if sparse.issparse(X_test):
        X_df = pd.DataFrame(X_test.toarray())
    elif isinstance(X_test, np.ndarray):
        X_df = pd.DataFrame(X_test)
    elif hasattr(X_test, "to_pandas"):
        X_df = X_test.to_pandas()
    elif isinstance(X_test, pd.DataFrame):
        X_df = X_test.copy()
    else:
        raise TypeError(f"Unsupported type for X_test: {type(X_test)}")
    if X_df.columns.tolist() == list(range(len(X_df.columns))):
        X_df.columns = (
            [f"feature_{i}" for i in range(len(X_df.columns))]
            if features_names is None
            else features_names
        )
    if isinstance(y_true, list):
        y_true_series = pd.Series(y_true, name="y_true")
    elif isinstance(y_true, np.ndarray):
        y_true_series = pd.Series(y_true, name="y_true")
    elif hasattr(y_true, "to_pandas"):
        y_true_series = y_true.to_pandas()
        y_true_series.name = "y_true"
    elif isinstance(y_true, pd.Series):
        y_true_series = y_true.copy()
        y_true_series.name = "y_true"
    else:
        raise TypeError(f"Unsupported type for y_true: {type(y_true)}")
    if isinstance(y_pred, list):
        y_pred_series = pd.Series(y_pred, name="y_pred")
    elif isinstance(y_pred, np.ndarray):
        y_pred_series = pd.Series(y_pred, name="y_pred")
    elif hasattr(y_pred, "to_pandas"):
        y_pred_series = y_pred.to_pandas()
        y_pred_series.name = "y_pred"
    elif isinstance(y_pred, pd.Series):
        y_pred_series = y_pred.copy()
        y_pred_series.name = "y_pred"
    else:
        raise TypeError(f"Unsupported type for y_pred: {type(y_pred)}")
    if isinstance(y_proba, list):
        y_proba_array = np.array(y_proba)
    elif isinstance(y_proba, np.ndarray):
        y_proba_array = y_proba
    elif hasattr(y_proba, "to_pandas"):
        y_proba_pd = y_proba.to_pandas()
        if isinstance(y_proba_pd, pd.Series):
            y_proba_array = y_proba_pd.values
        else:
            y_proba_array = y_proba_pd.values
    elif isinstance(y_proba, pd.Series):
        y_proba_array = y_proba.values
    elif isinstance(y_proba, pd.DataFrame):
        y_proba_array = y_proba.values
    else:
        raise TypeError(f"Unsupported type for y_proba: {type(y_proba)}")

    def sanitize_class_name(class_name):
        """Convert class name to valid column name by replacing spaces and special chars"""
        return str(class_name).replace(" ", "_").replace("-", "_")

    if y_proba_array.ndim == 1:
        y_proba_df = pd.DataFrame({"proba": y_proba_array})
    else:
        n_classes = y_proba_array.shape[1]
        if len(class_names) == n_classes:
            proba_columns = [f"proba_{sanitize_class_name(cls)}" for cls in class_names]
        else:
            proba_columns = [f"proba_{i}" for i in range(n_classes)]
        y_proba_df = pd.DataFrame(y_proba_array, columns=proba_columns)
    y_proba_df = y_proba_df.reset_index(drop=True)
    X_df = X_df.reset_index(drop=True)
    y_true_series = y_true_series.reset_index(drop=True)
    y_pred_series = y_pred_series.reset_index(drop=True)
    is_false_prediction = pd.Series(
        y_true_series != y_pred_series, name="is_false_prediction"
    ).reset_index(drop=True)
    result_df = pd.concat(
        [X_df, y_proba_df, y_true_series, y_pred_series, is_false_prediction], axis=1
    )
    return result_df


def _smart_shap_background(
    X: Union[np.ndarray, pd.DataFrame],
    model_type: str = "tree",
    seed: int = 42,
    verbose: bool = True,
) -> Union[np.ndarray, pd.DataFrame, object]:
    """
    Intelligently prepares a background dataset for SHAP based on model type.

    Strategies:
    - Tree: Higher sample cap (1000), uses Random Sampling (preserves data structure).
    - Other: Lower sample cap (100), uses K-Means (maximizes info density).
    """
    (n_rows, n_features) = X.shape
    if model_type == "tree":
        max_samples = 1000
        use_kmeans = False
    else:
        max_samples = 100
        use_kmeans = True
    if n_rows <= max_samples:
        verbose and logger.info(
            f"✓ Dataset small ({n_rows} <= {max_samples}). Using full data."
        )
        return X
    verbose and logger.info(
        f"⚡ Large dataset detected ({n_rows} rows). Optimization Strategy: {('K-Means' if use_kmeans else 'Random Sampling')}"
    )
    if use_kmeans:
        try:
            verbose and logger.info(
                f"   Summarizing to {max_samples} weighted centroids..."
            )
            return shap.kmeans(X, max_samples)
        except Exception as e:
            logger.warning(
                f"   K-Means failed ({str(e)}). Falling back to random sampling."
            )
            return shap.sample(X, max_samples, random_state=seed)
    else:
        verbose and logger.info(f"   Sampling {max_samples} random rows...")
        return shap.sample(X, max_samples, random_state=seed)


def _class_index_df(model):
    columns = {"index": pd.Series(dtype="int64"), "class": pd.Series(dtype="object")}
    if model is None:
        return pd.DataFrame(columns)
    classes = getattr(model, "classes_", None)
    if classes is None:
        return pd.DataFrame(columns)
    return pd.DataFrame({"index": range(len(classes)), "class": classes})


def train_cla_svm(
    X: DataFrame, y: Union[DataSeries, NDArray, List], options=None
) -> Tuple[
    Any,
    DataFrame,
    Any,
    Any,
    Union[DataFrame, Dict],
    DataFrame,
    DataFrame,
    DataFrame,
    Dict,
]:
    options = options or {}
    kernel = options.get("kernel", "rbf").lower()
    kernel = "poly" if kernel == "polynomial" else kernel
    C_value = options.get("c_value", 1.0)
    degree = options.get("degree", 3)
    gamma = options.get("gamma", "scale").lower()
    gamma_float = options.get("gamma_float", 0.1)
    tol = options.get("tol", 0.001)
    max_iter = options.get("max_iterations", -1)
    if gamma == "numerical":
        gamma = gamma_float
    standard_scaling = options.get("standard_scaling", True)
    auto_split = options.get("auto_split", True)
    test_val_size = options.get("test_val_size", 15) / 100
    shuffle_split = options.get("shuffle_split", True)
    stratify_split = options.get("stratify_split", True)
    retrain_on_full = options.get("retrain_on_full", False)
    custom_average_strategy = options.get("custom_average_strategy", "auto")
    use_cross_validation = options.get("use_cross_validation", False)
    cv_folds = options.get("cv_folds", 5)
    use_hpo = options.get("use_hyperparameter_optimization", False)
    optimization_metric = options.get("optimization_metric", "F1 Score")
    optimization_method = options.get("optimization_method", "Tree-structured Parzen")
    optimization_iterations = options.get("optimization_iterations", 50)
    pos_label_option = options.get("pos_label", "").strip()
    if pos_label_option == "":
        pos_label_option = None
    return_shap_explainer = options.get("return_shap_explainer", False)
    use_shap_sampler = options.get("use_shap_sampler", False)
    metrics_as = options.get("metrics_as", "Dataframe")
    n_jobs_str = options.get("n_jobs", "1")
    random_state = options.get("random_state", 42)
    activate_caching = options.get("activate_caching", False)
    verbose = options.get("verbose", True)
    n_jobs_int = -1 if n_jobs_str == "All" else int(n_jobs_str)
    skip_computation = False
    Scaler = None
    Model = None
    Metrics = pd.DataFrame()
    CV_Metrics = pd.DataFrame()
    SHAP = None
    HPO_Trials = pd.DataFrame()
    HPO_Best = None
    accuracy = None
    precision = None
    recall = None
    f1 = None
    roc_auc = None
    (n_samples, n_features, sparsity, is_sparse) = _get_shape_and_sparsity(X)
    shap_feature_names = _ensure_feature_names(X)
    if standard_scaling:
        verbose and logger.info("Standard scaling is activated")
    if activate_caching:
        verbose and logger.info(f"Caching is activate")
        data_hasher = _UniversalDatasetHasher(n_samples, verbose=verbose)
        X_hash = data_hasher.hash_data(X).hash
        y_hash = data_hasher.hash_data(y).hash
        all_hash_base_text = f"HASH BASE TEXTPandas Version {pd.__version__}POLARS Version {pl.__version__}Numpy Version {np.__version__}Scikit Learn Version {sklearn.__version__}Scipy Version {scipy.__version__}{('SHAP Version ' + shap.__version__ if return_shap_explainer else 'NO SHAP Version')}{X_hash}{y_hash}{kernel}{C_value}{degree}{gamma}{gamma_float}{tol}{max_iter}{('Use HPO' if use_hpo else 'No HPO')}{(optimization_metric if use_hpo else 'No HPO Metric')}{(optimization_method if use_hpo else 'No HPO Method')}{(optimization_iterations if use_hpo else 'No HPO Iter')}{(cv_folds if use_cross_validation else 'No CV')}{standard_scaling}{('Auto Split' if auto_split else test_val_size)}{shuffle_split}{stratify_split}{return_shap_explainer}{use_shap_sampler}{random_state}{(pos_label_option if pos_label_option else 'default_pos')}{custom_average_strategy}"
        all_hash = hashlib.sha256(all_hash_base_text.encode("utf-8")).hexdigest()
        verbose and logger.info(f"Hash was computed: {all_hash}")
        temp_folder = Path(tempfile.gettempdir())
        cache_folder = temp_folder / "coded-flows-cache"
        cache_folder.mkdir(parents=True, exist_ok=True)
        model_path = cache_folder / f"{all_hash}.model"
        metrics_dict_path = cache_folder / f"metrics_{all_hash}.json"
        metrics_df_path = cache_folder / f"metrics_{all_hash}.parquet"
        cv_metrics_path = cache_folder / f"cv_metrics_{all_hash}.parquet"
        hpo_trials_path = cache_folder / f"hpo_trials_{all_hash}.parquet"
        hpo_best_params_path = cache_folder / f"hpo_best_params_{all_hash}.json"
        prediction_set_path = cache_folder / f"prediction_set_{all_hash}.parquet"
        shap_path = cache_folder / f"{all_hash}.shap"
        scaler_path = cache_folder / f"{all_hash}.scaler"
        skip_computation = model_path.is_file()
    if not skip_computation:
        try:
            _validate_numerical_data(X)
        except Exception as e:
            verbose and logger.error(
                f"Only numerical or boolean types are allowed for 'X' input!"
            )
            raise
        if hasattr(y, "nunique"):
            n_classes = y.nunique()
        elif hasattr(y, "n_unique"):
            n_classes = y.n_unique()
        else:
            n_classes = len(np.unique(y))
        is_multiclass = n_classes > 2
        features_names = X.columns if hasattr(X, "columns") else None
        class_weight = _choose_class_weight_logreg(y)
        if class_weight == "balanced":
            verbose and logger.info(
                f"Class imbalance, using 'balanced' mode for class weight"
            )
        fixed_test_split = None if auto_split else test_val_size
        (X_train, X_test, y_train, y_test, val_ratio) = _smart_split(
            n_samples,
            X,
            y,
            random_state=random_state,
            shuffle=shuffle_split,
            stratify=y if stratify_split else None,
            fixed_test_split=fixed_test_split,
            verbose=verbose,
        )
        if custom_average_strategy == "auto":
            metric_average_strategy = _get_best_metric_average_strategy(y_test)
        else:
            metric_average_strategy = custom_average_strategy
        effective_pos_label = None
        if metric_average_strategy == "binary":
            unique_classes = np.unique(y_train)
            if pos_label_option is not None:
                if pos_label_option in unique_classes:
                    effective_pos_label = pos_label_option
                else:
                    try:
                        as_int = int(pos_label_option)
                        if as_int in unique_classes:
                            effective_pos_label = as_int
                    except ValueError:
                        pass
            elif 1 in unique_classes:
                effective_pos_label = 1
            elif "1" in unique_classes:
                effective_pos_label = "1"
            if effective_pos_label is not None:
                verbose and logger.info(f"Using positive label: {effective_pos_label}")
            elif effective_pos_label is None and metric_average_strategy == "binary":
                error_message = 'The target appears to be binary, but no positive label was provided and no "1" class exists in the label set.'
                verbose and logger.error(error_message)
                raise ValueError(error_message)
        if use_hpo:
            verbose and logger.info(f"Performing Hyperparameters Optimization")
            constant_hyperparameters = {
                "class_weight": class_weight,
                "max_iter": max_iter,
            }
            (HPO_Best, HPO_Trials) = _hyperparameters_optimization(
                X_train,
                y_train,
                constant_hyperparameters,
                optimization_metric,
                metric_average_strategy,
                val_ratio,
                shuffle_split,
                stratify_split,
                use_cross_validation,
                cv_folds,
                optimization_iterations,
                "maximize",
                optimization_method,
                random_state,
                n_jobs_int,
                verbose=verbose,
                pos_label=effective_pos_label,
            )
            HPO_Trials = _normalize_hpo_df(HPO_Trials)
            C_value = HPO_Best["C"]
            kernel = HPO_Best["kernel"]
            degree = HPO_Best.get("degree", 3)
            gamma_type = HPO_Best["gamma_type"]
            gamma = HPO_Best["gamma"] if gamma_type == "numerical" else gamma_type
            tol = HPO_Best["tol"]
        if standard_scaling:
            Scaler = StandardScaler().set_output(transform="pandas")
            X_train = Scaler.fit_transform(X_train)
        Model = SVC(
            C=C_value,
            kernel=kernel,
            degree=degree,
            gamma=gamma,
            tol=tol,
            max_iter=max_iter,
            class_weight=class_weight,
            probability=True,
            random_state=random_state,
        )
        if use_cross_validation and (not use_hpo):
            verbose and logger.info(
                f"Using Cross-Validation to measure performance metrics"
            )
            CV_Metrics = _perform_cross_validation(
                Model,
                X_train,
                y_train,
                cv_folds,
                metric_average_strategy,
                shuffle_split,
                random_state,
                n_jobs_int,
                verbose,
                pos_label=effective_pos_label,
            )
        Model.fit(X_train, y_train)
        y_pred = Model.predict(Scaler.transform(X_test) if standard_scaling else X_test)
        if is_multiclass:
            y_score = Model.predict_proba(
                Scaler.transform(X_test) if standard_scaling else X_test
            )
        elif effective_pos_label is not None:
            try:
                pos_idx = list(Model.classes_).index(effective_pos_label)
                y_score = Model.predict_proba(
                    Scaler.transform(X_test) if standard_scaling else X_test
                )[:, pos_idx]
            except ValueError:
                y_score = Model.predict_proba(
                    Scaler.transform(X_test) if standard_scaling else X_test
                )[:, 1]
        else:
            y_score = Model.predict_proba(
                Scaler.transform(X_test) if standard_scaling else X_test
            )[:, 1]
        score_params = {"average": metric_average_strategy, "zero_division": 0}
        if effective_pos_label:
            score_params["pos_label"] = effective_pos_label
        accuracy = accuracy_score(y_test, y_pred)
        precision = precision_score(y_test, y_pred, **score_params)
        recall = recall_score(y_test, y_pred, **score_params)
        f1 = f1_score(y_test, y_pred, **score_params)
        if is_multiclass:
            roc_auc = roc_auc_score(
                y_test, y_score, multi_class="ovr", average=metric_average_strategy
            )
        else:
            roc_auc = roc_auc_score(y_test, y_score)
        if metrics_as == "Dataframe":
            Metrics = pd.DataFrame(
                {
                    "Metric": [
                        "Accuracy",
                        "Precision",
                        "Recall",
                        "F1-Score",
                        "ROC AUC",
                    ],
                    "Value": [accuracy, precision, recall, f1, roc_auc],
                }
            )
        else:
            Metrics = {
                "Accuracy": accuracy,
                "Precision": precision,
                "Recall": recall,
                "F1-Score": f1,
                "ROC AUC": roc_auc,
            }
        verbose and logger.info(f"Accuracy  : {accuracy:.4f}")
        verbose and logger.info(f"Precision : {precision:.4f}")
        verbose and logger.info(f"Recall    : {recall:.4f}")
        verbose and logger.info(f"F1-Score  : {f1:.4f}")
        verbose and logger.info(f"ROC-AUC   : {roc_auc:.4f}")
        Prediction_Set = _combine_test_data(
            X_test, y_test, y_pred, y_score, Model.classes_, features_names
        )
        verbose and logger.info(f"Prediction Set created")
        if retrain_on_full:
            verbose and logger.info(
                "Retraining model on full dataset for production deployment"
            )
            if standard_scaling:
                Scaler = StandardScaler().set_output(transform="pandas")
                X = Scaler.fit_transform(X)
            Model.fit(X, y)
            verbose and logger.info(
                "Model successfully retrained on full dataset. Reported metrics remain from original held-out test set."
            )
        if return_shap_explainer:
            SHAP = shap.KernelExplainer(
                Model.predict_proba,
                (
                    _smart_shap_background(
                        X if retrain_on_full else X_train,
                        model_type="other",
                        seed=random_state,
                        verbose=verbose,
                    )
                    if use_shap_sampler
                    else X if retrain_on_full else X_train
                ),
                feature_names=shap_feature_names,
                link="logit",
            )
            verbose and logger.info(f"SHAP explainer generated")
        if activate_caching:
            verbose and logger.info(f"Caching output elements")
            joblib.dump(Model, model_path)
            if isinstance(Metrics, dict):
                with metrics_dict_path.open("w", encoding="utf-8") as f:
                    json.dump(Metrics, f, ensure_ascii=False, indent=4)
            else:
                Metrics.to_parquet(metrics_df_path)
            if use_cross_validation and (not use_hpo):
                CV_Metrics.to_parquet(cv_metrics_path)
            if use_hpo:
                HPO_Trials.to_parquet(hpo_trials_path)
                with hpo_best_params_path.open("w", encoding="utf-8") as f:
                    json.dump(HPO_Best, f, ensure_ascii=False, indent=4)
            Prediction_Set.to_parquet(prediction_set_path)
            if return_shap_explainer:
                with shap_path.open("wb") as f:
                    joblib.dump(SHAP, f)
            joblib.dump(Scaler, scaler_path)
            verbose and logger.info(f"Caching done")
    else:
        verbose and logger.info(f"Skipping computations and loading cached elements")
        Model = joblib.load(model_path)
        verbose and logger.info(f"Model loaded")
        if metrics_dict_path.is_file():
            with metrics_dict_path.open("r", encoding="utf-8") as f:
                Metrics = json.load(f)
        else:
            Metrics = pd.read_parquet(metrics_df_path)
        verbose and logger.info(f"Metrics loaded")
        if use_cross_validation and (not use_hpo):
            CV_Metrics = pd.read_parquet(cv_metrics_path)
            verbose and logger.info(f"Cross Validation metrics loaded")
        if use_hpo:
            HPO_Trials = pd.read_parquet(hpo_trials_path)
            with hpo_best_params_path.open("r", encoding="utf-8") as f:
                HPO_Best = json.load(f)
            verbose and logger.info(
                f"Hyperparameters Optimization trials and best params loaded"
            )
        Prediction_Set = pd.read_parquet(prediction_set_path)
        verbose and logger.info(f"Prediction Set loaded")
        if return_shap_explainer:
            with shap_path.open("rb") as f:
                SHAP = joblib.load(f)
            verbose and logger.info(f"SHAP Explainer loaded")
        Scaler = joblib.load(scaler_path)
        verbose and logger.info(f"Standard Scaler loaded")
    Model_Classes = _class_index_df(Model)
    return (
        Model,
        Model_Classes,
        SHAP,
        Scaler,
        Metrics,
        CV_Metrics,
        Prediction_Set,
        HPO_Trials,
        HPO_Best,
    )

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • scikit-learn
  • pandas
  • numpy
  • torch
  • numba>=0.56.0
  • shap
  • cmaes
  • optuna
  • scipy
  • polars
  • xxhash