def fit(
self,
*,
concept_sets: dict[str, np.ndarray],
random_set: ArrayLike,
target_activations: np.ndarray | None = None,
target_directional_derivatives: np.ndarray | None = None,
) -> CAVDetector:
"""Fit CAVs for concept-vs-random discrimination and compute TCAV metrics."""
concept_sets_arr = self._validate_concept_sets(concept_sets)
random_map = self._normalize_random_set(random_set, concept_sets_arr.keys())
dim = next(iter(concept_sets_arr.values())).shape[1]
target_acts_arr = self._validate_optional_matrix(
target_activations, dim, "target_activations"
)
target_dd_arr = self._validate_optional_matrix(
target_directional_derivatives,
dim,
"target_directional_derivatives",
)
concept_results: list[ConceptResult] = []
cav_vectors: dict[str, np.ndarray] = {}
for concept_name, concept_examples in concept_sets_arr.items():
random_examples = random_map[concept_name]
self._validate_min_examples(concept_examples, random_examples, concept_name)
X = np.vstack([concept_examples, random_examples])
y = np.concatenate(
[
np.ones(concept_examples.shape[0], dtype=int),
np.zeros(random_examples.shape[0], dtype=int),
]
)
X_train, X_test, y_train, y_test = train_test_split(
X,
y,
test_size=self.test_size,
stratify=y,
random_state=self.random_state,
)
model = LogisticRegression(max_iter=2000, random_state=self.random_state)
model.fit(X_train, y_train)
coef = np.asarray(model.coef_[0], dtype=float)
norm = float(np.linalg.norm(coef))
if norm <= 0.0:
raise ValueError(f"Concept '{concept_name}' produced a zero-norm CAV vector.")
cav_vector = coef / norm
cav_vectors[concept_name] = cav_vector
y_prob = model.predict_proba(X_test)[:, 1]
quality_auc = float(roc_auc_score(y_test, y_prob))
tcav_score: float | None = None
if target_dd_arr is not None:
directional = target_dd_arr @ cav_vector
tcav_score = float(np.mean(directional > 0.0))
activation_mean: float | None = None
activation_p95: float | None = None
if target_acts_arr is not None:
projections = target_acts_arr @ cav_vector
activation_mean = float(np.mean(projections))
activation_p95 = float(np.percentile(projections, 95))
flagged = (
tcav_score is not None
and quality_auc >= self.quality_threshold
and tcav_score >= self.shortcut_threshold
)
concept_results.append(
ConceptResult(
concept_name=concept_name,
n_concept=int(concept_examples.shape[0]),
n_random=int(random_examples.shape[0]),
quality_auc=quality_auc,
tcav_score=tcav_score,
activation_mean=activation_mean,
activation_p95=activation_p95,
flagged=bool(flagged),
)
)
self.cav_vectors_ = cav_vectors
self.concept_results_ = concept_results
tcav_values = [r.tcav_score for r in concept_results if r.tcav_score is not None]
quality_values = [r.quality_auc for r in concept_results]
n_flagged = int(sum(1 for r in concept_results if r.flagged))
if not tcav_values:
shortcut_detected = None
risk_level = "unknown"
notes = "Directional derivatives were not provided; TCAV scores unavailable."
else:
max_tcav = float(max(tcav_values))
if n_flagged > 0:
shortcut_detected = True
risk_level = "high" if max_tcav >= 0.75 else "moderate"
notes = "At least one concept exceeded quality and TCAV thresholds."
else:
shortcut_detected = False
risk_level = "low"
notes = "No concept exceeded both quality and TCAV thresholds."
metrics = {
"n_concepts": int(len(concept_results)),
"n_tested": int(len(tcav_values)),
"max_tcav_score": float(max(tcav_values)) if tcav_values else None,
"mean_tcav_score": float(np.mean(tcav_values)) if tcav_values else None,
"max_concept_quality": float(max(quality_values)) if quality_values else None,
"n_flagged": n_flagged,
"shortcut_threshold": self.shortcut_threshold,
"quality_threshold": self.quality_threshold,
}
metadata = {
"classifier": self.classifier,
"random_state": self.random_state,
"test_size": self.test_size,
"min_examples_per_set": self.min_examples_per_set,
"concept_names": [r.concept_name for r in concept_results],
"has_target_activations": target_acts_arr is not None,
"has_target_directional_derivatives": target_dd_arr is not None,
}
report = {
"per_concept": [self._concept_result_to_dict(result) for result in concept_results],
}
details = {
"cav_vectors": {name: vec.tolist() for name, vec in cav_vectors.items()},
}
self._set_results(
shortcut_detected=shortcut_detected,
risk_level=risk_level,
metrics=metrics,
notes=notes,
metadata=metadata,
report=report,
details=details,
)
self._is_fitted = True
return self