Skip to content

Commit

Permalink
Add model_name arg to whisper example (#404)
Browse files Browse the repository at this point in the history
## Describe your changes
Add model_name arg to whisper example
## Checklist before requesting a review
- [ ] Add unit tests for this change.
- [ ] Make sure all tests can pass.
- [ ] Update documents if necessary.
- [ ] Format your code by running `pre-commit run --all-files`
- [ ] Is this a user-facing change? If yes, give a description of this
change to be included in the release notes.

## (Optional) Issue link
  • Loading branch information
xiaoyu-work authored Jul 13, 2023
1 parent 7daca73 commit ca6dfb3
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 14 deletions.
47 changes: 47 additions & 0 deletions docs/source/tutorials/huggingface_model_optimization.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,52 @@ Please refer to [hf_config](../overview/options.md#hf_config) for more details.
```
Please refer to [metrics](../overview/options.md#metrics) for more details.

### Custom components config
You can use your own custom compenents functions for your model. You will need to define the details of your components in your script as functions.
```json
{
"input_model": {
"type": "PyTorchModel",
"config": {
"model_script": "code/user_script.py",
"script_dir": "code",
"hf_config": {
"model_class": "WhisperForConditionalGeneration",
"model_name": "openai/whisper-medium",
"components": [
{
"name": "encoder_decoder_init",
"io_config": "get_encdec_io_config",
"component_func": "get_encoder_decoder_init",
"dummy_inputs_func": "encoder_decoder_init_dummy_inputs"
},
{
"name": "decoder",
"io_config": "get_dec_io_config",
"component_func": "get_decoder",
"dummy_inputs_func": "decoder_dummy_inputs"
}
]
}
}
},
}
```
#### Script example
```
# my_script.py
def get_dec_io_config(model_name: str):
# return your io dict
...
def get_decoder(model_name: str):
# your component implementation
...
def dummy_inputs_func():
# return the dummy imput for your component
...
```

### E2E example
For the complete example, please refer to [Bert Optimization with PTQ on CPU](https://github.com/microsoft/Olive/tree/main/examples/bert#bert-optimization-with-ptq-on-cpu).
6 changes: 5 additions & 1 deletion examples/whisper/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,13 @@ python -m pip install -r requirements.txt

### Prepare workflow config json
```
python prepare_whisper_configs.py [--no_audio_decoder] [--multilingual]
python prepare_whisper_configs.py [--model_name MODEL_NAME] [--no_audio_decoder] [--multilingual]
# For example, using whisper tiny model
python prepare_whisper_configs.py --model_name openai/whisper-tiny.en
```

`--model_name MODEL_NAME` is the name or path of the whisper model. The default value is `openai/whisper-tiny.en`.
`--no_audio_decoder` is optional. If not provided, will use audio decoder in the preprocessing ops.

**Note:** If `--no_audio_decoder` is provided, you need to install `librosa` package before running the optimization steps below.
Expand Down
16 changes: 8 additions & 8 deletions examples/whisper/code/user_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitInputs


def get_encoder_decoder_init():
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
def get_encoder_decoder_init(model_name):
model = WhisperForConditionalGeneration.from_pretrained(model_name)
return WhisperEncoderDecoderInit(
model,
model,
Expand All @@ -19,13 +19,13 @@ def get_encoder_decoder_init():
)


def get_decoder():
model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
def get_decoder(model_name):
model = WhisperForConditionalGeneration.from_pretrained(model_name)
return WhisperDecoder(model, model.config)


def get_encdec_io_config():
model = get_encoder_decoder_init()
def get_encdec_io_config(model_name):
model = get_encoder_decoder_init(model_name)
use_decoder_input_ids = True

inputs = WhisperEncoderDecoderInitInputs.create_dummy(
Expand Down Expand Up @@ -96,10 +96,10 @@ def get_encdec_io_config():
}


def get_dec_io_config():
def get_dec_io_config(model_name):
# Fix past disappearing bug - duplicate first past entry
# input_list.insert(2, input_list[2])
model = get_decoder()
model = get_decoder(model_name)
past_names = PastKeyValuesHelper.get_past_names(model.config.decoder_layers, present=False)
present_names = PastKeyValuesHelper.get_past_names(model.config.decoder_layers, present=True)
present_self_names = present_names[: 2 * model.config.decoder_layers]
Expand Down
28 changes: 27 additions & 1 deletion examples/whisper/prepare_whisper_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@

def get_args(raw_args):
parser = argparse.ArgumentParser(description="Prepare config file for Whisper")
parser.add_argument("--model_name", type=str, default="openai/whisper-tiny.en", help="Model name")
parser.add_argument(
"--no_audio_decoder",
action="store_true",
Expand All @@ -54,8 +55,12 @@ def main(raw_args=None):

# load template
template_json = json.load(open("whisper_template.json", "r"))
model_name = args.model_name

whisper_config = WhisperConfig.from_pretrained(template_json["input_model"]["config"]["hf_config"]["model_name"])
whisper_config = WhisperConfig.from_pretrained(model_name)

# update model name
template_json["input_model"]["config"]["hf_config"]["model_name"] = model_name

# set dataloader
template_json["evaluators"]["common_evaluator"]["metrics"][0]["user_config"]["dataloader_func"] = (
Expand All @@ -69,6 +74,9 @@ def main(raw_args=None):
# update multi-lingual support
template_json["passes"]["insert_beam_search"]["config"]["use_forced_decoder_ids"] = args.multilingual

# set model name in prepost
template_json["passes"]["prepost"]["config"]["tool_command_args"]["model_name"] = model_name

# download audio test data
test_audio_path = download_audio_test_data()
template_json["passes"]["prepost"]["config"]["tool_command_args"]["testdata_filepath"] = str(test_audio_path)
Expand Down Expand Up @@ -97,6 +105,10 @@ def main(raw_args=None):
# dump config
json.dump(config, open(f"whisper_{device}_{precision}.json", "w"), indent=4)

# update user script
user_script_path = Path(__file__).parent / "code" / "user_script.py"
update_user_script(user_script_path, model_name)


def download_audio_test_data():
cur_dir = Path(__file__).parent
Expand All @@ -113,5 +125,19 @@ def download_audio_test_data():
return test_audio_path.relative_to(cur_dir)


def update_user_script(file_path, model_name):
with open(file_path, "r") as file:
lines = file.readlines()

new_lines = []
for line in lines:
if "<model_name>" in line:
line = line.replace("<model_name>", model_name)
new_lines.append(line)

with open(file_path, "w") as file:
file.writelines(new_lines)


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions examples/whisper/whisper_template.json
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
"script_dir": "code",
"hf_config": {
"model_class" : "WhisperForConditionalGeneration",
"model_name" : "openai/whisper-tiny",
"model_name" : "<place_holder>",
"components" : [
{
"name": "encoder_decoder_init",
Expand Down Expand Up @@ -90,7 +90,7 @@
"config": {
"tool_command": "whisper",
"tool_command_args": {
"model_name": "openai/whisper-tiny",
"model_name" : "<place_holder>",
"testdata_filepath": "<place_holder>",
"use_audio_decoder" : "<place_holder>"
}
Expand Down
4 changes: 2 additions & 2 deletions olive/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -635,12 +635,12 @@ def get_component(self, component_name: str) -> "PyTorchModel":
hf_component = components_dict[component_name]

user_module_loader = UserModuleLoader(self.model_script, self.script_dir)
model_component = user_module_loader.call_object(hf_component.component_func)
model_component = user_module_loader.call_object(hf_component.component_func, self.hf_config.model_name)

io_config = hf_component.io_config
if isinstance(io_config, str):
user_module_loader = UserModuleLoader(self.model_script, self.script_dir)
io_config = user_module_loader.call_object(hf_component.io_config)
io_config = user_module_loader.call_object(hf_component.io_config, self.hf_config.model_name)
io_config = validate_config(io_config, IOConfig)

def model_loader(_):
Expand Down

0 comments on commit ca6dfb3

Please sign in to comment.