diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 1797c3b5b4..a186f67f14 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -613,6 +613,12 @@ def tensor_hook( ) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext( ) with context_manager: + # Add the pip requirements directly to avoid mlflow + # attempting to run inference on the model + model_saving_kwargs['pip_requirements'] = [ + 'transformers', + 'torch', + ] mlflow_logger.save_model(**model_saving_kwargs) # Upload the license file generated by mlflow during the model saving. diff --git a/tests/a_scripts/inference/test_convert_composer_to_hf.py b/tests/a_scripts/inference/test_convert_composer_to_hf.py index 9eb214e83d..cd47b2df7c 100644 --- a/tests/a_scripts/inference/test_convert_composer_to_hf.py +++ b/tests/a_scripts/inference/test_convert_composer_to_hf.py @@ -414,6 +414,7 @@ def test_huggingface_conversion_callback_interval( task='llm/v1/completions', input_example=ANY, metadata={}, + pip_requirements=ANY, ) assert checkpointer_callback.transform_model_pre_registration.call_count == 1 assert checkpointer_callback.pre_register_edit.call_count == 1 @@ -594,6 +595,7 @@ def _assert_mlflow_logger_calls( 'task': 'llm/v1/completions', 'input_example': default_input_example, 'metadata': {}, + 'pip_requirements': ANY, } mlflow_logger_mock.save_model.assert_called_with(**expectation) assert mlflow_logger_mock.register_model_with_run_id.call_count == 1