From 68aaadcce111c18d9561131d8108faed16916718 Mon Sep 17 00:00:00 2001 From: thomasht86 Date: Mon, 13 Apr 2026 13:16:34 +0200 Subject: [PATCH] Add colbert-pooling sample app: ColBERT with hierarchical token pooling PoolingColBertEmbedder runs ColBERT-Zero ONNX model inside Vespa and applies Ward's agglomerative clustering to reduce per-document token count. Binary pooled embeddings are HNSW-indexed in memory for fast ANN retrieval; full bfloat16 embeddings stay paged on disk for exact second-phase reranking. ColBertSearcher handles query embedding, binarisation, and nearestNeighbor query construction so users just send text. Verified against pylate (model authors' reference): cos_sim > 0.999 per token vs ONNX INT8 reference, 100% bit-exact binarisation. Co-Authored-By: Claude Opus 4.6 (1M context) --- colbert-pooling/.gitignore | 10 + colbert-pooling/README.md | 242 ++++++++++++ colbert-pooling/export_model.py | 93 +++++ colbert-pooling/ext/feed.jsonl | 100 +++++ colbert-pooling/feed.py | 174 +++++++++ colbert-pooling/pom.xml | 39 ++ colbert-pooling/pyproject.toml | 32 ++ .../src/main/application/schemas/doc.sd | 142 +++++++ .../search/query-profiles/default.xml | 4 + .../search/query-profiles/types/root.xml | 4 + .../src/main/application/services.xml | 49 +++ .../main/application/validation-overrides.xml | 3 + .../ai/vespa/colbert/ColBertSearcher.java | 125 +++++++ .../colbert/HierarchicalTokenPooling.java | 314 ++++++++++++++++ .../vespa/colbert/PoolingColBertEmbedder.java | 352 ++++++++++++++++++ .../pooling-colbert-embedder.def | 33 ++ .../system-test/feed-and-search-test.json | 84 +++++ .../colbert/HierarchicalTokenPoolingTest.java | 244 ++++++++++++ colbert-pooling/verify_correctness.py | 275 ++++++++++++++ 19 files changed, 2319 insertions(+) create mode 100644 colbert-pooling/.gitignore create mode 100644 colbert-pooling/README.md create mode 100644 colbert-pooling/export_model.py create mode 100644 colbert-pooling/ext/feed.jsonl create mode 100644 colbert-pooling/feed.py create mode 100644 colbert-pooling/pom.xml create mode 100644 colbert-pooling/pyproject.toml create mode 100644 colbert-pooling/src/main/application/schemas/doc.sd create mode 100644 colbert-pooling/src/main/application/search/query-profiles/default.xml create mode 100644 colbert-pooling/src/main/application/search/query-profiles/types/root.xml create mode 100644 colbert-pooling/src/main/application/services.xml create mode 100644 colbert-pooling/src/main/application/validation-overrides.xml create mode 100644 colbert-pooling/src/main/java/ai/vespa/colbert/ColBertSearcher.java create mode 100644 colbert-pooling/src/main/java/ai/vespa/colbert/HierarchicalTokenPooling.java create mode 100644 colbert-pooling/src/main/java/ai/vespa/colbert/PoolingColBertEmbedder.java create mode 100644 colbert-pooling/src/main/resources/configdefinitions/pooling-colbert-embedder.def create mode 100644 colbert-pooling/src/test/application/tests/system-test/feed-and-search-test.json create mode 100644 colbert-pooling/src/test/java/ai/vespa/colbert/HierarchicalTokenPoolingTest.java create mode 100644 colbert-pooling/verify_correctness.py diff --git a/colbert-pooling/.gitignore b/colbert-pooling/.gitignore new file mode 100644 index 000000000..5d60fa657 --- /dev/null +++ b/colbert-pooling/.gitignore @@ -0,0 +1,10 @@ +target/ +.venv/ +__pycache__/ +*.pyc +.pytest_cache/ +uv.lock + +# ONNX model files (large, export with export_model.py) +src/main/application/models/ +models/ diff --git a/colbert-pooling/README.md b/colbert-pooling/README.md new file mode 100644 index 000000000..d3156f6d2 --- /dev/null +++ b/colbert-pooling/README.md @@ -0,0 +1,242 @@ +# ColBERT with Hierarchical Token Pooling + +> **Motivation.** ColBERT's multi-vector representation delivers excellent +> retrieval quality, but it comes with storage and query-time costs: every +> document token gets its own embedding, and every query needs per-token +> `nearestNeighbor` operators plus custom binarisation logic. +> +> This sample application prototypes the components that would make ColBERT a +> first-class citizen in Vespa: a pooling-aware embedder, a query searcher that +> handles all the multi-vector plumbing, and a two-phase ranking strategy that +> combines cheap HNSW retrieval with precise late-interaction reranking. The +> goal is to explore how these could be generalised and integrated into Vespa +> itself, so that using ColBERT becomes as simple as using any single-vector +> embedder. + +## What this app does + +All embedding, pooling, binarisation, query rewriting, and ranking runs +**inside Vespa** -- the user just sends text. + +| Component | Role | +|---|---| +| `PoolingColBertEmbedder` | Runs the ONNX model at indexing time, optionally applies hierarchical token pooling (Ward's clustering), produces bfloat16 or int8 tensors | +| `ColBertSearcher` | At query time: embeds the query, binarises per-token vectors, builds `nearestNeighbor` OR query, sets all ranking tensors | +| Rank profile `colbert-pooled-ann` | HNSW ANN retrieval on binary pooled field, hamming MaxSim first phase, exact bfloat16 MaxSim second phase | + +### Fields + +| Field | Type | Storage | Description | +|---|---|---|---| +| `colbert` | `tensor(dt{}, x[128])` | paged (disk) | Full non-pooled ColBERT embeddings for exact reranking | +| `colbert_pooled_binary` | `tensor(dt{}, x[16])` | in-memory + HNSW | Pooled + binarised embeddings for fast ANN retrieval | + +### Memory savings + +The combination of token pooling and binarisation dramatically reduces the +in-memory footprint of the HNSW-indexed field: + +| Representation | Tokens (typical) | Bytes per token | Per-document (200 tokens) | 1B docs | +|---|---|---|---|---| +| Full bfloat16 | 200 | 256 (128 dims x 2B) | 50 KB | **~47 TB** | +| Pooled bfloat16 (factor=2) | 100 | 256 | 25 KB | ~24 TB | +| Pooled binary (factor=2) | 100 | 16 (128 bits packed) | 1.6 KB | ~1.5 TB | +| **Pooled binary (factor=4)** | **50** | **16** | **0.8 KB** | **~0.7 TB** | + +With pool factor 4 and binarisation, the in-memory HNSW index is **~64x +smaller** than full bfloat16 -- making ColBERT-scale retrieval feasible at +billion-document scale. The full bfloat16 embeddings stay on disk (paged) +and are only read for the top-100 second-phase rerank. + +### Ranking pipeline + +| Phase | What | Field | Cost | +|---|---|---|---| +| **Retrieval** | `nearestNeighbor` ANN via HNSW (hamming) | `colbert_pooled_binary` (in-memory) | Cheap | +| **First-phase scoring** | Hamming-based MaxSim on binary embeddings | `colbert_pooled_binary` | Cheap | +| **Second-phase rerank** | Exact MaxSim on full bfloat16 (top 100) | `colbert` (paged from disk) | Expensive but precise | + +## Why the custom searcher? + +Vespa's `nearestNeighbor` operator works on single-vector query tensors, but +ColBERT queries are multi-vector (one embedding per token). To bridge this gap, +`ColBertSearcher` runs at query time and: + +1. Calls the embedder to produce a float multi-vector query tensor (`qt`). +2. Binarises each token into a packed int8 vector. +3. Sets per-token `nearestNeighbor` query tensors (`rq0`..`rqN`) and ORs them + together for HNSW retrieval. +4. Sets the binary multi-vector `qtb` for hamming MaxSim scoring. + +This means the user just sends `query=planets in the solar system` and the +searcher handles all the plumbing. Without it, the client would need to +compute and pass 32+ binary tensors per request. + +The per-token input declarations (`rq0`..`rq31`) in the rank profile are a +Vespa platform requirement: the content node needs them to resolve +`nearestNeighbor` terms, even though the searcher populates them +programmatically. + +## Quick start + +### 1. Export the ONNX model + +```bash +uv pip install pylate-onnx-export onnxscript +uv run python export_model.py +``` + +This uses [pylate-onnx-export](https://github.com/lightonai/next-plaid/tree/main/next-plaid-onnx/python) +to export [ColBERT-Zero](https://huggingface.co/lightonai/ColBERT-Zero) to +ONNX with INT8 dynamic quantisation, and copies `model_int8.onnx` + +`tokenizer.json` into `src/main/application/models/`. + +### 2. Build the application + +```bash +mvn clean package -DskipTests +``` + +### 3. Deploy + +```bash +vespa config set target local +docker run --detach --name vespa --hostname vespa-container \ + --publish 8080:8080 --publish 19071:19071 \ + vespaengine/vespa + +vespa deploy target/application +``` + +### 4. Feed 100 documents + +```bash +vespa feed ext/feed.jsonl +``` + +Vespa computes both ColBERT representations at indexing time using the +`PoolingColBertEmbedder` component. + +### 5. Query + +Just send text -- the `ColBertSearcher` handles embedding, binarisation, and +`nearestNeighbor` query construction: + +```bash +vespa query 'query=planets in the solar system' +vespa query 'query=machine learning and artificial intelligence' +vespa query 'query=renewable energy sources' +``` + +### 6. Run tests + +```bash +mvn test +``` + +### 7. Verify correctness + +The `verify_correctness.py` script validates Vespa's embeddings against +[pylate](https://github.com/lightonai/pylate), the model authors' reference +implementation. + +``` +pylate (FP32 PyTorch) ← quantisation floor (cos ~0.98) → ONNX INT8 + ↕ identical (cos >0.999) + Vespa (Java + ONNX INT8) +``` + +Vespa and the Python ONNX INT8 reference produce near-identical embeddings +(cos_sim > 0.999 per token, 100% bit-exact binarisation). The ~2% gap to +pylate FP32 is the INT8 quantisation floor, not an implementation difference. + +```bash +uv pip install pylate onnxruntime tokenizers + +# Full comparison (Vespa must be running): +uv run python verify_correctness.py + +# Without Vespa (pylate vs ONNX INT8 only): +uv run python verify_correctness.py --no-vespa +``` + +## Hierarchical token pooling + +The `HierarchicalTokenPooling` class (invoked by `PoolingColBertEmbedder` +after the ONNX model produces per-token embeddings) implements: + +1. **Pairwise cosine distances** between all token embeddings (excluding CLS). +2. **Ward's agglomerative clustering** using the Nearest-Neighbor Chain (NNC) + algorithm -- O(n^2) time. +3. **Cut the dendrogram** to `ceil(n / poolFactor)` clusters. +4. **Replace each cluster with its L2-normalised centroid**. +5. **Prepend the original CLS token** (always preserved). + +The algorithm is a Java port of the +[hierarchy.rs](https://github.com/lightonai/next-plaid/blob/main/next-plaid-onnx/src/hierarchy.rs) +module from [next-plaid-onnx](https://github.com/lightonai/next-plaid). + +### Performance (Java, 128-dim embeddings) + +| Tokens | `poolTokens` | `pdistCosine` | +|--------|-------------|---------------| +| 128 | ~1 ms | < 1 ms | +| 512 | ~14 ms | ~9 ms | + +## Architecture + +### Embedder + +`PoolingColBertEmbedder` is a custom Vespa `Embedder` that runs the ColBERT +ONNX model and optionally applies hierarchical token pooling, controlled by +`poolFactor`: + +- **`poolFactor=0`** -- standard ColBERT (no pooling) +- **`poolFactor=2`** -- merge similar tokens, keep ~half the vectors +- **`poolFactor=3`** -- keep roughly a third, etc. + +Two instances are configured in `services.xml` with different `poolFactor` +values, sharing the same ONNX model and tokenizer. + +### Sequence construction + +Input sequences match [pylate](https://github.com/lightonai/pylate)'s +ColBERT encoding: + +``` +[CLS] [D] search_document: [SEP] (documents) +[CLS] [Q] search_query: [MASK]... (queries, padded) +``` + +The `search_document:` / `search_query:` prefixes are configurable via +`prependDocument` / `prependQuery` in the +[config definition](src/main/resources/configdefinitions/pooling-colbert-embedder.def). +ColBERT-Zero requires these prompts for full retrieval quality +([model card](https://huggingface.co/lightonai/ColBERT-Zero)). + +## Project structure + +``` +colbert-pooling/ +├── pom.xml # Maven build (container-plugin) +├── src/main/ +│ ├── application/ # Vespa application package +│ │ ├── schemas/doc.sd # Schema, HNSW index, rank profile +│ │ ├── services.xml # Embedder + Searcher config +│ │ └── search/query-profiles/ # Query tensor type definitions +│ ├── java/ai/vespa/colbert/ +│ │ ├── HierarchicalTokenPooling.java # Ward NNC + fcluster + poolTokens +│ │ ├── PoolingColBertEmbedder.java # Custom Vespa Embedder +│ │ └── ColBertSearcher.java # Query rewriting + embedding +│ └── resources/configdefinitions/ +│ └── pooling-colbert-embedder.def # Embedder config definition +├── src/test/ +│ ├── java/ai/vespa/colbert/ +│ │ └── HierarchicalTokenPoolingTest.java # JUnit tests + benchmarks +│ └── application/tests/system-test/ +│ └── feed-and-search-test.json # Vespa system test +├── verify_correctness.py # E2E: pylate vs ONNX INT8 vs Vespa +├── ext/feed.jsonl # 100 sample documents +├── export_model.py # ONNX export via pylate-onnx-export +└── feed.py # Feed JSONL generation helper +``` diff --git a/colbert-pooling/export_model.py b/colbert-pooling/export_model.py new file mode 100644 index 000000000..949a9df03 --- /dev/null +++ b/colbert-pooling/export_model.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python3 +"""Export lightonai/ColBERT-Zero to ONNX with INT8 quantisation. + +Uses pylate-onnx-export to produce an optimised ONNX model and copies +the artefacts into the Vespa application package at app/models/. + +Install the export dependency first: + uv pip install pylate-onnx-export + +Usage: + uv run python export_model.py +""" + +from __future__ import annotations + +import json +import shutil +import sys +from pathlib import Path + +MODEL_NAME = "lightonai/ColBERT-Zero" +APP_MODELS_DIR = Path(__file__).resolve().parent / "src" / "main" / "application" / "models" + + +def main() -> None: + try: + from colbert_export.export import export_model + except ImportError: + sys.exit( + "pylate-onnx-export is not installed.\n" + "Install it with: uv pip install pylate-onnx-export" + ) + + print(f"Exporting {MODEL_NAME} to ONNX (FP32 + INT8) ...") + output_path = export_model( + model_name=MODEL_NAME, + output_dir=Path("models") / "ColBERT-Zero", + quantize=True, + verbose=True, + force=True, + ) + + # Copy artefacts into the Vespa application package + APP_MODELS_DIR.mkdir(parents=True, exist_ok=True) + + int8_model = output_path / "model_int8.onnx" + if not int8_model.exists(): + sys.exit(f"Expected INT8 model not found at {int8_model}") + + # Find tokenizer.json – pylate-onnx-export places it alongside the model + tokenizer = output_path / "tokenizer.json" + if not tokenizer.exists(): + # Fall back: download directly from HuggingFace + print("tokenizer.json not in export dir – downloading from HuggingFace ...") + try: + from huggingface_hub import hf_hub_download + + tokenizer = Path( + hf_hub_download(repo_id=MODEL_NAME, filename="tokenizer.json") + ) + except Exception as exc: + sys.exit(f"Could not locate tokenizer.json: {exc}") + + shutil.copy2(int8_model, APP_MODELS_DIR / "model_int8.onnx") + shutil.copy2(tokenizer, APP_MODELS_DIR / "tokenizer.json") + + # Write a small metadata file consumed by feed.py + config = { + "model_name": MODEL_NAME, + "embedding_dim": 128, + "onnx_model": str(APP_MODELS_DIR / "model_int8.onnx"), + "tokenizer": str(APP_MODELS_DIR / "tokenizer.json"), + } + # Also check onnx_config.json from export + onnx_cfg_path = output_path / "onnx_config.json" + if onnx_cfg_path.exists(): + with open(onnx_cfg_path) as f: + onnx_cfg = json.load(f) + config["embedding_dim"] = onnx_cfg.get("embedding_dim", 128) + + meta_path = APP_MODELS_DIR / "model_config.json" + with open(meta_path, "w") as f: + json.dump(config, f, indent=2) + + print(f"\nModel artefacts copied to {APP_MODELS_DIR}/") + print(f" model_int8.onnx ({int8_model.stat().st_size / 1e6:.1f} MB)") + print(f" tokenizer.json") + print(f" model_config.json") + print("\nVespa application is ready to deploy.") + + +if __name__ == "__main__": + main() diff --git a/colbert-pooling/ext/feed.jsonl b/colbert-pooling/ext/feed.jsonl new file mode 100644 index 000000000..f0402bcc0 --- /dev/null +++ b/colbert-pooling/ext/feed.jsonl @@ -0,0 +1,100 @@ +{"put": "id:doc:doc::0", "fields": {"doc_id": "0", "title": "Solar System Overview", "text": "The solar system consists of the Sun and the celestial bodies that orbit it, including eight planets, their moons, dwarf planets, asteroids, and comets. The four inner planets are rocky worlds, while the outer planets are gas and ice giants."}} +{"put": "id:doc:doc::1", "fields": {"doc_id": "1", "title": "Photosynthesis", "text": "Photosynthesis is the biological process by which green plants and certain organisms convert light energy into chemical energy stored in glucose. It takes place primarily in chloroplasts using chlorophyll pigments."}} +{"put": "id:doc:doc::2", "fields": {"doc_id": "2", "title": "Machine Learning Basics", "text": "Machine learning is a branch of artificial intelligence where computers learn patterns from data without being explicitly programmed. Common approaches include supervised learning, unsupervised learning, and reinforcement learning."}} +{"put": "id:doc:doc::3", "fields": {"doc_id": "3", "title": "The Roman Empire", "text": "The Roman Empire was one of the largest empires in ancient history, spanning from Britain to Mesopotamia at its peak. It profoundly influenced Western civilisation through its legal system, architecture, engineering, and language."}} +{"put": "id:doc:doc::4", "fields": {"doc_id": "4", "title": "DNA Structure", "text": "Deoxyribonucleic acid is a double helix molecule that carries the genetic instructions for the development and functioning of all known living organisms. Watson and Crick described its structure in 1953."}} +{"put": "id:doc:doc::5", "fields": {"doc_id": "5", "title": "Ocean Currents", "text": "Ocean currents are continuous movements of seawater driven by wind, temperature differences, salinity, and the Earth's rotation. The Gulf Stream carries warm water from the tropics toward northern Europe."}} +{"put": "id:doc:doc::6", "fields": {"doc_id": "6", "title": "Neural Networks", "text": "Artificial neural networks are computational models inspired by the human brain. They consist of layers of interconnected nodes that process information using weighted connections adjusted during training."}} +{"put": "id:doc:doc::7", "fields": {"doc_id": "7", "title": "The French Revolution", "text": "The French Revolution began in 1789 with the storming of the Bastille and resulted in profound social and political changes. It ended the absolute monarchy and established principles of citizenship and rights."}} +{"put": "id:doc:doc::8", "fields": {"doc_id": "8", "title": "Quantum Mechanics", "text": "Quantum mechanics describes the behaviour of matter and energy at the atomic and subatomic level. Key principles include wave-particle duality, the uncertainty principle, and quantum entanglement."}} +{"put": "id:doc:doc::9", "fields": {"doc_id": "9", "title": "Climate Change", "text": "Climate change refers to long-term shifts in global temperatures and weather patterns, largely driven by human activities since the Industrial Revolution, primarily the burning of fossil fuels."}} +{"put": "id:doc:doc::10", "fields": {"doc_id": "10", "title": "Protein Folding", "text": "Protein folding is the physical process by which a polypeptide chain acquires its functional three-dimensional structure. Misfolded proteins are associated with diseases like Alzheimer's and Parkinson's."}} +{"put": "id:doc:doc::11", "fields": {"doc_id": "11", "title": "The Silk Road", "text": "The Silk Road was an ancient network of trade routes connecting China to the Mediterranean, facilitating the exchange of goods, ideas, technologies, and cultural practices across civilisations for centuries."}} +{"put": "id:doc:doc::12", "fields": {"doc_id": "12", "title": "Blockchain Technology", "text": "Blockchain is a distributed ledger technology that records transactions across many computers so that records cannot be altered retroactively. It underpins cryptocurrencies like Bitcoin."}} +{"put": "id:doc:doc::13", "fields": {"doc_id": "13", "title": "Plate Tectonics", "text": "Plate tectonics is the theory that Earth's outer shell is divided into large plates that float on a semi-fluid layer of the mantle. Movements at plate boundaries cause earthquakes, volcanic activity, and mountain building."}} +{"put": "id:doc:doc::14", "fields": {"doc_id": "14", "title": "The Human Immune System", "text": "The immune system is a complex network of cells, tissues, and organs that work together to defend the body against pathogens. It includes innate immunity and adaptive immunity with memory cells."}} +{"put": "id:doc:doc::15", "fields": {"doc_id": "15", "title": "Renewable Energy Sources", "text": "Renewable energy comes from naturally replenished sources such as solar, wind, hydroelectric, geothermal, and biomass. These sources produce far fewer greenhouse gas emissions than fossil fuels."}} +{"put": "id:doc:doc::16", "fields": {"doc_id": "16", "title": "The Industrial Revolution", "text": "The Industrial Revolution began in Britain in the late 18th century and transformed manufacturing through mechanisation, steam power, and later electricity. It caused massive urbanisation and social change."}} +{"put": "id:doc:doc::17", "fields": {"doc_id": "17", "title": "General Relativity", "text": "Einstein's theory of general relativity describes gravity as the curvature of spacetime caused by mass and energy. It predicts phenomena such as gravitational lensing, time dilation, and black holes."}} +{"put": "id:doc:doc::18", "fields": {"doc_id": "18", "title": "Antibiotic Resistance", "text": "Antibiotic resistance occurs when bacteria evolve mechanisms to survive exposure to antibiotics. Overuse and misuse of antibiotics in medicine and agriculture accelerate this global health threat."}} +{"put": "id:doc:doc::19", "fields": {"doc_id": "19", "title": "The Renaissance", "text": "The Renaissance was a cultural movement from the 14th to 17th century originating in Italy. It marked a renewed interest in classical art, literature, science, and humanist philosophy."}} +{"put": "id:doc:doc::20", "fields": {"doc_id": "20", "title": "Deep Learning Architectures", "text": "Deep learning uses multi-layered neural networks to learn hierarchical representations of data. Key architectures include convolutional networks for images, recurrent networks for sequences, and transformers for language."}} +{"put": "id:doc:doc::21", "fields": {"doc_id": "21", "title": "The Water Cycle", "text": "The water cycle describes the continuous movement of water through evaporation, condensation, precipitation, and collection. It is driven by solar energy and gravity, redistributing water across the planet."}} +{"put": "id:doc:doc::22", "fields": {"doc_id": "22", "title": "CRISPR Gene Editing", "text": "CRISPR-Cas9 is a revolutionary gene editing technology that allows scientists to precisely modify DNA sequences. It has applications in treating genetic diseases, improving crops, and basic research."}} +{"put": "id:doc:doc::23", "fields": {"doc_id": "23", "title": "Ancient Egyptian Civilisation", "text": "Ancient Egypt flourished along the Nile River for over three thousand years. The Egyptians built pyramids, developed hieroglyphic writing, and made advances in medicine, engineering, and astronomy."}} +{"put": "id:doc:doc::24", "fields": {"doc_id": "24", "title": "Natural Language Processing", "text": "Natural language processing enables computers to understand, interpret, and generate human language. Modern NLP relies on transformer models pre-trained on large text corpora."}} +{"put": "id:doc:doc::25", "fields": {"doc_id": "25", "title": "Volcanic Eruptions", "text": "Volcanic eruptions occur when magma from beneath the Earth's crust reaches the surface. They can cause lava flows, ash clouds, pyroclastic flows, and global temperature changes from aerosol emissions."}} +{"put": "id:doc:doc::26", "fields": {"doc_id": "26", "title": "The Cold War", "text": "The Cold War was a geopolitical tension between the United States and the Soviet Union from 1947 to 1991. It involved proxy wars, nuclear arms races, and ideological competition without direct military conflict."}} +{"put": "id:doc:doc::27", "fields": {"doc_id": "27", "title": "Superconductivity", "text": "Superconductivity is a phenomenon where certain materials exhibit zero electrical resistance below a critical temperature. Superconductors are used in MRI machines, particle accelerators, and maglev trains."}} +{"put": "id:doc:doc::28", "fields": {"doc_id": "28", "title": "Coral Reef Ecosystems", "text": "Coral reefs are diverse underwater ecosystems built by colonies of coral polyps. They support roughly 25 percent of all marine species despite covering less than one percent of the ocean floor."}} +{"put": "id:doc:doc::29", "fields": {"doc_id": "29", "title": "Autonomous Vehicles", "text": "Self-driving cars use sensors, cameras, lidar, and artificial intelligence to navigate without human input. They promise safer roads but face challenges in edge cases, regulation, and public trust."}} +{"put": "id:doc:doc::30", "fields": {"doc_id": "30", "title": "The Periodic Table", "text": "The periodic table organises chemical elements by atomic number and electron configuration. Mendeleev's original 1869 table predicted the existence and properties of elements not yet discovered."}} +{"put": "id:doc:doc::31", "fields": {"doc_id": "31", "title": "Space Exploration", "text": "Space exploration began with the launch of Sputnik in 1957 and has included crewed Moon landings, robotic Mars rovers, and the International Space Station. Private companies now contribute significantly."}} +{"put": "id:doc:doc::32", "fields": {"doc_id": "32", "title": "Microbiome Research", "text": "The human microbiome comprises trillions of microorganisms living in and on the body. Research links gut microbiome composition to digestion, immunity, mental health, and chronic diseases."}} +{"put": "id:doc:doc::33", "fields": {"doc_id": "33", "title": "The Byzantine Empire", "text": "The Byzantine Empire was the continuation of the Eastern Roman Empire, lasting from the fall of Rome in 476 AD until 1453. Constantinople was its capital and a centre of trade, art, and scholarship."}} +{"put": "id:doc:doc::34", "fields": {"doc_id": "34", "title": "Computer Vision", "text": "Computer vision is the field of AI that enables machines to interpret visual information from images and videos. Applications include facial recognition, medical imaging analysis, and autonomous navigation."}} +{"put": "id:doc:doc::35", "fields": {"doc_id": "35", "title": "Earthquake Seismology", "text": "Seismology studies earthquakes and the propagation of seismic waves through the Earth. Seismographs measure ground motion and help scientists understand Earth's internal structure and predict hazards."}} +{"put": "id:doc:doc::36", "fields": {"doc_id": "36", "title": "The Enlightenment", "text": "The Enlightenment was an intellectual movement of the 17th and 18th centuries emphasising reason, science, and individual rights. Thinkers like Locke, Voltaire, and Kant shaped modern democratic thought."}} +{"put": "id:doc:doc::37", "fields": {"doc_id": "37", "title": "Fusion Energy", "text": "Nuclear fusion is the process that powers the Sun, combining light atomic nuclei to release vast amounts of energy. Achieving controlled fusion on Earth would provide a nearly limitless clean energy source."}} +{"put": "id:doc:doc::38", "fields": {"doc_id": "38", "title": "Biodiversity Loss", "text": "Biodiversity loss refers to the decline of species variety on Earth, driven by habitat destruction, pollution, overexploitation, invasive species, and climate change. It threatens ecosystem stability."}} +{"put": "id:doc:doc::39", "fields": {"doc_id": "39", "title": "Recommender Systems", "text": "Recommender systems suggest relevant items to users based on preferences and behaviour. Collaborative filtering and content-based filtering are the two primary approaches."}} +{"put": "id:doc:doc::40", "fields": {"doc_id": "40", "title": "The Circulatory System", "text": "The circulatory system transports blood, oxygen, and nutrients throughout the body via the heart, arteries, veins, and capillaries. It also removes metabolic waste products."}} +{"put": "id:doc:doc::41", "fields": {"doc_id": "41", "title": "Cryptography Fundamentals", "text": "Cryptography secures communication through mathematical techniques. Modern methods include symmetric encryption like AES, asymmetric encryption like RSA, and hash functions like SHA-256."}} +{"put": "id:doc:doc::42", "fields": {"doc_id": "42", "title": "The Mongol Empire", "text": "The Mongol Empire, founded by Genghis Khan in 1206, became the largest contiguous land empire in history. It facilitated trade, communication, and cultural exchange across Eurasia."}} +{"put": "id:doc:doc::43", "fields": {"doc_id": "43", "title": "Gravitational Waves", "text": "Gravitational waves are ripples in spacetime caused by accelerating massive objects. First detected by LIGO in 2015, they opened a new window for observing the universe."}} +{"put": "id:doc:doc::44", "fields": {"doc_id": "44", "title": "Deforestation", "text": "Deforestation is the large-scale removal of forest cover, primarily for agriculture, logging, and urban expansion. It contributes to carbon emissions, soil erosion, and loss of habitat."}} +{"put": "id:doc:doc::45", "fields": {"doc_id": "45", "title": "Information Retrieval", "text": "Information retrieval is the science of finding relevant documents from a large collection given a user query. It underpins search engines and uses techniques like inverted indexes, TF-IDF, and neural ranking."}} +{"put": "id:doc:doc::46", "fields": {"doc_id": "46", "title": "The Nervous System", "text": "The nervous system coordinates the body's actions by transmitting electrical signals between the brain, spinal cord, and peripheral nerves. Neurons communicate via synapses using neurotransmitters."}} +{"put": "id:doc:doc::47", "fields": {"doc_id": "47", "title": "Cloud Computing", "text": "Cloud computing delivers computing services over the internet, including servers, storage, databases, networking, and software. Major providers include AWS, Azure, and Google Cloud Platform."}} +{"put": "id:doc:doc::48", "fields": {"doc_id": "48", "title": "The Ottoman Empire", "text": "The Ottoman Empire was a vast state founded in 1299 that at its height controlled Southeast Europe, Western Asia, and North Africa. It lasted until 1922, influencing art, law, and governance."}} +{"put": "id:doc:doc::49", "fields": {"doc_id": "49", "title": "Dark Matter and Dark Energy", "text": "Dark matter and dark energy make up about 95 percent of the universe's total mass-energy content. Dark matter provides gravitational scaffolding for galaxies, while dark energy drives cosmic expansion."}} +{"put": "id:doc:doc::50", "fields": {"doc_id": "50", "title": "Soil Science", "text": "Soil is a complex mixture of minerals, organic matter, water, and air that supports plant growth. Soil health affects agricultural productivity, water filtration, and carbon storage."}} +{"put": "id:doc:doc::51", "fields": {"doc_id": "51", "title": "Graph Neural Networks", "text": "Graph neural networks extend deep learning to graph-structured data, learning representations of nodes, edges, and entire graphs. Applications include molecular property prediction and social network analysis."}} +{"put": "id:doc:doc::52", "fields": {"doc_id": "52", "title": "The American Civil War", "text": "The American Civil War from 1861 to 1865 was fought between the Union and the Confederacy over slavery and states' rights. It resulted in the abolition of slavery and the preservation of the Union."}} +{"put": "id:doc:doc::53", "fields": {"doc_id": "53", "title": "Semiconductor Physics", "text": "Semiconductors are materials with electrical conductivity between conductors and insulators. Silicon-based semiconductors are the foundation of modern electronics, from transistors to integrated circuits."}} +{"put": "id:doc:doc::54", "fields": {"doc_id": "54", "title": "Pollinator Decline", "text": "Pollinators like bees, butterflies, and bats are declining due to pesticides, habitat loss, and disease. Their loss threatens food production since many crops depend on animal pollination."}} +{"put": "id:doc:doc::55", "fields": {"doc_id": "55", "title": "Transformer Models", "text": "The transformer architecture uses self-attention mechanisms to process sequences in parallel, replacing recurrent models. It powers modern language models like BERT, GPT, and ColBERT for retrieval."}} +{"put": "id:doc:doc::56", "fields": {"doc_id": "56", "title": "The Respiratory System", "text": "The respiratory system facilitates gas exchange, bringing oxygen into the body and removing carbon dioxide. Air passes through the trachea to the bronchi and into the alveoli in the lungs."}} +{"put": "id:doc:doc::57", "fields": {"doc_id": "57", "title": "Distributed Systems", "text": "Distributed systems are networked computers that coordinate to achieve a common goal. Challenges include consistency, availability, partition tolerance, and the CAP theorem constrains their design."}} +{"put": "id:doc:doc::58", "fields": {"doc_id": "58", "title": "The Ming Dynasty", "text": "The Ming Dynasty ruled China from 1368 to 1644, known for its strong centralised government, the construction of the Forbidden City, and maritime expeditions led by Admiral Zheng He."}} +{"put": "id:doc:doc::59", "fields": {"doc_id": "59", "title": "Exoplanet Discovery", "text": "Exoplanets are planets orbiting stars outside our solar system. Thousands have been discovered using transit photometry and radial velocity methods, with some in habitable zones."}} +{"put": "id:doc:doc::60", "fields": {"doc_id": "60", "title": "Wetland Ecosystems", "text": "Wetlands are transitional areas between land and water that provide critical ecosystem services. They filter pollutants, buffer floods, store carbon, and support diverse wildlife populations."}} +{"put": "id:doc:doc::61", "fields": {"doc_id": "61", "title": "Attention Mechanisms", "text": "Attention mechanisms allow neural networks to focus on relevant parts of the input when producing output. Self-attention computes relevance scores between all positions in a sequence."}} +{"put": "id:doc:doc::62", "fields": {"doc_id": "62", "title": "The Digestive System", "text": "The digestive system breaks down food into nutrients that the body uses for energy, growth, and repair. It includes the mouth, oesophagus, stomach, small intestine, and large intestine."}} +{"put": "id:doc:doc::63", "fields": {"doc_id": "63", "title": "Cybersecurity Threats", "text": "Common cybersecurity threats include malware, phishing, ransomware, denial-of-service attacks, and supply chain compromises. Defence requires layered security, patching, and user awareness."}} +{"put": "id:doc:doc::64", "fields": {"doc_id": "64", "title": "The Inca Empire", "text": "The Inca Empire was the largest pre-Columbian empire in the Americas, stretching along the Andes mountains. They built Machu Picchu, developed an extensive road system, and used quipu for record keeping."}} +{"put": "id:doc:doc::65", "fields": {"doc_id": "65", "title": "Stellar Evolution", "text": "Stars form from collapsing clouds of gas and dust, undergo nuclear fusion on the main sequence, and end their lives as white dwarfs, neutron stars, or black holes depending on their mass."}} +{"put": "id:doc:doc::66", "fields": {"doc_id": "66", "title": "Freshwater Scarcity", "text": "Freshwater scarcity affects billions of people worldwide due to population growth, pollution, and climate change. Only about three percent of Earth's water is fresh, and most is locked in ice."}} +{"put": "id:doc:doc::67", "fields": {"doc_id": "67", "title": "Vector Databases", "text": "Vector databases are specialised systems for storing and querying high-dimensional vector embeddings. They use approximate nearest neighbour algorithms like HNSW for fast similarity search."}} +{"put": "id:doc:doc::68", "fields": {"doc_id": "68", "title": "The Endocrine System", "text": "The endocrine system regulates body functions through hormones secreted by glands such as the pituitary, thyroid, adrenal glands, and pancreas. It controls metabolism, growth, and reproduction."}} +{"put": "id:doc:doc::69", "fields": {"doc_id": "69", "title": "Containerisation Technology", "text": "Containers package applications with their dependencies for consistent deployment across environments. Docker popularised containerisation, while Kubernetes orchestrates container workloads at scale."}} +{"put": "id:doc:doc::70", "fields": {"doc_id": "70", "title": "The Viking Age", "text": "The Viking Age spanned from roughly 793 to 1066 AD, during which Norse seafarers from Scandinavia explored, traded, and settled across Europe, the North Atlantic, and even North America."}} +{"put": "id:doc:doc::71", "fields": {"doc_id": "71", "title": "Neutrino Physics", "text": "Neutrinos are nearly massless subatomic particles that interact very weakly with matter. Billions pass through your body every second. Their study reveals information about nuclear reactions in stars."}} +{"put": "id:doc:doc::72", "fields": {"doc_id": "72", "title": "Urban Heat Islands", "text": "Urban heat islands are metropolitan areas significantly warmer than surrounding rural areas due to human activities, dark surfaces absorbing heat, and reduced vegetation and evapotranspiration."}} +{"put": "id:doc:doc::73", "fields": {"doc_id": "73", "title": "Embedding Models", "text": "Embedding models map discrete objects like words or documents to continuous vector spaces where semantic similarity is preserved as geometric distance. ColBERT produces multi-vector embeddings per token."}} +{"put": "id:doc:doc::74", "fields": {"doc_id": "74", "title": "The Skeletal System", "text": "The skeletal system provides structural support, protects organs, enables movement, stores minerals, and produces blood cells in bone marrow. Adults have 206 bones connected by joints and ligaments."}} +{"put": "id:doc:doc::75", "fields": {"doc_id": "75", "title": "API Design Principles", "text": "Good API design follows principles of consistency, discoverability, and minimal surprise. RESTful APIs use HTTP verbs and resource-based URLs, while GraphQL offers flexible query capabilities."}} +{"put": "id:doc:doc::76", "fields": {"doc_id": "76", "title": "The Mughal Empire", "text": "The Mughal Empire ruled much of the Indian subcontinent from 1526 to 1857. It was known for cultural achievements including the Taj Mahal, miniature painting, and a sophisticated administrative system."}} +{"put": "id:doc:doc::77", "fields": {"doc_id": "77", "title": "Cosmic Microwave Background", "text": "The cosmic microwave background is the residual radiation from the Big Bang, filling the universe as a nearly uniform glow at about 2.7 Kelvin. Its tiny fluctuations map the early universe's density variations."}} +{"put": "id:doc:doc::78", "fields": {"doc_id": "78", "title": "Glacier Retreat", "text": "Glaciers worldwide are retreating due to rising temperatures. Their loss contributes to sea level rise, alters freshwater supplies for millions of people, and affects ecosystems that depend on glacial meltwater."}} +{"put": "id:doc:doc::79", "fields": {"doc_id": "79", "title": "Late Interaction Retrieval", "text": "Late interaction models like ColBERT represent queries and documents as sets of token embeddings and score relevance via MaxSim. This balances the effectiveness of cross-encoders with the efficiency of bi-encoders."}} +{"put": "id:doc:doc::80", "fields": {"doc_id": "80", "title": "The Musculoskeletal System", "text": "Muscles work with bones and joints to produce movement. Skeletal muscles are voluntary and contract in response to nerve signals. Tendons attach muscles to bones, while ligaments connect bones to each other."}} +{"put": "id:doc:doc::81", "fields": {"doc_id": "81", "title": "DevOps Practices", "text": "DevOps combines software development and IT operations to shorten the development lifecycle. Key practices include continuous integration, continuous delivery, infrastructure as code, and monitoring."}} +{"put": "id:doc:doc::82", "fields": {"doc_id": "82", "title": "The Aztec Empire", "text": "The Aztec Empire flourished in central Mexico from the 14th to 16th century. Tenochtitlan, their capital built on an island in Lake Texcoco, was one of the largest cities in the world at the time."}} +{"put": "id:doc:doc::83", "fields": {"doc_id": "83", "title": "Black Hole Physics", "text": "Black holes are regions of spacetime where gravity is so strong that nothing, not even light, can escape. They form from the collapse of massive stars and are detected through gravitational effects."}} +{"put": "id:doc:doc::84", "fields": {"doc_id": "84", "title": "Air Quality and Health", "text": "Poor air quality from particulate matter, ozone, and nitrogen oxides causes respiratory and cardiovascular diseases. Indoor and outdoor air pollution contributes to millions of premature deaths annually."}} +{"put": "id:doc:doc::85", "fields": {"doc_id": "85", "title": "Retrieval Augmented Generation", "text": "Retrieval augmented generation combines information retrieval with language model generation. The system first retrieves relevant documents, then conditions the language model's output on the retrieved context."}} +{"put": "id:doc:doc::86", "fields": {"doc_id": "86", "title": "The Lymphatic System", "text": "The lymphatic system is a network of vessels, nodes, and organs that maintains fluid balance, absorbs fats from digestion, and plays a key role in immune defence by filtering pathogens."}} +{"put": "id:doc:doc::87", "fields": {"doc_id": "87", "title": "Database Indexing", "text": "Database indexes are data structures that improve the speed of data retrieval operations. B-tree indexes handle range queries efficiently, while hash indexes excel at exact-match lookups."}} +{"put": "id:doc:doc::88", "fields": {"doc_id": "88", "title": "The Han Dynasty", "text": "The Han Dynasty ruled China from 206 BC to 220 AD and is considered a golden age. It established the Silk Road trade, advanced paper making, and created a civil service examination system."}} +{"put": "id:doc:doc::89", "fields": {"doc_id": "89", "title": "Magnetar Stars", "text": "Magnetars are neutron stars with extremely powerful magnetic fields, a trillion times stronger than Earth's. They produce intense bursts of X-rays and gamma rays and are the strongest magnets known in the universe."}} +{"put": "id:doc:doc::90", "fields": {"doc_id": "90", "title": "Plastic Pollution", "text": "Plastic pollution accumulates in oceans, rivers, and soil, harming wildlife through ingestion and entanglement. Microplastics enter food chains and have been found in human blood and organs."}} +{"put": "id:doc:doc::91", "fields": {"doc_id": "91", "title": "Semantic Search", "text": "Semantic search goes beyond keyword matching to understand the meaning and intent behind queries. It uses dense vector representations and neural models to find conceptually relevant results."}} +{"put": "id:doc:doc::92", "fields": {"doc_id": "92", "title": "The Renal System", "text": "The kidneys filter blood to remove waste products and excess fluid, producing urine. They also regulate electrolyte balance, blood pressure through the renin-angiotensin system, and red blood cell production."}} +{"put": "id:doc:doc::93", "fields": {"doc_id": "93", "title": "Microservices Architecture", "text": "Microservices decompose applications into small, independently deployable services that communicate via APIs. This approach improves scalability and allows teams to develop and deploy services independently."}} +{"put": "id:doc:doc::94", "fields": {"doc_id": "94", "title": "The Gupta Empire", "text": "The Gupta Empire in India from 320 to 550 AD is called the Golden Age of India. It saw advances in mathematics including the concept of zero, astronomy, literature, and art."}} +{"put": "id:doc:doc::95", "fields": {"doc_id": "95", "title": "Gamma-Ray Bursts", "text": "Gamma-ray bursts are the most energetic electromagnetic events in the universe, lasting from milliseconds to several hours. They are associated with supernovae and neutron star mergers."}} +{"put": "id:doc:doc::96", "fields": {"doc_id": "96", "title": "Ocean Acidification", "text": "Ocean acidification occurs as seawater absorbs excess carbon dioxide from the atmosphere, lowering its pH. This threatens shell-forming organisms, coral reefs, and entire marine food webs."}} +{"put": "id:doc:doc::97", "fields": {"doc_id": "97", "title": "Dense Retrieval", "text": "Dense retrieval uses learned vector representations to match queries and documents. Unlike sparse methods such as BM25, dense models capture semantic relationships beyond exact term overlap."}} +{"put": "id:doc:doc::98", "fields": {"doc_id": "98", "title": "The Reproductive System", "text": "The reproductive system enables organisms to produce offspring. In humans, it involves the ovaries and uterus in females and the testes in males, with hormones coordinating reproductive cycles."}} +{"put": "id:doc:doc::99", "fields": {"doc_id": "99", "title": "Observability in Software", "text": "Observability encompasses logging, metrics, and tracing to understand system behaviour in production. It helps engineers diagnose issues, track performance, and improve reliability of distributed systems."}} diff --git a/colbert-pooling/feed.py b/colbert-pooling/feed.py new file mode 100644 index 000000000..d094969a7 --- /dev/null +++ b/colbert-pooling/feed.py @@ -0,0 +1,174 @@ +#!/usr/bin/env python3 +"""Feed 100 sample documents to Vespa. + +All embedding (non-pooled, pooled, pooled-binary) is done inside Vespa by +the PoolingColBertEmbedder. This script only sends text. + +Usage: + uv run python feed.py [--endpoint http://localhost:8080] + # Or generate JSONL for vespa-cli: + uv run python feed.py --dry-run +""" + +from __future__ import annotations + +import argparse +import json +import time + +import requests + +# --------------------------------------------------------------------------- +# 100 sample documents – diverse topics for meaningful retrieval testing +# --------------------------------------------------------------------------- + +DOCUMENTS: list[dict[str, str]] = [ + {"title": "Solar System Overview", "text": "The solar system consists of the Sun and the celestial bodies that orbit it, including eight planets, their moons, dwarf planets, asteroids, and comets. The four inner planets are rocky worlds, while the outer planets are gas and ice giants."}, + {"title": "Photosynthesis", "text": "Photosynthesis is the biological process by which green plants and certain organisms convert light energy into chemical energy stored in glucose. It takes place primarily in chloroplasts using chlorophyll pigments."}, + {"title": "Machine Learning Basics", "text": "Machine learning is a branch of artificial intelligence where computers learn patterns from data without being explicitly programmed. Common approaches include supervised learning, unsupervised learning, and reinforcement learning."}, + {"title": "The Roman Empire", "text": "The Roman Empire was one of the largest empires in ancient history, spanning from Britain to Mesopotamia at its peak. It profoundly influenced Western civilisation through its legal system, architecture, engineering, and language."}, + {"title": "DNA Structure", "text": "Deoxyribonucleic acid is a double helix molecule that carries the genetic instructions for the development and functioning of all known living organisms. Watson and Crick described its structure in 1953."}, + {"title": "Ocean Currents", "text": "Ocean currents are continuous movements of seawater driven by wind, temperature differences, salinity, and the Earth's rotation. The Gulf Stream carries warm water from the tropics toward northern Europe."}, + {"title": "Neural Networks", "text": "Artificial neural networks are computational models inspired by the human brain. They consist of layers of interconnected nodes that process information using weighted connections adjusted during training."}, + {"title": "The French Revolution", "text": "The French Revolution began in 1789 with the storming of the Bastille and resulted in profound social and political changes. It ended the absolute monarchy and established principles of citizenship and rights."}, + {"title": "Quantum Mechanics", "text": "Quantum mechanics describes the behaviour of matter and energy at the atomic and subatomic level. Key principles include wave-particle duality, the uncertainty principle, and quantum entanglement."}, + {"title": "Climate Change", "text": "Climate change refers to long-term shifts in global temperatures and weather patterns, largely driven by human activities since the Industrial Revolution, primarily the burning of fossil fuels."}, + {"title": "Protein Folding", "text": "Protein folding is the physical process by which a polypeptide chain acquires its functional three-dimensional structure. Misfolded proteins are associated with diseases like Alzheimer's and Parkinson's."}, + {"title": "The Silk Road", "text": "The Silk Road was an ancient network of trade routes connecting China to the Mediterranean, facilitating the exchange of goods, ideas, technologies, and cultural practices across civilisations for centuries."}, + {"title": "Blockchain Technology", "text": "Blockchain is a distributed ledger technology that records transactions across many computers so that records cannot be altered retroactively. It underpins cryptocurrencies like Bitcoin."}, + {"title": "Plate Tectonics", "text": "Plate tectonics is the theory that Earth's outer shell is divided into large plates that float on a semi-fluid layer of the mantle. Movements at plate boundaries cause earthquakes, volcanic activity, and mountain building."}, + {"title": "The Human Immune System", "text": "The immune system is a complex network of cells, tissues, and organs that work together to defend the body against pathogens. It includes innate immunity and adaptive immunity with memory cells."}, + {"title": "Renewable Energy Sources", "text": "Renewable energy comes from naturally replenished sources such as solar, wind, hydroelectric, geothermal, and biomass. These sources produce far fewer greenhouse gas emissions than fossil fuels."}, + {"title": "The Industrial Revolution", "text": "The Industrial Revolution began in Britain in the late 18th century and transformed manufacturing through mechanisation, steam power, and later electricity. It caused massive urbanisation and social change."}, + {"title": "General Relativity", "text": "Einstein's theory of general relativity describes gravity as the curvature of spacetime caused by mass and energy. It predicts phenomena such as gravitational lensing, time dilation, and black holes."}, + {"title": "Antibiotic Resistance", "text": "Antibiotic resistance occurs when bacteria evolve mechanisms to survive exposure to antibiotics. Overuse and misuse of antibiotics in medicine and agriculture accelerate this global health threat."}, + {"title": "The Renaissance", "text": "The Renaissance was a cultural movement from the 14th to 17th century originating in Italy. It marked a renewed interest in classical art, literature, science, and humanist philosophy."}, + {"title": "Deep Learning Architectures", "text": "Deep learning uses multi-layered neural networks to learn hierarchical representations of data. Key architectures include convolutional networks for images, recurrent networks for sequences, and transformers for language."}, + {"title": "The Water Cycle", "text": "The water cycle describes the continuous movement of water through evaporation, condensation, precipitation, and collection. It is driven by solar energy and gravity, redistributing water across the planet."}, + {"title": "CRISPR Gene Editing", "text": "CRISPR-Cas9 is a revolutionary gene editing technology that allows scientists to precisely modify DNA sequences. It has applications in treating genetic diseases, improving crops, and basic research."}, + {"title": "Ancient Egyptian Civilisation", "text": "Ancient Egypt flourished along the Nile River for over three thousand years. The Egyptians built pyramids, developed hieroglyphic writing, and made advances in medicine, engineering, and astronomy."}, + {"title": "Natural Language Processing", "text": "Natural language processing enables computers to understand, interpret, and generate human language. Modern NLP relies on transformer models pre-trained on large text corpora."}, + {"title": "Volcanic Eruptions", "text": "Volcanic eruptions occur when magma from beneath the Earth's crust reaches the surface. They can cause lava flows, ash clouds, pyroclastic flows, and global temperature changes from aerosol emissions."}, + {"title": "The Cold War", "text": "The Cold War was a geopolitical tension between the United States and the Soviet Union from 1947 to 1991. It involved proxy wars, nuclear arms races, and ideological competition without direct military conflict."}, + {"title": "Superconductivity", "text": "Superconductivity is a phenomenon where certain materials exhibit zero electrical resistance below a critical temperature. Superconductors are used in MRI machines, particle accelerators, and maglev trains."}, + {"title": "Coral Reef Ecosystems", "text": "Coral reefs are diverse underwater ecosystems built by colonies of coral polyps. They support roughly 25 percent of all marine species despite covering less than one percent of the ocean floor."}, + {"title": "Autonomous Vehicles", "text": "Self-driving cars use sensors, cameras, lidar, and artificial intelligence to navigate without human input. They promise safer roads but face challenges in edge cases, regulation, and public trust."}, + {"title": "The Periodic Table", "text": "The periodic table organises chemical elements by atomic number and electron configuration. Mendeleev's original 1869 table predicted the existence and properties of elements not yet discovered."}, + {"title": "Space Exploration", "text": "Space exploration began with the launch of Sputnik in 1957 and has included crewed Moon landings, robotic Mars rovers, and the International Space Station. Private companies now contribute significantly."}, + {"title": "Microbiome Research", "text": "The human microbiome comprises trillions of microorganisms living in and on the body. Research links gut microbiome composition to digestion, immunity, mental health, and chronic diseases."}, + {"title": "The Byzantine Empire", "text": "The Byzantine Empire was the continuation of the Eastern Roman Empire, lasting from the fall of Rome in 476 AD until 1453. Constantinople was its capital and a centre of trade, art, and scholarship."}, + {"title": "Computer Vision", "text": "Computer vision is the field of AI that enables machines to interpret visual information from images and videos. Applications include facial recognition, medical imaging analysis, and autonomous navigation."}, + {"title": "Earthquake Seismology", "text": "Seismology studies earthquakes and the propagation of seismic waves through the Earth. Seismographs measure ground motion and help scientists understand Earth's internal structure and predict hazards."}, + {"title": "The Enlightenment", "text": "The Enlightenment was an intellectual movement of the 17th and 18th centuries emphasising reason, science, and individual rights. Thinkers like Locke, Voltaire, and Kant shaped modern democratic thought."}, + {"title": "Fusion Energy", "text": "Nuclear fusion is the process that powers the Sun, combining light atomic nuclei to release vast amounts of energy. Achieving controlled fusion on Earth would provide a nearly limitless clean energy source."}, + {"title": "Biodiversity Loss", "text": "Biodiversity loss refers to the decline of species variety on Earth, driven by habitat destruction, pollution, overexploitation, invasive species, and climate change. It threatens ecosystem stability."}, + {"title": "Recommender Systems", "text": "Recommender systems suggest relevant items to users based on preferences and behaviour. Collaborative filtering and content-based filtering are the two primary approaches."}, + {"title": "The Circulatory System", "text": "The circulatory system transports blood, oxygen, and nutrients throughout the body via the heart, arteries, veins, and capillaries. It also removes metabolic waste products."}, + {"title": "Cryptography Fundamentals", "text": "Cryptography secures communication through mathematical techniques. Modern methods include symmetric encryption like AES, asymmetric encryption like RSA, and hash functions like SHA-256."}, + {"title": "The Mongol Empire", "text": "The Mongol Empire, founded by Genghis Khan in 1206, became the largest contiguous land empire in history. It facilitated trade, communication, and cultural exchange across Eurasia."}, + {"title": "Gravitational Waves", "text": "Gravitational waves are ripples in spacetime caused by accelerating massive objects. First detected by LIGO in 2015, they opened a new window for observing the universe."}, + {"title": "Deforestation", "text": "Deforestation is the large-scale removal of forest cover, primarily for agriculture, logging, and urban expansion. It contributes to carbon emissions, soil erosion, and loss of habitat."}, + {"title": "Information Retrieval", "text": "Information retrieval is the science of finding relevant documents from a large collection given a user query. It underpins search engines and uses techniques like inverted indexes, TF-IDF, and neural ranking."}, + {"title": "The Nervous System", "text": "The nervous system coordinates the body's actions by transmitting electrical signals between the brain, spinal cord, and peripheral nerves. Neurons communicate via synapses using neurotransmitters."}, + {"title": "Cloud Computing", "text": "Cloud computing delivers computing services over the internet, including servers, storage, databases, networking, and software. Major providers include AWS, Azure, and Google Cloud Platform."}, + {"title": "The Ottoman Empire", "text": "The Ottoman Empire was a vast state founded in 1299 that at its height controlled Southeast Europe, Western Asia, and North Africa. It lasted until 1922, influencing art, law, and governance."}, + {"title": "Dark Matter and Dark Energy", "text": "Dark matter and dark energy make up about 95 percent of the universe's total mass-energy content. Dark matter provides gravitational scaffolding for galaxies, while dark energy drives cosmic expansion."}, + {"title": "Soil Science", "text": "Soil is a complex mixture of minerals, organic matter, water, and air that supports plant growth. Soil health affects agricultural productivity, water filtration, and carbon storage."}, + {"title": "Graph Neural Networks", "text": "Graph neural networks extend deep learning to graph-structured data, learning representations of nodes, edges, and entire graphs. Applications include molecular property prediction and social network analysis."}, + {"title": "The American Civil War", "text": "The American Civil War from 1861 to 1865 was fought between the Union and the Confederacy over slavery and states' rights. It resulted in the abolition of slavery and the preservation of the Union."}, + {"title": "Semiconductor Physics", "text": "Semiconductors are materials with electrical conductivity between conductors and insulators. Silicon-based semiconductors are the foundation of modern electronics, from transistors to integrated circuits."}, + {"title": "Pollinator Decline", "text": "Pollinators like bees, butterflies, and bats are declining due to pesticides, habitat loss, and disease. Their loss threatens food production since many crops depend on animal pollination."}, + {"title": "Transformer Models", "text": "The transformer architecture uses self-attention mechanisms to process sequences in parallel, replacing recurrent models. It powers modern language models like BERT, GPT, and ColBERT for retrieval."}, + {"title": "The Respiratory System", "text": "The respiratory system facilitates gas exchange, bringing oxygen into the body and removing carbon dioxide. Air passes through the trachea to the bronchi and into the alveoli in the lungs."}, + {"title": "Distributed Systems", "text": "Distributed systems are networked computers that coordinate to achieve a common goal. Challenges include consistency, availability, partition tolerance, and the CAP theorem constrains their design."}, + {"title": "The Ming Dynasty", "text": "The Ming Dynasty ruled China from 1368 to 1644, known for its strong centralised government, the construction of the Forbidden City, and maritime expeditions led by Admiral Zheng He."}, + {"title": "Exoplanet Discovery", "text": "Exoplanets are planets orbiting stars outside our solar system. Thousands have been discovered using transit photometry and radial velocity methods, with some in habitable zones."}, + {"title": "Wetland Ecosystems", "text": "Wetlands are transitional areas between land and water that provide critical ecosystem services. They filter pollutants, buffer floods, store carbon, and support diverse wildlife populations."}, + {"title": "Attention Mechanisms", "text": "Attention mechanisms allow neural networks to focus on relevant parts of the input when producing output. Self-attention computes relevance scores between all positions in a sequence."}, + {"title": "The Digestive System", "text": "The digestive system breaks down food into nutrients that the body uses for energy, growth, and repair. It includes the mouth, oesophagus, stomach, small intestine, and large intestine."}, + {"title": "Cybersecurity Threats", "text": "Common cybersecurity threats include malware, phishing, ransomware, denial-of-service attacks, and supply chain compromises. Defence requires layered security, patching, and user awareness."}, + {"title": "The Inca Empire", "text": "The Inca Empire was the largest pre-Columbian empire in the Americas, stretching along the Andes mountains. They built Machu Picchu, developed an extensive road system, and used quipu for record keeping."}, + {"title": "Stellar Evolution", "text": "Stars form from collapsing clouds of gas and dust, undergo nuclear fusion on the main sequence, and end their lives as white dwarfs, neutron stars, or black holes depending on their mass."}, + {"title": "Freshwater Scarcity", "text": "Freshwater scarcity affects billions of people worldwide due to population growth, pollution, and climate change. Only about three percent of Earth's water is fresh, and most is locked in ice."}, + {"title": "Vector Databases", "text": "Vector databases are specialised systems for storing and querying high-dimensional vector embeddings. They use approximate nearest neighbour algorithms like HNSW for fast similarity search."}, + {"title": "The Endocrine System", "text": "The endocrine system regulates body functions through hormones secreted by glands such as the pituitary, thyroid, adrenal glands, and pancreas. It controls metabolism, growth, and reproduction."}, + {"title": "Containerisation Technology", "text": "Containers package applications with their dependencies for consistent deployment across environments. Docker popularised containerisation, while Kubernetes orchestrates container workloads at scale."}, + {"title": "The Viking Age", "text": "The Viking Age spanned from roughly 793 to 1066 AD, during which Norse seafarers from Scandinavia explored, traded, and settled across Europe, the North Atlantic, and even North America."}, + {"title": "Neutrino Physics", "text": "Neutrinos are nearly massless subatomic particles that interact very weakly with matter. Billions pass through your body every second. Their study reveals information about nuclear reactions in stars."}, + {"title": "Urban Heat Islands", "text": "Urban heat islands are metropolitan areas significantly warmer than surrounding rural areas due to human activities, dark surfaces absorbing heat, and reduced vegetation and evapotranspiration."}, + {"title": "Embedding Models", "text": "Embedding models map discrete objects like words or documents to continuous vector spaces where semantic similarity is preserved as geometric distance. ColBERT produces multi-vector embeddings per token."}, + {"title": "The Skeletal System", "text": "The skeletal system provides structural support, protects organs, enables movement, stores minerals, and produces blood cells in bone marrow. Adults have 206 bones connected by joints and ligaments."}, + {"title": "API Design Principles", "text": "Good API design follows principles of consistency, discoverability, and minimal surprise. RESTful APIs use HTTP verbs and resource-based URLs, while GraphQL offers flexible query capabilities."}, + {"title": "The Mughal Empire", "text": "The Mughal Empire ruled much of the Indian subcontinent from 1526 to 1857. It was known for cultural achievements including the Taj Mahal, miniature painting, and a sophisticated administrative system."}, + {"title": "Cosmic Microwave Background", "text": "The cosmic microwave background is the residual radiation from the Big Bang, filling the universe as a nearly uniform glow at about 2.7 Kelvin. Its tiny fluctuations map the early universe's density variations."}, + {"title": "Glacier Retreat", "text": "Glaciers worldwide are retreating due to rising temperatures. Their loss contributes to sea level rise, alters freshwater supplies for millions of people, and affects ecosystems that depend on glacial meltwater."}, + {"title": "Late Interaction Retrieval", "text": "Late interaction models like ColBERT represent queries and documents as sets of token embeddings and score relevance via MaxSim. This balances the effectiveness of cross-encoders with the efficiency of bi-encoders."}, + {"title": "The Musculoskeletal System", "text": "Muscles work with bones and joints to produce movement. Skeletal muscles are voluntary and contract in response to nerve signals. Tendons attach muscles to bones, while ligaments connect bones to each other."}, + {"title": "DevOps Practices", "text": "DevOps combines software development and IT operations to shorten the development lifecycle. Key practices include continuous integration, continuous delivery, infrastructure as code, and monitoring."}, + {"title": "The Aztec Empire", "text": "The Aztec Empire flourished in central Mexico from the 14th to 16th century. Tenochtitlan, their capital built on an island in Lake Texcoco, was one of the largest cities in the world at the time."}, + {"title": "Black Hole Physics", "text": "Black holes are regions of spacetime where gravity is so strong that nothing, not even light, can escape. They form from the collapse of massive stars and are detected through gravitational effects."}, + {"title": "Air Quality and Health", "text": "Poor air quality from particulate matter, ozone, and nitrogen oxides causes respiratory and cardiovascular diseases. Indoor and outdoor air pollution contributes to millions of premature deaths annually."}, + {"title": "Retrieval Augmented Generation", "text": "Retrieval augmented generation combines information retrieval with language model generation. The system first retrieves relevant documents, then conditions the language model's output on the retrieved context."}, + {"title": "The Lymphatic System", "text": "The lymphatic system is a network of vessels, nodes, and organs that maintains fluid balance, absorbs fats from digestion, and plays a key role in immune defence by filtering pathogens."}, + {"title": "Database Indexing", "text": "Database indexes are data structures that improve the speed of data retrieval operations. B-tree indexes handle range queries efficiently, while hash indexes excel at exact-match lookups."}, + {"title": "The Han Dynasty", "text": "The Han Dynasty ruled China from 206 BC to 220 AD and is considered a golden age. It established the Silk Road trade, advanced paper making, and created a civil service examination system."}, + {"title": "Magnetar Stars", "text": "Magnetars are neutron stars with extremely powerful magnetic fields, a trillion times stronger than Earth's. They produce intense bursts of X-rays and gamma rays and are the strongest magnets known in the universe."}, + {"title": "Plastic Pollution", "text": "Plastic pollution accumulates in oceans, rivers, and soil, harming wildlife through ingestion and entanglement. Microplastics enter food chains and have been found in human blood and organs."}, + {"title": "Semantic Search", "text": "Semantic search goes beyond keyword matching to understand the meaning and intent behind queries. It uses dense vector representations and neural models to find conceptually relevant results."}, + {"title": "The Renal System", "text": "The kidneys filter blood to remove waste products and excess fluid, producing urine. They also regulate electrolyte balance, blood pressure through the renin-angiotensin system, and red blood cell production."}, + {"title": "Microservices Architecture", "text": "Microservices decompose applications into small, independently deployable services that communicate via APIs. This approach improves scalability and allows teams to develop and deploy services independently."}, + {"title": "The Gupta Empire", "text": "The Gupta Empire in India from 320 to 550 AD is called the Golden Age of India. It saw advances in mathematics including the concept of zero, astronomy, literature, and art."}, + {"title": "Gamma-Ray Bursts", "text": "Gamma-ray bursts are the most energetic electromagnetic events in the universe, lasting from milliseconds to several hours. They are associated with supernovae and neutron star mergers."}, + {"title": "Ocean Acidification", "text": "Ocean acidification occurs as seawater absorbs excess carbon dioxide from the atmosphere, lowering its pH. This threatens shell-forming organisms, coral reefs, and entire marine food webs."}, + {"title": "Dense Retrieval", "text": "Dense retrieval uses learned vector representations to match queries and documents. Unlike sparse methods such as BM25, dense models capture semantic relationships beyond exact term overlap."}, + {"title": "The Reproductive System", "text": "The reproductive system enables organisms to produce offspring. In humans, it involves the ovaries and uterus in females and the testes in males, with hormones coordinating reproductive cycles."}, + {"title": "Observability in Software", "text": "Observability encompasses logging, metrics, and tracing to understand system behaviour in production. It helps engineers diagnose issues, track performance, and improve reliability of distributed systems."}, +] + +assert len(DOCUMENTS) == 100, f"Expected 100 documents, got {len(DOCUMENTS)}" + + +def main() -> None: + import argparse + parser = argparse.ArgumentParser(description="Feed 100 documents to Vespa") + parser.add_argument("--endpoint", default="http://localhost:8080") + parser.add_argument("--dry-run", action="store_true", + help="Write JSONL instead of feeding to Vespa") + args = parser.parse_args() + + feed_records = [] + for i, doc in enumerate(DOCUMENTS): + rec = { + "put": f"id:doc:doc::{i}", + "fields": { + "doc_id": str(i), + "title": doc["title"], + "text": doc["text"], + }, + } + feed_records.append(rec) + + if args.dry_run: + from pathlib import Path + out = Path("ext") / "feed.jsonl" + out.parent.mkdir(exist_ok=True) + with open(out, "w") as f: + for rec in feed_records: + f.write(json.dumps(rec) + "\n") + print(f"Wrote {len(feed_records)} documents to {out}") + print("Feed with: vespa feed ext/feed.jsonl") + else: + t0 = time.time() + for i, rec in enumerate(feed_records): + url = f"{args.endpoint}/document/v1/doc/doc/docid/{i}" + resp = requests.post(url, json={"fields": rec["fields"]}, timeout=60) + if resp.status_code not in (200, 201): + print(f" WARN doc {i}: {resp.status_code} {resp.text[:200]}") + if (i + 1) % 25 == 0: + print(f" [{i+1:3d}/100] ({time.time()-t0:.1f}s)") + print(f"Fed {len(feed_records)} documents in {time.time()-t0:.1f}s") + print("Vespa computes all ColBERT embeddings (non-pooled, pooled, binary) at indexing time.") + + +if __name__ == "__main__": + main() diff --git a/colbert-pooling/pom.xml b/colbert-pooling/pom.xml new file mode 100644 index 000000000..a1e2f4def --- /dev/null +++ b/colbert-pooling/pom.xml @@ -0,0 +1,39 @@ + + + + 4.0.0 + ai.vespa.examples + colbert-pooling + 1.0.0 + container-plugin + + + com.yahoo.vespa + cloud-tenant-base + [8,9) + + + + + UTF-8 + true + + + + + com.yahoo.vespa + model-integration + [8,9) + provided + + + com.yahoo.vespa + linguistics-components + [8,9) + provided + + + diff --git a/colbert-pooling/pyproject.toml b/colbert-pooling/pyproject.toml new file mode 100644 index 000000000..215c151fc --- /dev/null +++ b/colbert-pooling/pyproject.toml @@ -0,0 +1,32 @@ +[project] +name = "colbert-pooling" +version = "0.1.0" +description = "ColBERT with hierarchical token pooling for Vespa" +requires-python = ">=3.10" +dependencies = [ + "numpy>=1.24.0", + "onnxscript>=0.6.2", + "requests>=2.28.0", +] + +[project.optional-dependencies] +export = [ + "pylate-onnx-export>=0.1.0", +] +scipy = [ + "scipy>=1.10.0", +] +test = [ + "pytest>=7.0.0", + "scipy>=1.10.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["python/colbert_pooling"] + +[tool.pytest.ini_options] +testpaths = ["python/tests"] diff --git a/colbert-pooling/src/main/application/schemas/doc.sd b/colbert-pooling/src/main/application/schemas/doc.sd new file mode 100644 index 000000000..5f5de4ad4 --- /dev/null +++ b/colbert-pooling/src/main/application/schemas/doc.sd @@ -0,0 +1,142 @@ +schema doc { + + document doc { + + field doc_id type string { + indexing: summary | attribute + } + + field title type string { + indexing: summary | index + index: enable-bm25 + } + + field text type string { + indexing: summary | index + index: enable-bm25 + } + } + + # Full non-pooled ColBERT embeddings (bfloat16, paged to disk). + # Used for exact MaxSim reranking in the second phase. + field colbert type tensor(dt{}, x[128]) { + indexing: input text | embed colbert | summary | attribute + attribute: paged + } + + # Hierarchically pooled + binarised ColBERT embeddings (in memory). + # HNSW-indexed for fast approximate nearest-neighbour retrieval. + field colbert_pooled_binary type tensor(dt{}, x[16]) { + indexing: input text | embed colbert-pooled | summary | attribute | index + attribute { + distance-metric: hamming + } + index { + hnsw { + max-links-per-node: 32 + neighbors-to-explore-at-insert: 400 + } + } + } + + # Summary class for retrieving tensor fields (not included by default) + document-summary tensors { + summary doc_id {} + summary colbert {} + summary colbert_pooled_binary {} + } + + fieldset default { + fields: title, text + } + + # ------------------------------------------------------------------ + # Rank profile: HNSW ANN on binary pooled → rerank with full MaxSim + # ------------------------------------------------------------------ + + rank-profile colbert-pooled-ann { + + # All query tensors are set by ColBertSearcher at query time. + # The rq inputs are required for nearestNeighbor resolution on the content node. + inputs { + query(qt) tensor(qt{}, x[128]) + query(qtb) tensor(qt{}, x[16]) + query(rq0) tensor(x[16]) + query(rq1) tensor(x[16]) + query(rq2) tensor(x[16]) + query(rq3) tensor(x[16]) + query(rq4) tensor(x[16]) + query(rq5) tensor(x[16]) + query(rq6) tensor(x[16]) + query(rq7) tensor(x[16]) + query(rq8) tensor(x[16]) + query(rq9) tensor(x[16]) + query(rq10) tensor(x[16]) + query(rq11) tensor(x[16]) + query(rq12) tensor(x[16]) + query(rq13) tensor(x[16]) + query(rq14) tensor(x[16]) + query(rq15) tensor(x[16]) + query(rq16) tensor(x[16]) + query(rq17) tensor(x[16]) + query(rq18) tensor(x[16]) + query(rq19) tensor(x[16]) + query(rq20) tensor(x[16]) + query(rq21) tensor(x[16]) + query(rq22) tensor(x[16]) + query(rq23) tensor(x[16]) + query(rq24) tensor(x[16]) + query(rq25) tensor(x[16]) + query(rq26) tensor(x[16]) + query(rq27) tensor(x[16]) + query(rq28) tensor(x[16]) + query(rq29) tensor(x[16]) + query(rq30) tensor(x[16]) + query(rq31) tensor(x[16]) + } + + # First phase: hamming-based MaxSim on binary pooled embeddings. + # Documents are retrieved via nearestNeighbor operators in YQL; + # this scores them cheaply using the full binary query tensor. + function max_sim_binary() { + expression { + sum( + reduce( + 1 - sum( + hamming(query(qtb), attribute(colbert_pooled_binary)), x + ) / (8 * 16), + max, dt + ), + qt + ) + } + } + + first-phase { + expression: max_sim_binary() + } + + # Second phase: exact MaxSim on full non-pooled bfloat16 embeddings. + function max_sim() { + expression { + sum( + reduce( + sum(query(qt) * attribute(colbert), x), + max, dt + ), + qt + ) + } + } + + second-phase { + rerank-count: 100 + expression: max_sim() + } + + match-features { + max_sim_binary + max_sim + } + } +} diff --git a/colbert-pooling/src/main/application/search/query-profiles/default.xml b/colbert-pooling/src/main/application/search/query-profiles/default.xml new file mode 100644 index 000000000..db0bf2b5a --- /dev/null +++ b/colbert-pooling/src/main/application/search/query-profiles/default.xml @@ -0,0 +1,4 @@ + + 10 + colbert-pooled-ann + diff --git a/colbert-pooling/src/main/application/search/query-profiles/types/root.xml b/colbert-pooling/src/main/application/search/query-profiles/types/root.xml new file mode 100644 index 000000000..66cd6a91b --- /dev/null +++ b/colbert-pooling/src/main/application/search/query-profiles/types/root.xml @@ -0,0 +1,4 @@ + + + + diff --git a/colbert-pooling/src/main/application/services.xml b/colbert-pooling/src/main/application/services.xml new file mode 100644 index 000000000..2a3e4c70c --- /dev/null +++ b/colbert-pooling/src/main/application/services.xml @@ -0,0 +1,49 @@ + + + + + + + + + + + + + + + models/model_int8.onnx + models/tokenizer.json + 0 + + + + + + + models/model_int8.onnx + models/tokenizer.json + 2 + + + + + + + + + + 1 + + + + + + + + + diff --git a/colbert-pooling/src/main/application/validation-overrides.xml b/colbert-pooling/src/main/application/validation-overrides.xml new file mode 100644 index 000000000..7f47a7a55 --- /dev/null +++ b/colbert-pooling/src/main/application/validation-overrides.xml @@ -0,0 +1,3 @@ + + paged-setting-removal + diff --git a/colbert-pooling/src/main/java/ai/vespa/colbert/ColBertSearcher.java b/colbert-pooling/src/main/java/ai/vespa/colbert/ColBertSearcher.java new file mode 100644 index 000000000..37b69e3d0 --- /dev/null +++ b/colbert-pooling/src/main/java/ai/vespa/colbert/ColBertSearcher.java @@ -0,0 +1,125 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.colbert; + +import com.yahoo.component.annotation.Inject; +import com.yahoo.component.provider.ComponentRegistry; +import com.yahoo.language.process.Embedder; +import com.yahoo.prelude.query.NearestNeighborItem; +import com.yahoo.prelude.query.OrItem; +import com.yahoo.search.Query; +import com.yahoo.search.Result; +import com.yahoo.search.Searcher; +import com.yahoo.search.searchchain.Execution; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; + +import java.util.BitSet; + +/** + * Searcher that embeds the user query with ColBERT, binarises the per-token + * embeddings, and rewrites the query to use {@code nearestNeighbor} operators + * on the binary HNSW-indexed field. + *

+ * Also sets the float multi-vector query tensor {@code qt} for the second-phase + * exact MaxSim reranking, and the binary multi-vector {@code qtb} for the + * first-phase hamming MaxSim scoring. + *

+ * Just send text: + *

+ *   /search/?query=machine+learning
+ * 
+ */ +public class ColBertSearcher extends Searcher { + + private static final String BINARY_FIELD = "colbert_pooled_binary"; + private static final int TARGET_HITS_PER_TOKEN = 10; + private static final int MAX_QUERY_TOKENS = 32; + + private static final TensorType FLOAT_QUERY_TYPE = + TensorType.fromSpec("tensor(qt{},x[128])"); + private static final TensorType BINARY_QUERY_TYPE = + TensorType.fromSpec("tensor(qt{},x[16])"); + private static final TensorType SINGLE_BINARY_TYPE = + TensorType.fromSpec("tensor(x[16])"); + + private final Embedder embedder; + + @Inject + public ColBertSearcher(ComponentRegistry embedders) { + this.embedder = embedders.getComponent("colbert"); + } + + @Override + public Result search(Query query, Execution execution) { + String queryText = query.getModel().getQueryString(); + if (queryText == null || queryText.isBlank()) { + return execution.search(query); + } + + // Skip if query tensors are already set (client-side embedding) + if (query.getRanking().getFeatures().getTensor("query(qt)").isPresent()) { + return execution.search(query); + } + + // 1. Embed query → float multi-vector + var context = new Embedder.Context("query(qt)"); + Tensor floatTensor = embedder.embed(queryText, context, FLOAT_QUERY_TYPE); + query.getRanking().getFeatures().put("query(qt)", floatTensor); + + // 2. Count tokens + int nTokens = 0; + for (var iter = floatTensor.cellIterator(); iter.hasNext(); ) { + var cell = iter.next(); + int mapped = (int) cell.getKey().numericLabel(0); + nTokens = Math.max(nTokens, mapped + 1); + } + nTokens = Math.min(nTokens, MAX_QUERY_TOKENS); + int dim = 128; + int packedDim = dim / 8; + + // 3. Binarise tokens → qtb (multi-vector) + rqN (single-vector per token) + Tensor.Builder qtbBuilder = Tensor.Builder.of(BINARY_QUERY_TYPE); + Tensor[] rqTensors = new Tensor[nTokens]; + + for (int t = 0; t < nTokens; t++) { + Tensor.Builder rqBuilder = Tensor.Builder.of(SINGLE_BINARY_TYPE); + BitSet bits = new BitSet(8); + int key = 0; + for (int d = 0; d < dim; d++) { + int bitIndex = 7 - (d % 8); + double value = floatTensor.get(TensorAddress.of(t, d)); + if (value > 0.0) bits.set(bitIndex); + else bits.clear(bitIndex); + + if ((d + 1) % 8 == 0) { + byte[] bytes = bits.toByteArray(); + byte packed = bytes.length == 0 ? 0 : bytes[0]; + qtbBuilder.cell(TensorAddress.of(t, key), packed); + rqBuilder.cell(TensorAddress.of(key), packed); + key++; + bits = new BitSet(8); + } + } + rqTensors[t] = rqBuilder.build(); + } + query.getRanking().getFeatures().put("query(qtb)", qtbBuilder.build()); + + // 4. Build nearestNeighbor OR query + OrItem nnOr = new OrItem(); + for (int t = 0; t < nTokens; t++) { + String rqName = "rq" + t; + query.getRanking().getFeatures().put("query(" + rqName + ")", rqTensors[t]); + + NearestNeighborItem nn = new NearestNeighborItem(BINARY_FIELD, rqName); + nn.setTargetHits(TARGET_HITS_PER_TOKEN); + nn.setAllowApproximate(true); + nnOr.addItem(nn); + } + + // 5. Set as query root (nearestNeighbor is the sole retrieval mechanism) + query.getModel().getQueryTree().setRoot(nnOr); + + return execution.search(query); + } +} diff --git a/colbert-pooling/src/main/java/ai/vespa/colbert/HierarchicalTokenPooling.java b/colbert-pooling/src/main/java/ai/vespa/colbert/HierarchicalTokenPooling.java new file mode 100644 index 000000000..f100cb84c --- /dev/null +++ b/colbert-pooling/src/main/java/ai/vespa/colbert/HierarchicalTokenPooling.java @@ -0,0 +1,314 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.colbert; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.Map; + +/** + * Hierarchical token pooling for ColBERT multi-vector embeddings. + *

+ * Implements Ward's agglomerative clustering using the Nearest-Neighbor Chain + * (NNC) algorithm in O(n²), matching the behaviour of the Rust reference at + * https://github.com/lightonai/next-plaid/blob/main/next-plaid-onnx/src/hierarchy.rs + *

+ * The pooling pipeline: + *

    + *
  1. Compute pairwise cosine distances between token embeddings.
  2. + *
  3. Build a Ward linkage dendrogram via NNC.
  4. + *
  5. Cut the dendrogram to ceil(n / poolFactor) clusters.
  6. + *
  7. Replace each cluster with its L2-normalised centroid.
  8. + *
  9. Preserve the CLS token unchanged.
  10. + *
+ */ +public final class HierarchicalTokenPooling { + + private HierarchicalTokenPooling() {} + + // --------------------------------------------------------------- + // Pairwise cosine distance – condensed upper-triangular form + // --------------------------------------------------------------- + + /** + * Compute pairwise cosine distances between rows of a flat row-major embedding matrix. + * + * @param embeddings row-major float array of shape (n, dim) + * @param n number of embeddings (rows) + * @param dim embedding dimension + * @return condensed distance vector of length n*(n-1)/2 + */ + public static double[] pdistCosine(double[] embeddings, int n, int dim) { + double[] norms = new double[n]; + for (int i = 0; i < n; i++) { + double s = 0; + int off = i * dim; + for (int d = 0; d < dim; d++) { + double v = embeddings[off + d]; + s += v * v; + } + norms[i] = Math.sqrt(s); + if (norms[i] < 1e-10) norms[i] = 1e-10; + } + + int len = n * (n - 1) / 2; + double[] dists = new double[len]; + int idx = 0; + for (int i = 0; i < n; i++) { + int offI = i * dim; + for (int j = i + 1; j < n; j++) { + double dot = 0; + int offJ = j * dim; + for (int d = 0; d < dim; d++) { + dot += embeddings[offI + d] * embeddings[offJ + d]; + } + double sim = dot / (norms[i] * norms[j]); + dists[idx++] = Math.max(0.0, Math.min(2.0, 1.0 - sim)); + } + } + return dists; + } + + // --------------------------------------------------------------- + // Ward linkage – Nearest-Neighbor Chain algorithm + // --------------------------------------------------------------- + + /** + * Ward's method hierarchical clustering using the NNC algorithm. + * + * @param dists condensed cosine distance vector from {@link #pdistCosine} + * @param n number of observations + * @return linkage matrix of shape (n-1, 4): [id_a, id_b, distance, size] + */ + public static double[][] linkageWard(double[] dists, int n) { + if (n <= 1) return new double[0][4]; + + int total = 2 * n - 1; + + // Full symmetric squared-distance matrix for all clusters (original + merged) + double[][] D = new double[total][total]; + for (double[] row : D) Arrays.fill(row, Double.MAX_VALUE); + + int idx = 0; + for (int i = 0; i < n; i++) { + for (int j = i + 1; j < n; j++) { + double d2 = dists[idx] * dists[idx]; + D[i][j] = d2; + D[j][i] = d2; + idx++; + } + } + + int[] sizes = new int[total]; + Arrays.fill(sizes, 0, n, 1); + + boolean[] active = new boolean[total]; + Arrays.fill(active, 0, n, true); + + double[][] Z = new double[n - 1][4]; + int[] chain = new int[n + 1]; + int chainSize = 0; + + for (int step = 0; step < n - 1; step++) { + if (chainSize == 0) { + for (int i = 0; i < total; i++) { + if (active[i]) { + chain[chainSize++] = i; + break; + } + } + } + + int a, b; + while (true) { + int cur = chain[chainSize - 1]; + // Find nearest active neighbour + double minD = Double.MAX_VALUE; + int nn = -1; + for (int k = 0; k < total; k++) { + if (k != cur && active[k] && D[cur][k] < minD) { + minD = D[cur][k]; + nn = k; + } + } + if (chainSize >= 2 && nn == chain[chainSize - 2]) { + b = chain[--chainSize]; + a = chain[--chainSize]; + break; + } + chain[chainSize++] = nn; + } + + if (a > b) { int t = a; a = b; b = t; } + + double mergeD2 = D[a][b]; + int newId = n + step; + Z[step][0] = a; + Z[step][1] = b; + Z[step][2] = Math.sqrt(Math.max(mergeD2, 0.0)); + Z[step][3] = sizes[a] + sizes[b]; + sizes[newId] = sizes[a] + sizes[b]; + + active[a] = false; + active[b] = false; + + // Ward distance update – vectorised over active clusters + double na = sizes[a], nb = sizes[b]; + for (int k = 0; k < total; k++) { + if (active[k]) { + double nk = sizes[k]; + double dNew = ((na + nk) * D[a][k] + (nb + nk) * D[b][k] - nk * mergeD2) / (na + nb + nk); + D[newId][k] = dNew; + D[k][newId] = dNew; + } + } + active[newId] = true; + } + return Z; + } + + // --------------------------------------------------------------- + // Flat cluster assignment – maxclust criterion + // --------------------------------------------------------------- + + private static int find(int[] parent, int x) { + while (parent[x] != x) { + parent[x] = parent[parent[x]]; + x = parent[x]; + } + return x; + } + + /** + * Cut a Ward dendrogram to form exactly {@code maxClusters} flat clusters. + *

+ * Selects the merges with the smallest distances (robust to non-monotonic + * merge orderings from the NNC algorithm). + * + * @param Z linkage matrix from {@link #linkageWard} + * @param n number of original observations + * @param maxClusters desired number of clusters + * @return 1-indexed cluster labels of length n + */ + public static int[] fclusterMaxclust(double[][] Z, int n, int maxClusters) { + maxClusters = Math.max(1, Math.min(maxClusters, n)); + int nMerges = n - maxClusters; + + if (nMerges == 0) { + int[] labels = new int[n]; + for (int i = 0; i < n; i++) labels[i] = i + 1; + return labels; + } + + // Sort merge indices by ascending distance + Integer[] sortedIdx = new Integer[Z.length]; + for (int i = 0; i < sortedIdx.length; i++) sortedIdx[i] = i; + Arrays.sort(sortedIdx, (x, y) -> Double.compare(Z[x][2], Z[y][2])); + + int[] parent = new int[2 * n - 1]; + for (int i = 0; i < parent.length; i++) parent[i] = i; + + for (int m = 0; m < nMerges; m++) { + int mi = sortedIdx[m]; + int a = (int) Z[mi][0], b = (int) Z[mi][1]; + int newId = n + mi; + parent[find(parent, a)] = newId; + parent[find(parent, b)] = newId; + } + + int[] labels = new int[n]; + Map clusterMap = new HashMap<>(); + int nextLabel = 1; + for (int i = 0; i < n; i++) { + int root = find(parent, i); + Integer label = clusterMap.get(root); + if (label == null) { + label = nextLabel++; + clusterMap.put(root, label); + } + labels[i] = label; + } + return labels; + } + + // --------------------------------------------------------------- + // Token pooling entry-point + // --------------------------------------------------------------- + + /** + * Pool ColBERT token embeddings via hierarchical clustering. + * + * @param embeddings row-major array of shape (nTokens, dim) + * @param nTokens number of tokens + * @param dim embedding dimension + * @param poolFactor reduction factor (2 = halve tokens) + * @param skipFirst preserve first token (CLS) unchanged + * @return row-major array of shape (nPooled, dim) with L2-normalised centroids + */ + public static double[] poolTokens(double[] embeddings, int nTokens, int dim, + int poolFactor, boolean skipFirst) { + if (poolFactor <= 1 || nTokens <= 1) { + return Arrays.copyOf(embeddings, nTokens * dim); + } + + int startIdx = skipFirst ? 1 : 0; + int nToPool = nTokens - startIdx; + + if (nToPool <= 1) { + return Arrays.copyOf(embeddings, nTokens * dim); + } + + int nClusters = Math.max(1, (int) Math.ceil((double) nToPool / poolFactor)); + if (nClusters >= nToPool) { + return Arrays.copyOf(embeddings, nTokens * dim); + } + + // Extract tokens to pool (skip CLS if requested) + double[] tokens = new double[nToPool * dim]; + System.arraycopy(embeddings, startIdx * dim, tokens, 0, nToPool * dim); + + // Cluster + double[] dists = pdistCosine(tokens, nToPool, dim); + double[][] Z = linkageWard(dists, nToPool); + int[] labels = fclusterMaxclust(Z, nToPool, nClusters); + + // Compute centroids + int nActualClusters = 0; + for (int l : labels) nActualClusters = Math.max(nActualClusters, l); + + double[] centroids = new double[nActualClusters * dim]; + int[] counts = new int[nActualClusters]; + for (int i = 0; i < nToPool; i++) { + int c = labels[i] - 1; // 1-indexed to 0-indexed + counts[c]++; + int cOff = c * dim; + int tOff = i * dim; + for (int d = 0; d < dim; d++) { + centroids[cOff + d] += tokens[tOff + d]; + } + } + for (int c = 0; c < nActualClusters; c++) { + int cOff = c * dim; + double norm = 0; + for (int d = 0; d < dim; d++) { + centroids[cOff + d] /= counts[c]; + norm += centroids[cOff + d] * centroids[cOff + d]; + } + norm = Math.sqrt(norm); + if (norm < 1e-10) norm = 1e-10; + for (int d = 0; d < dim; d++) { + centroids[cOff + d] /= norm; + } + } + + // Build result: CLS (if skipFirst) + pooled centroids + int nPooled = (skipFirst ? 1 : 0) + nActualClusters; + double[] result = new double[nPooled * dim]; + + if (skipFirst) { + System.arraycopy(embeddings, 0, result, 0, dim); // CLS token + } + System.arraycopy(centroids, 0, result, (skipFirst ? 1 : 0) * dim, nActualClusters * dim); + + return result; + } +} diff --git a/colbert-pooling/src/main/java/ai/vespa/colbert/PoolingColBertEmbedder.java b/colbert-pooling/src/main/java/ai/vespa/colbert/PoolingColBertEmbedder.java new file mode 100644 index 000000000..011aabbf2 --- /dev/null +++ b/colbert-pooling/src/main/java/ai/vespa/colbert/PoolingColBertEmbedder.java @@ -0,0 +1,352 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.colbert; + +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; +import com.yahoo.ai.vespa.colbert.PoolingColbertEmbedderConfig; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import com.yahoo.language.huggingface.HuggingFaceTokenizer; +import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; + +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.BitSet; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST; + +/** + * A ColBERT embedder with optional hierarchical token pooling. + *

+ * Constructs input sequences matching pylate's + * ColBERT encoding: {@code [CLS] [D] search_document: [SEP]} for documents and + * {@code [CLS] [Q] search_query: [MASK]...} for queries. The {@code search_document:} / + * {@code search_query:} prefixes are configurable via {@code prependDocument} / {@code prependQuery}. + *

+ * When {@code poolFactor > 1}, applies Ward's agglomerative clustering + * (see {@link HierarchicalTokenPooling}) to merge semantically similar tokens, + * reducing the multi-vector representation size while preserving retrieval quality. + *

+ * Configured via {@code pooling-colbert-embedder.def}: + *

    + *
  • {@code poolFactor=0} — standard ColBERT (no pooling)
  • + *
  • {@code poolFactor=2} — merge semantically similar tokens, keep ~half the vectors
  • + *
  • {@code poolFactor=3} — keep roughly a third, etc.
  • + *
+ */ +public class PoolingColBertEmbedder extends AbstractComponent implements Embedder { + + private static final String PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~"; + + private final Embedder.Runtime runtime; + private final String inputIdsName; + private final String attentionMaskName; + private final String outputName; + private final HuggingFaceTokenizer tokenizer; + private final OnnxEvaluator evaluator; + private final int maxTransformerTokens; + private final int maxQueryTokens; + private final int maxDocumentTokens; + private final long startSequenceToken; + private final long endSequenceToken; + private final long maskSequenceToken; + private final long padSequenceToken; + private final long querySequenceToken; + private final long documentSequenceToken; + private final Set skipTokens; + private final int poolFactor; + private final String prependQuery; + private final String prependDocument; + + record TransformerInput(List inputIds, List attentionMask) {} + record EmbedderCacheKey(String embedderId, Object embeddedValue) {} + record EmbeddingResult(int inputIdSize, Map outputs) {} + + @Inject + public PoolingColBertEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, PoolingColbertEmbedderConfig config) { + this.runtime = runtime; + this.inputIdsName = config.transformerInputIds(); + this.attentionMaskName = config.transformerAttentionMask(); + this.outputName = config.transformerOutput(); + this.maxTransformerTokens = config.transformerMaxTokens(); + this.maxDocumentTokens = Math.min(config.maxDocumentTokens(), maxTransformerTokens); + this.maxQueryTokens = Math.min(config.maxQueryTokens(), maxTransformerTokens); + this.startSequenceToken = config.transformerStartSequenceToken(); + this.endSequenceToken = config.transformerEndSequenceToken(); + this.maskSequenceToken = config.transformerMaskToken(); + this.padSequenceToken = config.transformerPadToken(); + this.querySequenceToken = config.queryTokenId(); + this.documentSequenceToken = config.documentTokenId(); + this.poolFactor = config.poolFactor(); + this.prependQuery = config.prependQuery(); + this.prependDocument = config.prependDocument(); + + Path tokenizerPath = config.tokenizerPath(); + var builder = new HuggingFaceTokenizer.Builder() + .addSpecialTokens(false) + .addDefaultModel(tokenizerPath) + .setPadding(false); + var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath); + if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) { + int maxLength = info.maxLength() > 0 && info.maxLength() <= maxTransformerTokens + ? info.maxLength() + : maxTransformerTokens; + builder.setTruncation(true).setMaxLength(maxLength); + } + this.tokenizer = builder.build(); + + this.skipTokens = new HashSet<>(); + PUNCTUATION.chars().forEach(c -> + this.skipTokens.addAll(tokenizer.encode(Character.toString((char) c), null).ids())); + + var onnxOpts = OnnxEvaluatorOptions.createDefault(); + this.evaluator = onnx.evaluatorOf(config.transformerModelPath().toString(), onnxOpts); + validateModel(); + } + + private void validateModel() { + Map inputs = evaluator.getInputInfo(); + validateName(inputs, inputIdsName, "input"); + validateName(inputs, attentionMaskName, "input"); + Map outputs = evaluator.getOutputInfo(); + validateName(outputs, outputName, "output"); + } + + private static void validateName(Map types, String name, String type) { + if (!types.containsKey(name)) { + throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " + + "Model contains: " + String.join(",", types.keySet())); + } + } + + @Override + public List embed(String text, Context context) { + throw new UnsupportedOperationException("This embedder only supports embed with tensor type"); + } + + @Override + public Tensor embed(String text, Context context, TensorType tensorType) { + if (!validTensorType(tensorType)) + throw new IllegalArgumentException("Invalid tensor target. Expected a mixed 2-d mapped-indexed tensor, got " + tensorType); + text = prependInstruction(text, context); + if (context.getDestinationType() == Context.DestinationType.QUERY) + return embedQuery(text, context, tensorType); + else + return embedDocument(text, context, tensorType); + } + + String prependInstruction(String text, Context context) { + if (prependQuery != null && !prependQuery.isEmpty() + && context.getDestinationType() == Context.DestinationType.QUERY) { + return prependQuery + text; + } + if (prependDocument != null && !prependDocument.isEmpty()) { + return prependDocument + text; + } + return text; + } + + @Override + public void deconstruct() { + evaluator.close(); + tokenizer.close(); + } + + // ----- Query embedding (no pooling, pad with MASK) ----- + + private Tensor embedQuery(String text, Context context, TensorType tensorType) { + if (tensorType.valueType() == TensorType.Value.INT8) + throw new IllegalArgumentException("ColBERT query embedding does not support int8 tensor type"); + + EmbeddingResult result = lookupOrEvaluate(context, text, true); + var output = (IndexedTensor) result.outputs.get(outputName); + return toFloatTensor(output, tensorType, result.inputIdSize); + } + + // ----- Document embedding (with optional pooling) ----- + + private Tensor embedDocument(String text, Context context, TensorType tensorType) { + EmbeddingResult result = lookupOrEvaluate(context, text, false); + var modelOutput = (IndexedTensor) result.outputs.get(outputName); + + if (poolFactor <= 1) { + // No pooling – behave like standard ColBertEmbedder + if (tensorType.valueType() == TensorType.Value.INT8) + return toBitTensor(modelOutput, tensorType, result.inputIdSize); + else + return toFloatTensor(modelOutput, tensorType, result.inputIdSize); + } + + // Extract float embeddings from model output + int nTokens = result.inputIdSize; + int dim = (int) modelOutput.shape()[2]; + double[] embeddings = new double[nTokens * dim]; + for (int t = 0; t < nTokens; t++) { + for (int d = 0; d < dim; d++) { + embeddings[t * dim + d] = modelOutput.get(0, t, d); + } + } + + // Apply hierarchical token pooling + double[] pooled = HierarchicalTokenPooling.poolTokens(embeddings, nTokens, dim, poolFactor, true); + int nPooled = pooled.length / dim; + + // Build output tensor + if (tensorType.valueType() == TensorType.Value.INT8) { + return toBitTensorFromPooled(pooled, nPooled, dim, tensorType); + } else { + return toFloatTensorFromPooled(pooled, nPooled, dim, tensorType); + } + } + + // ----- Model inference ----- + + TransformerInput buildTransformerInput(List tokens, int maxTokens, boolean isQuery) { + if (!isQuery) + tokens = tokens.stream().filter(t -> !skipTokens.contains(t)).toList(); + + List inputIds = new ArrayList<>(maxTokens); + List attentionMask = new ArrayList<>(maxTokens); + + if (tokens.size() > maxTokens - 3) + tokens = tokens.subList(0, maxTokens - 3); + + inputIds.add(startSequenceToken); + inputIds.add(isQuery ? querySequenceToken : documentSequenceToken); + inputIds.addAll(tokens); + inputIds.add(endSequenceToken); + + int inputLength = inputIds.size(); + long padTokenId = isQuery ? maskSequenceToken : padSequenceToken; + int padding = isQuery ? maxTokens - inputLength : 0; + + for (int i = 0; i < padding; i++) inputIds.add(padTokenId); + for (int i = 0; i < inputLength; i++) attentionMask.add(1L); + for (int i = 0; i < padding; i++) attentionMask.add(0L); + + return new TransformerInput(inputIds, attentionMask); + } + + EmbeddingResult lookupOrEvaluate(Context context, String text, boolean isQuery) { + var key = new EmbedderCacheKey(context.getEmbedderId(), text); + return context.computeCachedValueIfAbsent(key, () -> evaluate(context, text, isQuery)); + } + + private EmbeddingResult evaluate(Context context, String text, boolean isQuery) { + var start = System.nanoTime(); + var encoding = tokenizer.encode(text, context.getLanguage()); + runtime.sampleSequenceLength(encoding.ids().size(), context); + + TransformerInput input = buildTransformerInput( + encoding.ids(), isQuery ? maxQueryTokens : maxDocumentTokens, isQuery); + + Tensor inputIdsTensor = createTensorRepresentation(input.inputIds, "d1"); + Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1"); + + Map outputs = evaluator.evaluate(Map.of( + inputIdsName, inputIdsTensor.expand("d0"), + attentionMaskName, attentionMaskTensor.expand("d0"))); + + runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); + return new EmbeddingResult(input.inputIds.size(), outputs); + } + + // ----- Tensor construction ----- + + static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) { + int wantedDim = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + int resultDim = (int) result.shape()[2]; + if (wantedDim > resultDim) + throw new IllegalArgumentException("Cannot map " + resultDim + " dims into " + wantedDim); + + Tensor.Builder builder = Tensor.Builder.of(type); + for (int t = 0; t < nTokens; t++) + for (int d = 0; d < wantedDim; d++) + builder.cell(TensorAddress.of(t, d), result.get(0, t, d)); + return builder.build(); + } + + static Tensor toBitTensor(IndexedTensor result, TensorType type, int nTokens) { + int wantedDim = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + int floatDim = 8 * wantedDim; + if (floatDim > (int) result.shape()[2]) + throw new IllegalArgumentException("Cannot pack " + result.shape()[2] + " dims into " + wantedDim); + + Tensor.Builder builder = Tensor.Builder.of(type); + for (int t = 0; t < nTokens; t++) { + BitSet bits = new BitSet(8); + int key = 0; + for (int d = 0; d < floatDim; d++) { + int bitIndex = 7 - (d % 8); + if (result.get(0, t, d) > 0.0) bits.set(bitIndex); else bits.clear(bitIndex); + if ((d + 1) % 8 == 0) { + byte[] bytes = bits.toByteArray(); + builder.cell(TensorAddress.of(t, key), bytes.length == 0 ? 0 : bytes[0]); + key++; + bits = new BitSet(8); + } + } + } + return builder.build(); + } + + /** Build a float tensor from pooled double[] embeddings. */ + private static Tensor toFloatTensorFromPooled(double[] pooled, int nPooled, int dim, TensorType type) { + int wantedDim = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + if (wantedDim > dim) + throw new IllegalArgumentException("Cannot map " + dim + " dims into " + wantedDim); + + Tensor.Builder builder = Tensor.Builder.of(type); + for (int t = 0; t < nPooled; t++) + for (int d = 0; d < wantedDim; d++) + builder.cell(TensorAddress.of(t, d), pooled[t * dim + d]); + return builder.build(); + } + + /** Build an int8 bit-packed tensor from pooled double[] embeddings. */ + private static Tensor toBitTensorFromPooled(double[] pooled, int nPooled, int dim, TensorType type) { + int wantedDim = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + int floatDim = 8 * wantedDim; + if (floatDim > dim) + throw new IllegalArgumentException("Cannot pack " + dim + " dims into " + wantedDim); + + Tensor.Builder builder = Tensor.Builder.of(type); + for (int t = 0; t < nPooled; t++) { + BitSet bits = new BitSet(8); + int key = 0; + int tOff = t * dim; + for (int d = 0; d < floatDim; d++) { + int bitIndex = 7 - (d % 8); + if (pooled[tOff + d] > 0.0) bits.set(bitIndex); else bits.clear(bitIndex); + if ((d + 1) % 8 == 0) { + byte[] bytes = bits.toByteArray(); + builder.cell(TensorAddress.of(t, key), bytes.length == 0 ? 0 : bytes[0]); + key++; + bits = new BitSet(8); + } + } + } + return builder.build(); + } + + private boolean validTensorType(TensorType target) { + return target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1; + } + + private IndexedTensor createTensorRepresentation(List input, String dimension) { + int size = input.size(); + TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build(); + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); + for (int i = 0; i < size; i++) builder.cell(input.get(i), i); + return builder.build(); + } +} diff --git a/colbert-pooling/src/main/resources/configdefinitions/pooling-colbert-embedder.def b/colbert-pooling/src/main/resources/configdefinitions/pooling-colbert-embedder.def new file mode 100644 index 000000000..eb6bedbba --- /dev/null +++ b/colbert-pooling/src/main/resources/configdefinitions/pooling-colbert-embedder.def @@ -0,0 +1,33 @@ +namespace=ai.vespa.colbert + +# Model files – deployed with the application package +transformerModelPath path +tokenizerPath path + +# Token limits +maxQueryTokens int default=39 +maxDocumentTokens int default=512 +transformerMaxTokens int default=518 + +# Special token IDs (defaults for lightonai/ColBERT-Zero) +transformerStartSequenceToken int default=50281 +transformerEndSequenceToken int default=50282 +transformerMaskToken int default=50284 +transformerPadToken int default=50284 +queryTokenId int default=50368 +documentTokenId int default=50369 + +# Text to prepend to the input in query context (e.g. "search_query: ") +prependQuery string default="search_query: " + +# Text to prepend to the input in document context (e.g. "search_document: ") +prependDocument string default="search_document: " + +# ONNX input/output tensor names +transformerInputIds string default="input_ids" +transformerAttentionMask string default="attention_mask" +transformerOutput string default="output" + +# Hierarchical token pooling factor. +# 0 = no pooling (standard ColBERT); 2 = halve token count; 3 = keep ~1/3, etc. +poolFactor int default=2 diff --git a/colbert-pooling/src/test/application/tests/system-test/feed-and-search-test.json b/colbert-pooling/src/test/application/tests/system-test/feed-and-search-test.json new file mode 100644 index 000000000..fe11e8d07 --- /dev/null +++ b/colbert-pooling/src/test/application/tests/system-test/feed-and-search-test.json @@ -0,0 +1,84 @@ +[ + { + "name": "colbert-pooling system test", + "defaults": { + "cluster": "default", + "parameters": { + "timeout": "10s" + } + }, + "steps": [ + { + "name": "Feed document 0", + "request": { + "method": "POST", + "uri": "/document/v1/doc/doc/docid/0", + "body": { + "fields": { + "doc_id": "0", + "title": "Solar System Overview", + "text": "The solar system consists of the Sun and the celestial bodies that orbit it, including eight planets, their moons, dwarf planets, asteroids, and comets." + } + } + }, + "response": { "code": 200 } + }, + { + "name": "Feed document 1", + "request": { + "method": "POST", + "uri": "/document/v1/doc/doc/docid/1", + "body": { + "fields": { + "doc_id": "1", + "title": "Machine Learning Basics", + "text": "Machine learning is a branch of artificial intelligence where computers learn patterns from data without being explicitly programmed." + } + } + }, + "response": { "code": 200 } + }, + { + "name": "Feed document 2", + "request": { + "method": "POST", + "uri": "/document/v1/doc/doc/docid/2", + "body": { + "fields": { + "doc_id": "2", + "title": "Ocean Currents", + "text": "Ocean currents are continuous movements of seawater driven by wind, temperature differences, salinity, and the Earth rotation." + } + } + }, + "response": { "code": 200 } + }, + { + "name": "BM25 query", + "request": { + "method": "POST", + "uri": "/search/", + "body": { + "yql": "select * from doc where userQuery()", + "query": "machine learning artificial intelligence", + "hits": 3 + } + }, + "response": { + "code": 200, + "body": { + "root": { + "children": [ + { + "fields": { + "doc_id": "1" + } + } + ] + } + } + } + } + ] + } +] diff --git a/colbert-pooling/src/test/java/ai/vespa/colbert/HierarchicalTokenPoolingTest.java b/colbert-pooling/src/test/java/ai/vespa/colbert/HierarchicalTokenPoolingTest.java new file mode 100644 index 000000000..bb0e16cb0 --- /dev/null +++ b/colbert-pooling/src/test/java/ai/vespa/colbert/HierarchicalTokenPoolingTest.java @@ -0,0 +1,244 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.colbert; + +import org.junit.jupiter.api.Test; + +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.*; + +/** + * Tests for {@link HierarchicalTokenPooling}. + *

+ * Covers: pdistCosine, linkageWard, fclusterMaxclust, poolTokens correctness + * and performance for realistic ColBERT document sizes. + */ +class HierarchicalTokenPoolingTest { + + private static final int DIM = 128; + + /** Generate L2-normalised random embeddings in row-major flat array. */ + private static double[] randomEmbeddings(int n, int dim, long seed) { + Random rng = new Random(seed); + double[] emb = new double[n * dim]; + for (int i = 0; i < n; i++) { + double norm = 0; + for (int d = 0; d < dim; d++) { + emb[i * dim + d] = rng.nextGaussian(); + norm += emb[i * dim + d] * emb[i * dim + d]; + } + norm = Math.sqrt(norm); + for (int d = 0; d < dim; d++) emb[i * dim + d] /= norm; + } + return emb; + } + + // --------------------------------------------------------------- + // pdistCosine + // --------------------------------------------------------------- + + @Test + void pdistCosine_identicalVectors() { + double[] emb = randomEmbeddings(1, 4, 42); + double[] triple = new double[3 * 4]; + System.arraycopy(emb, 0, triple, 0, 4); + System.arraycopy(emb, 0, triple, 4, 4); + System.arraycopy(emb, 0, triple, 8, 4); + + double[] dists = HierarchicalTokenPooling.pdistCosine(triple, 3, 4); + assertEquals(3, dists.length); + for (double d : dists) assertEquals(0.0, d, 1e-10); + } + + @Test + void pdistCosine_orthogonalVectors() { + // Three orthogonal unit vectors + double[] emb = new double[3 * 3]; + emb[0] = 1; emb[4] = 1; emb[8] = 1; + double[] dists = HierarchicalTokenPooling.pdistCosine(emb, 3, 3); + assertEquals(3, dists.length); + for (double d : dists) assertEquals(1.0, d, 1e-10); + } + + @Test + void pdistCosine_outputLength() { + int n = 20; + double[] emb = randomEmbeddings(n, DIM, 42); + double[] dists = HierarchicalTokenPooling.pdistCosine(emb, n, DIM); + assertEquals(n * (n - 1) / 2, dists.length); + for (double d : dists) { + assertTrue(d >= 0.0 && d <= 2.0, "Distance out of range: " + d); + } + } + + // --------------------------------------------------------------- + // linkageWard + // --------------------------------------------------------------- + + @Test + void linkageWard_threePoints() { + // Two close points + one far point + double[] emb = {1.0, 0.0, 0.9, 0.1, 0.0, 1.0}; + // Normalise + for (int i = 0; i < 3; i++) { + double n = Math.sqrt(emb[i*2]*emb[i*2] + emb[i*2+1]*emb[i*2+1]); + emb[i*2] /= n; emb[i*2+1] /= n; + } + double[] dists = HierarchicalTokenPooling.pdistCosine(emb, 3, 2); + double[][] Z = HierarchicalTokenPooling.linkageWard(dists, 3); + + assertEquals(2, Z.length); + // First merge: the two closest (0 and 1) + assertTrue((int) Z[0][0] == 0 || (int) Z[0][0] == 1); + assertTrue((int) Z[0][1] == 0 || (int) Z[0][1] == 1); + assertEquals(2, (int) Z[0][3]); + // All distances non-negative + for (double[] row : Z) assertTrue(row[2] >= 0); + } + + @Test + void linkageWard_clusterSizes() { + int n = 30; + double[] emb = randomEmbeddings(n, DIM, 123); + double[] dists = HierarchicalTokenPooling.pdistCosine(emb, n, DIM); + double[][] Z = HierarchicalTokenPooling.linkageWard(dists, n); + + assertEquals(n - 1, Z.length); + for (double[] row : Z) assertTrue(row[3] >= 2, "Merge size must be >= 2"); + assertEquals(n, (int) Z[n - 2][3], "Final merge must contain all observations"); + } + + // --------------------------------------------------------------- + // fclusterMaxclust + // --------------------------------------------------------------- + + @Test + void fcluster_correctNumberOfClusters() { + int n = 50; + double[] emb = randomEmbeddings(n, DIM, 42); + double[] dists = HierarchicalTokenPooling.pdistCosine(emb, n, DIM); + double[][] Z = HierarchicalTokenPooling.linkageWard(dists, n); + + for (int k : new int[]{1, 2, 5, 10, 25, 50}) { + int[] labels = HierarchicalTokenPooling.fclusterMaxclust(Z, n, k); + assertEquals(n, labels.length); + int maxLabel = 0; + for (int l : labels) maxLabel = Math.max(maxLabel, l); + assertEquals(k, maxLabel, "Expected " + k + " clusters, got " + maxLabel); + for (int l : labels) assertTrue(l >= 1, "Labels must be 1-indexed"); + } + } + + // --------------------------------------------------------------- + // poolTokens + // --------------------------------------------------------------- + + @Test + void poolTokens_outputShape() { + int nTokens = 50; + double[] emb = randomEmbeddings(nTokens, DIM, 42); + double[] pooled = HierarchicalTokenPooling.poolTokens(emb, nTokens, DIM, 2, true); + // CLS (1) + ceil(49/2) = 1 + 25 = 26 + assertEquals(26 * DIM, pooled.length); + } + + @Test + void poolTokens_clsPreserved() { + int nTokens = 20; + double[] emb = randomEmbeddings(nTokens, DIM, 42); + double[] pooled = HierarchicalTokenPooling.poolTokens(emb, nTokens, DIM, 2, true); + for (int d = 0; d < DIM; d++) { + assertEquals(emb[d], pooled[d], 1e-12, "CLS token dimension " + d + " not preserved"); + } + } + + @Test + void poolTokens_normalised() { + int nTokens = 40; + double[] emb = randomEmbeddings(nTokens, DIM, 42); + double[] pooled = HierarchicalTokenPooling.poolTokens(emb, nTokens, DIM, 2, true); + int nPooled = pooled.length / DIM; + for (int t = 0; t < nPooled; t++) { + double norm = 0; + for (int d = 0; d < DIM; d++) { + norm += pooled[t * DIM + d] * pooled[t * DIM + d]; + } + assertEquals(1.0, Math.sqrt(norm), 1e-8, "Token " + t + " not normalised"); + } + } + + @Test + void poolTokens_poolFactor3() { + int nTokens = 31; + double[] emb = randomEmbeddings(nTokens, DIM, 42); + double[] pooled = HierarchicalTokenPooling.poolTokens(emb, nTokens, DIM, 3, true); + // CLS (1) + ceil(30/3) = 1 + 10 = 11 + assertEquals(11 * DIM, pooled.length); + } + + @Test + void poolTokens_noPooling() { + int nTokens = 20; + double[] emb = randomEmbeddings(nTokens, DIM, 42); + double[] pooled = HierarchicalTokenPooling.poolTokens(emb, nTokens, DIM, 0, true); + assertEquals(nTokens * DIM, pooled.length); + assertArrayEquals(emb, pooled, 1e-15); + } + + @Test + void poolTokens_singleToken() { + double[] emb = randomEmbeddings(1, DIM, 42); + double[] pooled = HierarchicalTokenPooling.poolTokens(emb, 1, DIM, 2, false); + assertArrayEquals(emb, pooled, 1e-15); + } + + // --------------------------------------------------------------- + // Performance + // --------------------------------------------------------------- + + @Test + void performance_pooling512tokens() { + double[] emb = randomEmbeddings(512, DIM, 42); + // Warm-up + HierarchicalTokenPooling.poolTokens(emb, 512, DIM, 2, true); + + int runs = 3; + long t0 = System.nanoTime(); + for (int i = 0; i < runs; i++) { + HierarchicalTokenPooling.poolTokens(emb, 512, DIM, 2, true); + } + double ms = (System.nanoTime() - t0) / 1e6 / runs; + System.out.printf(" poolTokens n=512: %.1f ms%n", ms); + assertTrue(ms < 2000, "Pooling 512 tokens took " + ms + " ms, expected < 2000"); + } + + @Test + void performance_pooling128tokens() { + double[] emb = randomEmbeddings(128, DIM, 42); + HierarchicalTokenPooling.poolTokens(emb, 128, DIM, 2, true); + + int runs = 5; + long t0 = System.nanoTime(); + for (int i = 0; i < runs; i++) { + HierarchicalTokenPooling.poolTokens(emb, 128, DIM, 2, true); + } + double ms = (System.nanoTime() - t0) / 1e6 / runs; + System.out.printf(" poolTokens n=128: %.1f ms%n", ms); + assertTrue(ms < 500, "Pooling 128 tokens took " + ms + " ms, expected < 500"); + } + + @Test + void performance_pdist512() { + double[] emb = randomEmbeddings(512, DIM, 42); + HierarchicalTokenPooling.pdistCosine(emb, 512, DIM); + + int runs = 5; + long t0 = System.nanoTime(); + for (int i = 0; i < runs; i++) { + HierarchicalTokenPooling.pdistCosine(emb, 512, DIM); + } + double ms = (System.nanoTime() - t0) / 1e6 / runs; + System.out.printf(" pdistCosine n=512: %.1f ms%n", ms); + assertTrue(ms < 500, "pdist 512 took " + ms + " ms, expected < 500"); + } +} diff --git a/colbert-pooling/verify_correctness.py b/colbert-pooling/verify_correctness.py new file mode 100644 index 000000000..058d28029 --- /dev/null +++ b/colbert-pooling/verify_correctness.py @@ -0,0 +1,275 @@ +#!/usr/bin/env python3 +"""End-to-end correctness verification: Vespa vs pylate (model authors' reference). + +Both produce [CLS] [D] search_document: [SEP] sequences. pylate runs +FP32 PyTorch; Vespa runs the INT8 ONNX model. The expected cos_sim floor is +~0.94 per token (INT8 quantisation). + +Prerequisites: + uv pip install pylate onnxruntime tokenizers requests numpy + +Usage: + uv run python verify_correctness.py # full comparison + uv run python verify_correctness.py --no-vespa # pylate vs ONNX INT8 only +""" + +from __future__ import annotations + +import argparse +import string +import sys +import time +from pathlib import Path + +import numpy as np + +MODEL_DIR = Path("src/main/application/models") +VESPA_ENDPOINT = "http://localhost:8080" + +CLS, SEP, D_TOKEN = 50281, 50282, 50369 +PREPEND_DOC = "search_document: " + +TEST_DOC = ( + "The solar system consists of the Sun and the celestial bodies that " + "orbit it, including eight planets, their moons, dwarf planets, " + "asteroids, and comets. The four inner planets are rocky worlds, " + "while the outer planets are gas and ice giants." +) + +POOL_FACTORS = [2, 3, 4] + + +def cos(a, b): + na, nb = np.linalg.norm(a), np.linalg.norm(b) + return float(np.dot(a, b) / (na * nb)) if min(na, nb) > 1e-10 else 0.0 + + +def build_skip(tok_nospc): + skip = set() + for ch in string.punctuation: + skip.update(tok_nospc.encode(ch).ids) + return skip + + +def onnx_embed(text, tok_nospc, session, skip): + """Same sequence as Vespa PoolingColBertEmbedder: [CLS, D, filtered_tokens, SEP].""" + enc = tok_nospc.encode(PREPEND_DOC + text) + tids = [t for t in enc.ids if t not in skip][:512 - 3] + ids = [CLS, D_TOKEN] + tids + [SEP] + mask = [1] * len(ids) + (out,) = session.run(None, { + "input_ids": np.array([ids], dtype=np.int64), + "attention_mask": np.array([mask], dtype=np.int64), + }) + return out[0, :len(ids)].astype(np.float32) + + +def compare(name, ref, test, min_cos): + if ref.shape != test.shape: + print(f" FAIL {name}: shape {ref.shape} vs {test.shape}") + return False + sims = np.array([cos(ref[i], test[i]) for i in range(len(ref))]) + ok = sims.min() >= min_cos + print(f" {'PASS' if ok else 'FAIL'} {name}: {len(ref)} tokens, " + f"cos mean={sims.mean():.6f} min={sims.min():.6f}") + return ok + + +# ── Pooling (matches Java HierarchicalTokenPooling) ───────────────────────── + +def _pdist(emb): + norms = np.maximum(np.linalg.norm(emb, axis=1, keepdims=True), 1e-10) + s = (emb / norms) @ (emb / norms).T + r, c = np.triu_indices(len(emb), k=1) + return np.clip(1.0 - s[r, c], 0.0, 2.0).astype(np.float64) + +def _ward(d, n): + tot = 2*n-1; D = np.full((tot,tot), np.inf); idx=0 + for i in range(n): + for j in range(i+1,n): d2=float(d[idx])**2; D[i,j]=d2; D[j,i]=d2; idx+=1 + sz=np.ones(tot,dtype=np.int64); act=np.zeros(tot,dtype=bool); act[:n]=True + Z=np.empty((n-1,4)); ch=[] + for s in range(n-1): + if not ch: ch.append(int(np.argmax(act))) + while True: + r=D[ch[-1]].copy(); r[~act]=np.inf; r[ch[-1]]=np.inf; nn=int(np.argmin(r)) + if len(ch)>=2 and nn==ch[-2]: b=ch.pop(); a=ch.pop(); break + ch.append(nn) + if a>b: a,b=b,a + md=D[a,b]; nid=n+s; Z[s]=[a,b,np.sqrt(max(md,0)),sz[a]+sz[b]]; sz[nid]=sz[a]+sz[b] + act[a]=False; act[b]=False; ai=np.where(act)[0] + if len(ai): + na,nb=float(sz[a]),float(sz[b]); nk=sz[ai].astype(np.float64) + dn=((na+nk)*D[a,ai]+(nb+nk)*D[b,ai]-nk*md)/(na+nb+nk) + D[nid,ai]=dn; D[ai,nid]=dn + act[nid]=True + return Z + +def _fclust(Z,n,k): + k=max(1,min(k,n)); nm=n-k + if nm==0: return np.arange(1,n+1,dtype=np.int64) + p=list(range(2*n-1)) + def f(x): + while p[x]!=x: p[x]=p[p[x]]; x=p[x] + return x + si=np.argsort(Z[:,2]) + for m in range(nm): i=int(si[m]); a,b=int(Z[i,0]),int(Z[i,1]); p[f(a)]=n+i; p[f(b)]=n+i + lb=np.empty(n,dtype=np.int64); cm={}; nl=1 + for i in range(n): + r=f(i) + if r not in cm: cm[r]=nl; nl+=1 + lb[i]=cm[r] + return lb + +def pool(emb, pf): + e=emb.astype(np.float64); cls=e[0:1]; t=e[1:]; n=len(t) + if n<=1 or pf<=1: return e.astype(np.float32) + nc=max(1,int(np.ceil(n/pf))) + if nc>=n: return e.astype(np.float32) + Z=_ward(_pdist(t),n); lb=_fclust(Z,n,nc) + po=np.empty((len(np.unique(lb)),t.shape[1]),dtype=np.float64) + for i,l in enumerate(np.unique(lb)): + c=t[lb==l].mean(axis=0); po[i]=c/max(np.linalg.norm(c),1e-10) + return np.vstack([cls,po]).astype(np.float32) + +def binarize(emb): + e=emb.astype(np.float64); n,d=e.shape + b=(e>0).astype(np.uint8).reshape(n,d//8,8) + return np.sum(b*np.array([128,64,32,16,8,4,2,1],dtype=np.uint8),axis=2).astype(np.uint8).view(np.int8) + + +# ── Vespa ──────────────────────────────────────────────────────────────────── + +def vespa_feed(ep, did, text): + import requests + requests.post(f"{ep}/document/v1/doc/doc/docid/{did}", + json={"fields":{"doc_id":did,"title":"test","text":text}}, timeout=60).raise_for_status() + +def vespa_tensors(ep, did): + import requests + r = requests.post(f"{ep}/search/", json={ + "yql": f'select * from doc where doc_id contains "{did}"', + "hits":1, "summary":"tensors"}, timeout=30) + r.raise_for_status() + return r.json()["root"]["children"][0]["fields"] + +def parse_tensor(tj): + if "blocks" in tj: + ks=sorted(tj["blocks"],key=int); return np.array([tj["blocks"][k] for k in ks],dtype=np.float64) + cells=tj["cells"] + mt=max(int(c["address"]["dt"]) for c in cells)+1; mx=max(int(c["address"]["x"]) for c in cells)+1 + a=np.zeros((mt,mx),dtype=np.float64) + for c in cells: a[int(c["address"]["dt"]),int(c["address"]["x"])]=c["value"] + return a + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--endpoint", default=VESPA_ENDPOINT) + p.add_argument("--model-dir", type=Path, default=MODEL_DIR) + p.add_argument("--no-vespa", action="store_true") + args = p.parse_args() + + model_path = args.model_dir / "model_int8.onnx" + tok_path = args.model_dir / "tokenizer.json" + if not model_path.exists(): + sys.exit(f"Model not found: {model_path}") + + ok = True + + # ── pylate FP32 reference ──────────────────────────────────────────── + print("Loading pylate (FP32 PyTorch) ...") + from pylate import models as pm + pylate_model = pm.ColBERT("lightonai/ColBERT-Zero", device="cpu") + pylate_emb = np.array(pylate_model.encode([TEST_DOC], is_query=False, prompt_name="document")[0], dtype=np.float32) + print(f" pylate: {pylate_emb.shape}") + + # ── ONNX INT8 reference (same model file as Vespa) ─────────────────── + print("Loading ONNX INT8 model ...") + import onnxruntime as ort + from tokenizers import Tokenizer + session = ort.InferenceSession(str(model_path), providers=["CPUExecutionProvider"]) + tok_nospc = Tokenizer.from_file(str(tok_path)) + tok_nospc.no_padding(); tok_nospc.no_truncation(); tok_nospc.post_processor = None + skip = build_skip(tok_nospc) + + onnx_emb = onnx_embed(TEST_DOC, tok_nospc, session, skip) + print(f" ONNX INT8: {onnx_emb.shape}") + + # ── Step 1: pylate vs ONNX INT8 (quantisation floor) ───────────────── + print(f"\n{'='*72}") + print("Step 1: pylate FP32 vs ONNX INT8 (quantisation floor)") + print(f"{'='*72}") + ok &= compare("pylate vs ONNX INT8", pylate_emb, onnx_emb, min_cos=0.90) + + # ── Step 2: Pooling (Python reference, various factors) ────────────── + print(f"\n{'='*72}") + print("Step 2: Pooling correctness (Python, on ONNX INT8 embeddings)") + print(f"{'='*72}") + pooled_ref = {} + for pf in POOL_FACTORS: + po = pool(onnx_emb, pf) + n_in, n_out = onnx_emb.shape[0], po.shape[0] + exp = 1 + int(np.ceil((n_in - 1) / pf)) + ok_s = n_out == exp + ok_c = np.allclose(po[0], onnx_emb[0], atol=1e-6) + ok_n = np.allclose(np.linalg.norm(po, axis=1), 1.0, atol=1e-6) + stat = "PASS" if (ok_s and ok_c and ok_n) else "FAIL" + if stat == "FAIL": ok = False + print(f" {stat} factor={pf}: {n_in} -> {n_out} (exp {exp}), CLS={'ok' if ok_c else 'FAIL'}, norms ok={ok_n}") + pooled_ref[pf] = po + bin_ref = binarize(pooled_ref[2]) + + # ── Step 3: Vespa vs ONNX INT8 + Python pooling ───────────────────── + if args.no_vespa: + print(f"\n{'='*72}") + print("Step 3: Vespa comparison SKIPPED (--no-vespa)") + print(f"{'='*72}") + else: + print(f"\n{'='*72}") + print(f"Step 3: Vespa vs ONNX INT8 reference ({args.endpoint})") + print(f"{'='*72}") + try: + import requests + did = "verify_test_0" + vespa_feed(args.endpoint, did, TEST_DOC) + time.sleep(1) + f = vespa_tensors(args.endpoint, did) + + vc = parse_tensor(f["colbert"]) + print(f" Vespa colbert: {vc.shape}") + ok &= compare("non-pooled (ONNX INT8 vs Vespa)", onnx_emb.astype(np.float64), vc, min_cos=0.999) + + vp = parse_tensor(f["colbert_pooled"]) + rp = pooled_ref[2] + print(f" Vespa colbert_pooled: {vp.shape}, ref: {rp.shape}") + if vp.shape == rp.shape: + ok &= compare("pooled factor=2 (Python vs Vespa)", rp.astype(np.float64), vp, min_cos=0.999) + else: + print(f" FAIL shape mismatch"); ok = False + + vb = parse_tensor(f["colbert_pooled_binary"]).astype(np.int8) + m = int(np.sum(vb == bin_ref)); t = vb.size; p = 100*m/t + stat = "PASS" if p == 100 else "FAIL" + if stat == "FAIL": ok = False + print(f" {stat} binary: {m}/{t} bytes match ({p:.1f}%)") + + # Also compare Vespa directly against pylate + print() + ok &= compare("non-pooled (pylate FP32 vs Vespa)", pylate_emb.astype(np.float64), vc, min_cos=0.90) + + requests.delete(f"{args.endpoint}/document/v1/doc/doc/docid/{did}", timeout=10) + except Exception as e: + if "Connection" in str(type(e).__name__) or "Connection" in str(e): + print(f" SKIP: cannot connect to {args.endpoint}") + else: + print(f" ERROR: {e}"); ok = False + + print(f"\n{'='*72}") + print(f"RESULT: {'ALL CHECKS PASSED' if ok else 'SOME CHECKS FAILED'}") + print(f"{'='*72}") + sys.exit(0 if ok else 1) + + +if __name__ == "__main__": + main()