SHAP Impact Distribution

Visualizes the distribution of feature impacts. Shows how high/low feature values affect predictions.

SHAP Impact Distribution

Processing

This brick generates a visual summary (specifically a "beeswarm" plot) that reveals how your model makes decisions. It calculates the impact of every feature for every item in your dataset, showing you two things at once: which features are the most important overall, and how high or low values of those features shift the prediction.

In the resulting plot, features are sorted by importance. Each dot represents a single data point. The color represents the feature's value (Red = High, Blue = Low), and the horizontal position shows whether that value pushed the prediction higher or lower.

Inputs

explainer
The initialized SHAP Explainer tool. This must be a pre-calculated explainer object that understands your specific machine learning model.
data
The dataset you want to analyze. The explainer will calculate impact scores based on this data.
original data (optional)
The raw dataset containing the original, human-readable values.
  • Why use this? Often, the data input is pre-processed (e.g., normalized between 0 and 1). Providing the original data allows the plot to map the impact scores back to the real-world values (e.g., Age "25" instead of "0.3"), ensuring the red/blue color scale corresponds to meaningful numbers.

Inputs Types

Input Types
explainer Any
data DataFrame
original data DataFrame

You can check the list of supported types here: Available Type Hints.

Outputs

image
The generated visualization of the impact distribution. Depending on your settings, this is returned as an image object, raw bytes, or a numerical array.

Outputs Types

Output Types
image MediaData,PILImage

You can check the list of supported types here: Available Type Hints.

Options

The SHAP Impact Distribution brick contains some changeable options:

Class Index (Multiclass)
Used only for models that predict multiple categories (e.g., predicting "Red," "Green," or "Blue"). This integer determines which category to analyze.
Max Features to Display
Controls how many features are shown on the y-axis of the plot. The brick automatically selects the top most important features up to this number (default is 10).
Output Image Format
Determines the technical format of the output image.
  • pil: Returns a Python Imaging Library object. Best for further image manipulation within the workflow.
  • bytes: Returns the raw image file data.
  • array: Returns the image as a numpy matrix of numbers.
Verbose
When enabled, details about the analysis progress (e.g., detected class types, matrix shapes) will be printed to the logs.
import logging
import io
import inspect
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
from shap.explainers._explainer import Explainer
from coded_flows.types import DataFrame, Any, MediaData, Union, PILImage
from coded_flows.utils import CodedFlowsLogger

matplotlib.use("Agg")
import shap

logger = CodedFlowsLogger(name="SHAP Impact Distribution", level=logging.INFO)
logging.getLogger("shap").setLevel(logging.ERROR)


def _call_explainer(explainer, data, **base_kwargs):
    sig = inspect.signature(explainer.__call__)
    if "silent" in sig.parameters:
        return explainer(data, silent=True, **base_kwargs)
    return explainer(data, **base_kwargs)


def _to_numpy(df):
    """Convert pandas or polars DataFrame to numpy array."""
    if hasattr(df, "__class__") and "polars" in df.__class__.__module__:
        return df.to_numpy()
    elif hasattr(df, "values"):
        return df.values
    elif hasattr(df, "to_numpy"):
        return df.to_numpy()
    else:
        return np.array(df)


def _get_columns(df):
    """Get column names from pandas or polars DataFrame."""
    if hasattr(df, "columns"):
        columns = df.columns
        return list(columns) if not isinstance(columns, list) else columns
    return None


def shap_impact(
    explainer: Any,
    data: DataFrame,
    original_data: DataFrame = None,
    options: dict = None,
) -> Union[MediaData, PILImage]:
    options = options or {}
    verbose = options.get("verbose", True)
    class_index = options.get("class_index", 0)
    max_display = options.get("max_display", 10)
    image_format = options.get("image_format", "pil")
    image = None
    try:
        verbose and logger.info("Starting SHAP impact distribution analysis.")
        if not isinstance(explainer, Explainer):
            verbose and logger.error("Expects a Shap Explainer as an input.")
            raise ValueError("Expects a Shap Explainer as an input.")
        shap_explanation = _call_explainer(explainer, data)
        if original_data is not None:
            verbose and logger.info(
                "Original data provided. Overwriting explanation data for visualization."
            )
            columns = _get_columns(original_data)
            if columns is not None:
                shap_explanation.feature_names = columns
            shap_explanation.data = _to_numpy(original_data)
        ndim = len(shap_explanation.shape)
        explanation_to_plot = None
        if ndim == 2:
            verbose and logger.info(
                "Detected Regression or Binary Classification (single output)."
            )
            explanation_to_plot = shap_explanation
        elif ndim == 3:
            num_classes = shap_explanation.shape[2]
            verbose and logger.info(
                f"Detected Multiclass output with {num_classes} classes."
            )
            if class_index >= num_classes or class_index < 0:
                verbose and logger.warning(
                    f"Selected class_index {class_index} is out of bounds (0-{num_classes - 1}). Defaulting to 0."
                )
                class_index = 0
            verbose and logger.info(f"Selecting class index: {class_index}")
            explanation_to_plot = shap_explanation[:, :, class_index]
        else:
            raise ValueError(
                f"Unexpected SHAP explanation dimensionality: {ndim}. Expected 2 or 3."
            )
        verbose and logger.info(
            f"Generating Beeswarm Plot with max_display={max_display}."
        )
        fig = plt.figure()
        shap.plots.beeswarm(explanation_to_plot, max_display=max_display, show=False)
        buf = io.BytesIO()
        plt.savefig(buf, format="png", bbox_inches="tight", dpi=300)
        buf.seek(0)
        plt.close(plt.gcf())
        plt.close(fig)
        if image_format == "pil":
            image = Image.open(buf)
        elif image_format == "array":
            image = Image.open(buf)
            image = np.array(image)
        elif image_format == "bytes":
            image = buf.getvalue()
        verbose and logger.info("Image generation complete.")
    except Exception as e:
        verbose and logger.error(f"Error computing SHAP impact distribution: {str(e)}")
        raise e
    return image

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • matplotlib
  • numpy
  • pandas
  • polars[pyarrow]
  • pillow
  • numba>=0.56.0
  • shap