Skip to content

Commit

Permalink
cherry-pick prs
Browse files Browse the repository at this point in the history
  • Loading branch information
sneaxiy committed Oct 11, 2024
1 parent 156182e commit 8c8859d
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 3 deletions.
14 changes: 13 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,17 @@ def in_auto_parallel_align_mode():
return False


try:
from paddle.framework.recall_error import LOSS_NAN_ERROR
except ImportError:
LOSS_NAN_ERROR = "PaddleRecall error(102): LossNan"

try:
from paddle.framework.recall_error import LOSS_INF_ERROR
except ImportError:
LOSS_INF_ERROR = "PaddleRecall error(104): LossInf"


__all__ = ["Trainer"]


Expand Down Expand Up @@ -1368,7 +1379,8 @@ def _get_item_from_loss(self, loss):
loss_value = loss.item()
if not self.args.fp16:
if not np.isfinite(loss_value).all():
raise ValueError(f"Loss contains inf or nan values, its value is {loss_value}")
err_msg = LOSS_NAN_ERROR if np.isnan(loss_value).any() else LOSS_INF_ERROR
raise ValueError(f"{err_msg}. Loss contains inf or nan values, its value is {loss_value}")
return loss_value

def _maybe_log_save_evaluate(self, tr_loss, model, epoch, ignore_keys_for_eval, **kwargs):
Expand Down
26 changes: 24 additions & 2 deletions paddlenlp/trainer/utils/sharding_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
)
from paddlenlp.transformers.utils import paddlenlp_load
from paddlenlp.utils.log import logger
from paddlenlp.utils.tools import get_env_device

from . import reshard as reshard_util
from .reshard import SHARDING_STRATEGY_V1, SHARDING_STRATEGY_V2, pp_reshard
Expand All @@ -53,6 +54,22 @@
SHARDING_META_NAME = "shard_meta.json"


def to_device(tensor, place=None):
if place is None:
place = get_env_device()

if isinstance(place, str):
place = paddle.device._convert_to_place(place)

if not tensor.place._equals(place):
new_t = tensor._copy_to(place, True)
dst_tensor = tensor.value().get_tensor()
src_tensor = new_t.value().get_tensor()
dst_tensor._share_data_with(src_tensor)

return tensor


def filter_sharded_params(state_dict, optimizer, sharding_group):

sharding_rank = sharding_group.rank
Expand Down Expand Up @@ -239,6 +256,9 @@ def _need_reshard(self, checkpoint):
param2rank = sharding_meta["param2rank"]
optimizer = unwrap_optimizer(self.optimizer, DygraphShardingOptimizer)
assert optimizer
if len(param2rank) == 0:
logger.warning("The param2rank is empty. Force reshard would be performed.")
return True
assert len(param2rank) == len(optimizer._param2rank)
for (k, v) in param2rank.items():
assert k in optimizer._param2rank
Expand Down Expand Up @@ -460,7 +480,7 @@ def _recover_params_from_master_weights(self, state_dict, opt_state_dict=None):
# cast to before
for (k, v) in tmp.items():
name = v.name
master_weights[k] = paddle.cast(v.cuda(), paddle.bfloat16).cpu()
master_weights[k] = paddle.cast(to_device(v), paddle.bfloat16).cpu()
master_weights[k].name = name

structure_name_map = {k: v.name for (k, v) in self.model.state_dict().items()}
Expand Down Expand Up @@ -491,7 +511,9 @@ def filter_func(name):
for key, param in model_state_dict.items():
if param.name in master_weights:
assert param.shape == master_weights[param.name].shape
paddle.assign(paddle.cast(master_weights[param.name].cuda(), paddle.bfloat16), model_state_dict[key])
paddle.assign(
paddle.cast(to_device(master_weights[param.name]), paddle.bfloat16), model_state_dict[key]
)
elif key in state_dict:
logger.info(f"key: {key} is in state_dict, but not in master_weights")
paddle.assign(state_dict[key], model_state_dict[key])
Expand Down

0 comments on commit 8c8859d

Please sign in to comment.