Skip to content

Commit

Permalink
pt: fix multitask stuck on multiple-gpu (#3411)
Browse files Browse the repository at this point in the history
  • Loading branch information
iProzd authored Mar 5, 2024
1 parent 4454811 commit c8c941a
Showing 1 changed file with 12 additions and 10 deletions.
22 changes: 12 additions & 10 deletions deepmd/pt/model/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,18 @@ def share_params(self, base_class, shared_level, resume=False):
), "Only descriptors of the same type can share params!"
if shared_level == 0:
# link buffers
if hasattr(self, "mean") and not resume:
# in case of change params during resume
base_env = EnvMatStatSe(base_class)
base_env.stats = base_class.stats
for kk in base_class.get_stats():
base_env.stats[kk] += self.get_stats()[kk]
mean, stddev = base_env()
if not base_class.set_davg_zero:
base_class.mean.copy_(torch.tensor(mean, device=env.DEVICE))
base_class.stddev.copy_(torch.tensor(stddev, device=env.DEVICE))
if hasattr(self, "mean"):
if not resume:
# in case of change params during resume
base_env = EnvMatStatSe(base_class)
base_env.stats = base_class.stats
for kk in base_class.get_stats():
base_env.stats[kk] += self.get_stats()[kk]
mean, stddev = base_env()
if not base_class.set_davg_zero:
base_class.mean.copy_(torch.tensor(mean, device=env.DEVICE))
base_class.stddev.copy_(torch.tensor(stddev, device=env.DEVICE))
# must share, even if not do stat
self.mean = base_class.mean
self.stddev = base_class.stddev
# self.load_state_dict(base_class.state_dict()) # this does not work, because it only inits the model
Expand Down

0 comments on commit c8c941a

Please sign in to comment.