SHAP Dependence

Visualizes the relationship between a feature value and its SHAP value. supports coloring by a second feature to reveal interactions.

SHAP Dependence

Processing

This brick generates a "Dependence Plot" to visualize how a specific feature affects your model's predictions. It creates a scatter plot where the X-axis represents the actual value of a feature (e.g., "Age") and the Y-axis represents that feature's impact on the prediction (its SHAP value).

It effectively answers the question: "How does changing this specific variable change the result?" Additionally, it can color the dots based on a second feature to reveal interactions—situations where the impact of one feature depends on the value of another (e.g., "Age" might matter more if "Income" is high).

Inputs

explainer
The SHAP Explainer object created by a previous brick (e.g., "Calculate SHAP Values"). This contains the logic needed to interpret your model.
data
The dataset containing the rows you want to analyze. The explainer uses this to calculate the impact values.
original data (optional)
If your primary data input was normalized, scaled, or encoded (making it hard to read), provide the original, human-readable dataset here. The brick will use the data for calculation but use original data for the X-axis labels, making the chart easier to understand.

Inputs Types

Input Types
explainer Any
data DataFrame
original data DataFrame

You can check the list of supported types here: Available Type Hints.

Outputs

image
The generated dependence plot. The format of this output (an image object, raw bytes, or a data array) is determined by the "Output Image Format" option settings.

Outputs Types

Output Types
image MediaData,PILImage

You can check the list of supported types here: Available Type Hints.

Options

The SHAP Dependence brick contains some changeable options:

Feature to Plot
The exact name of the feature (column) you want to visualize on the X-axis. This is the primary variable you are analyzing.
Interaction Feature
Determines how the dots in the scatter plot are colored. If left empty , the algorithm automatically finds the feature that interacts most strongly with the "Feature to Plot" and uses it for coloring. You can type the name of a specific feature to force the plot to color by that variable.
Class Index (Multiclass)
Used only for classification models that predict multiple categories (e.g., Red, Green, Blue). It selects which class you want to explain.
Output Image Format
Defines the technical format of the output image.
  • pil: Returns a standard Python Image Library object (e.g., for saving or further processing).
  • bytes: Returns the raw image file data.
  • array: Returns the image as a NumPy matrix of numbers.
Verbose
When enabled, the brick logs detailed progress information to the console, which is helpful for troubleshooting.
import logging
import io
import inspect
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
from shap.explainers._explainer import Explainer
from coded_flows.types import DataFrame, Any, MediaData, Union, PILImage
from coded_flows.utils import CodedFlowsLogger

matplotlib.use("Agg")
import shap

logger = CodedFlowsLogger(name="SHAP Dependence", level=logging.INFO)
logging.getLogger("shap").setLevel(logging.ERROR)


def _call_explainer(explainer, data, **base_kwargs):
    sig = inspect.signature(explainer.__call__)
    if "silent" in sig.parameters:
        return explainer(data, silent=True, **base_kwargs)
    return explainer(data, **base_kwargs)


def _to_numpy(df):
    """Convert pandas or polars DataFrame to numpy array."""
    if hasattr(df, "__class__") and "polars" in df.__class__.__module__:
        return df.to_numpy()
    elif hasattr(df, "values"):
        return df.values
    elif hasattr(df, "to_numpy"):
        return df.to_numpy()
    else:
        return np.array(df)


def _get_columns(df):
    """Get column names from pandas or polars DataFrame."""
    if hasattr(df, "columns"):
        columns = df.columns
        return list(columns) if not isinstance(columns, list) else columns
    return None


def shap_dependence(
    explainer: Any,
    data: DataFrame,
    original_data: DataFrame = None,
    options: dict = None,
) -> Union[MediaData, PILImage]:
    options = options or {}
    verbose = options.get("verbose", True)
    class_index = options.get("class_index", 0)
    feature_name = options.get("feature_name", "Feature 0")
    interaction_feature = options.get("interaction_feature", "Auto")
    image_format = options.get("image_format", "pil")
    image = None
    try:
        verbose and logger.info("Starting SHAP Interaction Scatter generation.")
        if not isinstance(explainer, Explainer):
            verbose and logger.error("Expects a Shap Explainer as an input.")
            raise ValueError("Expects a Shap Explainer as an input.")
        shap_explanation = _call_explainer(explainer, data)
        if original_data is not None:
            verbose and logger.info(
                "Original data provided. Overwriting explanation data."
            )
            columns = _get_columns(original_data)
            if columns is not None:
                shap_explanation.feature_names = columns
            shap_explanation.data = _to_numpy(original_data)
        ndim = len(shap_explanation.shape)
        explanation_to_process = None
        if ndim == 2:
            verbose and logger.info("Detected Regression or Binary Classification.")
            explanation_to_process = shap_explanation
        elif ndim == 3:
            num_classes = shap_explanation.shape[2]
            verbose and logger.info(
                f"Detected Multiclass output with {num_classes} classes."
            )
            if class_index >= num_classes or class_index < 0:
                verbose and logger.warning(
                    f"Class index {class_index} out of bounds. Defaulting to 0."
                )
                class_index = 0
            verbose and logger.info(f"Selecting class index: {class_index}")
            explanation_to_process = shap_explanation[:, :, class_index]
        else:
            raise ValueError(f"Unexpected SHAP explanation dimensionality: {ndim}")
        available_features = explanation_to_process.feature_names
        if not available_features:
            available_features = [
                f"Feature {i}" for i in range(explanation_to_process.shape[1])
            ]
            explanation_to_process.feature_names = available_features
        if feature_name not in available_features:
            try:
                f_idx = int(feature_name.replace("Feature ", ""))
                if 0 <= f_idx < len(available_features):
                    feature_name = available_features[f_idx]
                else:
                    raise ValueError
            except:
                verbose and logger.warning(
                    f"Feature '{feature_name}' not found. Defaulting to first feature."
                )
                feature_name = available_features[0]
        verbose and logger.info(f"Plotting scatter for feature: '{feature_name}'")
        fig = plt.figure()
        color_arg = None
        if interaction_feature is None:
            color_arg = explanation_to_process
            verbose and logger.info(
                "Interaction coloring: Auto (Strongest interaction)."
            )
        elif interaction_feature in available_features:
            color_arg = explanation_to_process[:, interaction_feature]
            verbose and logger.info(f"Interaction coloring: '{interaction_feature}'.")
        else:
            verbose and logger.info(
                "Interaction coloring: None or Invalid (Defaulting to Auto)."
            )
            color_arg = explanation_to_process
        shap.plots.scatter(
            explanation_to_process[:, feature_name], color=color_arg, show=False
        )
        buf = io.BytesIO()
        plt.savefig(buf, format="png", bbox_inches="tight", dpi=300)
        buf.seek(0)
        plt.close(plt.gcf())
        plt.close(fig)
        if image_format == "pil":
            image = Image.open(buf)
        elif image_format == "array":
            image = Image.open(buf)
            image = np.array(image)
        elif image_format == "bytes":
            image = buf.getvalue()
        verbose and logger.info("Image generation complete.")
    except Exception as e:
        verbose and logger.error(f"Error computing SHAP scatter: {str(e)}")
        raise e
    return image

Brick Info

version v0.1.4
python 3.11, 3.12, 3.13
requirements
  • shap>=0.47.0
  • matplotlib
  • numpy
  • pandas
  • pillow
  • numba>=0.56.0
  • shap