Generative CVAE Counterfactual Detection¶
The Generative CVAE detector trains a Conditional Variational Autoencoder (CVAE) on embeddings conditioned on a binary spurious attribute. It generates counterfactual embeddings by encoding with the original attribute and decoding with the flipped attribute, then measures how a probe classifier's predictions change.
What It Detects¶
- Whether a spurious (group) attribute causally influences model predictions in embedding space.
- Large probe prediction shifts after counterfactual attribute flipping indicate shortcut reliance.
How It Works¶
- Train CVAE: A conditional VAE learns to reconstruct embeddings conditioned on the binary group label (spurious attribute).
- Train Attribute Predictor: A small linear network learns to predict the group label from embeddings.
- Generate Counterfactuals: For each sample, encode with the original group label and decode with the flipped label. Latent guidance optimizes the counterfactual to actually flip the predicted attribute while staying close to the original embedding.
- Evaluate Probe Shift: A probe classifier (internal LogisticRegression or user-provided) scores both original and counterfactual embeddings. Large prediction shifts indicate the model relies on the spurious attribute.
Required Inputs¶
embeddings:np.ndarray(n, d)— representation spacegroup_labels:np.ndarray(n,)— binary spurious attribute labels (0/1)
Optional:
labels:np.ndarray(n,)— task labels (required for probe training; if omitted, detection is inconclusive)probe_classifier: Pre-trained sklearn-like classifier or callable
Unified API Example¶
from shortcut_detect import ShortcutDetector
detector = ShortcutDetector(methods=["generative_cvae"])
detector.fit(embeddings=emb, labels=labels, group_labels=groups)
result = detector.get_results()["generative_cvae"]
print(result["metrics"])
print(result["shortcut_detected"])
Direct API Example¶
from shortcut_detect.causal import GenerativeCVEDetector
detector = GenerativeCVEDetector(epochs=50, random_state=42)
detector.fit(embeddings, group_labels, labels)
print(detector.results_["shortcut_detected"])
print(detector.results_["metrics"])
print(detector.summary())
Key Metrics¶
| Metric | Description |
|---|---|
mean_delta |
Mean probe prediction shift (original - counterfactual) |
frac_large_change |
Fraction of samples with |
mean_cosine_similarity |
Average cosine similarity between original and counterfactual embeddings |
probe_accuracy |
Accuracy of the internal/external probe |
Detection Rule¶
A shortcut is detected when both:
abs(mean_delta) > mean_delta_threshold(default: 1e-4)frac_large_change > frac_large_threshold(default: 0.01)
Interpretation¶
- High risk: Large prediction shifts indicate the model relies heavily on the spurious attribute.
- Low risk: Minimal prediction shifts suggest the spurious attribute has little influence.
- Unknown: No probe available (labels not provided).
Configuration¶
| Parameter | Default | Description |
|---|---|---|
epochs |
50 | CVAE training epochs |
hidden |
256 | Hidden layer size |
zdim |
64 | Latent dimension |
guidance_steps |
50 | Latent optimization steps for counterfactual generation |
guidance_weight |
5.0 | Weight for attribute-flip loss |
proximity_weight |
1.0 | Weight for staying close to original embedding |
Limitations¶
- Requires binary group labels (0/1).
- Detection thresholds are sensitive to dataset size and noise level.
- CVAE quality depends on sufficient training data and appropriate hyperparameters.
- Results can vary across random seeds, especially with small datasets.