diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index eebb635e3ead7..ba22ffee3e4d6 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -399,7 +399,7 @@ def forward(self, inp): def experts_fwd(x, fwd_expert_count, experts): if x.shape[0] == 0: - return paddle.empty(x.shape, x.dtype) + return x y = [] last_index = 0 assert isinstance(fwd_expert_count, np.ndarray) @@ -411,7 +411,7 @@ def experts_fwd(x, fwd_expert_count, experts): last_index = expert_count + last_index return paddle.concat(y, axis=0) - if self.recompute_interval <= 0: + if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: x = _hp_recompute(experts_fwd, x,