Skip to content

Generative CVAE API

Generative CVAE counterfactual detector for shortcut detection via conditional VAE embedding-space counterfactuals.

Class Reference

GenerativeCVEDetector

GenerativeCVEDetector(
    dim: int = 128,
    hidden: int = 256,
    zdim: int = 64,
    lr: float = 0.001,
    batch_size: int = 256,
    epochs: int = 50,
    device: str = "cpu",
    recon_loss_weight: float = 1.0,
    kld_weight: float = 0.001,
    verbose: bool = False,
    random_state: int = 42,
    probe_classifier: Any | None = None,
    method: str = "generative_cvae",
)

Bases: DetectorBase

Generative CVAE detector for embedding counterfactuals.

Source code in shortcut_detect/causal/generative_cvae/src/detector.py
def __init__(
    self,
    dim: int = 128,
    hidden: int = 256,
    zdim: int = 64,
    lr: float = 1e-3,
    batch_size: int = 256,
    epochs: int = 50,
    device: str = "cpu",
    recon_loss_weight: float = 1.0,
    kld_weight: float = 1e-3,
    verbose: bool = False,
    random_state: int = 42,
    probe_classifier: Any | None = None,
    method: str = "generative_cvae",
):
    super().__init__(method=method)

    self.cfg = CVAEConfig(
        dim=dim,
        hidden=hidden,
        zdim=zdim,
        lr=lr,
        batch_size=batch_size,
        epochs=epochs,
        device=device,
        recon_loss_weight=recon_loss_weight,
        kld_weight=kld_weight,
        verbose=verbose,
        random_state=random_state,
    )

    self.external_probe: Any | None = probe_classifier
    self.model: CVAE | None = None
    self.scaler: StandardScaler | None = None
    self.shortcut_detected_ = None

Quick Example

from shortcut_detect.causal import GenerativeCVEDetector

detector = GenerativeCVEDetector(epochs=50, random_state=42)
detector.fit(embeddings, group_labels, labels)

results = detector.results_
print(results["shortcut_detected"])
print(results["metrics"])
print(detector.summary())

Unified API Example

from shortcut_detect import ShortcutDetector

detector = ShortcutDetector(methods=["generative_cvae"])
detector.fit(embeddings=emb, labels=labels, group_labels=groups)

print(detector.get_results()["generative_cvae"])