From 5872f3f464064bf3f77c66017a1d5c336a32c9b4 Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Thu, 2 Nov 2023 18:12:39 -0700 Subject: [PATCH 01/22] set task if the passed metadata does not include task --- llmfoundry/callbacks/hf_checkpointer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 3050529a5a..44f2b1348d 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -74,12 +74,12 @@ def __init__( if self.mlflow_registered_model_name is not None: # Both the metadata and the task are needed in order for mlflow # and databricks optimized model serving to work - if 'metadata' not in mlflow_logging_config: - mlflow_logging_config['metadata'] = { - 'task': 'llm/v1/completions' - } - if 'task' not in mlflow_logging_config: - mlflow_logging_config['task'] = 'text-generation' + default_metadata = { + 'task': 'llm/v1/completions' + } + passed_metadata = mlflow_logging_config.get('metadata', {}) + mlflow_logging_config['metadata'] = {**default_metadata, **passed_metadata} + mlflow_logging_config.setdefault('task', 'text-generation') self.mlflow_logging_config = mlflow_logging_config self.huggingface_folder_name_fstr = os.path.join( From 3e17fd0639cf9acfe3aba4f8c780a05a1bad4fed Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Tue, 7 Nov 2023 13:09:40 -0800 Subject: [PATCH 02/22] test --- llmfoundry/callbacks/hf_checkpointer.py | 9 +++++---- tests/test_hf_conversion_script.py | 24 +++++++++++++++++++----- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 44f2b1348d..b8990f67df 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -74,11 +74,12 @@ def __init__( if self.mlflow_registered_model_name is not None: # Both the metadata and the task are needed in order for mlflow # and databricks optimized model serving to work - default_metadata = { - 'task': 'llm/v1/completions' - } + default_metadata = {'task': 'llm/v1/completions'} passed_metadata = mlflow_logging_config.get('metadata', {}) - mlflow_logging_config['metadata'] = {**default_metadata, **passed_metadata} + mlflow_logging_config['metadata'] = { + **default_metadata, + **passed_metadata + } mlflow_logging_config.setdefault('task', 'text-generation') self.mlflow_logging_config = mlflow_logging_config diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index d2c2a9e1c9..a2de0ea854 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -5,7 +5,7 @@ import os import pathlib import sys -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock, patch, ANY from composer import Trainer from composer.loggers import MLFlowLogger @@ -421,10 +421,24 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, trainer.fit() if dist.get_global_rank() == 0: - assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow - else 0) - assert mlflow_logger_mock.register_model.call_count == ( - 1 if log_to_mlflow else 0) + if log_to_mlflow: + # assert mlflow_logger_mock.save_model.call_count == 1 + mlflow_logger_mock.save_model.assert_called_once_with( + flavor='transformers', + transformers_model=ANY, + path=ANY, + task='text-generation', + metatdata={'task': 'llm/v1/completions'}) + assert mlflow_logger_mock.register_model.call_count == 1 + # mlflow_logger.save_model( + # flavor='transformers', + # transformers_model=components, + # path=local_save_path, + # **self.mlflow_logging_config, + # ) + else: + assert mlflow_logger_mock.save_model.call_count == 0 + assert mlflow_logger_mock.register_model.call_count == 0 else: assert mlflow_logger_mock.log_model.call_count == 0 assert mlflow_logger_mock.register_model.call_count == 0 From bedfd1d5066bb4ca3b739e8a78708eae36b05e25 Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Tue, 7 Nov 2023 13:18:19 -0800 Subject: [PATCH 03/22] isort --- tests/test_hf_conversion_script.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index a2de0ea854..d5ee86a406 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -5,7 +5,7 @@ import os import pathlib import sys -from unittest.mock import MagicMock, patch, ANY +from unittest.mock import ANY, MagicMock, patch from composer import Trainer from composer.loggers import MLFlowLogger From 94e3515effca72293cb977e7c0cfa9546656ffef Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Tue, 7 Nov 2023 16:07:04 -0800 Subject: [PATCH 04/22] test --- tests/test_hf_conversion_script.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index d5ee86a406..955be64ef7 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -422,20 +422,14 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, if dist.get_global_rank() == 0: if log_to_mlflow: - # assert mlflow_logger_mock.save_model.call_count == 1 - mlflow_logger_mock.save_model.assert_called_once_with( + mlflow_logger_mock.save_model.assert_called_with( flavor='transformers', transformers_model=ANY, path=ANY, task='text-generation', - metatdata={'task': 'llm/v1/completions'}) + metatdata={'task': 'llm/v1/completions'} + ) assert mlflow_logger_mock.register_model.call_count == 1 - # mlflow_logger.save_model( - # flavor='transformers', - # transformers_model=components, - # path=local_save_path, - # **self.mlflow_logging_config, - # ) else: assert mlflow_logger_mock.save_model.call_count == 0 assert mlflow_logger_mock.register_model.call_count == 0 From c1a80e62992df5a48fbe90912a1ece2d0842d993 Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Wed, 8 Nov 2023 14:15:48 -0800 Subject: [PATCH 05/22] add print --- llmfoundry/callbacks/hf_checkpointer.py | 2 ++ tests/test_hf_conversion_script.py | 14 +++++++------- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index b8990f67df..3ae692057a 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -265,6 +265,8 @@ def _save_checkpoint(self, state: State, logger: Logger): # TODO: Remove after mlflow fixes the bug that makes this necessary import mlflow mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' + print(f"_save_checkpoint::self.mlflow_logging_config={self.mlflow_logger}\n[{type(self.mlflow_logger)}]") + print(f"_save_checkpoint::self.mlflow_logging_config['metadata']={self.mlflow_logging_config['metadata']}\n{type(self.mlflow_logging_config['metadata'])}]") mlflow_logger.save_model( flavor='transformers', transformers_model=components, diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 955be64ef7..fb047e0bef 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -422,13 +422,13 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, if dist.get_global_rank() == 0: if log_to_mlflow: - mlflow_logger_mock.save_model.assert_called_with( - flavor='transformers', - transformers_model=ANY, - path=ANY, - task='text-generation', - metatdata={'task': 'llm/v1/completions'} - ) + # mlflow_logger_mock.save_model.assert_called_with( + # flavor='transformers', + # transformers_model=ANY, + # path=ANY, + # task='text-generation', + # metatdata={'task': 'llm/v1/completions'} + # ) assert mlflow_logger_mock.register_model.call_count == 1 else: assert mlflow_logger_mock.save_model.call_count == 0 From d97e4afe002791275e1f079708fb6ae9e6404863 Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Wed, 8 Nov 2023 14:39:26 -0800 Subject: [PATCH 06/22] a hacky fix for UCObjectStore? --- llmfoundry/callbacks/hf_checkpointer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 3ae692057a..fe50c950be 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -98,6 +98,7 @@ def __init__( if self.upload_to_object_store: self.remote_ud = RemoteUploaderDownloader( bucket_uri=f'{self.backend}://{self.bucket_name}', + backend_kwargs={'path': self.save_dir_format_str}, num_concurrent_uploads=4) else: self.remote_ud = None From b143df3926dec89d0969faf561431b1425729fda Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Thu, 9 Nov 2023 09:11:09 -0800 Subject: [PATCH 07/22] debug --- llmfoundry/callbacks/hf_checkpointer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index fe50c950be..112bc07571 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -266,7 +266,7 @@ def _save_checkpoint(self, state: State, logger: Logger): # TODO: Remove after mlflow fixes the bug that makes this necessary import mlflow mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' - print(f"_save_checkpoint::self.mlflow_logging_config={self.mlflow_logger}\n[{type(self.mlflow_logger)}]") + print(f"_save_checkpoint::self.mlflow_logging_config={self.mlflow_logging_config}\n[{type(self.mlflow_logging_config)}]") print(f"_save_checkpoint::self.mlflow_logging_config['metadata']={self.mlflow_logging_config['metadata']}\n{type(self.mlflow_logging_config['metadata'])}]") mlflow_logger.save_model( flavor='transformers', From 18e6f4871c3dfd447f2fc02e7c33360361572dcf Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Thu, 9 Nov 2023 11:08:48 -0800 Subject: [PATCH 08/22] convert mlflow_logging_config from omegaconf to dict --- llmfoundry/callbacks/hf_checkpointer.py | 1 + llmfoundry/utils/builders.py | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 112bc07571..fe82f2359e 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -69,6 +69,7 @@ def __init__( # mlflow config setup self.mlflow_registered_model_name = mlflow_registered_model_name + print(f"__init__::mlflow_logging_config={mlflow_logging_config}\n[{type(mlflow_logging_config)}]") if mlflow_logging_config is None: mlflow_logging_config = {} if self.mlflow_registered_model_name is not None: diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index f027afb0ce..ed3a3336bf 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -117,6 +117,10 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: elif name == 'early_stopper': return EarlyStopper(**kwargs) elif name == 'hf_checkpointer': + mlflow_logging_config = kwargs.pop("mlflow_logging_config", None) + if isinstance(mlflow_logging_config, omegaconf.dictconfig.DictConfig): + mlflow_logging_config = om.to_object(mlflow_logging_config) + kwargs["mlflow_logging_config"] = mlflow_logging_config return HuggingFaceCheckpointer(**kwargs) else: raise ValueError(f'Not sure how to build callback: {name}') From e416b96c8b7d68dc08a4a562292c73323334b25d Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Thu, 9 Nov 2023 11:41:39 -0800 Subject: [PATCH 09/22] pre-commit --- llmfoundry/callbacks/hf_checkpointer.py | 12 +++++++++--- llmfoundry/utils/builders.py | 4 ++-- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 882f42dc1d..4db918aebe 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -69,7 +69,9 @@ def __init__( # mlflow config setup self.mlflow_registered_model_name = mlflow_registered_model_name - print(f"__init__::mlflow_logging_config={mlflow_logging_config}\n[{type(mlflow_logging_config)}]") + print( + f'__init__::mlflow_logging_config={mlflow_logging_config}\n[{type(mlflow_logging_config)}]' + ) if mlflow_logging_config is None: mlflow_logging_config = {} if self.mlflow_registered_model_name is not None: @@ -260,8 +262,12 @@ def _save_checkpoint(self, state: State, logger: Logger): # TODO: Remove after mlflow fixes the bug that makes this necessary import mlflow mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' - print(f"_save_checkpoint::self.mlflow_logging_config={self.mlflow_logging_config}\n[{type(self.mlflow_logging_config)}]") - print(f"_save_checkpoint::self.mlflow_logging_config['metadata']={self.mlflow_logging_config['metadata']}\n{type(self.mlflow_logging_config['metadata'])}]") + print( + f'_save_checkpoint::self.mlflow_logging_config={self.mlflow_logging_config}\n[{type(self.mlflow_logging_config)}]' + ) + print( + f"_save_checkpoint::self.mlflow_logging_config['metadata']={self.mlflow_logging_config['metadata']}\n{type(self.mlflow_logging_config['metadata'])}]" + ) mlflow_logger.save_model( flavor='transformers', transformers_model=components, diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 6b0ad0614c..8889e50b80 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -117,10 +117,10 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: elif name == 'early_stopper': return EarlyStopper(**kwargs) elif name == 'hf_checkpointer': - mlflow_logging_config = kwargs.pop("mlflow_logging_config", None) + mlflow_logging_config = kwargs.pop('mlflow_logging_config', None) if isinstance(mlflow_logging_config, omegaconf.dictconfig.DictConfig): mlflow_logging_config = om.to_object(mlflow_logging_config) - kwargs["mlflow_logging_config"] = mlflow_logging_config + kwargs['mlflow_logging_config'] = mlflow_logging_config return HuggingFaceCheckpointer(**kwargs) else: raise ValueError(f'Not sure how to build callback: {name}') From d47cf4d25ee622d55f2a6bb000f7f653526fc28a Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Thu, 9 Nov 2023 11:53:26 -0800 Subject: [PATCH 10/22] fix --- llmfoundry/utils/builders.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 8889e50b80..a4fa52139b 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -118,7 +118,7 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: return EarlyStopper(**kwargs) elif name == 'hf_checkpointer': mlflow_logging_config = kwargs.pop('mlflow_logging_config', None) - if isinstance(mlflow_logging_config, omegaconf.dictconfig.DictConfig): + if isinstance(mlflow_logging_config, DictConfig): mlflow_logging_config = om.to_object(mlflow_logging_config) kwargs['mlflow_logging_config'] = mlflow_logging_config return HuggingFaceCheckpointer(**kwargs) From e8274a11b5c7296ed83ce98ea985a80a7ebacc92 Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Thu, 9 Nov 2023 13:24:28 -0800 Subject: [PATCH 11/22] more prints --- llmfoundry/utils/builders.py | 3 +++ tests/test_hf_conversion_script.py | 14 +++++++------- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index a4fa52139b..b70313a2dc 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -118,9 +118,12 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: return EarlyStopper(**kwargs) elif name == 'hf_checkpointer': mlflow_logging_config = kwargs.pop('mlflow_logging_config', None) + print(f"build_callback::mlflow_logging_config={mlflow_logging_config}") + print(f"build_callback::isinstance(mlflow_logging_config, DictConfig)={isinstance(mlflow_logging_config, DictConfig)}") if isinstance(mlflow_logging_config, DictConfig): mlflow_logging_config = om.to_object(mlflow_logging_config) kwargs['mlflow_logging_config'] = mlflow_logging_config + print(f"build_callback::kwargs['mlflow_logging_config']={kwargs['mlflow_logging_config']}") return HuggingFaceCheckpointer(**kwargs) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index fb047e0bef..955be64ef7 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -422,13 +422,13 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, if dist.get_global_rank() == 0: if log_to_mlflow: - # mlflow_logger_mock.save_model.assert_called_with( - # flavor='transformers', - # transformers_model=ANY, - # path=ANY, - # task='text-generation', - # metatdata={'task': 'llm/v1/completions'} - # ) + mlflow_logger_mock.save_model.assert_called_with( + flavor='transformers', + transformers_model=ANY, + path=ANY, + task='text-generation', + metatdata={'task': 'llm/v1/completions'} + ) assert mlflow_logger_mock.register_model.call_count == 1 else: assert mlflow_logger_mock.save_model.call_count == 0 From fb0cd3eb8b635a209871440572ea299f2cb34fe9 Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Thu, 9 Nov 2023 15:03:31 -0800 Subject: [PATCH 12/22] more... debug --- tests/test_builders.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_builders.py b/tests/test_builders.py index 0d24d2154f..d5286d521e 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -7,6 +7,8 @@ import pytest from composer.callbacks import Generate from transformers import PreTrainedTokenizerBase +from omegaconf import DictConfig +from omegaconf import OmegaConf as om from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.builders import build_callback, build_tokenizer @@ -78,3 +80,13 @@ def test_build_generate_callback_unspecified_interval(): 'foo': 'bar', 'something': 'else', }) + +def test_build_hf_checkpointer_callback(): + hfc = build_callback( + 'hf_checkpointer', + mlflow_logging_config=om.create({"metadata": {'task': 'llm/v1/completions'}}) + ) + print(hfs) + + + From 5be055d59f37a9bf24c29038268b81a7dae5c3d7 Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Thu, 9 Nov 2023 16:06:40 -0800 Subject: [PATCH 13/22] ??? --- llmfoundry/utils/builders.py | 4 ++++ tests/test_builders.py | 25 ++++++++++++++++++------- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index b70313a2dc..e4f310a989 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -121,9 +121,13 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: print(f"build_callback::mlflow_logging_config={mlflow_logging_config}") print(f"build_callback::isinstance(mlflow_logging_config, DictConfig)={isinstance(mlflow_logging_config, DictConfig)}") if isinstance(mlflow_logging_config, DictConfig): + print("converting mlflow_logging_config") mlflow_logging_config = om.to_object(mlflow_logging_config) + print(f"build_callback::mlflow_logging_config={mlflow_logging_config}") + print(f"[{type(mlflow_logging_config)}]") kwargs['mlflow_logging_config'] = mlflow_logging_config print(f"build_callback::kwargs['mlflow_logging_config']={kwargs['mlflow_logging_config']}") + print(f"[{type(kwargs['mlflow_logging_config'])}") return HuggingFaceCheckpointer(**kwargs) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/tests/test_builders.py b/tests/test_builders.py index d5286d521e..8e6f0d000c 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -12,6 +12,7 @@ from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.builders import build_callback, build_tokenizer +from llmfoundry.callbacks import HuggingFaceCheckpointer @pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [ @@ -82,11 +83,21 @@ def test_build_generate_callback_unspecified_interval(): }) def test_build_hf_checkpointer_callback(): - hfc = build_callback( - 'hf_checkpointer', - mlflow_logging_config=om.create({"metadata": {'task': 'llm/v1/completions'}}) - ) - print(hfs) - - + with mock.patch.object(HuggingFaceCheckpointer, '__init__') as mock_hf_checkpointer: + mock_hf_checkpointer.return_value = None + save_folder = "path_to_save_folder" + save_interval = 1 + mlflow_logging_config_dict = {"metadata": {'task': 'llm/v1/completions'}} + build_callback( + name='hf_checkpointer', + kwargs={ + "save_folder": save_folder, + "save_interval": save_interval, + "mlflow_logging_config": om.create(mlflow_logging_config_dict) + }) + assert mock_hf_checkpointer.call_count == 1 + _, _, kwargs = mock_hf_checkpointer.mock_calls[0] + assert kwargs['save_folder'] == save_folder + assert kwargs['save_interval'] == save_interval + assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict From 7f52da9ee965a370289b3411c7f78e0e8fd8d730 Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Thu, 9 Nov 2023 16:48:04 -0800 Subject: [PATCH 14/22] try copy --- llmfoundry/utils/builders.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index e4f310a989..abeef4a1eb 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -5,6 +5,7 @@ import os import warnings from typing import Any, Dict, List, Optional, Tuple, Union +from copy import deepcopy import torch from composer import algorithms @@ -117,7 +118,8 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: elif name == 'early_stopper': return EarlyStopper(**kwargs) elif name == 'hf_checkpointer': - mlflow_logging_config = kwargs.pop('mlflow_logging_config', None) + kwargs_copy = deepcopy(kwargs) + mlflow_logging_config = kwargs_copy.pop('mlflow_logging_config', None) print(f"build_callback::mlflow_logging_config={mlflow_logging_config}") print(f"build_callback::isinstance(mlflow_logging_config, DictConfig)={isinstance(mlflow_logging_config, DictConfig)}") if isinstance(mlflow_logging_config, DictConfig): @@ -125,10 +127,13 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: mlflow_logging_config = om.to_object(mlflow_logging_config) print(f"build_callback::mlflow_logging_config={mlflow_logging_config}") print(f"[{type(mlflow_logging_config)}]") - kwargs['mlflow_logging_config'] = mlflow_logging_config - print(f"build_callback::kwargs['mlflow_logging_config']={kwargs['mlflow_logging_config']}") - print(f"[{type(kwargs['mlflow_logging_config'])}") - return HuggingFaceCheckpointer(**kwargs) + + + + kwargs_copy['mlflow_logging_config'] = mlflow_logging_config + print(f"build_callback::kwargs['mlflow_logging_config']={kwargs_copy['mlflow_logging_config']}") + print(f"[{type(kwargs_copy['mlflow_logging_config'])}") + return HuggingFaceCheckpointer(**kwargs_copy) else: raise ValueError(f'Not sure how to build callback: {name}') From b042ea7ce99ef71347cfcc03ada108c3599e4a16 Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Thu, 9 Nov 2023 18:23:11 -0800 Subject: [PATCH 15/22] more prints???? --- llmfoundry/utils/builders.py | 10 ++++++---- tests/test_builders.py | 4 +++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index abeef4a1eb..357c90252d 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -118,6 +118,7 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: elif name == 'early_stopper': return EarlyStopper(**kwargs) elif name == 'hf_checkpointer': + print(type(kwargs)) kwargs_copy = deepcopy(kwargs) mlflow_logging_config = kwargs_copy.pop('mlflow_logging_config', None) print(f"build_callback::mlflow_logging_config={mlflow_logging_config}") @@ -127,11 +128,12 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: mlflow_logging_config = om.to_object(mlflow_logging_config) print(f"build_callback::mlflow_logging_config={mlflow_logging_config}") print(f"[{type(mlflow_logging_config)}]") - - - + print(f"after if statement: build_callback::mlflow_logging_config={mlflow_logging_config}") + print(f"[{type(mlflow_logging_config)}]") + print(f"before reassign: kwargs_copy.get('mlflow_logging_config', None)={kwargs_copy.get('mlflow_logging_config', None)}") + print(f"[{type(kwargs_copy.get('mlflow_logging_config', None))}]") kwargs_copy['mlflow_logging_config'] = mlflow_logging_config - print(f"build_callback::kwargs['mlflow_logging_config']={kwargs_copy['mlflow_logging_config']}") + print(f"after reassign - build_callback::kwargs_copy['mlflow_logging_config']={kwargs_copy['mlflow_logging_config']}") print(f"[{type(kwargs_copy['mlflow_logging_config'])}") return HuggingFaceCheckpointer(**kwargs_copy) else: diff --git a/tests/test_builders.py b/tests/test_builders.py index 8e6f0d000c..eb6e271eeb 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -87,7 +87,7 @@ def test_build_hf_checkpointer_callback(): mock_hf_checkpointer.return_value = None save_folder = "path_to_save_folder" save_interval = 1 - mlflow_logging_config_dict = {"metadata": {'task': 'llm/v1/completions'}} + mlflow_logging_config_dict = {'metadata': {'databricks_model_family': 'MptForCausalLM', 'databricks_model_size_parameters': '7b', 'databricks_model_source': 'mosaic-fine-tuning', 'task': 'llm/v1/completions'}} build_callback( name='hf_checkpointer', kwargs={ @@ -100,4 +100,6 @@ def test_build_hf_checkpointer_callback(): _, _, kwargs = mock_hf_checkpointer.mock_calls[0] assert kwargs['save_folder'] == save_folder assert kwargs['save_interval'] == save_interval + assert isinstance(kwargs['mlflow_logging_config'], dict) + assert isinstance(kwargs['mlflow_logging_config']['metadata'], dict) assert kwargs['mlflow_logging_config'] == mlflow_logging_config_dict From fff9d07ce738149fa0159ea3ca9b3ec945dedb91 Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Thu, 9 Nov 2023 19:08:52 -0800 Subject: [PATCH 16/22] convert kwargs directly --- llmfoundry/utils/builders.py | 38 +++++++++++++++++++----------------- tests/test_builders.py | 6 +++--- 2 files changed, 23 insertions(+), 21 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 357c90252d..2feddb92b7 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -118,24 +118,26 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: elif name == 'early_stopper': return EarlyStopper(**kwargs) elif name == 'hf_checkpointer': - print(type(kwargs)) - kwargs_copy = deepcopy(kwargs) - mlflow_logging_config = kwargs_copy.pop('mlflow_logging_config', None) - print(f"build_callback::mlflow_logging_config={mlflow_logging_config}") - print(f"build_callback::isinstance(mlflow_logging_config, DictConfig)={isinstance(mlflow_logging_config, DictConfig)}") - if isinstance(mlflow_logging_config, DictConfig): - print("converting mlflow_logging_config") - mlflow_logging_config = om.to_object(mlflow_logging_config) - print(f"build_callback::mlflow_logging_config={mlflow_logging_config}") - print(f"[{type(mlflow_logging_config)}]") - print(f"after if statement: build_callback::mlflow_logging_config={mlflow_logging_config}") - print(f"[{type(mlflow_logging_config)}]") - print(f"before reassign: kwargs_copy.get('mlflow_logging_config', None)={kwargs_copy.get('mlflow_logging_config', None)}") - print(f"[{type(kwargs_copy.get('mlflow_logging_config', None))}]") - kwargs_copy['mlflow_logging_config'] = mlflow_logging_config - print(f"after reassign - build_callback::kwargs_copy['mlflow_logging_config']={kwargs_copy['mlflow_logging_config']}") - print(f"[{type(kwargs_copy['mlflow_logging_config'])}") - return HuggingFaceCheckpointer(**kwargs_copy) + if isinstance(kwargs, DictConfig): + kwargs = om.to_object(kwargs) + # print(type(kwargs)) + # kwargs_copy = deepcopy(kwargs) + # mlflow_logging_config = kwargs_copy.pop('mlflow_logging_config', None) + # print(f"build_callback::mlflow_logging_config={mlflow_logging_config}") + # print(f"build_callback::isinstance(mlflow_logging_config, DictConfig)={isinstance(mlflow_logging_config, DictConfig)}") + # if isinstance(mlflow_logging_config, DictConfig): + # print("converting mlflow_logging_config") + # mlflow_logging_config = om.to_object(mlflow_logging_config) + # print(f"build_callback::mlflow_logging_config={mlflow_logging_config}") + # print(f"[{type(mlflow_logging_config)}]") + # print(f"after if statement: build_callback::mlflow_logging_config={mlflow_logging_config}") + # print(f"[{type(mlflow_logging_config)}]") + # print(f"before reassign: kwargs_copy.get('mlflow_logging_config', None)={kwargs_copy.get('mlflow_logging_config', None)}") + # print(f"[{type(kwargs_copy.get('mlflow_logging_config', None))}]") + # kwargs_copy['mlflow_logging_config'] = mlflow_logging_config + # print(f"after reassign - build_callback::kwargs_copy['mlflow_logging_config']={kwargs_copy['mlflow_logging_config']}") + # print(f"[{type(kwargs_copy['mlflow_logging_config'])}") + return HuggingFaceCheckpointer(**kwargs) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/tests/test_builders.py b/tests/test_builders.py index eb6e271eeb..487eb0ffe1 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -90,11 +90,11 @@ def test_build_hf_checkpointer_callback(): mlflow_logging_config_dict = {'metadata': {'databricks_model_family': 'MptForCausalLM', 'databricks_model_size_parameters': '7b', 'databricks_model_source': 'mosaic-fine-tuning', 'task': 'llm/v1/completions'}} build_callback( name='hf_checkpointer', - kwargs={ + kwargs=om.create({ "save_folder": save_folder, "save_interval": save_interval, - "mlflow_logging_config": om.create(mlflow_logging_config_dict) - }) + "mlflow_logging_config": mlflow_logging_config_dict + })) assert mock_hf_checkpointer.call_count == 1 _, _, kwargs = mock_hf_checkpointer.mock_calls[0] From a0626e270fc0cdf1e34a8d473345c105708030af Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Thu, 9 Nov 2023 21:50:53 -0800 Subject: [PATCH 17/22] clean up --- llmfoundry/callbacks/hf_checkpointer.py | 9 -------- llmfoundry/utils/builders.py | 17 --------------- tests/test_hf_conversion_script.py | 28 ++++++++++++++----------- 3 files changed, 16 insertions(+), 38 deletions(-) diff --git a/llmfoundry/callbacks/hf_checkpointer.py b/llmfoundry/callbacks/hf_checkpointer.py index 4db918aebe..fb0ccabaf5 100644 --- a/llmfoundry/callbacks/hf_checkpointer.py +++ b/llmfoundry/callbacks/hf_checkpointer.py @@ -69,9 +69,6 @@ def __init__( # mlflow config setup self.mlflow_registered_model_name = mlflow_registered_model_name - print( - f'__init__::mlflow_logging_config={mlflow_logging_config}\n[{type(mlflow_logging_config)}]' - ) if mlflow_logging_config is None: mlflow_logging_config = {} if self.mlflow_registered_model_name is not None: @@ -262,12 +259,6 @@ def _save_checkpoint(self, state: State, logger: Logger): # TODO: Remove after mlflow fixes the bug that makes this necessary import mlflow mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' - print( - f'_save_checkpoint::self.mlflow_logging_config={self.mlflow_logging_config}\n[{type(self.mlflow_logging_config)}]' - ) - print( - f"_save_checkpoint::self.mlflow_logging_config['metadata']={self.mlflow_logging_config['metadata']}\n{type(self.mlflow_logging_config['metadata'])}]" - ) mlflow_logger.save_model( flavor='transformers', transformers_model=components, diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 2feddb92b7..ea693a4105 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -120,23 +120,6 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: elif name == 'hf_checkpointer': if isinstance(kwargs, DictConfig): kwargs = om.to_object(kwargs) - # print(type(kwargs)) - # kwargs_copy = deepcopy(kwargs) - # mlflow_logging_config = kwargs_copy.pop('mlflow_logging_config', None) - # print(f"build_callback::mlflow_logging_config={mlflow_logging_config}") - # print(f"build_callback::isinstance(mlflow_logging_config, DictConfig)={isinstance(mlflow_logging_config, DictConfig)}") - # if isinstance(mlflow_logging_config, DictConfig): - # print("converting mlflow_logging_config") - # mlflow_logging_config = om.to_object(mlflow_logging_config) - # print(f"build_callback::mlflow_logging_config={mlflow_logging_config}") - # print(f"[{type(mlflow_logging_config)}]") - # print(f"after if statement: build_callback::mlflow_logging_config={mlflow_logging_config}") - # print(f"[{type(mlflow_logging_config)}]") - # print(f"before reassign: kwargs_copy.get('mlflow_logging_config', None)={kwargs_copy.get('mlflow_logging_config', None)}") - # print(f"[{type(kwargs_copy.get('mlflow_logging_config', None))}]") - # kwargs_copy['mlflow_logging_config'] = mlflow_logging_config - # print(f"after reassign - build_callback::kwargs_copy['mlflow_logging_config']={kwargs_copy['mlflow_logging_config']}") - # print(f"[{type(kwargs_copy['mlflow_logging_config'])}") return HuggingFaceCheckpointer(**kwargs) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 955be64ef7..47f8408dce 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -421,18 +421,22 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, trainer.fit() if dist.get_global_rank() == 0: - if log_to_mlflow: - mlflow_logger_mock.save_model.assert_called_with( - flavor='transformers', - transformers_model=ANY, - path=ANY, - task='text-generation', - metatdata={'task': 'llm/v1/completions'} - ) - assert mlflow_logger_mock.register_model.call_count == 1 - else: - assert mlflow_logger_mock.save_model.call_count == 0 - assert mlflow_logger_mock.register_model.call_count == 0 + assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow + else 0) + assert mlflow_logger_mock.register_model.call_count == ( + 1 if log_to_mlflow else 0) + # if log_to_mlflow: + # # mlflow_logger_mock.save_model.assert_called_with( + # # flavor='transformers', + # # transformers_model=ANY, + # # path=ANY, + # # task='text-generation', + # # metatdata={'task': 'llm/v1/completions'} + # # ) + # assert mlflow_logger_mock.register_model.call_count == 1 + # else: + # assert mlflow_logger_mock.save_model.call_count == 0 + # assert mlflow_logger_mock.register_model.call_count == 0 else: assert mlflow_logger_mock.log_model.call_count == 0 assert mlflow_logger_mock.register_model.call_count == 0 From f3ee2ae5ba875d1d59668e8104049be012b4cfd3 Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Thu, 9 Nov 2023 22:02:31 -0800 Subject: [PATCH 18/22] pre-commit --- llmfoundry/utils/builders.py | 1 - tests/test_builders.py | 32 ++++++++++++++++++++------------ 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index ea693a4105..96be6ad45d 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -5,7 +5,6 @@ import os import warnings from typing import Any, Dict, List, Optional, Tuple, Union -from copy import deepcopy import torch from composer import algorithms diff --git a/tests/test_builders.py b/tests/test_builders.py index 487eb0ffe1..a8b484bb24 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -6,13 +6,13 @@ import pytest from composer.callbacks import Generate -from transformers import PreTrainedTokenizerBase from omegaconf import DictConfig from omegaconf import OmegaConf as om +from transformers import PreTrainedTokenizerBase +from llmfoundry.callbacks import HuggingFaceCheckpointer from llmfoundry.tokenizers.tiktoken import TiktokenTokenizerWrapper from llmfoundry.utils.builders import build_callback, build_tokenizer -from llmfoundry.callbacks import HuggingFaceCheckpointer @pytest.mark.parametrize('tokenizer_name,tokenizer_kwargs', [ @@ -82,19 +82,27 @@ def test_build_generate_callback_unspecified_interval(): 'something': 'else', }) + def test_build_hf_checkpointer_callback(): - with mock.patch.object(HuggingFaceCheckpointer, '__init__') as mock_hf_checkpointer: + with mock.patch.object(HuggingFaceCheckpointer, + '__init__') as mock_hf_checkpointer: mock_hf_checkpointer.return_value = None - save_folder = "path_to_save_folder" + save_folder = 'path_to_save_folder' save_interval = 1 - mlflow_logging_config_dict = {'metadata': {'databricks_model_family': 'MptForCausalLM', 'databricks_model_size_parameters': '7b', 'databricks_model_source': 'mosaic-fine-tuning', 'task': 'llm/v1/completions'}} - build_callback( - name='hf_checkpointer', - kwargs=om.create({ - "save_folder": save_folder, - "save_interval": save_interval, - "mlflow_logging_config": mlflow_logging_config_dict - })) + mlflow_logging_config_dict = { + 'metadata': { + 'databricks_model_family': 'MptForCausalLM', + 'databricks_model_size_parameters': '7b', + 'databricks_model_source': 'mosaic-fine-tuning', + 'task': 'llm/v1/completions' + } + } + build_callback(name='hf_checkpointer', + kwargs=om.create({ + 'save_folder': save_folder, + 'save_interval': save_interval, + 'mlflow_logging_config': mlflow_logging_config_dict + })) assert mock_hf_checkpointer.call_count == 1 _, _, kwargs = mock_hf_checkpointer.mock_calls[0] From 5e0d31bb7c470f47aa654a176dc9b7b6024c125f Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Thu, 9 Nov 2023 22:34:31 -0800 Subject: [PATCH 19/22] pyright --- llmfoundry/utils/builders.py | 5 +++-- tests/test_builders.py | 1 - tests/test_hf_conversion_script.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index 96be6ad45d..349dd9c017 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -73,7 +73,8 @@ def build_icl_data_and_gauntlet( return icl_evaluators, logger_keys, eval_gauntlet_cb -def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: +def build_callback(name: str, kwargs: Union[DictConfig, Dict[str, + Any]]) -> Callback: if name == 'lr_monitor': return LRMonitor() elif name == 'memory_monitor': @@ -118,7 +119,7 @@ def build_callback(name: str, kwargs: Dict[str, Any]) -> Callback: return EarlyStopper(**kwargs) elif name == 'hf_checkpointer': if isinstance(kwargs, DictConfig): - kwargs = om.to_object(kwargs) + kwargs = om.to_object(kwargs) # pyright: ignore return HuggingFaceCheckpointer(**kwargs) else: raise ValueError(f'Not sure how to build callback: {name}') diff --git a/tests/test_builders.py b/tests/test_builders.py index a8b484bb24..237e27b52b 100644 --- a/tests/test_builders.py +++ b/tests/test_builders.py @@ -6,7 +6,6 @@ import pytest from composer.callbacks import Generate -from omegaconf import DictConfig from omegaconf import OmegaConf as om from transformers import PreTrainedTokenizerBase diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 47f8408dce..73a027704c 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -5,7 +5,7 @@ import os import pathlib import sys -from unittest.mock import ANY, MagicMock, patch +from unittest.mock import MagicMock, patch from composer import Trainer from composer.loggers import MLFlowLogger From 3adaf11555b42a64026ed0bfae0b89cebca3e210 Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Tue, 14 Nov 2023 13:19:20 -0800 Subject: [PATCH 20/22] clean --- tests/test_hf_conversion_script.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 73a027704c..d2c2a9e1c9 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -425,18 +425,6 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, else 0) assert mlflow_logger_mock.register_model.call_count == ( 1 if log_to_mlflow else 0) - # if log_to_mlflow: - # # mlflow_logger_mock.save_model.assert_called_with( - # # flavor='transformers', - # # transformers_model=ANY, - # # path=ANY, - # # task='text-generation', - # # metatdata={'task': 'llm/v1/completions'} - # # ) - # assert mlflow_logger_mock.register_model.call_count == 1 - # else: - # assert mlflow_logger_mock.save_model.call_count == 0 - # assert mlflow_logger_mock.register_model.call_count == 0 else: assert mlflow_logger_mock.log_model.call_count == 0 assert mlflow_logger_mock.register_model.call_count == 0 From 157e059e4150e5dae7cd15e0067b77fded64bea2 Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Tue, 14 Nov 2023 13:33:34 -0800 Subject: [PATCH 21/22] tests --- tests/test_hf_conversion_script.py | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index d2c2a9e1c9..7ba2559de2 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -421,10 +421,23 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, trainer.fit() if dist.get_global_rank() == 0: - assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow - else 0) - assert mlflow_logger_mock.register_model.call_count == ( - 1 if log_to_mlflow else 0) + # assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow + # else 0) + # assert mlflow_logger_mock.register_model.call_count == ( + # 1 if log_to_mlflow else 0) + if log_to_mlflow: + assert mlflow_logger_mock.save_model.call_count == 1 + mlflow_logger_mock.save_model.assert_called_with( + flavor='transformers', + transformers_model=ANY, + path=ANY, + task='text-generation', + metatdata={'task': 'llm/v1/completions'} + ) + assert mlflow_logger_mock.register_model.call_count == 1 + else: + assert mlflow_logger_mock.save_model.call_count == 0 + assert mlflow_logger_mock.register_model.call_count == 0 else: assert mlflow_logger_mock.log_model.call_count == 0 assert mlflow_logger_mock.register_model.call_count == 0 From e801bbdb6e74784566f849acc509467720d584e7 Mon Sep 17 00:00:00 2001 From: wenfeiy-db Date: Tue, 14 Nov 2023 17:24:55 -0800 Subject: [PATCH 22/22] fix tests --- tests/test_hf_conversion_script.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/tests/test_hf_conversion_script.py b/tests/test_hf_conversion_script.py index 7ba2559de2..07b951a382 100644 --- a/tests/test_hf_conversion_script.py +++ b/tests/test_hf_conversion_script.py @@ -5,7 +5,7 @@ import os import pathlib import sys -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch from composer import Trainer from composer.loggers import MLFlowLogger @@ -242,9 +242,22 @@ def get_config( return cast(DictConfig, test_cfg) -def test_callback_inits_with_defaults(): +def test_callback_inits(): + # test with defaults _ = HuggingFaceCheckpointer(save_folder='test', save_interval='1ba') + # test default metatdata when mlflow registered name is given + hf_checkpointer = HuggingFaceCheckpointer( + save_folder='test', + save_interval='1ba', + mlflow_registered_model_name='test_model_name') + assert hf_checkpointer.mlflow_logging_config == { + 'task': 'text-generation', + 'metadata': { + 'task': 'llm/v1/completions' + } + } + @pytest.mark.world_size(2) @pytest.mark.gpu @@ -421,10 +434,6 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, trainer.fit() if dist.get_global_rank() == 0: - # assert mlflow_logger_mock.save_model.call_count == (1 if log_to_mlflow - # else 0) - # assert mlflow_logger_mock.register_model.call_count == ( - # 1 if log_to_mlflow else 0) if log_to_mlflow: assert mlflow_logger_mock.save_model.call_count == 1 mlflow_logger_mock.save_model.assert_called_with( @@ -432,8 +441,7 @@ def test_huggingface_conversion_callback(model: str, tmp_path: pathlib.Path, transformers_model=ANY, path=ANY, task='text-generation', - metatdata={'task': 'llm/v1/completions'} - ) + metadata={'task': 'llm/v1/completions'}) assert mlflow_logger_mock.register_model.call_count == 1 else: assert mlflow_logger_mock.save_model.call_count == 0