Pairplot Image

Generate a pairplot visualization. Automatically adjusts DPI based on feature count to prevent memory errors.

Pairplot Image

Processing

This function takes tabular data (Pandas DataFrame, Polars DataFrame, or PyArrow Table) and generates a pairplot visualization, which displays pairwise relationships between features. It automatically selects all numeric columns if no specific columns are provided, and allows conditioning the visualization using a hue column. The resulting image is rendered to memory and returned in a user-specified format (NumPy array, PIL Image, bytes, or BytesIO stream).

Inputs

data
Input data used for visualization, typically containing multiple numeric features.

Inputs Types

Input Types
data DataFrame, ArrowTable

You can check the list of supported types here: Available Type Hints.

Outputs

image
The generated pairplot visualization. The specific format depends on the 'Output Type' option selected.

Outputs Types

Output Types
image MediaData, PILImage

You can check the list of supported types here: Available Type Hints.

Options

The Pairplot Image brick contains some changeable options:

Columns to Plot
List of specific columns to include in the pairplot matrix. If left empty, the function defaults to using all numeric columns found in the input data.
Hue Column
Name of the column used to color code the points in the plot based on categorical values.
Color Palette
The color scheme used for rendering the plot. Available choices include standard Seaborn palettes like husl, deep, muted, etc.
Diagonal Plot Type
Specifies the type of plot drawn on the diagonal axes, such as hist (histogram) or kde (Kernel Density Estimate).
Only Lower
If enabled, only the lower triangle of the plot matrix is drawn, making the output cleaner when analyzing symmetry is unnecessary.
Output Type
Defines the format of the returned image object: NumPy array (array), PIL Image object (pil), raw bytes (bytes), or BytesIO stream (bytesio).
Verbose
If enabled, detailed logs and information about the execution process are printed.
import logging
import io
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from coded_flows.types import Union, DataFrame, ArrowTable, MediaData, PILImage
from coded_flows.utils import CodedFlowsLogger

logger = CodedFlowsLogger(name="Pairplot Image", level=logging.INFO)


def pairplot(
    data: Union[DataFrame, ArrowTable], options=None
) -> Union[MediaData, PILImage]:
    options = options or {}
    verbose = options.get("verbose", True)
    output_type = options.get("output_type", "array")
    columns = options.get("columns", None)
    hue = options.get("hue", "")
    palette = options.get("palette", "husl")
    diag_kind = options.get("diag_kind", "auto")
    corner = options.get("corner", False)
    image = None
    verbose and logger.info(
        f"Starting pairplot generation with output type: '{output_type}'"
    )
    try:
        df = None
        if isinstance(data, pl.DataFrame):
            verbose and logger.info(f"Converting Polars DataFrame to Pandas")
            df = data.to_pandas()
        elif isinstance(data, pa.Table):
            verbose and logger.info(f"Converting Arrow Table to Pandas")
            df = data.to_pandas()
        elif isinstance(data, pd.DataFrame):
            verbose and logger.info(f"Input is already Pandas DataFrame")
            df = data
        else:
            raise ValueError(f"Unsupported data type: {type(data).__name__}")
        if df.empty:
            raise ValueError("Input DataFrame is empty")
        verbose and logger.info(
            f"Processing DataFrame with {df.shape[0]:,} rows × {df.shape[1]:,} columns"
        )
        plot_cols = []
        if columns and len(columns) > 0:
            missing_cols = [col for col in columns if col not in df.columns]
            if missing_cols:
                raise ValueError(f"Columns not found in DataFrame: {missing_cols}")
            plot_cols = list(columns)
            verbose and logger.info(f"Using specified columns: {plot_cols}")
        else:
            plot_cols = df.select_dtypes(include=[np.number]).columns.tolist()
            if not plot_cols:
                raise ValueError("No numeric columns found in DataFrame")
            verbose and logger.info(f"Using all numeric columns: {plot_cols}")
        hue_col = None
        if hue and hue.strip():
            if hue not in df.columns:
                raise ValueError(f"Hue column '{hue}' not found in DataFrame")
            if hue not in plot_cols:
                plot_cols.append(hue)
            hue_col = hue
            verbose and logger.info(f"Using hue column: '{hue}'")
        n_features = len(plot_cols)
        dpi = 300
        if n_features > 20:
            dpi = 200
            verbose and logger.info(
                f"Feature count is {n_features} (>20). Reducing DPI to {dpi} to prevent memory overflow."
            )
        elif n_features >= 30:
            dpi = 150
            verbose and logger.info(
                f"Feature count is {n_features} (>30). Reducing DPI to {dpi} for performance."
            )
        else:
            verbose and logger.info(
                f"Feature count is {n_features}. Using standard high DPI ({dpi})."
            )
        pairplot_obj = sns.pairplot(
            df[plot_cols],
            hue=hue_col,
            palette=palette if hue_col else None,
            diag_kind=diag_kind,
            corner=corner,
        )
        verbose and logger.info("Rendering plot to buffer...")
        fig = getattr(pairplot_obj, "figure", getattr(pairplot_obj, "fig", None))
        if fig is None:
            raise RuntimeError(
                "Could not retrieve Matplotlib Figure from Seaborn PairGrid object."
            )
        buf = io.BytesIO()
        fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
        buf.seek(0)
        plt.close(fig)
        Image.MAX_IMAGE_PIXELS = None
        if output_type == "bytesio":
            image = buf
        elif output_type == "bytes":
            image = buf.getvalue()
            buf.close()
        elif output_type == "pil":
            image = Image.open(buf)
        elif output_type == "array":
            pil_img = Image.open(buf)
            image = np.array(pil_img)
            buf.close()
        else:
            raise ValueError(f"Invalid output_type: '{output_type}'")
        verbose and logger.info(f"Successfully generated pairplot as {output_type}")
    except Exception as e:
        verbose and logger.error(f"Failed to generate pairplot: {e}")
        plt.close("all")
        raise
    return image

Brick Info

version v0.1.7
python 3.10, 3.11, 3.12, 3.13
requirements
  • matplotlib
  • pandas
  • pyarrow
  • polars[pyarrow]
  • numpy
  • seaborn
  • pillow