From 8c8859d1f6e5556198d669b26f569fff7772e348 Mon Sep 17 00:00:00 2001 From: sneaxiy Date: Fri, 11 Oct 2024 15:52:13 +0800 Subject: [PATCH] cherry-pick prs --- paddlenlp/trainer/trainer.py | 14 +++++++++++++- paddlenlp/trainer/utils/sharding_io.py | 26 ++++++++++++++++++++++++-- 2 files changed, 37 insertions(+), 3 deletions(-) diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 7a4c73a2e14cc1..ab377acf644a04 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -197,6 +197,17 @@ def in_auto_parallel_align_mode(): return False +try: + from paddle.framework.recall_error import LOSS_NAN_ERROR +except ImportError: + LOSS_NAN_ERROR = "PaddleRecall error(102): LossNan" + +try: + from paddle.framework.recall_error import LOSS_INF_ERROR +except ImportError: + LOSS_INF_ERROR = "PaddleRecall error(104): LossInf" + + __all__ = ["Trainer"] @@ -1368,7 +1379,8 @@ def _get_item_from_loss(self, loss): loss_value = loss.item() if not self.args.fp16: if not np.isfinite(loss_value).all(): - raise ValueError(f"Loss contains inf or nan values, its value is {loss_value}") + err_msg = LOSS_NAN_ERROR if np.isnan(loss_value).any() else LOSS_INF_ERROR + raise ValueError(f"{err_msg}. Loss contains inf or nan values, its value is {loss_value}") return loss_value def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs): diff --git a/paddlenlp/trainer/utils/sharding_io.py b/paddlenlp/trainer/utils/sharding_io.py index 0926988771bab8..da45fb1f810275 100644 --- a/paddlenlp/trainer/utils/sharding_io.py +++ b/paddlenlp/trainer/utils/sharding_io.py @@ -38,6 +38,7 @@ ) from paddlenlp.transformers.utils import paddlenlp_load from paddlenlp.utils.log import logger +from paddlenlp.utils.tools import get_env_device from . import reshard as reshard_util from .reshard import SHARDING_STRATEGY_V1, SHARDING_STRATEGY_V2, pp_reshard @@ -53,6 +54,22 @@ SHARDING_META_NAME = "shard_meta.json" +def to_device(tensor, place=None): + if place is None: + place = get_env_device() + + if isinstance(place, str): + place = paddle.device._convert_to_place(place) + + if not tensor.place._equals(place): + new_t = tensor._copy_to(place, True) + dst_tensor = tensor.value().get_tensor() + src_tensor = new_t.value().get_tensor() + dst_tensor._share_data_with(src_tensor) + + return tensor + + def filter_sharded_params(state_dict, optimizer, sharding_group): sharding_rank = sharding_group.rank @@ -239,6 +256,9 @@ def _need_reshard(self, checkpoint): param2rank = sharding_meta["param2rank"] optimizer = unwrap_optimizer(self.optimizer, DygraphShardingOptimizer) assert optimizer + if len(param2rank) == 0: + logger.warning("The param2rank is empty. Force reshard would be performed.") + return True assert len(param2rank) == len(optimizer._param2rank) for (k, v) in param2rank.items(): assert k in optimizer._param2rank @@ -460,7 +480,7 @@ def _recover_params_from_master_weights(self, state_dict, opt_state_dict=None): # cast to before for (k, v) in tmp.items(): name = v.name - master_weights[k] = paddle.cast(v.cuda(), paddle.bfloat16).cpu() + master_weights[k] = paddle.cast(to_device(v), paddle.bfloat16).cpu() master_weights[k].name = name structure_name_map = {k: v.name for (k, v) in self.model.state_dict().items()} @@ -491,7 +511,9 @@ def filter_func(name): for key, param in model_state_dict.items(): if param.name in master_weights: assert param.shape == master_weights[param.name].shape - paddle.assign(paddle.cast(master_weights[param.name].cuda(), paddle.bfloat16), model_state_dict[key]) + paddle.assign( + paddle.cast(to_device(master_weights[param.name]), paddle.bfloat16), model_state_dict[key] + ) elif key in state_dict: logger.info(f"key: {key} is in state_dict, but not in master_weights") paddle.assign(state_dict[key], model_state_dict[key])