diff --git a/ann_benchmarks/algorithms/zeelin/Dockerfile b/ann_benchmarks/algorithms/zeelin/Dockerfile new file mode 100644 index 000000000..23ab52527 --- /dev/null +++ b/ann_benchmarks/algorithms/zeelin/Dockerfile @@ -0,0 +1,4 @@ +FROM ann-benchmarks + +RUN pip install scann +RUN python -c 'import scann' diff --git a/ann_benchmarks/algorithms/zeelin/config.yml b/ann_benchmarks/algorithms/zeelin/config.yml new file mode 100644 index 000000000..d68aba92b --- /dev/null +++ b/ann_benchmarks/algorithms/zeelin/config.yml @@ -0,0 +1,42 @@ +float: + angular: + - base_args: ['@metric'] + constructor: Zeelin + disabled: false + docker_tag: ann-benchmarks-zeelin + module: ann_benchmarks.algorithms.zeelin + name: zeelin-scann + run_groups: + scann-2500: + arg_groups: + - num_leaves: [2500] + anisotropic_quantization_threshold: [0.2] + dims_per_block: [2] + query_args: + - [[60, 80], [80, 100], [100, 120], [120, 140], [128, 145], + [130, 146], [130, 150], [132, 150], [150, 180], [200, 300], + [300, 500], [500, 700], [800, 1000]] + scann-2600: + arg_groups: + - num_leaves: [2600] + anisotropic_quantization_threshold: [0.2] + dims_per_block: [2] + query_args: + - [[118, 138], [124, 142], [128, 145]] + scann-2000: + arg_groups: + - num_leaves: [2000] + anisotropic_quantization_threshold: [0.2] + dims_per_block: [2] + query_args: + - [[90, 110], [108, 125], [110, 120], [115, 130], [130, 150], + [170, 200], [250, 500], [400, 500], [800, 1000]] + scann-soar-high-recall: + arg_groups: + - num_leaves: [2500] + anisotropic_quantization_threshold: [0.2] + dims_per_block: [2] + soar_lambda: [1.5] + overretrieve_factor: [2.0] + query_args: + - [[180, 220], [220, 300], [260, 400], [320, 500], [400, 600]] diff --git a/ann_benchmarks/algorithms/zeelin/install.sh b/ann_benchmarks/algorithms/zeelin/install.sh new file mode 100755 index 000000000..653b97b7b --- /dev/null +++ b/ann_benchmarks/algorithms/zeelin/install.sh @@ -0,0 +1,2 @@ +#!/bin/bash +pip install scann diff --git a/ann_benchmarks/algorithms/zeelin/module.py b/ann_benchmarks/algorithms/zeelin/module.py new file mode 100644 index 000000000..725c85043 --- /dev/null +++ b/ann_benchmarks/algorithms/zeelin/module.py @@ -0,0 +1,72 @@ +import time + +import numpy as np +import scann +from sklearn import preprocessing + +from ..base.module import BaseANN + + +class Zeelin(BaseANN): + def __init__(self, metric, index_param): + if metric != "angular": + raise ValueError("zeelin-scann is tuned for angular distance") + self.index_param = index_param + self.name = "zeelin-scann" + + def fit(self, X): + start = time.time() + X = np.asarray(X, dtype=np.float32) + X[np.linalg.norm(X, axis=1) == 0] = 1.0 / np.sqrt(X.shape[1]) + X = preprocessing.normalize(X, norm="l2", axis=1) + + self.searcher = ( + scann.scann_ops_pybind.builder(X, 10, "dot_product") + .tree( + self.index_param["num_leaves"], + 1, + training_sample_size=len(X), + spherical=True, + quantize_centroids=True, + soar_lambda=self.index_param.get("soar_lambda"), + overretrieve_factor=self.index_param.get("overretrieve_factor"), + ) + .score_ah( + self.index_param["dims_per_block"], + anisotropic_quantization_threshold=self.index_param["anisotropic_quantization_threshold"], + ) + .reorder(1) + .build() + ) + self.build_time = time.time() - start + + def set_query_arguments(self, leaves_to_search, reorder_k=None): + if reorder_k is None: + leaves_to_search, reorder_k = leaves_to_search + self.leaves_to_search = leaves_to_search + self.reorder_k = reorder_k + self.name = "zeelin-scann (%s, leaves_to_search=%s, reorder_k=%s)" % ( + self.index_param, + leaves_to_search, + reorder_k, + ) + + def query(self, v, n): + v = np.asarray(v, dtype=np.float32) + norm = np.linalg.norm(v) + if norm != 0: + v = v / norm + return self.searcher.search(v, n, self.reorder_k, self.leaves_to_search)[0] + + def batch_query(self, X, n): + X = np.asarray(X, dtype=np.float32) + X = preprocessing.normalize(X, norm="l2", axis=1) + self.res = self.searcher.search_batched( + X, + final_num_neighbors=n, + pre_reorder_num_neighbors=self.reorder_k, + leaves_to_search=self.leaves_to_search, + )[0] + + def get_batch_results(self): + return self.res