diff --git a/build_support/x86_profiles.py b/build_support/x86_profiles.py index cb0f1904a..e9511769c 100644 --- a/build_support/x86_profiles.py +++ b/build_support/x86_profiles.py @@ -4,7 +4,7 @@ from dataclasses import dataclass from typing import Iterable -DEFAULT_X86_VARIANTS = ("sse3", "avx2", "avx512") +DEFAULT_X86_VARIANTS = ("sse3", "avx2", "avx512", "amx") KNOWN_X86_VARIANTS = frozenset(DEFAULT_X86_VARIANTS) X86_ARCHITECTURES = ("x86_64", "amd64", "x64", "i386", "i686") diff --git a/openviking/retrieve/multi_aspect_retriever.py b/openviking/retrieve/multi_aspect_retriever.py new file mode 100644 index 000000000..fdd871e6f --- /dev/null +++ b/openviking/retrieve/multi_aspect_retriever.py @@ -0,0 +1,332 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: AGPL-3.0 +""" +Multi-aspect retriever for OpenClaw recall. + +Embeds a user query with N different instruction prompts to capture different +semantic perspectives (semantic similarity, entity matching, temporal events, +procedural knowledge, etc.), then batch-searches all N vectors simultaneously +using search_batch() — which dispatches to AMX INT8 tile computation on +supported hardware. + +Architecture:: + + User query: "How does the auth module work?" + │ + ┌────┴──────────────────────────────────────────┐ + │ Multi-prompt Embedding (N aspects) │ + │ │ + │ "Find semantically similar: ..." → v_sem │ + │ "Find entities related to: ..." → v_ent │ + │ "Find procedures about: ..." → v_proc │ + │ "Find events related to: ..." → v_temp │ + └────┬──────────────────────────────────────────┘ + │ N vectors + ▼ + ┌────────────────────────────────────────────────┐ + │ search_batch([v_sem, v_ent, v_proc, v_temp]) │ + │ → AMX INT8 tiles process N queries in 1 pass │ + └────┬──────────────────────────────────────────┘ + │ N result sets + ▼ + ┌────────────────────────────────────────────────┐ + │ Reciprocal Rank Fusion (RRF) │ + │ → Diverse, multi-perspective ranked results │ + └────────────────────────────────────────────────┘ + +Usage:: + + from openviking.retrieve.multi_aspect_retriever import MultiAspectRetriever + + retriever = MultiAspectRetriever(embedder=my_embedder) + results = retriever.retrieve(query_text, engine=idx, topk=10) +""" +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass, field +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple + +from openviking.models.embedder.base import EmbedderBase, EmbedResult + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Default aspect definitions for OpenClaw memory recall +# --------------------------------------------------------------------------- + +@dataclass(frozen=True) +class AspectPrompt: + """An instruction prefix that steers the embedder toward a specific + semantic perspective.""" + name: str + instruction: str + + +#: Built-in aspects tuned for OpenClaw's memory/resource/skill recall. +#: These mirror the facets a user implicitly cares about when querying +#: a personal knowledge base. +DEFAULT_ASPECTS: Tuple[AspectPrompt, ...] = ( + AspectPrompt("semantic", "Find memories semantically similar to: "), + AspectPrompt("entity", "Find memories mentioning entities in: "), + AspectPrompt("temporal", "Find memories about events related to: "), + AspectPrompt("procedural", "Find memories about procedures for: "), +) + + +# --------------------------------------------------------------------------- +# Multi-aspect embedding helper +# --------------------------------------------------------------------------- + +@dataclass +class MultiAspectEmbedResult: + """Result of embedding a single text with N aspect instructions.""" + text: str + aspects: List[AspectPrompt] + vectors: List[List[float]] # one dense vector per aspect + embed_time_us: float = 0.0 # total embedding wall-clock µs + + @property + def n_aspects(self) -> int: + return len(self.vectors) + + +def embed_multi_aspect( + embedder: EmbedderBase, + text: str, + aspects: Sequence[AspectPrompt] = DEFAULT_ASPECTS, +) -> MultiAspectEmbedResult: + """Embed *text* once per aspect instruction. + + Each aspect prepends its instruction to the raw text before calling + ``embedder.embed()``. This is the standard approach for + instruction-following embedding models (E5-instruct, BGE-en-ICL, …). + + Returns a :class:`MultiAspectEmbedResult` containing all N vectors. + """ + vectors: List[List[float]] = [] + t0 = time.perf_counter() + for asp in aspects: + prefixed = asp.instruction + text + result: EmbedResult = embedder.embed(prefixed, is_query=True) + if result.dense_vector is not None: + vectors.append(result.dense_vector) + else: + raise RuntimeError( + f"Embedder returned no dense vector for aspect '{asp.name}'" + ) + elapsed_us = (time.perf_counter() - t0) * 1e6 + return MultiAspectEmbedResult( + text=text, + aspects=list(aspects), + vectors=vectors, + embed_time_us=elapsed_us, + ) + + +# --------------------------------------------------------------------------- +# Reciprocal Rank Fusion (RRF) +# --------------------------------------------------------------------------- + +@dataclass +class FusedResult: + """A single item after RRF fusion across multiple aspect result sets.""" + label: int + rrf_score: float + contributing_aspects: List[str] # which aspects contributed this label + per_aspect_rank: Dict[str, int] # aspect_name → 0-based rank (if present) + + +def reciprocal_rank_fusion( + aspect_names: List[str], + label_lists: List[List[int]], + score_lists: List[List[float]], + topk: int = 10, + k: int = 60, +) -> List[FusedResult]: + """Fuse N ranked result lists using Reciprocal Rank Fusion. + + RRF score for document d: ``sum_over_aspects( 1 / (k + rank_i(d)) )`` + + Args: + aspect_names: name of each aspect (length N) + label_lists: per-aspect label arrays (length N) + score_lists: per-aspect score arrays (length N, unused by RRF but + available for tie-breaking) + topk: how many fused results to return + k: RRF constant (default 60, as in the original paper) + + Returns: + Top-k :class:`FusedResult` sorted by RRF score descending. + """ + rrf_scores: Dict[int, float] = {} + contributors: Dict[int, List[str]] = {} + ranks: Dict[int, Dict[str, int]] = {} + + for asp_name, labels in zip(aspect_names, label_lists): + for rank, label in enumerate(labels): + rrf_scores[label] = rrf_scores.get(label, 0.0) + 1.0 / (k + rank + 1) + contributors.setdefault(label, []).append(asp_name) + ranks.setdefault(label, {})[asp_name] = rank + + sorted_items = sorted(rrf_scores.items(), key=lambda x: x[1], reverse=True) + + results = [] + for label, score in sorted_items[:topk]: + results.append(FusedResult( + label=label, + rrf_score=score, + contributing_aspects=contributors[label], + per_aspect_rank=ranks[label], + )) + return results + + +# --------------------------------------------------------------------------- +# Main retriever class +# --------------------------------------------------------------------------- + +class MultiAspectRetriever: + """Retriever that embeds a query from multiple semantic perspectives + and batch-searches all vectors in one engine call. + + This is designed to plug into the OpenClaw recall path as an alternative + to (or enhancement of) :class:`HierarchicalRetriever`. While the + hierarchical retriever searches with one query vector across a directory + tree, the multi-aspect retriever searches with N vectors across a flat + scope — ideal for brute-force indexes where AMX batch acceleration + provides significant speedup. + + Example:: + + from openviking.retrieve.multi_aspect_retriever import ( + MultiAspectRetriever, DEFAULT_ASPECTS, + ) + import openviking.storage.vectordb.engine as engine + + idx = engine.IndexEngine(config_json) + # ... add data ... + + retriever = MultiAspectRetriever(embedder=my_embedder) + + # Serial mode (N × search) + results = retriever.retrieve( + "How does auth work?", engine=idx, topk=10, mode="serial", + ) + + # Batch mode (1 × search_batch, AMX accelerated) + results = retriever.retrieve( + "How does auth work?", engine=idx, topk=10, mode="batch", + ) + """ + + def __init__( + self, + embedder: EmbedderBase, + aspects: Sequence[AspectPrompt] = DEFAULT_ASPECTS, + rrf_k: int = 60, + ): + self.embedder = embedder + self.aspects = list(aspects) + self.rrf_k = rrf_k + + # -- public API --------------------------------------------------------- + + def retrieve( + self, + query: str, + engine, # engine.IndexEngine + topk: int = 10, + dsl: str = "{}", + mode: str = "batch", + ) -> RetrieveResult: + """Run multi-aspect retrieval. + + Args: + query: raw query text + engine: an ``IndexEngine`` instance with ``search()`` and + ``search_batch()`` methods + topk: per-aspect top-k (RRF will re-rank the union) + dsl: DSL filter string (applied identically to all aspects) + mode: ``"batch"`` (1 × search_batch) or ``"serial"`` + (N × search) + + Returns: + A :class:`RetrieveResult` with fused results and timing. + """ + # Step 1: Multi-aspect embedding + multi = embed_multi_aspect(self.embedder, query, self.aspects) + + # Step 2: Vector search (batch or serial) + import openviking.storage.vectordb.engine as eng + + search_results = [] + t0 = time.perf_counter() + + if mode == "batch": + reqs = [] + for vec in multi.vectors: + sq = eng.SearchRequest() + sq.query = vec + sq.topk = topk + sq.dsl = dsl + reqs.append(sq) + search_results = engine.search_batch(reqs) + else: + for vec in multi.vectors: + sq = eng.SearchRequest() + sq.query = vec + sq.topk = topk + sq.dsl = dsl + search_results.append(engine.search(sq)) + + search_time_us = (time.perf_counter() - t0) * 1e6 + + # Step 3: RRF fusion + t1 = time.perf_counter() + aspect_names = [a.name for a in self.aspects] + label_lists = [r.labels for r in search_results] + score_lists = [r.scores for r in search_results] + fused = reciprocal_rank_fusion( + aspect_names, label_lists, score_lists, + topk=topk, k=self.rrf_k, + ) + fusion_time_us = (time.perf_counter() - t1) * 1e6 + + # Diversity metric: unique labels / (N × topk) + all_labels = set() + for labels in label_lists: + all_labels.update(labels) + total_possible = len(self.aspects) * topk + diversity = len(all_labels) / total_possible if total_possible else 0.0 + + return RetrieveResult( + query=query, + mode=mode, + n_aspects=len(self.aspects), + fused_results=fused, + per_aspect_results=search_results, + embed_time_us=multi.embed_time_us, + search_time_us=search_time_us, + fusion_time_us=fusion_time_us, + diversity=diversity, + ) + + +@dataclass +class RetrieveResult: + """Complete result of a multi-aspect retrieval.""" + query: str + mode: str + n_aspects: int + fused_results: List[FusedResult] + per_aspect_results: list # List[SearchResult] + embed_time_us: float + search_time_us: float + fusion_time_us: float + diversity: float + + @property + def total_time_us(self) -> float: + return self.embed_time_us + self.search_time_us + self.fusion_time_us diff --git a/openviking/server/app.py b/openviking/server/app.py index 5e1a0e287..213b60ccc 100644 --- a/openviking/server/app.py +++ b/openviking/server/app.py @@ -104,9 +104,9 @@ async def lifespan(app: FastAPI): app.state.api_key_manager = None logger.warning( "Dev mode: no root_api_key configured, authentication disabled. " - "This is allowed because the server is bound to localhost (%s). " - "Do NOT expose this server to the network without configuring " - "server.root_api_key in ov.conf.", + "Server is bound to %s. Do NOT expose this server to the network " + "unless unauthenticated ROOT access is intentional or until " + "server.root_api_key is configured in ov.conf.", config.host, ) diff --git a/openviking/storage/vectordb/engine/__init__.py b/openviking/storage/vectordb/engine/__init__.py index 564015652..e1db82e32 100644 --- a/openviking/storage/vectordb/engine/__init__.py +++ b/openviking/storage/vectordb/engine/__init__.py @@ -18,14 +18,16 @@ "x86_sse3": "_x86_sse3", "x86_avx2": "_x86_avx2", "x86_avx512": "_x86_avx512", + "x86_amx": "_x86_amx", "native": "_native", } -_X86_DISPLAY_ORDER = ("x86_sse3", "x86_avx2", "x86_avx512") -_X86_PRIORITY = ("x86_avx512", "x86_avx2", "x86_sse3") +_X86_DISPLAY_ORDER = ("x86_sse3", "x86_avx2", "x86_avx512", "x86_amx") +_X86_PRIORITY = ("x86_amx", "x86_avx512", "x86_avx2", "x86_sse3") _REQUEST_ALIASES = { "sse3": "x86_sse3", "avx2": "x86_avx2", "avx512": "x86_avx512", + "amx": "x86_amx", } _WINDOWS_DLL_DIR_HANDLES = [] diff --git a/openviking/storage/vectordb/engine/_python_api.py b/openviking/storage/vectordb/engine/_python_api.py index 549aec11d..c72b6f86b 100644 --- a/openviking/storage/vectordb/engine/_python_api.py +++ b/openviking/storage/vectordb/engine/_python_api.py @@ -439,6 +439,12 @@ def delete_data(self, data_list: list[DeleteDataRequest]) -> int: def search(self, req: SearchRequest) -> SearchResult: return SearchResult.from_backend(self._backend._index_engine_search(self._handle, req)) + def search_batch(self, reqs: list[SearchRequest]) -> list[SearchResult]: + results = self._backend._index_engine_search_batch( + self._handle, list(reqs) + ) + return [SearchResult.from_backend(r) for r in results] + def dump(self, path: str) -> int: return int(self._backend._index_engine_dump(self._handle, path)) diff --git a/openviking/telemetry/operation.py b/openviking/telemetry/operation.py index 08fdd32f3..497cef03a 100644 --- a/openviking/telemetry/operation.py +++ b/openviking/telemetry/operation.py @@ -142,6 +142,14 @@ def build( llm_output_tokens = cls._i(counters.get("tokens.llm.output"), 0) llm_total_tokens = cls._i(counters.get("tokens.llm.total"), 0) embedding_total_tokens = cls._i(counters.get("tokens.embedding.total"), 0) + embedding_duration_ms = round(float(counters.get("embedding.duration_ms", 0.0) or 0.0), 3) + embedding_requests = cls._i(counters.get("embedding.requests"), 0) + embedding_error_count = cls._i(counters.get("embedding.error_count"), 0) + embedding_wall_duration_ms = gauges.get("embedding.wall_duration_ms") + if embedding_wall_duration_ms is None: + embedding_wall_duration_ms = embedding_duration_ms + else: + embedding_wall_duration_ms = round(float(embedding_wall_duration_ms or 0.0), 3) vector_candidates_scored = cls._i(counters.get("vector.scored"), 0) vectors_scanned = gauges.get("vector.scanned") if vectors_scanned is None: @@ -165,15 +173,43 @@ def build( }, } + if embedding_total_tokens or embedding_duration_ms or embedding_requests or embedding_error_count: + effective_embedding_ms = embedding_wall_duration_ms or embedding_duration_ms + summary["embedding"] = { + "duration_ms": embedding_duration_ms, + "wall_duration_ms": effective_embedding_ms, + "requests": embedding_requests, + "error_count": embedding_error_count, + "avg_duration_ms": round(embedding_duration_ms / embedding_requests, 3) + if embedding_requests + else 0.0, + "share_of_total_pct": round((effective_embedding_ms / duration_ms) * 100.0, 3) + if duration_ms > 0 + else 0.0, + } + if cls._has_metric_prefix("queue", counters, gauges): summary["queue"] = { + "wait_duration_ms": round(float(gauges.get("queue.wait.duration_ms") or 0.0), 3), "semantic": { "processed": cls._i(gauges.get("queue.semantic.processed"), 0), "error_count": cls._i(gauges.get("queue.semantic.error_count"), 0), + "duration_ms": round( + float(gauges.get("queue.semantic.duration_ms") or 0.0), 3 + ), + "wall_duration_ms": round( + float(gauges.get("queue.semantic.wall_duration_ms") or 0.0), 3 + ), }, "embedding": { "processed": cls._i(gauges.get("queue.embedding.processed"), 0), "error_count": cls._i(gauges.get("queue.embedding.error_count"), 0), + "duration_ms": round( + float(gauges.get("queue.embedding.duration_ms") or 0.0), 3 + ), + "wall_duration_ms": round( + float(gauges.get("queue.embedding.wall_duration_ms") or 0.0), 3 + ), }, } diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1a79b8a43..df8a5d741 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -5,7 +5,7 @@ project(openviking_cpp) include(CheckCXXCompilerFlag) include(CMakeParseArguments) -set(OV_X86_BUILD_VARIANTS "sse3;avx2;avx512" CACHE STRING "x86 engine variants to build") +set(OV_X86_BUILD_VARIANTS "sse3;avx2;avx512;amx" CACHE STRING "x86 engine variants to build") set(OV_PY_OUTPUT_DIR "" CACHE PATH "Output directory for Python extension modules") set(OV_PY_EXT_SUFFIX ".so" CACHE STRING "Python extension suffix, including ABI tag if needed") set(OV_PYTHON_SABI_LIBRARY "" CACHE FILEPATH "Stable-ABI Python import library or DLL for Windows abi3 modules") @@ -178,6 +178,17 @@ function(ov_get_x86_variant_flags variant out_flags out_defs out_supported) list(APPEND OV_FLAGS ${FLAG}) endif() endforeach() + elseif(OV_VARIANT STREQUAL "amx") + foreach(FLAG -mavx512f -mavx512bw -mavx512dq -mavx512vl -mavx512vnni -mamx-tile -mamx-int8) + string(REPLACE "-" "_" FLAG_VAR_SUFFIX "${FLAG}") + set(FLAG_VAR "HAVE_${FLAG_VAR_SUFFIX}") + check_cxx_compiler_flag("${FLAG}" ${FLAG_VAR}) + if(NOT ${FLAG_VAR}) + set(OV_SUPPORTED FALSE) + else() + list(APPEND OV_FLAGS ${FLAG}) + endif() + endforeach() else() set(OV_SUPPORTED FALSE) endif() diff --git a/src/abi3_engine_backend.cpp b/src/abi3_engine_backend.cpp index 68bd0f049..c6a6788fe 100644 --- a/src/abi3_engine_backend.cpp +++ b/src/abi3_engine_backend.cpp @@ -1222,6 +1222,57 @@ PyObject* py_index_engine_search(PyObject*, PyObject* args) { } } +PyObject* py_index_engine_search_batch(PyObject*, PyObject* args) { + PyObject* capsule = nullptr; + PyObject* request_list = nullptr; + if (!PyArg_ParseTuple(args, "OO", &capsule, &request_list)) { + return nullptr; + } + + auto* engine = capsule_to_ptr(capsule, kIndexCapsuleName); + if (engine == nullptr) { + return nullptr; + } + + if (!PyList_Check(request_list)) { + raise_runtime_error("search_batch requires a list of SearchRequest objects"); + return nullptr; + } + + Py_ssize_t nq = PyList_Size(request_list); + std::vector requests(static_cast(nq)); + for (Py_ssize_t i = 0; i < nq; ++i) { + PyObject* req_obj = PyList_GetItem(request_list, i); + if (!parse_search_request(req_obj, &requests[static_cast(i)])) { + return nullptr; + } + } + + try { + auto results = call_without_gil([&]() { + return engine->search_batch(requests); + }); + + PyObject* result_list = PyList_New(static_cast(results.size())); + if (result_list == nullptr) { + return nullptr; + } + + for (size_t i = 0; i < results.size(); ++i) { + PyObject* result = build_search_result(results[i]); + if (result == nullptr) { + Py_DECREF(result_list); + return nullptr; + } + PyList_SetItem(result_list, static_cast(i), result); + } + return result_list; + } catch (const std::exception& exc) { + raise_runtime_error(exc.what()); + return nullptr; + } +} + PyObject* py_index_engine_dump(PyObject*, PyObject* args) { PyObject* capsule = nullptr; const char* path = nullptr; @@ -1498,6 +1549,8 @@ PyMethodDef kModuleMethods[] = { "Delete data from the index engine."}, {"_index_engine_search", py_index_engine_search, METH_VARARGS, "Search the index engine."}, + {"_index_engine_search_batch", py_index_engine_search_batch, METH_VARARGS, + "Batch search the index engine with multiple queries."}, {"_index_engine_dump", py_index_engine_dump, METH_VARARGS, "Dump index state to disk."}, {"_index_engine_get_state", py_index_engine_get_state, METH_VARARGS, "Read index engine state."}, diff --git a/src/abi3_x86_caps.cpp b/src/abi3_x86_caps.cpp index f6aff1a3d..e03dc1e96 100644 --- a/src/abi3_x86_caps.cpp +++ b/src/abi3_x86_caps.cpp @@ -26,6 +26,9 @@ struct CpuFeatures { bool avx512dq = false; bool avx512bw = false; bool avx512vl = false; + bool avx512_vnni = false; + bool amx_tile = false; + bool amx_int8 = false; }; #if defined(__x86_64__) || defined(_M_X64) || defined(__i386__) || defined(_M_IX86) @@ -78,12 +81,27 @@ CpuFeatures detect_cpu_features() { features.avx512bw = (regs[1] & (1 << 30)) != 0; features.avx512vl = (regs[1] & (1u << 31)) != 0; + // AVX512-VNNI: CPUID leaf 7, subleaf 0, ECX bit 11 + features.avx512_vnni = (regs[2] & (1 << 11)) != 0; + + // AMX: CPUID leaf 7, subleaf 0, EDX bit 24 (AMX-TILE), bit 25 (AMX-INT8) + features.amx_tile = (regs[3] & (1 << 24)) != 0; + features.amx_int8 = (regs[3] & (1 << 25)) != 0; + const bool avx512_os = (xcr0 & 0xe6) == 0xe6; if (!avx512_os) { features.avx512f = false; features.avx512dq = false; features.avx512bw = false; features.avx512vl = false; + features.avx512_vnni = false; + } + + // AMX requires XCR0 bits 17 (TILECFG) and 18 (TILEDATA) + const bool amx_os = (xcr0 & 0x60000) == 0x60000; + if (!amx_os) { + features.amx_tile = false; + features.amx_int8 = false; } return features; @@ -104,6 +122,11 @@ std::vector get_supported_variants_impl() { if (features.avx && features.avx512f && features.avx512dq && features.avx512bw && features.avx512vl) { variants.emplace_back("x86_avx512"); + if (features.avx512_vnni) { + if (features.amx_tile && features.amx_int8) { + variants.emplace_back("x86_amx"); + } + } } return variants; } diff --git a/src/index/detail/index_manager_impl.cpp b/src/index/detail/index_manager_impl.cpp index 4d165d10a..4ddcb47eb 100644 --- a/src/index/detail/index_manager_impl.cpp +++ b/src/index/detail/index_manager_impl.cpp @@ -279,6 +279,61 @@ int IndexManagerImpl::perform_vector_recall(const SearchRequest& req, return 0; } +int IndexManagerImpl::search_batch(const std::vector& reqs, + std::vector& results) { + if (reqs.empty()) return 0; + + // Use first request's DSL for all queries (batch assumes shared filter) + const auto& dsl = reqs[0].dsl; + SearchContext ctx; + if (!dsl.empty()) { + if (int ret = parse_dsl_query(dsl, ctx); ret != 0) { + SPDLOG_ERROR("IndexManagerImpl::search_batch DSL parse fail: {}", dsl); + return ret; + } + } + + // Sorter queries are not supported in batch mode + if (ctx.sorter_op) { + return IndexManager::search_batch(reqs, results); + } + + std::shared_lock lock(rw_mutex_); + + BitmapPtr bitmap = nullptr; + if (ctx.filter_op) { + bitmap = calculate_filter_bitmap(ctx, dsl); + if (!bitmap) { + SPDLOG_DEBUG("search_batch: calculate_filter_bitmap returned null"); + return -1; + } + } + + // Extract query vectors + std::vector query_ptrs(reqs.size()); + for (size_t i = 0; i < reqs.size(); ++i) { + query_ptrs[i] = reqs[i].query.data(); + } + uint32_t topk = reqs[0].topk; + + std::vector recall_results; + int ret = vector_index_->recall_batch(query_ptrs, topk, bitmap.get(), + recall_results); + if (ret != 0) { + SPDLOG_ERROR("search_batch: recall_batch failed, ret={}", ret); + return ret; + } + + results.resize(reqs.size()); + for (size_t i = 0; i < reqs.size(); ++i) { + std::swap(results[i].labels, recall_results[i].labels); + std::swap(results[i].scores, recall_results[i].scores); + results[i].result_num = results[i].labels.size(); + } + + return 0; +} + int IndexManagerImpl::add_data(const std::vector& data_list) { auto start = std::chrono::high_resolution_clock::now(); std::vector parsed_fields_list(data_list.size()); diff --git a/src/index/detail/index_manager_impl.h b/src/index/detail/index_manager_impl.h index d7b64019e..6c40ca309 100644 --- a/src/index/detail/index_manager_impl.h +++ b/src/index/detail/index_manager_impl.h @@ -29,6 +29,9 @@ class IndexManagerImpl : public IndexManager { int search(const SearchRequest& req, SearchResult& result) override; + int search_batch(const std::vector& reqs, + std::vector& results) override; + int add_data(const std::vector& data_list) override; int delete_data(const std::vector& data_list) override; diff --git a/src/index/detail/vector/common/bruteforce.h b/src/index/detail/vector/common/bruteforce.h index cb3380691..d164b8dd8 100644 --- a/src/index/detail/vector/common/bruteforce.h +++ b/src/index/detail/vector/common/bruteforce.h @@ -205,20 +205,80 @@ class BruteforceSearch { void* dist_params = space_->get_metric_params(); if (!filter_bitmap) { - for (size_t i = 0; i < current_count_; ++i) { - char* ptr = data_buffer_ + (i * element_byte_size_); - - float dist = compute_score(encoded_query.data(), ptr, query_sparse_view, - i, dist_func, dist_params); - - uint64_t label; - std::memcpy(&label, ptr + vector_byte_size_, sizeof(uint64_t)); - - if (pq.size() < k) { - pq.emplace(dist, label); - } else if (dist > pq.top().first) { - pq.pop(); - pq.emplace(dist, label); +#if defined(OV_SIMD_AMX) + // AMX batch path: process int8 vectors in blocks of 16 using tiles + if (meta_->quantization_type == "int8") { + const int8_t* query_int8 = + reinterpret_cast(encoded_query.data()); + const float* query_meta_ptr = + reinterpret_cast( + encoded_query.data() + meta_->dimension); + float query_scale = query_meta_ptr[0]; + bool is_l2 = reverse_query_score_; + float query_norm_sq = is_l2 ? query_meta_ptr[1] : 0.0f; + + constexpr size_t kAmxBlock = 16; + int32_t ip_results[kAmxBlock]; + + for (size_t blk = 0; blk < current_count_; blk += kAmxBlock) { + size_t n = std::min(kAmxBlock, current_count_ - blk); + const char* blk_base = + data_buffer_ + blk * element_byte_size_; + + batch_inner_product_int8_amx( + blk_base, element_byte_size_, + query_int8, ip_results, n, meta_->dimension); + + for (size_t j = 0; j < n; ++j) { + size_t idx = blk + j; + char* ptr = data_buffer_ + (idx * element_byte_size_); + const float* db_meta_ptr = + reinterpret_cast(ptr + meta_->dimension); + float db_scale = db_meta_ptr[0]; + + float real_ip = + static_cast(ip_results[j]) * query_scale * db_scale; + float raw_dist; + if (is_l2) { + float db_norm_sq = db_meta_ptr[1]; + raw_dist = std::max(0.0f, + query_norm_sq + db_norm_sq - 2.0f * real_ip); + } else { + raw_dist = real_ip; + } + + float score = + finalize_score(raw_dist, query_sparse_view, idx); + + uint64_t label; + std::memcpy(&label, ptr + vector_byte_size_, sizeof(uint64_t)); + + if (pq.size() < k) { + pq.emplace(score, label); + } else if (score > pq.top().first) { + pq.pop(); + pq.emplace(score, label); + } + } + } + } else +#endif + { + for (size_t i = 0; i < current_count_; ++i) { + char* ptr = data_buffer_ + (i * element_byte_size_); + + float dist = compute_score(encoded_query.data(), ptr, query_sparse_view, + i, dist_func, dist_params); + + uint64_t label; + std::memcpy(&label, ptr + vector_byte_size_, sizeof(uint64_t)); + + if (pq.size() < k) { + pq.emplace(dist, label); + } else if (dist > pq.top().first) { + pq.pop(); + pq.emplace(dist, label); + } } } } else { @@ -265,6 +325,132 @@ class BruteforceSearch { } } + // Batch search: compute top-k for nq queries simultaneously. + // Uses AMX multi-query tiles to process up to 16 queries per tile operation. + void search_knn_batch( + const float* const* query_data_array, + size_t nq, + size_t k, + const Bitmap* filter_bitmap, + std::vector>& labels, + std::vector>& scores) const { + + labels.resize(nq); + scores.resize(nq); + + if (nq == 0 || k == 0 || current_count_ == 0) return; + +#if defined(OV_SIMD_AMX) + // AMX multi-query batch path: int8 quantization, no filter + if (meta_->quantization_type == "int8" && !filter_bitmap) { + // Encode all queries + std::vector> encoded_queries(nq); + std::vector query_int8_ptrs(nq); + std::vector query_scales(nq); + std::vector query_norm_sqs(nq); + bool is_l2 = reverse_query_score_; + + for (size_t q = 0; q < nq; ++q) { + encoded_queries[q].resize(vector_byte_size_); + quantizer_->encode(query_data_array[q], meta_->dimension, + encoded_queries[q].data()); + query_int8_ptrs[q] = + reinterpret_cast(encoded_queries[q].data()); + const float* meta_ptr = reinterpret_cast( + encoded_queries[q].data() + meta_->dimension); + query_scales[q] = meta_ptr[0]; + query_norm_sqs[q] = is_l2 ? meta_ptr[1] : 0.0f; + } + + using ResultPair = std::pair; + std::vector, + std::greater>> + pqs(nq); + + constexpr size_t kAmxBlock = 16; + constexpr size_t kMaxTileN = 16; + + // Process queries in groups of ≤16 (tile N dimension limit) + for (size_t q_start = 0; q_start < nq; q_start += kMaxTileN) { + size_t q_count = std::min(kMaxTileN, nq - q_start); + + const int8_t* batch_q[kMaxTileN]; + for (size_t q = 0; q < q_count; ++q) { + batch_q[q] = query_int8_ptrs[q_start + q]; + } + + int32_t ip_results[kAmxBlock * kMaxTileN]; + + for (size_t blk = 0; blk < current_count_; blk += kAmxBlock) { + size_t n = std::min(kAmxBlock, current_count_ - blk); + const char* blk_base = data_buffer_ + blk * element_byte_size_; + + batch_inner_product_int8_amx_multi_query( + blk_base, element_byte_size_, batch_q, ip_results, + n, q_count, meta_->dimension); + + for (size_t j = 0; j < n; ++j) { + size_t idx = blk + j; + char* ptr = data_buffer_ + (idx * element_byte_size_); + const float* db_meta_ptr = + reinterpret_cast(ptr + meta_->dimension); + float db_scale = db_meta_ptr[0]; + float db_norm_sq = is_l2 ? db_meta_ptr[1] : 0.0f; + + uint64_t label; + std::memcpy(&label, ptr + vector_byte_size_, sizeof(uint64_t)); + + for (size_t q = 0; q < q_count; ++q) { + size_t qi = q_start + q; + float real_ip = + static_cast(ip_results[j * q_count + q]) * + query_scales[qi] * db_scale; + float raw_dist; + if (is_l2) { + raw_dist = std::max( + 0.0f, query_norm_sqs[qi] + db_norm_sq - 2.0f * real_ip); + } else { + raw_dist = real_ip; + } + + float score = + reverse_query_score_ ? (1.0f - raw_dist) : raw_dist; + + auto& pq = pqs[qi]; + if (pq.size() < k) { + pq.emplace(score, label); + } else if (score > pq.top().first) { + pq.pop(); + pq.emplace(score, label); + } + } + } + } + } + + // Extract results + for (size_t q = 0; q < nq; ++q) { + auto& pq = pqs[q]; + size_t result_size = pq.size(); + labels[q].resize(result_size); + scores[q].resize(result_size); + for (int i = static_cast(result_size) - 1; i >= 0; --i) { + scores[q][i] = pq.top().first; + labels[q][i] = pq.top().second; + pq.pop(); + } + } + return; + } +#endif + + // Fallback: single-query loop + for (size_t q = 0; q < nq; ++q) { + search_knn(query_data_array[q], k, filter_bitmap, nullptr, + labels[q], scores[q]); + } + } + void save(const std::filesystem::path& dir) { if (meta_) { meta_->element_count = current_count_; @@ -416,6 +602,15 @@ class BruteforceSearch { size_t idx, MetricFunc dist_func, void* dist_params) const { float dense_raw = dist_func(encoded_query, data_ptr, dist_params); + return finalize_score(dense_raw, query_sparse_view, idx); + } + + // Convert a raw distance value (from metric function or AMX batch kernel) + // to a final score, including sparse index blending. + float finalize_score( + float dense_raw, + const std::shared_ptr& query_sparse_view, + size_t idx) const { float dense_score = reverse_query_score_ ? (1.0f - dense_raw) : dense_raw; if (!sparse_index_ || !query_sparse_view || diff --git a/src/index/detail/vector/common/space_int8.h b/src/index/detail/vector/common/space_int8.h index f51d65fde..363493af0 100644 --- a/src/index/detail/vector/common/space_int8.h +++ b/src/index/detail/vector/common/space_int8.h @@ -8,12 +8,34 @@ #include #include -#if defined(OV_SIMD_AVX) +#if defined(OV_SIMD_AVX) || defined(OV_SIMD_AMX) #include #endif +#if defined(OV_SIMD_AMX) +#include +#if defined(__linux__) +#include +#include +#endif +#endif namespace vectordb { +#if defined(OV_SIMD_AMX) +// Linux requires explicit permission for AMX tile data via arch_prctl. +// This must be called once per thread before any AMX tile instruction. +inline void ensure_amx_permission() { + static bool requested = false; + if (!requested) { +#if defined(__linux__) + // ARCH_REQ_XCOMP_PERM = 0x1023, XFEATURE_XTILEDATA = 18 + syscall(SYS_arch_prctl, 0x1023, 18); +#endif + requested = true; + } +} +#endif + static int32_t inner_product_int8_scalar(const void* v1, const void* v2, const void* params) { const int8_t* pv1 = static_cast(v1); @@ -92,6 +114,166 @@ static int32_t inner_product_int8_avx(const void* v1, const void* v2, } #endif +#if defined(OV_SIMD_AMX) +// AMX tile configuration structure (64 bytes, must be 64-byte aligned) +struct OV_ALIGN_64 AmxTileCfg { + uint8_t palette_id; + uint8_t start_row; + uint8_t reserved_0[14]; + uint16_t colsb[16]; + uint8_t rows[16]; +}; + +// Batch inner product using AMX TDPBSSD (signed int8 × signed int8 → int32). +// Computes dot(db[i], query) for i=0..num_vecs-1 simultaneously using tiles. +// db_base: pointer to int8 data of the first vector in the block +// stride: bytes between consecutive vectors (element_byte_size_) +// query: query int8 data (contiguous) +// results: output int32 dot products [num_vecs] +// num_vecs: vectors in this block (1-16) +// dim: vector dimension +static void batch_inner_product_int8_amx( + const char* db_base, size_t stride, + const int8_t* query, + int32_t* results, + size_t num_vecs, + size_t dim) { + + ensure_amx_permission(); + + const size_t dim64 = (dim / 64) * 64; + + if (dim64 > 0) { + // Configure tiles: + // Tile 0 (C/dst): num_vecs rows × 4 bytes (1 int32 per row) + // Tile 1 (A/src1): num_vecs rows × 64 bytes (64 int8s per row) + // Tile 2 (B/src2): 16 rows × 4 bytes (K/4 × N*4) + AmxTileCfg cfg = {}; + cfg.palette_id = 1; + cfg.rows[0] = static_cast(num_vecs); + cfg.colsb[0] = 4; + cfg.rows[1] = static_cast(num_vecs); + cfg.colsb[1] = 64; + cfg.rows[2] = 16; + cfg.colsb[2] = 4; + + _tile_loadconfig(&cfg); + _tile_zero(0); + + for (size_t k = 0; k < dim64; k += 64) { + // A: db vectors block, each row is a 64-byte chunk of one vector + _tile_loadd(1, db_base + k, stride); + // B: query chunk, 16 rows × 4 bytes (natural byte layout, stride=4) + _tile_loadd(2, query + k, 4); + // C += A × B (signed int8 × signed int8 → int32) + _tile_dpbssd(0, 1, 2); + } + + // Store tile C to buffer + int32_t OV_ALIGN_64 c_buf[16] = {}; + _tile_stored(0, c_buf, 4); + _tile_release(); + + for (size_t i = 0; i < num_vecs; ++i) { + results[i] = c_buf[i]; + } + } else { + std::memset(results, 0, num_vecs * sizeof(int32_t)); + } + + // Handle remaining dimensions (dim64..dim) with scalar + for (size_t k = dim64; k < dim; ++k) { + for (size_t i = 0; i < num_vecs; ++i) { + const int8_t* db_vec = + reinterpret_cast(db_base + i * stride); + results[i] += + static_cast(db_vec[k]) * static_cast(query[k]); + } + } +} + +// Multi-query AMX batch inner product. +// Computes dot(db[i], queries[q]) for i=0..num_vecs-1, q=0..nq-1 +// simultaneously using one TDPBSSD per 64-dim chunk. +// results layout (row-major): results[i * nq + q] = dot(db[i], queries[q]) +// num_vecs: DB vectors in this block (1-16) +// nq: number of queries (1-16) +static void batch_inner_product_int8_amx_multi_query( + const char* db_base, size_t stride, + const int8_t* const* queries, + int32_t* results, + size_t num_vecs, + size_t nq, + size_t dim) { + + ensure_amx_permission(); + + const size_t dim64 = (dim / 64) * 64; + const size_t b_stride = nq * 4; // bytes per row of B tile + const size_t c_stride = nq * 4; // bytes per row of C tile + + if (dim64 > 0) { + AmxTileCfg cfg = {}; + cfg.palette_id = 1; + // Tile 0 (C): num_vecs rows × nq int32 cols + cfg.rows[0] = static_cast(num_vecs); + cfg.colsb[0] = static_cast(c_stride); + // Tile 1 (A): num_vecs rows × 64 bytes + cfg.rows[1] = static_cast(num_vecs); + cfg.colsb[1] = 64; + // Tile 2 (B): 16 rows × nq*4 bytes (VNNI interleaved) + cfg.rows[2] = 16; + cfg.colsb[2] = static_cast(b_stride); + + // B tile buffer: 16 rows × max 16 queries × 4 bytes = 1024 bytes max + int8_t OV_ALIGN_64 b_buf[16 * 16 * 4]; + int32_t OV_ALIGN_64 c_buf[16 * 16] = {}; + + _tile_loadconfig(&cfg); + _tile_zero(0); + + for (size_t k = 0; k < dim64; k += 64) { + // Prepare B tile: interleave query data into VNNI layout + // b_buf[row * b_stride + q * 4 + j] = queries[q][k + row * 4 + j] + for (size_t row = 0; row < 16; ++row) { + for (size_t q = 0; q < nq; ++q) { + std::memcpy(b_buf + row * b_stride + q * 4, + queries[q] + k + row * 4, 4); + } + } + + _tile_loadd(1, db_base + k, stride); + _tile_loadd(2, b_buf, b_stride); + _tile_dpbssd(0, 1, 2); + } + + _tile_stored(0, c_buf, c_stride); + _tile_release(); + + for (size_t i = 0; i < num_vecs; ++i) { + for (size_t q = 0; q < nq; ++q) { + results[i * nq + q] = c_buf[i * nq + q]; + } + } + } else { + std::memset(results, 0, num_vecs * nq * sizeof(int32_t)); + } + + // Handle remaining dimensions with scalar + for (size_t k = dim64; k < dim; ++k) { + for (size_t i = 0; i < num_vecs; ++i) { + const int8_t* db_vec = + reinterpret_cast(db_base + i * stride); + for (size_t q = 0; q < nq; ++q) { + results[i * nq + q] += + static_cast(db_vec[k]) * + static_cast(queries[q][k]); + } + } + } +} +#endif + // Distance functions static float inner_product_distance_int8(const void* v1, const void* v2, const void* params) { @@ -108,7 +290,12 @@ static float inner_product_distance_int8(const void* v1, const void* v2, float scale2 = *scale2_ptr; int32_t ip; -#if defined(OV_SIMD_AVX) +#if defined(OV_SIMD_AMX) + batch_inner_product_int8_amx( + static_cast(v1), dim, + static_cast(v2), + &ip, 1, dim); +#elif defined(OV_SIMD_AVX) if (dim >= 32) { ip = inner_product_int8_avx(v1, v2, params); } else { @@ -140,7 +327,12 @@ static float l2_distance_int8(const void* v1, const void* v2, float norm_sq2 = meta2[1]; int32_t ip; -#if defined(OV_SIMD_AVX) +#if defined(OV_SIMD_AMX) + batch_inner_product_int8_amx( + static_cast(v1), dim, + static_cast(v2), + &ip, 1, dim); +#elif defined(OV_SIMD_AVX) if (dim >= 32) { ip = inner_product_int8_avx(v1, v2, params); } else { diff --git a/src/index/detail/vector/common/vector_base.h b/src/index/detail/vector/common/vector_base.h index 8bac1319a..cc0146c5c 100644 --- a/src/index/detail/vector/common/vector_base.h +++ b/src/index/detail/vector/common/vector_base.h @@ -14,6 +14,9 @@ #if defined(__AVX512F__) && !defined(OV_DISABLE_AVX512) #define OV_SIMD_AVX512 #endif +#if defined(__AMX_INT8__) && defined(__AMX_TILE__) +#define OV_SIMD_AMX +#endif #if defined(__AVX__) #define OV_SIMD_AVX #endif diff --git a/src/index/detail/vector/vector_index_adapter.h b/src/index/detail/vector/vector_index_adapter.h index accf73540..301889f09 100644 --- a/src/index/detail/vector/vector_index_adapter.h +++ b/src/index/detail/vector/vector_index_adapter.h @@ -41,6 +41,24 @@ class VectorIndexAdapter { virtual uint64_t get_label_by_offset(const int& offset) { return 0; } + + // Batch recall: all queries share the same filter. + // Default implementation loops over single-query recall. + virtual int recall_batch( + const std::vector& dense_vectors, + uint32_t topk, + const Bitmap* filter_bitmap, + std::vector& results) { + results.resize(dense_vectors.size()); + for (size_t i = 0; i < dense_vectors.size(); ++i) { + VectorRecallRequest req; + req.dense_vector = dense_vectors[i]; + req.topk = topk; + req.bitmap = filter_bitmap; + recall(req, results[i]); + } + return 0; + } }; class BruteForceIndex : public VectorIndexAdapter { @@ -104,6 +122,24 @@ class BruteForceIndex : public VectorIndexAdapter { return index_->get_label_by_offset(offset); } + int recall_batch( + const std::vector& dense_vectors, + uint32_t topk, + const Bitmap* filter_bitmap, + std::vector& results) override { + size_t nq = dense_vectors.size(); + std::vector> batch_labels; + std::vector> batch_scores; + index_->search_knn_batch(dense_vectors.data(), nq, topk, + filter_bitmap, batch_labels, batch_scores); + results.resize(nq); + for (size_t i = 0; i < nq; ++i) { + std::swap(results[i].labels, batch_labels[i]); + std::swap(results[i].scores, batch_scores[i]); + } + return 0; + } + private: std::shared_ptr meta_; std::shared_ptr index_; diff --git a/src/index/index_engine.cpp b/src/index/index_engine.cpp index 7a5fd8c2c..c1d360b2f 100644 --- a/src/index/index_engine.cpp +++ b/src/index/index_engine.cpp @@ -17,6 +17,13 @@ SearchResult IndexEngine::search(const SearchRequest& req) { return result; } +std::vector IndexEngine::search_batch( + const std::vector& reqs) { + std::vector results; + impl_->search_batch(reqs, results); + return results; +} + int IndexEngine::add_data(const std::vector& data_list) { return impl_->add_data(data_list); } diff --git a/src/index/index_engine.h b/src/index/index_engine.h index 55d130db0..9aaea216c 100644 --- a/src/index/index_engine.h +++ b/src/index/index_engine.h @@ -24,6 +24,8 @@ class IndexEngine { SearchResult search(const SearchRequest& req); + std::vector search_batch(const std::vector& reqs); + int64_t dump(const std::string& dir); StateResult get_state(); diff --git a/src/index/index_manager.h b/src/index/index_manager.h index 3da13ad99..2e49663fc 100644 --- a/src/index/index_manager.h +++ b/src/index/index_manager.h @@ -14,6 +14,18 @@ class IndexManager { virtual int search(const SearchRequest& req, SearchResult& result) = 0; + // Batch search: process multiple queries sharing the same filter. + // Default implementation loops over single-query search. + virtual int search_batch(const std::vector& reqs, + std::vector& results) { + results.resize(reqs.size()); + for (size_t i = 0; i < reqs.size(); ++i) { + int ret = search(reqs[i], results[i]); + if (ret != 0) return ret; + } + return 0; + } + virtual int add_data(const std::vector& data_list) = 0; virtual int delete_data(const std::vector& data_list) = 0; diff --git a/tests/unit/test_embedding_config_dimension.py b/tests/unit/test_embedding_config_dimension.py new file mode 100644 index 000000000..8535c5298 --- /dev/null +++ b/tests/unit/test_embedding_config_dimension.py @@ -0,0 +1,45 @@ +from unittest.mock import MagicMock + +from openviking_cli.utils.config.embedding_config import EmbeddingConfig, EmbeddingModelConfig + + +def test_embedding_config_dimension_prefers_explicit_dimension(monkeypatch): + get_embedder = MagicMock() + monkeypatch.setattr(EmbeddingConfig, "get_embedder", get_embedder) + + config = EmbeddingConfig( + dense=EmbeddingModelConfig( + provider="openai", + model="text-embedding-3-small", + api_key="test-key", + dimension=768, + ) + ) + + assert config.get_dimension() == 768 + assert config.dimension == 768 + get_embedder.assert_not_called() + + +def test_embedding_config_dimension_detects_embedder_dimension_without_mutating_config(monkeypatch): + fake_embedder = MagicMock() + fake_embedder.get_dimension.return_value = 1024 + get_embedder = MagicMock(return_value=fake_embedder) + monkeypatch.setattr(EmbeddingConfig, "get_embedder", get_embedder) + + config = EmbeddingConfig( + dense=EmbeddingModelConfig( + provider="openai", + model="Qwen/Qwen3-Embedding-0.6B", + api_key="test-key", + ) + ) + + assert config.dense is not None + assert config.dense.dimension is None + assert config.get_dimension() == 1024 + assert config.dimension == 1024 + assert config.dense.dimension is None + + get_embedder.assert_called_once_with() + fake_embedder.get_dimension.assert_called_once_with() \ No newline at end of file