Skip to content

Probes API

The probes module provides classifier-based shortcut detection.

Class Reference

SKLearnProbe

SKLearnProbe

SKLearnProbe(
    estimator: BaseEstimator | None = None,
    *,
    metric: MetricName = "f1",
    threshold: float = 0.7,
    average: str = "macro",
    evaluation: EvaluationName = "holdout",
    test_size: float = 0.2,
    cv_folds: int = 5,
    random_state: int = 0
)

Bases: DetectorBase

Shortcut detector based on training a classifier to predict a demographic target.

Idea
  • Train a probe classifier to predict a sensitive/demographic attribute y from embeddings X.
  • If the probe performs above a user-defined threshold on a metric (e.g., F1), treat this as evidence that embeddings encode the attribute (potential shortcut).

Parameters

estimator: Any scikit-learn estimator supporting fit/predict (optionally predict_proba or decision_function). If None, uses a standardized LogisticRegression. metric: One of: "accuracy", "f1", "precision", "recall", "roc_auc". threshold: Shortcut is detected if metric_value > threshold. average: Averaging strategy for multiclass f1/precision/recall ("macro", "micro", "weighted"). For binary problems, "binary" is used automatically for these metrics. evaluation: "holdout" (train/test split), "cv" (StratifiedKFold cross-validation), or "train" (no splitting). test_size: Used for holdout split. cv_folds: Used for CV. random_state: Reproducibility for splitting.

Fit inputs

embeddings: np.ndarray, shape (n_samples, n_features) target: np.ndarray, shape (n_samples,) Demographic/sensitive attribute labels (e.g., gender).

Source code in shortcut_detect/probes/sklearn_probe.py
def __init__(
    self,
    estimator: BaseEstimator | None = None,
    *,
    metric: MetricName = "f1",
    threshold: float = 0.70,
    average: str = "macro",
    evaluation: EvaluationName = "holdout",
    test_size: float = 0.2,
    cv_folds: int = 5,
    random_state: int = 0,
):
    super().__init__(method="ml_probe")

    if estimator is None:
        # Reasonable default probe: scale then logistic regression
        estimator = Pipeline(
            steps=[
                ("scaler", StandardScaler(with_mean=True, with_std=True)),
                (
                    "clf",
                    LogisticRegression(
                        max_iter=2000,
                        solver="lbfgs",
                        random_state=random_state,
                    ),
                ),
            ]
        )

    self.estimator = estimator
    self.config = MLProbeConfig(
        metric=metric,
        threshold=float(threshold),
        average=average,
        evaluation=evaluation,
        test_size=float(test_size),
        cv_folds=int(cv_folds),
        random_state=int(random_state),
    )

    # Fitted artifacts
    self.estimator_: BaseEstimator | None = None
    self.metric_value_: float | None = None
    self.y_true_eval_: np.ndarray | None = None
    self.y_pred_eval_: np.ndarray | None = None

Functions

predict

predict(X: ndarray) -> np.ndarray

Predict class labels for embeddings (requires prior fit).

Source code in shortcut_detect/probes/sklearn_probe.py
def predict(self, X: np.ndarray) -> np.ndarray:
    """Predict class labels for embeddings (requires prior fit)."""
    self._ensure_fitted()
    X = np.asarray(X)
    if X.ndim != 2:
        raise ValueError(f"X must be 2D (n_samples, n_features). Got shape={X.shape}.")
    return np.asarray(self.estimator_.predict(X))

TorchProbe

TorchProbe

TorchProbe(
    model: Module,
    loss_fn: Any,
    *,
    optimizer_class: Any = torch.optim.Adam,
    optimizer_kwargs: dict[str, Any] | None = None,
    device: str | None = None,
    metric: MetricName = "accuracy",
    threshold: float = 0.7,
    test_size: float = 0.2,
    random_state: int = 0,
    epochs: int = 10,
    batch_size: int = 128,
    num_workers: int = 0,
    early_stopping: int | None = None,
    use_amp: bool = False,
    verbose: bool = False,
    loader_factory: LoaderFactory | None = None,
    stage_loader_overrides: StageLoaderOverrides = None
)

Bases: DetectorBase

Probe-based shortcut detector using a PyTorch model.

Fits a torch model to predict a demographic target from embeddings and flags a shortcut if the chosen metric exceeds a threshold.

Parameters

model: torch.nn.Module that maps embeddings -> logits (classification) or scalar (regression). For classification, return shape (N, C) logits. loss_fn: Loss function (e.g., nn.CrossEntropyLoss()). optimizer_class / optimizer_kwargs: Optimizer configuration. device: "cpu" or "cuda"; defaults to CUDA if available. metric: One of: "accuracy", "f1", "roc_auc", "loss". For multiclass, "f1" uses macro averaging; "roc_auc" only supported for binary. threshold: shortcut_detected is True when metric_value > threshold (except for "loss", where shortcut_detected is True when loss < threshold if you choose to use loss; see notes).

Source code in shortcut_detect/probes/torch_probe.py
def __init__(
    self,
    model: nn.Module,
    loss_fn: Any,
    *,
    optimizer_class: Any = torch.optim.Adam,
    optimizer_kwargs: dict[str, Any] | None = None,
    device: str | None = None,
    metric: MetricName = "accuracy",
    threshold: float = 0.70,
    test_size: float = 0.2,
    random_state: int = 0,
    epochs: int = 10,
    batch_size: int = 128,
    num_workers: int = 0,
    early_stopping: int | None = None,
    use_amp: bool = False,
    verbose: bool = False,
    loader_factory: LoaderFactory | None = None,
    stage_loader_overrides: StageLoaderOverrides = None,
):
    super().__init__(method="torch_probe")

    self.model = model
    self.loss_fn = loss_fn
    self.optimizer_class = optimizer_class
    self.optimizer_kwargs = optimizer_kwargs or {"lr": 1e-3}
    self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")

    self.config = TorchProbeConfig(
        metric=metric,
        threshold=float(threshold),
        test_size=float(test_size),
        random_state=int(random_state),
        epochs=int(epochs),
        batch_size=int(batch_size),
        num_workers=int(num_workers),
        early_stopping=early_stopping,
        use_amp=bool(use_amp),
        verbose=bool(verbose),
    )

    self.shortcut_detected_: bool | None = None
    self.metric_value_: float | None = None

    self._optimizer: torch.optim.Optimizer | None = None
    self._scaler: torch.cuda.amp.GradScaler | None = None
    self.loader_factory = loader_factory
    self.stage_loader_overrides = stage_loader_overrides

Functions

fit_dataset

fit_dataset(
    dataset: Dataset,
    *,
    val_dataset: Dataset | None = None,
    target_extractor: Callable[[Any], Any] | None = None,
    data_spec: DataSpec | dict[str, Any] | None = None
) -> TorchProbe

Train using map-style or iterable datasets without materializing full arrays.

Source code in shortcut_detect/probes/torch_probe.py
def fit_dataset(
    self,
    dataset: Dataset,
    *,
    val_dataset: Dataset | None = None,
    target_extractor: Callable[[Any], Any] | None = None,
    data_spec: DataSpec | dict[str, Any] | None = None,
) -> TorchProbe:
    """Train using map-style or iterable datasets without materializing full arrays."""
    resolved_spec = resolve_data_spec(data_spec)
    is_iterable = is_iterable_dataset(dataset)

    if is_iterable and val_dataset is None:
        raise ValueError("IterableDataset requires `val_dataset` for TorchProbe.fit_dataset.")

    if val_dataset is None:
        n = safe_len(dataset)
        if n is None or n < 4:
            raise ValueError("Dataset must expose len() and have at least 4 samples.")
        rng = np.random.RandomState(self.config.random_state)
        idx = np.arange(n)
        rng.shuffle(idx)
        test_n = max(1, int(round(self.config.test_size * n)))
        val_idx = idx[:test_n]
        train_idx = idx[test_n:]
        train_dataset = Subset(dataset, train_idx.tolist())
        eval_dataset = Subset(dataset, val_idx.tolist())
    else:
        train_dataset = dataset
        eval_dataset = val_dataset

    train_loader = self._make_loader_from_dataset(
        train_dataset,
        stage="train",
        shuffle=not is_iterable_dataset(train_dataset),
    )
    val_loader = self._make_loader_from_dataset(
        eval_dataset,
        stage="val",
        shuffle=False,
    )
    return self.fit_loaders(
        train_loader,
        val_loader=val_loader,
        target_extractor=target_extractor,
        data_spec=resolved_spec,
    )

fit_loaders

fit_loaders(
    train_loader: DataLoader,
    *,
    val_loader: DataLoader,
    target_extractor: Callable[[Any], Any] | None = None,
    data_spec: DataSpec | dict[str, Any] | None = None
) -> TorchProbe

Train/evaluate from user-provided loaders.

Source code in shortcut_detect/probes/torch_probe.py
def fit_loaders(
    self,
    train_loader: DataLoader,
    *,
    val_loader: DataLoader,
    target_extractor: Callable[[Any], Any] | None = None,
    data_spec: DataSpec | dict[str, Any] | None = None,
) -> TorchProbe:
    """Train/evaluate from user-provided loaders."""
    resolved_spec = resolve_data_spec(data_spec)
    device = torch.device(self.device)

    train_loss_hist, best_val_loss, best_epoch = self._train_with_loaders(
        train_loader,
        val_loader,
        target_extractor=target_extractor,
    )
    eval_metrics = self._evaluate_loader(
        val_loader,
        device,
        target_extractor=target_extractor,
    )

    metric = self.config.metric
    if metric not in eval_metrics:
        raise ValueError(
            f"Metric '{metric}' not available. Available: {sorted(eval_metrics.keys())}"
        )
    metric_value = float(eval_metrics[metric])
    self.metric_value_ = metric_value

    if metric == "loss":
        shortcut = bool(metric_value < self.config.threshold)
        notes_rule = "shortcut_detected is True when loss < threshold."
        if metric_value <= min(self.config.threshold, 0.35):
            risk_level = "high"
        elif metric_value < self.config.threshold:
            risk_level = "moderate"
        else:
            risk_level = "low"
    else:
        shortcut = bool(metric_value > self.config.threshold)
        notes_rule = f"shortcut_detected is True when {metric} > threshold."
        if metric_value >= max(self.config.threshold, 0.85):
            risk_level = "high"
        elif metric_value >= self.config.threshold:
            risk_level = "moderate"
        else:
            risk_level = "low"

    self.shortcut_detected_ = shortcut
    n_train = (
        resolved_spec.train_size
        if resolved_spec is not None
        else safe_len(getattr(train_loader, "dataset", None))
    )
    n_test = (
        resolved_spec.val_size
        if resolved_spec is not None
        else safe_len(getattr(val_loader, "dataset", None))
    )

    self._set_results(
        shortcut_detected=shortcut,
        risk_level=risk_level,
        metrics={
            "metric": metric,
            "metric_value": metric_value,
            "threshold": self.config.threshold,
            "protocol": "loader",
        },
        notes=("Trained a PyTorch probe model from provided data loaders. " + notes_rule),
        metadata={
            "device": self.device,
            "epochs": self.config.epochs,
            "batch_size": self.config.batch_size,
            "optimizer": getattr(self.optimizer_class, "__name__", str(self.optimizer_class)),
            "optimizer_kwargs": dict(self.optimizer_kwargs),
            "n_train": n_train,
            "n_test": n_test,
        },
        report={
            "protocol": "loader",
            "test_size": self.config.test_size,
            "train_loss_history": train_loss_hist,
            "best_val_loss": best_val_loss,
            "best_epoch": best_epoch,
            "eval_metrics": {k: float(v) for k, v in eval_metrics.items()},
        },
    )
    self._is_fitted = True
    return self

predict

predict(X: ndarray) -> np.ndarray

Predict class labels for embeddings (requires prior fit).

Source code in shortcut_detect/probes/torch_probe.py
def predict(self, X: np.ndarray) -> np.ndarray:
    """Predict class labels for embeddings (requires prior fit)."""
    self._ensure_fitted()
    X_arr = np.asarray(X)
    if X_arr.ndim != 2:
        raise ValueError(f"X must be 2D (n_samples, n_features). Got shape={X_arr.shape}.")
    device = torch.device(self.device)
    self.model.to(device)
    self.model.eval()
    loader = self._make_loader(
        X_arr,
        np.zeros(X_arr.shape[0], dtype=np.int64),
        stage="predict",
        batch_size=max(64, min(256, int(X_arr.shape[0]))),
        shuffle=False,
        num_workers=0,
    )
    preds_list = []
    with torch.no_grad():
        for xb, _ in loader:
            xb = xb.to(device, non_blocking=True)
            out = self.model(xb)
            if out.dim() == 2 and out.shape[1] > 1:
                preds_list.append(out.argmax(dim=1).detach().cpu().numpy())
            else:
                preds_list.append(out.squeeze(-1).detach().cpu().numpy().astype(np.int64))
    return np.concatenate(preds_list, axis=0)

SKLearnProbe

Probe using scikit-learn classifiers.

Constructor

SKLearnProbe(
    classifier: sklearn.base.ClassifierMixin = None,
    cv: int = 5
)

Parameters

Parameter Type Default Description
classifier ClassifierMixin LogisticRegression sklearn classifier
cv int 5 Cross-validation folds

Methods

fit()

def fit(X: np.ndarray, y: np.ndarray) -> SKLearnProbe

Train the probe classifier.

score()

def score(X: np.ndarray, y: np.ndarray) -> float

Evaluate accuracy on test data.

predict()

def predict(X: np.ndarray) -> np.ndarray

Predict group labels.

predict_proba()

def predict_proba(X: np.ndarray) -> np.ndarray

Predict class probabilities.

Attributes

Attribute Type Description
accuracy_ float Training accuracy
cv_scores_ ndarray Cross-validation scores
classifier object Fitted classifier

Usage

from shortcut_detect import SKLearnProbe
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(
    embeddings, group_labels, test_size=0.2
)

probe = SKLearnProbe(LogisticRegression(max_iter=1000))
probe.fit(X_train, y_train)

accuracy = probe.score(X_test, y_test)
print(f"Accuracy: {accuracy:.2%}")

TorchProbe

Probe using PyTorch models with GPU support.

Constructor

TorchProbe(
    model: torch.nn.Module = None,
    device: str = 'cpu',
    epochs: int = 100,
    learning_rate: float = 1e-3,
    batch_size: int = 64,
    early_stopping: int = 10
)

Parameters

Parameter Type Default Description
model nn.Module MLP PyTorch model
device str 'cpu' Device ('cpu' or 'cuda')
epochs int 100 Training epochs
learning_rate float 1e-3 Learning rate
batch_size int 64 Batch size
early_stopping int 10 Early stopping patience
loader_factory callable or None None Optional hook to build loaders by stage
stage_loader_overrides dict or None None Per-stage DataLoader kwargs overrides

Methods

Same as SKLearnProbe: fit(), score(), predict(), predict_proba()

Additional Attributes

Attribute Type Description
train_losses_ list Training loss history
val_losses_ list Validation loss history

Usage

from shortcut_detect import TorchProbe
import torch.nn as nn

class CustomProbe(nn.Module):
    def __init__(self, input_dim, n_classes):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(128, n_classes)
        )

    def forward(self, x):
        return self.net(x)

probe = TorchProbe(
    model=CustomProbe(512, 3),
    device='cuda',
    epochs=50
)
probe.fit(X_train, y_train)
accuracy = probe.score(X_test, y_test)

Base Probe Class

Probe

Abstract base class for all probes.

from shortcut_detect.probes import Probe

class MyCustomProbe(Probe):
    def fit(self, X, y):
        # Training logic
        return self

    def score(self, X, y):
        # Evaluation logic
        return accuracy

    def predict(self, X):
        # Prediction logic
        return predictions

See Also