Skip to content

Embedding Sources API

Embedding sources allow shortcut detection without direct model access.

Class Reference

EmbeddingSource

EmbeddingSource

EmbeddingSource(name: str = 'embedding_source')

Bases: ABC

Abstract base class for embedding generators.

Source code in shortcut_detect/embedding_sources.py
def __init__(self, name: str = "embedding_source"):
    self.name = name

Functions

generate abstractmethod

generate(inputs: Sequence[Any]) -> np.ndarray

Generate embeddings for a sequence of inputs.

Parameters:

Name Type Description Default
inputs Sequence[Any]

Sequence of raw inputs (text, images, etc.)

required

Returns:

Type Description
ndarray

np.ndarray of shape (n_samples, embedding_dim)

Source code in shortcut_detect/embedding_sources.py
@abstractmethod
def generate(self, inputs: Sequence[Any]) -> np.ndarray:
    """
    Generate embeddings for a sequence of inputs.

    Args:
        inputs: Sequence of raw inputs (text, images, etc.)

    Returns:
        np.ndarray of shape (n_samples, embedding_dim)
    """

HuggingFaceEmbeddingSource

HuggingFaceEmbeddingSource

HuggingFaceEmbeddingSource(
    model_name: str,
    tokenizer_name: str | None = None,
    device: str | None = None,
    batch_size: int = 16,
    pooling: str = "cls",
    normalize: bool = True,
)

Bases: EmbeddingSource

Generate embeddings from any Hugging Face transformer model without requiring gradient access.

Parameters:

Name Type Description Default
model_name str

Name or path of the Hugging Face model to load.

required
tokenizer_name str | None

Optional tokenizer name (defaults to model_name).

None
device str | None

Device string ("cpu", "cuda", etc.). Auto-detect if None.

None
batch_size int

Batch size for inference.

16
pooling str

"cls" or "mean" pooling strategy.

'cls'
normalize bool

Whether to L2-normalize embeddings.

True
Source code in shortcut_detect/embedding_sources.py
def __init__(
    self,
    model_name: str,
    tokenizer_name: str | None = None,
    device: str | None = None,
    batch_size: int = 16,
    pooling: str = "cls",
    normalize: bool = True,
):
    """
    Args:
        model_name: Name or path of the Hugging Face model to load.
        tokenizer_name: Optional tokenizer name (defaults to model_name).
        device: Device string ("cpu", "cuda", etc.). Auto-detect if None.
        batch_size: Batch size for inference.
        pooling: "cls" or "mean" pooling strategy.
        normalize: Whether to L2-normalize embeddings.
    """
    super().__init__(name=f"huggingface:{model_name}")
    torch = _require_torch()
    self.model_name = model_name
    self.tokenizer_name = tokenizer_name or model_name
    self.batch_size = batch_size
    self.pooling = pooling
    self.normalize = normalize
    self.device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))
    self._model = None
    self._tokenizer = None

Functions

CallableEmbeddingSource

CallableEmbeddingSource

CallableEmbeddingSource(
    fn: Callable[[Sequence[Any]], ndarray],
    name: str = "callable_source",
)

Bases: EmbeddingSource

Wrap an arbitrary callable so it can be used as an embedding source.

The callable should accept a sequence of inputs and return a 2D numpy array. This is useful for production or closed-source APIs where only inference access is available.

Source code in shortcut_detect/embedding_sources.py
def __init__(self, fn: Callable[[Sequence[Any]], np.ndarray], name: str = "callable_source"):
    super().__init__(name=name)
    self.fn = fn

EmbeddingSource (Base)

Abstract base class for embedding generators.

from shortcut_detect.embedding_sources import EmbeddingSource

class MyEmbeddingSource(EmbeddingSource):
    def embed(self, inputs: list) -> np.ndarray:
        # Return embeddings for inputs
        return embeddings

    @property
    def name(self) -> str:
        return "my_source"

HuggingFaceEmbeddingSource

Generate embeddings using HuggingFace transformers.

Constructor

HuggingFaceEmbeddingSource(
    model_name: str,
    pooling: str = 'mean',
    batch_size: int = 32,
    device: str = None,
    max_length: int = 512
)

Parameters

Parameter Type Default Description
model_name str required HuggingFace model name
pooling str 'mean' Pooling strategy
batch_size int 32 Batch size for encoding
device str None Device (auto-detected)
max_length int 512 Maximum sequence length

Pooling Options

Value Description
'mean' Mean of all token embeddings
'cls' [CLS] token embedding
'max' Max pooling over tokens
'last' Last token embedding

Usage

from shortcut_detect import HuggingFaceEmbeddingSource

source = HuggingFaceEmbeddingSource(
    model_name="sentence-transformers/all-MiniLM-L6-v2",
    pooling="mean",
    batch_size=64
)

texts = ["Sample text 1", "Sample text 2", ...]
embeddings = source.embed(texts)
print(embeddings.shape)  # (n_samples, 384)

With ShortcutDetector

from shortcut_detect import ShortcutDetector, HuggingFaceEmbeddingSource

source = HuggingFaceEmbeddingSource("bert-base-uncased")

detector = ShortcutDetector(methods=['probe', 'statistical'])
detector.fit(
    embeddings=None,
    group_labels=groups,
    raw_inputs=texts,
    embedding_source=source,
    embedding_cache_path="cached_embeddings.npy"
)

CallableEmbeddingSource

Wrap any function as an embedding source.

Constructor

CallableEmbeddingSource(
    embed_fn: Callable[[list], np.ndarray],
    name: str = "callable"
)

Parameters

Parameter Type Description
embed_fn callable Function taking list, returning ndarray
name str Name for logging

Usage

from shortcut_detect import CallableEmbeddingSource
import numpy as np

# Wrap external API
def my_embedding_api(texts):
    # Call your API
    response = external_client.embed(texts)
    return np.array(response["embeddings"])

source = CallableEmbeddingSource(
    embed_fn=my_embedding_api,
    name="my_api"
)

embeddings = source.embed(["text1", "text2"])

With Batching

def batched_api(texts, batch_size=32):
    all_embeddings = []
    for i in range(0, len(texts), batch_size):
        batch = texts[i:i+batch_size]
        embeddings = external_api.embed(batch)
        all_embeddings.append(embeddings)
    return np.vstack(all_embeddings)

source = CallableEmbeddingSource(
    embed_fn=lambda x: batched_api(x, batch_size=64),
    name="batched_api"
)

Caching

All embedding sources support caching:

detector.fit(
    embeddings=None,
    group_labels=groups,
    raw_inputs=texts,
    embedding_source=source,
    embedding_cache_path="embeddings.npy"  # Cache here
)

# Second run loads from cache
detector2.fit(
    embeddings=None,
    group_labels=groups,
    raw_inputs=texts,
    embedding_source=source,
    embedding_cache_path="embeddings.npy"  # Loaded from cache
)

See Also