Skip to content

SpRAy API

SpRAy performs spectral clustering of explanation heatmaps to reveal systematic attention artifacts (Clever Hans behavior).

Class Reference

SpRAyDetector

SpRAyDetector(
    *,
    n_clusters: int | None = None,
    cluster_selection: str = "auto",
    affinity: str = "cosine",
    nearest_neighbors: int = 10,
    rbf_gamma: float | None = None,
    min_clusters: int = 2,
    max_clusters: int = 10,
    downsample_size: int | tuple[int, int] | None = 32,
    random_state: int = 42,
    small_cluster_threshold: float = 0.05,
    purity_threshold: float = 0.8,
    focus_threshold: float = 0.6,
    separation_threshold: float = 0.3,
    alignment_threshold: float = 0.6,
    heatmap_generator: HeatmapGenerator | None = None
)

Bases: DetectorBase

Spectral clustering of explanation heatmaps for Clever Hans detection.

Source code in shortcut_detect/xai/spray_detector.py
def __init__(
    self,
    *,
    n_clusters: int | None = None,
    cluster_selection: str = "auto",
    affinity: str = "cosine",
    nearest_neighbors: int = 10,
    rbf_gamma: float | None = None,
    min_clusters: int = 2,
    max_clusters: int = 10,
    downsample_size: int | tuple[int, int] | None = 32,
    random_state: int = 42,
    small_cluster_threshold: float = 0.05,
    purity_threshold: float = 0.8,
    focus_threshold: float = 0.6,
    separation_threshold: float = 0.3,
    alignment_threshold: float = 0.6,
    heatmap_generator: HeatmapGenerator | None = None,
) -> None:
    super().__init__(method="spray")
    self.n_clusters = n_clusters
    self.cluster_selection = cluster_selection
    self.affinity = affinity
    self.nearest_neighbors = nearest_neighbors
    self.rbf_gamma = rbf_gamma
    self.min_clusters = min_clusters
    self.max_clusters = max_clusters
    self.downsample_size = downsample_size
    self.random_state = random_state
    self.small_cluster_threshold = small_cluster_threshold
    self.purity_threshold = purity_threshold
    self.focus_threshold = focus_threshold
    self.separation_threshold = separation_threshold
    self.alignment_threshold = alignment_threshold
    self.heatmap_generator = heatmap_generator

    self.heatmaps_: np.ndarray | None = None
    self.cluster_labels_: np.ndarray | None = None
    self.cluster_summaries_: Sequence[SpRAyClusterSummary] = []
    self.representative_heatmaps_: dict[int, np.ndarray] = {}

Functions

fit

fit(
    heatmaps: ndarray | None = None,
    *,
    labels: ndarray | None = None,
    group_labels: ndarray | None = None,
    inputs: Any | None = None,
    heatmap_generator: HeatmapGenerator | None = None,
    model: Any | None = None,
    target_layer: str | Any | None = None,
    head: str | int = "logits",
    target_index: int | None = None,
    batch_size: int = 16
) -> SpRAyDetector

Fit SpRAy detector on heatmaps or raw inputs.

Parameters

heatmaps: Precomputed heatmaps shaped (N,H,W). If provided, inputs and model are ignored. labels: Optional task labels for cluster purity analysis. group_labels: Optional protected attribute labels for cluster purity analysis. inputs: Raw inputs to generate heatmaps from if heatmaps are not provided. heatmap_generator: Optional callable or GradCAMHeatmapGenerator used to produce heatmaps from inputs. model / target_layer: If provided (and heatmaps is None), a GradCAMHeatmapGenerator will be created internally. head / target_index: Parameters forwarded to GradCAM when generating heatmaps. batch_size: Batch size for heatmap generation when inputs are provided.

Source code in shortcut_detect/xai/spray_detector.py
def fit(
    self,
    heatmaps: np.ndarray | None = None,
    *,
    labels: np.ndarray | None = None,
    group_labels: np.ndarray | None = None,
    inputs: Any | None = None,
    heatmap_generator: HeatmapGenerator | None = None,
    model: Any | None = None,
    target_layer: str | Any | None = None,
    head: str | int = "logits",
    target_index: int | None = None,
    batch_size: int = 16,
) -> SpRAyDetector:
    """Fit SpRAy detector on heatmaps or raw inputs.

    Parameters
    ----------
    heatmaps:
        Precomputed heatmaps shaped (N,H,W). If provided, inputs and model are ignored.
    labels:
        Optional task labels for cluster purity analysis.
    group_labels:
        Optional protected attribute labels for cluster purity analysis.
    inputs:
        Raw inputs to generate heatmaps from if heatmaps are not provided.
    heatmap_generator:
        Optional callable or GradCAMHeatmapGenerator used to produce heatmaps from inputs.
    model / target_layer:
        If provided (and heatmaps is None), a GradCAMHeatmapGenerator will be created internally.
    head / target_index:
        Parameters forwarded to GradCAM when generating heatmaps.
    batch_size:
        Batch size for heatmap generation when inputs are provided.
    """

    if heatmaps is None:
        heatmaps = self._generate_heatmaps(
            inputs=inputs,
            heatmap_generator=heatmap_generator,
            model=model,
            target_layer=target_layer,
            head=head,
            target_index=target_index,
            batch_size=batch_size,
        )

    heatmaps = self._validate_heatmaps(heatmaps)
    self.heatmaps_ = heatmaps
    n_samples = heatmaps.shape[0]

    labels_arr = np.asarray(labels) if labels is not None else None
    if labels_arr is not None and len(labels_arr) != n_samples:
        raise ValueError("labels length must match number of heatmaps")

    group_arr = np.asarray(group_labels) if group_labels is not None else None
    if group_arr is not None and len(group_arr) != n_samples:
        raise ValueError("group_labels length must match number of heatmaps")

    processed = self._preprocess_heatmaps(heatmaps)
    features = processed.reshape(n_samples, -1)

    affinity_matrix = None
    if self.affinity in {"cosine", "rbf"}:
        affinity_matrix = self._compute_affinity(features)

    if self.cluster_selection not in {"fixed", "eigengap", "auto"}:
        raise ValueError("cluster_selection must be 'fixed', 'eigengap', or 'auto'")

    n_clusters = self._resolve_cluster_count(
        features=features,
        affinity_matrix=affinity_matrix,
    )

    cluster_labels = self._cluster(features, affinity_matrix, n_clusters)

    summaries = self._summarize_clusters(
        cluster_labels,
        labels_arr,
        group_arr,
        processed,
    )
    representatives = self._compute_representative_heatmaps(processed, cluster_labels)

    clever_hans = self._evaluate_clever_hans(
        summaries=summaries,
        cluster_labels=cluster_labels,
        labels=labels_arr,
        features=features,
    )

    silhouette = self._compute_silhouette(features, cluster_labels)
    purity_values = [s.label_purity for s in summaries if s.label_purity is not None]
    mean_purity = float(np.mean(purity_values)) if purity_values else None
    max_purity = float(np.max(purity_values)) if purity_values else None
    focus_scores = [s.focus_mean for s in summaries]
    mean_focus = float(np.mean(focus_scores)) if focus_scores else None

    self.cluster_labels_ = cluster_labels
    self.cluster_summaries_ = summaries
    self.representative_heatmaps_ = representatives
    self.shortcut_detected_ = clever_hans["shortcut_detected"]

    metrics = {
        "n_clusters": int(n_clusters),
        "silhouette": silhouette,
        "mean_label_purity": mean_purity,
        "max_label_purity": max_purity,
        "mean_focus": mean_focus,
    }
    metadata = {
        "n_samples": int(n_samples),
        "heatmap_shape": heatmaps.shape[1:],
        "affinity": self.affinity,
        "cluster_selection": self.cluster_selection,
        "downsample_size": self.downsample_size,
    }
    report = {
        "clusters": [summary.__dict__ for summary in summaries],
        "clever_hans": clever_hans,
    }
    details = {
        "cluster_labels": cluster_labels,
        "representative_heatmaps": representatives,
    }

    risk_level = clever_hans.get("risk_level", "unknown")
    self._set_results(
        shortcut_detected=self.shortcut_detected_,
        risk_level=risk_level,
        metrics=metrics,
        notes="SpRAy spectral clustering on explanation heatmaps.",
        metadata=metadata,
        report=report,
        details=details,
    )

    self._is_fitted = True
    return self

Usage Example

import numpy as np
from shortcut_detect import SpRAyDetector

heatmaps = np.load("heatmaps.npy")
detector = SpRAyDetector(affinity="cosine", cluster_selection="auto")
detector.fit(heatmaps=heatmaps)

report = detector.get_report()
print(report["report"]["clever_hans"])