Probes API¶
The probes module provides classifier-based shortcut detection.
Class Reference¶
SKLearnProbe¶
SKLearnProbe
¶
SKLearnProbe(
estimator: BaseEstimator | None = None,
*,
metric: MetricName = "f1",
threshold: float = 0.7,
average: str = "macro",
evaluation: EvaluationName = "holdout",
test_size: float = 0.2,
cv_folds: int = 5,
random_state: int = 0
)
Bases: DetectorBase
Shortcut detector based on training a classifier to predict a demographic target.
Idea
- Train a probe classifier to predict a sensitive/demographic attribute y from embeddings X.
- If the probe performs above a user-defined threshold on a metric (e.g., F1), treat this as evidence that embeddings encode the attribute (potential shortcut).
Parameters¶
estimator: Any scikit-learn estimator supporting fit/predict (optionally predict_proba or decision_function). If None, uses a standardized LogisticRegression. metric: One of: "accuracy", "f1", "precision", "recall", "roc_auc". threshold: Shortcut is detected if metric_value > threshold. average: Averaging strategy for multiclass f1/precision/recall ("macro", "micro", "weighted"). For binary problems, "binary" is used automatically for these metrics. evaluation: "holdout" (train/test split), "cv" (StratifiedKFold cross-validation), or "train" (no splitting). test_size: Used for holdout split. cv_folds: Used for CV. random_state: Reproducibility for splitting.
Fit inputs¶
embeddings: np.ndarray, shape (n_samples, n_features) target: np.ndarray, shape (n_samples,) Demographic/sensitive attribute labels (e.g., gender).
Source code in shortcut_detect/probes/sklearn_probe.py
TorchProbe¶
TorchProbe
¶
TorchProbe(
model: Module,
loss_fn: Any,
*,
optimizer_class: Any = torch.optim.Adam,
optimizer_kwargs: dict[str, Any] | None = None,
device: str | None = None,
metric: MetricName = "accuracy",
threshold: float = 0.7,
test_size: float = 0.2,
random_state: int = 0,
epochs: int = 10,
batch_size: int = 128,
num_workers: int = 0,
early_stopping: int | None = None,
use_amp: bool = False,
verbose: bool = False,
loader_factory: LoaderFactory | None = None,
stage_loader_overrides: StageLoaderOverrides = None
)
Bases: DetectorBase
Probe-based shortcut detector using a PyTorch model.
Fits a torch model to predict a demographic target from embeddings and flags a shortcut if the chosen metric exceeds a threshold.
Parameters¶
model: torch.nn.Module that maps embeddings -> logits (classification) or scalar (regression). For classification, return shape (N, C) logits. loss_fn: Loss function (e.g., nn.CrossEntropyLoss()). optimizer_class / optimizer_kwargs: Optimizer configuration. device: "cpu" or "cuda"; defaults to CUDA if available. metric: One of: "accuracy", "f1", "roc_auc", "loss". For multiclass, "f1" uses macro averaging; "roc_auc" only supported for binary. threshold: shortcut_detected is True when metric_value > threshold (except for "loss", where shortcut_detected is True when loss < threshold if you choose to use loss; see notes).
Source code in shortcut_detect/probes/torch_probe.py
Functions¶
fit_dataset
¶
fit_dataset(
dataset: Dataset,
*,
val_dataset: Dataset | None = None,
target_extractor: Callable[[Any], Any] | None = None,
data_spec: DataSpec | dict[str, Any] | None = None
) -> TorchProbe
Train using map-style or iterable datasets without materializing full arrays.
Source code in shortcut_detect/probes/torch_probe.py
fit_loaders
¶
fit_loaders(
train_loader: DataLoader,
*,
val_loader: DataLoader,
target_extractor: Callable[[Any], Any] | None = None,
data_spec: DataSpec | dict[str, Any] | None = None
) -> TorchProbe
Train/evaluate from user-provided loaders.
Source code in shortcut_detect/probes/torch_probe.py
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 | |
predict
¶
Predict class labels for embeddings (requires prior fit).
Source code in shortcut_detect/probes/torch_probe.py
SKLearnProbe¶
Probe using scikit-learn classifiers.
Constructor¶
Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
classifier |
ClassifierMixin | LogisticRegression | sklearn classifier |
cv |
int | 5 | Cross-validation folds |
Methods¶
fit()¶
Train the probe classifier.
score()¶
Evaluate accuracy on test data.
predict()¶
Predict group labels.
predict_proba()¶
Predict class probabilities.
Attributes¶
| Attribute | Type | Description |
|---|---|---|
accuracy_ |
float | Training accuracy |
cv_scores_ |
ndarray | Cross-validation scores |
classifier |
object | Fitted classifier |
Usage¶
from shortcut_detect import SKLearnProbe
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(
embeddings, group_labels, test_size=0.2
)
probe = SKLearnProbe(LogisticRegression(max_iter=1000))
probe.fit(X_train, y_train)
accuracy = probe.score(X_test, y_test)
print(f"Accuracy: {accuracy:.2%}")
TorchProbe¶
Probe using PyTorch models with GPU support.
Constructor¶
TorchProbe(
model: torch.nn.Module = None,
device: str = 'cpu',
epochs: int = 100,
learning_rate: float = 1e-3,
batch_size: int = 64,
early_stopping: int = 10
)
Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
nn.Module | MLP | PyTorch model |
device |
str | 'cpu' | Device ('cpu' or 'cuda') |
epochs |
int | 100 | Training epochs |
learning_rate |
float | 1e-3 | Learning rate |
batch_size |
int | 64 | Batch size |
early_stopping |
int | 10 | Early stopping patience |
loader_factory |
callable or None | None | Optional hook to build loaders by stage |
stage_loader_overrides |
dict or None | None | Per-stage DataLoader kwargs overrides |
Methods¶
Same as SKLearnProbe: fit(), score(), predict(), predict_proba()
Additional Attributes¶
| Attribute | Type | Description |
|---|---|---|
train_losses_ |
list | Training loss history |
val_losses_ |
list | Validation loss history |
Usage¶
from shortcut_detect import TorchProbe
import torch.nn as nn
class CustomProbe(nn.Module):
def __init__(self, input_dim, n_classes):
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, 128),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(128, n_classes)
)
def forward(self, x):
return self.net(x)
probe = TorchProbe(
model=CustomProbe(512, 3),
device='cuda',
epochs=50
)
probe.fit(X_train, y_train)
accuracy = probe.score(X_test, y_test)
Base Probe Class¶
Probe¶
Abstract base class for all probes.
from shortcut_detect.probes import Probe
class MyCustomProbe(Probe):
def fit(self, X, y):
# Training logic
return self
def score(self, X, y):
# Evaluation logic
return accuracy
def predict(self, X):
# Prediction logic
return predictions