Skip to content

CAV API

Concept Activation Vector detector for concept-level shortcut testing.

Class Reference

CAVDetector

CAVDetector(
    *,
    config: CAVConfig | None = None,
    classifier: str = "logreg",
    random_state: int = 42,
    test_size: float = 0.2,
    min_examples_per_set: int = 20,
    shortcut_threshold: float = 0.6,
    quality_threshold: float = 0.7
)

Bases: DetectorBase

Test shortcut concepts using Concept Activation Vectors (Kim et al., 2018).

Source code in shortcut_detect/xai/cav/src/detector.py
def __init__(
    self,
    *,
    config: CAVConfig | None = None,
    classifier: str = "logreg",
    random_state: int = 42,
    test_size: float = 0.2,
    min_examples_per_set: int = 20,
    shortcut_threshold: float = 0.6,
    quality_threshold: float = 0.7,
) -> None:
    super().__init__(method="cav")
    cfg = config or CAVConfig(
        classifier=classifier,
        random_state=int(random_state),
        test_size=float(test_size),
        min_examples_per_set=int(min_examples_per_set),
        shortcut_threshold=float(shortcut_threshold),
        quality_threshold=float(quality_threshold),
    )
    self.config = cfg
    if cfg.classifier != "logreg":
        raise ValueError("Only classifier='logreg' is currently supported for CAVDetector.")
    if not 0.0 < cfg.test_size < 1.0:
        raise ValueError("test_size must be in (0, 1).")
    if cfg.min_examples_per_set < 2:
        raise ValueError("min_examples_per_set must be >= 2.")
    if not 0.0 <= cfg.shortcut_threshold <= 1.0:
        raise ValueError("shortcut_threshold must be in [0, 1].")
    if not 0.0 <= cfg.quality_threshold <= 1.0:
        raise ValueError("quality_threshold must be in [0, 1].")

    self.classifier = cfg.classifier
    self.random_state = int(cfg.random_state)
    self.test_size = float(cfg.test_size)
    self.min_examples_per_set = int(cfg.min_examples_per_set)
    self.shortcut_threshold = float(cfg.shortcut_threshold)
    self.quality_threshold = float(cfg.quality_threshold)

    self.cav_vectors_: dict[str, np.ndarray] = {}
    self.concept_results_: list[ConceptResult] = []

Functions

fit

fit(
    *,
    concept_sets: dict[str, ndarray],
    random_set: ArrayLike,
    target_activations: ndarray | None = None,
    target_directional_derivatives: ndarray | None = None
) -> CAVDetector

Fit CAVs for concept-vs-random discrimination and compute TCAV metrics.

Source code in shortcut_detect/xai/cav/src/detector.py
def fit(
    self,
    *,
    concept_sets: dict[str, np.ndarray],
    random_set: ArrayLike,
    target_activations: np.ndarray | None = None,
    target_directional_derivatives: np.ndarray | None = None,
) -> CAVDetector:
    """Fit CAVs for concept-vs-random discrimination and compute TCAV metrics."""
    concept_sets_arr = self._validate_concept_sets(concept_sets)
    random_map = self._normalize_random_set(random_set, concept_sets_arr.keys())

    dim = next(iter(concept_sets_arr.values())).shape[1]
    target_acts_arr = self._validate_optional_matrix(
        target_activations, dim, "target_activations"
    )
    target_dd_arr = self._validate_optional_matrix(
        target_directional_derivatives,
        dim,
        "target_directional_derivatives",
    )

    concept_results: list[ConceptResult] = []
    cav_vectors: dict[str, np.ndarray] = {}

    for concept_name, concept_examples in concept_sets_arr.items():
        random_examples = random_map[concept_name]
        self._validate_min_examples(concept_examples, random_examples, concept_name)

        X = np.vstack([concept_examples, random_examples])
        y = np.concatenate(
            [
                np.ones(concept_examples.shape[0], dtype=int),
                np.zeros(random_examples.shape[0], dtype=int),
            ]
        )
        X_train, X_test, y_train, y_test = train_test_split(
            X,
            y,
            test_size=self.test_size,
            stratify=y,
            random_state=self.random_state,
        )

        model = LogisticRegression(max_iter=2000, random_state=self.random_state)
        model.fit(X_train, y_train)

        coef = np.asarray(model.coef_[0], dtype=float)
        norm = float(np.linalg.norm(coef))
        if norm <= 0.0:
            raise ValueError(f"Concept '{concept_name}' produced a zero-norm CAV vector.")
        cav_vector = coef / norm
        cav_vectors[concept_name] = cav_vector

        y_prob = model.predict_proba(X_test)[:, 1]
        quality_auc = float(roc_auc_score(y_test, y_prob))

        tcav_score: float | None = None
        if target_dd_arr is not None:
            directional = target_dd_arr @ cav_vector
            tcav_score = float(np.mean(directional > 0.0))

        activation_mean: float | None = None
        activation_p95: float | None = None
        if target_acts_arr is not None:
            projections = target_acts_arr @ cav_vector
            activation_mean = float(np.mean(projections))
            activation_p95 = float(np.percentile(projections, 95))

        flagged = (
            tcav_score is not None
            and quality_auc >= self.quality_threshold
            and tcav_score >= self.shortcut_threshold
        )

        concept_results.append(
            ConceptResult(
                concept_name=concept_name,
                n_concept=int(concept_examples.shape[0]),
                n_random=int(random_examples.shape[0]),
                quality_auc=quality_auc,
                tcav_score=tcav_score,
                activation_mean=activation_mean,
                activation_p95=activation_p95,
                flagged=bool(flagged),
            )
        )

    self.cav_vectors_ = cav_vectors
    self.concept_results_ = concept_results

    tcav_values = [r.tcav_score for r in concept_results if r.tcav_score is not None]
    quality_values = [r.quality_auc for r in concept_results]
    n_flagged = int(sum(1 for r in concept_results if r.flagged))

    if not tcav_values:
        shortcut_detected = None
        risk_level = "unknown"
        notes = "Directional derivatives were not provided; TCAV scores unavailable."
    else:
        max_tcav = float(max(tcav_values))
        if n_flagged > 0:
            shortcut_detected = True
            risk_level = "high" if max_tcav >= 0.75 else "moderate"
            notes = "At least one concept exceeded quality and TCAV thresholds."
        else:
            shortcut_detected = False
            risk_level = "low"
            notes = "No concept exceeded both quality and TCAV thresholds."

    metrics = {
        "n_concepts": int(len(concept_results)),
        "n_tested": int(len(tcav_values)),
        "max_tcav_score": float(max(tcav_values)) if tcav_values else None,
        "mean_tcav_score": float(np.mean(tcav_values)) if tcav_values else None,
        "max_concept_quality": float(max(quality_values)) if quality_values else None,
        "n_flagged": n_flagged,
        "shortcut_threshold": self.shortcut_threshold,
        "quality_threshold": self.quality_threshold,
    }

    metadata = {
        "classifier": self.classifier,
        "random_state": self.random_state,
        "test_size": self.test_size,
        "min_examples_per_set": self.min_examples_per_set,
        "concept_names": [r.concept_name for r in concept_results],
        "has_target_activations": target_acts_arr is not None,
        "has_target_directional_derivatives": target_dd_arr is not None,
    }

    report = {
        "per_concept": [self._concept_result_to_dict(result) for result in concept_results],
    }

    details = {
        "cav_vectors": {name: vec.tolist() for name, vec in cav_vectors.items()},
    }

    self._set_results(
        shortcut_detected=shortcut_detected,
        risk_level=risk_level,
        metrics=metrics,
        notes=notes,
        metadata=metadata,
        report=report,
        details=details,
    )
    self._is_fitted = True
    return self

Loader Integration Example

from shortcut_detect import ShortcutDetector

loader_data = {
    "concept_sets": {
        "shortcut_concept": concept_examples,
        "control_concept": control_examples,
    },
    "random_set": random_examples,
    "target_activations": target_activations,
    "target_directional_derivatives": target_directional_derivatives,
}

detector = ShortcutDetector(methods=["cav"])
detector.fit_from_loaders({"cav": loader_data})

print(detector.get_results()["cav"]["metrics"])