Skip to content

GradCAM API

The GradCAM module provides visual shortcut detection for image models.

Class Reference

GradCAMHeatmapGenerator

GradCAMHeatmapGenerator(
    model: Module,
    target_layer: str | Module,
    head_mappings: (
        dict[str, int | Callable[[Any], Tensor]] | None
    ) = None,
    device: device | None = None,
)

Compute GradCAM heatmaps for multiple prediction heads.

Parameters

model: torch.nn.Module that produces the predictions. The module is always used in eval mode and gradients are enabled only for the current forward pass. target_layer: Layer whose activations will be used for GradCAM. Can be the actual nn.Module instance or the dotted path to the module. head_mappings: Optional mapping from head names to either an index (for tuple outputs) or a callable fn(output) -> Tensor. Defaults to a mapping that treats ("disease", "attribute") as the first two entries of a tuple/list output. device: Device to run inference on. Defaults to the model's first parameter device or CPU.

Source code in shortcut_detect/gradcam.py
def __init__(
    self,
    model: torch.nn.Module,
    target_layer: str | torch.nn.Module,
    head_mappings: dict[str, int | Callable[[Any], torch.Tensor]] | None = None,
    device: torch.device | None = None,
) -> None:
    self.model = model.eval()
    self.device = device or self._infer_device()
    self.model.to(self.device)

    self._target_layer = self._resolve_target_layer(target_layer)
    self._activations: torch.Tensor | None = None
    self._gradients: torch.Tensor | None = None
    self._handles: Sequence[torch.utils.hooks.RemovableHandle] = []

    default_mappings: dict[str, int | Callable[[Any], torch.Tensor]] = {
        "logits": None,
    }
    if head_mappings:
        default_mappings.update(head_mappings)
    self.head_mappings = default_mappings

Functions

generate_heatmap

generate_heatmap(
    inputs: TensorOrArray,
    head: str | int,
    target_index: int | None = None,
) -> np.ndarray

Generate GradCAM heatmap for a single prediction head.

Parameters

inputs: Input image tensor/array shaped (B,C,H,W) or (C,H,W). head: Identifier for the prediction head. Can be a string key, integer index, or alias defined in head_mappings. target_index: Target class index. If None the argmax of the head's logits for each sample is used.

Source code in shortcut_detect/gradcam.py
def generate_heatmap(
    self,
    inputs: TensorOrArray,
    head: str | int,
    target_index: int | None = None,
) -> np.ndarray:
    """Generate GradCAM heatmap for a single prediction head.

    Parameters
    ----------
    inputs:
        Input image tensor/array shaped ``(B,C,H,W)`` or ``(C,H,W)``.
    head:
        Identifier for the prediction head.  Can be a string key,
        integer index, or alias defined in ``head_mappings``.
    target_index:
        Target class index.  If ``None`` the argmax of the head's
        logits for each sample is used.
    """

    inputs_tensor = self._prepare_inputs(inputs)
    with torch.enable_grad():
        return self._run_gradcam(inputs_tensor, head, target_index)

generate_attention_overlap

generate_attention_overlap(
    inputs: TensorOrArray,
    disease_target: int | None = None,
    attribute_target: int | None = None,
    disease_head: str | int = "disease",
    attribute_head: str | int = "attribute",
    threshold: float = 0.5,
) -> AttentionOverlapResult

Generate disease/attribute heatmaps and overlap metrics.

Source code in shortcut_detect/gradcam.py
def generate_attention_overlap(
    self,
    inputs: TensorOrArray,
    disease_target: int | None = None,
    attribute_target: int | None = None,
    disease_head: str | int = "disease",
    attribute_head: str | int = "attribute",
    threshold: float = 0.5,
) -> AttentionOverlapResult:
    """Generate disease/attribute heatmaps and overlap metrics."""

    disease_heatmap = self.generate_heatmap(inputs, disease_head, disease_target)
    attribute_heatmap = self.generate_heatmap(inputs, attribute_head, attribute_target)
    metrics = self.calculate_overlap(disease_heatmap, attribute_heatmap, threshold)
    overlap = metrics.get("dice", 0.0)
    return AttentionOverlapResult(
        disease_heatmap=disease_heatmap,
        attribute_heatmap=attribute_heatmap,
        overlap_score=overlap,
        metrics=metrics,
    )

calculate_overlap staticmethod

calculate_overlap(
    disease_heatmap: ndarray,
    attribute_heatmap: ndarray,
    threshold: float = 0.5,
) -> dict[str, float]

Compute overlap metrics between two normalized heatmaps.

Source code in shortcut_detect/gradcam.py
@staticmethod
def calculate_overlap(
    disease_heatmap: np.ndarray,
    attribute_heatmap: np.ndarray,
    threshold: float = 0.5,
) -> dict[str, float]:
    """Compute overlap metrics between two normalized heatmaps."""

    dh = np.asarray(disease_heatmap, dtype=np.float32)
    ah = np.asarray(attribute_heatmap, dtype=np.float32)
    if dh.shape != ah.shape:
        raise ValueError("Heatmaps must share the same shape to compute overlap")

    if dh.ndim == 2:
        dh = dh[None, ...]
        ah = ah[None, ...]

    eps = 1e-8
    dh_bin = (dh >= threshold).astype(np.float32)
    ah_bin = (ah >= threshold).astype(np.float32)

    intersection = (dh_bin * ah_bin).sum(axis=(1, 2))
    union = (dh_bin + ah_bin - dh_bin * ah_bin).sum(axis=(1, 2))
    dice_den = dh_bin.sum(axis=(1, 2)) + ah_bin.sum(axis=(1, 2)) + eps

    dice = (2.0 * intersection / dice_den).mean().item() if dice_den.size else 0.0
    iou = (intersection / (union + eps)).mean().item() if union.size else 0.0

    flat_d = dh.reshape(dh.shape[0], -1)
    flat_a = ah.reshape(ah.shape[0], -1)
    cosines = []
    for d_vec, a_vec in zip(flat_d, flat_a, strict=False):
        d_norm = np.linalg.norm(d_vec)
        a_norm = np.linalg.norm(a_vec)
        if d_norm < eps or a_norm < eps:
            cosines.append(0.0)
        else:
            cosines.append(float(np.dot(d_vec, a_vec) / (d_norm * a_norm)))
    cosine = float(np.mean(cosines)) if cosines else 0.0

    return {"dice": float(dice), "iou": float(iou), "cosine": cosine}

close

close() -> None

Remove any lingering hooks (safe to call multiple times).

Source code in shortcut_detect/gradcam.py
def close(self) -> None:
    """Remove any lingering hooks (safe to call multiple times)."""

    self._detach_hooks()

GradCAMHeatmapGenerator

Constructor

GradCAMHeatmapGenerator(
    model: torch.nn.Module,
    target_layer: torch.nn.Module,
    device: str = 'cuda',
    use_guided: bool = False
)

Parameters

Parameter Type Default Description
model nn.Module required PyTorch model
target_layer nn.Module required Layer for GradCAM
device str 'cuda' Computation device
use_guided bool False Use Guided GradCAM

Methods

generate()

def generate(
    input_tensor: torch.Tensor,
    target_class: int = None
) -> np.ndarray

Generate GradCAM heatmap for a single input.

Parameters:

Parameter Type Description
input_tensor Tensor Shape (C, H, W) or (1, C, H, W)
target_class int Class to explain (None = predicted)

Returns: Heatmap array (H, W)

generate_batch()

def generate_batch(
    inputs: torch.Tensor,
    target_classes: list[int] = None
) -> list[np.ndarray]

Generate heatmaps for a batch of inputs.

visualize()

def visualize(
    input_tensor: torch.Tensor,
    heatmap: np.ndarray,
    alpha: float = 0.4,
    colormap: str = 'jet',
    save_path: str = None
) -> np.ndarray

Overlay heatmap on input image.

compare_groups()

def compare_groups(
    heatmaps: np.ndarray,
    group_labels: np.ndarray
) -> AttentionOverlapResult

Compare attention patterns between groups.

Returns: AttentionOverlapResult dataclass

AttentionOverlapResult

@dataclass
class AttentionOverlapResult:
    overlap_score: float        # Attention overlap (0-1)
    group_heatmaps: dict        # Average heatmap per group
    divergence_regions: ndarray # Regions with different attention
    summary: str                # Human-readable summary

Usage Examples

Basic Usage

from shortcut_detect import GradCAMHeatmapGenerator
import torch

model = torch.load("model.pth")
target_layer = model.layer4[-1]

gradcam = GradCAMHeatmapGenerator(model, target_layer)

heatmap = gradcam.generate(image_tensor)
gradcam.visualize(image_tensor, heatmap, save_path="attention.png")

Group Comparison

# Generate heatmaps for all images
heatmaps = []
for img in images:
    heatmaps.append(gradcam.generate(img))
heatmaps = np.stack(heatmaps)

# Compare groups
result = gradcam.compare_groups(heatmaps, group_labels)
print(f"Overlap: {result.overlap_score:.2f}")
print(result.summary)

Batch Processing

from torch.utils.data import DataLoader

all_heatmaps = []
for batch in DataLoader(dataset, batch_size=32):
    images, labels = batch
    heatmaps = gradcam.generate_batch(images.cuda(), labels)
    all_heatmaps.extend(heatmaps)

Layer Selection Tips

# ResNet
target_layer = model.layer4[-1]

# VGG
target_layer = model.features[-1]

# DenseNet
target_layer = model.features.denseblock4

# EfficientNet
target_layer = model.features[-1]

See Also