Skip to content

Last Layer Retraining (M06 DFR) API

LastLayerRetraining

LastLayerRetraining(
    C: float = 1.0,
    penalty: str = "l1",
    solver: str = "liblinear",
    class_weight: str | dict | None = "balanced",
    random_state: int | None = None,
)

Last Layer Retraining (DFR) per Kirichenko et al. 2023.

Retrains only the last linear layer (logistic regression) on a group-balanced subset of embeddings. The embeddings stay frozen; only the classifier is retrained. This simple approach can match or outperform more complex debiasing methods.

Parameters

C : float Inverse regularization strength. Smaller values = stronger regularization. Default 1.0. penalty : str Regularization penalty: "l1" or "l2". Default "l1". solver : str Solver for LogisticRegression. "liblinear" works for both L1 and L2. Default "liblinear". class_weight : str or dict Class weights for imbalanced task labels. "balanced" or None. Default "balanced". random_state : int, optional Seed for reproducibility.

Source code in shortcut_detect/mitigation/last_layer_retraining.py
def __init__(
    self,
    C: float = 1.0,
    penalty: str = "l1",
    solver: str = "liblinear",
    class_weight: str | dict | None = "balanced",
    random_state: int | None = None,
):
    if penalty not in ("l1", "l2"):
        raise ValueError("penalty must be 'l1' or 'l2'")
    self.C = float(C)
    self.penalty = penalty
    self.solver = solver
    self.class_weight = class_weight
    self.random_state = random_state

    self._scaler: StandardScaler | None = None
    self._classifier: LogisticRegression | None = None
    self._task_map: dict | None = None  # original label -> index
    self._embed_dim: int | None = None
    self._n_balanced: int | None = None
    self._n_groups: int | None = None
    self._fitted = False

Attributes

scaler_ property

scaler_: StandardScaler | None

Fitted StandardScaler (None if not fitted).

classifier_ property

classifier_: LogisticRegression | None

Fitted LogisticRegression (None if not fitted).

Functions

fit

fit(
    embeddings: ndarray,
    task_labels: ndarray,
    group_labels: ndarray,
) -> LastLayerRetraining

Build group-balanced subset and fit the logistic regression classifier.

Parameters

embeddings : np.ndarray Shape (n_samples, embed_dim). task_labels : np.ndarray Shape (n_samples,) – task/target labels to predict. group_labels : np.ndarray Shape (n_samples,) – protected/group labels for balancing.

Returns

self : LastLayerRetraining

Source code in shortcut_detect/mitigation/last_layer_retraining.py
def fit(
    self,
    embeddings: np.ndarray,
    task_labels: np.ndarray,
    group_labels: np.ndarray,
) -> LastLayerRetraining:
    """
    Build group-balanced subset and fit the logistic regression classifier.

    Parameters
    ----------
    embeddings : np.ndarray
        Shape (n_samples, embed_dim).
    task_labels : np.ndarray
        Shape (n_samples,) – task/target labels to predict.
    group_labels : np.ndarray
        Shape (n_samples,) – protected/group labels for balancing.

    Returns
    -------
    self : LastLayerRetraining
    """
    X = np.asarray(embeddings, dtype=np.float64)
    y = np.asarray(task_labels)
    g = np.asarray(group_labels)

    if X.ndim != 2:
        raise ValueError("embeddings must be 2D (n_samples, embed_dim)")
    if y.ndim != 1 or g.ndim != 1:
        raise ValueError("task_labels and group_labels must be 1D")
    if X.shape[0] != y.shape[0] or X.shape[0] != g.shape[0]:
        raise ValueError("embeddings, task_labels, and group_labels must have same length")

    # Map task labels to 0..n_classes-1
    y_uniq = np.unique(y)
    task_map = {v: i for i, v in enumerate(y_uniq)}
    y_idx = np.array([task_map[v] for v in y], dtype=np.int64)

    # Map group labels to 0..n_groups-1
    g_uniq = np.unique(g)
    g_map = {v: i for i, v in enumerate(g_uniq)}
    g_idx = np.array([g_map[v] for v in g], dtype=np.int64)

    n_groups = len(g_uniq)
    g_indices = [np.where(g_idx == grp)[0] for grp in range(n_groups)]
    min_g = min(len(gi) for gi in g_indices)

    if min_g == 0:
        raise ValueError(
            "At least one group has 0 samples. "
            "Cannot build balanced subset. Check group_labels."
        )

    # Build balanced subset: take min_g samples per group
    if self.random_state is not None:
        rng = np.random.default_rng(self.random_state)
    else:
        rng = np.random.default_rng()

    balanced_indices = []
    for gi in g_indices:
        idx = gi.copy()
        rng.shuffle(idx)
        balanced_indices.extend(idx[:min_g])

    X_bal = X[balanced_indices]
    y_bal = y_idx[balanced_indices]

    # Preprocess
    scaler = StandardScaler()
    X_bal_scaled = scaler.fit_transform(X_bal)

    # Fit classifier (liblinear is binary-only; saga supports multiclass+L1, lbfgs supports multiclass+L2)
    n_classes = len(np.unique(y_bal))
    solver = self.solver
    if n_classes > 2:
        solver = "saga" if self.penalty == "l1" else "lbfgs"

    clf = LogisticRegression(
        C=self.C,
        penalty=self.penalty,
        solver=solver,
        class_weight=self.class_weight,
        random_state=self.random_state,
        max_iter=1000,
    )
    clf.fit(X_bal_scaled, y_bal)

    self._scaler = scaler
    self._classifier = clf
    self._task_map = task_map
    self._embed_dim = X.shape[1]
    self._n_balanced = len(balanced_indices)
    self._n_groups = n_groups
    self._fitted = True
    return self

predict

predict(embeddings: ndarray) -> np.ndarray

Predict task labels for given embeddings.

Parameters

embeddings : np.ndarray Shape (n_samples, embed_dim). Must match embed_dim from fit.

Returns

np.ndarray Predicted task labels (original label values, not indices).

Source code in shortcut_detect/mitigation/last_layer_retraining.py
def predict(self, embeddings: np.ndarray) -> np.ndarray:
    """
    Predict task labels for given embeddings.

    Parameters
    ----------
    embeddings : np.ndarray
        Shape (n_samples, embed_dim). Must match embed_dim from fit.

    Returns
    -------
    np.ndarray
        Predicted task labels (original label values, not indices).
    """
    if not self._fitted or self._classifier is None or self._scaler is None:
        raise ValueError("LastLayerRetraining has not been fitted")
    X = np.asarray(embeddings, dtype=np.float64)
    if X.ndim != 2:
        raise ValueError("embeddings must be 2D")
    if X.shape[1] != self._embed_dim:
        raise ValueError(
            f"embed_dim {X.shape[1]} does not match fitted embed_dim {self._embed_dim}"
        )

    X_scaled = self._scaler.transform(X)
    pred_idx = self._classifier.predict(X_scaled)

    # Map indices back to original label values
    inv_map = {i: v for v, i in self._task_map.items()}
    return np.array([inv_map[int(p)] for p in pred_idx])

fit_predict

fit_predict(
    embeddings: ndarray,
    task_labels: ndarray,
    group_labels: ndarray,
) -> np.ndarray

Fit and predict in one call.

Source code in shortcut_detect/mitigation/last_layer_retraining.py
def fit_predict(
    self,
    embeddings: np.ndarray,
    task_labels: np.ndarray,
    group_labels: np.ndarray,
) -> np.ndarray:
    """Fit and predict in one call."""
    self.fit(embeddings, task_labels, group_labels)
    return self.predict(embeddings)