Skip to content

feat: BERT encoder inference for cross-encoder reranking (.apr) #326

@noahgift

Description

@noahgift

Summary

Verify and implement end-to-end BERT encoder inference in aprender, enabling cross-encoder reranking models (e.g., BAAI/bge-reranker-base) to run as .apr models via trueno SIMD. This is the sovereign-stack alternative to ONNX Runtime / fastembed for neural reranking.

Motivation

trueno-rag's RAG pipeline achieves MRR 0.952 with semantic hybrid (BGE-small + BM25 RRF). Cross-encoder reranking is the standard next step to push MRR toward 0.97+. Lexical reranking was tested and rejected (regressed to MRR 0.876 — term overlap disrupts semantic ordering).

Rather than adding ort (ONNX Runtime) as a dependency, the sovereign approach is:

  1. Convert cross-encoder ONNX/SafeTensors → .apr via apr import
  2. Run BERT inference natively via aprender/trueno SIMD
  3. Same pattern as whisper-apr (whisper model → .apr → pure Rust inference)

Design

Model Architecture: BERT Cross-Encoder

Cross-encoders are BERT-base models with a classification/regression head:

Input:  [CLS] query_tokens [SEP] passage_tokens [SEP]
         ↓
      BERT Encoder (12 layers, 768d, 12 heads)
         ↓
      CLS pooling (extract [CLS] token embedding)
         ↓
      Linear head (768 → 1) → sigmoid → relevance score [0, 1]

Target Models

Model Params Dim Source
BAAI/bge-reranker-base 109M 768 HuggingFace SafeTensors
cross-encoder/ms-marco-MiniLM-L-6-v2 22M 384 HuggingFace SafeTensors

Required Components

1. BERT Tensor Name Mapping (Architecture::Bert)

Verify map_name() handles all BERT tensors:

  • bert.embeddings.word_embeddings.weight
  • bert.embeddings.position_embeddings.weight
  • bert.embeddings.token_type_embeddings.weight
  • bert.embeddings.LayerNorm.{weight,bias}
  • bert.encoder.layer.{N}.attention.self.{query,key,value}.{weight,bias}
  • bert.encoder.layer.{N}.attention.output.dense.{weight,bias}
  • bert.encoder.layer.{N}.attention.output.LayerNorm.{weight,bias}
  • bert.encoder.layer.{N}.intermediate.dense.{weight,bias}
  • bert.encoder.layer.{N}.output.dense.{weight,bias}
  • bert.encoder.layer.{N}.output.LayerNorm.{weight,bias}
  • classifier.{weight,bias} (regression head)

2. BERT Encoder Forward Pass

pub struct BertEncoder {
    config: BertConfig,  // n_layers, n_heads, hidden_dim, intermediate_dim
    tensors: AprTensorStore,
}

impl BertEncoder {
    /// Forward pass: token_ids + type_ids + position_ids → hidden states
    pub fn forward(&self, input: &BertInput) -> Vec<f32> {
        // 1. Embedding lookup: word + position + token_type
        // 2. LayerNorm + dropout
        // 3. For each encoder layer:
        //    a. Multi-head self-attention (trueno matmul + softmax)
        //    b. Residual + LayerNorm
        //    c. FFN: Linear(768→3072) → GELU → Linear(3072→768)
        //    d. Residual + LayerNorm
        // 4. Return all hidden states (or just CLS)
    }
}

3. Cross-Encoder Scoring Wrapper

pub struct CrossEncoder {
    encoder: BertEncoder,
    tokenizer: WordPieceTokenizer,
    classifier: Linear,  // hidden_dim → 1
}

impl CrossEncoder {
    pub fn score(&self, query: &str, passage: &str) -> f32 {
        let input = self.tokenizer.encode_pair(query, passage);  // [CLS] q [SEP] p [SEP]
        let hidden = self.encoder.forward(&input);
        let cls = &hidden[..self.config.hidden_dim];  // CLS token
        sigmoid(self.classifier.forward(cls))  // relevance score
    }

    pub fn score_batch(&self, query: &str, passages: &[&str]) -> Vec<f32> {
        // Batch scoring for reranking top-N candidates
    }
}

4. GELU Activation

BERT uses GELU (not SiLU/SwiGLU like decoder models). Verify trueno supports it:

// GELU(x) = x * Φ(x) ≈ x * sigmoid(1.702 * x)  [fast approximation]
trueno::gelu_scalar(x: f32) -> f32

Acceptance Criteria

  • apr import hf://BAAI/bge-reranker-base --architecture bert -o bge-reranker.apr succeeds
  • apr inspect bge-reranker.apr shows all expected tensors with correct shapes
  • BertEncoder::forward() produces correct hidden states (validated against HuggingFace reference output)
  • CrossEncoder::score("query", "passage") returns reasonable relevance score (>0.5 for relevant, <0.5 for irrelevant)
  • Inference is deterministic (same input → same output)
  • SIMD-aligned tensor access (64-byte, zero-copy via trueno)

Verification Strategy

  1. Export reference outputs from HuggingFace transformers (Python):
    from transformers import AutoTokenizer, AutoModelForSequenceClassification
    model = AutoModelForSequenceClassification.from_pretrained("BAAI/bge-reranker-base")
    # Save intermediate activations for layer-by-layer comparison
  2. Compare aprender inference output against reference at each layer
  3. Tolerance: cosine similarity > 0.999 for F32, > 0.99 for F16

Related

Non-Goals

  • Training / fine-tuning BERT (inference only)
  • Decoder-only or encoder-decoder models (already supported via Llama/Whisper)
  • ONNX Runtime dependency (the whole point is sovereign inference)

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions