Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pt: avoid torch.tensor(constant) during forward #3421

Merged
merged 7 commits into from
Mar 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions deepmd/pt/loss/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
coord_mask = label["coord_mask"]
type_mask = label["type_mask"]

loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]

Check warning on line 55 in deepmd/pt/loss/denoise.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/denoise.py#L55

Added line #L55 was not covered by tests
more_loss = {}
if self.has_coord:
if self.mask_loss_coord:
Expand All @@ -66,9 +66,9 @@
beta=self.beta,
)
else:
coord_loss = torch.tensor(
0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
coord_loss = torch.zeros(

Check warning on line 69 in deepmd/pt/loss/denoise.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/denoise.py#L69

Added line #L69 was not covered by tests
1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)[0]
else:
coord_loss = F.smooth_l1_loss(
updated_coord.view(-1, 3),
Expand All @@ -89,9 +89,9 @@
reduction="mean",
)
else:
token_loss = torch.tensor(
0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)
token_loss = torch.zeros(

Check warning on line 92 in deepmd/pt/loss/denoise.py

View check run for this annotation

Codecov / codecov/patch

deepmd/pt/loss/denoise.py#L92

Added line #L92 was not covered by tests
1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE
)[0]
else:
token_loss = F.nll_loss(
F.log_softmax(logits.view(-1, self.ntypes - 1), dim=-1),
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/loss/ener.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def forward(self, model_pred, label, natoms, learning_rate, mae=False):
pref_e = self.limit_pref_e + (self.start_pref_e - self.limit_pref_e) * coef
pref_f = self.limit_pref_f + (self.start_pref_f - self.limit_pref_f) * coef
pref_v = self.limit_pref_v + (self.start_pref_v - self.limit_pref_v) * coef
loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
# more_loss['log_keys'] = [] # showed when validation on the fly
# more_loss['test_keys'] = [] # showed when doing dp test
Expand Down
2 changes: 1 addition & 1 deletion deepmd/pt/loss/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def forward(self, model_pred, label, natoms, learning_rate=0.0, mae=False):
Other losses for display.
"""
del learning_rate, mae
loss = torch.tensor(0.0, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)
loss = torch.zeros(1, dtype=env.GLOBAL_PT_FLOAT_PRECISION, device=env.DEVICE)[0]
more_loss = {}
if (
self.has_local_weight
Expand Down
14 changes: 8 additions & 6 deletions deepmd/pt/model/atomic_model/linear_atomic_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ def __init__(

self.atomic_bias = None
self.mixed_types_list = [model.mixed_types() for model in self.models]
self.rcuts = torch.tensor(
self.get_model_rcuts(), dtype=torch.float64, device=env.DEVICE
)
self.nsels = torch.tensor(self.get_model_nsels(), device=env.DEVICE)
BaseAtomicModel.__init__(self, **kwargs)

def mixed_types(self) -> bool:
Expand Down Expand Up @@ -117,14 +121,12 @@ def get_model_sels(self) -> List[List[int]]:
"""Get the sels for each individual models."""
return [model.get_sel() for model in self.models]

def _sort_rcuts_sels(self, device: torch.device) -> Tuple[List[float], List[int]]:
def _sort_rcuts_sels(self) -> Tuple[List[float], List[int]]:
# sort the pair of rcut and sels in ascending order, first based on sel, then on rcut.
rcuts = torch.tensor(self.get_model_rcuts(), dtype=torch.float64, device=device)
nsels = torch.tensor(self.get_model_nsels(), device=device)
zipped = torch.stack(
[
rcuts,
nsels,
self.rcuts,
self.nsels,
],
dim=0,
).T
Expand Down Expand Up @@ -171,7 +173,7 @@ def forward_atomic(
if self.do_grad_r() or self.do_grad_c():
extended_coord.requires_grad_(True)
extended_coord = extended_coord.view(nframes, -1, 3)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels(device=extended_coord.device)
sorted_rcuts, sorted_sels = self._sort_rcuts_sels()
nlists = build_multiple_neighbor_list(
extended_coord,
nlist,
Expand Down