Skip to content

Commit

Permalink
fix issues with q + enforce graph-level deup-dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexDuvalinho committed May 17, 2024
1 parent 175567e commit 03f3038
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 19 deletions.
2 changes: 2 additions & 0 deletions ocpmodels/datasets/deup_dataset_creator.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def create_deup_dataset(
pred_mean, batch.y_relaxed.to(pred_mean.device)
)
# Store deup samples
assert len(preds["q"]) == len(batch)
deup_samples += [
{
"energy_target": batch.y_relaxed.clone(),
Expand Down Expand Up @@ -481,6 +482,7 @@ def parse_args():
# base_config = make_config_from_conf_str("faenet-is2re-all")
# base_datasets_config = base_config["dataset"]

# Load deup dataset
deup_dataset = DeupDataset(
{
**base_datasets_config,
Expand Down
2 changes: 1 addition & 1 deletion ocpmodels/models/depfaenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,6 @@ def __init__(self, **kwargs):
def energy_forward(self, data, q=None):
# We need to save the tags so this step is necessary.
self.output_block.tags_saver(data.tags)
pred = super().energy_forward(data)
pred = super().energy_forward(data, q)

return pred
20 changes: 11 additions & 9 deletions ocpmodels/models/deup_depfaenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def __init__(
)

def forward(self, h, edge_index, edge_weight, batch, alpha, data=None):
# If sample density is used as feature, we need to add the extra dimension
if self._set_q_dim:
assert data is not None
assert "deup_q" in data.to_dict().keys()
Expand Down Expand Up @@ -58,13 +59,14 @@ def forward(self, h, edge_index, edge_weight, batch, alpha, data=None):
}:
h = h * alpha

# Global pooling -- get final graph rep
out = scatter(
h,
batch,
dim=0,
reduce="mean" if self.deup_extra_dim > 0 else "add",
)
# Pool into a graph rep if necessary
if len(h) > len(batch):
h = scatter(
h,
batch,
dim=0,
reduce="mean" if self.deup_extra_dim > 0 else "add",
)

# Concat graph representation with deup features (s, kde(q), std)
# and apply MLPs
Expand All @@ -76,7 +78,7 @@ def forward(self, h, edge_index, edge_weight, batch, alpha, data=None):
+ f" from the data dict ({data_keys})"
)
out = torch.cat(
[out]
[h]
+ [data[f"deup_{k}"][:, None].float() for k in self.deup_features],
dim=-1,
)
Expand All @@ -87,7 +89,7 @@ def forward(self, h, edge_index, edge_weight, batch, alpha, data=None):
return out

@registry.register_model("deup_depfaenet")
class DeupFAENet(DepFAENet):
class DeupDepFAENet(DepFAENet):
def __init__(self, *args, **kwargs):
kwargs["dropout_edge"] = 0
super().__init__(*args, **kwargs)
Expand Down
15 changes: 8 additions & 7 deletions ocpmodels/models/deup_faenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,13 @@ def forward(self, h, edge_index, edge_weight, batch, alpha, data=None):
h = h * alpha

# Global pooling -- get final graph rep
out = scatter(
h,
batch,
dim=0,
reduce="mean" if self.deup_extra_dim > 0 else "add",
)
if len(h) > len(batch):
h = scatter(
h,
batch,
dim=0,
reduce="mean" if self.deup_extra_dim > 0 else "add",
)

# Concat graph representation with deup features (s, kde(q), std)
# and apply MLPs
Expand All @@ -75,7 +76,7 @@ def forward(self, h, edge_index, edge_weight, batch, alpha, data=None):
+ f" from the data dict ({data_keys})"
)
out = torch.cat(
[out]
[h]
+ [data[f"deup_{k}"][:, None].float() for k in self.deup_features],
dim=-1,
)
Expand Down
6 changes: 4 additions & 2 deletions ocpmodels/models/faenet.py
Original file line number Diff line number Diff line change
Expand Up @@ -711,7 +711,7 @@ def energy_forward(self, data, q=None):
edge_attr = edge_attr[edge_mask]
rel_pos = rel_pos[edge_mask]

if q is None:
if not hasattr(data, "deup_q"):
# Embedding block
h, e = self.embed_block(z, rel_pos, edge_attr, data.tags)

Expand Down Expand Up @@ -754,6 +754,7 @@ def energy_forward(self, data, q=None):
# WARNING
# q which is NOT the hidden state h if it was stored as a scattered
# version of h. This works for GPs, NOT for MC-dropout
q = data.deup_q # No need to clone # TODO: check that it's not a problem (move to deup models)
h = q
alpha = None

Expand All @@ -766,7 +767,8 @@ def energy_forward(self, data, q=None):
elif self.skip_co == "add":
energy = sum(energy_skip_co)

if q and len(q) > len(energy):
# Store graph-level representation. # TODO: maybe want node-level rep
if q is not None and len(q) > len(energy): # N_atoms x hidden_channels
q = scatter(q, batch, dim=0, reduce="mean") # N_graphs x hidden_channels

preds = {
Expand Down

0 comments on commit 03f3038

Please sign in to comment.