From 39c3fbdc2c75996beb4c28bca17cc1cc1b933203 Mon Sep 17 00:00:00 2001 From: dumpmemory <64742282+dumpmemory@users.noreply.github.com> Date: Wed, 30 Aug 2023 22:18:52 +0800 Subject: [PATCH 1/9] fix loss inconsistent after resume #25340 --- src/transformers/trainer.py | 29 ++++++++++++++++++++++++++--- 1 file changed, 26 insertions(+), 3 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 14487f128f58..486679201cb8 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1782,9 +1782,32 @@ def _inner_training_loop( # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): - for _ in train_dataloader: - break - + is_random_sampler = (hasattr(train_dataloader, "sampler") and isinstance( + train_dataloader.sampler, RandomSampler + )) or (hasattr(train_dataloader, "batch_sampler") and isinstance( + train_dataloader.batch_sampler.sampler, RandomSampler + )) or ( + hasattr(train_dataloader, "batch_sampler") and \ + hasattr(train_dataloader.batch_sampler, "batch_sampler") \ + and isinstance(train_dataloader.batch_sampler.batch_sampler.sampler, RandomSampler + )) + if not is_random_sampler: + # We just need to begin an iteration to create the randomization of the sampler. + for _ in train_dataloader: + break + else: + # Otherwise we need to call the whooooole sampler cause there is some random operation added + # AT THE VERY END! + sampler = [] + if hasattr(train_dataloader, "sampler") and isinstance(train_dataloader.sampler, RandomSampler): + sampler = train_dataloader.sampler + elif hasattr(train_dataloader, "batch_sampler") and isinstance(train_dataloader.batch_sampler.sampler, RandomSampler): + sampler = train_dataloader.batch_sampler.sampler + elif hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "batch_sampler") \ + and isinstance(train_dataloader.batch_sampler.batch_sampler.sampler, RandomSampler): + sampler = train_dataloader.batch_sampler.batch_sampler.sampler + _ = list(sampler) + total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): epoch_iterator = train_dataloader From 2823a53aa31dd151402aa6b2720e816ed268e60f Mon Sep 17 00:00:00 2001 From: dumpmemory <64742282+dumpmemory@users.noreply.github.com> Date: Wed, 30 Aug 2023 22:22:37 +0800 Subject: [PATCH 2/9] fix typo --- src/transformers/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 486679201cb8..a367d8cbd0ad 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1805,7 +1805,7 @@ def _inner_training_loop( sampler = train_dataloader.batch_sampler.sampler elif hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "batch_sampler") \ and isinstance(train_dataloader.batch_sampler.batch_sampler.sampler, RandomSampler): - sampler = train_dataloader.batch_sampler.batch_sampler.sampler + sampler = train_dataloader.batch_sampler.batch_sampler.sampler _ = list(sampler) total_batched_samples = 0 From 71010cb3f37fcc2a786e4dae786c384ed6d00b55 Mon Sep 17 00:00:00 2001 From: dumpmemory <64742282+dumpmemory@users.noreply.github.com> Date: Wed, 30 Aug 2023 22:32:39 +0800 Subject: [PATCH 3/9] clean code --- src/transformers/trainer.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index a367d8cbd0ad..ad6dd02e1aea 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1799,13 +1799,12 @@ def _inner_training_loop( # Otherwise we need to call the whooooole sampler cause there is some random operation added # AT THE VERY END! sampler = [] + if hasattr(train_dataloader, "sampler") and isinstance(train_dataloader.sampler, RandomSampler): sampler = train_dataloader.sampler - elif hasattr(train_dataloader, "batch_sampler") and isinstance(train_dataloader.batch_sampler.sampler, RandomSampler): - sampler = train_dataloader.batch_sampler.sampler - elif hasattr(train_dataloader, "batch_sampler") and hasattr(train_dataloader.batch_sampler, "batch_sampler") \ - and isinstance(train_dataloader.batch_sampler.batch_sampler.sampler, RandomSampler): - sampler = train_dataloader.batch_sampler.batch_sampler.sampler + else: + sampler = train_dataloader.batch_sampler + _ = list(sampler) total_batched_samples = 0 From b4c38a35d44a6fe1822c42f0200a1db97a1e9247 Mon Sep 17 00:00:00 2001 From: dumpmemory <64742282+dumpmemory@users.noreply.github.com> Date: Wed, 30 Aug 2023 15:52:22 +0000 Subject: [PATCH 4/9] reformatted code --- src/transformers/trainer.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index ad6dd02e1aea..efcf859d20ec 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -1782,15 +1782,18 @@ def _inner_training_loop( # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): - is_random_sampler = (hasattr(train_dataloader, "sampler") and isinstance( - train_dataloader.sampler, RandomSampler - )) or (hasattr(train_dataloader, "batch_sampler") and isinstance( - train_dataloader.batch_sampler.sampler, RandomSampler - )) or ( - hasattr(train_dataloader, "batch_sampler") and \ - hasattr(train_dataloader.batch_sampler, "batch_sampler") \ - and isinstance(train_dataloader.batch_sampler.batch_sampler.sampler, RandomSampler - )) + is_random_sampler = ( + (hasattr(train_dataloader, "sampler") and isinstance(train_dataloader.sampler, RandomSampler)) + or ( + hasattr(train_dataloader, "batch_sampler") + and isinstance(train_dataloader.batch_sampler.sampler, RandomSampler) + ) + or ( + hasattr(train_dataloader, "batch_sampler") + and hasattr(train_dataloader.batch_sampler, "batch_sampler") + and isinstance(train_dataloader.batch_sampler.batch_sampler.sampler, RandomSampler) + ) + ) if not is_random_sampler: # We just need to begin an iteration to create the randomization of the sampler. for _ in train_dataloader: @@ -1799,14 +1802,14 @@ def _inner_training_loop( # Otherwise we need to call the whooooole sampler cause there is some random operation added # AT THE VERY END! sampler = [] - + if hasattr(train_dataloader, "sampler") and isinstance(train_dataloader.sampler, RandomSampler): sampler = train_dataloader.sampler else: sampler = train_dataloader.batch_sampler - + _ = list(sampler) - + total_batched_samples = 0 for epoch in range(epochs_trained, num_train_epochs): epoch_iterator = train_dataloader From 67223ae4fd01bd1866c529ffe1a23f6b61769043 Mon Sep 17 00:00:00 2001 From: dumpmemory <64742282+dumpmemory@users.noreply.github.com> Date: Thu, 31 Aug 2023 02:02:28 +0000 Subject: [PATCH 5/9] adjust code according to comments --- src/transformers/trainer.py | 35 ++++++++++++++--------------------- 1 file changed, 14 insertions(+), 21 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index efcf859d20ec..6073e9a138b6 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -65,7 +65,7 @@ from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES from .optimization import Adafactor, get_scheduler -from .pytorch_utils import ALL_LAYERNORM_LAYERS +from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_less_than_1_11 from .tokenization_utils_base import PreTrainedTokenizerBase from .trainer_callback import ( CallbackHandler, @@ -215,6 +215,15 @@ if TYPE_CHECKING: import optuna + +def get_dataloader_sampler(dataloader): + if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, RandomSampler): + return dataloader.sampler, True + if hasattr(dataloader, "batch_sampler"): + return get_dataloader_sampler(dataloader.batch_sampler) + return dataloader.sampler, False + + logger = logging.get_logger(__name__) @@ -1782,32 +1791,16 @@ def _inner_training_loop( # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): - is_random_sampler = ( - (hasattr(train_dataloader, "sampler") and isinstance(train_dataloader.sampler, RandomSampler)) - or ( - hasattr(train_dataloader, "batch_sampler") - and isinstance(train_dataloader.batch_sampler.sampler, RandomSampler) - ) - or ( - hasattr(train_dataloader, "batch_sampler") - and hasattr(train_dataloader.batch_sampler, "batch_sampler") - and isinstance(train_dataloader.batch_sampler.batch_sampler.sampler, RandomSampler) - ) - ) - if not is_random_sampler: + sampler, is_random_sampler = get_dataloader_sampler(train_dataloader) + + if is_torch_less_than_1_11 or not is_random_sampler: # We just need to begin an iteration to create the randomization of the sampler. for _ in train_dataloader: break else: # Otherwise we need to call the whooooole sampler cause there is some random operation added # AT THE VERY END! - sampler = [] - - if hasattr(train_dataloader, "sampler") and isinstance(train_dataloader.sampler, RandomSampler): - sampler = train_dataloader.sampler - else: - sampler = train_dataloader.batch_sampler - + sampler = sampler if sampler is not None else [] _ = list(sampler) total_batched_samples = 0 From 401d44359d53d447cb69d62d05ed15b7a4698b48 Mon Sep 17 00:00:00 2001 From: dumpmemory <64742282+dumpmemory@users.noreply.github.com> Date: Thu, 31 Aug 2023 02:08:38 +0000 Subject: [PATCH 6/9] adjust check_dataloader_randomsampler location --- src/transformers/trainer.py | 11 ++--------- src/transformers/trainer_pt_utils.py | 8 ++++++++ 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 6073e9a138b6..1392b2b1cdf1 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -82,6 +82,7 @@ LabelSmoother, LengthGroupedSampler, SequentialDistributedSampler, + check_dataloader_randomsampler, distributed_broadcast_scalars, distributed_concat, find_batch_size, @@ -216,14 +217,6 @@ import optuna -def get_dataloader_sampler(dataloader): - if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, RandomSampler): - return dataloader.sampler, True - if hasattr(dataloader, "batch_sampler"): - return get_dataloader_sampler(dataloader.batch_sampler) - return dataloader.sampler, False - - logger = logging.get_logger(__name__) @@ -1791,7 +1784,7 @@ def _inner_training_loop( # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): - sampler, is_random_sampler = get_dataloader_sampler(train_dataloader) + sampler, is_random_sampler = check_dataloader_randomsampler(train_dataloader) if is_torch_less_than_1_11 or not is_random_sampler: # We just need to begin an iteration to create the randomization of the sampler. diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 88e27e3c4dc7..7d2237f25207 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -55,6 +55,14 @@ logger = logging.get_logger(__name__) +def check_dataloader_randomsampler(dataloader): + if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, RandomSampler): + return dataloader.sampler, True + if hasattr(dataloader, "batch_sampler"): + return check_dataloader_randomsampler(dataloader.batch_sampler) + return dataloader.sampler, False + + def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]): if isinstance(tensor_or_array, torch.Tensor): if hasattr(torch, "atleast_1d"): From dfb06a53cf031c77d5cbda482d2c1444890551b3 Mon Sep 17 00:00:00 2001 From: dumpmemory <64742282+dumpmemory@users.noreply.github.com> Date: Fri, 1 Sep 2023 02:45:12 +0000 Subject: [PATCH 7/9] return sampler only --- src/transformers/trainer.py | 6 +++--- src/transformers/trainer_pt_utils.py | 11 ++++++----- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 1392b2b1cdf1..cd74a8760208 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -82,10 +82,10 @@ LabelSmoother, LengthGroupedSampler, SequentialDistributedSampler, - check_dataloader_randomsampler, distributed_broadcast_scalars, distributed_concat, find_batch_size, + get_dataloader_sampler, get_model_param_count, get_module_class_from_name, get_parameter_names, @@ -1784,8 +1784,8 @@ def _inner_training_loop( # Skip the first epochs_trained epochs to get the random state of the dataloader at the right point. if not args.ignore_data_skip: for epoch in range(epochs_trained): - sampler, is_random_sampler = check_dataloader_randomsampler(train_dataloader) - + sampler = get_dataloader_sampler(train_dataloader) + is_random_sampler = isinstance(sampler, RandomSampler) if is_torch_less_than_1_11 or not is_random_sampler: # We just need to begin an iteration to create the randomization of the sampler. for _ in train_dataloader: diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 7d2237f25207..56c6e71a393d 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -55,12 +55,13 @@ logger = logging.get_logger(__name__) -def check_dataloader_randomsampler(dataloader): +def get_dataloader_sampler(dataloader): + # after accelerate.prepare function the wraped dataloader.sampler will be SequentialSampler instead of RandomSampler if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, RandomSampler): - return dataloader.sampler, True - if hasattr(dataloader, "batch_sampler"): - return check_dataloader_randomsampler(dataloader.batch_sampler) - return dataloader.sampler, False + return dataloader.sampler + if hasattr(dataloader, "batch_sampler") and dataloader.batch_sampler is not None: + return get_dataloader_sampler(dataloader.batch_sampler) + return dataloader.sampler def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]): From 807512c46f822a532c732f589dde58c0ce1ca676 Mon Sep 17 00:00:00 2001 From: dumpmemory <64742282+dumpmemory@users.noreply.github.com> Date: Wed, 6 Sep 2023 08:40:21 +0000 Subject: [PATCH 8/9] handle sampler is None --- src/transformers/trainer_pt_utils.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 56c6e71a393d..2998c5635b57 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -59,9 +59,11 @@ def get_dataloader_sampler(dataloader): # after accelerate.prepare function the wraped dataloader.sampler will be SequentialSampler instead of RandomSampler if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, RandomSampler): return dataloader.sampler - if hasattr(dataloader, "batch_sampler") and dataloader.batch_sampler is not None: + elif hasattr(dataloader, "batch_sampler") and dataloader.batch_sampler is not None: return get_dataloader_sampler(dataloader.batch_sampler) - return dataloader.sampler + elif hasattr(dataloader, "sampler"): + return dataloader.sampler + return None def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]): From 49146c8cdc7177fce932410eccd6ce4aaf019c85 Mon Sep 17 00:00:00 2001 From: dumpmemory <64742282+dumpmemory@users.noreply.github.com> Date: Wed, 6 Sep 2023 19:27:03 +0800 Subject: [PATCH 9/9] Update src/transformers/trainer_pt_utils.py thanks @amyeroberts Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com> --- src/transformers/trainer_pt_utils.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/transformers/trainer_pt_utils.py b/src/transformers/trainer_pt_utils.py index 2998c5635b57..62d38ffc7593 100644 --- a/src/transformers/trainer_pt_utils.py +++ b/src/transformers/trainer_pt_utils.py @@ -56,14 +56,10 @@ def get_dataloader_sampler(dataloader): - # after accelerate.prepare function the wraped dataloader.sampler will be SequentialSampler instead of RandomSampler - if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, RandomSampler): - return dataloader.sampler - elif hasattr(dataloader, "batch_sampler") and dataloader.batch_sampler is not None: + if hasattr(dataloader, "batch_sampler") and dataloader.batch_sampler is not None: return get_dataloader_sampler(dataloader.batch_sampler) elif hasattr(dataloader, "sampler"): return dataloader.sampler - return None def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]):