PCA Analysis

Perform Principal Component Analysis and generate visualization.

PCA Analysis

Processing

This brick performs Principal Component Analysis (PCA) on numerical columns of the input data. It handles optional data scaling, calculates component scores, explained variance, and feature loadings. It supports caching based on input data hashing and allows generation of various visualizations (scatter plot, biplot, or explained variance plot) in multiple formats.

Inputs

data
The input table or DataFrame containing numerical features suitable for PCA. The supported formats are Pandas DataFrame, Polars DataFrame, or Apache Arrow Table.

Inputs Types

Input Types
data DataFrame, ArrowTable

You can check the list of supported types here: Available Type Hints.

Outputs

Scaler
The fitted object used for scaling the data before PCA. Returns None if scaling is disabled.
PCA Image
The generated visualization (scatter plot, biplot, or variance plot) in the format specified by the 'Output Type' option.
PCA Components
A DataFrame containing the transformed data (PCA scores), where each row is an observation and columns represent the principal components. The original index is preserved.
PCA Summary
A DataFrame detailing the PCA results, including explained variance ratios, eigenvalues, and feature loadings for each component.

The PCA Components output DataFrame contains:

  • ID Column: If specified, the column containing identifiers.
  • PC1, PC2, ..., PCN: The coordinates (scores) of the data points projected onto the Principal Components.

The PCA Summary output DataFrame contains:

  • Component: The name of the Principal Component (e.g., PC1).
  • Eigenvalue: The variance of the component.
  • Variance_Explained: The ratio of total variance explained by this component.
  • Cumulative_Variance: The running total of variance explained up to this component.
  • Loading_[Feature Name]: Columns representing the loading of each original feature onto the respective component.

Outputs Types

Output Types
Scaler Any
PCA Image MediaData, PILImage, DataFrame
PCA Components DataFrame
PCA Summary DataFrame

You can check the list of supported types here: Available Type Hints.

Options

The PCA Analysis brick contains some changeable options:

Columns for PCA
Defines which specific numerical columns should be used for the PCA calculation. If left empty, all numerical columns are used.
Number of Components
Sets the number of principal components to extract and analyze. Must be at least 2. (Default: 2)
Scale Data
If enabled, features are standardized (zero mean, unit variance) before PCA is performed. (Default: True)
ID Column
Specifies a non-numeric column (e.g., identifiers) to be attached to the resulting PCA_Components DataFrame. (optional)
Hue Column (for scatter)
Specifies a categorical or numerical column to use for coloring points in the scatter and biplots. (optional)
PC for X-axis
The index of the principal component to display on the X-axis in the scatter and biplots (e.g., 1 for PC1). (Default: 1)
PC for Y-axis
The index of the principal component to display on the Y-axis in the scatter and biplots (e.g., 2 for PC2). (Default: 2)
Plot Type
Determines the visualization generated: scatter (scores plot), biplot (scores and loadings), variance (scree plot), or all (outputs all three plots). (Default: scatter)
Color Palette
The color palette to use when coloring points based on the Hue Column. (Default: husl)
Output Type
Defines the format of the output visualization (PCA_Image). Choices include array (Numpy array), pil (PIL Image object), bytes (raw image bytes), or bytesio (BytesIO buffer). (Default: array)
Random State
Sets the random seed for reproducible results, particularly important if internal random processes are involved. (Default: 42)
Brick Caching
Activates caching of the computed model results (PCA Components, PCA Summary, Scaler) based on the input data and configuration parameters, speeding up subsequent identical runs. (Default: False)
Verbose
Toggles detailed logging messages during processing. (Default: True)
import logging
import io
import json
import xxhash
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import matplotlib
from scipy import sparse
from dataclasses import dataclass
from datetime import datetime
import hashlib
import tempfile
import sklearn
import scipy
import joblib
from pathlib import Path

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
from coded_flows.types import (
    Union,
    DataFrame,
    ArrowTable,
    MediaData,
    PILImage,
    Tuple,
    Any,
)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
DataType = Union[
    pd.DataFrame, pl.DataFrame, np.ndarray, sparse.spmatrix, pd.Series, pl.Series
]


@dataclass
class _DatasetFingerprint:
    """Lightweight fingerprint of a dataset."""

    hash: str
    shape: tuple
    computed_at: str
    data_type: str
    method: str


class _UniversalDatasetHasher:
    """
    High-performance dataset hasher optimizing for zero-copy operations
    and native backend execution (C/Rust).
    """

    def __init__(
        self,
        data_size: int,
        method: str = "auto",
        sample_size: int = 100000,
        verbose: bool = False,
    ):
        self.method = method
        self.sample_size = sample_size
        self.data_size = data_size
        self.verbose = verbose

    def hash_data(self, data: DataType) -> _DatasetFingerprint:
        """
        Main entry point: hash any supported data format.
        Auto-detects format and applies optimal strategy.
        """
        if isinstance(data, pd.DataFrame):
            return self._hash_pandas(data)
        elif isinstance(data, pl.DataFrame):
            return self._hash_polars(data)
        elif isinstance(data, pd.Series):
            return self._hash_pandas_series(data)
        elif isinstance(data, pl.Series):
            return self._hash_polars_series(data)
        elif isinstance(data, np.ndarray):
            return self._hash_numpy(data)
        elif sparse.issparse(data):
            return self._hash_sparse(data)
        else:
            raise TypeError(f"Unsupported data type: {type(data)}")

    def _hash_pandas(self, df: pd.DataFrame) -> _DatasetFingerprint:
        """
        Optimized Pandas hashing using pd.util.hash_pandas_object.
        Avoids object-to-string conversion overhead.
        """
        method = self._determine_method(self.data_size, self.method)
        self.verbose and logger.info(
            f"Hashing Pandas: {self.data_size:,} rows - {method}"
        )
        target_df = df
        if method == "sampled":
            target_df = self._get_pandas_sample(df)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns.tolist(),
                "dtypes": {k: str(v) for (k, v) in df.dtypes.items()},
                "shape": df.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(target_df, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception as e:
            self.verbose and logger.warning(
                f"Fast hash failed, falling back to slow hash: {e}"
            )
            self._hash_pandas_fallback(hasher, target_df)
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas",
            method=method,
        )

    def _get_pandas_sample(self, df: pd.DataFrame) -> pd.DataFrame:
        """Deterministic slicing for sampling (Zero randomness)."""
        if self.data_size <= self.sample_size:
            return df
        chunk = self.sample_size // 3
        head = df.iloc[:chunk]
        mid_idx = self.data_size // 2
        mid = df.iloc[mid_idx : mid_idx + chunk]
        tail = df.iloc[-chunk:]
        return pd.concat([head, mid, tail])

    def _hash_pandas_fallback(self, hasher, df: pd.DataFrame):
        """Legacy fallback for complex object types."""
        for col in df.columns:
            val = df[col].astype(str).values
            hasher.update(val.astype(np.bytes_).tobytes())

    def _hash_polars(self, df: pl.DataFrame) -> _DatasetFingerprint:
        """Optimized Polars hashing using native Rust execution."""
        method = self._determine_method(self.data_size, self.method)
        self.verbose and logger.info(
            f"Hashing Polars: {self.data_size:,} rows - {method}"
        )
        target_df = df
        if method == "sampled" and self.data_size > self.sample_size:
            indices = self._get_sample_indices(self.data_size, self.sample_size)
            target_df = df.gather(indices)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns,
                "dtypes": [str(t) for t in df.dtypes],
                "shape": df.shape,
            },
        )
        row_hashes = target_df.hash_rows()
        hasher.update(memoryview(row_hashes.to_numpy()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars",
            method=method,
        )

    def _hash_pandas_series(self, series: pd.Series) -> _DatasetFingerprint:
        """Hash Pandas Series using the fastest vectorized method."""
        self.verbose and logger.info(f"Hashing Pandas Series: {self.data_size:,} rows")
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "name": series.name if series.name else "None",
                "dtype": str(series.dtype),
                "shape": series.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(series, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception as e:
            self.verbose and logger.warning(f"Series hash failed, falling back: {e}")
            hasher.update(memoryview(series.astype(str).values.tobytes()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas_series",
            method="full",
        )

    def _hash_polars_series(self, series: pl.Series) -> _DatasetFingerprint:
        """Hash Polars Series using native Polars expressions."""
        self.verbose and logger.info(f"Hashing Polars Series: {self.data_size:,} rows")
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"name": series.name, "dtype": str(series.dtype), "shape": series.shape},
        )
        try:
            row_hashes = series.hash()
            hasher.update(memoryview(row_hashes.to_numpy()))
        except Exception as e:
            self.verbose and logger.warning(
                f"Polars series native hash failed. Falling back."
            )
            hasher.update(str(series.to_list()).encode())
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars_series",
            method="full",
        )

    def _hash_numpy(self, arr: np.ndarray) -> _DatasetFingerprint:
        """Optimized NumPy hashing using Buffer Protocol (Zero-Copy)."""
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"shape": arr.shape, "dtype": str(arr.dtype), "strides": arr.strides},
        )
        if arr.flags["C_CONTIGUOUS"] or arr.flags["F_CONTIGUOUS"]:
            hasher.update(memoryview(arr))
        else:
            hasher.update(memoryview(np.ascontiguousarray(arr)))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=arr.shape,
            computed_at=datetime.now().isoformat(),
            data_type="numpy",
            method="full",
        )

    def _hash_sparse(self, matrix: sparse.spmatrix) -> _DatasetFingerprint:
        """Optimized sparse hashing."""
        if not (sparse.isspmatrix_csr(matrix) or sparse.isspmatrix_csc(matrix)):
            matrix = matrix.tocsr()
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher, {"shape": matrix.shape, "format": matrix.format, "nnz": matrix.nnz}
        )
        hasher.update(memoryview(matrix.data))
        hasher.update(memoryview(matrix.indices))
        hasher.update(memoryview(matrix.indptr))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=matrix.shape,
            computed_at=datetime.now().isoformat(),
            data_type=f"sparse_{matrix.format}",
            method="sparse",
        )

    def _determine_method(self, rows: int, requested: str) -> str:
        if requested != "auto":
            return requested
        if rows < 5000000:
            return "full"
        return "sampled"

    def _hash_schema(self, hasher, schema):
        hasher.update(
            json.dumps(schema, sort_keys=True, separators=(",", ":")).encode()
        )

    def _get_sample_indices(self, total_rows: int, sample_size: int) -> list:
        chunk = sample_size // 3
        indices = list(range(min(chunk, total_rows)))
        mid_start = max(0, total_rows // 2 - chunk // 2)
        mid_end = min(mid_start + chunk, total_rows)
        indices.extend(range(mid_start, mid_end))
        last_start = max(0, total_rows - chunk)
        indices.extend(range(last_start, total_rows))
        return sorted(list(set(indices)))


def pca_analysis(
    data: Union[DataFrame, ArrowTable], options=None
) -> Tuple[Any, Union[MediaData, PILImage, DataFrame], DataFrame, DataFrame]:
    brick_display_name = "pca_analysis"
    options = options or {}
    verbose = options.get("verbose", True)
    random_state = options.get("random_state", 42)
    activate_caching = options.get("activate_caching", False)
    output_type = options.get("output_type", "array")
    columns = options.get("columns", None)
    n_components = options.get("n_components", 2)
    scale_data = options.get("scale_data", True)
    id_column = options.get("id_column", "")
    hue = options.get("hue", "")
    pc_x = options.get("pc_x", 1)
    pc_y = options.get("pc_y", 2)
    plot_type = options.get("plot_type", "scatter")
    palette = options.get("palette", "husl")
    dpi = 300
    verbose and logger.info(
        f"[{brick_display_name}] Starting PCA with {n_components} components"
    )
    PCA_Image = None
    PCA_Components = pd.DataFrame()
    PCA_Summary = pd.DataFrame()
    Scaler = None
    plot_components = None
    plot_loadings = None
    plot_explained_var = None
    hue_data = None
    hue_col = None
    df = None
    try:
        if isinstance(data, pl.DataFrame):
            verbose and logger.info(
                f"[{brick_display_name}] Converting polars DataFrame to pandas"
            )
            df = data.to_pandas()
        elif isinstance(data, pa.Table):
            verbose and logger.info(
                f"[{brick_display_name}] Converting Arrow table to pandas"
            )
            df = data.to_pandas()
        elif isinstance(data, pd.DataFrame):
            verbose and logger.info(
                f"[{brick_display_name}] Input is already pandas DataFrame"
            )
            df = data
        else:
            raise ValueError(f"Unsupported data type: {type(data).__name__}")
    except Exception as e:
        raise RuntimeError(
            f"Failed to convert input data to pandas DataFrame: {e}"
        ) from e
    if df.empty:
        raise ValueError("Input DataFrame is empty")
    verbose and logger.info(
        f"[{brick_display_name}] Processing DataFrame with {df.shape[0]:,} rows × {df.shape[1]:,} columns"
    )
    if id_column and id_column.strip():
        if id_column not in df.columns:
            raise ValueError(f"ID column '{id_column}' not found in DataFrame")
    if columns and len(columns) > 0:
        missing_cols = [col for col in columns if col not in df.columns]
        if missing_cols:
            raise ValueError(f"Columns not found in DataFrame: {missing_cols}")
        feature_cols = list(columns)
    else:
        feature_cols = df.select_dtypes(include=[np.number]).columns.tolist()
        if not feature_cols:
            raise ValueError("No numeric columns found in DataFrame")
    if id_column and id_column.strip() and (id_column in feature_cols):
        feature_cols.remove(id_column)
    if not feature_cols:
        raise ValueError("No feature columns remaining after excluding ID column")
    skip_computation = False
    cache_folder = None
    all_hash = None
    if activate_caching:
        verbose and logger.info(f"[{brick_display_name}] Caching is active")
        data_hasher = _UniversalDatasetHasher(df.shape[0], verbose=verbose)
        X_hash = data_hasher.hash_data(df[feature_cols]).hash
        all_hash_base_text = f"HASH BASE TEXT PCAPandas Version {pd.__version__}POLARS Version {pl.__version__}Numpy Version {np.__version__}Scikit Learn Version {sklearn.__version__}Scipy Version {scipy.__version__}{X_hash}{n_components}{scale_data}{random_state}{sorted(feature_cols)}"
        all_hash = hashlib.sha256(all_hash_base_text.encode("utf-8")).hexdigest()
        verbose and logger.info(f"[{brick_display_name}] Hash was computed: {all_hash}")
        temp_folder = Path(tempfile.gettempdir())
        cache_folder = temp_folder / "coded-flows-cache"
        cache_folder.mkdir(parents=True, exist_ok=True)
        pca_components_path = cache_folder / f"pca_components_{all_hash}.parquet"
        pca_summary_path = cache_folder / f"pca_summary_{all_hash}.parquet"
        scaler_path = cache_folder / f"pca_scaler_{all_hash}.joblib"
        if (
            pca_components_path.is_file()
            and pca_summary_path.is_file()
            and scaler_path.is_file()
        ):
            verbose and logger.info(
                f"[{brick_display_name}] Cache hit! Loading results."
            )
            try:
                PCA_Components = pd.read_parquet(pca_components_path)
                PCA_Summary = pd.read_parquet(pca_summary_path)
                Scaler = joblib.load(scaler_path)
                plot_explained_var = PCA_Summary["Variance_Explained"].values
                loading_cols = [
                    c for c in PCA_Summary.columns if c.startswith("Loading_")
                ]
                plot_loadings = PCA_Summary[loading_cols].values
                comps_for_plot = PCA_Components.copy()
                if id_column and id_column in comps_for_plot.columns:
                    comps_for_plot = comps_for_plot.drop(columns=[id_column])
                plot_components = comps_for_plot.values
                skip_computation = True
            except Exception as e:
                verbose and logger.warning(
                    f"[{brick_display_name}] Cache load failed, recomputing"
                )
                skip_computation = False
                raise
    if not skip_computation:
        X = df[feature_cols].values
        if np.any(np.isnan(X)):
            verbose and logger.warning(
                f"[{brick_display_name}] Removing rows with missing values"
            )
            mask = ~np.isnan(X).any(axis=1)
            X = X[mask]
            df_indices = df.index[mask]
        else:
            df_indices = df.index
        if X.shape[0] == 0:
            raise ValueError("No valid rows after removing missing values")
        if scale_data:
            verbose and logger.info(f"[{brick_display_name}] Scaling features")
            Scaler = StandardScaler()
            X_transformed = Scaler.fit_transform(X)
        else:
            X_transformed = X
        verbose and logger.info(f"[{brick_display_name}] Fitting PCA model")
        pca = PCA(n_components=n_components, random_state=random_state)
        plot_components = pca.fit_transform(X_transformed)
        explained_var = pca.explained_variance_ratio_
        eigenvalues = pca.explained_variance_
        cumulative_var = np.cumsum(explained_var)
        plot_explained_var = explained_var
        plot_loadings = pca.components_
        pc_columns = [f"PC{i + 1}" for i in range(n_components)]
        PCA_Components = pd.DataFrame(
            plot_components, columns=pc_columns, index=df_indices
        )
        if id_column and id_column.strip():
            id_data = df.loc[df_indices, id_column].values
            PCA_Components.insert(0, id_column, id_data)
        summary_data = {
            "Component": pc_columns,
            "Eigenvalue": eigenvalues,
            "Variance_Explained": explained_var,
            "Cumulative_Variance": cumulative_var,
        }
        for idx, feature in enumerate(feature_cols):
            summary_data[f"Loading_{feature}"] = pca.components_[:, idx]
        PCA_Summary = pd.DataFrame(summary_data)
        if activate_caching and cache_folder and all_hash:
            try:
                pca_components_path = (
                    cache_folder / f"pca_components_{all_hash}.parquet"
                )
                pca_summary_path = cache_folder / f"pca_summary_{all_hash}.parquet"
                scaler_path = cache_folder / f"pca_scaler_{all_hash}.joblib"
                PCA_Components.to_parquet(pca_components_path)
                PCA_Summary.to_parquet(pca_summary_path)
                if Scaler is not None:
                    joblib.dump(Scaler, scaler_path)
                verbose and logger.info(
                    f"[{brick_display_name}] Results saved to cache"
                )
            except Exception as e:
                verbose and logger.warning(
                    f"[{brick_display_name}] Failed to save cache: {e}"
                )
    try:
        if hue and hue.strip():
            if hue not in df.columns:
                raise ValueError(f"Hue column '{hue}' not found in DataFrame")
            hue_col = hue
            hue_data = df.loc[PCA_Components.index, hue].values
            verbose and logger.info(f"[{brick_display_name}] Using hue column: '{hue}'")
        verbose and logger.info(
            f"[{brick_display_name}] Creating {plot_type} visualization"
        )
        if plot_components is None or plot_explained_var is None:
            raise RuntimeError("PCA plot data is missing (logic error)")
        if plot_type == "all":
            (fig, axes) = plt.subplots(1, 3, figsize=(18, 5))
            _plot_scatter(
                axes[0],
                plot_components,
                hue_data,
                hue_col,
                palette,
                plot_explained_var,
                pc_x,
                pc_y,
            )
            _plot_biplot(
                axes[1],
                plot_components,
                plot_loadings,
                feature_cols,
                plot_explained_var,
                hue_data,
                hue_col,
                palette,
                pc_x,
                pc_y,
            )
            _plot_variance(axes[2], plot_explained_var)
            plt.tight_layout()
        else:
            (fig, ax) = plt.subplots(figsize=(10, 7))
            if plot_type == "scatter":
                _plot_scatter(
                    ax,
                    plot_components,
                    hue_data,
                    hue_col,
                    palette,
                    plot_explained_var,
                    pc_x,
                    pc_y,
                )
            elif plot_type == "biplot":
                _plot_biplot(
                    ax,
                    plot_components,
                    plot_loadings,
                    feature_cols,
                    plot_explained_var,
                    hue_data,
                    hue_col,
                    palette,
                    pc_x,
                    pc_y,
                )
            elif plot_type == "variance":
                _plot_variance(ax, plot_explained_var)
            else:
                raise ValueError(f"Invalid plot_type: '{plot_type}'")
        verbose and logger.info(
            f"[{brick_display_name}] Rendering to {output_type} format with DPI={dpi}"
        )
        buf = io.BytesIO()
        fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
        buf.seek(0)
        if output_type == "bytesio":
            PCA_Image = buf
        elif output_type == "bytes":
            PCA_Image = buf.getvalue()
            buf.close()
        elif output_type == "pil":
            PCA_Image = Image.open(buf)
            PCA_Image.load()
            buf.close()
        elif output_type == "array":
            img = Image.open(buf)
            PCA_Image = np.array(img)
            buf.close()
        else:
            raise ValueError(f"Invalid output_type: '{output_type}'")
        plt.close(fig)
    except (ValueError, RuntimeError):
        plt.close("all")
        raise
    except Exception as e:
        error_msg = f"Failed to perform PCA: {e}"
        verbose and logger.error(f"[{brick_display_name}] {error_msg}")
        plt.close("all")
        raise RuntimeError(error_msg) from e
    if PCA_Image is None:
        raise RuntimeError("PCA analysis returned empty result")
    verbose and logger.info(
        f"[{brick_display_name}] Successfully completed PCA analysis"
    )
    return (Scaler, PCA_Image, PCA_Components, PCA_Summary)


def _plot_scatter(
    ax, components, hue_data, hue_col, palette, explained_var, pc_x, pc_y
):
    """Plot PCA scatter plot."""
    pc_x_idx = pc_x - 1
    pc_y_idx = pc_y - 1
    if hue_data is not None:
        scatter_df = pd.DataFrame(
            {
                f"PC{pc_x}": components[:, pc_x_idx],
                f"PC{pc_y}": components[:, pc_y_idx],
                hue_col: hue_data,
            }
        )
        sns.scatterplot(
            data=scatter_df,
            x=f"PC{pc_x}",
            y=f"PC{pc_y}",
            hue=hue_col,
            palette=palette,
            ax=ax,
            alpha=0.7,
            s=50,
        )
        ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    else:
        ax.scatter(components[:, pc_x_idx], components[:, pc_y_idx], alpha=0.7, s=50)
    ax.set_xlabel(f"PC{pc_x} ({explained_var[pc_x_idx]:.1%} variance)")
    ax.set_ylabel(f"PC{pc_y} ({explained_var[pc_y_idx]:.1%} variance)")
    ax.set_title(f"PCA Scatter Plot: PC{pc_x} vs PC{pc_y}")
    ax.grid(True, alpha=0.3)


def _plot_biplot(
    ax,
    components,
    loadings,
    feature_names,
    explained_var,
    hue_data=None,
    hue_col=None,
    palette=None,
    pc_x=1,
    pc_y=2,
):
    """Plot PCA biplot with loadings."""
    pc_x_idx = pc_x - 1
    pc_y_idx = pc_y - 1
    if hue_data is not None:
        scatter_df = pd.DataFrame(
            {
                f"PC{pc_x}": components[:, pc_x_idx],
                f"PC{pc_y}": components[:, pc_y_idx],
                hue_col: hue_data,
            }
        )
        sns.scatterplot(
            data=scatter_df,
            x=f"PC{pc_x}",
            y=f"PC{pc_y}",
            hue=hue_col,
            palette=palette,
            ax=ax,
            alpha=0.5,
            s=30,
            legend=True,
        )
        ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    else:
        ax.scatter(components[:, pc_x_idx], components[:, pc_y_idx], alpha=0.5, s=30)
    scale = components[:, [pc_x_idx, pc_y_idx]].max() * 0.8
    for i, feature in enumerate(feature_names):
        ax.arrow(
            0,
            0,
            loadings[pc_x_idx, i] * scale,
            loadings[pc_y_idx, i] * scale,
            head_width=0.05,
            head_length=0.05,
            fc="red",
            ec="red",
            alpha=0.6,
        )
        ax.text(
            loadings[pc_x_idx, i] * scale * 1.1,
            loadings[pc_y_idx, i] * scale * 1.1,
            feature,
            fontsize=9,
            ha="center",
            va="center",
        )
    ax.set_xlabel(f"PC{pc_x} ({explained_var[pc_x_idx]:.1%} variance)")
    ax.set_ylabel(f"PC{pc_y} ({explained_var[pc_y_idx]:.1%} variance)")
    ax.set_title(f"PCA Biplot: PC{pc_x} vs PC{pc_y}")
    ax.grid(True, alpha=0.3)


def _plot_variance(ax, explained_var):
    """Plot explained variance."""
    n_components = len(explained_var)
    cumulative_var = np.cumsum(explained_var)
    x = np.arange(1, n_components + 1)
    ax.bar(x, explained_var, alpha=0.6, label="Individual")
    ax.plot(x, cumulative_var, "ro-", linewidth=2, label="Cumulative")
    ax.set_xlabel("Principal Component")
    ax.set_ylabel("Explained Variance Ratio")
    ax.set_title("Explained Variance by Component")
    ax.set_xticks(x)
    ax.legend()
    ax.grid(True, alpha=0.3, axis="y")

Brick Info

version v0.1.0
python 3.10, 3.11, 3.12, 3.13
requirements
  • seaborn
  • xxhash
  • numpy
  • scikit-learn
  • pyarrow
  • matplotlib
  • polars[pyarrow]
  • pillow
  • pandas
  • joblib