Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try to fix training Loss inconsistent after resume from old checkpoint #25872

Merged
merged 10 commits into from
Sep 7, 2023
17 changes: 14 additions & 3 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES, MODEL_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS
from .pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_less_than_1_11
from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import (
CallbackHandler,
Expand All @@ -85,6 +85,7 @@
distributed_broadcast_scalars,
distributed_concat,
find_batch_size,
get_dataloader_sampler,
get_model_param_count,
get_module_class_from_name,
get_parameter_names,
Expand Down Expand Up @@ -215,6 +216,7 @@
if TYPE_CHECKING:
import optuna


logger = logging.get_logger(__name__)


Expand Down Expand Up @@ -1782,8 +1784,17 @@ def _inner_training_loop(
# Skip the first epochs_trained epochs to get the random state of the dataloader at the right point.
if not args.ignore_data_skip:
for epoch in range(epochs_trained):
for _ in train_dataloader:
break
sampler = get_dataloader_sampler(train_dataloader)
is_random_sampler = isinstance(sampler, RandomSampler)
if is_torch_less_than_1_11 or not is_random_sampler:
# We just need to begin an iteration to create the randomization of the sampler.
for _ in train_dataloader:
break
else:
# Otherwise we need to call the whooooole sampler cause there is some random operation added
# AT THE VERY END!
sampler = sampler if sampler is not None else []
_ = list(sampler)

total_batched_samples = 0
for epoch in range(epochs_trained, num_train_epochs):
Expand Down
9 changes: 9 additions & 0 deletions src/transformers/trainer_pt_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,15 @@
logger = logging.get_logger(__name__)


def get_dataloader_sampler(dataloader):
# after accelerate.prepare function the wraped dataloader.sampler will be SequentialSampler instead of RandomSampler
if hasattr(dataloader, "sampler") and isinstance(dataloader.sampler, RandomSampler):
return dataloader.sampler
if hasattr(dataloader, "batch_sampler") and dataloader.batch_sampler is not None:
return get_dataloader_sampler(dataloader.batch_sampler)
return dataloader.sampler
dumpmemory marked this conversation as resolved.
Show resolved Hide resolved


def atleast_1d(tensor_or_array: Union[torch.Tensor, np.ndarray]):
if isinstance(tensor_or_array, torch.Tensor):
if hasattr(torch, "atleast_1d"):
Expand Down