From 40d2c84f0dca26a6cc6d63734975cd4f1043aa16 Mon Sep 17 00:00:00 2001 From: Anas Ahouzi <112881240+aahouzi@users.noreply.github.com> Date: Wed, 24 Jan 2024 01:46:57 -0800 Subject: [PATCH] Add support for convert.py --- scripts/convert.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/scripts/convert.py b/scripts/convert.py index d0012d51a..757445570 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -13,8 +13,10 @@ # limitations under the License. import argparse +import sys from pathlib import Path from typing import List, Optional +from huggingface_hub import snapshot_download from neural_speed.convert import convert_model def main(args_in: Optional[List[str]] = None) -> None: @@ -25,6 +27,11 @@ def main(args_in: Optional[List[str]] = None) -> None: help="output format, default: f32", default="f32", ) + parser.add_argument( + "--token", + type=str, + help="Access token ID for models that require it (LLaMa2, etc..)", + ) parser.add_argument("--outfile", type=Path, required=True, help="path to write to") parser.add_argument("model", type=Path, help="directory containing model file or model id") args = parser.parse_args(args_in) @@ -32,7 +39,12 @@ def main(args_in: Optional[List[str]] = None) -> None: if args.model.exists(): dir_model = args.model.as_posix() else: - dir_model = args.model + try: + dir_model = snapshot_download(repo_id=str(args.model), resume_download=True, token=args.token) + except Exception as e: + if e.response.status_code == 401: + print("You are required to input an acccess token ID for {}, please add it in option --token or download model weights locally".format(args.model)) + sys.exit(f"{e}") convert_model(dir_model, args.outfile, args.outtype)