Dim. Reduction UMAP

Visualize high-dimensional data using Uniform Manifold Approximation and Projection.

Dim. Reduction UMAP

Processing

This brick simplifies complex, high-dimensional datasets into a 2D or 3D representation using the Uniform Manifold Approximation and Projection (UMAP) algorithm.

In simple terms, it takes data with many characteristics (columns) and "flattens" it into a visual map. It is designed to keep similar items close together and dissimilar items far apart. This is widely used for:

  • Clustering: Seeing natural groups in data (e.g., customer segments, biological species).
  • Anomaly Detection: Spotting points that don't fit into any group.
  • Visualization: Making datasets with dozens of columns understandable in a simple X/Y scatter plot.

The brick handles data scaling, dimensionality reduction, and automatic visualization generation.

Inputs

data
The dataset you want to analyze. This should contain the numeric features (columns) you want to use for clustering or visualization.

Inputs Types

Input Types
data DataFrame, ArrowTable

You can check the list of supported types here: Available Type Hints.

Outputs

UMAP
The trained UMAP model object. This contains the mathematical rules learned from your data and can be used in custom Python scripts to transform new data later.
Scaler
The scaling model used to normalize the data before processing. This can be used later to transform new data in the exact same way.
UMAP_Image
A visualization of the data. This is a scatter plot showing the data points projected onto the selected dimensions (e.g., Dimension 1 vs. Dimension 2).
UMAP_Projections
The processed data containing the new coordinates. This is a DataFrame where each row corresponds to the original data, but with new columns representing the reduced dimensions.

The UMAP_Projections output contains the following specific data fields:

  • {ID Column}: If an ID column was specified in the options, it appears here to help identify rows.
  • UMAP1: The coordinate for the first reduced dimension.
  • UMAP2: The coordinate for the second reduced dimension.
  • UMAP3: (If "Number of Components" is set to 3) The coordinate for the third dimension.

Outputs Types

Output Types
UMAP Any
Scaler Any
UMAP Image MediaData, PILImage
UMAP Projections DataFrame

You can check the list of supported types here: Available Type Hints.

Options

The Dim. Reduction UMAP brick contains some changeable options:

Columns for UMAP
Select specifically which numeric columns to use for the calculation. If left empty, the brick automatically selects all numeric columns.
Number of Components
The number of dimensions to reduce the data down to.
Component for X-axis
Selects which of the calculated dimensions to plot on the horizontal (X) axis of the output image. Usually set to 1.
Component for Y-axis
Selects which of the calculated dimensions to plot on the vertical (Y) axis of the output image. Usually set to 2.
Number of Neighbors
Controls how the algorithm balances local detail vs. global structure.
  • Low values (e.g., 5-10): Focuses on local structure. Good for finding small clusters, but might lose the big picture.
  • High values (e.g., 50-100): Focuses on global structure. Good for seeing the overall shape of the data, but might obscure fine details.
Minimum Distance
Controls how tightly the points are allowed to pack together in the output.
  • Low values (e.g., 0.1): Points clump tightly together. Good for distinct clustering.
  • High values (e.g., 0.5): Points are more evenly distributed. Good for preserving the topological structure.
Distance Metric
The mathematical rule used to calculate the "distance" between two data points.
  • Euclidean: The standard "straight line" distance. Works well for most general data.
  • Cosine: Measures the angle between vectors. Excellent for text data or word embeddings.
  • Manhattan: (Taxicab geometry) Measures distance along axes at right angles.
  • Correlation, Chebyshev, Canberra, Braycurtis: Specialized metrics for specific statistical use cases.
Scale Data
If enabled, the data is normalized (StandardScaler) before processing. This is highly recommended so that columns with large numbers (e.g., "Salary") don't dominate columns with small numbers (e.g., "Age").
ID Column
The name of a column in your input data that acts as a unique identifier (e.g., "Product_ID", "Email"). This column will be excluded from calculations but added back to the UMAP_Projections output so you can identify your rows.
Hue Column (for scatter)
The name of a column to use for coloring the dots in the output image. For example, if you set this to "Species", points will be colored based on their species.
Exclude Hue
Determines if the column used for coloring (the Hue) should be part of the mathematical calculation. If the hue column is not numerical or boolean, it is automatically excluded.
  • True (Active): The Hue column is excluded from the dimensionality reduction. The algorithm ignores this data when calculating positions/clusters, using it only to assign colors in the final result. This is best when the Hue represents a label, outcome, or category (e.g., "Customer Type") that you want to visualize but don't want influencing the clustering logic.
  • False (Inactive): The Hue column is included in the dimensionality reduction. The values in this column will actively influence where the data points are positioned on the graph.
Color Palette
The color scheme used for the visualization. husl, deep, muted, bright, pastel, dark, colorblind: Various seaborn color palettes.
Output Type
Determines the format of the UMAP Image output.
  • array: Returns a NumPy array (standard for image processing).
  • pil: Returns a PIL Image object.
  • bytes: Returns the raw file bytes (useful for saving directly to disk).
  • bytesio: Returns a BytesIO stream.
Number of Jobs
How many CPU cores to use for parallel processing. More cores speed up training but use more system resources.
Random State
A seed number for the random number generator. Using the same number ensures that running the brick twice on the same data produces the exact same result.
Brick Caching
If enabled, results are saved temporarily. If you run the workflow again with the exact same data and settings, the brick loads the result from the cache instead of recalculating, which is significantly faster.
Verbose
If enabled, detailed logs about the progress (scaling, fitting, plotting) will be printed to the console.
import logging
import io
import json
import xxhash
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import matplotlib
from scipy import sparse
from dataclasses import dataclass
from datetime import datetime
import hashlib
import tempfile
import sklearn
import joblib
from pathlib import Path

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
import umap
from sklearn.preprocessing import StandardScaler
from coded_flows.types import (
    Union,
    DataFrame,
    ArrowTable,
    MediaData,
    PILImage,
    Tuple,
    Any,
)
from coded_flows.utils import CodedFlowsLogger

logger = CodedFlowsLogger(name="Dim. Reduction UMAP", level=logging.INFO)
DataType = Union[
    pd.DataFrame, pl.DataFrame, np.ndarray, sparse.spmatrix, pd.Series, pl.Series
]


@dataclass
class _DatasetFingerprint:
    """Lightweight fingerprint of a dataset."""

    hash: str
    shape: tuple
    computed_at: str
    data_type: str
    method: str


class _UniversalDatasetHasher:
    """
    High-performance dataset hasher optimizing for zero-copy operations.
    """

    def __init__(
        self,
        data_size: int,
        method: str = "auto",
        sample_size: int = 100000,
        verbose: bool = False,
    ):
        self.method = method
        self.sample_size = sample_size
        self.data_size = data_size
        self.verbose = verbose

    def hash_data(self, data: DataType) -> _DatasetFingerprint:
        if isinstance(data, pd.DataFrame):
            return self._hash_pandas(data)
        elif isinstance(data, pl.DataFrame):
            return self._hash_polars(data)
        elif isinstance(data, pd.Series):
            return self._hash_pandas_series(data)
        elif isinstance(data, pl.Series):
            return self._hash_polars_series(data)
        elif isinstance(data, np.ndarray):
            return self._hash_numpy(data)
        elif sparse.issparse(data):
            return self._hash_sparse(data)
        else:
            raise TypeError(f"Unsupported data type: {type(data)}")

    def _hash_pandas(self, df: pd.DataFrame) -> _DatasetFingerprint:
        method = self._determine_method(self.data_size, self.method)
        target_df = df
        if method == "sampled":
            target_df = self._get_pandas_sample(df)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns.tolist(),
                "dtypes": {k: str(v) for (k, v) in df.dtypes.items()},
                "shape": df.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(target_df, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception:
            self._hash_pandas_fallback(hasher, target_df)
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas",
            method=method,
        )

    def _get_pandas_sample(self, df: pd.DataFrame) -> pd.DataFrame:
        if self.data_size <= self.sample_size:
            return df
        chunk = self.sample_size // 3
        head = df.iloc[:chunk]
        mid_idx = self.data_size // 2
        mid = df.iloc[mid_idx : mid_idx + chunk]
        tail = df.iloc[-chunk:]
        return pd.concat([head, mid, tail])

    def _hash_pandas_fallback(self, hasher, df: pd.DataFrame):
        for col in df.columns:
            val = df[col].astype(str).values
            hasher.update(val.astype(np.bytes_).tobytes())

    def _hash_polars(self, df: pl.DataFrame) -> _DatasetFingerprint:
        method = self._determine_method(self.data_size, self.method)
        target_df = df
        if method == "sampled" and self.data_size > self.sample_size:
            indices = self._get_sample_indices(self.data_size, self.sample_size)
            target_df = df.gather(indices)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns,
                "dtypes": [str(t) for t in df.dtypes],
                "shape": df.shape,
            },
        )
        row_hashes = target_df.hash_rows()
        hasher.update(memoryview(row_hashes.to_numpy()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars",
            method=method,
        )

    def _hash_pandas_series(self, series: pd.Series) -> _DatasetFingerprint:
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "name": series.name if series.name else "None",
                "dtype": str(series.dtype),
                "shape": series.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(series, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception:
            hasher.update(memoryview(series.astype(str).values.tobytes()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas_series",
            method="full",
        )

    def _hash_polars_series(self, series: pl.Series) -> _DatasetFingerprint:
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"name": series.name, "dtype": str(series.dtype), "shape": series.shape},
        )
        try:
            row_hashes = series.hash()
            hasher.update(memoryview(row_hashes.to_numpy()))
        except Exception:
            hasher.update(str(series.to_list()).encode())
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars_series",
            method="full",
        )

    def _hash_numpy(self, arr: np.ndarray) -> _DatasetFingerprint:
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"shape": arr.shape, "dtype": str(arr.dtype), "strides": arr.strides},
        )
        if arr.flags["C_CONTIGUOUS"] or arr.flags["F_CONTIGUOUS"]:
            hasher.update(memoryview(arr))
        else:
            hasher.update(memoryview(np.ascontiguousarray(arr)))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=arr.shape,
            computed_at=datetime.now().isoformat(),
            data_type="numpy",
            method="full",
        )

    def _hash_sparse(self, matrix: sparse.spmatrix) -> _DatasetFingerprint:
        if not (sparse.isspmatrix_csr(matrix) or sparse.isspmatrix_csc(matrix)):
            matrix = matrix.tocsr()
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher, {"shape": matrix.shape, "format": matrix.format, "nnz": matrix.nnz}
        )
        hasher.update(memoryview(matrix.data))
        hasher.update(memoryview(matrix.indices))
        hasher.update(memoryview(matrix.indptr))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=matrix.shape,
            computed_at=datetime.now().isoformat(),
            data_type=f"sparse_{matrix.format}",
            method="sparse",
        )

    def _determine_method(self, rows: int, requested: str) -> str:
        if requested != "auto":
            return requested
        if rows < 5000000:
            return "full"
        return "sampled"

    def _hash_schema(self, hasher, schema):
        hasher.update(
            json.dumps(schema, sort_keys=True, separators=(",", ":")).encode()
        )

    def _get_sample_indices(self, total_rows: int, sample_size: int) -> list:
        chunk = sample_size // 3
        indices = list(range(min(chunk, total_rows)))
        mid_start = max(0, total_rows // 2 - chunk // 2)
        mid_end = min(mid_start + chunk, total_rows)
        indices.extend(range(mid_start, mid_end))
        last_start = max(0, total_rows - chunk)
        indices.extend(range(last_start, total_rows))
        return sorted(list(set(indices)))


def _plot_umap_scatter(
    ax, x_data, y_data, x_label, y_label, hue_data, hue_col, palette
):
    """Plot UMAP scatter plot with selected axes."""
    if hue_data is not None:
        scatter_df = pd.DataFrame({x_label: x_data, y_label: y_data, hue_col: hue_data})
        sns.scatterplot(
            data=scatter_df,
            x=x_label,
            y=y_label,
            hue=hue_col,
            palette=palette,
            ax=ax,
            alpha=0.7,
            s=50,
        )
        ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    else:
        ax.scatter(x_data, y_data, alpha=0.7, s=50)
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title("UMAP Projection")
    ax.grid(True, alpha=0.3)


def umap_analysis(
    data: Union[DataFrame, ArrowTable], options=None
) -> Tuple[Any, Any, Union[MediaData, PILImage], DataFrame]:
    options = options or {}
    verbose = options.get("verbose", True)
    random_state = options.get("random_state", 42)
    activate_caching = options.get("activate_caching", False)
    output_type = options.get("output_type", "array")
    columns = options.get("columns", None)
    n_components = options.get("n_components", 2)
    scale_data = options.get("scale_data", True)
    id_column = options.get("id_column", "")
    hue = options.get("hue", "")
    exclude_hue = options.get("exclude_hue", True)
    palette = options.get("palette", "husl")
    n_jobs_str = options.get("n_jobs", "1")
    n_jobs_int = -1 if n_jobs_str == "All" else int(n_jobs_str)
    pc_x = options.get("pc_x", 1)
    pc_y = options.get("pc_y", 2)
    n_neighbors = options.get("n_neighbors", 15)
    min_dist = options.get("min_dist", 0.1)
    metric_raw = options.get("metric", "Euclidean")
    metric = metric_raw.lower() if metric_raw else "euclidean"
    dpi = 300
    verbose and logger.info(
        f"Starting UMAP with {n_components} components. Displaying Dim{pc_x} vs Dim{pc_y}"
    )
    verbose and logger.info(
        f"Config: neighbors={n_neighbors}, min_dist={min_dist}, metric={metric}"
    )
    if pc_x > n_components or pc_y > n_components:
        verbose and logger.warning(
            f"Selected axes (X:{pc_x}, Y:{pc_y}) exceed n_components ({n_components}). Clamping to valid range."
        )
        pc_x = min(pc_x, n_components)
        pc_y = min(pc_y, n_components)
    UMAP = None
    UMAP_Image = None
    UMAP_Projections = pd.DataFrame()
    Scaler = None
    plot_components = None
    hue_data = None
    hue_col = None
    df = None
    try:
        if isinstance(data, pl.DataFrame):
            verbose and logger.info(f"Converting polars DataFrame to pandas")
            df = data.to_pandas()
        elif isinstance(data, pa.Table):
            verbose and logger.info(f"Converting Arrow table to pandas")
            df = data.to_pandas()
        elif isinstance(data, pd.DataFrame):
            verbose and logger.info(f"Input is already pandas DataFrame")
            df = data
        else:
            raise ValueError(f"Unsupported data type: {type(data).__name__}")
    except Exception as e:
        raise RuntimeError(
            f"Failed to convert input data to pandas DataFrame: {e}"
        ) from e
    if df.empty:
        raise ValueError("Input DataFrame is empty")
    verbose and logger.info(
        f"Processing DataFrame with {df.shape[0]:,} rows × {df.shape[1]:,} columns"
    )
    if hue and hue.strip():
        if hue not in df.columns:
            raise ValueError(f"Hue column '{hue}' not found")
        hue_col = hue
    if id_column and id_column.strip() and (id_column not in df.columns):
        raise ValueError(f"ID column '{id_column}' not found in DataFrame")
    if columns and len(columns) > 0:
        missing_cols = [col for col in columns if col not in df.columns]
        if missing_cols:
            raise ValueError(f"Columns not found in DataFrame: {missing_cols}")
        feature_cols = list(columns)
    else:
        feature_cols = df.select_dtypes(include=["number", "bool"]).columns.tolist()
        if not feature_cols:
            raise ValueError("No numeric columns found in DataFrame")
    if id_column and id_column.strip() and (id_column in feature_cols):
        feature_cols.remove(id_column)
    if exclude_hue and hue_col and (hue_col in feature_cols):
        feature_cols.remove(hue_col)
    if not feature_cols:
        raise ValueError("No feature columns remaining after excluding ID column")
    skip_computation = False
    cache_folder = None
    all_hash = None
    if activate_caching:
        verbose and logger.info(f"Caching is active")
        data_hasher = _UniversalDatasetHasher(df.shape[0], verbose=verbose)
        X_hash = data_hasher.hash_data(df[feature_cols]).hash
        all_hash_base_text = f"HASH BASE TEXT UMAPScikit Learn Version {sklearn.__version__}{X_hash}_{n_components}_{scale_data}_{random_state}{n_neighbors}_{min_dist}_{metric}{exclude_hue}{n_jobs_int}{sorted(feature_cols)}"
        all_hash = hashlib.sha256(all_hash_base_text.encode("utf-8")).hexdigest()
        verbose and logger.info(f"Hash was computed: {all_hash}")
        temp_folder = Path(tempfile.gettempdir())
        cache_folder = temp_folder / "coded-flows-cache"
        cache_folder.mkdir(parents=True, exist_ok=True)
        umap_proj_path = cache_folder / f"umap_proj_{all_hash}.parquet"
        umap_est_path = cache_folder / f"umap_estimator_{all_hash}.joblib"
        scaler_path = cache_folder / f"umap_scaler_{all_hash}.joblib"
        if (
            umap_proj_path.is_file()
            and umap_est_path.is_file()
            and scaler_path.is_file()
        ):
            verbose and logger.info(f"Cache hit! Loading results.")
            try:
                UMAP_Projections = pd.read_parquet(umap_proj_path)
                UMAP = joblib.load(umap_est_path)
                Scaler = joblib.load(scaler_path)
                comps_for_plot_df = UMAP_Projections.copy()
                if id_column and id_column in comps_for_plot_df.columns:
                    comps_for_plot_df = comps_for_plot_df.drop(columns=[id_column])
                plot_components = comps_for_plot_df.values
                skip_computation = True
            except Exception as e:
                verbose and logger.warning(f"Cache load failed, recomputing: {e}")
                skip_computation = False
    if not skip_computation:
        X = df[feature_cols]
        if np.any(np.isnan(X)):
            verbose and logger.warning(f"Removing rows with missing values")
            mask = ~np.isnan(X).any(axis=1)
            X = X[mask]
            df_indices = df.index[mask]
        else:
            df_indices = df.index
        if X.shape[0] < n_neighbors:
            verbose and logger.warning(
                f"Data samples ({X.shape[0]}) < n_neighbors ({n_neighbors}). Adjusting n_neighbors."
            )
            n_neighbors = max(2, int(X.shape[0] - 1))
        if scale_data:
            verbose and logger.info(f"Scaling features")
            Scaler = StandardScaler()
            X_transformed = Scaler.fit_transform(X)
        else:
            X_transformed = X
        verbose and logger.info(f"Fitting UMAP model (metric={metric})")
        umap_params = {
            "n_neighbors": n_neighbors,
            "min_dist": min_dist,
            "n_components": n_components,
            "metric": metric,
            "n_jobs": n_jobs_int,
        }
        if n_jobs_int == 1:
            umap_params["random_state"] = random_state
        UMAP = umap.UMAP(**umap_params)
        plot_components = UMAP.fit_transform(X_transformed)
        umap_cols = [f"UMAP{i + 1}" for i in range(n_components)]
        UMAP_Projections = pd.DataFrame(
            plot_components, columns=umap_cols, index=df_indices
        )
        if id_column and id_column.strip():
            id_data = df.loc[df_indices, id_column].values
            UMAP_Projections.insert(0, id_column, id_data)
        if activate_caching and cache_folder and all_hash:
            try:
                umap_proj_path = cache_folder / f"umap_proj_{all_hash}.parquet"
                umap_est_path = cache_folder / f"umap_estimator_{all_hash}.joblib"
                scaler_path = cache_folder / f"umap_scaler_{all_hash}.joblib"
                UMAP_Projections.to_parquet(umap_proj_path)
                joblib.dump(UMAP, umap_est_path)
                if Scaler is not None:
                    joblib.dump(Scaler, scaler_path)
                verbose and logger.info(f"Results saved to cache")
            except Exception as e:
                verbose and logger.warning(f"Failed to save cache: {e}")
    try:
        idx_x = pc_x - 1
        idx_y = pc_y - 1
        x_values = plot_components[:, idx_x]
        y_values = plot_components[:, idx_y]
        x_label = f"UMAP Dimension {pc_x}"
        y_label = f"UMAP Dimension {pc_y}"
        if hue_col:
            hue_data = df.loc[UMAP_Projections.index, hue].values
            verbose and logger.info(f"Using hue column: '{hue}'")
        verbose and logger.info(
            f"Creating scatter visualization ({x_label} vs {y_label})"
        )
        (fig, ax) = plt.subplots(figsize=(10, 7))
        _plot_umap_scatter(
            ax, x_values, y_values, x_label, y_label, hue_data, hue_col, palette
        )
        verbose and logger.info(f"Rendering to {output_type}")
        buf = io.BytesIO()
        fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
        buf.seek(0)
        if output_type == "bytesio":
            UMAP_Image = buf
        elif output_type == "bytes":
            UMAP_Image = buf.getvalue()
            buf.close()
        elif output_type == "pil":
            UMAP_Image = Image.open(buf)
            UMAP_Image.load()
            buf.close()
        elif output_type == "array":
            img = Image.open(buf)
            UMAP_Image = np.array(img)
            buf.close()
        else:
            raise ValueError(f"Invalid output_type: '{output_type}'")
        plt.close(fig)
    except Exception as e:
        verbose and logger.error(f"Error during plotting: {e}")
        plt.close("all")
        raise RuntimeError(f"Plotting failed: {e}") from e
    verbose and logger.info(f"Successfully completed UMAP analysis")
    return (UMAP, Scaler, UMAP_Image, UMAP_Projections)

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • scikit-learn
  • pandas
  • pyarrow
  • polars[pyarrow]
  • numpy
  • umap-learn
  • numba>=0.56.0
  • joblib
  • matplotlib
  • seaborn
  • pillow
  • xxhash