Skip to content

Commit

Permalink
Update conversion script to copy certain JSON files to destination (h…
Browse files Browse the repository at this point in the history
  • Loading branch information
xenova committed Mar 10, 2023
1 parent 62a08ff commit 1f14ca0
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions scripts/convert.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@

import os
import shutil
from dataclasses import dataclass, field
from typing import Optional
from pathlib import Path

from transformers import AutoTokenizer, HfArgumentParser
from transformers.utils import cached_file

from optimum.utils import DEFAULT_DUMMY_SHAPES
from optimum.exporters.tasks import TasksManager
Expand Down Expand Up @@ -145,6 +147,13 @@ def quantize(models_name_or_path, model_type):
) # op_types_to_quantize=['MatMul', 'Relu', 'Add', 'Mul' ],


def copy_if_exists(model_path, file_name, destination):
file = cached_file(model_path, file_name,
_raise_exceptions_for_missing_entries=False)
if file is not None:
shutil.copy(file, destination)


def main():

# Helper script to fix inconsistencies between optimum exporter and other exporters.
Expand Down Expand Up @@ -200,7 +209,12 @@ def main():
f"At least {onnx_config.DEFAULT_ONNX_OPSET} is required."
)

# TODO copy all .json files
# Create output folder
os.makedirs(output_model_folder, exist_ok=True)

# Copy certain JSON files, which save_pretrained doesn't handle
copy_if_exists(model_path, 'tokenizer.json', output_model_folder)
copy_if_exists(model_path, 'preprocessor_config.json', output_model_folder)

# Saving the model config
model.config.save_pretrained(output_model_folder)
Expand All @@ -209,9 +223,6 @@ def main():
tokenizer = AutoTokenizer.from_pretrained(model_path)
tokenizer.save_pretrained(output_model_folder)

# Create output folder
os.makedirs(output_model_folder, exist_ok=True)

# Specify output paths
OUTPUT_WEIGHTS_PATH = os.path.join(output_model_folder, ONNX_WEIGHTS_NAME)
OUTPUT_ENCODER_PATH = os.path.join(output_model_folder, ONNX_ENCODER_NAME)
Expand Down

0 comments on commit 1f14ca0

Please sign in to comment.