ShortcutDetector API¶
The ShortcutDetector class provides a unified interface for running all shortcut detection methods.
Class Reference¶
ShortcutDetector
¶
ShortcutDetector(
methods: list[str] = None,
seed: int = 42,
condition_name: str = "indicator_count",
condition_kwargs: dict[str, Any] | None = None,
**kwargs
)
Unified interface for detecting shortcuts using multiple methods.
Combines five detection approaches: 1. HBAC (Hierarchical Bias-Aware Clustering) 2. Probe-based detection (train classifiers on embeddings) 3. Statistical testing (feature-wise group differences) 4. Geometric shortcut analysis (subspace monitoring) 5. Equalized odds gap analysis (fairness-based)
Example
detector = ShortcutDetector( ... methods=['hbac', 'probe', 'statistical', 'geometric', 'equalized_odds'] ... ) results = detector.fit(embeddings, labels, group_labels=groups) print(detector.summary())
Initialize unified shortcut detector.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
methods
|
list[str]
|
List of methods to use. Options: 'hbac', 'probe', 'statistical' |
None
|
seed
|
int
|
Random seed for reproducibility |
42
|
condition_name
|
str
|
Registered overall assessment condition name |
'indicator_count'
|
condition_kwargs
|
dict[str, Any] | None
|
Keyword arguments passed to the selected condition |
None
|
**kwargs
|
Additional arguments passed to individual detectors |
{}
|
Source code in shortcut_detect/unified.py
Functions¶
fit
¶
fit(
embeddings: ndarray | None,
labels: ndarray,
group_labels: ndarray | None = None,
feature_names: list[str] | None = None,
raw_inputs: Sequence[Any] | None = None,
embedding_source: EmbeddingSource | None = None,
embedding_cache_path: str | None = None,
force_embedding_recompute: bool = False,
splits: dict[str, ndarray] | None = None,
extra_labels: dict[str, ndarray] | None = None,
) -> ShortcutDetector
Fit all detection methods on embeddings.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
embeddings
|
ndarray | None
|
(n_samples, embedding_dim) array. Optional when using embedding-only mode. |
required |
labels
|
ndarray
|
(n_samples,) target labels |
required |
group_labels
|
ndarray | None
|
(n_samples,) group labels (e.g., demographic attributes)
If None, uses |
None
|
feature_names
|
list[str] | None
|
Optional list of feature names |
None
|
raw_inputs
|
Sequence[Any] | None
|
Optional raw inputs (text, ids, etc.). Required when
|
None
|
embedding_source
|
EmbeddingSource | None
|
Source used to generate embeddings when they are not pre-computed. |
None
|
embedding_cache_path
|
str | None
|
Optional path to cache generated embeddings. |
None
|
force_embedding_recompute
|
bool
|
Whether to ignore cached embeddings. |
False
|
splits
|
dict[str, ndarray] | None
|
Optional dictionary of named index sets for semi-supervised methods. Expected keys: 'train_l' (labeled), 'train_u' (unlabeled). |
None
|
extra_labels
|
dict[str, ndarray] | None
|
Optional dictionary of named per-sample arrays for additional supervision signals (e.g., 'spurious' labels). Use -1 for unknown labels. |
None
|
Returns:
| Type | Description |
|---|---|
ShortcutDetector
|
self |
Source code in shortcut_detect/unified.py
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 | |
summary
¶
Generate a text summary of detection results.
Returns:
| Type | Description |
|---|---|
str
|
Formatted summary string |
Source code in shortcut_detect/unified.py
generate_report
¶
generate_report(
output_path: str = None,
format: str = "html",
include_visualizations: bool = True,
export_csv: bool = False,
csv_dir: str = None,
)
Generate comprehensive report with visualizations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
output_path
|
str
|
Path to save report (required for HTML/PDF) |
None
|
format
|
str
|
Report format ('html' or 'pdf') |
'html'
|
include_visualizations
|
bool
|
Whether to include plots in HTML/PDF |
True
|
export_csv
|
bool
|
Whether to export results to CSV files |
False
|
csv_dir
|
str
|
Directory to save CSV files (default: same dir as output_path) |
None
|
Raises:
| Type | Description |
|---|---|
ValueError
|
If format is not supported or required paths are missing |
Source code in shortcut_detect/unified.py
get_results
¶
Get raw results dictionary.
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Dictionary with results from all methods |
Quick Reference¶
Constructor¶
ShortcutDetector(
methods: list[str] = ['hbac', 'probe', 'statistical'],
seed: int = 42,
condition_name: str = 'indicator_count',
condition_kwargs: dict | None = None,
**kwargs
)
Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
methods |
list[str] | ['hbac', 'probe', 'statistical'] |
Methods to run |
seed |
int | 42 |
Random seed for reproducibility |
condition_name |
str | 'indicator_count' |
Overall assessment condition to use |
condition_kwargs |
dict | None |
Keyword args passed to the selected condition |
**kwargs |
dict | - | Additional detector-specific configuration |
Available Methods¶
| Method Key | Description |
|---|---|
'hbac' |
Hierarchical Bias-Aware Clustering |
'probe' |
Probe-based classifier detection |
'statistical' |
Feature-wise statistical testing |
'geometric' |
Geometric subspace analysis |
'cav' |
Concept Activation Vector testing (loader mode) |
Methods¶
fit()¶
def fit(
embeddings: np.ndarray,
group_labels: np.ndarray,
task_labels: np.ndarray = None,
raw_inputs: list = None,
embedding_source: EmbeddingSource = None,
embedding_cache_path: str = None
) -> ShortcutDetector
Fit all detection methods on the data.
Parameters:
| Parameter | Type | Description |
|---|---|---|
embeddings |
ndarray | Shape (n_samples, n_features). Pass None for embedding-only mode. |
group_labels |
ndarray | Protected attribute labels |
task_labels |
ndarray | Optional task labels |
raw_inputs |
list | Raw inputs for embedding generation |
embedding_source |
EmbeddingSource | Embedding generator for embedding-only mode |
embedding_cache_path |
str | Path to cache generated embeddings |
Returns: self
summary()¶
Get a human-readable summary of all detection results.
Returns: Multi-line string with risk assessment and key metrics.
generate_report()¶
def generate_report(
output_path: str,
format: str = 'html',
include_visualizations: bool = True
) -> None
Generate a comprehensive report.
Parameters:
| Parameter | Type | Default | Description |
|---|---|---|---|
output_path |
str | required | Path for output file |
format |
str | 'html' |
Output format: 'html' or 'pdf' |
include_visualizations |
bool | True | Include plots in report |
get_results()¶
Get raw results from all methods.
Returns: Dictionary with results from each method.
Attributes (after fit)¶
| Attribute | Type | Description |
|---|---|---|
results_ |
dict | Results from all detection methods |
detectors_ |
dict | Fitted detector instances for each method |
embeddings_ |
ndarray | Stored embeddings used for fitting |
labels_ |
ndarray | Stored labels used for fitting |
group_labels_ |
ndarray | Stored group labels (if provided) |
Usage Examples¶
Basic Usage¶
from shortcut_detect import ShortcutDetector
import numpy as np
# Load data
embeddings = np.load("embeddings.npy")
group_labels = np.load("groups.npy")
# Create detector with all methods
detector = ShortcutDetector(
methods=['hbac', 'probe', 'statistical', 'geometric']
)
# Fit
detector.fit(embeddings, group_labels)
# Results
print(detector.summary())
results = detector.get_results()
print(f"Probe accuracy: {results['probe']['accuracy']}")
Custom Parameters¶
detector = ShortcutDetector(
methods=['hbac', 'probe', 'statistical'],
seed=42,
n_bootstraps=1000,
probe_backend='sklearn'
)
Custom Overall Assessment¶
The condition_name parameter selects how method-level results are aggregated into
the final risk summary. Five built-in conditions are available:
| Condition | Description | Key Parameters |
|---|---|---|
indicator_count |
Counts total risk indicators across methods (default) | - |
majority_vote |
Counts methods with at least one indicator as votes | high_threshold (int) |
weighted_risk |
Weights each detector by evidence strength (probe accuracy, stat significance ratio, HBAC confidence, geometric effect size) | high_threshold (float), moderate_threshold (float) |
multi_attribute |
Cross-references risk across sensitive attributes; escalates when multiple attributes independently flag shortcuts | high_threshold (int) |
meta_classifier |
Trained sklearn meta-classifier on detector features, or heuristic fallback | model_path (str), high_threshold (float), moderate_threshold (float) |
# Majority vote
detector = ShortcutDetector(
methods=['probe', 'statistical'],
condition_name='majority_vote',
condition_kwargs={'high_threshold': 2},
)
# Weighted risk scoring
detector = ShortcutDetector(
methods=['hbac', 'probe', 'statistical', 'geometric'],
condition_name='weighted_risk',
condition_kwargs={'high_threshold': 0.6, 'moderate_threshold': 0.3},
)
# Multi-attribute intersection
detector = ShortcutDetector(
methods=['probe', 'statistical'],
condition_name='multi_attribute',
condition_kwargs={'high_threshold': 2},
)
# Meta-classifier (heuristic fallback when no trained model provided)
detector = ShortcutDetector(
methods=['probe', 'statistical', 'hbac'],
condition_name='meta_classifier',
)
# Meta-classifier with a trained model
detector = ShortcutDetector(
methods=['probe', 'statistical', 'hbac'],
condition_name='meta_classifier',
condition_kwargs={'model_path': 'path/to/meta_model.joblib'},
)
The meta_classifier condition also exposes MetaClassifierCondition.extract_features(ctx)
for building training datasets from synthetic benchmark runs.
The default indicator_count condition preserves the library's existing summary semantics.
Embedding-Only Mode¶
from shortcut_detect import ShortcutDetector, HuggingFaceEmbeddingSource
# Create embedding source
hf_source = HuggingFaceEmbeddingSource(
model_name="sentence-transformers/all-MiniLM-L6-v2"
)
# Detect shortcuts from raw inputs
detector = ShortcutDetector(methods=['probe', 'statistical'])
detector.fit(
embeddings=None, # Triggers embedding-only mode
group_labels=groups,
raw_inputs=texts,
embedding_source=hf_source,
embedding_cache_path="embeddings.npy"
)
Report Generation¶
# HTML report (recommended)
detector.generate_report(
output_path="report.html",
format="html",
include_visualizations=True
)
# PDF report
detector.generate_report(
output_path="report.pdf",
format="pdf"
)
Accessing Raw Results¶
# Get all results
results = detector.get_results()
# Access specific method results
hbac = results['hbac']
print(f"HBAC purity: {hbac['purity']:.2f}")
probe = results['probe']
print(f"Probe accuracy: {probe['accuracy']:.2%}")
statistical = results['statistical']
print(f"Significant features: {statistical['n_significant']}")
geometric = results['geometric']
print(f"Effect size: {geometric['effect_size']:.2f}")
Loader-based CAV Usage¶
detector = ShortcutDetector(methods=["cav"])
detector.fit_from_loaders({
"cav": {
"concept_sets": {"shortcut": concept_arr},
"random_set": random_arr,
"target_directional_derivatives": directional_derivatives, # optional but needed for TCAV risk
}
})
print(detector.get_results()["cav"]["metrics"])
Error Handling¶
try:
detector.fit(embeddings, group_labels)
except ValueError as e:
print(f"Invalid input: {e}")
except RuntimeError as e:
print(f"Detection failed: {e}")