From c057898ddb34ced9c6d016cec2bca1bd5b180e35 Mon Sep 17 00:00:00 2001 From: David Meikle Date: Wed, 29 Nov 2023 17:57:25 +0000 Subject: [PATCH] Tokenizers: Updated huggingface_models.py to support Safetensors models as well as pytorch (#2880) * Updated huggingface_models.py to support Safetensors models as well as pytorch --------- Co-authored-by: Frank Liu --- .../src/main/python/huggingface_models.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/extensions/tokenizers/src/main/python/huggingface_models.py b/extensions/tokenizers/src/main/python/huggingface_models.py index 12db9cf2c0c..5b1c6debe5d 100644 --- a/extensions/tokenizers/src/main/python/huggingface_models.py +++ b/extensions/tokenizers/src/main/python/huggingface_models.py @@ -60,17 +60,20 @@ def list_models(self, args: Namespace) -> List[dict]: api = HfApi() if args.model_name: - models = api.list_models(filter="pytorch", - search=args.model_name, - sort="downloads", - direction=-1, - limit=args.limit) + all_models = api.list_models(search=args.model_name, + sort="downloads", + direction=-1, + limit=args.limit) import_all = True else: - models = api.list_models(filter=f"{args.category},pytorch", - sort="downloads", - direction=-1, - limit=args.limit) + all_models = api.list_models(filter=args.category, + sort="downloads", + direction=-1, + limit=args.limit) + models = [ + model for model in all_models + if 'pytorch' in model.tags or 'safetensors' in model.tags + ] if not models: if args.model_name: logging.warning(f"no model found: {args.model_name}.")