Skip to content

Commit

Permalink
Add pip requirements directly for mlflow save (#1400)
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg authored Jul 27, 2024
1 parent 0c93331 commit 799279b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
6 changes: 6 additions & 0 deletions llmfoundry/callbacks/hf_checkpointer.py
Original file line number Diff line number Diff line change
Expand Up @@ -613,6 +613,12 @@ def tensor_hook(
) if is_te_imported and state.precision == Precision.AMP_FP8 else contextlib.nullcontext(
)
with context_manager:
# Add the pip requirements directly to avoid mlflow
# attempting to run inference on the model
model_saving_kwargs['pip_requirements'] = [
'transformers',
'torch',
]
mlflow_logger.save_model(**model_saving_kwargs)

# Upload the license file generated by mlflow during the model saving.
Expand Down
2 changes: 2 additions & 0 deletions tests/a_scripts/inference/test_convert_composer_to_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ def test_huggingface_conversion_callback_interval(
task='llm/v1/completions',
input_example=ANY,
metadata={},
pip_requirements=ANY,
)
assert checkpointer_callback.transform_model_pre_registration.call_count == 1
assert checkpointer_callback.pre_register_edit.call_count == 1
Expand Down Expand Up @@ -594,6 +595,7 @@ def _assert_mlflow_logger_calls(
'task': 'llm/v1/completions',
'input_example': default_input_example,
'metadata': {},
'pip_requirements': ANY,
}
mlflow_logger_mock.save_model.assert_called_with(**expectation)
assert mlflow_logger_mock.register_model_with_run_id.call_count == 1
Expand Down

0 comments on commit 799279b

Please sign in to comment.