Skip to content

Commit

Permalink
[djl-import] Includes requires version when importing model (#3431)
Browse files Browse the repository at this point in the history
  • Loading branch information
frankfliu authored Aug 19, 2024
1 parent 2a8ff9e commit bb86c00
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ def importer_args():
" repository. This option should only be set for repositories you trust and in which"
" you have read the code, as it will execute on your local machine arbitrary code"
" present in the model repository.")
parser.add_argument(
"--min-version",
help="Requires a specific version of DJL to load the model.")

args = parser.parse_args()
if args.output_dir is None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import logging
import os
from argparse import Namespace
from typing import List
from typing import List, Optional

from huggingface_hub import HfApi
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -150,7 +150,8 @@ def list_models(self, args: Namespace) -> List[dict]:
return ret

def update_progress(self, model_info: ModelInfo, application: str,
result: bool, reason: str, size: int, cpu_only: bool):
result: bool, reason: str, size: int, cpu_only: bool,
min_version: Optional[str]):
status = {
"result": "success" if result else "failed",
"application": application,
Expand All @@ -162,6 +163,8 @@ def update_progress(self, model_info: ModelInfo, application: str,
status["reason"] = reason
if cpu_only:
status["cpu_only"] = True
if result and min_version:
status["requires"] = min_version

self.processed_models[model_info.modelId] = status

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ def main():
size = -1

huggingface_models.update_progress(model_info, converter.application,
result, reason, size, args.cpu_only)
result, reason, size, args.cpu_only,
args.min_version)
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)

Expand Down

0 comments on commit bb86c00

Please sign in to comment.