Skip to content

SSA Detector API

The SSADetector class implements semi-supervised shortcut detection by pseudo-labeling spurious attributes and running GroupDRO on pseudo-groups.

Class Reference

SSADetector

SSADetector(config: SSAConfig | None = None)

Bases: DetectorBase

Embeddings-first SSA

Phase 1: pseudo-label spurious attribute a_hat(x) for DU using DL + DU with adaptive thresholds. Phase 2: run GroupDRO on (X_U, y_U, g_hat=(y,a_hat)) with validation on DL (true groups).

Source code in shortcut_detect/ssa/ssa.py
def __init__(self, config: SSAConfig | None = None):
    super().__init__(method="ssa")
    self.config = config or SSAConfig()
    self.report_: dict[str, Any] = {}
    self.groupdro_: GroupDRODetector | None = None
    self.attr_models_: list[nn.Module] = []  # one per fold (optional to keep)
    self.a_hat_: np.ndarray | None = None  # pseudo a for DU

    self.seed = self.config.seed
    self._is_fitted = False
    self.shortcut_detected_ = None
    self.results_ = {}

Functions

fit

fit(
    du_embeddings: ndarray,
    du_labels: ndarray,
    dl_embeddings: ndarray,
    dl_labels: ndarray,
    dl_spurious: ndarray,
) -> SSADetector
End-to-end SSA

Phase 1: pseudo-label spurious attribute for DU (Algorithm 1, Eq. 5–9). Phase 2: GroupDRO on pseudo-groups with validation on DL.

Source code in shortcut_detect/ssa/ssa.py
def fit(
    self,
    # DU: group-unlabeled
    du_embeddings: np.ndarray,
    du_labels: np.ndarray,
    # DL: group-labeled (has spurious attribute labels)
    dl_embeddings: np.ndarray,
    dl_labels: np.ndarray,
    dl_spurious: np.ndarray,
) -> SSADetector:
    """
    End-to-end SSA:
      Phase 1: pseudo-label spurious attribute for DU (Algorithm 1, Eq. 5–9).
      Phase 2: GroupDRO on pseudo-groups with validation on DL.
    """
    self._set_seed()
    cfg = self.config
    device = self._device()
    rng = np.random.RandomState(cfg.seed)

    # infer cardinalities
    n_y = int(max(np.max(du_labels), np.max(dl_labels))) + 1
    n_a = int(np.max(dl_spurious)) + 1
    G = n_y * n_a

    # Phase 1: Algorithm 1 pseudo-labeling with K folds
    du_ds = EmbeddingUnlabeledDataset(du_embeddings, du_labels)
    dl_full_ds = EmbeddingLabeledAttrDataset(dl_embeddings, dl_labels, dl_spurious)

    dl_train_idx, dl_val_idx = self._split_indices(len(dl_full_ds), cfg.dl_val_fraction, rng)
    dl_train_ds = torch.utils.data.Subset(dl_full_ds, dl_train_idx.tolist())
    dl_val_ds = torch.utils.data.Subset(dl_full_ds, dl_val_idx.tolist())

    dl_train_loader = build_loader(
        LoaderRequest(
            stage="dl_train",
            dataset=dl_train_ds,
            batch_size=cfg.batch_size,
            shuffle=True,
            num_workers=0,
            drop_last=False,
        ),
        loader_factory=cfg.loader_factory,
        stage_loader_overrides=cfg.stage_loader_overrides,
    )
    dl_val_loader = build_loader(
        LoaderRequest(
            stage="dl_val",
            dataset=dl_val_ds,
            batch_size=cfg.batch_size,
            shuffle=False,
            num_workers=0,
            drop_last=False,
        ),
        loader_factory=cfg.loader_factory,
        stage_loader_overrides=cfg.stage_loader_overrides,
    )

    folds = self._kfold_indices(len(du_ds), cfg.K, rng)
    a_hat_all = np.zeros(len(du_ds), dtype=np.int64)

    # Precompute DL^o group counts: |DL^o(g)| for g=(y,a)
    dl_y_train = dl_labels[dl_train_idx]
    dl_a_train = dl_spurious[dl_train_idx]
    dl_g_train = self._encode_group(dl_y_train, dl_a_train, n_a)
    dl_group_counts = np.bincount(dl_g_train, minlength=G).astype(np.int64)

    self.attr_models_.clear()

    for k in range(cfg.K):
        du_bullet_idx = folds[k]  # D_U^(k)
        du_circ_idx = np.concatenate([folds[j] for j in range(cfg.K) if j != k]).astype(
            np.int64
        )

        du_circ_ds = torch.utils.data.Subset(du_ds, du_circ_idx.tolist())
        du_bullet_ds = torch.utils.data.Subset(du_ds, du_bullet_idx.tolist())

        du_circ_loader = build_loader(
            LoaderRequest(
                stage="du_train",
                dataset=du_circ_ds,
                batch_size=cfg.batch_size,
                shuffle=True,
                num_workers=0,
                drop_last=False,
            ),
            loader_factory=cfg.loader_factory,
            stage_loader_overrides=cfg.stage_loader_overrides,
        )
        # a second loader (no shuffle) for threshold updates so subsampling is stable-ish
        du_circ_loader_eval = build_loader(
            LoaderRequest(
                stage="du_eval",
                dataset=du_circ_ds,
                batch_size=cfg.batch_size,
                shuffle=False,
                num_workers=0,
                drop_last=False,
            ),
            loader_factory=cfg.loader_factory,
            stage_loader_overrides=cfg.stage_loader_overrides,
        )
        du_bullet_loader = build_loader(
            LoaderRequest(
                stage="du_bullet",
                dataset=du_bullet_ds,
                batch_size=cfg.batch_size,
                shuffle=False,
                num_workers=0,
                drop_last=False,
            ),
            loader_factory=cfg.loader_factory,
            stage_loader_overrides=cfg.stage_loader_overrides,
        )

        d = du_embeddings.shape[1]
        model = AttrPredictor(d=d, n_a=n_a, hidden_dim=cfg.hidden_dim, dropout=cfg.dropout).to(
            device
        )

        opt = torch.optim.SGD(
            model.parameters(),
            lr=cfg.lr,
            momentum=cfg.momentum,
            weight_decay=cfg.weight_decay,
        )
        ce = nn.CrossEntropyLoss(reduction="none")

        # initialize thresholds
        tau = np.full(G, cfg.tau_gmin, dtype=np.float32)

        # train loop for T iterations (Algorithm 1 lines 7–12) :contentReference[oaicite:3]{index=3}
        it_dl = iter(dl_train_loader)
        it_du = iter(du_circ_loader)

        best_worst = -1.0
        best_state = None

        for t in range(1, cfg.T + 1):
            model.train()

            try:
                xL, yL, aL = next(it_dl)
            except StopIteration:
                it_dl = iter(dl_train_loader)
                xL, yL, aL = next(it_dl)

            try:
                xU, yU, _idxU = next(it_du)
            except StopIteration:
                it_du = iter(du_circ_loader)
                xU, yU, _idxU = next(it_du)

            xL = xL.to(device)
            aL = aL.to(device)
            xU = xU.to(device)

            # supervised loss on DL^o (Eq. 3/4 supervised term)
            logitsL = model(xL)
            loss_sup = ce(logitsL, aL).mean()

            # unsupervised loss on DU^o with group-wise threshold (Eq. 8)
            logitsU = model(xU)
            probsU = torch.softmax(logitsU, dim=1)
            confU, a_hatU = probsU.max(dim=1)

            # pseudo-group g_hat=(y, a_hat(x))
            yU_np = yU.numpy().astype(np.int64)
            a_hatU_np = a_hatU.detach().cpu().numpy().astype(np.int64)
            g_hat_np = (yU_np * n_a + a_hatU_np).astype(np.int64)

            tau_batch = torch.from_numpy(tau[g_hat_np]).to(device)
            mask = (confU >= tau_batch).float()

            loss_unsup = (mask * ce(logitsU, a_hatU)).sum() / mask.sum().clamp_min(1.0)

            loss = loss_sup + loss_unsup

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

            # update thresholds periodically (Algorithm 1 line 10; Eq. 7) :contentReference[oaicite:4]{index=4}
            if t % max(1, cfg.threshold_update_every) == 0:
                tau = self._compute_groupwise_thresholds(
                    model=model,
                    du_loader=du_circ_loader_eval,
                    dl_group_counts=dl_group_counts,
                    n_a=n_a,
                    tau_gmin=cfg.tau_gmin,
                    max_items=cfg.threshold_update_max_items,
                    device=device,
                )

            # optional: track best pseudo-labeler via worst-group attr acc on DL^bullet (Appendix A.3)
            if t % 200 == 0:
                worst = self._eval_attr_worst_group_acc(model, dl_val_loader, n_y, n_a, device)
                if worst > best_worst:
                    best_worst = worst
                    best_state = {
                        kk: vv.detach().cpu().clone() for kk, vv in model.state_dict().items()
                    }

        if best_state is not None:
            model.load_state_dict(best_state)

        self.attr_models_.append(model)

        # Predict pseudo-attributes on D_U^bullet (Algorithm 1 lines 13–14) :contentReference[oaicite:5]{index=5}
        model.eval()
        for xU, _yU, idxU in du_bullet_loader:
            xU = xU.to(device)
            logits = model(xU)
            pred_a = logits.argmax(dim=1).detach().cpu().numpy().astype(np.int64)
            idxU_np = idxU.numpy().astype(np.int64)
            a_hat_all[idxU_np] = pred_a

    # Eq. (9): final pseudo-labels for all DU without threshold :contentReference[oaicite:6]{index=6}
    self.a_hat_ = a_hat_all

    # Phase 2: robust training on pseudo-groups, validate on DL (Section 4.3) :contentReference[oaicite:7]{index=7}
    g_train = self._encode_group(du_labels, a_hat_all, n_a)
    g_val = self._encode_group(dl_labels, dl_spurious, n_a)

    group_ids = np.arange(G, dtype=np.int64)

    gdro = GroupDRODetector(cfg.groupdro)
    gdro.fit_with_val(
        train_embeddings=du_embeddings,
        train_labels=du_labels,
        train_group_labels=g_train,
        val_embeddings=dl_embeddings,
        val_labels=dl_labels,
        val_group_labels=g_val,
        group_ids=group_ids,
    )

    self.groupdro_ = gdro

    gdro_report = gdro.get_report()
    gdro_detail = gdro_report.get("report", {})
    final = gdro_detail.get("final", {})
    avg_acc = final.get("avg_acc", float("nan"))
    worst_acc = final.get("worst_group_acc", float("nan"))
    gap = (
        avg_acc - worst_acc if np.isfinite(avg_acc) and np.isfinite(worst_acc) else float("nan")
    )

    self.shortcut_detected_ = self.get_shortcut_detected(avg_acc, worst_acc)

    self.report_ = {
        "pseudo_attr_hat": a_hat_all,  # length |DU|
        "groupdro_report": gdro_detail,
    }
    metrics = {
        "n_labeled": int(len(dl_embeddings)),
        "n_unlabeled": int(len(du_embeddings)),
        "avg_acc": avg_acc,
        "worst_group_acc": worst_acc,
        "gap": gap,
    }
    metadata = {
        "K": cfg.K,
        "T": cfg.T,
        "tau_gmin": cfg.tau_gmin,
        "n_y": n_y,
        "n_a": n_a,
        "n_groups": G,
        "ssa_gap_threshold": cfg.ssa_gap_threshold,
    }
    if self.shortcut_detected_ is None:
        risk_level = "unknown"
    elif self.shortcut_detected_:
        risk_level = "moderate"
    else:
        risk_level = "low"

    self._set_results(
        shortcut_detected=self.shortcut_detected_,
        risk_level=risk_level,
        metrics=metrics,
        notes="SSA pseudo-labeling + GroupDRO gap-based detection.",
        metadata=metadata,
        report=self.report_,
    )
    self._is_fitted = True
    return self

get_shortcut_detected

get_shortcut_detected(
    avg_acc: float, worst_acc: float
) -> bool | None

Detect a shortcut if worst-group gap exceeds threshold.

Source code in shortcut_detect/ssa/ssa.py
def get_shortcut_detected(self, avg_acc: float, worst_acc: float) -> bool | None:
    """Detect a shortcut if worst-group gap exceeds threshold."""
    gap = avg_acc - worst_acc
    gap_thresh = self.config.ssa_gap_threshold
    if not np.isfinite(gap):
        return None
    return gap >= gap_thresh

Quick Reference

Constructor

SSADetector(config: SSAConfig | None = None)

SSAConfig Parameters

Parameter Type Default Description
K int 3 Number of K-fold splits
T int 2000 Training iterations per fold
batch_size int 128 Batch size
lr float 1e-3 Learning rate
weight_decay float 1e-4 Weight decay
momentum float 0.9 SGD momentum
hidden_dim int or None None Hidden layer size (None = linear)
dropout float 0.0 Dropout rate
tau_gmin float 0.95 Confidence threshold for smallest group
threshold_update_every int 200 Threshold update frequency
dl_val_fraction float 0.5 Fraction of DL used for validation
seed int 0 Random seed
device str or None None PyTorch device
groupdro GroupDROConfig default GroupDRO configuration for Phase 2
ssa_gap_threshold float 0.10 Accuracy gap threshold for detection

Methods

fit()

def fit(
    du_embeddings: np.ndarray,
    du_labels: np.ndarray,
    dl_embeddings: np.ndarray,
    dl_labels: np.ndarray,
    dl_spurious: np.ndarray,
) -> SSADetector

Run end-to-end SSA: Phase 1 pseudo-labeling + Phase 2 GroupDRO.

Parameters:

Parameter Type Description
du_embeddings ndarray (n_unlabeled, n_features) unlabeled embeddings
du_labels ndarray (n_unlabeled,) task labels for unlabeled data
dl_embeddings ndarray (n_labeled, n_features) labeled embeddings
dl_labels ndarray (n_labeled,) task labels for labeled data
dl_spurious ndarray (n_labeled,) spurious attribute labels for labeled data

Returns: self

get_report()

def get_report() -> dict

Get the detection report after fitting. Inherits from DetectorBase.

get_shortcut_detected()

def get_shortcut_detected(avg_acc: float, worst_acc: float) -> bool | None

Detect a shortcut if the worst-group accuracy gap exceeds the threshold.

Attributes (after fit)

Attribute Type Description
a_hat_ ndarray Pseudo spurious-attribute labels for DU
groupdro_ GroupDRODetector Fitted GroupDRO detector from Phase 2
attr_models_ list[nn.Module] Attribute predictor models (one per fold)
shortcut_detected_ bool or None Whether a shortcut was detected
report_ dict Detailed report including pseudo labels and GroupDRO results

Usage Examples

Basic Usage

from shortcut_detect.ssa import SSADetector, SSAConfig

detector = SSADetector()
detector.fit(
    du_embeddings=unlabeled_emb,
    du_labels=unlabeled_y,
    dl_embeddings=labeled_emb,
    dl_labels=labeled_y,
    dl_spurious=labeled_attrs,
)
report = detector.get_report()
print(report["shortcut_detected"])

Custom Configuration

config = SSAConfig(
    K=5,
    T=3000,
    tau_gmin=0.90,
    hidden_dim=64,
    dropout=0.1,
    seed=42,
)
detector = SSADetector(config=config)
detector.fit(
    du_embeddings=unlabeled_emb,
    du_labels=unlabeled_y,
    dl_embeddings=labeled_emb,
    dl_labels=labeled_y,
    dl_spurious=labeled_attrs,
)

Via Unified ShortcutDetector

from shortcut_detect import ShortcutDetector

detector = ShortcutDetector(methods=["ssa"])
detector.fit(
    embeddings=unlabeled_emb,
    labels=unlabeled_y,
    dl_embeddings=labeled_emb,
    dl_labels=labeled_y,
    dl_spurious=labeled_attrs,
)
print(detector.summary())

See Also