def fit(
self,
embeddings: np.ndarray,
labels: np.ndarray,
extra_labels: dict[str, np.ndarray],
) -> IntersectionalDetector:
"""Compute fairness metrics across demographic intersections."""
if extra_labels is None or len(extra_labels) < 2:
raise ValueError(
"IntersectionalDetector requires extra_labels with at least "
"2 demographic attribute arrays (e.g., {'race': ..., 'gender': ...})."
)
# Determine which attributes to use
candidate_keys = [k for k in extra_labels.keys() if k not in _RESERVED_EXTRA_LABELS]
if len(candidate_keys) < 2:
raise ValueError(
"Need at least 2 demographic attributes in extra_labels for "
f"intersectional analysis. Found: {candidate_keys}."
)
if self.intersection_attributes is not None:
attr_names = [a for a in self.intersection_attributes if a in extra_labels]
if len(attr_names) < 2:
raise ValueError(
f"intersection_attributes {self.intersection_attributes} "
f"must include at least 2 keys present in extra_labels: "
f"{list(extra_labels.keys())}."
)
else:
attr_names = candidate_keys[:2] # Use first 2
self.attribute_names_ = attr_names
if embeddings.ndim != 2:
raise ValueError("Embeddings must be 2D (n_samples, embedding_dim).")
if labels.ndim != 1:
raise ValueError("Labels must be 1D.")
if embeddings.shape[0] != labels.shape[0]:
raise ValueError("Embeddings and labels must align.")
unique_labels = np.unique(labels)
if unique_labels.size != 2:
raise ValueError("Intersectional analysis requires binary labels.")
# Build intersection labels
intersection_labels, valid_mask = _build_intersection_labels(
extra_labels, attr_names, self.separator
)
# Count samples per intersection
unique_intersections, counts = np.unique(
intersection_labels[valid_mask], return_counts=True
)
large_groups = {
u
for u, c in zip(unique_intersections, counts, strict=False)
if c >= self.min_group_size
}
if len(large_groups) < 2:
self.shortcut_detected_ = None
self.report_ = IntersectionalReport(
intersection_metrics={},
attribute_names=attr_names,
tpr_gap=float("nan"),
fpr_gap=float("nan"),
dp_gap=float("nan"),
overall_accuracy=float("nan"),
overall_positive_rate=float("nan"),
reference="Buolamwini & Gebru 2018",
risk_level=RiskLevel.UNKNOWN.value,
notes=(
f"Fewer than 2 intersection groups with >= {self.min_group_size} "
"samples. Cannot compute intersectional fairness metrics."
),
)
self._finalize_results()
self._is_fitted = True
return self
# Build mask for samples in large groups only
in_large = np.array([g in large_groups for g in intersection_labels]) & valid_mask
X_sub = embeddings[in_large]
y_sub = labels[in_large]
groups_sub = intersection_labels[in_large]
# Run EqualizedOddsDetector
eo = EqualizedOddsDetector(
estimator=clone(self.estimator),
min_group_size=self.min_group_size,
tpr_gap_threshold=self.tpr_gap_threshold,
fpr_gap_threshold=self.fpr_gap_threshold,
)
eo.fit(X_sub, y_sub, groups_sub)
# Run DemographicParityDetector
dp = DemographicParityDetector(
estimator=clone(self.estimator),
min_group_size=self.min_group_size,
dp_gap_threshold=self.dp_gap_threshold,
)
dp.fit(X_sub, y_sub, groups_sub)
# Merge metrics into intersection_metrics
eo_report: EqualizedOddsReport = eo.report_
dp_report: DemographicParityReport = dp.report_
self.intersection_metrics_ = {}
for group in eo_report.group_metrics:
eo_m = eo_report.group_metrics[group]
dp_m = dp_report.group_rates.get(group, {})
self.intersection_metrics_[group] = {
"tpr": eo_m["tpr"],
"fpr": eo_m["fpr"],
"positive_rate": dp_m.get("positive_rate", float("nan")),
"support": eo_m["support"],
}
self.tpr_gap_ = eo_report.tpr_gap
self.fpr_gap_ = eo_report.fpr_gap
self.dp_gap_ = dp_report.dp_gap
self.overall_accuracy_ = eo_report.overall_accuracy
self.overall_positive_rate_ = dp_report.overall_positive_rate
risk_level, notes = self._assess_risk()
self.report_ = IntersectionalReport(
intersection_metrics=self.intersection_metrics_,
attribute_names=self.attribute_names_,
tpr_gap=self.tpr_gap_,
fpr_gap=self.fpr_gap_,
dp_gap=self.dp_gap_,
overall_accuracy=self.overall_accuracy_,
overall_positive_rate=self.overall_positive_rate_,
reference="Buolamwini & Gebru 2018",
risk_level=risk_level,
notes=notes,
)
self._finalize_results()
self._is_fitted = True
return self