Correlation Matrix

Compute correlation between features and generate a heatmap. Returns both the Correlation Table and the Heatmap.

Correlation Matrix

Processing

Computes the correlation matrix between features present in the input dataset using a selected statistical method. It supports standard methods like Pearson, Spearman, and Kendall, as well as association measures like Phi K and Cramers V for categorical and mixed data types.

Inputs

data
The input dataset (DataFrame or Arrow Table) containing the features for which correlation must be computed.

Inputs Types

Input Types
data DataFrame, ArrowTable

You can check the list of supported types here: Available Type Hints.

Outputs

image
The generated heatmap visualization of the correlation matrix.
matrix
The square DataFrame representing the computed correlation coefficients between all features.

Outputs Types

Output Types
image MediaData, PILImage
matrix DataFrame

You can check the list of supported types here: Available Type Hints.

Options

The Correlation Matrix brick contains some changeable options:

Correlation Method
Selects the statistical method used to compute correlation. Choices include pearson, spearman, kendall (typically for continuous data), phi_k (for mixed data types), or cramers (for association between categorical variables). (Default: pearson)
Verbose Output
If enabled, detailed logging messages regarding the conversion and computation process are displayed. (Default: True)
import io
import logging
import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from coded_flows.types import DataFrame, ArrowTable, Union, MediaData, PILImage, Tuple
from coded_flows.utils import CodedFlowsLogger

matplotlib.use("Agg")
logger = CodedFlowsLogger(name="Correlation Matrix", level=logging.INFO)


def compute_correlation_matrix(
    data: Union[DataFrame, ArrowTable], options=None
) -> Tuple[Union[MediaData, PILImage], DataFrame]:
    options = options or {}
    verbose = options.get("verbose", True)
    correlation_type = options.get("correlation_type", "pearson")
    image_format = options.get("image_format", "pil")
    mask_upper = options.get("mask_upper", False)
    verbose and logger.info(f"Starting process. Method: '{correlation_type}'")
    matrix = None
    image = None
    try:
        if hasattr(data, "to_pandas") and isinstance(data, pl.DataFrame):
            verbose and logger.info(f"Converting polars DataFrame to pandas")
            data = data.to_pandas()
        elif hasattr(data, "to_pandas") and isinstance(data, pa.Table):
            verbose and logger.info(f"Converting Arrow table to pandas")
            data = data.to_pandas()
        elif isinstance(data, pd.DataFrame):
            verbose and logger.info(f"Input is already pandas DataFrame")
        else:
            error_msg = f"Unsupported data type: {type(data).__name__}"
            verbose and logger.error(f"{error_msg}")
            raise ValueError(error_msg)
        if data.empty:
            raise ValueError("Input DataFrame is empty")
        verbose and logger.info(
            f"Processing DataFrame with {data.shape[0]:,} rows × {data.shape[1]:,} columns"
        )
        if correlation_type == "pearson":
            verbose and logger.info(f"Computing Pearson correlation")
            matrix = data.corr(method="pearson", numeric_only=True)
        elif correlation_type == "spearman":
            verbose and logger.info(f"Computing Spearman rank correlation")
            matrix = data.corr(method="spearman", numeric_only=True)
        elif correlation_type == "kendall":
            verbose and logger.info(f"Computing Kendall rank correlation")
            matrix = data.corr(method="kendall", numeric_only=True)
        elif correlation_type == "phi_k":
            verbose and logger.info(f"Computing Phi K correlation")
            try:
                import phik

                matrix = data.phik_matrix()
            except ImportError as ie:
                raise RuntimeError("Install 'phik' to use this method.") from ie
        elif correlation_type == "cramers":
            verbose and logger.info(f"Computing Cramers V association")
            try:
                from scipy.stats import chi2_contingency

                columns = data.columns.tolist()
                n = len(columns)
                cramers_matrix = np.zeros((n, n))
                for i, col1 in enumerate(columns):
                    for j, col2 in enumerate(columns):
                        if i == j:
                            cramers_matrix[i, j] = 1.0
                        elif i < j:
                            contingency = pd.crosstab(data[col1], data[col2])
                            (chi2, _, _, _) = chi2_contingency(contingency)
                            n_obs = contingency.sum().sum()
                            min_dim = (
                                min(contingency.shape[0], contingency.shape[1]) - 1
                            )
                            cramers_v = (
                                np.sqrt(chi2 / (n_obs * min_dim)) if min_dim > 0 else 0
                            )
                            cramers_matrix[i, j] = cramers_v
                            cramers_matrix[j, i] = cramers_v
                matrix = pd.DataFrame(cramers_matrix, index=columns, columns=columns)
            except ImportError as ie:
                raise RuntimeError("Install 'scipy' to use cramers method.") from ie
        else:
            raise ValueError(f"Invalid correlation type: '{correlation_type}'")
        if matrix is None or matrix.empty:
            raise RuntimeError("Correlation computation returned empty result")
        verbose and logger.info("Generating heatmap visualization.")
        corr = matrix.round(2)
        mask = np.triu(np.ones_like(corr, dtype=bool)) if mask_upper else None
        dim = len(corr.columns)
        fig_size = max(10, min(25, dim * 0.8))
        (fig, ax) = plt.subplots(figsize=(fig_size, fig_size))
        cmap = sns.diverging_palette(220, 10, as_cmap=True)
        sns.heatmap(
            corr,
            mask=mask,
            cmap=cmap,
            vmin=-1,
            vmax=1,
            center=0,
            square=True,
            linewidths=0.5,
            cbar_kws={"shrink": 0.5},
            annot=True,
            fmt=".2f",
            ax=ax,
        )
        plt.title(f"Correlation Matrix ({correlation_type})", fontsize=16)
        plt.tight_layout()
        buf = io.BytesIO()
        plt.savefig(buf, format="png", bbox_inches="tight", dpi=300)
        buf.seek(0)
        plt.close(fig)
        if image_format == "pil":
            image = Image.open(buf)
        elif image_format == "array":
            image = np.array(Image.open(buf))
        elif image_format == "bytes":
            image = buf.getvalue()
        verbose and logger.info(f"Image generated successfully ({image_format}).")
        matrix.reset_index(inplace=True)
        verbose and logger.info("Formatted DataFrame for output.")
    except Exception as e:
        verbose and logger.error(f"Computation failed: {e}")
        raise e
    return (image, matrix)

Brick Info

version v0.1.7
python 3.10, 3.11, 3.12, 3.13
requirements
  • pandas
  • pyarrow
  • polars[pyarrow]
  • numpy
  • matplotlib
  • seaborn
  • pillow
  • scipy
  • phik