diff --git a/paddlenlp/transformers/nezha/modeling.py b/paddlenlp/transformers/nezha/modeling.py index ef2b2572f37d..9643aff6af98 100644 --- a/paddlenlp/transformers/nezha/modeling.py +++ b/paddlenlp/transformers/nezha/modeling.py @@ -820,7 +820,7 @@ def forward( # If we are on multi-GPU, split add a dimension if start_positions.ndim > 1: start_positions = start_positions.squeeze(-1) - if start_positions.ndim > 1: + if end_positions.ndim > 1: end_positions = end_positions.squeeze(-1) # sometimes the start/end positions are outside our model inputs, we ignore these terms ignored_index = paddle.shape(start_logits)[1]