-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding a conversion function for other packages #139
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,6 +4,8 @@ | |
from ._c_api import RASCAL_BUFFER_SIZE_ERROR | ||
from .status import RascalError | ||
|
||
import warnings | ||
|
||
|
||
def _call_with_growing_buffer(callback, initial=1024): | ||
bufflen = initial | ||
|
@@ -20,3 +22,225 @@ def _call_with_growing_buffer(callback, initial=1024): | |
else: | ||
raise | ||
return buffer.value.decode("utf8") | ||
|
||
|
||
def convert_old_hyperparameter_names(hyperparameters, mode): | ||
""" | ||
Function to convert old hyperparameter names to those | ||
used in rascaline. This function is meant to be dep- | ||
recated as rascaline becomes more mainstream, but will | ||
serve to help users convert existing workflows. | ||
|
||
Notes | ||
----- | ||
- This function does validate the values in the hyperparameter | ||
dictionary, and it is up to the user to check that they pass | ||
valid entries to `rascaline`. | ||
- Not all of these parameters are supported in rascaline, | ||
and some will raise warnings. | ||
|
||
Parameters | ||
---------- | ||
|
||
mode: string in ["librascal", "dscribe"] | ||
We anticipate future support for mode=="quip" as well | ||
|
||
hyperparameters: dictionary of hyperparameter keys and values. | ||
For mode = `librascal`, the anticipated values are: | ||
- `coefficient_subselection` | ||
- `compute_gradients` | ||
- `covariant_lambda` | ||
- `cutoff_function_parameters` | ||
- `cutoff_function_type` | ||
- `cutoff_smooth_width` | ||
- `expansion_by_species_method` | ||
- `gaussian_sigma_constant` | ||
- `gaussian_sigma_type` | ||
- `global_species` | ||
- `interaction_cutoff` | ||
- `inversion_symmetry` | ||
- `max_angular` | ||
- `max_radial` | ||
- `normalize` | ||
- `optimization_args` | ||
- `optimization` | ||
- `radial_basis` | ||
- `soap_type` | ||
For mode = `dscribe`, the anticipated values are: | ||
- `average` | ||
- `crossover` | ||
- `dtype` | ||
- `nmax` | ||
- `lmax` | ||
- `periodic` | ||
- `rbf` | ||
- `rcut` | ||
- `sigma` | ||
- `sparse` | ||
- `species` | ||
|
||
""" | ||
new_hypers = {} | ||
|
||
if mode == "librascal": | ||
anticipated_hypers = [ | ||
"coefficient_subselection", | ||
"compute_gradients", | ||
"covariant_lambda", | ||
"cutoff_function_parameters", | ||
"cutoff_function_type", | ||
"cutoff_smooth_width", | ||
"expansion_by_species_method", | ||
"gaussian_sigma_constant", | ||
"gaussian_sigma_type", | ||
"global_species", | ||
"interaction_cutoff", | ||
"inversion_symmetry", | ||
"max_angular", | ||
"max_radial", | ||
"normalize", | ||
"optimization_args", | ||
"optimization", | ||
"radial_basis", | ||
"soap_type", | ||
] | ||
|
||
if any([key not in anticipated_hypers for key in hyperparameters]): | ||
raise ValueError( | ||
"I do not know what to do with the following hyperparameter entries:\n\t".format( | ||
"\n\t".join( | ||
[ | ||
key | ||
for key in hyperparameters | ||
if key not in anticipated_hypers | ||
] | ||
) | ||
) | ||
) | ||
|
||
new_hypers["atomic_gaussian_width"] = hyperparameters.pop( | ||
"gaussian_sigma_constant", None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. librascal has a lot of default values: Should this code use the librascal defaults when translating to rascaline? |
||
) | ||
new_hypers["max_angular"] = hyperparameters.pop("max_angular", None) | ||
new_hypers["max_radial"] = hyperparameters.pop("max_radial", None) | ||
new_hypers["cutoff"] = hyperparameters.pop("interaction_cutoff", None) | ||
|
||
if 'radial_basis' in hyperparameters: | ||
new_hypers["radial_basis"] = {hyperparameters.pop("radial_basis").title(): {}} | ||
if new_hypers["radial_basis"] != "Gto": | ||
warnings.warn("WARNING: rascaline currently only supports a Gto basis.") | ||
|
||
if hyperparameters.get("cutoff_function_type", None) == "ShiftedCosine": | ||
new_hypers["cutoff_function"] = { | ||
hyperparameters.pop("cutoff_function_type", None): { | ||
"width": hyperparameters.pop("cutoff_smooth_width", None) | ||
} | ||
} | ||
else: | ||
new_hypers["cutoff_function"] = {"Step": {}} | ||
if hyperparameters.get("cutoff_function_type", None) == "RadialScaling": | ||
params = hyperparameters.pop("cutoff_function_parameters", None) | ||
new_hypers["radial_scaling"] = { | ||
"Willatt2018": { | ||
"exponent": int(params.get("exponent", None)), | ||
"rate": params.get("rate", None), | ||
"scale": params.get("scale", None), | ||
} | ||
} | ||
|
||
deprecated_params = [ | ||
"global_species", | ||
"expansion_by_species_method", | ||
"soap_type", | ||
"compute_gradients", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
] | ||
if any([d in hyperparameters for d in deprecated_params]): | ||
warnings.warn( | ||
"{} are not required parameters in the rascaline software infrastructure".format( | ||
",".join( | ||
[f"`{d}`" for d in deprecated_params if d in hyperparameters] | ||
) | ||
) | ||
) | ||
|
||
not_supported = [ | ||
"coefficient_subselection", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"covariant_lambda", | ||
"gaussian_sigma_type", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
"inversion_symmetry", | ||
"normalize", | ||
"optimization_args", | ||
"optimization", | ||
Comment on lines
+172
to
+173
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some of these could be translated to "Gto" radial basis options in rascaline |
||
] | ||
if any([d in hyperparameters for d in not_supported]): | ||
warnings.warn( | ||
"{} are not currently supported in rascaline".format( | ||
",".join([f"`{d}`" for d in not_supported if d in hyperparameters]) | ||
) | ||
) | ||
|
||
return {k: v for k,v in new_hypers.items() if v is not None} | ||
elif mode == "dscribe": | ||
anticipated_hypers = [ | ||
"rcut", | ||
"nmax", | ||
"lmax", | ||
"species", | ||
"sigma", | ||
"rbf", | ||
"periodic", | ||
"crossover", | ||
"average", | ||
"sparse", | ||
"dtype", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there is also |
||
] | ||
|
||
if any([key not in anticipated_hypers for key in hyperparameters]): | ||
raise ValueError( | ||
"I do not know what to do with the following hyperparameter entries:\n\t".format( | ||
"\n\t".join( | ||
[ | ||
key | ||
for key in hyperparameters | ||
if key not in anticipated_hypers | ||
] | ||
) | ||
) | ||
) | ||
|
||
new_hypers["atomic_gaussian_width"] = hyperparameters.pop("sigma", None) | ||
new_hypers["max_angular"] = hyperparameters.pop("lmax", None) | ||
new_hypers["max_radial"] = hyperparameters.pop("nmax", None) | ||
new_hypers["cutoff"] = hyperparameters.pop("rcut", None) | ||
|
||
if 'rbf' in hyperparameters: | ||
new_hypers["radial_basis"] = {hyperparameters.pop("rbf").title(): {}} | ||
if new_hypers["radial_basis"] != "Gto": | ||
warnings.warn("WARNING: rascaline currently only supports a Gto basis.") | ||
|
||
deprecated_params = ["average", "sparse", "dtype"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
if any([d in hyperparameters for d in deprecated_params]): | ||
warnings.warn( | ||
"{} are not required parameters in the rascaline software infrastructure".format( | ||
",".join( | ||
[f"`{d}`" for d in deprecated_params if d in hyperparameters] | ||
) | ||
) | ||
) | ||
|
||
not_supported = [ | ||
"periodic", | ||
"crossover", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
] | ||
if any([d in hyperparameters for d in not_supported]): | ||
warnings.warn( | ||
"{} are not currently supported in rascaline".format( | ||
",".join([f"`{d}`" for d in not_supported if d in hyperparameters]) | ||
) | ||
) | ||
|
||
return {k: v for k,v in new_hypers.items() if v is not None} | ||
else: | ||
raise ValueError( | ||
f"Mode {mode} is not supported and must be either `librascal` or `dscribe`." | ||
) |
Original file line number | Diff line number | Diff line change | ||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
@@ -1,8 +1,10 @@ | ||||||||||||||
# -*- coding: utf-8 -*- | ||||||||||||||
import os | ||||||||||||||
import unittest | ||||||||||||||
import warnings | ||||||||||||||
|
||||||||||||||
import rascaline | ||||||||||||||
from rascaline.utils import convert_old_hyperparameter_names | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class TestCMakePrefixPath(unittest.TestCase): | ||||||||||||||
|
@@ -18,5 +20,91 @@ def test_cmake_files_exists(self): | |||||||||||||
) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
class TestConverter(unittest.TestCase): | ||||||||||||||
""" | ||||||||||||||
Tests the hyperparameter conversions in `rascaline.utils` | ||||||||||||||
""" | ||||||||||||||
|
||||||||||||||
def test_mode(self): | ||||||||||||||
from rascaline.utils import convert_old_hyperparameter_names | ||||||||||||||
|
||||||||||||||
with self.assertRaises(ValueError): | ||||||||||||||
convert_old_hyperparameter_names({}, mode="BadMode") | ||||||||||||||
Comment on lines
+31
to
+32
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please also check the error message when checking that exceptions are raised, otherwise it is really easy to get the tests to pass with a different error than the expected one beening raised.
Suggested change
|
||||||||||||||
|
||||||||||||||
def test_errant_params(self): | ||||||||||||||
with self.assertRaises(ValueError): | ||||||||||||||
convert_old_hyperparameter_names({"bad_param": 0}, mode="librascal") | ||||||||||||||
with self.assertRaises(ValueError): | ||||||||||||||
convert_old_hyperparameter_names({"bad_param": 0}, mode="dscribe") | ||||||||||||||
|
||||||||||||||
def test_not_gto(self): | ||||||||||||||
with warnings.catch_warnings(record=True) as w: | ||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. that works, but I've been using |
||||||||||||||
convert_old_hyperparameter_names( | ||||||||||||||
{"radial_basis": "NOT_GTO"}, mode="librascal" | ||||||||||||||
) | ||||||||||||||
self.assertEquals( | ||||||||||||||
str(w[-1].message), | ||||||||||||||
"WARNING: rascaline currently only supports a Gto basis.", | ||||||||||||||
) | ||||||||||||||
with warnings.catch_warnings(record=True) as w: | ||||||||||||||
convert_old_hyperparameter_names({"rbf": "NOT_GTO"}, mode="dscribe") | ||||||||||||||
self.assertEquals( | ||||||||||||||
str(w[-1].message), | ||||||||||||||
"WARNING: rascaline currently only supports a Gto basis.", | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def test_param_warnings(self): | ||||||||||||||
with warnings.catch_warnings(record=True) as w: | ||||||||||||||
convert_old_hyperparameter_names({"global_species": [0]}, mode="librascal") | ||||||||||||||
self.assertEquals( | ||||||||||||||
str(w[-1].message), | ||||||||||||||
"`global_species` are not required parameters in the rascaline software infrastructure", | ||||||||||||||
) | ||||||||||||||
with warnings.catch_warnings(record=True) as w: | ||||||||||||||
convert_old_hyperparameter_names({"average": 0}, mode="dscribe") | ||||||||||||||
self.assertEquals( | ||||||||||||||
str(w[-1].message), | ||||||||||||||
"`average` are not required parameters in the rascaline software infrastructure", | ||||||||||||||
) | ||||||||||||||
with warnings.catch_warnings(record=True) as w: | ||||||||||||||
convert_old_hyperparameter_names({"coefficient_subselection": [0]}, mode="librascal") | ||||||||||||||
self.assertEquals( | ||||||||||||||
str(w[-1].message), | ||||||||||||||
"`coefficient_subselection` are not currently supported in rascaline" | ||||||||||||||
) | ||||||||||||||
with warnings.catch_warnings(record=True) as w: | ||||||||||||||
convert_old_hyperparameter_names({"periodic": 0}, mode="dscribe") | ||||||||||||||
self.assertEquals( | ||||||||||||||
str(w[-1].message), | ||||||||||||||
"`periodic` are not currently supported in rascaline" | ||||||||||||||
) | ||||||||||||||
|
||||||||||||||
def test_radial_scaling(self): | ||||||||||||||
new_hypers = convert_old_hyperparameter_names( | ||||||||||||||
{ | ||||||||||||||
"cutoff_function_type": "RadialScaling", | ||||||||||||||
"cutoff_function_parameters": { | ||||||||||||||
"exponent": 3, | ||||||||||||||
"rate": 1.0, | ||||||||||||||
"scale": 1.5, | ||||||||||||||
}, | ||||||||||||||
}, | ||||||||||||||
mode="librascal", | ||||||||||||||
) | ||||||||||||||
self.assertEqual(new_hypers["radial_scaling"]['Willatt2018']["exponent"], 3) | ||||||||||||||
self.assertEqual(new_hypers["radial_scaling"]['Willatt2018']["scale"], 1.5) | ||||||||||||||
self.assertEqual(new_hypers["radial_scaling"]['Willatt2018']["rate"], 1.0) | ||||||||||||||
|
||||||||||||||
|
||||||||||||||
def test_shifted_cosine(self): | ||||||||||||||
new_hypers = convert_old_hyperparameter_names( | ||||||||||||||
{ | ||||||||||||||
"cutoff_function_type": "ShiftedCosine", | ||||||||||||||
"cutoff_smooth_width": 0.5 | ||||||||||||||
}, | ||||||||||||||
mode="librascal", | ||||||||||||||
) | ||||||||||||||
self.assertEqual(new_hypers["cutoff_function"]['ShiftedCosine']["width"], 0.5) | ||||||||||||||
|
||||||||||||||
if __name__ == "__main__": | ||||||||||||||
unittest.main() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for readability & extensibility, I would have separate functions for each mode:
_translate_soap_hyperparameters_from_librascal
,_translate_soap_hyperparameters_from_dscribe
and potentially later_translate_soap_hyperparameters_from_quip