diff --git a/optimum/onnxruntime/modeling_decoder.py b/optimum/onnxruntime/modeling_decoder.py index c3b37bb7ad0..084776c1aee 100644 --- a/optimum/onnxruntime/modeling_decoder.py +++ b/optimum/onnxruntime/modeling_decoder.py @@ -14,7 +14,6 @@ """Classes handling causal-lm related architectures in ONNX Runtime.""" import logging -import re import shutil from pathlib import Path from tempfile import TemporaryDirectory @@ -27,11 +26,12 @@ from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions import onnxruntime -from huggingface_hub import HfApi, HfFolder, get_hf_file_metadata, hf_hub_download, hf_hub_url +from huggingface_hub import hf_hub_download from ..exporters import TasksManager from ..exporters.onnx import export from ..utils import NormalizedConfigManager, check_if_transformers_greater +from ..utils.file_utils import validate_file_exists from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from .io_binding import TypeHelper from .modeling_ort import ORTModel @@ -110,6 +110,9 @@ ``` """ +DECODER_ONNX_FILE_PATTERN = r"(.*)?decoder((?!with_past).)*?\.onnx" +DECODER_WITH_PAST_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?with_past(.*)?\.onnx" + class ORTDecoder: """ @@ -502,50 +505,24 @@ def _from_pretrained( ): model_path = Path(model_id) - def validate_file_exists(filename): - if model_path.is_dir(): - return (model_path / subfolder / filename).is_file() - succeeded = True - try: - get_hf_file_metadata(hf_hub_url(model_id, filename, subfolder=subfolder, revision=revision)) - except Exception: - succeeded = False - return succeeded - - def infer_filename(pattern: str, argument_name: str, fail_if_not_found: bool = True) -> str: - pattern = re.compile(f"{subfolder}/{pattern}" if subfolder != "" else pattern) - if model_path.is_dir(): - path = model_path - files = model_path.glob("**/*.onnx") - onnx_files = [p for p in files if re.search(pattern, str(p))] - else: - path = model_id - if isinstance(use_auth_token, bool): - token = HfFolder().get_token() - else: - token = use_auth_token - repo_files = map(Path, HfApi().list_repo_files(model_id, revision=revision, token=token)) - if subfolder != "": - path = f"{path}/{subfolder}" - onnx_files = [p for p in repo_files if re.match(pattern, str(p))] - - if len(onnx_files) == 0: - if fail_if_not_found: - raise FileNotFoundError(f"Could not find any ONNX model file in {path}") - return None - elif len(onnx_files) > 1: - if argument_name is not None: - raise RuntimeError( - f"Too many ONNX model files were found in {path}, specify which one to load by using the " - f"{argument_name} argument." - ) - return onnx_files[0].name - - if not validate_file_exists(decoder_file_name): - decoder_file_name = infer_filename(r"(.*)?decoder((?!with_past).)*?\.onnx", "decoder_file_name") - if not validate_file_exists(decoder_with_past_file_name): - decoder_with_past_file_name = infer_filename( - r"(.*)?decoder(.*)?with_past(.*)?\.onnx", "decoder_with_past_file_name", fail_if_not_found=False + if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision): + decoder_file_name = ORTModelDecoder.infer_onnx_filename( + model_id, + DECODER_ONNX_FILE_PATTERN, + "decoder_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + ) + if not validate_file_exists(model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision): + decoder_with_past_file_name = ORTModelDecoder.infer_onnx_filename( + model_id, + DECODER_WITH_PAST_ONNX_FILE_PATTERN, + "decoder_with_past_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + fail_if_not_found=use_cache, ) decoder_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename(ONNX_DECODER_NAME) diff --git a/optimum/onnxruntime/modeling_ort.py b/optimum/onnxruntime/modeling_ort.py index 144a0c3f76f..f696346d1d8 100644 --- a/optimum/onnxruntime/modeling_ort.py +++ b/optimum/onnxruntime/modeling_ort.py @@ -47,6 +47,7 @@ from ..exporters import TasksManager from ..exporters.onnx import export from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel +from ..utils.file_utils import find_files_matching_pattern from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from .io_binding import IOBindingHelper, TypeHelper from .utils import ( @@ -309,6 +310,41 @@ def _generate_regular_names_for_filename(filename: str): name, extension = filename.rsplit(".", maxsplit=1) return [filename, f"{name}_quantized.{extension}", f"{name}_optimized.{extension}"] + @staticmethod + def infer_onnx_filename( + model_name_or_path: Union[str, Path], + pattern: str, + argument_name: str, + subfolder: str = "", + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, + fail_if_not_found: bool = True, + ) -> str: + onnx_files = find_files_matching_pattern( + model_name_or_path, + pattern, + glob_pattern="**/*.onnx", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + ) + + path = model_name_or_path + if subfolder != "": + path = f"{path}/{subfolder}" + + if len(onnx_files) == 0: + if fail_if_not_found: + raise FileNotFoundError(f"Could not find any ONNX model file in {path}") + return None + elif len(onnx_files) > 1: + if argument_name is not None: + raise RuntimeError( + f"Too many ONNX model files were found in {path}, specify which one to load by using the " + f"{argument_name} argument." + ) + return onnx_files[0].name + @classmethod def _from_pretrained( cls, diff --git a/optimum/onnxruntime/modeling_seq2seq.py b/optimum/onnxruntime/modeling_seq2seq.py index 761d0d80dc0..ad1c7489780 100644 --- a/optimum/onnxruntime/modeling_seq2seq.py +++ b/optimum/onnxruntime/modeling_seq2seq.py @@ -36,6 +36,7 @@ from ..exporters.onnx.convert import export_encoder_decoder_model as export from ..exporters.tasks import TasksManager from ..utils import NormalizedConfigManager, check_if_transformers_greater +from ..utils.file_utils import validate_file_exists from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors from .io_binding import TypeHelper from .modeling_decoder import ORTDecoder @@ -932,52 +933,33 @@ def _from_pretrained( ): model_path = Path(model_id) - def validate_file_exists(filename): - if model_path.is_dir(): - return (model_path / subfolder / filename).is_file() - succeeded = True - try: - get_hf_file_metadata(hf_hub_url(model_id, filename, subfolder=subfolder, revision=revision)) - except Exception: - succeeded = False - return succeeded - - def infer_filename(pattern: str, argument_name: str, fail_if_not_found: bool = True) -> str: - pattern = re.compile(f"{subfolder}/{pattern}" if subfolder != "" else pattern) - if model_path.is_dir(): - path = model_path - files = model_path.glob("**/*.onnx") - onnx_files = [p for p in files if re.search(pattern, str(p))] - else: - path = model_id - if isinstance(use_auth_token, bool): - token = HfFolder().get_token() - else: - token = use_auth_token - repo_files = map(Path, HfApi().list_repo_files(model_id, revision=revision, token=token)) - if subfolder != "": - path = f"{path}/{subfolder}" - onnx_files = [p for p in repo_files if re.match(pattern, str(p))] - - if len(onnx_files) == 0: - if fail_if_not_found: - raise FileNotFoundError(f"Could not find any ONNX model file in {path}") - return None - elif len(onnx_files) > 1: - if argument_name is not None: - raise RuntimeError( - f"Too many ONNX model files were found in {path}, specify which one to load by using the " - f"{argument_name} argument." - ) - return onnx_files[0].name - - if not validate_file_exists(encoder_file_name): - encoder_file_name = infer_filename(ENCODER_ONNX_FILE_PATTERN, "encoder_file_name") - if not validate_file_exists(decoder_file_name): - decoder_file_name = infer_filename(DECODER_ONNX_FILE_PATTERN, "decoder_file_name") - if not validate_file_exists(decoder_with_past_file_name): - decoder_with_past_file_name = infer_filename( - DECODER_WITH_PAST_ONNX_FILE_PATTERN, "decoder_with_past_file_name", fail_if_not_found=use_cache + if not validate_file_exists(model_id, encoder_file_name, subfolder=subfolder, revision=revision): + encoder_file_name = ORTModelForConditionalGeneration.infer_onnx_filename( + model_id, + ENCODER_ONNX_FILE_PATTERN, + "encoder_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + ) + if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision): + decoder_file_name = ORTModelForConditionalGeneration.infer_onnx_filename( + model_id, + DECODER_ONNX_FILE_PATTERN, + "decoder_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + ) + if not validate_file_exists(model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision): + decoder_with_past_file_name = ORTModelForConditionalGeneration.infer_onnx_filename( + model_id, + DECODER_WITH_PAST_ONNX_FILE_PATTERN, + "decoder_with_past_file_name", + subfolder=subfolder, + use_auth_token=use_auth_token, + revision=revision, + fail_if_not_found=use_cache, ) encoder_regular_onnx_filenames = ORTModelForConditionalGeneration._generate_regular_names_for_filename( diff --git a/optimum/onnxruntime/utils.py b/optimum/onnxruntime/utils.py index 13e404672b8..a0cfa2e2ae8 100644 --- a/optimum/onnxruntime/utils.py +++ b/optimum/onnxruntime/utils.py @@ -16,7 +16,7 @@ import importlib.util import os from enum import Enum -from typing import Dict, Tuple, Type, Union +from typing import Dict, Tuple, Union import torch from transformers.onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast @@ -27,7 +27,6 @@ import pkg_resources from ..onnx import OnnxConfigWithLoss, OnnxConfigWithPastAndLoss, OnnxSeq2SeqConfigWithPastAndLoss -from ..utils import NormalizedTextConfig logger = logging.get_logger(__name__) diff --git a/optimum/utils/file_utils.py b/optimum/utils/file_utils.py new file mode 100644 index 00000000000..bfd62d9f3ca --- /dev/null +++ b/optimum/utils/file_utils.py @@ -0,0 +1,88 @@ +# coding=utf-8 +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Utility functions related to both local files and files on the Hugging Face Hub.""" + +import re +from pathlib import Path +from typing import List, Optional, Union + +from huggingface_hub import HfApi, HfFolder, get_hf_file_metadata, hf_hub_url + + +def validate_file_exists( + model_name_or_path: Union[str, Path], filename: str, subfolder: str = "", revision: Optional[str] = None +) -> bool: + """ + Checks that the file called `filename` exists in the `model_name_or_path` directory or model repo. + """ + model_path = Path(model_name_or_path) if isinstance(model_name_or_path, str) else model_name_or_path + if model_path.is_dir(): + return (model_path / subfolder / filename).is_file() + succeeded = True + try: + get_hf_file_metadata(hf_hub_url(model_name_or_path, filename, subfolder=subfolder, revision=revision)) + except Exception: + succeeded = False + return succeeded + + +def find_files_matching_pattern( + model_name_or_path: Union[str, Path], + pattern: str, + glob_pattern: str = "**/*", + subfolder: str = "", + use_auth_token: Optional[Union[bool, str]] = None, + revision: Optional[str] = None, +) -> List[Path]: + """ + Scans either a model repo or a local directory to find filenames matching the pattern. + + Args: + model_name_or_path (`Union[str, Path]`): + The name of the model repo on the Hugging Face Hub or the path to a local directory. + pattern (`str`): + The pattern to use to look for files. + glob_pattern (`str`, defaults to `"**/*"`): + The pattern to use to list all the files that need to be checked. + subfolder (`str`, defaults to `""`): + In case the model files are located inside a subfolder of the model directory / repo on the Hugging + Face Hub, you can specify the subfolder name here. + use_auth_token (`Optional[bool, str]`, *optional*): + The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated + when running `transformers-cli login` (stored in `~/.huggingface`). + revision (`Optional[str]`, defaults to `None`): + Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id. + + Returns: + `List[Path]` + """ + model_path = Path(model_name_or_path) if isinstance(model_name_or_path, str) else model_name_or_path + pattern = re.compile(f"{subfolder}/{pattern}" if subfolder != "" else pattern) + if model_path.is_dir(): + path = model_path + files = model_path.glob("**/*.onnx") + files = [p for p in files if re.search(pattern, str(p))] + else: + path = model_name_or_path + if isinstance(use_auth_token, bool): + token = HfFolder().get_token() + else: + token = use_auth_token + repo_files = map(Path, HfApi().list_repo_files(model_name_or_path, revision=revision, token=token)) + if subfolder != "": + path = f"{path}/{subfolder}" + files = [Path(p) for p in repo_files if re.match(pattern, str(p))] + + return files