Skip to content

Causal Effect API

Causal effect regularization detector for shortcut detection via causal effect estimation.

Class Reference

CausalEffectDetector

CausalEffectDetector(
    *,
    effect_estimator: str = "direct",
    spurious_threshold: float = 0.1,
    random_state: int = 42
)

Bases: DetectorBase

Detect shortcut attributes via causal effect estimation.

Estimates the causal effect of each candidate attribute on the task label. Attributes with near-zero estimated effect are flagged as spurious (shortcuts), since changing them should not change the true label.

Parameters:

Name Type Description Default
effect_estimator str

Estimator for causal effect ("direct" supported).

'direct'
spurious_threshold float

Attributes with |TE_a| < threshold are flagged as spurious. Default 0.1.

0.1
random_state int

Random seed for reproducibility.

42
Source code in shortcut_detect/causal/causal_effect/src/detector.py
def __init__(
    self,
    *,
    effect_estimator: str = "direct",
    spurious_threshold: float = 0.1,
    random_state: int = 42,
) -> None:
    """
    Args:
        effect_estimator: Estimator for causal effect ("direct" supported).
        spurious_threshold: Attributes with |TE_a| < threshold are flagged
            as spurious. Default 0.1.
        random_state: Random seed for reproducibility.
    """
    super().__init__(method="causal_effect")
    if effect_estimator != "direct":
        raise ValueError(f"effect_estimator must be 'direct'; got '{effect_estimator}'.")
    if not 0.0 <= spurious_threshold <= 1.0:
        raise ValueError("spurious_threshold must be in [0, 1].")
    self.effect_estimator = effect_estimator
    self.spurious_threshold = spurious_threshold
    self.random_state = int(random_state)

    self.attribute_results_: list[AttributeEffectResult] = []

Functions

fit

fit(
    *,
    embeddings: ndarray,
    labels: ndarray,
    attributes: dict[str, ndarray],
    counterfactual_pairs: ndarray | list | None = None
) -> CausalEffectDetector

Fit causal effect estimator and detect spurious attributes.

Parameters:

Name Type Description Default
embeddings ndarray

(n_samples, n_features) representation space.

required
labels ndarray

(n_samples,) task labels (binary or multi-class).

required
attributes dict[str, ndarray]

Dict of attribute_name -> (n_samples,) values per sample. Binary (0/1) or categorical; multi-valued attributes are binarized.

required
counterfactual_pairs ndarray | list | None

Optional. For interventional data (Phase 2). Not used in current Direct estimator.

None

Returns:

Type Description
CausalEffectDetector

self

Source code in shortcut_detect/causal/causal_effect/src/detector.py
def fit(
    self,
    *,
    embeddings: np.ndarray,
    labels: np.ndarray,
    attributes: dict[str, np.ndarray],
    counterfactual_pairs: np.ndarray | list | None = None,
) -> CausalEffectDetector:
    """
    Fit causal effect estimator and detect spurious attributes.

    Args:
        embeddings: (n_samples, n_features) representation space.
        labels: (n_samples,) task labels (binary or multi-class).
        attributes: Dict of attribute_name -> (n_samples,) values per sample.
            Binary (0/1) or categorical; multi-valued attributes are binarized.
        counterfactual_pairs: Optional. For interventional data (Phase 2).
            Not used in current Direct estimator.

    Returns:
        self
    """
    X = np.asarray(embeddings, dtype=float)
    y = np.asarray(labels)
    if X.ndim != 2:
        raise ValueError("embeddings must be 2D (n_samples, n_features)")
    if y.ndim != 1:
        raise ValueError("labels must be 1D")
    if X.shape[0] != y.shape[0]:
        raise ValueError("embeddings and labels must have same length")

    if not isinstance(attributes, dict) or not attributes:
        raise ValueError("attributes must be a non-empty dict of name -> (n,) array")

    n_samples = X.shape[0]
    for name, arr in attributes.items():
        if not isinstance(name, str) or not name.strip():
            raise ValueError("Attribute names must be non-empty strings")
        arr = np.asarray(arr)
        if arr.ndim != 1 or arr.shape[0] != n_samples:
            raise ValueError(f"attributes['{name}'] must be 1D of length {n_samples}")

    if counterfactual_pairs is not None:
        # Placeholder for Phase 2; Direct estimator ignores for now
        pass

    # Map labels to 0, 1, ... for binary/multi-class
    classes = np.unique(y)
    if len(classes) < 2:
        raise ValueError("At least 2 distinct labels are required")
    y_int = np.searchsorted(classes, y)

    attribute_results: list[AttributeEffectResult] = []

    for attr_name, attr_values in attributes.items():
        effect, n_a0, n_a1 = self._estimate_causal_effect_direct(X, y_int, attr_values)
        is_spurious = abs(effect) < self.spurious_threshold
        attribute_results.append(
            AttributeEffectResult(
                attribute_name=attr_name,
                causal_effect=effect,
                is_spurious=is_spurious,
                n_samples_a0=n_a0,
                n_samples_a1=n_a1,
            )
        )

    self.attribute_results_ = attribute_results

    n_spurious = sum(1 for r in attribute_results if r.is_spurious)
    if n_spurious > 1:
        shortcut_detected = True
        risk_level = "high"
        notes = f"Multiple attributes ({n_spurious}) have low causal effect (spurious)."
    elif n_spurious == 1:
        shortcut_detected = True
        risk_level = "moderate"
        spurious_names = [r.attribute_name for r in attribute_results if r.is_spurious]
        notes = f"Attribute '{spurious_names[0]}' has low causal effect (spurious)."
    else:
        shortcut_detected = False
        risk_level = "low"
        notes = "No attribute has estimated causal effect below threshold."

    effects_dict = {r.attribute_name: r.causal_effect for r in attribute_results}
    ranking = sorted(
        attribute_results,
        key=lambda r: abs(r.causal_effect),
        reverse=True,
    )

    metrics = {
        "n_attributes": len(attribute_results),
        "n_spurious": n_spurious,
        "spurious_threshold": self.spurious_threshold,
        "per_attribute_effects": effects_dict,
        "attribute_ranking": [r.attribute_name for r in ranking],
    }

    metadata = {
        "effect_estimator": self.effect_estimator,
        "random_state": self.random_state,
        "attribute_names": list(attributes.keys()),
    }

    report = {
        "per_attribute": [
            {
                "attribute_name": r.attribute_name,
                "causal_effect": r.causal_effect,
                "is_spurious": r.is_spurious,
                "n_samples_a0": r.n_samples_a0,
                "n_samples_a1": r.n_samples_a1,
            }
            for r in attribute_results
        ],
    }

    self._set_results(
        shortcut_detected=shortcut_detected,
        risk_level=risk_level,
        metrics=metrics,
        notes=notes,
        metadata=metadata,
        report=report,
    )
    self._is_fitted = True
    return self

Loader Integration Example

from shortcut_detect import ShortcutDetector

loader_data = {
    "embeddings": embeddings,   # (n, d)
    "labels": labels,          # (n,)
    "attributes": {
        "race": race_labels,   # (n,) binary or categorical
        "color": color_labels,
    },
}

detector = ShortcutDetector(
    methods=["causal_effect"],
    causal_effect_spurious_threshold=0.1,
)
detector.fit_from_loaders({"causal_effect": loader_data})

print(detector.get_results()["causal_effect"]["metrics"])