From 4cf083542a24b322282c135df9bc91fb430aad91 Mon Sep 17 00:00:00 2001 From: shauray8 Date: Thu, 22 Jun 2023 15:18:00 +0530 Subject: [PATCH 01/19] testing --- src/ausio.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 src/ausio.py diff --git a/src/ausio.py b/src/ausio.py new file mode 100644 index 000000000000..e1308e63288f --- /dev/null +++ b/src/ausio.py @@ -0,0 +1,14 @@ +from datasets import load_dataset, Audio +from transformers import EncodecModel, AutoProcessor +librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + +model = EncodecModel.from_pretrained("facebook/encodec_24khz") +processor = AutoProcessor.from_pretrained("facebook/encodec_24khz") +librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate)) +audio_sample = librispeech_dummy[-1]["audio"]["array"] +inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt") + +encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"]) +audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])[0] +# or the equivalent with a forward pass +audio_values = model(inputs["input_values"], inputs["padding_mask"]).audio_values From 7e2f313b2c3524431f3f572f9185e0a2aecd44b5 Mon Sep 17 00:00:00 2001 From: shauray8 Date: Sun, 25 Jun 2023 22:33:01 +0530 Subject: [PATCH 02/19] example script --- src/ausio.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/src/ausio.py b/src/ausio.py index e1308e63288f..5bc057cb0fcc 100644 --- a/src/ausio.py +++ b/src/ausio.py @@ -1,5 +1,6 @@ from datasets import load_dataset, Audio from transformers import EncodecModel, AutoProcessor +from transformers import Trainer librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") model = EncodecModel.from_pretrained("facebook/encodec_24khz") @@ -8,7 +9,21 @@ audio_sample = librispeech_dummy[-1]["audio"]["array"] inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt") -encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"]) -audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])[0] -# or the equivalent with a forward pass -audio_values = model(inputs["input_values"], inputs["padding_mask"]).audio_values +class CustomTrainer(Trainer): + def compute_loss(self, model, inputs, return_outputs=False): + labels = inputs.get("labels") + # forward pass + outputs = model(inputs) + logits = outputs.get("logits") + # compute custom loss (suppose one has 3 labels with different weights) + loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device)) + loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) + return (loss, outputs) if return_outputs else loss + +#encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"]) +#audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])[0] +## or the equivalent with a forward pass +#audio_values = model(inputs["input_values"], inputs["padding_mask"]).audio_values + +a = CustomTrainer() +a.compute_loss(model, inputs) From 2a2b232d3d88a99601bd053b3b8b1c2e706b2705 Mon Sep 17 00:00:00 2001 From: shauray8 Date: Wed, 28 Jun 2023 18:53:06 +0530 Subject: [PATCH 03/19] fix typehinting --- src/ausio.py | 29 ----------------------------- src/test.py | 14 ++++++++++++++ src/transformers/training_args.py | 6 +++--- 3 files changed, 17 insertions(+), 32 deletions(-) delete mode 100644 src/ausio.py create mode 100644 src/test.py diff --git a/src/ausio.py b/src/ausio.py deleted file mode 100644 index 5bc057cb0fcc..000000000000 --- a/src/ausio.py +++ /dev/null @@ -1,29 +0,0 @@ -from datasets import load_dataset, Audio -from transformers import EncodecModel, AutoProcessor -from transformers import Trainer -librispeech_dummy = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - -model = EncodecModel.from_pretrained("facebook/encodec_24khz") -processor = AutoProcessor.from_pretrained("facebook/encodec_24khz") -librispeech_dummy = librispeech_dummy.cast_column("audio", Audio(sampling_rate=processor.sampling_rate)) -audio_sample = librispeech_dummy[-1]["audio"]["array"] -inputs = processor(raw_audio=audio_sample, sampling_rate=processor.sampling_rate, return_tensors="pt") - -class CustomTrainer(Trainer): - def compute_loss(self, model, inputs, return_outputs=False): - labels = inputs.get("labels") - # forward pass - outputs = model(inputs) - logits = outputs.get("logits") - # compute custom loss (suppose one has 3 labels with different weights) - loss_fct = nn.CrossEntropyLoss(weight=torch.tensor([1.0, 2.0, 3.0], device=model.device)) - loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) - return (loss, outputs) if return_outputs else loss - -#encoder_outputs = model.encode(inputs["input_values"], inputs["padding_mask"]) -#audio_values = model.decode(encoder_outputs.audio_codes, encoder_outputs.audio_scales, inputs["padding_mask"])[0] -## or the equivalent with a forward pass -#audio_values = model(inputs["input_values"], inputs["padding_mask"]).audio_values - -a = CustomTrainer() -a.compute_loss(model, inputs) diff --git a/src/test.py b/src/test.py new file mode 100644 index 000000000000..97f28a417df9 --- /dev/null +++ b/src/test.py @@ -0,0 +1,14 @@ +from pydantic import BaseModel +from transformers.training_args import TrainingArguments + +class MyTrainingArguments(TrainingArguments): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.my_arg = "my_arg" + + +class MyModel(BaseModel): + training_args: MyTrainingArguments + + +model = MyModel(training_args=MyTrainingArguments(output_dir="")) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index e8c2823f3793..59d9d0ca70fe 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -944,7 +944,7 @@ class TrainingArguments: ) }, ) - sharded_ddp: str = field( + sharded_ddp: Union[Optional[str], bool, List[ShardedDDPOption]] = field( default="", metadata={ "help": ( @@ -955,7 +955,7 @@ class TrainingArguments: ), }, ) - fsdp: str = field( + fsdp: Union[Optional[str], bool, List[FSDPOption]] = field( default="", metadata={ "help": ( @@ -976,7 +976,7 @@ class TrainingArguments: ) }, ) - fsdp_config: Optional[str] = field( + fsdp_config: Union[Optional[str], Dict] = field( default=None, metadata={ "help": ( From bb237350533cfbc8cef5bfd6561be3308e1feefc Mon Sep 17 00:00:00 2001 From: shauray8 Date: Wed, 28 Jun 2023 19:03:50 +0530 Subject: [PATCH 04/19] some tests --- src/test.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/test.py b/src/test.py index 97f28a417df9..59db8898f5a3 100644 --- a/src/test.py +++ b/src/test.py @@ -6,9 +6,7 @@ def __init__(self, **kwargs): super().__init__(**kwargs) self.my_arg = "my_arg" - class MyModel(BaseModel): training_args: MyTrainingArguments - -model = MyModel(training_args=MyTrainingArguments(output_dir="")) +model = MyModel(training_args=MyTrainingArguments(output_dir="./")) From 1569775a647f5b06b7739b727dcb67a45e4e113d Mon Sep 17 00:00:00 2001 From: shauray8 Date: Wed, 28 Jun 2023 19:28:55 +0530 Subject: [PATCH 05/19] make test --- src/test.py | 12 ------------ 1 file changed, 12 deletions(-) delete mode 100644 src/test.py diff --git a/src/test.py b/src/test.py deleted file mode 100644 index 59db8898f5a3..000000000000 --- a/src/test.py +++ /dev/null @@ -1,12 +0,0 @@ -from pydantic import BaseModel -from transformers.training_args import TrainingArguments - -class MyTrainingArguments(TrainingArguments): - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.my_arg = "my_arg" - -class MyModel(BaseModel): - training_args: MyTrainingArguments - -model = MyModel(training_args=MyTrainingArguments(output_dir="./")) From 266d5b0e620fdc438e893729c007e0f5d058ca57 Mon Sep 17 00:00:00 2001 From: shauray8 Date: Wed, 28 Jun 2023 21:07:23 +0530 Subject: [PATCH 06/19] optional update --- src/transformers/training_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 59d9d0ca70fe..e8c2cf95418e 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -944,7 +944,7 @@ class TrainingArguments: ) }, ) - sharded_ddp: Union[Optional[str], bool, List[ShardedDDPOption]] = field( + sharded_ddp: Optional[Union[str], bool, List[ShardedDDPOption]] = field( default="", metadata={ "help": ( @@ -955,7 +955,7 @@ class TrainingArguments: ), }, ) - fsdp: Union[Optional[str], bool, List[FSDPOption]] = field( + fsdp: Optional[Union[str], bool, List[FSDPOption]] = field( default="", metadata={ "help": ( From 8f6e3828a49b53e3e8fb3b82bcd0d50d5ea61933 Mon Sep 17 00:00:00 2001 From: shauray8 Date: Thu, 29 Jun 2023 01:05:14 +0530 Subject: [PATCH 07/19] Union of arguments --- src/transformers/training_args.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index e8c2cf95418e..c507561c8172 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -944,7 +944,7 @@ class TrainingArguments: ) }, ) - sharded_ddp: Optional[Union[str], bool, List[ShardedDDPOption]] = field( + sharded_ddp: Optional[Union[str, bool, List[ShardedDDPOption]]] = field( default="", metadata={ "help": ( @@ -955,7 +955,7 @@ class TrainingArguments: ), }, ) - fsdp: Optional[Union[str], bool, List[FSDPOption]] = field( + fsdp: Optional[Union[str, bool, List[FSDPOption]]] = field( default="", metadata={ "help": ( @@ -976,7 +976,7 @@ class TrainingArguments: ) }, ) - fsdp_config: Union[Optional[str], Dict] = field( + fsdp_config: Optional[Union[str, Dict]] = field( default=None, metadata={ "help": ( From 056fd3d3ab6b710153be180a99c8c95a5b675621 Mon Sep 17 00:00:00 2001 From: shauray8 Date: Thu, 29 Jun 2023 01:24:12 +0530 Subject: [PATCH 08/19] does this fix the issue --- src/reports/examples_flax/errors.txt | 0 src/reports/examples_flax/failures_line.txt | 0 src/reports/examples_flax/failures_long.txt | 0 src/reports/examples_flax/failures_short.txt | 0 src/reports/examples_flax/stats.txt | 1 + src/reports/examples_flax/summary_short.txt | 0 src/reports/examples_flax/warnings.txt | 39 ++++++++++++++++++++ src/transformers/training_args.py | 2 +- 8 files changed, 41 insertions(+), 1 deletion(-) create mode 100644 src/reports/examples_flax/errors.txt create mode 100644 src/reports/examples_flax/failures_line.txt create mode 100644 src/reports/examples_flax/failures_long.txt create mode 100644 src/reports/examples_flax/failures_short.txt create mode 100644 src/reports/examples_flax/stats.txt create mode 100644 src/reports/examples_flax/summary_short.txt create mode 100644 src/reports/examples_flax/warnings.txt diff --git a/src/reports/examples_flax/errors.txt b/src/reports/examples_flax/errors.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/reports/examples_flax/failures_line.txt b/src/reports/examples_flax/failures_line.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/reports/examples_flax/failures_long.txt b/src/reports/examples_flax/failures_long.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/reports/examples_flax/failures_short.txt b/src/reports/examples_flax/failures_short.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/reports/examples_flax/stats.txt b/src/reports/examples_flax/stats.txt new file mode 100644 index 000000000000..723469e1a9d3 --- /dev/null +++ b/src/reports/examples_flax/stats.txt @@ -0,0 +1 @@ +======================================================= 27 warnings in 8.68s ======================================================= diff --git a/src/reports/examples_flax/summary_short.txt b/src/reports/examples_flax/summary_short.txt new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/src/reports/examples_flax/warnings.txt b/src/reports/examples_flax/warnings.txt new file mode 100644 index 000000000000..4657c785b29c --- /dev/null +++ b/src/reports/examples_flax/warnings.txt @@ -0,0 +1,39 @@ +===================================================== warnings summary (final) ===================================================== +../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 +../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 +../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 +../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 +../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 +../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 +../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 +../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 +../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 + /usr/lib/python3/dist-packages/requests/__init__.py:87: RequestsDependencyWarning: urllib3 (2.0.3) or chardet (4.0.0) doesn't match a supported version! + warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported " + +../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 +../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 +../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 +../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 +../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 +../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 +../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 +../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 +../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 + /home/taylo/.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives + from distutils.util import strtobool + +../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 +../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 +../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 +../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 +../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 +../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 +../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 +../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 +../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 + /home/taylo/.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302: PytestConfigWarning: Unknown config option: doctest_glob + + self._warn_or_fail_if_strict(f"Unknown config option: {key}\n") + +-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index c507561c8172..e62a680cef25 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -945,7 +945,7 @@ class TrainingArguments: }, ) sharded_ddp: Optional[Union[str, bool, List[ShardedDDPOption]]] = field( - default="", + default=False, metadata={ "help": ( "Whether or not to use sharded DDP training (in distributed training only). The base option should be" From 71049cb8852d4656629ae4eb05be664e1fc1e8a8 Mon Sep 17 00:00:00 2001 From: shauray8 Date: Thu, 29 Jun 2023 01:24:37 +0530 Subject: [PATCH 09/19] remove reports --- src/reports/examples_flax/errors.txt | 0 src/reports/examples_flax/failures_line.txt | 0 src/reports/examples_flax/failures_long.txt | 0 src/reports/examples_flax/failures_short.txt | 0 src/reports/examples_flax/stats.txt | 1 - src/reports/examples_flax/summary_short.txt | 0 src/reports/examples_flax/warnings.txt | 39 -------------------- 7 files changed, 40 deletions(-) delete mode 100644 src/reports/examples_flax/errors.txt delete mode 100644 src/reports/examples_flax/failures_line.txt delete mode 100644 src/reports/examples_flax/failures_long.txt delete mode 100644 src/reports/examples_flax/failures_short.txt delete mode 100644 src/reports/examples_flax/stats.txt delete mode 100644 src/reports/examples_flax/summary_short.txt delete mode 100644 src/reports/examples_flax/warnings.txt diff --git a/src/reports/examples_flax/errors.txt b/src/reports/examples_flax/errors.txt deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/src/reports/examples_flax/failures_line.txt b/src/reports/examples_flax/failures_line.txt deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/src/reports/examples_flax/failures_long.txt b/src/reports/examples_flax/failures_long.txt deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/src/reports/examples_flax/failures_short.txt b/src/reports/examples_flax/failures_short.txt deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/src/reports/examples_flax/stats.txt b/src/reports/examples_flax/stats.txt deleted file mode 100644 index 723469e1a9d3..000000000000 --- a/src/reports/examples_flax/stats.txt +++ /dev/null @@ -1 +0,0 @@ -======================================================= 27 warnings in 8.68s ======================================================= diff --git a/src/reports/examples_flax/summary_short.txt b/src/reports/examples_flax/summary_short.txt deleted file mode 100644 index e69de29bb2d1..000000000000 diff --git a/src/reports/examples_flax/warnings.txt b/src/reports/examples_flax/warnings.txt deleted file mode 100644 index 4657c785b29c..000000000000 --- a/src/reports/examples_flax/warnings.txt +++ /dev/null @@ -1,39 +0,0 @@ -===================================================== warnings summary (final) ===================================================== -../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 -../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 -../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 -../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 -../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 -../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 -../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 -../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 -../../../../../usr/lib/python3/dist-packages/requests/__init__.py:87 - /usr/lib/python3/dist-packages/requests/__init__.py:87: RequestsDependencyWarning: urllib3 (2.0.3) or chardet (4.0.0) doesn't match a supported version! - warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported " - -../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 -../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 -../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 -../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 -../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 -../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 -../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 -../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 -../../../.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29 - /home/taylo/.local/lib/python3.10/site-packages/accelerate/utils/dataclasses.py:29: DeprecationWarning: The distutils package is deprecated and slated for removal in Python 3.12. Use setuptools or check PEP 632 for potential alternatives - from distutils.util import strtobool - -../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 -../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 -../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 -../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 -../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 -../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 -../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 -../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 -../../../.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302 - /home/taylo/.local/lib/python3.10/site-packages/_pytest/config/__init__.py:1302: PytestConfigWarning: Unknown config option: doctest_glob - - self._warn_or_fail_if_strict(f"Unknown config option: {key}\n") - --- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html From 7f9b4cbffd1a047f9ad0eee079e37b05e812b8ab Mon Sep 17 00:00:00 2001 From: shauray8 Date: Thu, 29 Jun 2023 01:30:19 +0530 Subject: [PATCH 10/19] set default to False --- src/transformers/training_args.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index e62a680cef25..902898bc534d 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -956,7 +956,7 @@ class TrainingArguments: }, ) fsdp: Optional[Union[str, bool, List[FSDPOption]]] = field( - default="", + default=False, metadata={ "help": ( "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training" From 878eb9bc096b63af6d0834a46ab018c24c05208e Mon Sep 17 00:00:00 2001 From: shauray8 Date: Thu, 29 Jun 2023 18:21:32 +0530 Subject: [PATCH 11/19] documentation change --- src/transformers/training_args.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 902898bc534d..ab2247b362e1 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -397,7 +397,7 @@ class TrainingArguments: When resuming training, whether or not to skip the epochs and batches to get the data loading at the same stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step can take a long time) but will not yield the same results as the interrupted training would have. - sharded_ddp (`bool`, `str` or list of [`~trainer_utils.ShardedDDPOption`], *optional*, defaults to `False`): + sharded_ddp (`bool`, `str` or list of [`~trainer_utils.ShardedDDPOption`], *optional*, defaults to `""`): Use Sharded DDP training from [FairScale](https://github.com/facebookresearch/fairscale) (in distributed training only). This is an experimental feature. @@ -412,7 +412,7 @@ class TrainingArguments: If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty list for `False` and `["simple"]` for `True`. - fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `False`): + fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `""`): Use PyTorch Distributed Parallel Training (in distributed training only). A list of options along the following: @@ -945,7 +945,7 @@ class TrainingArguments: }, ) sharded_ddp: Optional[Union[str, bool, List[ShardedDDPOption]]] = field( - default=False, + default="", metadata={ "help": ( "Whether or not to use sharded DDP training (in distributed training only). The base option should be" @@ -956,7 +956,7 @@ class TrainingArguments: }, ) fsdp: Optional[Union[str, bool, List[FSDPOption]]] = field( - default=False, + default="", metadata={ "help": ( "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training" From 4443814419d269d21bb6f25909204a12cb626d7d Mon Sep 17 00:00:00 2001 From: shauray8 Date: Thu, 29 Jun 2023 18:30:22 +0530 Subject: [PATCH 12/19] None support --- src/transformers/training_args.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index ab2247b362e1..7ce1340965b9 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1457,6 +1457,8 @@ def __post_init__(self): raise ValueError("`--sharded_ddp simple` is not compatible with any other option.") elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp: raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.") + if self.sharded_ddp is None: + self.sharded_ddp = [] if isinstance(self.fsdp, bool): self.fsdp = "full_shard" if self.fsdp else "" @@ -1469,6 +1471,8 @@ def __post_init__(self): ) elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp: raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") + if self.fsdp is None: + self.fsdp = [] if self.fsdp_config is None: self.fsdp_config = {} From 07760beef0c1ee557cd61d91c6db5afa5f8bfc70 Mon Sep 17 00:00:00 2001 From: shauray8 Date: Thu, 29 Jun 2023 19:17:15 +0530 Subject: [PATCH 13/19] does not need None --- src/transformers/training_args.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 7ce1340965b9..ab2247b362e1 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -1457,8 +1457,6 @@ def __post_init__(self): raise ValueError("`--sharded_ddp simple` is not compatible with any other option.") elif ShardedDDPOption.ZERO_DP_2 in self.sharded_ddp and ShardedDDPOption.ZERO_DP_3 in self.sharded_ddp: raise ValueError("`--sharded_ddp zero_dp_2` is not compatible with `--sharded_ddp zero_dp_3`.") - if self.sharded_ddp is None: - self.sharded_ddp = [] if isinstance(self.fsdp, bool): self.fsdp = "full_shard" if self.fsdp else "" @@ -1471,8 +1469,6 @@ def __post_init__(self): ) elif FSDPOption.FULL_SHARD in self.fsdp and FSDPOption.SHARD_GRAD_OP in self.fsdp: raise ValueError("`--fsdp full_shard` is not compatible with `--fsdp shard_grad_op`.") - if self.fsdp is None: - self.fsdp = [] if self.fsdp_config is None: self.fsdp_config = {} From fbac8a75baf1852a5a21411c9106206e57d8dc48 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Wed, 28 Jun 2023 17:36:17 +0300 Subject: [PATCH 14/19] Fix typing annotations for FSDP and DeepSpeed in TrainingArguments (#24549) * Fix typing annotations for FSDP and DeepSpeed in TrainingArguments * Change dict to Dict --- src/transformers/training_args.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index ab2247b362e1..72a4221879c6 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -980,8 +980,8 @@ class TrainingArguments: default=None, metadata={ "help": ( - "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a" - "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`." + "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a" + "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`." ) }, ) @@ -994,11 +994,11 @@ class TrainingArguments: ) }, ) - deepspeed: Optional[str] = field( + deepspeed: Optional[Union[str, Dict]] = field( default=None, metadata={ "help": ( - "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json) or an already" + "Enable deepspeed and pass the path to deepspeed json config file (e.g. `ds_config.json`) or an already" " loaded json file as a dict" ) }, From ea1b7141082aca50aba718852943f93a59b10d5e Mon Sep 17 00:00:00 2001 From: Sylvain Gugger <35901082+sgugger@users.noreply.github.com> Date: Thu, 29 Jun 2023 08:14:43 -0400 Subject: [PATCH 15/19] Revert "Fix typing annotations for FSDP and DeepSpeed in TrainingArguments" (#24574) Revert "Fix typing annotations for FSDP and DeepSpeed in TrainingArguments (#24549)" This reverts commit c5e29d4381d4b9739e6cb427adbca87fbb43a3ad. --- src/transformers/training_args.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 72a4221879c6..cfc37ccc62ae 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -976,12 +976,12 @@ class TrainingArguments: ) }, ) - fsdp_config: Optional[Union[str, Dict]] = field( + fsdp_config: Optional[str] = field( default=None, metadata={ "help": ( - "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a" - "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`." + "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a" + "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`." ) }, ) @@ -994,11 +994,11 @@ class TrainingArguments: ) }, ) - deepspeed: Optional[Union[str, Dict]] = field( + deepspeed: Optional[str] = field( default=None, metadata={ "help": ( - "Enable deepspeed and pass the path to deepspeed json config file (e.g. `ds_config.json`) or an already" + "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json) or an already" " loaded json file as a dict" ) }, From 05cc09b7aab8e267585ce73a4894f27260cae334 Mon Sep 17 00:00:00 2001 From: Max Ryabinin Date: Wed, 28 Jun 2023 17:36:17 +0300 Subject: [PATCH 16/19] Fix typing annotations for FSDP and DeepSpeed in TrainingArguments (#24549) * Fix typing annotations for FSDP and DeepSpeed in TrainingArguments * Change dict to Dict --- src/transformers/training_args.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index cfc37ccc62ae..c003e9fdedbd 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -980,8 +980,8 @@ class TrainingArguments: default=None, metadata={ "help": ( - "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a" - "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`." + "Config to be used with FSDP (Pytorch Fully Sharded Data Parallel). The value is either a" + "fsdp json config file (e.g., `fsdp_config.json`) or an already loaded json file as `dict`." ) }, ) @@ -994,11 +994,11 @@ class TrainingArguments: ) }, ) - deepspeed: Optional[str] = field( + deepspeed: Optional[Union[str, Dict]] = field( default=None, metadata={ "help": ( - "Enable deepspeed and pass the path to deepspeed json config file (e.g. ds_config.json) or an already" + "Enable deepspeed and pass the path to deepspeed json config file (e.g. `ds_config.json`) or an already" " loaded json file as a dict" ) }, From 760f89c2ab4354bde26221b318d0f691d5d4d476 Mon Sep 17 00:00:00 2001 From: shauray8 Date: Sun, 9 Jul 2023 22:34:28 +0530 Subject: [PATCH 17/19] merge --- src/transformers/training_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index c003e9fdedbd..ed2f60e64fbe 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -397,7 +397,7 @@ class TrainingArguments: When resuming training, whether or not to skip the epochs and batches to get the data loading at the same stage as in the previous training. If set to `True`, the training will begin faster (as that skipping step can take a long time) but will not yield the same results as the interrupted training would have. - sharded_ddp (`bool`, `str` or list of [`~trainer_utils.ShardedDDPOption`], *optional*, defaults to `""`): + sharded_ddp (`bool`, `str` or list of [`~trainer_utils.ShardedDDPOption`], *optional*, defaults to `''`): Use Sharded DDP training from [FairScale](https://github.com/facebookresearch/fairscale) (in distributed training only). This is an experimental feature. @@ -412,7 +412,7 @@ class TrainingArguments: If a string is passed, it will be split on space. If a bool is passed, it will be converted to an empty list for `False` and `["simple"]` for `True`. - fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `""`): + fsdp (`bool`, `str` or list of [`~trainer_utils.FSDPOption`], *optional*, defaults to `''`): Use PyTorch Distributed Parallel Training (in distributed training only). A list of options along the following: From 20d6b84613984f2497587a62774704882ccbeee6 Mon Sep 17 00:00:00 2001 From: shauray8 Date: Mon, 10 Jul 2023 16:40:40 +0530 Subject: [PATCH 18/19] hacky fix --- src/transformers/training_args.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index ed2f60e64fbe..46a26361135b 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -944,8 +944,8 @@ class TrainingArguments: ) }, ) - sharded_ddp: Optional[Union[str, bool, List[ShardedDDPOption]]] = field( - default="", + sharded_ddp: Optional[Union[List[ShardedDDPOption], str]] = field( + default='', metadata={ "help": ( "Whether or not to use sharded DDP training (in distributed training only). The base option should be" @@ -955,8 +955,8 @@ class TrainingArguments: ), }, ) - fsdp: Optional[Union[str, bool, List[FSDPOption]]] = field( - default="", + fsdp: Optional[Union[List[FSDPOption], str]] = field( + default='', metadata={ "help": ( "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training" From 272f6545f8c47ef85f57464dad59310ecb7d7c02 Mon Sep 17 00:00:00 2001 From: shauray8 Date: Thu, 20 Jul 2023 18:41:37 +0530 Subject: [PATCH 19/19] fixup --- src/transformers/training_args.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/transformers/training_args.py b/src/transformers/training_args.py index 46a26361135b..40f321fcb294 100644 --- a/src/transformers/training_args.py +++ b/src/transformers/training_args.py @@ -945,7 +945,7 @@ class TrainingArguments: }, ) sharded_ddp: Optional[Union[List[ShardedDDPOption], str]] = field( - default='', + default="", metadata={ "help": ( "Whether or not to use sharded DDP training (in distributed training only). The base option should be" @@ -956,7 +956,7 @@ class TrainingArguments: }, ) fsdp: Optional[Union[List[FSDPOption], str]] = field( - default='', + default="", metadata={ "help": ( "Whether or not to use PyTorch Fully Sharded Data Parallel (FSDP) training (in distributed training"