Dim. Reduction T-SNE

Visualize high-dimensional data using t-Distributed Stochastic Neighbor Embedding.

Dim. Reduction T-SNE

Processing

This brick reduces high-dimensional data into 2 or 3 dimensions using t-SNE (t-Distributed Stochastic Neighbor Embedding). It helps visualize complex datasets by grouping similar items together in a way that reveals patterns or clusters that might not be visible in the raw data.

The brick processes numeric columns, optionally scales them to a standard range, calculates the projection, and generates two main results: a scatter plot visualization of the data and a dataset containing the new coordinate values.

Inputs

data
The dataset containing the information you want to analyze. This must contain numeric columns (e.g., measurements, scores, counts) to be processed.

Inputs Types

Input Types
data DataFrame, ArrowTable

You can check the list of supported types here: Available Type Hints.

Outputs

Scaler
The scaling model used to normalize the data before processing. This can be passed to subsequent bricks if you need to reverse the scaling or apply the same scaling to new data.
TSNE Image
A visual representation of the t-SNE result. This is a scatter plot where similar data points are grouped closer together.
TSNE Projections
The processed data containing the new calculated coordinates (dimensions).

The TSNE Projections output contains the following specific data fields:

  • TSNE1: The coordinate for the first dimension.
  • TSNE2: The coordinate for the second dimension.
  • TSNE3: (If 3 components selected) The coordinate for the third dimension.
  • {ID Column}: (If provided) The identifier from your original data.

Outputs Types

Output Types
Scaler Any
TSNE Image MediaData,PILImage
TSNE Projections DataFrame

You can check the list of supported types here: Available Type Hints.

Options

The Dim. Reduction T-SNE brick contains some changeable options:

Columns for T-SNE
Select specific numeric columns to use for the analysis. If left empty, all numeric columns in the dataset will be used.
Number of Components
The number of dimensions to reduce the data into.
  • 2: Reduces data to 2 dimensions (best for standard 2D plots).
  • 3: Reduces data to 3 dimensions.
Component for X-axis
Selects which of the calculated dimensions (components) to plot on the horizontal X-axis. Usually "1".
Component for Y-axis
Selects which of the calculated dimensions (components) to plot on the vertical Y-axis. Usually "2".
Distance Metric
The method used to calculate the distance/similarity between data points.
  • Euclidean: Standard straight-line distance. Good for general physical data.
  • Cosine: Measures the angle between vectors. Good for text data or high-dimensional sparse data.
  • Manhattan: Grid-like distance (L1 norm).
  • Chebyshev: The greatest difference along any coordinate dimension.
Perplexity
Controls how the algorithm balances attention between local and global aspects of your data. It is roughly a guess about the number of close neighbors each point has.
  • Lower values (5-30): Focuses on local structure (small, tight groups).
  • Higher values (30-100): Focuses on global structure (overall shape).
Auto Learning Rate
If enabled, the algorithm automatically calculates the best learning rate based on your data size.
Learning Rate (if not Auto)
The step size for the optimization algorithm. If the result looks like a "ball" with no structure, try increasing this. If it looks like a condensed cloud with points far apart, try decreasing it.
Max Iterations
The maximum number of times the algorithm will run to refine the shape. Higher numbers take longer but may produce more stable results.
Early Exaggeration
Controls how tight natural clusters in the original space are in the embedded space and how much space will be between them.
Angle (Barnes-Hut)
Controls the trade-off between speed and accuracy.
  • Lower (e.g., 0.2): More accurate, but slower.
  • Higher (e.g., 0.8): Faster, but less accurate.
Min Gradient Norm
A technical threshold to stop the processing early if the changes become insignificant.
Scale Data
If enabled, data is standardized (mean=0, variance=1) before processing. This is highly recommended to prevent columns with large numbers (e.g., "Salary") from dominating columns with small numbers (e.g., "Age").
ID Column
The name of a column in your input data that serves as a unique identifier (e.g., "Product_ID", "Email"). This column will be preserved in the output dataset so you can map the results back to your original items.
Hue Column (for scatter)
The name of a column to use for coloring the points in the plot (e.g., "Category", "Status").
Exclude Hue
Determines if the column used for coloring (the Hue) should be part of the mathematical calculation. If the hue column is not numerical or boolean, it is automatically excluded.
  • True (Active): The Hue column is excluded from the dimensionality reduction. The algorithm ignores this data when calculating positions/clusters, using it only to assign colors in the final result. This is best when the Hue represents a label, outcome, or category (e.g., "Customer Type") that you want to visualize but don't want influencing the clustering logic.
  • False (Inactive): The Hue column is included in the dimensionality reduction. The values in this column will actively influence where the data points are positioned on the graph.
Color Palette
The color scheme used for the visualization.
Output Type
The format of the resulting image.
  • array: Returns a NumPy array representation of the image.
  • pil: Returns a PIL Image object.
  • bytes: Returns the raw image file bytes.
  • bytesio: Returns a BytesIO stream.
Number of Jobs
How many CPU cores to use for parallel processing. More cores speed up training but use more system resources.
Random State
A seed number to ensure the results are reproducible. Using the same number ensures the plot looks the same every time you run it.
Brick Caching
If enabled, the result is saved temporarily. Running the brick again with the exact same inputs will load the result from the cache instead of recalculating, speeding up the workflow.
Verbose
If enabled, detailed logs about the processing steps will be generated.
import logging
import io
import json
import xxhash
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import matplotlib
from scipy import sparse
from dataclasses import dataclass
from datetime import datetime
import hashlib
import tempfile
import sklearn
import joblib
from pathlib import Path

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from sklearn.manifold import TSNE
from sklearn.preprocessing import StandardScaler
from coded_flows.types import (
    Union,
    DataFrame,
    ArrowTable,
    MediaData,
    PILImage,
    Tuple,
    Any,
    DataSeries,
    Str,
)
from coded_flows.utils import CodedFlowsLogger

logger = CodedFlowsLogger(name="Dim. Reduction T-SNE", level=logging.INFO)
DataType = Union[
    pd.DataFrame, pl.DataFrame, np.ndarray, sparse.spmatrix, pd.Series, pl.Series
]


@dataclass
class _DatasetFingerprint:
    """Lightweight fingerprint of a dataset."""

    hash: str
    shape: tuple
    computed_at: str
    data_type: str
    method: str


class _UniversalDatasetHasher:
    """
    High-performance dataset hasher optimizing for zero-copy operations.
    """

    def __init__(
        self,
        data_size: int,
        method: str = "auto",
        sample_size: int = 100000,
        verbose: bool = False,
    ):
        self.method = method
        self.sample_size = sample_size
        self.data_size = data_size
        self.verbose = verbose

    def hash_data(self, data: DataType) -> _DatasetFingerprint:
        if isinstance(data, pd.DataFrame):
            return self._hash_pandas(data)
        elif isinstance(data, pl.DataFrame):
            return self._hash_polars(data)
        elif isinstance(data, pd.Series):
            return self._hash_pandas_series(data)
        elif isinstance(data, pl.Series):
            return self._hash_polars_series(data)
        elif isinstance(data, np.ndarray):
            return self._hash_numpy(data)
        elif sparse.issparse(data):
            return self._hash_sparse(data)
        else:
            raise TypeError(f"Unsupported data type: {type(data)}")

    def _hash_pandas(self, df: pd.DataFrame) -> _DatasetFingerprint:
        method = self._determine_method(self.data_size, self.method)
        target_df = df
        if method == "sampled":
            target_df = self._get_pandas_sample(df)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns.tolist(),
                "dtypes": {k: str(v) for (k, v) in df.dtypes.items()},
                "shape": df.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(target_df, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception:
            self._hash_pandas_fallback(hasher, target_df)
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas",
            method=method,
        )

    def _get_pandas_sample(self, df: pd.DataFrame) -> pd.DataFrame:
        if self.data_size <= self.sample_size:
            return df
        chunk = self.sample_size // 3
        head = df.iloc[:chunk]
        mid_idx = self.data_size // 2
        mid = df.iloc[mid_idx : mid_idx + chunk]
        tail = df.iloc[-chunk:]
        return pd.concat([head, mid, tail])

    def _hash_pandas_fallback(self, hasher, df: pd.DataFrame):
        for col in df.columns:
            val = df[col].astype(str).values
            hasher.update(val.astype(np.bytes_).tobytes())

    def _hash_polars(self, df: pl.DataFrame) -> _DatasetFingerprint:
        method = self._determine_method(self.data_size, self.method)
        target_df = df
        if method == "sampled" and self.data_size > self.sample_size:
            indices = self._get_sample_indices(self.data_size, self.sample_size)
            target_df = df.gather(indices)
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "columns": df.columns,
                "dtypes": [str(t) for t in df.dtypes],
                "shape": df.shape,
            },
        )
        row_hashes = target_df.hash_rows()
        hasher.update(memoryview(row_hashes.to_numpy()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=df.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars",
            method=method,
        )

    def _hash_pandas_series(self, series: pd.Series) -> _DatasetFingerprint:
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {
                "name": series.name if series.name else "None",
                "dtype": str(series.dtype),
                "shape": series.shape,
            },
        )
        try:
            row_hashes = pd.util.hash_pandas_object(series, index=False)
            hasher.update(memoryview(row_hashes.values))
        except Exception:
            hasher.update(memoryview(series.astype(str).values.tobytes()))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="pandas_series",
            method="full",
        )

    def _hash_polars_series(self, series: pl.Series) -> _DatasetFingerprint:
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"name": series.name, "dtype": str(series.dtype), "shape": series.shape},
        )
        try:
            row_hashes = series.hash()
            hasher.update(memoryview(row_hashes.to_numpy()))
        except Exception:
            hasher.update(str(series.to_list()).encode())
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=series.shape,
            computed_at=datetime.now().isoformat(),
            data_type="polars_series",
            method="full",
        )

    def _hash_numpy(self, arr: np.ndarray) -> _DatasetFingerprint:
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher,
            {"shape": arr.shape, "dtype": str(arr.dtype), "strides": arr.strides},
        )
        if arr.flags["C_CONTIGUOUS"] or arr.flags["F_CONTIGUOUS"]:
            hasher.update(memoryview(arr))
        else:
            hasher.update(memoryview(np.ascontiguousarray(arr)))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=arr.shape,
            computed_at=datetime.now().isoformat(),
            data_type="numpy",
            method="full",
        )

    def _hash_sparse(self, matrix: sparse.spmatrix) -> _DatasetFingerprint:
        if not (sparse.isspmatrix_csr(matrix) or sparse.isspmatrix_csc(matrix)):
            matrix = matrix.tocsr()
        hasher = xxhash.xxh128()
        self._hash_schema(
            hasher, {"shape": matrix.shape, "format": matrix.format, "nnz": matrix.nnz}
        )
        hasher.update(memoryview(matrix.data))
        hasher.update(memoryview(matrix.indices))
        hasher.update(memoryview(matrix.indptr))
        return _DatasetFingerprint(
            hash=hasher.hexdigest(),
            shape=matrix.shape,
            computed_at=datetime.now().isoformat(),
            data_type=f"sparse_{matrix.format}",
            method="sparse",
        )

    def _determine_method(self, rows: int, requested: str) -> str:
        if requested != "auto":
            return requested
        if rows < 5000000:
            return "full"
        return "sampled"

    def _hash_schema(self, hasher, schema):
        hasher.update(
            json.dumps(schema, sort_keys=True, separators=(",", ":")).encode()
        )

    def _get_sample_indices(self, total_rows: int, sample_size: int) -> list:
        chunk = sample_size // 3
        indices = list(range(min(chunk, total_rows)))
        mid_start = max(0, total_rows // 2 - chunk // 2)
        mid_end = min(mid_start + chunk, total_rows)
        indices.extend(range(mid_start, mid_end))
        last_start = max(0, total_rows - chunk)
        indices.extend(range(last_start, total_rows))
        return sorted(list(set(indices)))


def _plot_tsne_scatter(
    ax, x_data, y_data, x_label, y_label, hue_data, hue_col, palette
):
    """Plot T-SNE scatter plot with selected axes."""
    if hue_data is not None:
        scatter_df = pd.DataFrame({x_label: x_data, y_label: y_data, hue_col: hue_data})
        sns.scatterplot(
            data=scatter_df,
            x=x_label,
            y=y_label,
            hue=hue_col,
            palette=palette,
            ax=ax,
            alpha=0.7,
            s=50,
        )
        ax.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
    else:
        ax.scatter(x_data, y_data, alpha=0.7, s=50)
    ax.set_xlabel(x_label)
    ax.set_ylabel(y_label)
    ax.set_title("t-SNE Projection")
    ax.grid(True, alpha=0.3)


def tsne_analysis(
    data: Union[DataFrame, ArrowTable], options=None
) -> Tuple[Any, Union[MediaData, PILImage], DataFrame]:
    options = options or {}
    verbose = options.get("verbose", True)
    random_state = options.get("random_state", 42)
    activate_caching = options.get("activate_caching", False)
    output_type = options.get("output_type", "array")
    columns = options.get("columns", None)
    n_components = options.get("n_components", 2)
    scale_data = options.get("scale_data", True)
    id_column = options.get("id_column", "")
    hue = options.get("hue", "")
    exclude_hue = options.get("exclude_hue", True)
    palette = options.get("palette", "husl")
    n_jobs_str = options.get("n_jobs", "1")
    n_jobs_int = -1 if n_jobs_str == "All" else int(n_jobs_str)
    pc_x = options.get("pc_x", 1)
    pc_y = options.get("pc_y", 2)
    perplexity = options.get("perplexity", 30.0)
    auto_lr = options.get("auto_learning_rate", True)
    custom_lr = options.get("learning_rate", 200.0)
    learning_rate = "auto" if auto_lr else custom_lr
    n_iter = options.get("n_iter", 1000)
    early_exaggeration = options.get("early_exaggeration", 12.0)
    metric_raw = options.get("metric", "Euclidean")
    metric = metric_raw.lower() if metric_raw else "euclidean"
    angle = options.get("angle", 0.5)
    min_grad_norm = options.get("min_grad_norm", 1e-07)
    dpi = 300
    verbose and logger.info(
        f"Starting T-SNE with {n_components} components. Displaying PC{pc_x} vs PC{pc_y}"
    )
    verbose and logger.info(
        f"Config: metric={metric}, lr={learning_rate}, angle={angle}"
    )
    if pc_x > n_components or pc_y > n_components:
        verbose and logger.warning(
            f"Selected PCs (X:{pc_x}, Y:{pc_y}) exceed n_components ({n_components}). Clamping to valid range."
        )
        pc_x = min(pc_x, n_components)
        pc_y = min(pc_y, n_components)
    TSNE_Image = None
    TSNE_Projections = pd.DataFrame()
    Scaler = None
    plot_components = None
    hue_data = None
    hue_col = None
    df = None
    try:
        if isinstance(data, pl.DataFrame):
            verbose and logger.info(f"Converting polars DataFrame to pandas")
            df = data.to_pandas()
        elif isinstance(data, pa.Table):
            verbose and logger.info(f"Converting Arrow table to pandas")
            df = data.to_pandas()
        elif isinstance(data, pd.DataFrame):
            verbose and logger.info(f"Input is already pandas DataFrame")
            df = data
        else:
            raise ValueError(f"Unsupported data type: {type(data).__name__}")
    except Exception as e:
        raise RuntimeError(
            f"Failed to convert input data to pandas DataFrame: {e}"
        ) from e
    if df.empty:
        raise ValueError("Input DataFrame is empty")
    verbose and logger.info(
        f"Processing DataFrame with {df.shape[0]:,} rows × {df.shape[1]:,} columns"
    )
    if hue and hue.strip():
        if hue not in df.columns:
            raise ValueError(f"Hue column '{hue}' not found")
        hue_col = hue
    if id_column and id_column.strip() and (id_column not in df.columns):
        raise ValueError(f"ID column '{id_column}' not found in DataFrame")
    if columns and len(columns) > 0:
        missing_cols = [col for col in columns if col not in df.columns]
        if missing_cols:
            raise ValueError(f"Columns not found in DataFrame: {missing_cols}")
        feature_cols = list(columns)
    else:
        feature_cols = df.select_dtypes(include=["number", "bool"]).columns.tolist()
        if not feature_cols:
            raise ValueError("No numeric columns found in DataFrame")
    if id_column and id_column.strip() and (id_column in feature_cols):
        feature_cols.remove(id_column)
    if exclude_hue and hue_col and (hue_col in feature_cols):
        feature_cols.remove(hue_col)
    if not feature_cols:
        raise ValueError("No feature columns remaining after excluding ID column")
    skip_computation = False
    cache_folder = None
    all_hash = None
    if activate_caching:
        verbose and logger.info(f"Caching is active")
        data_hasher = _UniversalDatasetHasher(df.shape[0], verbose=verbose)
        X_hash = data_hasher.hash_data(df[feature_cols]).hash
        all_hash_base_text = f"HASH BASE TEXT TSNE_V3Scikit Learn Version {sklearn.__version__}{X_hash}_{n_components}_{scale_data}_{random_state}{perplexity}_{learning_rate}_{n_iter}_{early_exaggeration}{metric}_{angle}_{min_grad_norm}{exclude_hue}{sorted(feature_cols)}"
        all_hash = hashlib.sha256(all_hash_base_text.encode("utf-8")).hexdigest()
        verbose and logger.info(f"Hash was computed: {all_hash}")
        temp_folder = Path(tempfile.gettempdir())
        cache_folder = temp_folder / "coded-flows-cache"
        cache_folder.mkdir(parents=True, exist_ok=True)
        tsne_proj_path = cache_folder / f"tsne_proj_{all_hash}.parquet"
        tsne_est_path = cache_folder / f"tsne_estimator_{all_hash}.joblib"
        scaler_path = cache_folder / f"pca_scaler_{all_hash}.joblib"
        if (
            tsne_proj_path.is_file()
            and tsne_est_path.is_file()
            and scaler_path.is_file()
        ):
            verbose and logger.info(f"Cache hit! Loading results.")
            try:
                TSNE_Projections = pd.read_parquet(tsne_proj_path)
                tsne = joblib.load(tsne_est_path)
                Scaler = joblib.load(scaler_path)
                comps_for_plot_df = TSNE_Projections.copy()
                if id_column and id_column in comps_for_plot_df.columns:
                    comps_for_plot_df = comps_for_plot_df.drop(columns=[id_column])
                plot_components = comps_for_plot_df.values
                skip_computation = True
            except Exception as e:
                verbose and logger.warning(f"Cache load failed, recomputing: {e}")
                skip_computation = False
    if not skip_computation:
        X = df[feature_cols]
        if np.any(np.isnan(X)):
            verbose and logger.warning(f"Removing rows with missing values")
            mask = ~np.isnan(X).any(axis=1)
            X = X[mask]
            df_indices = df.index[mask]
        else:
            df_indices = df.index
        if X.shape[0] < perplexity:
            verbose and logger.warning(
                f"Data samples ({X.shape[0]}) < Perplexity ({perplexity}). Adjusting perplexity."
            )
            perplexity = max(5.0, float(X.shape[0] - 1))
        if scale_data:
            verbose and logger.info(f"Scaling features")
            Scaler = StandardScaler()
            X_transformed = Scaler.fit_transform(X)
        else:
            X_transformed = X
        verbose and logger.info(f"Fitting T-SNE model (metric={metric})")
        tsne = TSNE(
            n_components=n_components,
            perplexity=perplexity,
            learning_rate=learning_rate,
            max_iter=n_iter,
            early_exaggeration=early_exaggeration,
            metric=metric,
            angle=angle,
            min_grad_norm=min_grad_norm,
            random_state=random_state,
            n_jobs=n_jobs_int,
        )
        plot_components = tsne.fit_transform(X_transformed)
        tsne_cols = [f"TSNE{i + 1}" for i in range(n_components)]
        TSNE_Projections = pd.DataFrame(
            plot_components, columns=tsne_cols, index=df_indices
        )
        if id_column and id_column.strip():
            id_data = df.loc[df_indices, id_column].values
            TSNE_Projections.insert(0, id_column, id_data)
        if activate_caching and cache_folder and all_hash:
            try:
                tsne_proj_path = cache_folder / f"tsne_proj_{all_hash}.parquet"
                tsne_est_path = cache_folder / f"tsne_estimator_{all_hash}.joblib"
                scaler_path = cache_folder / f"pca_scaler_{all_hash}.joblib"
                TSNE_Projections.to_parquet(tsne_proj_path)
                joblib.dump(tsne, tsne_est_path)
                if Scaler is not None:
                    joblib.dump(Scaler, scaler_path)
                verbose and logger.info(f"Results saved to cache")
            except Exception as e:
                verbose and logger.warning(f"Failed to save cache: {e}")
    try:
        idx_x = pc_x - 1
        idx_y = pc_y - 1
        x_values = plot_components[:, idx_x]
        y_values = plot_components[:, idx_y]
        x_label = f"t-SNE Dimension {pc_x}"
        y_label = f"t-SNE Dimension {pc_y}"
        if hue_col:
            hue_data = df.loc[TSNE_Projections.index, hue].values
            verbose and logger.info(f"Using hue column: '{hue}'")
        verbose and logger.info(
            f"Creating scatter visualization ({x_label} vs {y_label})"
        )
        (fig, ax) = plt.subplots(figsize=(10, 7))
        _plot_tsne_scatter(
            ax, x_values, y_values, x_label, y_label, hue_data, hue_col, palette
        )
        verbose and logger.info(f"Rendering to {output_type}")
        buf = io.BytesIO()
        fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
        buf.seek(0)
        if output_type == "bytesio":
            TSNE_Image = buf
        elif output_type == "bytes":
            TSNE_Image = buf.getvalue()
            buf.close()
        elif output_type == "pil":
            TSNE_Image = Image.open(buf)
            TSNE_Image.load()
            buf.close()
        elif output_type == "array":
            img = Image.open(buf)
            TSNE_Image = np.array(img)
            buf.close()
        else:
            raise ValueError(f"Invalid output_type: '{output_type}'")
        plt.close(fig)
    except Exception as e:
        verbose and logger.error(f"Error during plotting: {e}")
        plt.close("all")
        raise RuntimeError(f"Plotting failed: {e}") from e
    verbose and logger.info(f"Successfully completed T-SNE analysis")
    return (Scaler, TSNE_Image, TSNE_Projections)

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • scikit-learn
  • pandas
  • pyarrow
  • polars[pyarrow]
  • numpy
  • numba>=0.56.0
  • joblib
  • matplotlib
  • seaborn
  • pillow
  • xxhash