diff --git a/scripts/convert.py b/scripts/convert.py index d0012d51a..757445570 100644 --- a/scripts/convert.py +++ b/scripts/convert.py @@ -13,8 +13,10 @@ # limitations under the License. import argparse +import sys from pathlib import Path from typing import List, Optional +from huggingface_hub import snapshot_download from neural_speed.convert import convert_model def main(args_in: Optional[List[str]] = None) -> None: @@ -25,6 +27,11 @@ def main(args_in: Optional[List[str]] = None) -> None: help="output format, default: f32", default="f32", ) + parser.add_argument( + "--token", + type=str, + help="Access token ID for models that require it (LLaMa2, etc..)", + ) parser.add_argument("--outfile", type=Path, required=True, help="path to write to") parser.add_argument("model", type=Path, help="directory containing model file or model id") args = parser.parse_args(args_in) @@ -32,7 +39,12 @@ def main(args_in: Optional[List[str]] = None) -> None: if args.model.exists(): dir_model = args.model.as_posix() else: - dir_model = args.model + try: + dir_model = snapshot_download(repo_id=str(args.model), resume_download=True, token=args.token) + except Exception as e: + if e.response.status_code == 401: + print("You are required to input an acccess token ID for {}, please add it in option --token or download model weights locally".format(args.model)) + sys.exit(f"{e}") convert_model(dir_model, args.outfile, args.outtype)