Skip to content

Adversarial Debiasing (M04) API

AdversarialDebiasing

AdversarialDebiasing(
    hidden_dim: int | None = None,
    adversary_weight: float = 0.5,
    n_epochs: int = 50,
    batch_size: int = 64,
    lr: float = 0.001,
    dropout: float = 0.1,
    device: str | device | None = None,
    random_state: int | None = None,
)

Adversarial debiasing to remove demographic encoding from embeddings (Zhang et al. 2018).

Trains an encoder that maps embeddings to a hidden representation, with an adversary (via Gradient Reversal Layer) trying to predict the protected attribute from that representation. The encoder learns to be informative for the task while uninformative for the protected attribute.

Parameters

hidden_dim : int, optional Hidden dimension of the encoder. If None, uses min(64, embed_dim). adversary_weight : float Weight (lambda) for the adversarial loss. Higher values push harder to remove protected-attribute information. Default 0.5. n_epochs : int Number of training epochs. batch_size : int Batch size for training. lr : float Learning rate. dropout : float Dropout rate in the encoder. device : str or torch.device, optional Device to train on. Defaults to cuda if available else cpu. random_state : int, optional Seed for reproducibility.

Source code in shortcut_detect/mitigation/adversarial_debiasing.py
def __init__(
    self,
    hidden_dim: int | None = None,
    adversary_weight: float = 0.5,
    n_epochs: int = 50,
    batch_size: int = 64,
    lr: float = 1e-3,
    dropout: float = 0.1,
    device: str | torch.device | None = None,
    random_state: int | None = None,
):
    self.hidden_dim = hidden_dim
    self.adversary_weight = float(adversary_weight)
    self.n_epochs = int(n_epochs)
    self.batch_size = int(batch_size)
    self.lr = float(lr)
    self.dropout = float(dropout)
    self.device = device
    self.random_state = random_state

    self._encoder: nn.Module | None = None
    self._adversary: nn.Module | None = None
    self._task_head: nn.Module | None = None
    self._embed_dim: int | None = None
    self._n_protected: int | None = None
    self._n_task: int | None = None
    self._fitted = False

Functions

fit

fit(
    embeddings: ndarray,
    protected_labels: ndarray,
    task_labels: ndarray | None = None,
) -> AdversarialDebiasing

Fit the adversarial debiasing model.

Parameters

embeddings : np.ndarray Shape (n_samples, embed_dim). protected_labels : np.ndarray Shape (n_samples,) – demographic/protected attribute labels. task_labels : np.ndarray, optional Shape (n_samples,) – task labels. If provided, the encoder also minimizes task loss to preserve utility.

Returns

self : AdversarialDebiasing

Source code in shortcut_detect/mitigation/adversarial_debiasing.py
def fit(
    self,
    embeddings: np.ndarray,
    protected_labels: np.ndarray,
    task_labels: np.ndarray | None = None,
) -> AdversarialDebiasing:
    """
    Fit the adversarial debiasing model.

    Parameters
    ----------
    embeddings : np.ndarray
        Shape (n_samples, embed_dim).
    protected_labels : np.ndarray
        Shape (n_samples,) – demographic/protected attribute labels.
    task_labels : np.ndarray, optional
        Shape (n_samples,) – task labels. If provided, the encoder also
        minimizes task loss to preserve utility.

    Returns
    -------
    self : AdversarialDebiasing
    """
    self._setup_seed()
    X = np.asarray(embeddings, dtype=np.float32)
    s = np.asarray(protected_labels)
    if X.ndim != 2:
        raise ValueError("embeddings must be 2D (n_samples, embed_dim)")
    if s.ndim != 1:
        raise ValueError("protected_labels must be 1D")
    if X.shape[0] != s.shape[0]:
        raise ValueError("embeddings and protected_labels must have same length")

    embed_dim = X.shape[1]
    uniq = np.unique(s)
    n_protected = len(uniq)
    s_map = {v: i for i, v in enumerate(uniq)}
    s_idx = np.array([s_map[v] for v in s], dtype=np.int64)

    n_task: int | None = None
    y_idx: np.ndarray | None = None
    if task_labels is not None:
        y = np.asarray(task_labels)
        if y.ndim != 1 or y.shape[0] != X.shape[0]:
            raise ValueError("task_labels must be 1D with same length as embeddings")
        y_uniq = np.unique(y)
        n_task = len(y_uniq)
        y_map = {v: i for i, v in enumerate(y_uniq)}
        y_idx = np.array([y_map[v] for v in y], dtype=np.int64)

    h_dim = self.hidden_dim
    if h_dim is None:
        h_dim = min(64, embed_dim)

    device = self._device()
    encoder = nn.Sequential(
        nn.Linear(embed_dim, h_dim),
        nn.ReLU(),
        nn.Dropout(self.dropout),
    ).to(device)
    adversary = nn.Linear(h_dim, n_protected).to(device)
    task_head = nn.Linear(h_dim, n_task).to(device) if n_task is not None else None

    opt = torch.optim.Adam(
        list(encoder.parameters())
        + list(adversary.parameters())
        + (list(task_head.parameters()) if task_head is not None else []),
        lr=self.lr,
    )
    ce = nn.CrossEntropyLoss()

    X_t = torch.from_numpy(X).float().to(device)
    s_t = torch.from_numpy(s_idx).long().to(device)
    if y_idx is not None:
        y_t = torch.from_numpy(y_idx).long().to(device)

    ds = TensorDataset(
        X_t,
        s_t,
        *((y_t,) if y_idx is not None else ()),
    )
    loader = DataLoader(ds, batch_size=self.batch_size, shuffle=True, drop_last=False)

    for _ in range(self.n_epochs):
        encoder.train()
        adversary.train()
        if task_head is not None:
            task_head.train()
        for batch in loader:
            x_b = batch[0]
            s_b = batch[1]
            y_b = batch[2] if len(batch) > 2 else None

            hidden = encoder(x_b)
            adv_in = _grl(hidden, self.adversary_weight)
            adv_logits = adversary(adv_in)
            loss_adv = ce(adv_logits, s_b)

            if task_head is not None and y_b is not None:
                task_logits = task_head(hidden)
                loss_task = ce(task_logits, y_b)
                loss = loss_task + loss_adv
            else:
                loss = loss_adv

            opt.zero_grad(set_to_none=True)
            loss.backward()
            opt.step()

    self._encoder = encoder
    self._adversary = adversary
    self._task_head = task_head
    self._embed_dim = embed_dim
    self._n_protected = n_protected
    self._n_task = n_task
    self._fitted = True
    return self

transform

transform(embeddings: ndarray) -> np.ndarray

Transform embeddings to debiased representations.

Parameters

embeddings : np.ndarray Shape (n_samples, embed_dim). Must match embed_dim from fit.

Returns

np.ndarray Debiased embeddings, shape (n_samples, hidden_dim).

Source code in shortcut_detect/mitigation/adversarial_debiasing.py
def transform(self, embeddings: np.ndarray) -> np.ndarray:
    """
    Transform embeddings to debiased representations.

    Parameters
    ----------
    embeddings : np.ndarray
        Shape (n_samples, embed_dim). Must match embed_dim from fit.

    Returns
    -------
    np.ndarray
        Debiased embeddings, shape (n_samples, hidden_dim).
    """
    if not self._fitted or self._encoder is None:
        raise ValueError("AdversarialDebiasing has not been fitted")
    X = np.asarray(embeddings, dtype=np.float32)
    if X.ndim != 2:
        raise ValueError("embeddings must be 2D")
    if X.shape[1] != self._embed_dim:
        raise ValueError(
            f"embed_dim {X.shape[1]} does not match fitted embed_dim {self._embed_dim}"
        )

    device = self._device()
    self._encoder.eval()
    with torch.no_grad():
        x_t = torch.from_numpy(X).float().to(device)
        hidden = self._encoder(x_t)
        out = hidden.cpu().numpy()
    return out.astype(np.float64)

fit_transform

fit_transform(
    embeddings: ndarray,
    protected_labels: ndarray,
    task_labels: ndarray | None = None,
) -> np.ndarray

Fit and transform in one step.

Source code in shortcut_detect/mitigation/adversarial_debiasing.py
def fit_transform(
    self,
    embeddings: np.ndarray,
    protected_labels: np.ndarray,
    task_labels: np.ndarray | None = None,
) -> np.ndarray:
    """Fit and transform in one step."""
    self.fit(embeddings, protected_labels, task_labels=task_labels)
    return self.transform(embeddings)