Shortcut Feature Masking (M01) API¶
ShortcutMasker
¶
ShortcutMasker(
strategy: str = "randomize",
heatmap_threshold: float = 0.5,
augment_fraction: float = 1.0,
random_state: int | None = None,
)
Mask or randomize detected shortcut regions (images) or dimensions (embeddings).
Implements the data augmentation mitigation from Teso & Kersting (2019): counterexamples are created by randomizing or zeroing shortcut components.
Parameters¶
strategy : str For images: "zero", "randomize", or "inpaint". For embeddings: "zero" or "randomize". heatmap_threshold : float Binarization threshold for heatmaps when converting to shortcut masks (0–1). augment_fraction : float Fraction of samples to augment (0–1). 1.0 = all samples. random_state : Optional[int] Seed for reproducible randomization.
Source code in shortcut_detect/mitigation/shortcut_masking.py
Functions¶
mask_images
¶
mask_images(
images: ndarray,
shortcut_masks: ndarray | None = None,
heatmaps: ndarray | None = None,
) -> np.ndarray
Produce augmented images by masking shortcut regions.
Parameters¶
images : np.ndarray Images of shape (N, H, W) or (N, H, W, C), values in [0, 1] or [0, 255]. shortcut_masks : np.ndarray, optional Binary masks (N, H, W), 1 = shortcut region. If None, heatmaps are used. heatmaps : np.ndarray, optional Heatmaps (N, H, W) in [0, 1]. Used if shortcut_masks is None; binarized with heatmap_threshold.
Returns¶
np.ndarray Augmented images, same shape and dtype as images.
Source code in shortcut_detect/mitigation/shortcut_masking.py
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | |
mask_embeddings
¶
Produce augmented embeddings by masking flagged shortcut dimensions.
Parameters¶
embeddings : np.ndarray Shape (N, D). flagged_dim_indices : list or array of int Dimension indices to mask (0-based).
Returns¶
np.ndarray Augmented embeddings, same shape as embeddings.