From 08ff007d823ea76a224d281b0658c1f565586289 Mon Sep 17 00:00:00 2001 From: Roc <30228238+sljlp@users.noreply.github.com> Date: Mon, 25 Apr 2022 10:40:27 +0800 Subject: [PATCH] fix recompute (#42128) * fix recompute * modify return --- python/paddle/incubate/distributed/models/moe/moe_layer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/incubate/distributed/models/moe/moe_layer.py b/python/paddle/incubate/distributed/models/moe/moe_layer.py index eebb635e3ead7..ba22ffee3e4d6 100644 --- a/python/paddle/incubate/distributed/models/moe/moe_layer.py +++ b/python/paddle/incubate/distributed/models/moe/moe_layer.py @@ -399,7 +399,7 @@ def forward(self, inp): def experts_fwd(x, fwd_expert_count, experts): if x.shape[0] == 0: - return paddle.empty(x.shape, x.dtype) + return x y = [] last_index = 0 assert isinstance(fwd_expert_count, np.ndarray) @@ -411,7 +411,7 @@ def experts_fwd(x, fwd_expert_count, experts): last_index = expert_count + last_index return paddle.concat(y, axis=0) - if self.recompute_interval <= 0: + if self.recompute_interval <= 0 or x.shape[0] == 0: x = experts_fwd(x, fwd_expert_count.numpy(), self.experts) else: x = _hp_recompute(experts_fwd, x,