SHAP Dependence
Visualizes the relationship between a feature value and its SHAP value. supports coloring by a second feature to reveal interactions.
SHAP Dependence
Processing
This brick generates a "Dependence Plot" to visualize how a specific feature affects your model's predictions. It creates a scatter plot where the X-axis represents the actual value of a feature (e.g., "Age") and the Y-axis represents that feature's impact on the prediction (its SHAP value).
It effectively answers the question: "How does changing this specific variable change the result?" Additionally, it can color the dots based on a second feature to reveal interactions—situations where the impact of one feature depends on the value of another (e.g., "Age" might matter more if "Income" is high).
Inputs
- explainer
- The SHAP Explainer object created by a previous brick (e.g., "Calculate SHAP Values"). This contains the logic needed to interpret your model.
- data
- The dataset containing the rows you want to analyze. The explainer uses this to calculate the impact values.
- original data (optional)
- If your primary
datainput was normalized, scaled, or encoded (making it hard to read), provide the original, human-readable dataset here. The brick will use thedatafor calculation but useoriginal datafor the X-axis labels, making the chart easier to understand.
Inputs Types
| Input | Types |
|---|---|
explainer |
Any |
data |
DataFrame |
original data |
DataFrame |
You can check the list of supported types here: Available Type Hints.
Outputs
- image
- The generated dependence plot. The format of this output (an image object, raw bytes, or a data array) is determined by the "Output Image Format" option settings.
Outputs Types
| Output | Types |
|---|---|
image |
MediaData,PILImage |
You can check the list of supported types here: Available Type Hints.
Options
The SHAP Dependence brick contains some changeable options:
- Feature to Plot
- The exact name of the feature (column) you want to visualize on the X-axis. This is the primary variable you are analyzing.
- Interaction Feature
- Determines how the dots in the scatter plot are colored. If left empty , the algorithm automatically finds the feature that interacts most strongly with the "Feature to Plot" and uses it for coloring. You can type the name of a specific feature to force the plot to color by that variable.
- Class Index (Multiclass)
- Used only for classification models that predict multiple categories (e.g., Red, Green, Blue). It selects which class you want to explain.
- Output Image Format
- Defines the technical format of the output image.
- pil: Returns a standard Python Image Library object (e.g., for saving or further processing).
- bytes: Returns the raw image file data.
- array: Returns the image as a NumPy matrix of numbers.
- Verbose
- When enabled, the brick logs detailed progress information to the console, which is helpful for troubleshooting.
import logging
import io
import inspect
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from PIL import Image
from shap.explainers._explainer import Explainer
from coded_flows.types import DataFrame, Any, MediaData, Union, PILImage
from coded_flows.utils import CodedFlowsLogger
matplotlib.use("Agg")
import shap
logger = CodedFlowsLogger(name="SHAP Dependence", level=logging.INFO)
logging.getLogger("shap").setLevel(logging.ERROR)
def _call_explainer(explainer, data, **base_kwargs):
sig = inspect.signature(explainer.__call__)
if "silent" in sig.parameters:
return explainer(data, silent=True, **base_kwargs)
return explainer(data, **base_kwargs)
def _to_numpy(df):
"""Convert pandas or polars DataFrame to numpy array."""
if hasattr(df, "__class__") and "polars" in df.__class__.__module__:
return df.to_numpy()
elif hasattr(df, "values"):
return df.values
elif hasattr(df, "to_numpy"):
return df.to_numpy()
else:
return np.array(df)
def _get_columns(df):
"""Get column names from pandas or polars DataFrame."""
if hasattr(df, "columns"):
columns = df.columns
return list(columns) if not isinstance(columns, list) else columns
return None
def shap_dependence(
explainer: Any,
data: DataFrame,
original_data: DataFrame = None,
options: dict = None,
) -> Union[MediaData, PILImage]:
options = options or {}
verbose = options.get("verbose", True)
class_index = options.get("class_index", 0)
feature_name = options.get("feature_name", "Feature 0")
interaction_feature = options.get("interaction_feature", "Auto")
image_format = options.get("image_format", "pil")
image = None
try:
verbose and logger.info("Starting SHAP Interaction Scatter generation.")
if not isinstance(explainer, Explainer):
verbose and logger.error("Expects a Shap Explainer as an input.")
raise ValueError("Expects a Shap Explainer as an input.")
shap_explanation = _call_explainer(explainer, data)
if original_data is not None:
verbose and logger.info(
"Original data provided. Overwriting explanation data."
)
columns = _get_columns(original_data)
if columns is not None:
shap_explanation.feature_names = columns
shap_explanation.data = _to_numpy(original_data)
ndim = len(shap_explanation.shape)
explanation_to_process = None
if ndim == 2:
verbose and logger.info("Detected Regression or Binary Classification.")
explanation_to_process = shap_explanation
elif ndim == 3:
num_classes = shap_explanation.shape[2]
verbose and logger.info(
f"Detected Multiclass output with {num_classes} classes."
)
if class_index >= num_classes or class_index < 0:
verbose and logger.warning(
f"Class index {class_index} out of bounds. Defaulting to 0."
)
class_index = 0
verbose and logger.info(f"Selecting class index: {class_index}")
explanation_to_process = shap_explanation[:, :, class_index]
else:
raise ValueError(f"Unexpected SHAP explanation dimensionality: {ndim}")
available_features = explanation_to_process.feature_names
if not available_features:
available_features = [
f"Feature {i}" for i in range(explanation_to_process.shape[1])
]
explanation_to_process.feature_names = available_features
if feature_name not in available_features:
try:
f_idx = int(feature_name.replace("Feature ", ""))
if 0 <= f_idx < len(available_features):
feature_name = available_features[f_idx]
else:
raise ValueError
except:
verbose and logger.warning(
f"Feature '{feature_name}' not found. Defaulting to first feature."
)
feature_name = available_features[0]
verbose and logger.info(f"Plotting scatter for feature: '{feature_name}'")
fig = plt.figure()
color_arg = None
if interaction_feature is None:
color_arg = explanation_to_process
verbose and logger.info(
"Interaction coloring: Auto (Strongest interaction)."
)
elif interaction_feature in available_features:
color_arg = explanation_to_process[:, interaction_feature]
verbose and logger.info(f"Interaction coloring: '{interaction_feature}'.")
else:
verbose and logger.info(
"Interaction coloring: None or Invalid (Defaulting to Auto)."
)
color_arg = explanation_to_process
shap.plots.scatter(
explanation_to_process[:, feature_name], color=color_arg, show=False
)
buf = io.BytesIO()
plt.savefig(buf, format="png", bbox_inches="tight", dpi=300)
buf.seek(0)
plt.close(plt.gcf())
plt.close(fig)
if image_format == "pil":
image = Image.open(buf)
elif image_format == "array":
image = Image.open(buf)
image = np.array(image)
elif image_format == "bytes":
image = buf.getvalue()
verbose and logger.info("Image generation complete.")
except Exception as e:
verbose and logger.error(f"Error computing SHAP scatter: {str(e)}")
raise e
return image
Brick Info
- shap>=0.47.0
- matplotlib
- numpy
- pandas
- pillow
- numba>=0.56.0
- shap