HBAC Clustering¶
Hierarchical Bias-Aware Clustering detects if your embeddings naturally cluster by protected attributes rather than by task-relevant features.
How It Works¶
HBAC performs hierarchical clustering on embeddings and measures how well the resulting clusters align with protected group labels:
- Cluster embeddings using agglomerative clustering
- Measure purity - how homogeneous clusters are with respect to protected attributes
- Assess linearity - how well cluster boundaries separate groups
- Iterate - recursively analyze sub-clusters
graph TD
A[Embeddings] --> B[Agglomerative Clustering]
B --> C[Cluster Assignments]
C --> D[Measure Purity]
D --> E{Purity High?}
E -->|Yes| F[Shortcut Detected]
E -->|No| G[Continue to Sub-clusters]
G --> B
Basic Usage¶
from shortcut_detect import HBACDetector
# Create detector
detector = HBACDetector(
max_iterations=3, # Maximum depth of recursion
min_cluster_size=0.05 # Minimum cluster size as fraction
)
# Fit on embeddings and group labels
detector.fit(embeddings, group_labels)
# Access results
print(f"Purity: {detector.purity_:.2f}")
print(f"Linearity: {detector.linearity_:.2f}")
print(f"Shortcut detected: {detector.shortcut_detected_}")
Parameters¶
| Parameter | Type | Default | Description |
|---|---|---|---|
max_iterations |
int | 3 | Maximum recursion depth |
min_cluster_size |
float | 0.05 | Minimum cluster size (fraction) |
linkage |
str | 'ward' | Clustering linkage method |
distance_metric |
str | 'euclidean' | Distance metric for clustering |
Outputs¶
Attributes¶
| Attribute | Type | Description |
|---|---|---|
purity_ |
float | Cluster purity (0-1) |
linearity_ |
float | Linear separability score (0-1) |
shortcut_detected_ |
bool | Whether shortcut was detected |
cluster_labels_ |
ndarray | Cluster assignments |
dendrogram_ |
dict | Dendrogram data for visualization |
Interpretation¶
| Metric | Low Risk | Medium Risk | High Risk |
|---|---|---|---|
| Purity | < 0.6 | 0.6 - 0.8 | > 0.8 |
| Linearity | < 0.5 | 0.5 - 0.7 | > 0.7 |
High purity + high linearity = Strong shortcut evidence
Visualization¶
import matplotlib.pyplot as plt
from scipy.cluster.hierarchy import dendrogram
# Get dendrogram data
fig, ax = plt.subplots(figsize=(10, 6))
dendrogram(
detector.dendrogram_,
ax=ax,
leaf_rotation=90,
leaf_font_size=8
)
plt.title("HBAC Dendrogram")
plt.xlabel("Sample Index")
plt.ylabel("Distance")
plt.tight_layout()
plt.savefig("hbac_dendrogram.png")
Example with Synthetic Data¶
from shortcut_detect import HBACDetector, generate_linear_shortcut
# Generate data with strong shortcut
X, y_task, y_group = generate_linear_shortcut(
n_samples=500,
n_features=100,
shortcut_strength=0.9,
random_state=42
)
# Detect shortcut
detector = HBACDetector()
detector.fit(X, y_group)
print(f"Purity: {detector.purity_:.2f}") # Expected: ~0.90
print(f"Linearity: {detector.linearity_:.2f}") # Expected: ~0.85
print(f"Detected: {detector.shortcut_detected_}") # Expected: True
When to Use HBAC¶
Use HBAC when:
- You want interpretable hierarchical structure
- Your embeddings may have natural cluster boundaries
- You need fast analysis (no training required)
- You want to visualize bias structure
Don't use HBAC when:
- Groups are uniformly mixed (no clusters)
- You have very few samples (< 50)
- Groups have highly overlapping distributions
Advanced Configuration¶
Custom Linkage¶
# Different linkage methods
detector = HBACDetector(linkage='complete') # Maximum linkage
detector = HBACDetector(linkage='average') # Average linkage
detector = HBACDetector(linkage='ward') # Ward's method (default)
Custom Distance Metric¶
Fine-grained Iteration Control¶
detector = HBACDetector(
max_iterations=5, # Deeper analysis
min_cluster_size=0.02, # Allow smaller clusters
)
Theory¶
HBAC is based on the observation that if a model has learned shortcuts, embeddings from the same protected group will be more similar to each other than to embeddings from other groups.
Purity measures this clustering:
$$\text{Purity} = \frac{1}{N} \sum_{k=1}^{K} \max_j |c_k \cap g_j|$$
Where $c_k$ is cluster $k$ and $g_j$ is group $j$.
Linearity measures how well a linear classifier can separate the clusters:
$$\text{Linearity} = \text{Accuracy}(\text{LinearSVM}(X, \text{clusters}))$$
See Also¶
- Probe-based Detection - Complementary classifier approach
- API Reference - Full API documentation
- Overview - Compare all methods