Skip to content

ShortcutDetector API

The ShortcutDetector class provides a unified interface for running all shortcut detection methods.

Class Reference

ShortcutDetector

ShortcutDetector(
    methods: list[str] = None,
    seed: int = 42,
    condition_name: str = "indicator_count",
    condition_kwargs: dict[str, Any] | None = None,
    **kwargs
)

Unified interface for detecting shortcuts using multiple methods.

Combines five detection approaches: 1. HBAC (Hierarchical Bias-Aware Clustering) 2. Probe-based detection (train classifiers on embeddings) 3. Statistical testing (feature-wise group differences) 4. Geometric shortcut analysis (subspace monitoring) 5. Equalized odds gap analysis (fairness-based)

Example

detector = ShortcutDetector( ... methods=['hbac', 'probe', 'statistical', 'geometric', 'equalized_odds'] ... ) results = detector.fit(embeddings, labels, group_labels=groups) print(detector.summary())

Initialize unified shortcut detector.

Parameters:

Name Type Description Default
methods list[str]

List of methods to use. Options: 'hbac', 'probe', 'statistical'

None
seed int

Random seed for reproducibility

42
condition_name str

Registered overall assessment condition name

'indicator_count'
condition_kwargs dict[str, Any] | None

Keyword arguments passed to the selected condition

None
**kwargs

Additional arguments passed to individual detectors

{}
Source code in shortcut_detect/unified.py
def __init__(
    self,
    methods: list[str] = None,
    seed: int = 42,
    condition_name: str = "indicator_count",
    condition_kwargs: dict[str, Any] | None = None,
    **kwargs,
):
    """
    Initialize unified shortcut detector.

    Args:
        methods: List of methods to use. Options: 'hbac', 'probe', 'statistical'
        seed: Random seed for reproducibility
        condition_name: Registered overall assessment condition name
        condition_kwargs: Keyword arguments passed to the selected condition
        **kwargs: Additional arguments passed to individual detectors
    """
    if methods is None:
        methods = ["hbac", "probe", "statistical"]
    self.methods = methods
    self.seed = seed
    self.condition_name = condition_name
    self.condition_kwargs = dict(condition_kwargs or {})
    create_condition(self.condition_name, **self.condition_kwargs)
    self.kwargs = kwargs
    set_seed(seed)

    # Results storage
    self.results_ = {}
    self.embeddings_ = None
    self.labels_ = None
    self.group_labels_ = None
    self.protected_labels_ = None
    self.embedding_source_ = None
    self.raw_inputs_ = None
    self.embedding_metadata_ = {}
    self.splits_ = None
    self.extra_labels_ = None

    # Initialize detectors
    self._init_detectors()

Functions

fit

fit(
    embeddings: ndarray | None,
    labels: ndarray,
    group_labels: ndarray | None = None,
    feature_names: list[str] | None = None,
    raw_inputs: Sequence[Any] | None = None,
    embedding_source: EmbeddingSource | None = None,
    embedding_cache_path: str | None = None,
    force_embedding_recompute: bool = False,
    splits: dict[str, ndarray] | None = None,
    extra_labels: dict[str, ndarray] | None = None,
) -> ShortcutDetector

Fit all detection methods on embeddings.

Parameters:

Name Type Description Default
embeddings ndarray | None

(n_samples, embedding_dim) array. Optional when using embedding-only mode.

required
labels ndarray

(n_samples,) target labels

required
group_labels ndarray | None

(n_samples,) group labels (e.g., demographic attributes) If None, uses labels for group-based tests

None
feature_names list[str] | None

Optional list of feature names

None
raw_inputs Sequence[Any] | None

Optional raw inputs (text, ids, etc.). Required when embedding_source is provided.

None
embedding_source EmbeddingSource | None

Source used to generate embeddings when they are not pre-computed.

None
embedding_cache_path str | None

Optional path to cache generated embeddings.

None
force_embedding_recompute bool

Whether to ignore cached embeddings.

False
splits dict[str, ndarray] | None

Optional dictionary of named index sets for semi-supervised methods. Expected keys: 'train_l' (labeled), 'train_u' (unlabeled).

None
extra_labels dict[str, ndarray] | None

Optional dictionary of named per-sample arrays for additional supervision signals (e.g., 'spurious' labels). Use -1 for unknown labels.

None

Returns:

Type Description
ShortcutDetector

self

Source code in shortcut_detect/unified.py
def fit(
    self,
    embeddings: np.ndarray | None,
    labels: np.ndarray,
    group_labels: np.ndarray | None = None,
    feature_names: list[str] | None = None,
    raw_inputs: Sequence[Any] | None = None,
    embedding_source: "EmbeddingSource | None" = None,
    embedding_cache_path: str | None = None,
    force_embedding_recompute: bool = False,
    splits: dict[str, np.ndarray] | None = None,
    extra_labels: dict[str, np.ndarray] | None = None,
) -> "ShortcutDetector":
    """
    Fit all detection methods on embeddings.

    Args:
        embeddings: (n_samples, embedding_dim) array. Optional when using
            embedding-only mode.
        labels: (n_samples,) target labels
        group_labels: (n_samples,) group labels (e.g., demographic attributes)
                     If None, uses `labels` for group-based tests
        feature_names: Optional list of feature names
        raw_inputs: Optional raw inputs (text, ids, etc.). Required when
            `embedding_source` is provided.
        embedding_source: Source used to generate embeddings when they are
            not pre-computed.
        embedding_cache_path: Optional path to cache generated embeddings.
        force_embedding_recompute: Whether to ignore cached embeddings.
        splits: Optional dictionary of named index sets for semi-supervised methods.
            Expected keys: 'train_l' (labeled), 'train_u' (unlabeled).
        extra_labels: Optional dictionary of named per-sample arrays for additional
            supervision signals (e.g., 'spurious' labels). Use -1 for unknown labels.

    Returns:
        self
    """
    if embeddings is None:
        if embedding_source is None or raw_inputs is None:
            raise ValueError(
                "Provide either `embeddings` directly or both `raw_inputs` and `embedding_source`."
            )
        embeddings = self._generate_embeddings_from_source(
            raw_inputs,
            embedding_source,
            cache_path=embedding_cache_path,
            force_recompute=force_embedding_recompute,
        )
    else:
        if embedding_source is not None or raw_inputs is not None:
            warnings.warn(
                "embeddings were provided directly; ignoring raw_inputs/embedding_source parameters.",
                stacklevel=2,
            )
        self.embedding_source_ = None
        self.raw_inputs_ = None
        self.embedding_metadata_ = {"mode": "precomputed", "cached": False}

    # Validate shapes, finite values, and minimum requirements
    embeddings, labels = validate_embeddings_labels(
        embeddings, labels, min_samples=4, min_classes=0, check_finite=True
    )
    # Require at least 2 distinct classes in the effective group signal
    effective_groups = group_labels if group_labels is not None else labels
    if len(np.unique(effective_groups)) < 2:
        raise ValueError(
            "At least 2 distinct classes required in labels "
            "(or group_labels when provided), got 1."
        )
    n_samples = embeddings.shape[0]

    # Validate splits if provided
    if splits is not None:
        if not isinstance(splits, dict):
            raise TypeError("splits must be a dictionary")
        for split_name, indices in splits.items():
            if not isinstance(indices, np.ndarray):
                raise TypeError(
                    f"Split '{split_name}' must be a numpy array, got {type(indices)}"
                )
            if indices.ndim != 1:
                raise ValueError(
                    f"Split '{split_name}' must be 1D array of indices, got shape {indices.shape}"
                )
            if len(indices) > 0 and (indices.min() < 0 or indices.max() >= n_samples):
                raise ValueError(
                    f"Split '{split_name}' contains invalid indices (must be in [0, {n_samples}))"
                )

    # Validate extra_labels if provided
    if extra_labels is not None:
        if not isinstance(extra_labels, dict):
            raise TypeError("extra_labels must be a dictionary")
        for label_name, label_array in extra_labels.items():
            if not isinstance(label_array, np.ndarray):
                raise TypeError(
                    f"Extra label '{label_name}' must be a numpy array, got {type(label_array)}"
                )
            if label_array.ndim != 1:
                raise ValueError(
                    f"Extra label '{label_name}' must be 1D, got shape {label_array.shape}"
                )
            if len(label_array) != n_samples:
                raise ValueError(
                    f"Extra label '{label_name}' must have same length as embeddings: {len(label_array)} != {n_samples}"
                )

    self.embeddings_ = embeddings
    self.labels_ = labels
    self.protected_labels_ = group_labels
    self.group_labels_ = group_labels if group_labels is not None else labels

    self.splits_ = splits
    self.extra_labels_ = extra_labels
    attribute_sources = _get_attribute_sources(group_labels, extra_labels)
    use_multi = len(attribute_sources) > 1

    for method in self.methods:
        builder = self.detector_builders_.get(method)
        if builder is None:
            continue

        if method in _MULTI_ATTRIBUTE_METHODS and use_multi:
            by_attribute: dict[str, dict] = {}
            for attr_name, attr_array in attribute_sources.items():
                try:
                    res = builder.run(
                        embeddings=embeddings,
                        labels=labels,
                        group_labels=attr_array,
                        feature_names=feature_names,
                        protected_labels=attr_array,
                        splits=self.splits_,
                        extra_labels=self.extra_labels_,
                    )
                except Exception as exc:
                    warnings.warn(
                        f"Detection for '{method}' (attribute '{attr_name}') failed: {exc}",
                        stacklevel=2,
                    )
                    res = {"success": False, "error": str(exc)}
                apply_standardized_risk(method, res)
                by_attribute[attr_name] = res
            result = {
                "success": any(r.get("success") for r in by_attribute.values()),
                "by_attribute": by_attribute,
            }
            apply_standardized_risk(method, result)
        else:
            eff_group = self.group_labels_
            eff_protected = self.protected_labels_
            if method in _MULTI_ATTRIBUTE_METHODS and len(attribute_sources) == 1:
                eff_group = next(iter(attribute_sources.values()))
                eff_protected = eff_group
            try:
                result = builder.run(
                    embeddings=embeddings,
                    labels=labels,
                    group_labels=eff_group,
                    feature_names=feature_names,
                    protected_labels=eff_protected,
                    splits=self.splits_,
                    extra_labels=self.extra_labels_,
                )
            except Exception as exc:
                warnings.warn(f"Detection for '{method}' failed: {exc}", stacklevel=2)
                result = {"success": False, "error": str(exc)}
            apply_standardized_risk(method, result)

        self.results_[method] = result
        detector_instance = result.get("detector")
        if detector_instance is not None:
            self.detectors_[method] = detector_instance
        elif "by_attribute" in result:
            for attr_name, sub in result["by_attribute"].items():
                d = sub.get("detector")
                if d is not None:
                    self.detectors_[f"{method}_{attr_name}"] = d

    print("✅ Detection complete!")
    return self

summary

summary() -> str

Generate a text summary of detection results.

Returns:

Type Description
str

Formatted summary string

Source code in shortcut_detect/unified.py
def summary(self) -> str:
    """
    Generate a text summary of detection results.

    Returns:
        Formatted summary string
    """
    if not self.results_:
        return "No results available. Call .fit() first."

    lines = []
    lines.append("=" * 70)
    lines.append("UNIFIED SHORTCUT DETECTION SUMMARY")
    lines.append("=" * 70)
    lines.append(
        f"Dataset: {len(self.embeddings_)} samples, {self.embeddings_.shape[1]} dimensions"
    )
    lines.append(f"Methods used: {', '.join(self.methods)}")
    lines.append("")

    for method in self.methods:
        result = self.results_.get(method)
        if not result:
            continue

        title = result.get("summary_title") or method.replace("_", " ").title()

        if result.get("success"):
            summary_lines = result.get("summary_lines") or ["Summary unavailable."]
        else:
            # Method failed or was skipped
            error_msg = result.get("error", "Unknown error")
            summary_lines = result.get("summary_lines") or [f"⚠️  Skipped: {error_msg}"]

        lines.append("-" * 70)
        lines.append(title)
        lines.append("-" * 70)
        lines.extend(summary_lines)
        lines.append("")

    # Overall Assessment
    lines.append("=" * 70)
    lines.append("OVERALL ASSESSMENT")
    lines.append("=" * 70)
    lines.append(self._generate_overall_assessment())

    return "\n".join(lines)

generate_report

generate_report(
    output_path: str = None,
    format: str = "html",
    include_visualizations: bool = True,
    export_csv: bool = False,
    csv_dir: str = None,
)

Generate comprehensive report with visualizations.

Parameters:

Name Type Description Default
output_path str

Path to save report (required for HTML/PDF)

None
format str

Report format ('html' or 'pdf')

'html'
include_visualizations bool

Whether to include plots in HTML/PDF

True
export_csv bool

Whether to export results to CSV files

False
csv_dir str

Directory to save CSV files (default: same dir as output_path)

None

Raises:

Type Description
ValueError

If format is not supported or required paths are missing

Source code in shortcut_detect/unified.py
def generate_report(
    self,
    output_path: str = None,
    format: str = "html",
    include_visualizations: bool = True,
    export_csv: bool = False,
    csv_dir: str = None,
):
    """
    Generate comprehensive report with visualizations.

    Args:
        output_path: Path to save report (required for HTML/PDF)
        format: Report format ('html' or 'pdf')
        include_visualizations: Whether to include plots in HTML/PDF
        export_csv: Whether to export results to CSV files
        csv_dir: Directory to save CSV files (default: same dir as output_path)

    Raises:
        ValueError: If format is not supported or required paths are missing
    """
    import os

    from .reporting import ReportBuilder

    builder = ReportBuilder(self)

    # Generate HTML or PDF report
    if format in ["html", "pdf", "markdown"]:
        if not output_path:
            raise ValueError(f"output_path is required for {format} format")

        if format == "html":
            builder.to_html(output_path, include_visualizations=include_visualizations)
        elif format == "pdf":
            builder.to_pdf(output_path, include_visualizations=include_visualizations)
        elif format == "markdown":
            builder.to_markdown(output_path, include_visualizations=include_visualizations)
    else:
        raise ValueError(f"Unknown format: {format}. Use 'html', 'pdf', or 'markdown'")

    # Export CSV if requested
    if export_csv:
        if csv_dir is None:
            # Default to same directory as report
            if output_path:
                csv_dir = os.path.join(os.path.dirname(output_path) or ".", "csv_results")
            else:
                csv_dir = "csv_results"

        builder.to_csv(csv_dir)

get_results

get_results() -> dict[str, Any]

Get raw results dictionary.

Returns:

Type Description
dict[str, Any]

Dictionary with results from all methods

Source code in shortcut_detect/unified.py
def get_results(self) -> dict[str, Any]:
    """
    Get raw results dictionary.

    Returns:
        Dictionary with results from all methods
    """
    return self.results_

Quick Reference

Constructor

ShortcutDetector(
    methods: list[str] = ['hbac', 'probe', 'statistical'],
    seed: int = 42,
    condition_name: str = 'indicator_count',
    condition_kwargs: dict | None = None,
    **kwargs
)

Parameters

Parameter Type Default Description
methods list[str] ['hbac', 'probe', 'statistical'] Methods to run
seed int 42 Random seed for reproducibility
condition_name str 'indicator_count' Overall assessment condition to use
condition_kwargs dict None Keyword args passed to the selected condition
**kwargs dict - Additional detector-specific configuration

Available Methods

Method Key Description
'hbac' Hierarchical Bias-Aware Clustering
'probe' Probe-based classifier detection
'statistical' Feature-wise statistical testing
'geometric' Geometric subspace analysis
'cav' Concept Activation Vector testing (loader mode)

Methods

fit()

def fit(
    embeddings: np.ndarray,
    group_labels: np.ndarray,
    task_labels: np.ndarray = None,
    raw_inputs: list = None,
    embedding_source: EmbeddingSource = None,
    embedding_cache_path: str = None
) -> ShortcutDetector

Fit all detection methods on the data.

Parameters:

Parameter Type Description
embeddings ndarray Shape (n_samples, n_features). Pass None for embedding-only mode.
group_labels ndarray Protected attribute labels
task_labels ndarray Optional task labels
raw_inputs list Raw inputs for embedding generation
embedding_source EmbeddingSource Embedding generator for embedding-only mode
embedding_cache_path str Path to cache generated embeddings

Returns: self

summary()

def summary() -> str

Get a human-readable summary of all detection results.

Returns: Multi-line string with risk assessment and key metrics.

generate_report()

def generate_report(
    output_path: str,
    format: str = 'html',
    include_visualizations: bool = True
) -> None

Generate a comprehensive report.

Parameters:

Parameter Type Default Description
output_path str required Path for output file
format str 'html' Output format: 'html' or 'pdf'
include_visualizations bool True Include plots in report

get_results()

def get_results() -> dict

Get raw results from all methods.

Returns: Dictionary with results from each method.

Attributes (after fit)

Attribute Type Description
results_ dict Results from all detection methods
detectors_ dict Fitted detector instances for each method
embeddings_ ndarray Stored embeddings used for fitting
labels_ ndarray Stored labels used for fitting
group_labels_ ndarray Stored group labels (if provided)

Usage Examples

Basic Usage

from shortcut_detect import ShortcutDetector
import numpy as np

# Load data
embeddings = np.load("embeddings.npy")
group_labels = np.load("groups.npy")

# Create detector with all methods
detector = ShortcutDetector(
    methods=['hbac', 'probe', 'statistical', 'geometric']
)

# Fit
detector.fit(embeddings, group_labels)

# Results
print(detector.summary())
results = detector.get_results()
print(f"Probe accuracy: {results['probe']['accuracy']}")

Custom Parameters

detector = ShortcutDetector(
    methods=['hbac', 'probe', 'statistical'],
    seed=42,
    n_bootstraps=1000,
    probe_backend='sklearn'
)

Custom Overall Assessment

The condition_name parameter selects how method-level results are aggregated into the final risk summary. Five built-in conditions are available:

Condition Description Key Parameters
indicator_count Counts total risk indicators across methods (default) -
majority_vote Counts methods with at least one indicator as votes high_threshold (int)
weighted_risk Weights each detector by evidence strength (probe accuracy, stat significance ratio, HBAC confidence, geometric effect size) high_threshold (float), moderate_threshold (float)
multi_attribute Cross-references risk across sensitive attributes; escalates when multiple attributes independently flag shortcuts high_threshold (int)
meta_classifier Trained sklearn meta-classifier on detector features, or heuristic fallback model_path (str), high_threshold (float), moderate_threshold (float)
# Majority vote
detector = ShortcutDetector(
    methods=['probe', 'statistical'],
    condition_name='majority_vote',
    condition_kwargs={'high_threshold': 2},
)
# Weighted risk scoring
detector = ShortcutDetector(
    methods=['hbac', 'probe', 'statistical', 'geometric'],
    condition_name='weighted_risk',
    condition_kwargs={'high_threshold': 0.6, 'moderate_threshold': 0.3},
)
# Multi-attribute intersection
detector = ShortcutDetector(
    methods=['probe', 'statistical'],
    condition_name='multi_attribute',
    condition_kwargs={'high_threshold': 2},
)
# Meta-classifier (heuristic fallback when no trained model provided)
detector = ShortcutDetector(
    methods=['probe', 'statistical', 'hbac'],
    condition_name='meta_classifier',
)

# Meta-classifier with a trained model
detector = ShortcutDetector(
    methods=['probe', 'statistical', 'hbac'],
    condition_name='meta_classifier',
    condition_kwargs={'model_path': 'path/to/meta_model.joblib'},
)

The meta_classifier condition also exposes MetaClassifierCondition.extract_features(ctx) for building training datasets from synthetic benchmark runs.

The default indicator_count condition preserves the library's existing summary semantics.

Embedding-Only Mode

from shortcut_detect import ShortcutDetector, HuggingFaceEmbeddingSource

# Create embedding source
hf_source = HuggingFaceEmbeddingSource(
    model_name="sentence-transformers/all-MiniLM-L6-v2"
)

# Detect shortcuts from raw inputs
detector = ShortcutDetector(methods=['probe', 'statistical'])
detector.fit(
    embeddings=None,  # Triggers embedding-only mode
    group_labels=groups,
    raw_inputs=texts,
    embedding_source=hf_source,
    embedding_cache_path="embeddings.npy"
)

Report Generation

# HTML report (recommended)
detector.generate_report(
    output_path="report.html",
    format="html",
    include_visualizations=True
)

# PDF report
detector.generate_report(
    output_path="report.pdf",
    format="pdf"
)

Accessing Raw Results

# Get all results
results = detector.get_results()

# Access specific method results
hbac = results['hbac']
print(f"HBAC purity: {hbac['purity']:.2f}")

probe = results['probe']
print(f"Probe accuracy: {probe['accuracy']:.2%}")

statistical = results['statistical']
print(f"Significant features: {statistical['n_significant']}")

geometric = results['geometric']
print(f"Effect size: {geometric['effect_size']:.2f}")

Loader-based CAV Usage

detector = ShortcutDetector(methods=["cav"])
detector.fit_from_loaders({
    "cav": {
        "concept_sets": {"shortcut": concept_arr},
        "random_set": random_arr,
        "target_directional_derivatives": directional_derivatives,  # optional but needed for TCAV risk
    }
})
print(detector.get_results()["cav"]["metrics"])

Error Handling

try:
    detector.fit(embeddings, group_labels)
except ValueError as e:
    print(f"Invalid input: {e}")
except RuntimeError as e:
    print(f"Detection failed: {e}")

See Also