-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Improve additive models #427
base: main
Are you sure you want to change the base?
Conversation
024e4f5
to
9277267
Compare
53197a0
to
aded37d
Compare
@@ -122,7 +122,7 @@ def test_regression_train(): | |||
) | |||
|
|||
# if you need to change the hardcoded values: | |||
torch.set_printoptions(precision=12) | |||
print(output["mtt::U0"].block().values) | |||
# torch.set_printoptions(precision=12) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is a leftover I suppose?
fixed_weights: Optional[Dict[str, Dict[int, str]]] = None, | ||
) -> None: | ||
"""Train/fit the composition weights for the datasets. | ||
|
||
:param datasets: Dataset(s) to calculate the composition weights for. | ||
:param fixed_weights: Optional fixed weights to use for the composition model, | ||
for one or more target quantities. | ||
:param additive_models: Additive models to be removed from the targets |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why this removal should happen implicitly in the composition model class? Maybe it is better to have a separate utility function that removes additives?
systems, targets = systems_and_targets_to_device( | ||
systems, targets, device | ||
) | ||
for additive_model in additive_models: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I mean, I would move this for-loop out of the composition model class, it feels like this functionality goes out of the scope of the CompositionModel
f"Composition model does not support target quantity " | ||
f"{target_info.quantity}. This is an architecture bug. " | ||
f"Composition model does not support target " | ||
f"{target_name}. This is an architecture bug. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why the lack of support of the target is an architecture bug? If would not call it a bug, in fact, but maybe it’s just a copy-paste glitch
"ZBL only supports eV units, but a " | ||
f"{target.unit} output was provided." | ||
f"ZBL model does not support target " | ||
f"{target_name}. This is an architecture bug. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question here
if not self.is_valid_target(target_name, target_info): | ||
raise ValueError( | ||
f"ZBL model does not support target " | ||
f"{target_name}. This is an architecture bug. " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here
@@ -261,6 +268,38 @@ def requested_neighbor_lists(self) -> List[NeighborListOptions]: | |||
) | |||
] | |||
|
|||
@staticmethod | |||
def is_valid_target(target_name: str, target_info: TargetInfo) -> bool: | |||
"""Finds if a ``TargetInfo`` object is compatible with a composition model. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Compatible with ZBL model I guess?
The rest of the code looks good to me, my main concern is the CompositionModel |
Improves the additive models and closes #355. In particular:
Contributor (creator of pull-request) checklist
Documentation updated (for new features)?📚 Documentation preview 📚: https://metatrain--427.org.readthedocs.build/en/427/