Skip to content

VAE API

VAEDetector

VAEDetector(
    *,
    latent_dim: int = 10,
    kld_weight: float = 3.0,
    lr: float = 0.001,
    batch_size: int = 32,
    epochs: int = 50,
    classifier_epochs: int = 20,
    device: str | device | None = None,
    predictiveness_threshold: float = 0.5,
    random_state: int = 42
)

Bases: DetectorBase

Detect shortcuts via VAE latent disentanglement and predictiveness analysis.

Trains a Beta-VAE on images, then trains a classifier on frozen VAE encoder. Latent dimensions with high classifier weight (predictiveness) are candidate shortcuts. MPWD (max pairwise Wasserstein distance) per dimension indicates class separability.

Parameters:

Name Type Description Default
latent_dim int

VAE latent dimension.

10
kld_weight float

Beta-VAE KL weight (higher = more disentanglement).

3.0
lr float

VAE learning rate.

0.001
batch_size int

Training batch size.

32
epochs int

VAE training epochs.

50
classifier_epochs int

Classifier training epochs (frozen encoder).

20
device str | device | None

Device for training (cuda/cpu).

None
predictiveness_threshold float

Normalized predictiveness [0,1] above which a dimension is flagged as shortcut. Default 0.5 = top half.

0.5
random_state int

Random seed.

42
Source code in shortcut_detect/vae/vae_detector.py
def __init__(
    self,
    *,
    latent_dim: int = 10,
    kld_weight: float = 3.0,
    lr: float = 0.001,
    batch_size: int = 32,
    epochs: int = 50,
    classifier_epochs: int = 20,
    device: str | torch.device | None = None,
    predictiveness_threshold: float = 0.5,
    random_state: int = 42,
) -> None:
    """
    Args:
        latent_dim: VAE latent dimension.
        kld_weight: Beta-VAE KL weight (higher = more disentanglement).
        lr: VAE learning rate.
        batch_size: Training batch size.
        epochs: VAE training epochs.
        classifier_epochs: Classifier training epochs (frozen encoder).
        device: Device for training (cuda/cpu).
        predictiveness_threshold: Normalized predictiveness [0,1] above which
            a dimension is flagged as shortcut. Default 0.5 = top half.
        random_state: Random seed.
    """
    super().__init__(method="vae")
    self.latent_dim = int(latent_dim)
    self.kld_weight = float(kld_weight)
    self.lr = float(lr)
    self.batch_size = int(batch_size)
    self.epochs = int(epochs)
    self.classifier_epochs = int(classifier_epochs)
    self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
    self.predictiveness_threshold = float(predictiveness_threshold)
    self.random_state = int(random_state)

    self._vae: ResnetVAE | None = None
    self._classifier: VAEClassifier | None = None
    self._latents_: np.ndarray | None = None
    self._labels_: np.ndarray | None = None

Functions

fit

fit(
    *,
    images: ndarray | Tensor | None = None,
    labels: ndarray | Tensor | None = None,
    img_size: int = 0,
    channels: int = 3,
    num_classes: int = 2,
    group_labels: ndarray | Tensor | None = None,
    vae_checkpoint: str | None = None,
    train_dl: DataLoader | None = None,
    val_dl: DataLoader | None = None,
    test_dl: DataLoader | None = None
) -> VAEDetector

Fit VAE detector on images.

Provide either (images, labels) or (train_dl, val_dl, test_dl).

Source code in shortcut_detect/vae/vae_detector.py
def fit(
    self,
    *,
    images: np.ndarray | torch.Tensor | None = None,
    labels: np.ndarray | torch.Tensor | None = None,
    img_size: int = 0,
    channels: int = 3,
    num_classes: int = 2,
    group_labels: np.ndarray | torch.Tensor | None = None,
    vae_checkpoint: str | None = None,
    train_dl: DataLoader | None = None,
    val_dl: DataLoader | None = None,
    test_dl: DataLoader | None = None,
) -> VAEDetector:
    """
    Fit VAE detector on images.

    Provide either (images, labels) or (train_dl, val_dl, test_dl).
    """
    if train_dl is not None and val_dl is not None:
        self._fit_from_dataloaders(
            train_dl=train_dl,
            val_dl=val_dl,
            test_dl=test_dl or val_dl,
            img_size=img_size,
            channels=channels,
            num_classes=num_classes,
            vae_checkpoint=vae_checkpoint,
        )
    else:
        if images is None or labels is None:
            raise ValueError("Provide images and labels, or train_dl/val_dl/test_dl.")
        self._fit_from_arrays(
            images=images,
            labels=labels,
            img_size=img_size,
            channels=channels,
            num_classes=num_classes,
            vae_checkpoint=vae_checkpoint,
        )

    latent_matrix = self._latents_
    labels_arr = self._labels_
    assert latent_matrix is not None
    assert labels_arr is not None

    mpwd = compute_mpwd_per_dimension(
        latent_matrix,
        labels_arr.astype(np.int64),
        self.latent_dim,
        num_classes,
    )

    predictiveness = compute_predictiveness_per_dimension(
        self._classifier,
        self.latent_dim,
    )

    # Normalize predictiveness to [0, 1] for thresholding
    pred_max = float(np.max(predictiveness))
    pred_norm = predictiveness / pred_max if pred_max > 0 else predictiveness

    _, flagged_indices = rank_candidate_dimensions(
        pred_norm,
        mpwd,
        self.predictiveness_threshold,
    )

    shortcut_detected = len(flagged_indices) > 0
    max_pred = float(np.max(pred_norm))
    n_flagged = len(flagged_indices)

    if shortcut_detected:
        risk_level = "high" if n_flagged >= self.latent_dim // 2 else "moderate"
        notes = f"{n_flagged} latent dimension(s) exceed predictiveness threshold."
    else:
        risk_level = "low"
        notes = "No latent dimension exceeded predictiveness threshold."

    per_dim = []
    for i in range(self.latent_dim):
        per_dim.append(
            {
                "dimension": i,
                "predictiveness": float(pred_norm[i]),
                "mpwd": float(mpwd[i]),
                "flagged": i in flagged_indices,
            }
        )

    metrics = {
        "n_candidate_dims": n_flagged,
        "max_predictiveness": max_pred,
        "n_flagged": n_flagged,
        "latent_dim": self.latent_dim,
    }

    metadata = {
        "kld_weight": self.kld_weight,
        "epochs": self.epochs,
        "predictiveness_threshold": self.predictiveness_threshold,
    }

    self._set_results(
        shortcut_detected=shortcut_detected,
        risk_level=risk_level,
        metrics=metrics,
        notes=notes,
        metadata=metadata,
        report={"per_dimension": per_dim},
        details={"mpwd": mpwd.tolist(), "predictiveness": pred_norm.tolist()},
    )
    self.shortcut_detected_ = shortcut_detected
    self._is_fitted = True
    return self