From 48f192ab4fb7b17e0f86a786594b8a3d8f004e86 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 25 Oct 2023 14:51:29 -0400 Subject: [PATCH] Updated validatoes --- .../src/raft-ann-bench/run/conf/algos/raft_cagra.yaml | 2 ++ .../src/raft-ann-bench/run/conf/algos/raft_ivf_pq.yaml | 3 ++- .../src/raft-ann-bench/validators/__init__.py | 8 ++++++-- 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_cagra.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_cagra.yaml index 095f77e66a..32ed400b1a 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_cagra.yaml +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_cagra.yaml @@ -1,4 +1,6 @@ name: raft_cagra +validators: + search: raft-ann-bench.validators.raft_cagra_build_validator base: build: graph_degree: [32, 64] diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ivf_pq.yaml b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ivf_pq.yaml index 0acc7ae694..0a1853f582 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ivf_pq.yaml +++ b/python/raft-ann-bench/src/raft-ann-bench/run/conf/algos/raft_ivf_pq.yaml @@ -1,5 +1,6 @@ name: raft_ivf_pq -validator: raft-ann-bench.validators.raft_ivf_pq_validator +validators: + search: raft-ann-bench.validators.raft_cagra_build_validator base: build: nlist: [1024] diff --git a/python/raft-ann-bench/src/raft-ann-bench/validators/__init__.py b/python/raft-ann-bench/src/raft-ann-bench/validators/__init__.py index b415d972c5..04448425ee 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/validators/__init__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/validators/__init__.py @@ -13,13 +13,17 @@ # See the License for the specific language governing permissions and # limitations under the License. - DTYPE_SIZES = {"float": 4, "half": 2, "fp8": 1} -def ivf_pq_validator(params): +def raft_ivf_pq_search_validator(params, k, batch_size): if "internalDistanceDtype" in params and "smemLutDtype" in params: return ( DTYPE_SIZES[params["smemLutDtype"]] >= DTYPE_SIZES[params["internalDistanceDtype"]] ) + + +def raft_cagra_search_validator(params, k, batch_size): + if "itopk" in params: + return params["itopk"] >= k