SHAP Impact Distribution
Visualizes the distribution of feature impacts. Shows how high/low feature values affect predictions.
SHAP Impact Distribution
Processing
This brick generates a visual summary (specifically a "beeswarm" plot) that reveals how your model makes decisions. It calculates the impact of every feature for every item in your dataset, showing you two things at once: which features are the most important overall, and how high or low values of those features shift the prediction.
In the resulting plot, features are sorted by importance. Each dot represents a single data point. The color represents the feature's value (Red = High, Blue = Low), and the horizontal position shows whether that value pushed the prediction higher or lower.
Inputs
- explainer
- The initialized SHAP Explainer tool. This must be a pre-calculated explainer object that understands your specific machine learning model.
- data
- The dataset you want to analyze. The explainer will calculate impact scores based on this data.
- original data (optional)
- The raw dataset containing the original, human-readable values.
- Why use this? Often, the
datainput is pre-processed (e.g., normalized between 0 and 1). Providing theoriginal dataallows the plot to map the impact scores back to the real-world values (e.g., Age "25" instead of "0.3"), ensuring the red/blue color scale corresponds to meaningful numbers.
Inputs Types
| Input | Types |
|---|---|
explainer |
Any |
data |
DataFrame |
original data |
DataFrame |
You can check the list of supported types here: Available Type Hints.
Outputs
- image
- The generated visualization of the impact distribution. Depending on your settings, this is returned as an image object, raw bytes, or a numerical array.
Outputs Types
| Output | Types |
|---|---|
image |
MediaData,PILImage |
You can check the list of supported types here: Available Type Hints.
Options
The SHAP Impact Distribution brick contains some changeable options:
- Class Index (Multiclass)
- Used only for models that predict multiple categories (e.g., predicting "Red," "Green," or "Blue"). This integer determines which category to analyze.
- Max Features to Display
- Controls how many features are shown on the y-axis of the plot. The brick automatically selects the top most important features up to this number (default is 10).
- Output Image Format
- Determines the technical format of the output image.
- pil: Returns a Python Imaging Library object. Best for further image manipulation within the workflow.
- bytes: Returns the raw image file data.
- array: Returns the image as a numpy matrix of numbers.
- Verbose
- When enabled, details about the analysis progress (e.g., detected class types, matrix shapes) will be printed to the logs.
import logging
import io
import inspect
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
from shap.explainers._explainer import Explainer
from coded_flows.types import DataFrame, Any, MediaData, Union, PILImage
from coded_flows.utils import CodedFlowsLogger
matplotlib.use("Agg")
import shap
logger = CodedFlowsLogger(name="SHAP Impact Distribution", level=logging.INFO)
logging.getLogger("shap").setLevel(logging.ERROR)
def _call_explainer(explainer, data, **base_kwargs):
sig = inspect.signature(explainer.__call__)
if "silent" in sig.parameters:
return explainer(data, silent=True, **base_kwargs)
return explainer(data, **base_kwargs)
def _to_numpy(df):
"""Convert pandas or polars DataFrame to numpy array."""
if hasattr(df, "__class__") and "polars" in df.__class__.__module__:
return df.to_numpy()
elif hasattr(df, "values"):
return df.values
elif hasattr(df, "to_numpy"):
return df.to_numpy()
else:
return np.array(df)
def _get_columns(df):
"""Get column names from pandas or polars DataFrame."""
if hasattr(df, "columns"):
columns = df.columns
return list(columns) if not isinstance(columns, list) else columns
return None
def shap_impact(
explainer: Any,
data: DataFrame,
original_data: DataFrame = None,
options: dict = None,
) -> Union[MediaData, PILImage]:
options = options or {}
verbose = options.get("verbose", True)
class_index = options.get("class_index", 0)
max_display = options.get("max_display", 10)
image_format = options.get("image_format", "pil")
image = None
try:
verbose and logger.info("Starting SHAP impact distribution analysis.")
if not isinstance(explainer, Explainer):
verbose and logger.error("Expects a Shap Explainer as an input.")
raise ValueError("Expects a Shap Explainer as an input.")
shap_explanation = _call_explainer(explainer, data)
if original_data is not None:
verbose and logger.info(
"Original data provided. Overwriting explanation data for visualization."
)
columns = _get_columns(original_data)
if columns is not None:
shap_explanation.feature_names = columns
shap_explanation.data = _to_numpy(original_data)
ndim = len(shap_explanation.shape)
explanation_to_plot = None
if ndim == 2:
verbose and logger.info(
"Detected Regression or Binary Classification (single output)."
)
explanation_to_plot = shap_explanation
elif ndim == 3:
num_classes = shap_explanation.shape[2]
verbose and logger.info(
f"Detected Multiclass output with {num_classes} classes."
)
if class_index >= num_classes or class_index < 0:
verbose and logger.warning(
f"Selected class_index {class_index} is out of bounds (0-{num_classes - 1}). Defaulting to 0."
)
class_index = 0
verbose and logger.info(f"Selecting class index: {class_index}")
explanation_to_plot = shap_explanation[:, :, class_index]
else:
raise ValueError(
f"Unexpected SHAP explanation dimensionality: {ndim}. Expected 2 or 3."
)
verbose and logger.info(
f"Generating Beeswarm Plot with max_display={max_display}."
)
fig = plt.figure()
shap.plots.beeswarm(explanation_to_plot, max_display=max_display, show=False)
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=300)
buf.seek(0)
plt.close(plt.gcf())
plt.close(fig)
if image_format == "pil":
image = Image.open(buf)
elif image_format == "array":
image = Image.open(buf)
image = np.array(image)
elif image_format == "bytes":
image = buf.getvalue()
verbose and logger.info("Image generation complete.")
except Exception as e:
verbose and logger.error(f"Error computing SHAP impact distribution: {str(e)}")
raise e
return image
Brick Info
- shap>=0.47.0
- matplotlib
- numpy
- pandas
- polars[pyarrow]
- pillow
- numba>=0.56.0
- shap