Skip to content

Early Epoch Clustering Detector API

The EarlyEpochClusteringDetector detects shortcut bias by clustering early-epoch representations and identifying imbalanced clusters that suggest shortcut reliance.

Class Reference

EarlyEpochClusteringDetector

EarlyEpochClusteringDetector(
    n_clusters: int = 4,
    cluster_method: str = "kmeans",
    min_cluster_ratio: float = 0.1,
    entropy_threshold: float = 0.7,
    random_state: int = 42,
)

Detect shortcut bias using early-epoch clustering (SPARE 2023).

Source code in shortcut_detect/training/early_epoch_clustering.py
def __init__(
    self,
    n_clusters: int = 4,
    cluster_method: str = "kmeans",
    min_cluster_ratio: float = 0.1,
    entropy_threshold: float = 0.7,
    random_state: int = 42,
) -> None:
    self.n_clusters = n_clusters
    self.cluster_method = cluster_method
    self.min_cluster_ratio = min_cluster_ratio
    self.entropy_threshold = entropy_threshold
    self.random_state = random_state

    self.cluster_labels_: np.ndarray | None = None
    self.report_: EarlyEpochClusteringReport | None = None

Functions

fit

fit(
    representations: ndarray,
    labels: ndarray | None = None,
    n_epochs: int = 1,
) -> EarlyEpochClusteringDetector

Cluster early-epoch representations and compute bias indicators.

Source code in shortcut_detect/training/early_epoch_clustering.py
def fit(
    self,
    representations: np.ndarray,
    labels: np.ndarray | None = None,
    n_epochs: int = 1,
) -> EarlyEpochClusteringDetector:
    """Cluster early-epoch representations and compute bias indicators."""
    if representations is None:
        raise ValueError("representations must be provided for early-epoch clustering")
    if representations.ndim != 2:
        raise ValueError("representations must be 2D (n_samples, n_features)")

    n_samples = representations.shape[0]
    if n_samples < 2:
        raise ValueError("Need at least 2 samples for clustering")
    if self.n_clusters < 2:
        raise ValueError("n_clusters must be >= 2")
    if self.n_clusters > n_samples:
        raise ValueError("n_clusters cannot exceed number of samples")

    if labels is not None:
        if labels.ndim != 1:
            raise ValueError("labels must be 1D")
        if labels.shape[0] != n_samples:
            raise ValueError("labels must align with representations")

    if self.cluster_method != "kmeans":
        raise ValueError(f"Unsupported cluster_method: {self.cluster_method}")

    kmeans = KMeans(n_clusters=self.n_clusters, n_init=10, random_state=self.random_state)
    cluster_labels = kmeans.fit_predict(representations)
    self.cluster_labels_ = cluster_labels

    counts = np.bincount(cluster_labels, minlength=self.n_clusters).astype(float)
    ratios = counts / float(n_samples)

    entropy = _normalized_entropy(ratios)
    minority_ratio = (
        float(np.min(ratios) / np.max(ratios)) if np.max(ratios) > 0 else float("nan")
    )
    largest_gap = float(np.max(ratios) - np.min(ratios)) if ratios.size else float("nan")

    agreement = None
    if labels is not None:
        agreement = _cluster_label_agreement(cluster_labels, labels, self.n_clusters)

    risk_level, notes = _assess_risk(
        minority_ratio=minority_ratio,
        entropy=entropy,
        min_cluster_ratio=self.min_cluster_ratio,
        entropy_threshold=self.entropy_threshold,
    )

    self.report_ = EarlyEpochClusteringReport(
        n_epochs=n_epochs,
        cluster_method=self.cluster_method,
        n_clusters=self.n_clusters,
        cluster_sizes={str(i): int(counts[i]) for i in range(self.n_clusters)},
        cluster_ratios={str(i): float(ratios[i]) for i in range(self.n_clusters)},
        size_entropy=float(entropy),
        minority_ratio=float(minority_ratio),
        largest_gap=float(largest_gap),
        cluster_label_agreement=agreement,
        risk_level=risk_level,
        notes=notes,
        reference="Yang et al. 2023 (SPARE)",
    )
    return self

Quick Reference

Constructor

EarlyEpochClusteringDetector(
    n_clusters: int = 4,
    min_cluster_ratio: float = 0.1,
    entropy_threshold: float = 0.7,
)

Parameters

Parameter Type Default Description
n_clusters int 4 Number of clusters to form
min_cluster_ratio float 0.1 Minimum cluster size ratio for balance check
entropy_threshold float 0.7 Entropy threshold for shortcut detection

Methods

fit()

def fit(
    early_epoch_reps: np.ndarray,
    labels: np.ndarray = None,
    n_epochs: int = 1,
) -> EarlyEpochClusteringDetector

Fit the clustering detector on early-epoch representations.

Parameters:

Parameter Type Description
early_epoch_reps ndarray Shape (n_samples, n_features), early-epoch representations
labels ndarray or None Optional labels for cluster-label agreement
n_epochs int Number of early epochs the representations come from

Returns: self

Attributes (after fit)

Attribute Type Description
report_ EarlyEpochClusteringReport Report with cluster statistics and risk assessment

Usage Examples

Basic Usage

from shortcut_detect.training import EarlyEpochClusteringDetector

detector = EarlyEpochClusteringDetector(n_clusters=4)
detector.fit(early_epoch_reps, labels=labels, n_epochs=1)
print(detector.report_)

Custom Parameters

detector = EarlyEpochClusteringDetector(
    n_clusters=6,
    min_cluster_ratio=0.05,
    entropy_threshold=0.6,
)
detector.fit(early_epoch_reps)

Via Unified ShortcutDetector

from shortcut_detect import ShortcutDetector

detector = ShortcutDetector(methods=["early_epoch_clustering"])
detector.fit(embeddings, labels)
print(detector.summary())

See Also