Skip to content

Commit

Permalink
fix recompute (#42128)
Browse files Browse the repository at this point in the history
* fix recompute

* modify return
  • Loading branch information
sljlp committed Apr 25, 2022
1 parent 2616796 commit 08ff007
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions python/paddle/incubate/distributed/models/moe/moe_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,7 @@ def forward(self, inp):
def experts_fwd(x, fwd_expert_count, experts):

if x.shape[0] == 0:
return paddle.empty(x.shape, x.dtype)
return x
y = []
last_index = 0
assert isinstance(fwd_expert_count, np.ndarray)
Expand All @@ -411,7 +411,7 @@ def experts_fwd(x, fwd_expert_count, experts):
last_index = expert_count + last_index
return paddle.concat(y, axis=0)

if self.recompute_interval <= 0:
if self.recompute_interval <= 0 or x.shape[0] == 0:
x = experts_fwd(x, fwd_expert_count.numpy(), self.experts)
else:
x = _hp_recompute(experts_fwd, x,
Expand Down

1 comment on commit 08ff007

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.