Generative CVAE API¶
Generative CVAE counterfactual detector for shortcut detection via conditional VAE embedding-space counterfactuals.
Class Reference¶
GenerativeCVEDetector
¶
GenerativeCVEDetector(
dim: int = 128,
hidden: int = 256,
zdim: int = 64,
lr: float = 0.001,
batch_size: int = 256,
epochs: int = 50,
device: str = "cpu",
recon_loss_weight: float = 1.0,
kld_weight: float = 0.001,
verbose: bool = False,
random_state: int = 42,
probe_classifier: Any | None = None,
method: str = "generative_cvae",
)
Bases: DetectorBase
Generative CVAE detector for embedding counterfactuals.
Source code in shortcut_detect/causal/generative_cvae/src/detector.py
Quick Example¶
from shortcut_detect.causal import GenerativeCVEDetector
detector = GenerativeCVEDetector(epochs=50, random_state=42)
detector.fit(embeddings, group_labels, labels)
results = detector.results_
print(results["shortcut_detected"])
print(results["metrics"])
print(detector.summary())