From e15b97b00973dcaea325777e0302fb992be575a8 Mon Sep 17 00:00:00 2001 From: Jinzhe Zeng Date: Tue, 17 Sep 2024 14:45:21 -0400 Subject: [PATCH] fix(pt): fix zero inputs for LayerNorm Fix #4064. Signed-off-by: Jinzhe Zeng --- deepmd/pt/model/network/layernorm.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/deepmd/pt/model/network/layernorm.py b/deepmd/pt/model/network/layernorm.py index f5cd6b965f..c1c2c29c87 100644 --- a/deepmd/pt/model/network/layernorm.py +++ b/deepmd/pt/model/network/layernorm.py @@ -96,8 +96,11 @@ def forward( # variance = xx.var(dim=-1, unbiased=False, keepdim=True) # The following operation is the same as above, but will not raise error when using jit model to inference. # See https://github.com/pytorch/pytorch/issues/85792 - variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True) - yy = (xx - mean) / torch.sqrt(variance + self.eps) + if xx.numel() > 0: + variance, mean = torch.var_mean(xx, dim=-1, unbiased=False, keepdim=True) + yy = (xx - mean) / torch.sqrt(variance + self.eps) + else: + yy = xx if self.matrix is not None and self.bias is not None: yy = yy * self.matrix + self.bias return yy