Skip to content

Demographic Parity Detector API

The DemographicParityDetector computes the demographic parity gap across demographic groups to detect shortcut reliance on protected attributes.

Class Reference

DemographicParityDetector

DemographicParityDetector(
    estimator: LogisticRegression | None = None,
    min_group_size: int = 10,
    dp_gap_threshold: float = 0.1,
)

Bases: DetectorBase

Compute demographic parity gap across demographic groups.

Source code in shortcut_detect/fairness/demographic_parity/src/detector.py
def __init__(
    self,
    estimator: LogisticRegression | None = None,
    min_group_size: int = 10,
    dp_gap_threshold: float = 0.1,
) -> None:
    super().__init__(method="demographic_parity")

    self.estimator = estimator or LogisticRegression(max_iter=1000)
    self.min_group_size = min_group_size
    self.dp_gap_threshold = dp_gap_threshold

    self.group_rates_: dict[str, dict[str, float]] = {}
    self.dp_gap_: float = float("nan")
    self.overall_positive_rate_: float = float("nan")
    self.report_: DemographicParityReport | None = None

Functions

fit

fit(
    embeddings: ndarray,
    labels: ndarray,
    group_labels: ndarray,
) -> DemographicParityDetector

Train classifier and compute demographic parity gap.

Source code in shortcut_detect/fairness/demographic_parity/src/detector.py
def fit(
    self,
    embeddings: np.ndarray,
    labels: np.ndarray,
    group_labels: np.ndarray,
) -> DemographicParityDetector:
    """Train classifier and compute demographic parity gap."""
    if group_labels is None:
        raise ValueError("DemographicParityDetector requires group_labels.")

    if embeddings.ndim != 2:
        raise ValueError("Embeddings must be 2D (n_samples, embedding_dim).")
    if labels.ndim != 1:
        raise ValueError("Labels must be 1D.")
    if group_labels.ndim != 1:
        raise ValueError("group_labels must be 1D.")

    if embeddings.shape[0] != labels.shape[0] or labels.shape[0] != group_labels.shape[0]:
        raise ValueError("Embeddings, labels, and group_labels must align.")

    unique_labels = np.unique(labels)
    if unique_labels.size != 2:
        raise ValueError("Demographic parity requires binary labels.")

    self.estimator.fit(embeddings, labels)
    preds = self.estimator.predict(embeddings)
    positive_label = self.estimator.classes_[1]
    self.overall_positive_rate_ = float(np.mean(preds == positive_label))

    self.group_rates_ = self._compute_group_rates(preds, group_labels, positive_label)
    self.dp_gap_ = self._compute_gap()

    risk_level, notes = self._assess_risk()
    self.report_ = DemographicParityReport(
        group_rates=self.group_rates_,
        dp_gap=self.dp_gap_,
        overall_positive_rate=self.overall_positive_rate_,
        reference="Feldman et al. 2015",
        risk_level=risk_level,
        notes=notes,
    )
    self._finalize_results()
    self._is_fitted = True
    return self

Quick Reference

Constructor

DemographicParityDetector(
    estimator: LogisticRegression | None = None,
    min_group_size: int = 10,
    dp_gap_threshold: float = 0.1,
)

Parameters

Parameter Type Default Description
estimator LogisticRegression None sklearn classifier (default: LogisticRegression with max_iter=1000)
min_group_size int 10 Minimum group size; smaller groups get NaN rates
dp_gap_threshold float 0.1 DP gap threshold for shortcut flagging

Methods

fit()

def fit(
    embeddings: np.ndarray,
    labels: np.ndarray,
    group_labels: np.ndarray,
) -> DemographicParityDetector

Train classifier and compute demographic parity gap.

Parameters:

Parameter Type Description
embeddings ndarray Shape (n_samples, n_features), 2D array
labels ndarray Shape (n_samples,), binary labels
group_labels ndarray Shape (n_samples,), demographic group labels

Returns: self

Raises:

  • ValueError if embeddings is not 2D, labels is not 1D, or group_labels is not 1D
  • ValueError if shapes do not align
  • ValueError if labels are not binary (exactly 2 unique values)
  • ValueError if group_labels is None

get_report()

def get_report() -> dict[str, Any]

Get the detection report after fitting. Inherits from DetectorBase.

Attributes (after fit)

Attribute Type Description
group_rates_ dict Per-group positive rate and support
dp_gap_ float Computed demographic parity gap
overall_positive_rate_ float Overall positive prediction rate
report_ DemographicParityReport Detailed report dataclass
shortcut_detected_ bool or None Whether a shortcut was detected

Usage Examples

Basic Usage

from shortcut_detect.fairness import DemographicParityDetector

detector = DemographicParityDetector()
detector.fit(embeddings, labels, group_labels=group_labels)
report = detector.get_report()
print(report["shortcut_detected"])
print(report["metrics"]["dp_gap"])

Custom Threshold

detector = DemographicParityDetector(
    dp_gap_threshold=0.05,
    min_group_size=20,
)
detector.fit(embeddings, labels, group_labels=group_labels)

Via Unified ShortcutDetector

from shortcut_detect import ShortcutDetector

detector = ShortcutDetector(methods=["demographic_parity"])
detector.fit(embeddings, labels, group_labels=group_labels)
print(detector.summary())

See Also