SHAP Explanation

Explains a single prediction (row) using a waterfall plot. Shows how each feature contributed to pushing the model output from the base value to the final prediction.

SHAP Explanation

Processing

This brick analyzes a single specific prediction from your machine learning model to explain why that result occurred. It generates a "Waterfall" plot and a detailed data table that visualizes how each individual factor (feature) pushed the prediction outcome higher or lower than the average (base) value.

For example, if a model predicts a house price, this brick helps you understand that while the "Size" added \(50k to the value, the "Age" removed \)10k. It processes the explanation data to sort features by their impact, ensuring the most important factors are highlighted.

Inputs

explainer
The calculated SHAP Explainer object. This is usually the output from a previous "Calculate SHAP" brick. It contains the mathematical rules needed to explain the model.
data
The dataset containing the rows you want to explain. This must match the format of the data used to train the model.
original data
(Optional) A dataset containing the raw, human-readable values. If provided, the waterfall plot will display these values (e.g., "Married") instead of the numerical encodings (e.g., "1") used by the model. This makes the chart easier to read.

Inputs Types

Input Types
explainer Any
data DataFrame
original data DataFrame

You can check the list of supported types here: Available Type Hints.

Outputs

image
The visual Waterfall plot showing the positive and negative forces acting on the prediction.
explanation
A structured list of the mathematical contributions for the specific row analyzed. It is returned as a DataFrame sorted by the magnitude of impact.

The explanation output contains the following specific data fields:

  • feature: The name of the data column (e.g., "Age", "Income") or "Base Value".
  • value: The actual value of that feature for this specific row (e.g., "25", "High").
  • contribution: The numerical amount this feature added to or subtracted from the prediction.

Outputs Types

Output Types
image MediaData,PILImage
explanation DataFrame

You can check the list of supported types here: Available Type Hints.

Options

The SHAP Explanation brick contains some changeable options:

Row Index
The specific row number in your dataset that you want to explain.
Class Index (Multiclass)
Used only if your model predicts multiple categories (e.g., "Red", "Blue", "Green"). This determines which category prediction to explain.
Max Display
The maximum number of features to show on the plot. If your data has 100 columns, setting this to 20 ensures the chart remains readable by grouping the smallest contributions into "Other".
Image Format
Determines the technical format of the output image.
  • pil: Returns a Python Image Library object. Best for further image processing in Python.
  • bytes: Returns the raw file bytes.
  • array: Returns the image as a grid of numbers (Numpy array). Best for computer vision tasks.
Verbose
When enabled, the brick logs detailed progress information to the console, which is helpful for troubleshooting.
import logging
import io
import inspect
import pandas as pd
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
from shap.explainers._explainer import Explainer
from coded_flows.types import DataFrame, Any, Tuple, MediaData, Union, PILImage
from coded_flows.utils import CodedFlowsLogger

matplotlib.use("Agg")
import shap

logger = CodedFlowsLogger(name="SHAP Explanation", level=logging.INFO)
logging.getLogger("shap").setLevel(logging.ERROR)


def _call_explainer(explainer, data, **base_kwargs):
    sig = inspect.signature(explainer.__call__)
    if "silent" in sig.parameters:
        return explainer(data, silent=True, **base_kwargs)
    return explainer(data, **base_kwargs)


def _to_numpy(df):
    """Convert pandas or polars DataFrame to numpy array."""
    if hasattr(df, "__class__") and "polars" in df.__class__.__module__:
        return df.to_numpy()
    elif hasattr(df, "values"):
        return df.values
    elif hasattr(df, "to_numpy"):
        return df.to_numpy()
    else:
        return np.array(df)


def _get_columns(df):
    """Get column names from pandas or polars DataFrame."""
    if hasattr(df, "columns"):
        columns = df.columns
        return list(columns) if not isinstance(columns, list) else columns
    return None


def shap_explanation(
    explainer: Any,
    data: DataFrame,
    original_data: DataFrame = None,
    options: dict = None,
) -> Tuple[Union[MediaData, PILImage], DataFrame]:
    options = options or {}
    verbose = options.get("verbose", True)
    row_index = options.get("row_index", 0)
    class_index = options.get("class_index", 0)
    max_display = options.get("max_display", 20)
    image_format = options.get("image_format", "pil")
    image = None
    explanation = pd.DataFrame()
    try:
        verbose and logger.info(
            f"Starting SHAP Individual Waterfall for row {row_index}."
        )
        if not isinstance(explainer, Explainer):
            verbose and logger.error("Expects a Shap Explainer as an input.")
            raise ValueError("Expects a Shap Explainer as an input.")
        shap_explanation = _call_explainer(explainer, data)
        if original_data is not None:
            verbose and logger.info(
                "Original data provided. Overwriting explanation data."
            )
            columns = _get_columns(original_data)
            if columns is not None:
                shap_explanation.feature_names = columns
            shap_explanation.data = _to_numpy(original_data)
        if row_index >= shap_explanation.shape[0] or row_index < 0:
            raise ValueError(
                f"Row index {row_index} is out of bounds (0-{shap_explanation.shape[0] - 1})."
            )
        ndim = len(shap_explanation.shape)
        single_explanation = None
        if ndim == 2:
            verbose and logger.info("Detected Regression or Binary Classification.")
            single_explanation = shap_explanation[row_index]
        elif ndim == 3:
            num_classes = shap_explanation.shape[2]
            verbose and logger.info(
                f"Detected Multiclass output with {num_classes} classes."
            )
            if class_index >= num_classes or class_index < 0:
                verbose and logger.warning(
                    f"Class index {class_index} out of bounds. Defaulting to 0."
                )
                class_index = 0
            verbose and logger.info(f"Selecting class index: {class_index}")
            single_explanation = shap_explanation[row_index, :, class_index]
        else:
            raise ValueError(f"Unexpected SHAP explanation dimensionality: {ndim}")
        feature_names = single_explanation.feature_names
        if not feature_names:
            feature_names = [f"Feature {i}" for i in range(single_explanation.shape[0])]
        values = np.array(single_explanation.data).flatten()
        contributions = np.array(single_explanation.values).flatten()
        rows = []
        for f, v, c in zip(feature_names, values, contributions):
            rows.append(
                {
                    "feature": f,
                    "value": v,
                    "contribution": c,
                    "abs_contribution": abs(c),
                }
            )
        rows.sort(key=lambda x: x["abs_contribution"], reverse=True)
        for r in rows:
            del r["abs_contribution"]
        base_value = single_explanation.base_values
        if isinstance(base_value, (np.ndarray, list)):
            base_value = base_value if np.ndim(base_value) == 0 else base_value[0]
        rows.insert(
            0,
            {
                "feature": "Base Value (Expected)",
                "value": np.nan,
                "contribution": base_value,
            },
        )
        explanation = pd.DataFrame(rows)
        verbose and logger.info(
            "Constructed sorted contribution DataFrame (Base Value + Features)."
        )
        verbose and logger.info("Generating Waterfall Plot.")
        fig = plt.figure()
        shap.plots.waterfall(single_explanation, max_display=max_display, show=False)
        buf = io.BytesIO()
        plt.savefig(buf, format="png", bbox_inches="tight", dpi=300)
        buf.seek(0)
        plt.close(plt.gcf())
        plt.close(fig)
        if image_format == "pil":
            image = Image.open(buf)
        elif image_format == "array":
            img = Image.open(buf)
            image = np.array(img)
        elif image_format == "bytes":
            image = buf.getvalue()
        verbose and logger.info("Image generation complete.")
    except Exception as e:
        verbose and logger.error(f"Error computing SHAP waterfall: {str(e)}")
        raise e
    return (image, explanation)

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • matplotlib
  • numpy
  • pandas
  • pillow
  • numba>=0.56.0
  • shap