From 1d532eb16975be74fb4fe5cb134d0a73b9fa73b0 Mon Sep 17 00:00:00 2001 From: Duo <50307526+iProzd@users.noreply.github.com> Date: Mon, 4 Mar 2024 22:45:09 +0800 Subject: [PATCH] pt: fix multitask stuck on multiple-gpu --- deepmd/pt/model/descriptor/descriptor.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/deepmd/pt/model/descriptor/descriptor.py b/deepmd/pt/model/descriptor/descriptor.py index 24c1ef4dab..5aae848aa4 100644 --- a/deepmd/pt/model/descriptor/descriptor.py +++ b/deepmd/pt/model/descriptor/descriptor.py @@ -126,16 +126,18 @@ def share_params(self, base_class, shared_level, resume=False): ), "Only descriptors of the same type can share params!" if shared_level == 0: # link buffers - if hasattr(self, "mean") and not resume: - # in case of change params during resume - base_env = EnvMatStatSe(base_class) - base_env.stats = base_class.stats - for kk in base_class.get_stats(): - base_env.stats[kk] += self.get_stats()[kk] - mean, stddev = base_env() - if not base_class.set_davg_zero: - base_class.mean.copy_(torch.tensor(mean, device=env.DEVICE)) - base_class.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) + if hasattr(self, "mean"): + if not resume: + # in case of change params during resume + base_env = EnvMatStatSe(base_class) + base_env.stats = base_class.stats + for kk in base_class.get_stats(): + base_env.stats[kk] += self.get_stats()[kk] + mean, stddev = base_env() + if not base_class.set_davg_zero: + base_class.mean.copy_(torch.tensor(mean, device=env.DEVICE)) + base_class.stddev.copy_(torch.tensor(stddev, device=env.DEVICE)) + # must share, even if not do stat self.mean = base_class.mean self.stddev = base_class.stddev # self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model