Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a conversion function for other packages #139

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/src/get-started/rascaline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ ML potentials, visualization or similarity analysis.

There exist several libraries able to compute such structural representations,
such as `DScribe`_, `QUIP`_, and many more. Rascaline tries to distinguish
itself by focussing on speed and memory efficiency of the calculations, with the
itself by focusing on speed and memory efficiency of the calculations, with the
explicit goal of running molecular simulations with ML potentials. In
particular, memory efficiency is achieved by using the `equistore`_ to store the
structural representation. Additionally, rascaline is not limited to a single
Expand All @@ -18,6 +18,10 @@ representation but supports several:
:start-after: inclusion-marker-representations-start
:end-before: inclusion-marker-representations-end

To help users familiar with these other libraries, we have a functionality in `rascaline.utils`
called `convert_old_hyperparameter_names` to show how to port your existing workflows to
rascaline. Note that, because rascaline takes a different approach to computing
descriptors, not all functionalities are supported.

.. _DScribe: https://singroup.github.io/dscribe/
.. _QUIP: https://www.libatoms.org
Expand Down
224 changes: 224 additions & 0 deletions python/rascaline/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from ._c_api import RASCAL_BUFFER_SIZE_ERROR
from .status import RascalError

import warnings


def _call_with_growing_buffer(callback, initial=1024):
bufflen = initial
Expand All @@ -20,3 +22,225 @@ def _call_with_growing_buffer(callback, initial=1024):
else:
raise
return buffer.value.decode("utf8")


def convert_old_hyperparameter_names(hyperparameters, mode):
"""
Function to convert old hyperparameter names to those
used in rascaline. This function is meant to be dep-
recated as rascaline becomes more mainstream, but will
serve to help users convert existing workflows.

Notes
-----
- This function does validate the values in the hyperparameter
dictionary, and it is up to the user to check that they pass
valid entries to `rascaline`.
- Not all of these parameters are supported in rascaline,
and some will raise warnings.

Parameters
----------

mode: string in ["librascal", "dscribe"]
We anticipate future support for mode=="quip" as well

hyperparameters: dictionary of hyperparameter keys and values.
For mode = `librascal`, the anticipated values are:
- `coefficient_subselection`
- `compute_gradients`
- `covariant_lambda`
- `cutoff_function_parameters`
- `cutoff_function_type`
- `cutoff_smooth_width`
- `expansion_by_species_method`
- `gaussian_sigma_constant`
- `gaussian_sigma_type`
- `global_species`
- `interaction_cutoff`
- `inversion_symmetry`
- `max_angular`
- `max_radial`
- `normalize`
- `optimization_args`
- `optimization`
- `radial_basis`
- `soap_type`
For mode = `dscribe`, the anticipated values are:
- `average`
- `crossover`
- `dtype`
- `nmax`
- `lmax`
- `periodic`
- `rbf`
- `rcut`
- `sigma`
- `sparse`
- `species`

"""
new_hypers = {}

if mode == "librascal":
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for readability & extensibility, I would have separate functions for each mode: _translate_soap_hyperparameters_from_librascal, _translate_soap_hyperparameters_from_dscribe and potentially later _translate_soap_hyperparameters_from_quip

anticipated_hypers = [
"coefficient_subselection",
"compute_gradients",
"covariant_lambda",
"cutoff_function_parameters",
"cutoff_function_type",
"cutoff_smooth_width",
"expansion_by_species_method",
"gaussian_sigma_constant",
"gaussian_sigma_type",
"global_species",
"interaction_cutoff",
"inversion_symmetry",
"max_angular",
"max_radial",
"normalize",
"optimization_args",
"optimization",
"radial_basis",
"soap_type",
]

if any([key not in anticipated_hypers for key in hyperparameters]):
raise ValueError(
"I do not know what to do with the following hyperparameter entries:\n\t".format(
"\n\t".join(
[
key
for key in hyperparameters
if key not in anticipated_hypers
]
)
)
)

new_hypers["atomic_gaussian_width"] = hyperparameters.pop(
"gaussian_sigma_constant", None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

librascal has a lot of default values: gaussian_sigma_constant defaults to 0.3, cutoff_smooth_width defaults to 0.5, …

Should this code use the librascal defaults when translating to rascaline?

)
new_hypers["max_angular"] = hyperparameters.pop("max_angular", None)
new_hypers["max_radial"] = hyperparameters.pop("max_radial", None)
new_hypers["cutoff"] = hyperparameters.pop("interaction_cutoff", None)

if 'radial_basis' in hyperparameters:
new_hypers["radial_basis"] = {hyperparameters.pop("radial_basis").title(): {}}
if new_hypers["radial_basis"] != "Gto":
warnings.warn("WARNING: rascaline currently only supports a Gto basis.")

if hyperparameters.get("cutoff_function_type", None) == "ShiftedCosine":
new_hypers["cutoff_function"] = {
hyperparameters.pop("cutoff_function_type", None): {
"width": hyperparameters.pop("cutoff_smooth_width", None)
}
}
else:
new_hypers["cutoff_function"] = {"Step": {}}
if hyperparameters.get("cutoff_function_type", None) == "RadialScaling":
params = hyperparameters.pop("cutoff_function_parameters", None)
new_hypers["radial_scaling"] = {
"Willatt2018": {
"exponent": int(params.get("exponent", None)),
"rate": params.get("rate", None),
"scale": params.get("scale", None),
}
}

deprecated_params = [
"global_species",
"expansion_by_species_method",
"soap_type",
"compute_gradients",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

compute_gradients could point people to the gradients argument of the compute function

]
if any([d in hyperparameters for d in deprecated_params]):
warnings.warn(
"{} are not required parameters in the rascaline software infrastructure".format(
",".join(
[f"`{d}`" for d in deprecated_params if d in hyperparameters]
)
)
)

not_supported = [
"coefficient_subselection",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

coefficient_subselection could point people to the selected_properties argument of the compute function

"covariant_lambda",
"gaussian_sigma_type",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

gaussian_sigma_type="Constant" (which is the only one implemented by librascal) is also what rascaline is doing by default

"inversion_symmetry",
"normalize",
"optimization_args",
"optimization",
Comment on lines +172 to +173
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some of these could be translated to "Gto" radial basis options in rascaline

]
if any([d in hyperparameters for d in not_supported]):
warnings.warn(
"{} are not currently supported in rascaline".format(
",".join([f"`{d}`" for d in not_supported if d in hyperparameters])
)
)

return {k: v for k,v in new_hypers.items() if v is not None}
elif mode == "dscribe":
anticipated_hypers = [
"rcut",
"nmax",
"lmax",
"species",
"sigma",
"rbf",
"periodic",
"crossover",
"average",
"sparse",
"dtype",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is also weighting which can be translated to either radial_scaling (weighting/function/pow in dscribe) or central_atom_weight (weighting/w0 in dscribe)

]

if any([key not in anticipated_hypers for key in hyperparameters]):
raise ValueError(
"I do not know what to do with the following hyperparameter entries:\n\t".format(
"\n\t".join(
[
key
for key in hyperparameters
if key not in anticipated_hypers
]
)
)
)

new_hypers["atomic_gaussian_width"] = hyperparameters.pop("sigma", None)
new_hypers["max_angular"] = hyperparameters.pop("lmax", None)
new_hypers["max_radial"] = hyperparameters.pop("nmax", None)
new_hypers["cutoff"] = hyperparameters.pop("rcut", None)

if 'rbf' in hyperparameters:
new_hypers["radial_basis"] = {hyperparameters.pop("rbf").title(): {}}
if new_hypers["radial_basis"] != "Gto":
warnings.warn("WARNING: rascaline currently only supports a Gto basis.")

deprecated_params = ["average", "sparse", "dtype"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

average="outer" could point to equistore.operations.mean_over_samples, and average="off" is the default

if any([d in hyperparameters for d in deprecated_params]):
warnings.warn(
"{} are not required parameters in the rascaline software infrastructure".format(
",".join(
[f"`{d}`" for d in deprecated_params if d in hyperparameters]
)
)
)

not_supported = [
"periodic",
"crossover",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

crossover=True is what rascaline does, crossover=False could be achieved once #134 is merged with a selected_keys parameter to the compute function

]
if any([d in hyperparameters for d in not_supported]):
warnings.warn(
"{} are not currently supported in rascaline".format(
",".join([f"`{d}`" for d in not_supported if d in hyperparameters])
)
)

return {k: v for k,v in new_hypers.items() if v is not None}
else:
raise ValueError(
f"Mode {mode} is not supported and must be either `librascal` or `dscribe`."
)
88 changes: 88 additions & 0 deletions python/tests/misc.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
# -*- coding: utf-8 -*-
import os
import unittest
import warnings

import rascaline
from rascaline.utils import convert_old_hyperparameter_names


class TestCMakePrefixPath(unittest.TestCase):
Expand All @@ -18,5 +20,91 @@ def test_cmake_files_exists(self):
)


class TestConverter(unittest.TestCase):
"""
Tests the hyperparameter conversions in `rascaline.utils`
"""

def test_mode(self):
from rascaline.utils import convert_old_hyperparameter_names

with self.assertRaises(ValueError):
convert_old_hyperparameter_names({}, mode="BadMode")
Comment on lines +31 to +32
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also check the error message when checking that exceptions are raised, otherwise it is really easy to get the tests to pass with a different error than the expected one beening raised.

Suggested change
with self.assertRaises(ValueError):
convert_old_hyperparameter_names({}, mode="BadMode")
with self.assertRaises(ValueError) as cm:
convert_old_hyperparameter_names({}, mode="BadMode")
self.assertEqual(str(cm.exception), "some error message")


def test_errant_params(self):
with self.assertRaises(ValueError):
convert_old_hyperparameter_names({"bad_param": 0}, mode="librascal")
with self.assertRaises(ValueError):
convert_old_hyperparameter_names({"bad_param": 0}, mode="dscribe")

def test_not_gto(self):
with warnings.catch_warnings(record=True) as w:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that works, but I've been using self.assertWarns everywhere else, so it might be better to use it here as well for consistency

convert_old_hyperparameter_names(
{"radial_basis": "NOT_GTO"}, mode="librascal"
)
self.assertEquals(
str(w[-1].message),
"WARNING: rascaline currently only supports a Gto basis.",
)
with warnings.catch_warnings(record=True) as w:
convert_old_hyperparameter_names({"rbf": "NOT_GTO"}, mode="dscribe")
self.assertEquals(
str(w[-1].message),
"WARNING: rascaline currently only supports a Gto basis.",
)

def test_param_warnings(self):
with warnings.catch_warnings(record=True) as w:
convert_old_hyperparameter_names({"global_species": [0]}, mode="librascal")
self.assertEquals(
str(w[-1].message),
"`global_species` are not required parameters in the rascaline software infrastructure",
)
with warnings.catch_warnings(record=True) as w:
convert_old_hyperparameter_names({"average": 0}, mode="dscribe")
self.assertEquals(
str(w[-1].message),
"`average` are not required parameters in the rascaline software infrastructure",
)
with warnings.catch_warnings(record=True) as w:
convert_old_hyperparameter_names({"coefficient_subselection": [0]}, mode="librascal")
self.assertEquals(
str(w[-1].message),
"`coefficient_subselection` are not currently supported in rascaline"
)
with warnings.catch_warnings(record=True) as w:
convert_old_hyperparameter_names({"periodic": 0}, mode="dscribe")
self.assertEquals(
str(w[-1].message),
"`periodic` are not currently supported in rascaline"
)

def test_radial_scaling(self):
new_hypers = convert_old_hyperparameter_names(
{
"cutoff_function_type": "RadialScaling",
"cutoff_function_parameters": {
"exponent": 3,
"rate": 1.0,
"scale": 1.5,
},
},
mode="librascal",
)
self.assertEqual(new_hypers["radial_scaling"]['Willatt2018']["exponent"], 3)
self.assertEqual(new_hypers["radial_scaling"]['Willatt2018']["scale"], 1.5)
self.assertEqual(new_hypers["radial_scaling"]['Willatt2018']["rate"], 1.0)


def test_shifted_cosine(self):
new_hypers = convert_old_hyperparameter_names(
{
"cutoff_function_type": "ShiftedCosine",
"cutoff_smooth_width": 0.5
},
mode="librascal",
)
self.assertEqual(new_hypers["cutoff_function"]['ShiftedCosine']["width"], 0.5)

if __name__ == "__main__":
unittest.main()