Skip to content

Commit

Permalink
Refactor of 2 functions used in ORTModel (#551)
Browse files Browse the repository at this point in the history
* Fix issues

* Fix issues

* Apply suggestion

Co-authored-by: Michael Benayoun <[email protected]>
  • Loading branch information
michaelbenayoun and michaelbenayoun authored Dec 7, 2022
1 parent 037467d commit f6eb417
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 94 deletions.
69 changes: 23 additions & 46 deletions optimum/onnxruntime/modeling_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
"""Classes handling causal-lm related architectures in ONNX Runtime."""

import logging
import re
import shutil
from pathlib import Path
from tempfile import TemporaryDirectory
Expand All @@ -27,11 +26,12 @@
from transformers.modeling_outputs import CausalLMOutputWithCrossAttentions

import onnxruntime
from huggingface_hub import HfApi, HfFolder, get_hf_file_metadata, hf_hub_download, hf_hub_url
from huggingface_hub import hf_hub_download

from ..exporters import TasksManager
from ..exporters.onnx import export
from ..utils import NormalizedConfigManager, check_if_transformers_greater
from ..utils.file_utils import validate_file_exists
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .io_binding import TypeHelper
from .modeling_ort import ORTModel
Expand Down Expand Up @@ -110,6 +110,9 @@
```
"""

DECODER_ONNX_FILE_PATTERN = r"(.*)?decoder((?!with_past).)*?\.onnx"
DECODER_WITH_PAST_ONNX_FILE_PATTERN = r"(.*)?decoder(.*)?with_past(.*)?\.onnx"


class ORTDecoder:
"""
Expand Down Expand Up @@ -502,50 +505,24 @@ def _from_pretrained(
):
model_path = Path(model_id)

def validate_file_exists(filename):
if model_path.is_dir():
return (model_path / subfolder / filename).is_file()
succeeded = True
try:
get_hf_file_metadata(hf_hub_url(model_id, filename, subfolder=subfolder, revision=revision))
except Exception:
succeeded = False
return succeeded

def infer_filename(pattern: str, argument_name: str, fail_if_not_found: bool = True) -> str:
pattern = re.compile(f"{subfolder}/{pattern}" if subfolder != "" else pattern)
if model_path.is_dir():
path = model_path
files = model_path.glob("**/*.onnx")
onnx_files = [p for p in files if re.search(pattern, str(p))]
else:
path = model_id
if isinstance(use_auth_token, bool):
token = HfFolder().get_token()
else:
token = use_auth_token
repo_files = map(Path, HfApi().list_repo_files(model_id, revision=revision, token=token))
if subfolder != "":
path = f"{path}/{subfolder}"
onnx_files = [p for p in repo_files if re.match(pattern, str(p))]

if len(onnx_files) == 0:
if fail_if_not_found:
raise FileNotFoundError(f"Could not find any ONNX model file in {path}")
return None
elif len(onnx_files) > 1:
if argument_name is not None:
raise RuntimeError(
f"Too many ONNX model files were found in {path}, specify which one to load by using the "
f"{argument_name} argument."
)
return onnx_files[0].name

if not validate_file_exists(decoder_file_name):
decoder_file_name = infer_filename(r"(.*)?decoder((?!with_past).)*?\.onnx", "decoder_file_name")
if not validate_file_exists(decoder_with_past_file_name):
decoder_with_past_file_name = infer_filename(
r"(.*)?decoder(.*)?with_past(.*)?\.onnx", "decoder_with_past_file_name", fail_if_not_found=False
if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision):
decoder_file_name = ORTModelDecoder.infer_onnx_filename(
model_id,
DECODER_ONNX_FILE_PATTERN,
"decoder_file_name",
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
)
if not validate_file_exists(model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision):
decoder_with_past_file_name = ORTModelDecoder.infer_onnx_filename(
model_id,
DECODER_WITH_PAST_ONNX_FILE_PATTERN,
"decoder_with_past_file_name",
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
fail_if_not_found=use_cache,
)

decoder_regular_onnx_filenames = ORTModelDecoder._generate_regular_names_for_filename(ONNX_DECODER_NAME)
Expand Down
36 changes: 36 additions & 0 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from ..exporters import TasksManager
from ..exporters.onnx import export
from ..modeling_base import FROM_PRETRAINED_START_DOCSTRING, OptimizedModel
from ..utils.file_utils import find_files_matching_pattern
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .io_binding import IOBindingHelper, TypeHelper
from .utils import (
Expand Down Expand Up @@ -309,6 +310,41 @@ def _generate_regular_names_for_filename(filename: str):
name, extension = filename.rsplit(".", maxsplit=1)
return [filename, f"{name}_quantized.{extension}", f"{name}_optimized.{extension}"]

@staticmethod
def infer_onnx_filename(
model_name_or_path: Union[str, Path],
pattern: str,
argument_name: str,
subfolder: str = "",
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
fail_if_not_found: bool = True,
) -> str:
onnx_files = find_files_matching_pattern(
model_name_or_path,
pattern,
glob_pattern="**/*.onnx",
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
)

path = model_name_or_path
if subfolder != "":
path = f"{path}/{subfolder}"

if len(onnx_files) == 0:
if fail_if_not_found:
raise FileNotFoundError(f"Could not find any ONNX model file in {path}")
return None
elif len(onnx_files) > 1:
if argument_name is not None:
raise RuntimeError(
f"Too many ONNX model files were found in {path}, specify which one to load by using the "
f"{argument_name} argument."
)
return onnx_files[0].name

@classmethod
def _from_pretrained(
cls,
Expand Down
74 changes: 28 additions & 46 deletions optimum/onnxruntime/modeling_seq2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from ..exporters.onnx.convert import export_encoder_decoder_model as export
from ..exporters.tasks import TasksManager
from ..utils import NormalizedConfigManager, check_if_transformers_greater
from ..utils.file_utils import validate_file_exists
from ..utils.save_utils import maybe_load_preprocessors, maybe_save_preprocessors
from .io_binding import TypeHelper
from .modeling_decoder import ORTDecoder
Expand Down Expand Up @@ -932,52 +933,33 @@ def _from_pretrained(
):
model_path = Path(model_id)

def validate_file_exists(filename):
if model_path.is_dir():
return (model_path / subfolder / filename).is_file()
succeeded = True
try:
get_hf_file_metadata(hf_hub_url(model_id, filename, subfolder=subfolder, revision=revision))
except Exception:
succeeded = False
return succeeded

def infer_filename(pattern: str, argument_name: str, fail_if_not_found: bool = True) -> str:
pattern = re.compile(f"{subfolder}/{pattern}" if subfolder != "" else pattern)
if model_path.is_dir():
path = model_path
files = model_path.glob("**/*.onnx")
onnx_files = [p for p in files if re.search(pattern, str(p))]
else:
path = model_id
if isinstance(use_auth_token, bool):
token = HfFolder().get_token()
else:
token = use_auth_token
repo_files = map(Path, HfApi().list_repo_files(model_id, revision=revision, token=token))
if subfolder != "":
path = f"{path}/{subfolder}"
onnx_files = [p for p in repo_files if re.match(pattern, str(p))]

if len(onnx_files) == 0:
if fail_if_not_found:
raise FileNotFoundError(f"Could not find any ONNX model file in {path}")
return None
elif len(onnx_files) > 1:
if argument_name is not None:
raise RuntimeError(
f"Too many ONNX model files were found in {path}, specify which one to load by using the "
f"{argument_name} argument."
)
return onnx_files[0].name

if not validate_file_exists(encoder_file_name):
encoder_file_name = infer_filename(ENCODER_ONNX_FILE_PATTERN, "encoder_file_name")
if not validate_file_exists(decoder_file_name):
decoder_file_name = infer_filename(DECODER_ONNX_FILE_PATTERN, "decoder_file_name")
if not validate_file_exists(decoder_with_past_file_name):
decoder_with_past_file_name = infer_filename(
DECODER_WITH_PAST_ONNX_FILE_PATTERN, "decoder_with_past_file_name", fail_if_not_found=use_cache
if not validate_file_exists(model_id, encoder_file_name, subfolder=subfolder, revision=revision):
encoder_file_name = ORTModelForConditionalGeneration.infer_onnx_filename(
model_id,
ENCODER_ONNX_FILE_PATTERN,
"encoder_file_name",
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
)
if not validate_file_exists(model_id, decoder_file_name, subfolder=subfolder, revision=revision):
decoder_file_name = ORTModelForConditionalGeneration.infer_onnx_filename(
model_id,
DECODER_ONNX_FILE_PATTERN,
"decoder_file_name",
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
)
if not validate_file_exists(model_id, decoder_with_past_file_name, subfolder=subfolder, revision=revision):
decoder_with_past_file_name = ORTModelForConditionalGeneration.infer_onnx_filename(
model_id,
DECODER_WITH_PAST_ONNX_FILE_PATTERN,
"decoder_with_past_file_name",
subfolder=subfolder,
use_auth_token=use_auth_token,
revision=revision,
fail_if_not_found=use_cache,
)

encoder_regular_onnx_filenames = ORTModelForConditionalGeneration._generate_regular_names_for_filename(
Expand Down
3 changes: 1 addition & 2 deletions optimum/onnxruntime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import importlib.util
import os
from enum import Enum
from typing import Dict, Tuple, Type, Union
from typing import Dict, Tuple, Union

import torch
from transformers.onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
Expand All @@ -27,7 +27,6 @@
import pkg_resources

from ..onnx import OnnxConfigWithLoss, OnnxConfigWithPastAndLoss, OnnxSeq2SeqConfigWithPastAndLoss
from ..utils import NormalizedTextConfig


logger = logging.get_logger(__name__)
Expand Down
88 changes: 88 additions & 0 deletions optimum/utils/file_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utility functions related to both local files and files on the Hugging Face Hub."""

import re
from pathlib import Path
from typing import List, Optional, Union

from huggingface_hub import HfApi, HfFolder, get_hf_file_metadata, hf_hub_url


def validate_file_exists(
model_name_or_path: Union[str, Path], filename: str, subfolder: str = "", revision: Optional[str] = None
) -> bool:
"""
Checks that the file called `filename` exists in the `model_name_or_path` directory or model repo.
"""
model_path = Path(model_name_or_path) if isinstance(model_name_or_path, str) else model_name_or_path
if model_path.is_dir():
return (model_path / subfolder / filename).is_file()
succeeded = True
try:
get_hf_file_metadata(hf_hub_url(model_name_or_path, filename, subfolder=subfolder, revision=revision))
except Exception:
succeeded = False
return succeeded


def find_files_matching_pattern(
model_name_or_path: Union[str, Path],
pattern: str,
glob_pattern: str = "**/*",
subfolder: str = "",
use_auth_token: Optional[Union[bool, str]] = None,
revision: Optional[str] = None,
) -> List[Path]:
"""
Scans either a model repo or a local directory to find filenames matching the pattern.
Args:
model_name_or_path (`Union[str, Path]`):
The name of the model repo on the Hugging Face Hub or the path to a local directory.
pattern (`str`):
The pattern to use to look for files.
glob_pattern (`str`, defaults to `"**/*"`):
The pattern to use to list all the files that need to be checked.
subfolder (`str`, defaults to `""`):
In case the model files are located inside a subfolder of the model directory / repo on the Hugging
Face Hub, you can specify the subfolder name here.
use_auth_token (`Optional[bool, str]`, *optional*):
The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
when running `transformers-cli login` (stored in `~/.huggingface`).
revision (`Optional[str]`, defaults to `None`):
Revision is the specific model version to use. It can be a branch name, a tag name, or a commit id.
Returns:
`List[Path]`
"""
model_path = Path(model_name_or_path) if isinstance(model_name_or_path, str) else model_name_or_path
pattern = re.compile(f"{subfolder}/{pattern}" if subfolder != "" else pattern)
if model_path.is_dir():
path = model_path
files = model_path.glob("**/*.onnx")
files = [p for p in files if re.search(pattern, str(p))]
else:
path = model_name_or_path
if isinstance(use_auth_token, bool):
token = HfFolder().get_token()
else:
token = use_auth_token
repo_files = map(Path, HfApi().list_repo_files(model_name_or_path, revision=revision, token=token))
if subfolder != "":
path = f"{path}/{subfolder}"
files = [Path(p) for p in repo_files if re.match(pattern, str(p))]

return files

0 comments on commit f6eb417

Please sign in to comment.