diff --git a/docs/source/tutorials/huggingface_model_optimization.md b/docs/source/tutorials/huggingface_model_optimization.md index 3435485c7..e7b4b9551 100644 --- a/docs/source/tutorials/huggingface_model_optimization.md +++ b/docs/source/tutorials/huggingface_model_optimization.md @@ -46,5 +46,52 @@ Please refer to [hf_config](../overview/options.md#hf_config) for more details. ``` Please refer to [metrics](../overview/options.md#metrics) for more details. +### Custom components config +You can use your own custom compenents functions for your model. You will need to define the details of your components in your script as functions. +```json +{ + "input_model": { + "type": "PyTorchModel", + "config": { + "model_script": "code/user_script.py", + "script_dir": "code", + "hf_config": { + "model_class": "WhisperForConditionalGeneration", + "model_name": "openai/whisper-medium", + "components": [ + { + "name": "encoder_decoder_init", + "io_config": "get_encdec_io_config", + "component_func": "get_encoder_decoder_init", + "dummy_inputs_func": "encoder_decoder_init_dummy_inputs" + }, + { + "name": "decoder", + "io_config": "get_dec_io_config", + "component_func": "get_decoder", + "dummy_inputs_func": "decoder_dummy_inputs" + } + ] + } + } + }, +} +``` +#### Script example +``` +# my_script.py +def get_dec_io_config(model_name: str): + # return your io dict + ... + +def get_decoder(model_name: str): + # your component implementation + ... + +def dummy_inputs_func(): + # return the dummy imput for your component + ... +``` + ### E2E example For the complete example, please refer to [Bert Optimization with PTQ on CPU](https://github.com/microsoft/Olive/tree/main/examples/bert#bert-optimization-with-ptq-on-cpu). diff --git a/examples/whisper/README.md b/examples/whisper/README.md index 3c4e3bcdf..528c54c18 100644 --- a/examples/whisper/README.md +++ b/examples/whisper/README.md @@ -25,9 +25,13 @@ python -m pip install -r requirements.txt ### Prepare workflow config json ``` -python prepare_whisper_configs.py [--no_audio_decoder] [--multilingual] +python prepare_whisper_configs.py [--model_name MODEL_NAME] [--no_audio_decoder] [--multilingual] + +# For example, using whisper tiny model +python prepare_whisper_configs.py --model_name openai/whisper-tiny.en ``` +`--model_name MODEL_NAME` is the name or path of the whisper model. The default value is `openai/whisper-tiny.en`. `--no_audio_decoder` is optional. If not provided, will use audio decoder in the preprocessing ops. **Note:** If `--no_audio_decoder` is provided, you need to install `librosa` package before running the optimization steps below. diff --git a/examples/whisper/code/user_script.py b/examples/whisper/code/user_script.py index e15b5bc7f..4c08e1a0d 100644 --- a/examples/whisper/code/user_script.py +++ b/examples/whisper/code/user_script.py @@ -9,8 +9,8 @@ from whisper_encoder_decoder_init import WhisperEncoderDecoderInit, WhisperEncoderDecoderInitInputs -def get_encoder_decoder_init(): - model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") +def get_encoder_decoder_init(model_name): + model = WhisperForConditionalGeneration.from_pretrained(model_name) return WhisperEncoderDecoderInit( model, model, @@ -19,13 +19,13 @@ def get_encoder_decoder_init(): ) -def get_decoder(): - model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny") +def get_decoder(model_name): + model = WhisperForConditionalGeneration.from_pretrained(model_name) return WhisperDecoder(model, model.config) -def get_encdec_io_config(): - model = get_encoder_decoder_init() +def get_encdec_io_config(model_name): + model = get_encoder_decoder_init(model_name) use_decoder_input_ids = True inputs = WhisperEncoderDecoderInitInputs.create_dummy( @@ -96,10 +96,10 @@ def get_encdec_io_config(): } -def get_dec_io_config(): +def get_dec_io_config(model_name): # Fix past disappearing bug - duplicate first past entry # input_list.insert(2, input_list[2]) - model = get_decoder() + model = get_decoder(model_name) past_names = PastKeyValuesHelper.get_past_names(model.config.decoder_layers, present=False) present_names = PastKeyValuesHelper.get_past_names(model.config.decoder_layers, present=True) present_self_names = present_names[: 2 * model.config.decoder_layers] diff --git a/examples/whisper/prepare_whisper_configs.py b/examples/whisper/prepare_whisper_configs.py index e53676c10..4f3622783 100644 --- a/examples/whisper/prepare_whisper_configs.py +++ b/examples/whisper/prepare_whisper_configs.py @@ -29,6 +29,7 @@ def get_args(raw_args): parser = argparse.ArgumentParser(description="Prepare config file for Whisper") + parser.add_argument("--model_name", type=str, default="openai/whisper-tiny.en", help="Model name") parser.add_argument( "--no_audio_decoder", action="store_true", @@ -54,8 +55,12 @@ def main(raw_args=None): # load template template_json = json.load(open("whisper_template.json", "r")) + model_name = args.model_name - whisper_config = WhisperConfig.from_pretrained(template_json["input_model"]["config"]["hf_config"]["model_name"]) + whisper_config = WhisperConfig.from_pretrained(model_name) + + # update model name + template_json["input_model"]["config"]["hf_config"]["model_name"] = model_name # set dataloader template_json["evaluators"]["common_evaluator"]["metrics"][0]["user_config"]["dataloader_func"] = ( @@ -69,6 +74,9 @@ def main(raw_args=None): # update multi-lingual support template_json["passes"]["insert_beam_search"]["config"]["use_forced_decoder_ids"] = args.multilingual + # set model name in prepost + template_json["passes"]["prepost"]["config"]["tool_command_args"]["model_name"] = model_name + # download audio test data test_audio_path = download_audio_test_data() template_json["passes"]["prepost"]["config"]["tool_command_args"]["testdata_filepath"] = str(test_audio_path) @@ -97,6 +105,10 @@ def main(raw_args=None): # dump config json.dump(config, open(f"whisper_{device}_{precision}.json", "w"), indent=4) + # update user script + user_script_path = Path(__file__).parent / "code" / "user_script.py" + update_user_script(user_script_path, model_name) + def download_audio_test_data(): cur_dir = Path(__file__).parent @@ -113,5 +125,19 @@ def download_audio_test_data(): return test_audio_path.relative_to(cur_dir) +def update_user_script(file_path, model_name): + with open(file_path, "r") as file: + lines = file.readlines() + + new_lines = [] + for line in lines: + if "" in line: + line = line.replace("", model_name) + new_lines.append(line) + + with open(file_path, "w") as file: + file.writelines(new_lines) + + if __name__ == "__main__": main() diff --git a/examples/whisper/whisper_template.json b/examples/whisper/whisper_template.json index 28d0433ad..1f625c664 100644 --- a/examples/whisper/whisper_template.json +++ b/examples/whisper/whisper_template.json @@ -6,7 +6,7 @@ "script_dir": "code", "hf_config": { "model_class" : "WhisperForConditionalGeneration", - "model_name" : "openai/whisper-tiny", + "model_name" : "", "components" : [ { "name": "encoder_decoder_init", @@ -90,7 +90,7 @@ "config": { "tool_command": "whisper", "tool_command_args": { - "model_name": "openai/whisper-tiny", + "model_name" : "", "testdata_filepath": "", "use_audio_decoder" : "" } diff --git a/olive/model/__init__.py b/olive/model/__init__.py index e05f6b5d1..28228edfe 100644 --- a/olive/model/__init__.py +++ b/olive/model/__init__.py @@ -635,12 +635,12 @@ def get_component(self, component_name: str) -> "PyTorchModel": hf_component = components_dict[component_name] user_module_loader = UserModuleLoader(self.model_script, self.script_dir) - model_component = user_module_loader.call_object(hf_component.component_func) + model_component = user_module_loader.call_object(hf_component.component_func, self.hf_config.model_name) io_config = hf_component.io_config if isinstance(io_config, str): user_module_loader = UserModuleLoader(self.model_script, self.script_dir) - io_config = user_module_loader.call_object(hf_component.io_config) + io_config = user_module_loader.call_object(hf_component.io_config, self.hf_config.model_name) io_config = validate_config(io_config, IOConfig) def model_loader(_):