Dim. Reduction PCA

Perform Principal Component Analysis and generate visualization, components, summary, and sorted feature loadings.

Dim. Reduction PCA

Processing

This brick performs Principal Component Analysis (PCA), a statistical method used to simplify complex datasets. It analyzes the numerical columns in your data to find the most significant patterns ("Principal Components") and reduces the number of dimensions while retaining as much information as possible.

Typically, this is used to:

  1. Reduce noise in data before machine learning.
  2. Visualize high-dimensional data (e.g., a table with 20 columns) in a simple 2D or 3D chart.
  3. Identify drivers, helping you understand which original features contribute most to the variance in your data.

The brick outputs the transformed data, a statistical summary, feature loadings, and a visualization of the results.

Inputs

data
The dataset you want to analyze. This should contain the numerical columns you wish to reduce (e.g., measurements, financial metrics, counts). Non-numeric columns are automatically ignored unless specified otherwise.

Inputs Types

Input Types
data DataFrame, ArrowTable

You can check the list of supported types here: Available Type Hints.

Outputs

PCA
The trained Scikit-Learn PCA model object. This can be passed to other scripts or bricks if you need to apply the exact same transformation logic to new, future data.
Scaler
The Scikit-Learn StandardScaler object used to normalize the data before analysis. This contains the mean and variance statistics of the input data.
PCA Image
A visualization of the analysis (e.g., a scatter plot or biplot), returned as an image. The format depends on the Output Type option (default is a generic image array).
PCA Components
The transformed dataset. Each row corresponds to a row in your original data, but the columns are now the calculated Principal Components (PC1, PC2, etc.) instead of the original features.
PCA Summary
A statistical summary table containing details about each component, such as how much variance (information) it captures.
PCA Loadings
A detailed breakdown of how your original columns relate to the new Principal Components. It shows the "weight" (loading) of each feature for every component.

The PCA Components output contains the following specific data fields:

  • {id_column}: (Optional) The identifier column you specified in the options.
  • PC1: The first Principal Component (captures the most variance).
  • PC2: The second Principal Component.
  • PC{n}: Additional components up to the "Number of Components" selected.

The PCA Summary output contains the following specific data fields:

  • Component: The name of the component (PC1, PC2, etc.).
  • Eigenvalue: The magnitude of the component.
  • Variance_Explained: Percentage of the dataset's information held by this component (e.g., 0.45 means 45%).
  • Cumulative_Variance: The running total of variance explained.

Outputs Types

Output Types
PCA Any
Scaler Any
PCA Image MediaData, PILImage
PCA Components DataFrame
PCA Summary DataFrame
PCA Loadings DataFrame

You can check the list of supported types here: Available Type Hints.

Options

The Dim. Reduction PCA brick contains some changeable options:

Columns for PCA
Select specific columns to include in the analysis. If left empty, the brick automatically selects all numeric columns in the dataset.
Number of Components
The number of dimensions (Principal Components) to keep in the result.
  • Default (2): Reduces data to 2 dimensions, ideal for 2D plotting.
  • Higher values: Retains more information but makes visualization harder.
Scale Data
Determines whether the data should be standardized (centered to mean 0, variance 1) before analysis.
  • True (Recommended): Ensures all columns contribute equally. Without this, a column with large numbers (e.g., Salary) would dominate a column with small numbers (e.g., Age).
  • False: Uses raw values. Only use this if your data is already on the same scale.
ID Column
The name of a column in your input data (e.g., "Customer_ID", "Product_Name") to preserve in the PCA_Components output. This allows you to map the mathematical results back to specific items.
Hue Column (for scatter)
The name of a categorical column in your input data to use for coloring the dots in the visualization. This helps identify clusters (e.g., coloring points by "Customer_Segment").
Exclude Hue
Determines if the column used for coloring (the Hue) should be part of the mathematical calculation. If the hue column is not numerical or boolean, it is automatically excluded.
  • True (Active): The Hue column is excluded from the dimensionality reduction. The algorithm ignores this data when calculating positions/clusters, using it only to assign colors in the final result. This is best when the Hue represents a label, outcome, or category (e.g., "Customer Type") that you want to visualize but don't want influencing the clustering logic.
  • False (Inactive): The Hue column is included in the dimensionality reduction. The values in this column will actively influence where the data points are positioned on the graph.
PC for X-axis
Which Principal Component to plot on the horizontal axis (default is 1).
PC for Y-axis
Which Principal Component to plot on the vertical axis (default is 2).
Plot Type
The style of the generated visualization.
  • scatter: A standard 2D chart showing the data points plotted against the selected components.
  • biplot: A scatter plot that also includes arrows (vectors) indicating the direction and strength of the original features.
  • variance: A bar chart showing how much information (variance) each component contains.
  • correlation_circle: A unit circle plot showing the correlations between original features and principal components, where each feature is represented as a point on or inside the circle.
  • all: Generates a composite image containing all three plots side-by-side.
Color Palette
The color scheme used for the chart (e.g., "husl", "deep", "bright").
Output Type
The technical format of the generated PCA Image.
  • array: Returns a NumPy array (standard for image processing).
  • pil: Returns a PIL Image object.
  • bytes: Returns the raw image file bytes (PNG format).
  • bytesio: Returns a BytesIO stream.
Random State
A seed number to ensure the results are reproducible. Using the same number ensures the algorithm behaves exactly the same way every time it runs.
Brick Caching
If enabled, the brick saves results to a temporary cache. If you run the flow again with the exact same data and settings, it loads the result instantly instead of recalculating.
Verbose
If enabled, detailed logs about the analysis process (e.g., variance explained, rows processed) will be printed to the console.
import logging
import io
import json
import xxhash
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import matplotlib
from scipy import sparse
from dataclasses import dataclass
from datetime import datetime
import hashlib
import tempfile
import sklearn
import scipy
import joblib
from pathlib import Path

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from sklearn.decomposition import PCA as _PCA
from sklearn.preprocessing import StandardScaler
from coded_flows.types import (
    Union,
    DataFrame,
    ArrowTable,
    MediaData,
    PILImage,
    Tuple,
    Any,
)
from coded_flows.utils import CodedFlowsLogger

logger = CodedFlowsLogger(name="Dim. Reduction PCA", level=logging.INFO)
DataType = Union[
    pd.DataFrame, pl.DataFrame, np.ndarray, sparse.spmatrix, pd.Series, pl.Series
]


@dataclass
class _DatasetFingerprint:
    """Lightweight fingerprint of a dataset."""

    hash: str
    shape: tuple
    computed_at: str
    data_type: str
    method: str


class _UniversalDatasetHasher:
    """
    High-performance dataset hasher optimizing for zero-copy operations
    and native backend execution (C/Rust).
    """

    def __init__(
        self,
        data_size: int,
        method: str = "auto",
        sample_size: int = 100000,
        verbose: bool = False,
    ):
        self.method = method
        self.sample_size = sample_size
        self.data_size = data_size
        self.verbose = verbose

    def hash_data(self, data: DataType) -> _DatasetFingerprint:
        """
        Main entry point: hash any supported data format.
        Auto-detects format and applies optimal strategy.
        """
        if isinstance(data, pd.DataFrame):
            return self._hash_pandas(data)
        elif isinstance(data, pl.DataFrame):
            return self._hash_polars(data)
        elif isinstance(data, pd.Series):
            return self._hash_pandas_series(data)
        elif isinstance(data, pl.Series):
            return self._hash_polars_series(data)
        elif isinstance(data, np.ndarray):
            return self._hash_numpy(data)
        elif sparse.issparse(data):
            return self._hash_sparse(data)
        else:
            raise TypeError(f"Unsupported data type: {type(data)}")

    def _hash_pandas(self, df: pd.DataFrame) -> _DatasetFingerprint:
        """
        Optimized Pandas hashing using pd.util.hash_pandas_object.
        Avoids object-to-string conversion overhead.
        """
        method = self._determine_method(self.data_size, self.method)
        self.verbose and logger.info(
            f"Hashing Pandas: {self.data_size:,} rows - {method}"
        )
        target_df = df
        if method == "sampled":
            target_df = self._get_pandas_sample(df)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns.tolist(),
                "dtypes": {k: str(v) for (k, v) in df.dtypes.items()},
                "shape": df.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(target_df, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception as e:
            self.verbose and logger.warning(
                f"Fast hash failed, falling back to slow hash: {e}"
            )
            self._hash_pandas_fallback(hasher, target_df)
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas",
            method=method,
        )

    def _get_pandas_sample(self, df: pd.DataFrame) -> pd.DataFrame:
        """Deterministic slicing for sampling (Zero randomness)."""
        if self.data_size <= self.sample_size:
            return df
        chunk = self.sample_size // 3
        head = df.iloc[:chunk]
        mid_idx = self.data_size // 2
        mid = df.iloc[mid_idx : mid_idx + chunk]
        tail = df.iloc[-chunk:]
        return pd.concat([head, mid, tail])

    def _hash_pandas_fallback(self, hasher, df: pd.DataFrame):
        """Legacy fallback for complex object types."""
        for col in df.columns:
            val = df[col].astype(str).values
            hasher.update(val.astype(np.bytes_).tobytes())

    def _hash_polars(self, df: pl.DataFrame) -> _DatasetFingerprint:
        """Optimized Polars hashing using native Rust execution."""
        method = self._determine_method(self.data_size, self.method)
        self.verbose and logger.info(
            f"Hashing Polars: {self.data_size:,} rows - {method}"
        )
        target_df = df
        if method == "sampled" and self.data_size > self.sample_size:
            indices = self._get_sample_indices(self.data_size, self.sample_size)
            target_df = df.gather(indices)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns,
                "dtypes": [str(t) for t in df.dtypes],
                "shape": df.shape,
            },
        )
        row_hashes = target_df.hash_rows()
        hasher.update(memoryview(row_hashes.to_numpy()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars",
            method=method,
        )

    def _hash_pandas_series(self, series: pd.Series) -> _DatasetFingerprint:
        """Hash Pandas Series using the fastest vectorized method."""
        self.verbose and logger.info(f"Hashing Pandas Series: {self.data_size:,} rows")
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "name": series.name if series.name else "None",
                "dtype": str(series.dtype),
                "shape": series.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(series, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception as e:
            self.verbose and logger.warning(f"Series hash failed, falling back: {e}")
            hasher.update(memoryview(series.astype(str).values.tobytes()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas_series",
            method="full",
        )

    def _hash_polars_series(self, series: pl.Series) -> _DatasetFingerprint:
        """Hash Polars Series using native Polars expressions."""
        self.verbose and logger.info(f"Hashing Polars Series: {self.data_size:,} rows")
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"name": series.name, "dtype": str(series.dtype), "shape": series.shape},
        )
        try:
            row_hashes = series.hash()
            hasher.update(memoryview(row_hashes.to_numpy()))
        except Exception as e:
            self.verbose and logger.warning(
                f"Polars series native hash failed. Falling back."
            )
            hasher.update(str(series.to_list()).encode())
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars_series",
            method="full",
        )

    def _hash_numpy(self, arr: np.ndarray) -> _DatasetFingerprint:
        """Optimized NumPy hashing using Buffer Protocol (Zero-Copy)."""
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"shape": arr.shape, "dtype": str(arr.dtype), "strides": arr.strides},
        )
        if arr.flags["C_CONTIGUOUS"] or arr.flags["F_CONTIGUOUS"]:
            hasher.update(memoryview(arr))
        else:
            hasher.update(memoryview(np.ascontiguousarray(arr)))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=arr.shape,
            computed_at=datetime.now().isoformat(),
            data_type="numpy",
            method="full",
        )

    def _hash_sparse(self, matrix: sparse.spmatrix) -> _DatasetFingerprint:
        """Optimized sparse hashing."""
        if not (sparse.isspmatrix_csr(matrix) or sparse.isspmatrix_csc(matrix)):
            matrix = matrix.tocsr()
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher, {"shape": matrix.shape, "format": matrix.format, "nnz": matrix.nnz}
        )
        hasher.update(memoryview(matrix.data))
        hasher.update(memoryview(matrix.indices))
        hasher.update(memoryview(matrix.indptr))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=matrix.shape,
            computed_at=datetime.now().isoformat(),
            data_type=f"sparse_{matrix.format}",
            method="sparse",
        )

    def _determine_method(self, rows: int, requested: str) -> str:
        if requested != "auto":
            return requested
        if rows < 5000000:
            return "full"
        return "sampled"

    def _hash_schema(self, hasher, schema):
        hasher.update(
            json.dumps(schema, sort_keys=True, separators=(",", ":")).encode()
        )

    def _get_sample_indices(self, total_rows: int, sample_size: int) -> list:
        chunk = sample_size // 3
        indices = list(range(min(chunk, total_rows)))
        mid_start = max(0, total_rows // 2 - chunk // 2)
        mid_end = min(mid_start + chunk, total_rows)
        indices.extend(range(mid_start, mid_end))
        last_start = max(0, total_rows - chunk)
        indices.extend(range(last_start, total_rows))
        return sorted(list(set(indices)))


def _plot_scatter(
    ax, components, hue_data, hue_col, palette, explained_var, pc_x, pc_y
):
    """Plot PCA scatter plot."""
    pc_x_idx = pc_x - 1
    pc_y_idx = pc_y - 1
    if hue_data is not None:
        scatter_df = pd.DataFrame(
            {
                f"PC{pc_x}": components[:, pc_x_idx],
                f"PC{pc_y}": components[:, pc_y_idx],
                hue_col: hue_data,
            }
        )
        sns.scatterplot(
            data=scatter_df,
            x=f"PC{pc_x}",
            y=f"PC{pc_y}",
            hue=hue_col,
            palette=palette,
            ax=ax,
            alpha=0.7,
            s=50,
        )
        ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    else:
        ax.scatter(components[:, pc_x_idx], components[:, pc_y_idx], alpha=0.7, s=50)
    ax.set_xlabel(f"PC{pc_x} ({explained_var[pc_x_idx]:.1%} variance)")
    ax.set_ylabel(f"PC{pc_y} ({explained_var[pc_y_idx]:.1%} variance)")
    ax.set_title(f"PCA Scatter Plot: PC{pc_x} vs PC{pc_y}")
    ax.grid(True, alpha=0.3)


def _plot_biplot(
    ax,
    components,
    loadings,
    feature_names,
    explained_var,
    hue_data=None,
    hue_col=None,
    palette=None,
    pc_x=1,
    pc_y=2,
):
    """Plot PCA biplot with loadings."""
    pc_x_idx = pc_x - 1
    pc_y_idx = pc_y - 1
    if hue_data is not None:
        scatter_df = pd.DataFrame(
            {
                f"PC{pc_x}": components[:, pc_x_idx],
                f"PC{pc_y}": components[:, pc_y_idx],
                hue_col: hue_data,
            }
        )
        sns.scatterplot(
            data=scatter_df,
            x=f"PC{pc_x}",
            y=f"PC{pc_y}",
            hue=hue_col,
            palette=palette,
            ax=ax,
            alpha=0.5,
            s=30,
            legend=True,
        )
        ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    else:
        ax.scatter(components[:, pc_x_idx], components[:, pc_y_idx], alpha=0.5, s=30)
    scale = components[:, [pc_x_idx, pc_y_idx]].max() * 0.8
    for i, feature in enumerate(feature_names):
        ax.arrow(
            0,
            0,
            loadings[pc_x_idx, i] * scale,
            loadings[pc_y_idx, i] * scale,
            head_width=0.05,
            head_length=0.05,
            fc="red",
            ec="red",
            alpha=0.6,
        )
        ax.text(
            loadings[pc_x_idx, i] * scale * 1.1,
            loadings[pc_y_idx, i] * scale * 1.1,
            feature,
            fontsize=7,
            ha="center",
            va="center",
        )
    ax.set_xlabel(f"PC{pc_x} ({explained_var[pc_x_idx]:.1%} variance)")
    ax.set_ylabel(f"PC{pc_y} ({explained_var[pc_y_idx]:.1%} variance)")
    ax.set_title(f"PCA Biplot: PC{pc_x} vs PC{pc_y}")
    ax.grid(True, alpha=0.3)


def _plot_correlation_circle(
    ax, loadings, feature_names, explained_var, pc_x=1, pc_y=2, text_size=7
):
    pc_x_idx = pc_x - 1
    pc_y_idx = pc_y - 1
    x_loadings = loadings[pc_x_idx, :]
    y_loadings = loadings[pc_y_idx, :]
    circle_1 = plt.Circle(
        (0, 0), 1, color="#333333", fill=False, linestyle="-", linewidth=1.2, alpha=0.6
    )
    circle_05 = plt.Circle(
        (0, 0), 0.5, color="gray", fill=False, linestyle=":", linewidth=0.8, alpha=0.5
    )
    ax.add_artist(circle_1)
    ax.add_artist(circle_05)
    ax.axhline(0, color="black", linewidth=0.8, alpha=0.5)
    ax.axvline(0, color="black", linewidth=0.8, alpha=0.5)
    ticks = [-1.0, -0.75, -0.5, -0.25, 0.0, 0.25, 0.5, 0.75, 1.0]
    ax.set_xticks(ticks)
    ax.set_yticks(ticks)
    ax.grid(True, linestyle=":", linewidth=0.6, alpha=0.4, color="gray", zorder=0)
    ax.tick_params(axis="both", which="major", labelsize=8, colors="#666666")
    color_code = "#E74C3C"
    for i, (feature, x, y) in enumerate(zip(feature_names, x_loadings, y_loadings)):
        ax.plot([0, x], [0, y], color=color_code, linewidth=0.8, alpha=0.9, zorder=5)
        ax.scatter(x, y, color=color_code, s=4, zorder=6)
        text_x = x * 1.15
        text_y = y * 1.15
        ha = "center"
        if x > 0.1:
            ha = "left"
        elif x < -0.1:
            ha = "right"
        va = "center"
        if y > 0.1:
            va = "bottom"
        elif y < -0.1:
            va = "top"
        ax.text(
            text_x,
            text_y,
            feature,
            color="#333333",
            ha=ha,
            va=va,
            fontsize=text_size,
            fontweight="normal",
            zorder=7,
        )
    ax.set_xlim(-1, 1)
    ax.set_ylim(-1, 1)
    ax.set_aspect("equal")
    ax.spines["top"].set_visible(False)
    ax.spines["right"].set_visible(False)
    ax.spines["left"].set_visible(False)
    ax.spines["bottom"].set_visible(False)
    ax.set_xlabel(
        f"PC{pc_x} ({explained_var[pc_x_idx]:.1%} var)",
        fontsize=12,
        fontweight="bold",
        color="#444444",
    )
    ax.set_ylabel(
        f"PC{pc_y} ({explained_var[pc_y_idx]:.1%} var)",
        fontsize=12,
        fontweight="bold",
        color="#444444",
    )
    ax.set_title(
        f"Correlation Circle (PC{pc_x} vs PC{pc_y})",
        fontsize=14,
        pad=15,
        color="#333333",
        fontweight="bold",
    )


def _plot_variance(ax, explained_var):
    """Plot explained variance."""
    n_components = len(explained_var)
    cumulative_var = np.cumsum(explained_var)
    x = np.arange(1, n_components + 1)
    ax.bar(x, explained_var, alpha=0.6, label="Individual")
    ax.plot(x, cumulative_var, "ro-", linewidth=2, label="Cumulative")
    ax.set_xlabel("Principal Component")
    ax.set_ylabel("Explained Variance Ratio")
    ax.set_title("Explained Variance by Component")
    ax.set_xticks(x)
    ax.legend()
    ax.grid(True, alpha=0.3, axis="y")


def pca_analysis(
    data: Union[DataFrame, ArrowTable], options=None
) -> Tuple[Any, Any, Union[MediaData, PILImage], DataFrame, DataFrame, DataFrame]:
    options = options or {}
    verbose = options.get("verbose", True)
    random_state = options.get("random_state", 42)
    activate_caching = options.get("activate_caching", False)
    output_type = options.get("output_type", "array")
    columns = options.get("columns", None)
    n_components = options.get("n_components", 2)
    scale_data = options.get("scale_data", True)
    id_column = options.get("id_column", "")
    hue = options.get("hue", "")
    pc_x = options.get("pc_x", 1)
    pc_y = options.get("pc_y", 2)
    plot_type = options.get("plot_type", "scatter")
    exclude_hue = options.get("exclude_hue", True)
    palette = options.get("palette", "husl")
    dpi = 300
    verbose and logger.info(f"Starting PCA with {n_components} components")
    PCA = None
    PCA_Image = None
    PCA_Components = pd.DataFrame()
    PCA_Summary = pd.DataFrame()
    PCA_Loadings = pd.DataFrame()
    Scaler = None
    plot_components = None
    plot_loadings = None
    plot_explained_var = None
    hue_data = None
    hue_col = None
    df = None
    try:
        if isinstance(data, pl.DataFrame):
            verbose and logger.info(f"Converting polars DataFrame to pandas")
            df = data.to_pandas()
        elif isinstance(data, pa.Table):
            verbose and logger.info(f"Converting Arrow table to pandas")
            df = data.to_pandas()
        elif isinstance(data, pd.DataFrame):
            verbose and logger.info(f"Input is already pandas DataFrame")
            df = data
        else:
            raise ValueError(f"Unsupported data type: {type(data).__name__}")
    except Exception as e:
        raise RuntimeError(
            f"Failed to convert input data to pandas DataFrame: {e}"
        ) from e
    if df.empty:
        raise ValueError("Input DataFrame is empty")
    verbose and logger.info(
        f"Processing DataFrame with {df.shape[0]:,} rows × {df.shape[1]:,} columns"
    )
    if hue and hue.strip():
        if hue not in df.columns:
            raise ValueError(f"Hue column '{hue}' not found")
        hue_col = hue
    if id_column and id_column.strip():
        if id_column not in df.columns:
            raise ValueError(f"ID column '{id_column}' not found in DataFrame")
    if columns and len(columns) > 0:
        missing_cols = [col for col in columns if col not in df.columns]
        if missing_cols:
            raise ValueError(f"Columns not found in DataFrame: {missing_cols}")
        feature_cols = list(columns)
    else:
        feature_cols = df.select_dtypes(include=["number", "bool"]).columns.tolist()
        if not feature_cols:
            raise ValueError("No numeric columns found in DataFrame")
    if id_column and id_column.strip() and (id_column in feature_cols):
        feature_cols.remove(id_column)
    if exclude_hue and hue_col and (hue_col in feature_cols):
        feature_cols.remove(hue_col)
    if not feature_cols:
        raise ValueError("No feature columns remaining after excluding ID column")
    skip_computation = False
    cache_folder = None
    all_hash = None
    if activate_caching:
        verbose and logger.info(f"Caching is active")
        data_hasher = _UniversalDatasetHasher(df.shape[0], verbose=verbose)
        X_hash = data_hasher.hash_data(df[feature_cols]).hash
        all_hash_base_text = f"HASH BASE TEXT PCA v7_corr_circle_stylePandas Version {pd.__version__}POLARS Version {pl.__version__}Numpy Version {np.__version__}Scikit Learn Version {sklearn.__version__}Scipy Version {scipy.__version__}{X_hash}{n_components}{scale_data}{random_state}{exclude_hue}{sorted(feature_cols)}"
        all_hash = hashlib.sha256(all_hash_base_text.encode("utf-8")).hexdigest()
        verbose and logger.info(f"Hash was computed: {all_hash}")
        temp_folder = Path(tempfile.gettempdir())
        cache_folder = temp_folder / "coded-flows-cache"
        cache_folder.mkdir(parents=True, exist_ok=True)
        pca_components_path = cache_folder / f"pca_components_{all_hash}.parquet"
        pca_summary_path = cache_folder / f"pca_summary_{all_hash}.parquet"
        pca_loadings_path = cache_folder / f"pca_loadings_{all_hash}.parquet"
        scaler_path = cache_folder / f"pca_scaler_{all_hash}.joblib"
        pca_model_path = cache_folder / f"pca_model_{all_hash}.joblib"
        if (
            pca_components_path.is_file()
            and pca_summary_path.is_file()
            and pca_loadings_path.is_file()
            and scaler_path.is_file()
            and pca_model_path.is_file()
        ):
            verbose and logger.info(f"Cache hit! Loading results.")
            try:
                PCA_Components = pd.read_parquet(pca_components_path)
                PCA_Summary = pd.read_parquet(pca_summary_path)
                PCA_Loadings = pd.read_parquet(pca_loadings_path)
                Scaler = joblib.load(scaler_path)
                PCA = joblib.load(pca_model_path)
                plot_explained_var = PCA_Summary["Variance_Explained"].values
                loading_cols = [
                    c for c in PCA_Summary.columns if c.startswith("Loading_")
                ]
                plot_loadings = PCA_Summary[loading_cols].values
                comps_for_plot = PCA_Components.copy()
                if id_column and id_column in comps_for_plot.columns:
                    comps_for_plot = comps_for_plot.drop(columns=[id_column])
                plot_components = comps_for_plot.values
                skip_computation = True
            except Exception as e:
                verbose and logger.warning(f"Cache load failed, recomputing")
                skip_computation = False
                raise
    if not skip_computation:
        X = df[feature_cols]
        if np.any(np.isnan(X)):
            verbose and logger.warning(f"Removing rows with missing values")
            mask = ~np.isnan(X).any(axis=1)
            X = X[mask]
            df_indices = df.index[mask]
        else:
            df_indices = df.index
        if X.shape[0] == 0:
            raise ValueError("No valid rows after removing missing values")
        if scale_data:
            verbose and logger.info(f"Scaling features")
            Scaler = StandardScaler()
            X_transformed = Scaler.fit_transform(X)
        else:
            X_transformed = X
        verbose and logger.info(f"Fitting PCA model")
        PCA = _PCA(n_components=n_components, random_state=random_state)
        plot_components = PCA.fit_transform(X_transformed)
        explained_var = PCA.explained_variance_ratio_
        eigenvalues = PCA.explained_variance_
        cumulative_var = np.cumsum(explained_var)
        plot_explained_var = explained_var
        plot_loadings = PCA.components_
        pc_columns = [f"PC{i + 1}" for i in range(n_components)]
        PCA_Components = pd.DataFrame(
            plot_components, columns=pc_columns, index=df_indices
        )
        if id_column and id_column.strip():
            id_data = df.loc[df_indices, id_column].values
            PCA_Components.insert(0, id_column, id_data)
        summary_data = {
            "Component": pc_columns,
            "Eigenvalue": eigenvalues,
            "Variance_Explained": explained_var,
            "Cumulative_Variance": cumulative_var,
        }
        for idx, feature in enumerate(feature_cols):
            summary_data[f"Loading_{feature}"] = PCA.components_[:, idx]
        PCA_Summary = pd.DataFrame(summary_data)
        PCA_Loadings = pd.DataFrame(PCA.components_.T, columns=pc_columns)
        PCA_Loadings.insert(0, "Feature", feature_cols)
        PCA_Loadings = PCA_Loadings.loc[
            PCA_Loadings["PC1"].abs().sort_values(ascending=False).index
        ]
        PCA_Loadings = PCA_Loadings.reset_index(drop=True)
        if activate_caching and cache_folder and all_hash:
            try:
                pca_components_path = (
                    cache_folder / f"pca_components_{all_hash}.parquet"
                )
                pca_summary_path = cache_folder / f"pca_summary_{all_hash}.parquet"
                pca_loadings_path = cache_folder / f"pca_loadings_{all_hash}.parquet"
                scaler_path = cache_folder / f"pca_scaler_{all_hash}.joblib"
                pca_model_path = cache_folder / f"pca_model_{all_hash}.joblib"
                PCA_Components.to_parquet(pca_components_path)
                PCA_Summary.to_parquet(pca_summary_path)
                PCA_Loadings.to_parquet(pca_loadings_path)
                if Scaler is not None:
                    joblib.dump(Scaler, scaler_path)
                if PCA is not None:
                    joblib.dump(PCA, pca_model_path)
                verbose and logger.info(f"Results saved to cache")
            except Exception as e:
                verbose and logger.warning(f"Failed to save cache: {e}")
    try:
        if hue_col:
            hue_data = df.loc[PCA_Components.index, hue].values
            verbose and logger.info(f"Using hue column: '{hue}'")
        verbose and logger.info(f"Creating {plot_type} visualization")
        if plot_components is None or plot_explained_var is None:
            raise RuntimeError("PCA plot data is missing (logic error)")
        if plot_type == "all":
            (fig, axes) = plt.subplots(2, 2, figsize=(16, 14))
            ax_list = axes.flatten()
            _plot_scatter(
                ax_list[0],
                plot_components,
                hue_data,
                hue_col,
                palette,
                plot_explained_var,
                pc_x,
                pc_y,
            )
            _plot_variance(ax_list[1], plot_explained_var)
            _plot_biplot(
                ax_list[2],
                plot_components,
                plot_loadings,
                feature_cols,
                plot_explained_var,
                hue_data,
                hue_col,
                palette,
                pc_x,
                pc_y,
            )
            _plot_correlation_circle(
                ax_list[3], plot_loadings, feature_cols, plot_explained_var, pc_x, pc_y
            )
            plt.tight_layout()
        else:
            (fig, ax) = plt.subplots(figsize=(10, 8))
            if plot_type == "scatter":
                _plot_scatter(
                    ax,
                    plot_components,
                    hue_data,
                    hue_col,
                    palette,
                    plot_explained_var,
                    pc_x,
                    pc_y,
                )
            elif plot_type == "biplot":
                _plot_biplot(
                    ax,
                    plot_components,
                    plot_loadings,
                    feature_cols,
                    plot_explained_var,
                    hue_data,
                    hue_col,
                    palette,
                    pc_x,
                    pc_y,
                )
            elif plot_type == "variance":
                _plot_variance(ax, plot_explained_var)
            elif plot_type == "correlation_circle":
                _plot_correlation_circle(
                    ax, plot_loadings, feature_cols, plot_explained_var, pc_x, pc_y
                )
            else:
                raise ValueError(f"Invalid plot_type: '{plot_type}'")
        verbose and logger.info(f"Rendering to {output_type} format with DPI={dpi}")
        buf = io.BytesIO()
        fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
        buf.seek(0)
        if output_type == "bytesio":
            PCA_Image = buf
        elif output_type == "bytes":
            PCA_Image = buf.getvalue()
            buf.close()
        elif output_type == "pil":
            PCA_Image = Image.open(buf)
            PCA_Image.load()
            buf.close()
        elif output_type == "array":
            img = Image.open(buf)
            PCA_Image = np.array(img)
            buf.close()
        else:
            raise ValueError(f"Invalid output_type: '{output_type}'")
        plt.close(fig)
    except (ValueError, RuntimeError):
        plt.close("all")
        raise
    except Exception as e:
        error_msg = f"Failed to perform PCA: {e}"
        verbose and logger.error(f"{error_msg}")
        plt.close("all")
        raise RuntimeError(error_msg) from e
    if PCA_Image is None:
        raise RuntimeError("PCA analysis returned empty result")
    verbose and logger.info(f"Successfully completed PCA analysis")
    return (PCA, Scaler, PCA_Image, PCA_Components, PCA_Summary, PCA_Loadings)

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • scikit-learn
  • pandas
  • pyarrow
  • polars[pyarrow]
  • numpy
  • numba>=0.56.0
  • joblib
  • matplotlib
  • seaborn
  • pillow
  • xxhash