Early Epoch Clustering Detector API¶
The EarlyEpochClusteringDetector detects shortcut bias by clustering early-epoch representations and identifying imbalanced clusters that suggest shortcut reliance.
Class Reference¶
EarlyEpochClusteringDetector
¶
EarlyEpochClusteringDetector(
n_clusters: int = 4,
cluster_method: str = "kmeans",
min_cluster_ratio: float = 0.1,
entropy_threshold: float = 0.7,
random_state: int = 42,
)
Detect shortcut bias using early-epoch clustering (SPARE 2023).
Source code in shortcut_detect/training/early_epoch_clustering.py
Functions¶
fit
¶
fit(
representations: ndarray,
labels: ndarray | None = None,
n_epochs: int = 1,
) -> EarlyEpochClusteringDetector
Cluster early-epoch representations and compute bias indicators.
Source code in shortcut_detect/training/early_epoch_clustering.py
Quick Reference¶
Constructor¶
EarlyEpochClusteringDetector(
n_clusters: int = 4,
min_cluster_ratio: float = 0.1,
entropy_threshold: float = 0.7,
)
Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
n_clusters |
int | 4 | Number of clusters to form |
min_cluster_ratio |
float | 0.1 | Minimum cluster size ratio for balance check |
entropy_threshold |
float | 0.7 | Entropy threshold for shortcut detection |
Methods¶
fit()¶
def fit(
early_epoch_reps: np.ndarray,
labels: np.ndarray = None,
n_epochs: int = 1,
) -> EarlyEpochClusteringDetector
Fit the clustering detector on early-epoch representations.
Parameters:
| Parameter | Type | Description |
|---|---|---|
early_epoch_reps |
ndarray | Shape (n_samples, n_features), early-epoch representations |
labels |
ndarray or None | Optional labels for cluster-label agreement |
n_epochs |
int | Number of early epochs the representations come from |
Returns: self
Attributes (after fit)¶
| Attribute | Type | Description |
|---|---|---|
report_ |
EarlyEpochClusteringReport | Report with cluster statistics and risk assessment |
Usage Examples¶
Basic Usage¶
from shortcut_detect.training import EarlyEpochClusteringDetector
detector = EarlyEpochClusteringDetector(n_clusters=4)
detector.fit(early_epoch_reps, labels=labels, n_epochs=1)
print(detector.report_)
Custom Parameters¶
detector = EarlyEpochClusteringDetector(
n_clusters=6,
min_cluster_ratio=0.05,
entropy_threshold=0.6,
)
detector.fit(early_epoch_reps)
Via Unified ShortcutDetector¶
from shortcut_detect import ShortcutDetector
detector = ShortcutDetector(methods=["early_epoch_clustering"])
detector.fit(embeddings, labels)
print(detector.summary())