Skip to content

GCE Detector API

The GCEDetector class identifies minority/bias-conflicting samples by training a linear classifier with Generalized Cross Entropy loss and flagging high-loss samples.

Class Reference

GCEDetector

GCEDetector(
    q: float = 0.7,
    loss_percentile_threshold: float = 90.0,
    max_iter: int = 500,
    random_state: int | None = 42,
)

Detect minority/bias-conflicting samples via Generalized Cross Entropy.

Trains a linear classifier on embeddings with GCE loss (q ≈ 0.7). Samples with high per-sample GCE loss are flagged as minority or bias-conflicting, as they are harder for the biased classifier to fit.

Parameters:

Name Type Description Default
q float

GCE parameter in (0, 1]. q≈0.7 downweights easy samples and emphasizes hard/minority ones. Smaller q is more robust but harder to optimize.

0.7
loss_percentile_threshold float

Samples with loss >= this percentile (0–100) are labeled as minority/bias-conflicting. Default 90.

90.0
max_iter int

Maximum iterations for training the linear classifier.

500
random_state int | None

Random seed for reproducibility.

42
Source code in shortcut_detect/gce/gce_detector.py
def __init__(
    self,
    q: float = 0.7,
    loss_percentile_threshold: float = 90.0,
    max_iter: int = 500,
    random_state: int | None = 42,
) -> None:
    """
    Args:
        q: GCE parameter in (0, 1]. q≈0.7 downweights easy samples and
           emphasizes hard/minority ones. Smaller q is more robust but
           harder to optimize.
        loss_percentile_threshold: Samples with loss >= this percentile
           (0–100) are labeled as minority/bias-conflicting. Default 90.
        max_iter: Maximum iterations for training the linear classifier.
        random_state: Random seed for reproducibility.
    """
    if not 0 < q <= 1:
        raise ValueError("q must be in (0, 1]")
    if not 0 <= loss_percentile_threshold <= 100:
        raise ValueError("loss_percentile_threshold must be in [0, 100]")
    self.q = q
    self.loss_percentile_threshold = loss_percentile_threshold
    self.max_iter = max_iter
    self.random_state = random_state

    self.coef_: np.ndarray | None = None
    self.intercept_: np.ndarray | None = None
    self.classes_: np.ndarray | None = None
    self.per_sample_losses_: np.ndarray | None = None
    self.is_minority_: np.ndarray | None = None
    self.loss_threshold_: float | None = None
    self.report_: GCEDetectorReport | None = None

Functions

fit

fit(embeddings: ndarray, labels: ndarray) -> GCEDetector

Fit a GCE classifier and flag high-loss (minority/bias-conflicting) samples.

Parameters:

Name Type Description Default
embeddings ndarray

(n_samples, n_features) embedding matrix

required
labels ndarray

(n_samples,) integer or binary labels

required

Returns:

Type Description
GCEDetector

self

Source code in shortcut_detect/gce/gce_detector.py
def fit(
    self,
    embeddings: np.ndarray,
    labels: np.ndarray,
) -> GCEDetector:
    """
    Fit a GCE classifier and flag high-loss (minority/bias-conflicting) samples.

    Args:
        embeddings: (n_samples, n_features) embedding matrix
        labels: (n_samples,) integer or binary labels

    Returns:
        self
    """
    X = np.asarray(embeddings, dtype=float)
    y = np.asarray(labels)
    if X.ndim != 2:
        raise ValueError("embeddings must be 2D (n_samples, n_features)")
    if y.ndim != 1:
        raise ValueError("labels must be 1D")
    if X.shape[0] != y.shape[0]:
        raise ValueError("embeddings and labels must have same length")

    # Map labels to 0, 1, ..., n_classes-1
    classes = np.unique(y)
    self.classes_ = classes
    n_classes = len(classes)
    if n_classes < 2:
        raise ValueError("At least 2 distinct labels are required")
    y_int = np.searchsorted(classes, y)

    # Train linear classifier with GCE
    W, b = _train_linear_gce(X, y_int, n_classes, self.q, self.max_iter, self.random_state)
    self.coef_ = W
    self.intercept_ = b

    # Per-sample losses on training set
    logits = X @ W + b
    probs = _softmax_stable(logits)
    per_sample_losses = _gce_loss_per_sample(probs, y_int, self.q)
    self.per_sample_losses_ = per_sample_losses

    # Threshold: percentile of loss distribution
    threshold = float(np.percentile(per_sample_losses, self.loss_percentile_threshold))
    self.loss_threshold_ = threshold
    self.is_minority_ = per_sample_losses >= threshold

    n_samples = X.shape[0]
    n_minority = int(np.sum(self.is_minority_))
    minority_ratio = n_minority / n_samples if n_samples else 0.0
    loss_mean = float(np.mean(per_sample_losses))
    loss_std = float(np.std(per_sample_losses))
    loss_min = float(np.min(per_sample_losses))
    loss_max = float(np.max(per_sample_losses))

    risk_level, notes = _assess_risk(
        minority_ratio=minority_ratio,
        loss_mean=loss_mean,
        n_minority=n_minority,
    )

    self.report_ = GCEDetectorReport(
        n_samples=n_samples,
        n_minority=n_minority,
        minority_ratio=minority_ratio,
        loss_mean=loss_mean,
        loss_std=loss_std,
        loss_min=loss_min,
        loss_max=loss_max,
        threshold=threshold,
        q=self.q,
        risk_level=risk_level.value,
        notes=notes,
        reference="GCE bias detector (high-loss = minority/bias-conflicting)",
    )
    return self

predict

predict(embeddings: ndarray) -> np.ndarray

Predict class labels from embeddings.

Source code in shortcut_detect/gce/gce_detector.py
def predict(self, embeddings: np.ndarray) -> np.ndarray:
    """Predict class labels from embeddings."""
    self._ensure_fitted()
    logits = embeddings @ self.coef_ + self.intercept_
    pred_idx = np.argmax(logits, axis=1)
    return self.classes_[pred_idx]

get_minority_indices

get_minority_indices() -> np.ndarray

Return indices of samples flagged as minority/bias-conflicting.

Source code in shortcut_detect/gce/gce_detector.py
def get_minority_indices(self) -> np.ndarray:
    """Return indices of samples flagged as minority/bias-conflicting."""
    self._ensure_fitted()
    return np.where(self.is_minority_)[0]

Quick Reference

Constructor

GCEDetector(
    q: float = 0.7,
    loss_percentile_threshold: float = 90.0,
    max_iter: int = 500,
    random_state: int | None = 42,
)

Parameters

Parameter Type Default Description
q float 0.7 GCE parameter in (0, 1]
loss_percentile_threshold float 90.0 Percentile threshold for flagging minority samples
max_iter int 500 Maximum L-BFGS-B iterations
random_state int or None 42 Random seed

Methods

fit()

def fit(
    embeddings: np.ndarray,
    labels: np.ndarray,
) -> GCEDetector

Train a GCE classifier and flag high-loss (minority/bias-conflicting) samples.

Parameters:

Parameter Type Description
embeddings ndarray Shape (n_samples, n_features), 2D array
labels ndarray Shape (n_samples,), integer or binary labels

Returns: self

Raises:

  • ValueError if embeddings is not 2D or labels is not 1D
  • ValueError if embeddings and labels have different lengths
  • ValueError if fewer than 2 distinct labels

predict()

def predict(embeddings: np.ndarray) -> np.ndarray

Predict class labels from embeddings using the fitted linear classifier.

get_minority_indices()

def get_minority_indices() -> np.ndarray

Return indices of samples flagged as minority/bias-conflicting.

Attributes (after fit)

Attribute Type Description
coef_ ndarray Fitted weight matrix (n_features, n_classes)
intercept_ ndarray Fitted bias vector (n_classes,)
classes_ ndarray Unique class labels
per_sample_losses_ ndarray Per-sample GCE losses
is_minority_ ndarray Boolean mask of flagged samples
loss_threshold_ float Computed loss threshold
report_ GCEDetectorReport Detailed report dataclass

Usage Examples

Basic Usage

from shortcut_detect.gce import GCEDetector

detector = GCEDetector()
detector.fit(embeddings, labels)
print(detector.report_.risk_level)
print(detector.report_.n_minority)

Custom Parameters

detector = GCEDetector(
    q=0.5,
    loss_percentile_threshold=85.0,
    max_iter=1000,
)
detector.fit(embeddings, labels)
minority_idx = detector.get_minority_indices()

Via Unified ShortcutDetector

from shortcut_detect import ShortcutDetector

detector = ShortcutDetector(methods=["gce"])
detector.fit(embeddings, labels)
print(detector.summary())

See Also