From 27b23a24f19904c7fe38d2cd61694db109f2d75a Mon Sep 17 00:00:00 2001 From: Divye Gala Date: Wed, 8 Nov 2023 17:43:25 -0500 Subject: [PATCH] Allow `raft-ann-bench/run` to continue after encountering bad YAML configs (#1980) Authors: - Divye Gala (https://github.com/divyegala) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1980 --- .../src/raft-ann-bench/data_export/__main__.py | 4 ++++ .../generate_groundtruth/__main__.py | 4 ++++ .../src/raft-ann-bench/get_dataset/__main__.py | 5 +++++ .../src/raft-ann-bench/plot/__main__.py | 4 ++++ .../src/raft-ann-bench/run/__main__.py | 14 +++++++++++++- .../raft-ann-bench/split_groundtruth/__main__.py | 5 +++++ 6 files changed, 35 insertions(+), 1 deletion(-) diff --git a/python/raft-ann-bench/src/raft-ann-bench/data_export/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/data_export/__main__.py index e19ada2934..47da9f39fa 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/data_export/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/data_export/__main__.py @@ -17,6 +17,7 @@ import argparse import json import os +import sys import warnings import pandas as pd @@ -147,6 +148,9 @@ def main(): default=default_dataset_path, ) + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) args = parser.parse_args() convert_json_to_csv_build(args.dataset, args.dataset_path) diff --git a/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py index 77a930f81e..f4d97edea5 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/generate_groundtruth/__main__.py @@ -16,6 +16,7 @@ # import argparse import os +import sys import cupy as cp import numpy as np @@ -178,6 +179,9 @@ def main(): " commonly used with RAFT ANN are 'sqeuclidean' and 'inner_product'", ) + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) args = parser.parse_args() if args.rows is not None: diff --git a/python/raft-ann-bench/src/raft-ann-bench/get_dataset/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/get_dataset/__main__.py index 4e6a0119b4..0a6c37aabc 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/get_dataset/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/get_dataset/__main__.py @@ -16,6 +16,7 @@ import argparse import os import subprocess +import sys from urllib.request import urlretrieve @@ -101,6 +102,10 @@ def main(): help="normalize cosine distance to inner product", action="store_true", ) + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) args = parser.parse_args() download(args.dataset, args.normalize, args.dataset_path) diff --git a/python/raft-ann-bench/src/raft-ann-bench/plot/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/plot/__main__.py index 78f8aea8b8..c45ff5b14e 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/plot/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/plot/__main__.py @@ -22,6 +22,7 @@ import argparse import itertools import os +import sys from collections import OrderedDict import matplotlib as mpl @@ -486,6 +487,9 @@ def main(): action="store_true", ) + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) args = parser.parse_args() if args.algorithms: diff --git a/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py index 6b01263c27..c9fde6dd7e 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/run/__main__.py @@ -18,7 +18,9 @@ import json import os import subprocess +import sys import uuid +import warnings from importlib import import_module import yaml @@ -292,6 +294,9 @@ def main(): action="store_true", ) + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) args = parser.parse_args() # If both build and search are not provided, @@ -368,7 +373,14 @@ def main(): algos_conf = dict() for algo_f in algos_conf_fs: with open(algo_f, "r") as f: - algo = yaml.safe_load(f) + try: + algo = yaml.safe_load(f) + except Exception as e: + warnings.warn( + f"Could not load YAML config {algo_f} due to " + + e.with_traceback() + ) + continue insert_algo = True insert_algo_group = False if filter_algos: diff --git a/python/raft-ann-bench/src/raft-ann-bench/split_groundtruth/__main__.py b/python/raft-ann-bench/src/raft-ann-bench/split_groundtruth/__main__.py index b886d40ea7..c65360ebb0 100644 --- a/python/raft-ann-bench/src/raft-ann-bench/split_groundtruth/__main__.py +++ b/python/raft-ann-bench/src/raft-ann-bench/split_groundtruth/__main__.py @@ -16,6 +16,7 @@ import argparse import os import subprocess +import sys def split_groundtruth(groundtruth_filepath): @@ -43,6 +44,10 @@ def main(): help="Path to billion-scale dataset groundtruth file", required=True, ) + + if len(sys.argv) == 1: + parser.print_help() + sys.exit(1) args = parser.parse_args() split_groundtruth(args.groundtruth)