GradCAM API¶
The GradCAM module provides visual shortcut detection for image models.
Class Reference¶
GradCAMHeatmapGenerator
¶
GradCAMHeatmapGenerator(
model: Module,
target_layer: str | Module,
head_mappings: (
dict[str, int | Callable[[Any], Tensor]] | None
) = None,
device: device | None = None,
)
Compute GradCAM heatmaps for multiple prediction heads.
Parameters¶
model:
torch.nn.Module that produces the predictions. The module
is always used in eval mode and gradients are enabled only
for the current forward pass.
target_layer:
Layer whose activations will be used for GradCAM. Can be the
actual nn.Module instance or the dotted path to the module.
head_mappings:
Optional mapping from head names to either an index (for tuple
outputs) or a callable fn(output) -> Tensor. Defaults to a
mapping that treats ("disease", "attribute") as the first
two entries of a tuple/list output.
device:
Device to run inference on. Defaults to the model's first
parameter device or CPU.
Source code in shortcut_detect/gradcam.py
Functions¶
generate_heatmap
¶
generate_heatmap(
inputs: TensorOrArray,
head: str | int,
target_index: int | None = None,
) -> np.ndarray
Generate GradCAM heatmap for a single prediction head.
Parameters¶
inputs:
Input image tensor/array shaped (B,C,H,W) or (C,H,W).
head:
Identifier for the prediction head. Can be a string key,
integer index, or alias defined in head_mappings.
target_index:
Target class index. If None the argmax of the head's
logits for each sample is used.
Source code in shortcut_detect/gradcam.py
generate_attention_overlap
¶
generate_attention_overlap(
inputs: TensorOrArray,
disease_target: int | None = None,
attribute_target: int | None = None,
disease_head: str | int = "disease",
attribute_head: str | int = "attribute",
threshold: float = 0.5,
) -> AttentionOverlapResult
Generate disease/attribute heatmaps and overlap metrics.
Source code in shortcut_detect/gradcam.py
calculate_overlap
staticmethod
¶
calculate_overlap(
disease_heatmap: ndarray,
attribute_heatmap: ndarray,
threshold: float = 0.5,
) -> dict[str, float]
Compute overlap metrics between two normalized heatmaps.
Source code in shortcut_detect/gradcam.py
GradCAMHeatmapGenerator¶
Constructor¶
GradCAMHeatmapGenerator(
model: torch.nn.Module,
target_layer: torch.nn.Module,
device: str = 'cuda',
use_guided: bool = False
)
Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
nn.Module | required | PyTorch model |
target_layer |
nn.Module | required | Layer for GradCAM |
device |
str | 'cuda' | Computation device |
use_guided |
bool | False | Use Guided GradCAM |
Methods¶
generate()¶
Generate GradCAM heatmap for a single input.
Parameters:
| Parameter | Type | Description |
|---|---|---|
input_tensor |
Tensor | Shape (C, H, W) or (1, C, H, W) |
target_class |
int | Class to explain (None = predicted) |
Returns: Heatmap array (H, W)
generate_batch()¶
Generate heatmaps for a batch of inputs.
visualize()¶
def visualize(
input_tensor: torch.Tensor,
heatmap: np.ndarray,
alpha: float = 0.4,
colormap: str = 'jet',
save_path: str = None
) -> np.ndarray
Overlay heatmap on input image.
compare_groups()¶
Compare attention patterns between groups.
Returns: AttentionOverlapResult dataclass
AttentionOverlapResult¶
@dataclass
class AttentionOverlapResult:
overlap_score: float # Attention overlap (0-1)
group_heatmaps: dict # Average heatmap per group
divergence_regions: ndarray # Regions with different attention
summary: str # Human-readable summary
Usage Examples¶
Basic Usage¶
from shortcut_detect import GradCAMHeatmapGenerator
import torch
model = torch.load("model.pth")
target_layer = model.layer4[-1]
gradcam = GradCAMHeatmapGenerator(model, target_layer)
heatmap = gradcam.generate(image_tensor)
gradcam.visualize(image_tensor, heatmap, save_path="attention.png")
Group Comparison¶
# Generate heatmaps for all images
heatmaps = []
for img in images:
heatmaps.append(gradcam.generate(img))
heatmaps = np.stack(heatmaps)
# Compare groups
result = gradcam.compare_groups(heatmaps, group_labels)
print(f"Overlap: {result.overlap_score:.2f}")
print(result.summary)
Batch Processing¶
from torch.utils.data import DataLoader
all_heatmaps = []
for batch in DataLoader(dataset, batch_size=32):
images, labels = batch
heatmaps = gradcam.generate_batch(images.cuda(), labels)
all_heatmaps.extend(heatmaps)
Layer Selection Tips¶
# ResNet
target_layer = model.layer4[-1]
# VGG
target_layer = model.features[-1]
# DenseNet
target_layer = model.features.denseblock4
# EfficientNet
target_layer = model.features[-1]