diff --git a/CI/unit_tests/data/test_double_well_potential.py b/CI/unit_tests/data/test_double_well_potential.py index 2967f36..bb9afbf 100644 --- a/CI/unit_tests/data/test_double_well_potential.py +++ b/CI/unit_tests/data/test_double_well_potential.py @@ -9,6 +9,7 @@ Description: Test the double_well_potential module. """ import unittest + from symsuite.data.double_well_potential import DoubleWellPotential @@ -52,5 +53,5 @@ def test_double_well(self): self.assertEqual(len(self.generator.image), 500) -if __name__ == '__main__': +if __name__ == "__main__": unittest.main() diff --git a/CI/unit_tests/distance_metrics/test_angular_distance.py b/CI/unit_tests/distance_metrics/test_angular_distance.py new file mode 100644 index 0000000..0d16ae3 --- /dev/null +++ b/CI/unit_tests/distance_metrics/test_angular_distance.py @@ -0,0 +1,72 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: +Summary +------- +Test the angular distance module. +""" +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import jax.numpy as np +from numpy.testing import assert_array_almost_equal + +from symsuite.distance_metrics.angular_distance import AngularDistance + + +class TestAngularDistance: + """ + Class to test the cosine distance measure module. + """ + + def test_angular_distance(self): + """ + Test the cosine similarity measure. + + Returns + ------- + Assert the correct answer is returned for orthogonal, parallel, and + somewhere in between. + """ + metric = AngularDistance() + + # Test orthogonal vectors + point_1 = np.array([[1, 0]]) + point_2 = np.array([[0, 1]]) + assert_array_almost_equal(metric(point_1, point_2), [0.5]) + + # Test parallel vectors + point_1 = np.array([[1, 0]]) + point_2 = np.array([[1, 1]]) + assert_array_almost_equal(metric(point_1, point_2), [0.25]) + + def test_multiple_distances(self): + """ + Test the cosine similarity measure. + + Returns + ------- + Assert the correct answer is returned for orthogonal, parallel, and + somewhere in between. + """ + metric = AngularDistance() + + # Test orthogonal vectors + point_1 = np.array([[1, 0], [1, 0]]) + point_2 = np.array([[0, 1], [1, 1]]) + assert_array_almost_equal(metric(point_1, point_2), [0.5, 0.25]) diff --git a/CI/unit_tests/distance_metrics/test_cosine_distance.py b/CI/unit_tests/distance_metrics/test_cosine_distance.py new file mode 100644 index 0000000..6f479a8 --- /dev/null +++ b/CI/unit_tests/distance_metrics/test_cosine_distance.py @@ -0,0 +1,75 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: +Summary +------- +Test the cosine distance module. +""" +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import jax.numpy as np +from numpy.testing import assert_array_almost_equal + +from symsuite.distance_metrics.cosine_distance import CosineDistance + + +class TestCosineDistance: + """ + Class to test the cosine distance measure module. + """ + + def test_cosine_distance(self): + """ + Test the cosine similarity measure. + + Returns + ------- + Assert the correct answer is returned for orthogonal, parallel, and + somewhere in between. + """ + metric = CosineDistance() + + # Test orthogonal vectors + point_1 = np.array([[1, 0, 0, 0]]) + point_2 = np.array([[0, 1, 0, 0]]) + assert_array_almost_equal(metric(point_1, point_2), [1]) + + # Test parallel vectors + assert_array_almost_equal(metric(point_1, point_1), [0]) + + # Somewhere in between + point_1 = np.array([[1.0, 0, 0, 0]]) + point_2 = np.array([[0.5, 1.0, 0, 3.0]]) + assert_array_almost_equal(metric(point_1, point_2), [0.84382623]) + + def test_multiple_distances(self): + """ + Test the cosine similarity measure. + + Returns + ------- + Assert the correct answer is returned for orthogonal, parallel, and + somewhere in between. + """ + metric = CosineDistance() + + # Test orthogonal vectors + point_1 = np.array([[1, 0, 0, 0], [1, 0, 0, 0], [1.0, 0, 0, 0]]) + point_2 = np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0.5, 1.0, 0, 3.0]]) + assert_array_almost_equal(metric(point_1, point_2), [1, 0, 0.843826], decimal=6) diff --git a/CI/unit_tests/distance_metrics/test_hyper_sphere_distance.py b/CI/unit_tests/distance_metrics/test_hyper_sphere_distance.py new file mode 100644 index 0000000..aac3ff8 --- /dev/null +++ b/CI/unit_tests/distance_metrics/test_hyper_sphere_distance.py @@ -0,0 +1,87 @@ +""" +ZnRND: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Test the hyper sphere distance module. +""" +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import jax.numpy as np +from numpy.testing import assert_array_almost_equal + +from symsuite.distance_metrics.hyper_sphere_distance import HyperSphere + + +class TestCosineDistance: + """ + Class to test the cosine distance measure module. + """ + + def test_hyper_sphere_distance(self): + """ + Test the hyper sphere distance. + + Returns + ------- + Assert the correct answer is returned for orthogonal, parallel, and + somewhere in between. + """ + metric = HyperSphere(order=2) + + # Test orthogonal vectors + point_1 = np.array([[1, 0, 0, 0]]) + point_2 = np.array([[0, 1, 0, 0]]) + assert_array_almost_equal(metric(point_1, point_2), [1.41421356]) + + # Test parallel vectors + point_1 = np.array([[1, 0, 0, 0]]) + point_2 = np.array([[1, 0, 0, 0]]) + assert_array_almost_equal(metric(point_1, point_2), [0]) + + # Somewhere in between + point_1 = np.array([[1.0, 0, 0, 0]]) + point_2 = np.array([[0.5, 1.0, 0, 3.0]]) + assert_array_almost_equal( + metric(point_1, point_2), [0.84382623 * np.sqrt(10.25)] + ) + + def test_multiple_distances(self): + """ + Test the hyper sphere distance. + + Returns + ------- + Assert the correct answer is returned for orthogonal, parallel, and + somewhere in between. + """ + metric = HyperSphere(order=2) + + # Test orthogonal vectors + point_1 = np.array([[1, 0, 0, 0], [1, 0, 0, 0], [1.0, 0, 0, 0]]) + point_2 = np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0.5, 1.0, 0, 3.0]]) + assert_array_almost_equal( + metric(point_1, point_2), + [np.sqrt(2), 0, 0.84382623 * np.sqrt(10.25)], + decimal=6, + ) diff --git a/CI/unit_tests/distance_metrics/test_l_p_norm.py b/CI/unit_tests/distance_metrics/test_l_p_norm.py new file mode 100644 index 0000000..e142006 --- /dev/null +++ b/CI/unit_tests/distance_metrics/test_l_p_norm.py @@ -0,0 +1,85 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Test the l_p norm metric. +""" +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import jax.numpy as np +from numpy.testing import assert_almost_equal, assert_array_almost_equal + +from symsuite.distance_metrics.l_p_norm import LPNorm + + +class TestLPNorm: + """ + Class to test the cosine distance measure module. + """ + + def test_l_2_distance(self): + """ + Test the cosine similarity measure. + + Returns + ------- + Assert the correct answer is returned for orthogonal, parallel, and + somewhere in between. + """ + metric = LPNorm(order=2) + + # Test orthogonal vectors + point_1 = np.array([[1.0, 7.0, 0.0, 0.0]]) + point_2 = np.array([[1.0, 1.0, 0.0, 0.0]]) + + assert_array_almost_equal(metric(point_1, point_2), [6.0]) + + def test_l_3_distance(self): + """ + Test the cosine similarity measure. + + Returns + ------- + Assert the correct answer is returned for orthogonal, parallel, and + somewhere in between. + """ + metric = LPNorm(order=3) + + # Test orthogonal vectors + point_1 = np.array([[1.0, 7.0, 0.0, 0.0]]) + point_2 = np.array([[1.0, 1.0, 0.0, 0.0]]) + assert_almost_equal(metric(point_1, point_2), [6.0], decimal=4) + + def test_multi_distance(self): + """ + Test the cosine similarity measure. + + Returns + ------- + Assert the correct answer is returned for orthogonal, parallel, and + somewhere in between. + """ + metric = LPNorm(order=1) + + # Test orthogonal vectors + point_1 = np.array([[1.0, 7.0, 0.0, 0.0], [4, 7, 2, 1]]) + point_2 = np.array([[1.0, 1.0, 0.0, 0.0], [6, 3, 1, 8]]) + assert_array_almost_equal(metric(point_1, point_2), [6.0, 14.0], decimal=4) diff --git a/CI/unit_tests/distance_metrics/test_mahalanobis_distance.py b/CI/unit_tests/distance_metrics/test_mahalanobis_distance.py new file mode 100644 index 0000000..93f8e72 --- /dev/null +++ b/CI/unit_tests/distance_metrics/test_mahalanobis_distance.py @@ -0,0 +1,170 @@ +""" +ZnRND: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Test the angular distance module. +""" +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import jax +import jax.numpy as np +import numpy as onp +import scipy.spatial.distance +from numpy.testing import assert_almost_equal, assert_array_almost_equal + +from symsuite.distance_metrics.mahalanobis_distance import MahalanobisDistance + + +class TestMahalanobisDistance: + """ + Class to test the cosine distance measure module. + """ + + @classmethod + def setup_class(cls): + """ + Prepare the test suite. + """ + cls.key = jax.random.PRNGKey(0) + + def test_mahalanobis_distance(self): + """ + Test the Mahalanobis distance on functionality by comparing results a + test Mahalanobis distance from scipy. + + Returns + ------- + Assert if the Mahalanobis distance returns true values for sample set of + random normal distributed points in two dimensions. + """ + metric = MahalanobisDistance() + + # Create sample set + point_1, point_2 = self.create_sample_set() + + # Calculate results from distance metric + metric_results = metric(np.array(point_1), np.array(point_2)) + + # Calculate test results from numpy distance metric + test_metric_results = [] + self.calculate_numpy_mahalanobis_distance(point_1, point_2, test_metric_results) + + # Assert results + assert_almost_equal(metric_results, test_metric_results, decimal=1) + + def test_identity(self): + """ + Test the identity criterion of a metric, based on a randomly produced sample + set (used to create the covariance matrix). + + Returns + ------- + Asserts if the distance of the last point of point_1 and point_2 is equal to 0 + """ + # Create Sample set + point_1, point_2 = self.create_sample_set() + + # Add point of interest + point_of_interest = np.array([[7.0, 3.0]]) + point_1 = np.concatenate([np.array(point_1), point_of_interest], axis=0) + point_2 = np.concatenate([np.array(point_2), point_of_interest], axis=0) + + # Assert identity + metric = MahalanobisDistance() + assert_array_almost_equal(metric(point_1, point_2)[-1], 0) + + def test_symmetry(self): + """ + Test the symmetry criterion of a metric, based on a randomly produced sample + set (used to create the covariance matrix). + + Returns + ------- + Asserts if the distances of the last two points of point_1 and point_2 are + identical. + """ + # Create Sample set + point_1, point_2 = self.create_sample_set() + + # Add point of interest + point_1_of_interest = np.array([[-2.0, 5.0], [7.0, 3.0]]) + point_2_of_interest = np.array([[7.0, 3.0], [-2.0, 5.0]]) + point_1 = np.concatenate( + [np.array(point_1), point_1_of_interest], + axis=0, + ) + point_2 = np.concatenate( + [np.array(point_2), point_2_of_interest], + axis=0, + ) + + # Assert identity + metric = MahalanobisDistance() + assert_array_almost_equal( + metric(point_1, point_2)[-1], (metric(point_1, point_2)[-2]) + ) + + @staticmethod + def create_sample_set(): + """ + + Returns + ------- + Creates a random normal distributed sample set + """ + point_1 = np.array( + [onp.random.normal(0, 10, 100), onp.random.normal(0, 20, 100)] + ).T + point_2 = np.array( + [onp.random.normal(0, 10, 100), onp.random.normal(0, 20, 100)] + ).T + return point_1, point_2 + + @staticmethod + def calculate_numpy_mahalanobis_distance( + point_1: np.ndarray, point_2: np.ndarray, result_list: list + ): + """ + Calculates the Mahalanobis distance based on a scipy integration. + + Parameters + ---------- + point_1 : np.ndarray + Set of points in the distance calculation. + point_2 : np.ndarray + Set of points in the distance calculation. + result_list : list + Results for each point are appended to this list. + + Returns + ------- + Appends all calculated distances to the result_list. + """ + inv_cov = np.linalg.inv(np.cov(point_1.T)) + for index in range(len(point_1.T[0, :])): + result_list.append( + scipy.spatial.distance.mahalanobis( + point_1[index], point_2[index], inv_cov + ) + ) diff --git a/CI/unit_tests/distance_metrics/test_order_n_difference.py b/CI/unit_tests/distance_metrics/test_order_n_difference.py new file mode 100644 index 0000000..e7340ec --- /dev/null +++ b/CI/unit_tests/distance_metrics/test_order_n_difference.py @@ -0,0 +1,85 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Test the order n norm metric. +""" +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import jax.numpy as np +from numpy.testing import assert_almost_equal, assert_array_equal + +from symsuite.distance_metrics.order_n_difference import OrderNDifference + + +class TestOrderNDifference: + """ + Class to test the cosine distance measure module. + """ + + def test_order_2_distance(self): + """ + Test the cosine similarity measure. + + Returns + ------- + Assert the correct answer is returned for orthogonal, parallel, and + somewhere in between. + """ + metric = OrderNDifference(order=2, reduce_operation="sum") + + # Test orthogonal vectors + point_1 = np.array([[1.0, 7.0, 0.0, 0.0]]) + point_2 = np.array([[1.0, 1.0, 0.0, 0.0]]) + assert_array_equal(metric(point_1, point_2), [36.0]) + + def test_order_3_distance(self): + """ + Test the cosine similarity measure. + + Returns + ------- + Assert the correct answer is returned for orthogonal, parallel, and + somewhere in between. + """ + metric = OrderNDifference(order=3, reduce_operation="sum") + + # Test orthogonal vectors + point_1 = np.array([[1.0, 1.0, 0.0, 0.0]]) + point_2 = np.array([[1.0, 7.0, 0.0, 0.0]]) + + assert_almost_equal(metric(point_1, point_2), [-216.0], decimal=4) + + def test_multi_distance(self): + """ + Test the cosine similarity measure. + + Returns + ------- + Assert the correct answer is returned for orthogonal, parallel, and + somewhere in between. + """ + metric = OrderNDifference(order=3, reduce_operation="sum") + + # Test orthogonal vectors + point_1 = np.array([[1.0, 7.0, 0.0, 0.0], [4, 7, 2, 1]]) + point_2 = np.array([[1.0, 1.0, 0.0, 0.0], [6, 3, 1, 8]]) + assert_almost_equal(metric(point_1, point_2), [216.0, -286.0], decimal=4) diff --git a/CI/unit_tests/loss_functions/test_loss_functions.py b/CI/unit_tests/loss_functions/test_loss_functions.py new file mode 100644 index 0000000..1dfc70c --- /dev/null +++ b/CI/unit_tests/loss_functions/test_loss_functions.py @@ -0,0 +1,93 @@ +""" +ZnRND: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Module for testing the loss functions + +Notes +----- +As the loss functions come directly from distance metrics and the distance metrics are +heavily tested, here we test all loss functions on the same set of data and ensure that +the results are as expected. +""" +import os + +os.environ["CUDA_VISIBLE_DEVICES"] = "-1" + +import jax.numpy as np +import pytest + +from symsuite.loss_functions import ( + AngleDistanceLoss, + CosineDistanceLoss, + LPNormLoss, + MeanPowerLoss, +) + + +class TestLossFunctions: + """ + Class for the testing of the ZnRND loss functions. + """ + + @classmethod + def setup_class(cls): + """ + Prepare the test class + """ + cls.linear_predictions = np.array([[1, 1, 2], [9, 9, 9], [0, 0, 0], [9, 1, 1]]) + cls.linear_targets = np.array([[9, 9, 9], [1, 1, 2], [9, 1, 1], [0, 0, 0]]) + + cls.angular_predictions = np.array([[0, 0, 1], [1, 0, 0], [1, 1, 0], [1, 0, 1]]) + cls.angular_targets = np.array([[1, 0, 0], [0, 0, 1], [1, 0, 1], [1, 1, 0]]) + + def test_absolute_angle(self): + """ + Test the absolute angle loss + """ + loss = AngleDistanceLoss()( + self.angular_predictions / 9, self.angular_targets / 9 + ) + loss == pytest.approx(0.417, 0.0001) + + def test_cosine_distance(self): + """ + Test the cosine_distance loss + """ + loss = CosineDistanceLoss()( + self.angular_predictions / 9, self.angular_targets / 9 + ) + loss == 0.75 + + def test_l_p_norm(self): + """ + Test the l_p norm loss + """ + loss = LPNormLoss(order=2)(self.linear_predictions, self.linear_targets) + loss == pytest.approx(11.207, 0.0001) + + def test_mean_power(self): + """ + Test the mean_power loss + """ + loss = MeanPowerLoss(order=2)(self.linear_predictions, self.linear_targets) + loss == 130.0 diff --git a/CI/unit_tests/symmetry_groups/test_data_clustering.py b/CI/unit_tests/symmetry_groups/test_data_clustering.py index 931ef42..280bab1 100644 --- a/CI/unit_tests/symmetry_groups/test_data_clustering.py +++ b/CI/unit_tests/symmetry_groups/test_data_clustering.py @@ -2,6 +2,7 @@ Test module for the Data Clustering module. """ import unittest + import numpy as np diff --git a/docs/source/conf.py b/docs/source/conf.py index 9fe9555..e70eb26 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -3,6 +3,7 @@ """ import os import sys + import sphinx_rtd_theme sys.path.insert(0, os.path.abspath(".")) @@ -47,35 +48,28 @@ # Material theme options (see theme.conf for more information) html_theme_options = { - # Set the name of the project to appear in the navigation. - 'nav_title': 'SymDet', - + "nav_title": "SymDet", # Set you GA account ID to enable tracking - 'google_analytics_account': 'UA-XXXXX', - + "google_analytics_account": "UA-XXXXX", # Specify a base_url used to generate sitemap.xml. If not # specified, then no sitemap will be built. - 'base_url': 'https://symdet.readthedocs.io/en/latest/', - + "base_url": "https://symdet.readthedocs.io/en/latest/", # Set the color and the accent color - 'color_primary': 'blue', - 'color_accent': 'light-blue', - + "color_primary": "blue", + "color_accent": "light-blue", # Set the repo location to get a badge with stats - 'repo_url': 'https://github.com/SamTov/SymDet', - 'repo_name': 'SymDet', - + "repo_url": "https://github.com/SamTov/SymDet", + "repo_name": "SymDet", # Visible levels of the global TOC; -1 means unlimited - 'globaltoc_depth': 3, + "globaltoc_depth": 3, # If False, expand all TOC entries - 'globaltoc_collapse': False, + "globaltoc_collapse": False, # If True, show hidden TOC entries - 'globaltoc_includehidden': False, + "globaltoc_includehidden": False, } - # Add any paths that contain templates here, relative to this directory. templates_path = ["_templates"] diff --git a/examples/notebooks/SO_example.ipynb b/examples/notebooks/SO_example.ipynb index 8d02ffe..798b5c8 100644 --- a/examples/notebooks/SO_example.ipynb +++ b/examples/notebooks/SO_example.ipynb @@ -344,7 +344,7 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -358,7 +358,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.10" + "version": "3.8.12" } }, "nbformat": 4, diff --git a/examples/notebooks/double_well_investigation.ipynb b/examples/notebooks/double_well_investigation.ipynb index 9e07240..4bbda77 100644 --- a/examples/notebooks/double_well_investigation.ipynb +++ b/examples/notebooks/double_well_investigation.ipynb @@ -37,9 +37,18 @@ "cell_type": "code", "execution_count": 1, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/samueltovey/miniconda3/envs/zincware/lib/python3.8/site-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.23.1\n", + " warnings.warn(f\"A NumPy version >={np_minversion} and <{np_maxversion}\"\n" + ] + } + ], "source": [ - "import symdet" + "import symsuite" ] }, { @@ -56,7 +65,7 @@ "metadata": {}, "outputs": [], "source": [ - "double_well_potential = symdet.DoubleWellPotential(a=2.4)" + "double_well_potential = symsuite.DoubleWellPotential(a=2.4)" ] }, { @@ -71,9 +80,16 @@ "execution_count": 3, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)\n" + ] + }, { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ "
" ] @@ -124,12 +140,12 @@ "name": "stderr", "output_type": "stream", "text": [ - "100%|████████████████████████████████| 11/11 [00:00<00:00, 357.68it/s]\n" + "100%|█████████████████████████████████| 11/11 [00:00<00:00, 81.40it/s]\n" ] }, { "data": { - "image/png": "\n", + "image/png": "\n", "text/plain": [ "
" ] @@ -159,7 +175,7 @@ "metadata": {}, "outputs": [], "source": [ - "model = symdet.DenseModel(n_layers=7,\n", + "model = symsuite.DenseModel(n_layers=7,\n", " units=80,\n", " epochs=10,\n", " batch_size=64,\n", @@ -341,7 +357,7 @@ } ], "source": [ - "sym_detector = symdet.GroupDetection(model, double_well_potential.clustered_data)\n", + "sym_detector = symsuite.GroupDetection(model, double_well_potential.clustered_data)\n", "point_cloud = sym_detector.run_symmetry_detection(plot=True, save=True)" ] }, @@ -378,6 +394,527 @@ "source": [ "We see that the clustering has worked relatively well. This is where some caution is advised in terms of simply parsing this data along to the generator detection. If you see that one of these groups should likely not be in the set then it should be parsed along to the generator extraction stage. Alternatively, you can use a different approach for identifying the sets in the symmetry representation. In the next update we will parse this data to the generator extraction algorithm in order to get the generators of this detected symmetry group." ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [], + "source": [ + "from flax import linen as nn\n", + "import jax\n", + "import jax.numpy as jnp\n", + "import numpy as np\n", + "\n", + "class ProductionModule(nn.Module):\n", + " \"\"\"\n", + " Simple CNN module.\n", + " \"\"\"\n", + "\n", + " @nn.compact\n", + " def __call__(self, x):\n", + " x = nn.Conv(features=128, kernel_size=(3, 3))(x)\n", + " x = nn.relu(x)\n", + " x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))\n", + " x = nn.Conv(features=128, kernel_size=(3, 3))(x)\n", + " x = nn.relu(x)\n", + " x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2))\n", + " x = x.reshape((x.shape[0], -1)) # flatten\n", + " x = nn.Dense(features=300)(x)\n", + " x = nn.relu(x)\n", + " x = nn.Dense(10)(x)\n", + "\n", + " return x" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "model = ProductionModule()" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "key = jax.random.PRNGKey(452)\n", + "\n", + "data = model.init(key, jnp.ones([1, 28, 28, 1]))" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "FrozenDict({\n", + " Conv_0: {\n", + " kernel: DeviceArray([[[[ 0.37528017, 0.13753012, 0.35983738, ...,\n", + " 0.31934947, -0.10718577, 0.06134491]],\n", + " \n", + " [[-0.03669485, 0.14023171, 0.17660904, ...,\n", + " 0.17953686, -0.09349817, 0.5601775 ]],\n", + " \n", + " [[ 0.00360224, -0.23205283, -0.4305723 , ...,\n", + " -0.55263746, 0.08736772, -0.33638862]]],\n", + " \n", + " \n", + " [[[ 0.3732802 , -0.3748587 , 0.18389685, ...,\n", + " -0.31333193, 0.2105282 , 0.02409167]],\n", + " \n", + " [[-0.46691698, 0.49650607, -0.5371514 , ...,\n", + " 0.01834795, 0.56708217, -0.34570318]],\n", + " \n", + " [[ 0.5792693 , -0.35017362, 0.5757499 , ...,\n", + " -0.41792697, -0.20277935, 0.1124992 ]]],\n", + " \n", + " \n", + " [[[-0.5241861 , -0.08293429, -0.15216689, ...,\n", + " 0.48141024, -0.00571131, -0.00457093]],\n", + " \n", + " [[ 0.15416023, 0.04572708, -0.05879792, ...,\n", + " 0.57276475, 0.09112789, 0.07591753]],\n", + " \n", + " [[-0.21100795, -0.17283523, 0.07323457, ...,\n", + " 0.02905115, -0.47458482, 0.04670273]]]], dtype=float32),\n", + " bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", + " },\n", + " Conv_1: {\n", + " kernel: DeviceArray([[[[ 3.60345989e-02, -1.80624407e-02, -9.17060301e-03, ...,\n", + " -1.57420221e-03, -3.21763046e-02, 3.03349812e-02],\n", + " [-4.88650911e-02, 2.21489817e-02, 1.55314393e-02, ...,\n", + " 1.22555168e-02, 2.84199193e-02, 2.09916160e-02],\n", + " [ 5.93110994e-02, -5.48763536e-02, -3.18756956e-03, ...,\n", + " 3.08737922e-02, -8.48920736e-03, -2.10110340e-02],\n", + " ...,\n", + " [ 4.99920659e-02, -6.42840192e-02, -2.96061370e-03, ...,\n", + " 1.10166790e-02, 1.05426693e-02, 2.54060999e-02],\n", + " [ 5.33394516e-02, 4.65421192e-02, 1.49851209e-02, ...,\n", + " 5.68490475e-03, -1.30498027e-02, -1.01780500e-02],\n", + " [ 6.59535229e-02, 9.09722038e-03, -2.24497523e-02, ...,\n", + " 7.19328318e-03, 5.78204077e-03, 5.39870039e-02]],\n", + " \n", + " [[-3.70773338e-02, -3.87179144e-02, 1.63091160e-02, ...,\n", + " 2.97876936e-03, 1.07463440e-02, -1.82296222e-04],\n", + " [-2.36669872e-02, -5.58980182e-03, -2.70987861e-02, ...,\n", + " 3.27060483e-02, 5.49956113e-02, -3.22447494e-02],\n", + " [ 3.49065661e-02, 2.02603359e-02, -4.39962931e-03, ...,\n", + " 1.14598181e-02, -1.17074130e-02, 4.88227000e-03],\n", + " ...,\n", + " [-3.37145887e-02, 4.31741215e-02, 4.60439175e-03, ...,\n", + " -1.50344986e-02, -2.81697568e-02, -2.23013163e-02],\n", + " [ 4.66451123e-02, 2.47632357e-04, 1.16665391e-02, ...,\n", + " -4.74373903e-03, 7.42866658e-03, -5.94518706e-02],\n", + " [ 1.07875392e-02, -1.24657005e-02, -3.20048593e-02, ...,\n", + " 1.44016631e-02, -4.46076095e-02, -4.64649275e-02]],\n", + " \n", + " [[-1.87598765e-02, -1.26873795e-02, 2.22337674e-02, ...,\n", + " -1.86812971e-02, 3.24476734e-02, -2.59078629e-02],\n", + " [ 2.18403228e-02, -2.43654586e-02, -5.68524711e-02, ...,\n", + " -3.14403288e-02, 1.45379817e-02, -3.71383354e-02],\n", + " [ 5.04917093e-02, 3.70514989e-02, -2.47972971e-03, ...,\n", + " 9.21094697e-03, 6.62889564e-03, 5.43021150e-02],\n", + " ...,\n", + " [ 5.53301275e-02, 8.09799600e-03, -7.37548014e-03, ...,\n", + " -3.22527029e-02, 9.47414152e-03, 1.66119877e-02],\n", + " [ 3.12854238e-02, -9.37583914e-04, -2.83447886e-03, ...,\n", + " 6.54492676e-02, -1.27467439e-02, -2.87838411e-02],\n", + " [ 4.77980748e-02, -1.19424276e-02, 1.74422693e-02, ...,\n", + " -3.70779335e-02, 2.49283463e-02, 5.49281761e-03]]],\n", + " \n", + " \n", + " [[[ 2.17519812e-02, -3.88661511e-02, 3.92616428e-02, ...,\n", + " 5.27783893e-02, 9.30878986e-03, -1.11992378e-02],\n", + " [-4.48921323e-02, 2.66084112e-02, -3.08889486e-02, ...,\n", + " -2.99759768e-03, 2.66164411e-02, -2.56100800e-02],\n", + " [ 1.48384757e-02, 2.49346737e-02, 2.10087299e-02, ...,\n", + " 1.19575774e-02, -4.37101051e-02, 1.59911637e-03],\n", + " ...,\n", + " [ 4.79628369e-02, 2.65846997e-02, -5.52016264e-03, ...,\n", + " 8.03794805e-03, 1.57203991e-02, -3.72503996e-02],\n", + " [ 3.36065628e-02, 5.82088120e-02, -4.79097627e-02, ...,\n", + " 1.47540374e-02, -4.82356735e-02, 2.53046560e-03],\n", + " [-1.85025409e-02, 1.19747752e-02, -2.92210910e-03, ...,\n", + " -3.48893404e-02, -4.21325751e-02, 1.10742254e-02]],\n", + " \n", + " [[-2.33020280e-02, -5.33820689e-02, 6.04530687e-05, ...,\n", + " -1.82941053e-02, 2.49706823e-02, 3.89664173e-02],\n", + " [ 3.02372547e-03, 3.97237539e-02, -4.37463261e-02, ...,\n", + " -4.79917275e-03, -3.19175012e-02, -5.08953705e-02],\n", + " [-3.39236967e-02, -6.98743481e-03, 2.17721332e-02, ...,\n", + " -4.89775054e-02, -5.73655218e-02, 3.42595093e-02],\n", + " ...,\n", + " [ 3.00649554e-02, -8.40300415e-03, -7.04725645e-03, ...,\n", + " -3.24916989e-02, 1.48367360e-02, 3.87442624e-03],\n", + " [ 1.28357522e-02, 1.16875945e-02, 2.62607671e-02, ...,\n", + " 2.08278652e-02, -2.04269513e-02, 2.65005231e-02],\n", + " [ 1.62072293e-02, 1.26523693e-04, 6.51293248e-02, ...,\n", + " -3.86682414e-02, -1.51754823e-03, -1.52847609e-02]],\n", + " \n", + " [[-1.05653480e-02, -3.87853314e-03, 3.05805751e-03, ...,\n", + " 3.76726035e-03, 4.94486131e-02, -5.60158156e-02],\n", + " [ 4.87619005e-02, -4.03809361e-03, 5.23778163e-02, ...,\n", + " -2.93569081e-02, 3.27635743e-02, 2.99357139e-02],\n", + " [-2.07389537e-02, -3.72745059e-02, 1.10834790e-02, ...,\n", + " -6.21929392e-02, 2.98373811e-02, 1.99949909e-02],\n", + " ...,\n", + " [ 3.27900052e-02, 5.98487668e-02, 4.71824668e-02, ...,\n", + " 8.78540706e-03, -6.14813250e-03, -3.14074531e-02],\n", + " [-9.78266541e-03, -5.90384342e-02, -5.13021387e-02, ...,\n", + " -1.38909575e-02, -2.51286477e-02, 5.18727489e-03],\n", + " [-3.64494771e-02, -4.27394621e-02, -3.23385722e-03, ...,\n", + " -3.18818376e-03, -5.87067306e-02, -6.36329651e-02]]],\n", + " \n", + " \n", + " [[[-8.47443007e-03, -4.90162708e-02, -1.27213998e-02, ...,\n", + " -7.46757211e-03, -2.31349524e-02, -3.39074843e-02],\n", + " [ 3.42674591e-02, -4.26242724e-02, 8.94031115e-03, ...,\n", + " 5.81473745e-02, 3.11985286e-03, -7.18289101e-03],\n", + " [-4.96962480e-02, -1.75054520e-02, 1.33445645e-02, ...,\n", + " 1.52927013e-02, 7.62744015e-03, 1.03257475e-02],\n", + " ...,\n", + " [-2.10905168e-02, 8.50216951e-04, -2.00528335e-02, ...,\n", + " 6.31971285e-03, 2.48127226e-02, -1.87783968e-02],\n", + " [-2.23999172e-02, 1.26594771e-02, 6.04727156e-02, ...,\n", + " -6.21030182e-02, 5.47166280e-02, -1.63479019e-02],\n", + " [ 9.98661667e-03, -2.33470928e-03, -2.09867079e-02, ...,\n", + " -3.46814878e-02, 2.51958854e-02, -1.26770735e-02]],\n", + " \n", + " [[-1.97642110e-02, 4.73201126e-02, 4.43987735e-02, ...,\n", + " 2.50004977e-02, -2.66237138e-03, -2.26721596e-02],\n", + " [ 6.78458950e-03, 1.61966719e-02, 2.98016015e-02, ...,\n", + " -5.48582349e-04, -4.17552330e-02, -4.90760766e-02],\n", + " [-3.85830216e-02, 6.56447783e-02, 6.06308272e-03, ...,\n", + " 7.21618021e-03, 3.87805365e-02, -4.35139164e-02],\n", + " ...,\n", + " [ 1.90136768e-02, 5.99899096e-03, 2.55533494e-02, ...,\n", + " -4.07748334e-02, -2.88647264e-02, 3.18611637e-02],\n", + " [ 2.06444710e-02, 3.52281965e-02, -3.74659821e-02, ...,\n", + " -2.47550700e-02, 7.98308384e-03, -1.17612109e-02],\n", + " [ 1.07358827e-03, -1.51244262e-02, 4.03979719e-02, ...,\n", + " 1.38022611e-02, -1.51816355e-02, 1.29964007e-02]],\n", + " \n", + " [[ 4.76968586e-02, -1.51051348e-02, 1.60713233e-02, ...,\n", + " -1.36397481e-02, -4.37372923e-02, -6.09542839e-02],\n", + " [-7.29514472e-03, 4.64752316e-03, 7.45077617e-03, ...,\n", + " -5.41699976e-02, 1.31481951e-02, -3.93584818e-02],\n", + " [ 5.68703189e-02, -1.18301352e-02, 2.59039719e-02, ...,\n", + " 4.63015288e-02, -1.92511063e-02, -2.26892084e-02],\n", + " ...,\n", + " [ 3.30666825e-02, -1.54788038e-02, 1.28783807e-02, ...,\n", + " -4.02262770e-02, -4.79671173e-03, 3.56641621e-03],\n", + " [-1.72251593e-02, -1.34579502e-02, -4.67023253e-02, ...,\n", + " 2.73214784e-02, 3.60498987e-02, -4.09855619e-02],\n", + " [ 2.44301874e-02, -3.01268771e-02, 4.78487536e-02, ...,\n", + " -3.52564305e-02, -4.78573143e-02, -1.05750179e-02]]]], dtype=float32),\n", + " bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", + " },\n", + " Dense_0: {\n", + " kernel: DeviceArray([[ 0.00046263, -0.01366332, -0.01965372, ..., -0.01099672,\n", + " 0.00353724, 0.02137254],\n", + " [ 0.00497269, -0.00472533, -0.01134661, ..., -0.00669816,\n", + " -0.00842239, -0.02093654],\n", + " [ 0.03009967, 0.00874493, 0.00541858, ..., 0.02110084,\n", + " -0.00377029, -0.02680648],\n", + " ...,\n", + " [-0.00427975, -0.00622658, -0.00404923, ..., 0.00897263,\n", + " 0.00962972, -0.00889211],\n", + " [-0.01013104, -0.00755738, 0.01790517, ..., -0.0158775 ,\n", + " -0.01318363, -0.01386497],\n", + " [ 0.02921496, -0.01328129, 0.01899061, ..., 0.03121358,\n", + " 0.02289535, 0.01520465]], dtype=float32),\n", + " bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,\n", + " 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", + " },\n", + " Dense_1: {\n", + " kernel: DeviceArray([[ 0.09170318, 0.03614756, 0.02535907, ..., 0.0762569 ,\n", + " -0.0044047 , 0.0330187 ],\n", + " [-0.07894796, -0.04608489, 0.00269128, ..., 0.02883765,\n", + " 0.0087204 , 0.05653023],\n", + " [ 0.03623685, 0.00995611, -0.04651761, ..., -0.05153248,\n", + " 0.02583626, -0.01592986],\n", + " ...,\n", + " [ 0.02275847, 0.08729529, -0.0204505 , ..., 0.09371196,\n", + " 0.06301026, 0.07643232],\n", + " [-0.02012173, 0.06339496, 0.01162816, ..., 0.07186693,\n", + " -0.05107251, -0.01174682],\n", + " [ 0.0267258 , 0.12019555, -0.06517248, ..., 0.00200361,\n", + " 0.02145562, -0.10603211]], dtype=float32),\n", + " bias: DeviceArray([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32),\n", + " },\n", + "})" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data[\"params\"]" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "from neural_tangents import stax" + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": {}, + "outputs": [], + "source": [ + "model = stax.serial(\n", + " stax.Dense(12),\n", + " stax.Relu(),\n", + " stax.Dense(12),\n", + " stax.Relu(),\n", + " stax.Dense(1)\n", + ")\n", + "small_model = stax.serial(\n", + " stax.Dense(12),\n", + " stax.Relu(),\n", + " stax.Dense(12),\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": {}, + "outputs": [], + "source": [ + "_, params = model[0](key, (9,))" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": {}, + "outputs": [], + "source": [ + "tst = np.random.uniform(size=(9,))" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray([-0.00011978], dtype=float32)" + ] + }, + "execution_count": 63, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model[1](params, tst)" + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[(DeviceArray([[ 0.36902574, -0.60083115, 0.44724423, -1.6919093 ,\n", + " -1.4313933 , -0.9526142 , 1.5278178 , 1.2445855 ,\n", + " -0.43463358, 0.1921436 , -0.51044935, 1.3951449 ],\n", + " [ 2.6688576 , 0.27140683, 0.7197667 , 1.7743273 ,\n", + " -0.12421109, 2.8912168 , 0.8812158 , -0.59462166,\n", + " 0.75510025, -0.82257533, -0.69549227, 0.11659083],\n", + " [ 0.7083261 , 0.0195996 , -0.96690786, -1.9975785 ,\n", + " -1.2147071 , 0.49638763, -0.882764 , 0.06965447,\n", + " 0.8740023 , -1.7946595 , -1.9639919 , 0.11853004],\n", + " [-0.04729311, -0.6084024 , -1.151995 , 1.8441046 ,\n", + " -1.0241059 , 0.29526377, -1.6498058 , 0.35885495,\n", + " -1.8048742 , -0.45532495, 0.86956114, 0.7042805 ],\n", + " [ 0.7024425 , -0.46502215, -1.8246489 , 0.03024639,\n", + " -0.09241743, 0.6734406 , -1.4146967 , 0.83981234,\n", + " 1.6875414 , 1.6478448 , -0.14617752, 0.73602235],\n", + " [ 0.11008073, 0.59370816, -0.47999948, -0.91698813,\n", + " 1.4644262 , -1.0099113 , 0.42220187, 0.57421154,\n", + " -1.6171031 , -0.4017566 , 1.6489285 , 0.18311332],\n", + " [ 0.9582781 , -1.557588 , -0.19959736, -0.15300396,\n", + " 1.4889674 , -0.253901 , 0.37354448, 1.5264944 ,\n", + " -2.0760427 , -0.5360183 , 0.61596334, 0.24212281],\n", + " [ 1.1130695 , -0.9824064 , -2.977521 , 0.74219394,\n", + " -1.1848937 , -0.12566446, -0.52775025, 0.28832507,\n", + " 1.53694 , 0.731218 , 1.3954059 , 1.9324546 ],\n", + " [-0.2378023 , 0.3914802 , -0.62872165, 0.3234975 ,\n", + " -0.61664855, -0.37052616, 0.4133723 , 0.3731773 ,\n", + " -1.6918828 , -0.40905452, 1.9508578 , -0.66678524]], dtype=float32),\n", + " None),\n", + " (),\n", + " (DeviceArray([[ 0.31568572, -0.78239375, 1.8139715 , 0.7123382 ,\n", + " 0.42159593, 1.4614258 , -1.2938766 , -0.2708335 ,\n", + " -0.8100683 , 0.42933106, 0.34290957, 2.5893972 ],\n", + " [-1.7116116 , -1.1352941 , -0.5111037 , -1.4054971 ,\n", + " -0.32070762, -1.42568 , 0.11512792, -0.1347293 ,\n", + " -0.71804935, 0.05384207, -0.78869826, -0.5870854 ],\n", + " [-0.32054782, 1.2623926 , -0.27968442, -0.591598 ,\n", + " -0.4438807 , -0.8524836 , 1.3454165 , 1.0766404 ,\n", + " 0.45279294, 0.83987784, 0.03888967, -0.19251114],\n", + " [-0.77271 , -0.04992997, -0.3330754 , 0.3508638 ,\n", + " -0.10676403, -0.58075947, 0.02378667, -1.9671077 ,\n", + " 0.0314757 , -0.82272756, -0.00710218, -1.1613281 ],\n", + " [-1.5903914 , 1.1984488 , -0.27399755, -0.06564435,\n", + " -0.62434304, 0.18636021, -0.29711506, 1.8457091 ,\n", + " -1.0972382 , 0.4340854 , 0.5363783 , -0.8186496 ],\n", + " [-0.19632275, 0.9917772 , 0.48805034, -0.83440447,\n", + " 0.18030544, 0.5788825 , 1.2664341 , 0.9014271 ,\n", + " 0.20820503, -0.60801953, -0.1659515 , -1.5893393 ],\n", + " [-1.0448416 , -1.0183587 , 1.1974787 , -1.5628817 ,\n", + " -1.1902788 , 0.02965477, -0.11841193, 0.9383633 ,\n", + " -0.6725073 , 1.2212063 , 0.8176451 , -0.25271893],\n", + " [-1.2024895 , -0.1259256 , -1.1257274 , -0.0786797 ,\n", + " -1.5528514 , -0.13004757, 0.49707723, -0.7027365 ,\n", + " 0.48414403, -2.0060594 , 0.5135008 , -0.800257 ],\n", + " [-0.21854441, -1.3389152 , -1.1716666 , 1.6518806 ,\n", + " 0.5396923 , 1.181409 , -0.1754383 , 1.391895 ,\n", + " -0.12838942, -0.28058812, -0.6165152 , -1.2462329 ],\n", + " [-0.9806118 , -0.0681522 , -0.43525243, 0.08544233,\n", + " 1.0447279 , 0.49811652, -1.262155 , -0.6817887 ,\n", + " -0.80958515, 0.31914094, -1.1228462 , -0.87152374],\n", + " [ 0.51965606, 0.7304284 , 0.18613489, 0.38158152,\n", + " -0.93155473, 0.37154427, 0.32466838, 0.65689003,\n", + " 0.7206714 , -0.4563752 , -0.5785838 , 1.020707 ],\n", + " [-0.28219694, -0.88026786, 0.4212645 , -1.3743248 ,\n", + " 1.375825 , 1.9244089 , -0.40021232, 0.4462603 ,\n", + " 1.5846783 , -0.4113239 , -0.7848289 , 0.5002319 ]], dtype=float32),\n", + " None)]" + ] + }, + "execution_count": 68, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "params[:-2]" + ] + }, + { + "cell_type": "code", + "execution_count": 76, + "metadata": {}, + "outputs": [], + "source": [ + "a = small_model[1](params, tst)" + ] + }, + { + "cell_type": "code", + "execution_count": 79, + "metadata": {}, + "outputs": [], + "source": [ + "b = small_model[1](params, tst)" + ] + }, + { + "cell_type": "code", + "execution_count": 80, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "DeviceArray([ True, True, True, True, True, True, True, True,\n", + " True, True, True, True], dtype=bool)" + ] + }, + "execution_count": 80, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a == b" + ] + }, + { + "cell_type": "code", + "execution_count": 84, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(.init_fun(rng, input_shape)>,\n", + " .apply_fun(params, inputs, **kwargs)>,\n", + " .kernel_fn_any(x1_or_kernel: Union[List[jax._src.numpy.ndarray.ndarray], Tuple[jax._src.numpy.ndarray.ndarray, ...], jax._src.numpy.ndarray.ndarray, List[neural_tangents._src.utils.kernel.Kernel], Tuple[neural_tangents._src.utils.kernel.Kernel, ...], neural_tangents._src.utils.kernel.Kernel], x2: Union[List[jax._src.numpy.ndarray.ndarray], Tuple[jax._src.numpy.ndarray.ndarray, ...], jax._src.numpy.ndarray.ndarray, NoneType] = None, get: Union[Tuple[str, ...], str, NoneType] = None, *, pattern: Union[Tuple[Union[jax._src.numpy.ndarray.ndarray, NoneType], Union[jax._src.numpy.ndarray.ndarray, NoneType]], NoneType] = None, mask_constant: Union[float, NoneType] = None, diagonal_batch: Union[bool, NoneType] = None, diagonal_spatial: Union[bool, NoneType] = None, **kwargs)>)" + ] + }, + "execution_count": 84, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "model" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -396,7 +933,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.9.5" + "version": "3.8.12" } }, "nbformat": 4, diff --git a/examples/scripts/SO2_extraction.py b/examples/scripts/SO2_extraction.py index c75cf25..3bf488a 100644 --- a/examples/scripts/SO2_extraction.py +++ b/examples/scripts/SO2_extraction.py @@ -2,10 +2,11 @@ Python module to show generator extraction of SO(2) Lie algebra generators. """ -from symsuite.test_systems.so2_data import SO2 -from symsuite.generators.generators import GeneratorExtraction import numpy as np +from symsuite.generators.generators import GeneratorExtraction +from symsuite.test_systems.so2_data import SO2 + def generator_extraction(): """ diff --git a/examples/scripts/SO3_extraction.py b/examples/scripts/SO3_extraction.py index c7cbf90..6743e69 100644 --- a/examples/scripts/SO3_extraction.py +++ b/examples/scripts/SO3_extraction.py @@ -2,8 +2,8 @@ Python module to show generator extraction of SO(3) Lie algebra generators. """ -from symsuite.test_systems.so3_data import SO3 from symsuite.generators.generators import GeneratorExtraction +from symsuite.test_systems.so3_data import SO3 def generator_extraction(): diff --git a/examples/scripts/double_well_investigation.py b/examples/scripts/double_well_investigation.py index 937a9ac..5414207 100644 --- a/examples/scripts/double_well_investigation.py +++ b/examples/scripts/double_well_investigation.py @@ -8,10 +8,10 @@ network and visualizing its embedding layer using TSNE. """ -from symsuite.test_systems.double_well_potential import DoubleWellPotential from symsuite.models.dense_model import DenseModel from symsuite.symmetry_groups.data_clustering import DataCluster from symsuite.symmetry_groups.group_detection import GroupDetection +from symsuite.test_systems.double_well_potential import DoubleWellPotential def main(): diff --git a/requirements.txt b/requirements.txt index 93ef214..a9c21b8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,3 +10,4 @@ IPython pandoc numpydoc pre-commit +jax diff --git a/setup.py b/setup.py index a1c9cc8..f16041d 100644 --- a/setup.py +++ b/setup.py @@ -2,9 +2,10 @@ Setup.py file for the SymDet package. """ -import setuptools from os import path +import setuptools + here = path.abspath(path.dirname(__file__)) with open(path.join(here, "requirements.txt")) as requirements_file: # Parse requirements.txt, ignoring any commented-out lines. diff --git a/symsuite/__init__.py b/symsuite/__init__.py index a7d749e..8465e94 100644 --- a/symsuite/__init__.py +++ b/symsuite/__init__.py @@ -2,12 +2,13 @@ __init__ file for the symsuite package """ import os + from symsuite.data.double_well_potential import DoubleWellPotential -from symsuite.ml_models.dense_model import DenseModel -from symsuite.symmetry_group_extraction.group_detection import GroupDetection -from symsuite.generator_extraction.generators import GeneratorExtraction from symsuite.data.so2_data import SO2 from symsuite.data.so3_data import SO3 +from symsuite.generator_extraction.generators import GeneratorExtraction +from symsuite.ml_models.dense_model import DenseModel +from symsuite.symmetry_group_extraction.group_detection import GroupDetection -__all__ = ['DoubleWellPotential', 'DenseModel', 'GroupDetection', 'SO2', 'SO3'] +__all__ = ["DoubleWellPotential", "DenseModel", "GroupDetection", "SO2", "SO3"] os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" diff --git a/symsuite/accuracy_functions/__init__.py b/symsuite/accuracy_functions/__init__.py new file mode 100644 index 0000000..b15ca28 --- /dev/null +++ b/symsuite/accuracy_functions/__init__.py @@ -0,0 +1,30 @@ +""" +SymSuite +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +init function for the accuracy functions. +""" +from symsuite.accuracy_functions.accuracy_function import AccuracyFunction +from symsuite.accuracy_functions.label_accuracy import LabelAccuracy + +__all__ = [AccuracyFunction.__name__, LabelAccuracy.__name__] diff --git a/symsuite/accuracy_functions/accuracy_function.py b/symsuite/accuracy_functions/accuracy_function.py new file mode 100644 index 0000000..d3a3d24 --- /dev/null +++ b/symsuite/accuracy_functions/accuracy_function.py @@ -0,0 +1,53 @@ +""" +ZnRND: A zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Parent class for the accuracy functions. +""" +import jax.numpy as np + + +class AccuracyFunction: + """ + Class for computing accuracy. + """ + + def __call__(self, predictions: np.array, targets: np.array) -> float: + """ + Accuracy function call method. + + Parameters + ---------- + predictions : np.array + First set of points to be compared. + targets : np.array + Second points to compare. This will be passed through any + pre-processing of the child classes. + + Returns + ------- + accuracy : float + Accuracy of the points. + """ + raise NotImplementedError("Implemented in child class.") diff --git a/symsuite/accuracy_functions/label_accuracy.py b/symsuite/accuracy_functions/label_accuracy.py new file mode 100644 index 0000000..f8d402e --- /dev/null +++ b/symsuite/accuracy_functions/label_accuracy.py @@ -0,0 +1,65 @@ +""" +ZnRND: A zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Compute the one hot accuracy between two points. +""" +import jax.numpy as np + +from symsuite.accuracy_functions.accuracy_function import AccuracyFunction + + +class LabelAccuracy(AccuracyFunction): + """ + Compute the one hot accuracy between two points. + """ + + def __init__(self, num_class: int): + """ + Constructor for the one hot accuracy. + + Parameters + ---------- + num_class : int + Number of classes in the one hot encoding. + """ + self.num_classes = num_class + + def __call__(self, predictions: np.array, targets: np.array) -> float: + """ + Accuracy function call method. + + Parameters + ---------- + predictions : np.array + First set of points to be compared. + targets : np.array + Second points to compare. Does not require one hot encoding. + + Returns + ------- + accuracy : float + Accuracy of the points. + """ + return np.mean(np.argmax(predictions, -1) == targets) diff --git a/symsuite/analysis/model_visualization.py b/symsuite/analysis/model_visualization.py index 5dd9bde..91b58b3 100644 --- a/symsuite/analysis/model_visualization.py +++ b/symsuite/analysis/model_visualization.py @@ -11,8 +11,8 @@ Visualize the NN models in different ways """ import matplotlib.pyplot as plt -from sklearn.manifold import TSNE import numpy as np +from sklearn.manifold import TSNE class Visualizer: @@ -42,11 +42,9 @@ def __init__(self, data, colour_map): self.data = data self.colour_map = colour_map - def tsne_visualization(self, - perplexity=50, - n_components=2, - plot: bool = True, - save: bool = False) -> np.ndarray: + def tsne_visualization( + self, perplexity=50, n_components=2, plot: bool = True, save: bool = False + ) -> np.ndarray: """ Display a TSNE representation of the models embedding layer @@ -70,9 +68,9 @@ def tsne_visualization(self, See the theory documentation for a full overview of these parameters, particularly in the case of the TSNE values. """ - tsne_model = TSNE(n_components=n_components, - perplexity=perplexity, - random_state=1) + tsne_model = TSNE( + n_components=n_components, perplexity=perplexity, random_state=1 + ) tsne_representation = tsne_model.fit_transform(self.data) if plot: @@ -80,7 +78,7 @@ def tsne_visualization(self, tsne_representation[:, 0], tsne_representation[:, 1], c=self.colour_map, - marker='.', + marker=".", cmap="viridis", vmax=11, vmin=-1, diff --git a/symsuite/data/__init__.py b/symsuite/data/__init__.py index 33e5766..78a8631 100644 --- a/symsuite/data/__init__.py +++ b/symsuite/data/__init__.py @@ -7,4 +7,4 @@ Copyright Contributors to the Zincware Project. Description: __init__ file for the data package. -""" \ No newline at end of file +""" diff --git a/symsuite/data/data_generator.py b/symsuite/data/data_generator.py index 28b4185..a6055ae 100644 --- a/symsuite/data/data_generator.py +++ b/symsuite/data/data_generator.py @@ -11,8 +11,9 @@ """ import abc from typing import Union + +import jax.numpy as jnp import numpy as np -import tensorflow as tf class DataGenerator(metaclass=abc.ABCMeta): diff --git a/symsuite/data/double_well_potential.py b/symsuite/data/double_well_potential.py index 6599e7b..7d56350 100644 --- a/symsuite/data/double_well_potential.py +++ b/symsuite/data/double_well_potential.py @@ -8,14 +8,17 @@ Description: Example data generator for the double well potential. """ -from symsuite.data.data_generator import DataGenerator -from symsuite.utils.data_clustering import range_binning from typing import Union -import numpy as np -import tensorflow as tf + +import jax +import jax.numpy as jnp import matplotlib.pyplot as plt +import numpy as np from tqdm import tqdm +from symsuite.data.data_generator import DataGenerator +from symsuite.utils.data_clustering import range_binning + class DoubleWellPotential(DataGenerator): """ @@ -43,8 +46,8 @@ class DoubleWellPotential(DataGenerator): .. math:: V = -a \cdot (x^{2} + y^{2}) + (x^{2} + y^{2})^{2} - We will require the data to be stored as x,y coordinates in order to facilitate the generator extraction in this - part of the process. + We will require the data to be stored as x,y coordinates in order to facilitate + the generator extraction in this part of the process. """ def __init__(self, a: float = 2.3): @@ -66,8 +69,8 @@ def _double_well(self): ------- """ - square_radii = tf.reduce_sum(tf.math.square(self.domain), 1) - self.image = -self.a * square_radii + tf.square(square_radii) + square_radii = jnp.sum(self.domain ** 2, axis=1) + self.image = -self.a * square_radii + square_radii ** 2 def _pick_points(self, n_points: int, min_val: float = 0, max_val: float = 1.6): """ @@ -84,11 +87,13 @@ def _pick_points(self, n_points: int, min_val: float = 0, max_val: float = 1.6): Returns ------- - + Updates the class attributes. """ - self.domain = tf.random.uniform(shape=(n_points, 2), - minval=min_val, - maxval=max_val) + key = jax.random.PRNGKey(0) + key, subkey = jax.random.split(key) + self.domain = jax.random.uniform( + subkey, shape=(n_points, 2), minval=min_val, maxval=max_val + ) def load_data(self, points: Union[int, np.ndarray], save: bool = False): """ @@ -97,9 +102,10 @@ def load_data(self, points: Union[int, np.ndarray], save: bool = False): Parameters ---------- points : Union[int, np.ndarray] - Points to generate, either an np.ndarray or an integer. If an integer, N points will be generated, if - an array, it will either be treated as input to a function to generate values or those indices will be - drawn from a pool. + Points to generate, either an np.ndarray or an integer. If an integer, + N points will be generated, if an array, it will either be treated as + input to a function to generate values or those indices will be drawn + from a pool. save : bool If true, save the data after generating it. @@ -114,7 +120,7 @@ def load_data(self, points: Union[int, np.ndarray], save: bool = False): # set domain and generate image data. else: - self.domain = tf.convert_to_tensor(points) + self.domain = jnp.array(points) self._double_well() def plot_clusters(self, save: bool = False): @@ -131,15 +137,17 @@ def plot_clusters(self, save: bool = False): """ self.plot_data(show=False) - for i, item in tqdm(enumerate(self.clustered_data), ncols=70, total=len(self.clustered_data)): - r = tf.norm(self.clustered_data[item]['domain'], axis=1) - v = self.clustered_data[item]['image'] - plt.plot(r, v, '.', label=f"Class {i}", markersize=15) + for i, item in tqdm( + enumerate(self.clustered_data), ncols=70, total=len(self.clustered_data) + ): + r = jnp.linalg.norm(self.clustered_data[item]["domain"], axis=1) + v = self.clustered_data[item]["image"] + plt.plot(r, v, ".", label=f"Class {i}", markersize=15) - plt.legend(loc='center left', bbox_to_anchor=(1, 0.5)) + plt.legend(loc="center left", bbox_to_anchor=(1, 0.5)) if save: - plt.savefig('Clusters.png', dpi=600) + plt.savefig("Clusters.png", dpi=600) plt.show() @@ -160,18 +168,20 @@ def plot_data(self, save: bool = False, show: bool = True): if self.domain is None: self._pick_points(1000, min_val=0, max_val=1.2) self._double_well() - plt.plot(tf.norm(self.domain, axis=1), self.image, '.') - plt.xlabel('r') - plt.ylabel('V') + plt.plot(jnp.linalg.norm(self.domain, axis=1), self.image, ".") + plt.xlabel("r") + plt.ylabel("V") plt.xlim(-0.03, 1.7) plt.ylim(-1.5, 1.3) plt.grid() if save: - plt.savefig(f'Double_Well_{len(self.domain)}.svg', dpi=600, format='dpi') + plt.savefig(f"Double_Well_{len(self.domain)}.svg", dpi=600, format="dpi") if show: plt.show() - def build_clusters(self, value_range: list = None, bin_operation: list = None, representatives=1000): + def build_clusters( + self, value_range: list = None, bin_operation: list = None, representatives=1000 + ): """ Split the raw function data into classes. @@ -191,7 +201,8 @@ def build_clusters(self, value_range: list = None, bin_operation: list = None, r Notes ----- - In the double well potential we can simply use the range_binning clustering algorithm. + In the double well potential we can simply use the range_binning clustering + algorithm. """ # Replace None type parameters. if bin_operation is None: @@ -207,8 +218,10 @@ def build_clusters(self, value_range: list = None, bin_operation: list = None, r print("Loading additional data.") self.load_data(n_classes * representatives * 1000) - self.clustered_data = range_binning(image=self.image, - domain=self.domain, - value_range=value_range, - bin_operation=bin_operation, - representatives=representatives) + self.clustered_data = range_binning( + image=self.image, + domain=self.domain, + value_range=value_range, + bin_operation=bin_operation, + representatives=representatives, + ) diff --git a/symsuite/data/so2_data.py b/symsuite/data/so2_data.py index a5afd09..37d673e 100644 --- a/symsuite/data/so2_data.py +++ b/symsuite/data/so2_data.py @@ -8,10 +8,12 @@ Description: Module for the computation of so2 data """ -from symsuite.data.data_generator import DataGenerator from typing import Union -import numpy as np + import matplotlib.pyplot as plt +import numpy as np + +from symsuite.data.data_generator import DataGenerator class SO2(DataGenerator): @@ -74,9 +76,9 @@ def _circle(self, points: int): """ if self.noise: - self.radial_values = np.random.uniform(self.radius - self.variance, - self.radius + self.variance, - points) + self.radial_values = np.random.uniform( + self.radius - self.variance, self.radius + self.variance, points + ) else: self.radial_values = self.radius @@ -110,7 +112,9 @@ def load_data(self, points: Union[int, np.ndarray], save: bool = False): # set domain and generate image data. else: - raise ValueError(f"Type {type(points)} is not valid for this data generator, try an integer") + raise ValueError( + f"Type {type(points)} is not valid for this data generator, try an integer" + ) def plot_data(self, save: bool = False, show: bool = True): """ diff --git a/symsuite/data/so3_data.py b/symsuite/data/so3_data.py index 916e9a3..6670f23 100644 --- a/symsuite/data/so3_data.py +++ b/symsuite/data/so3_data.py @@ -8,10 +8,12 @@ Description: Module for the computation of so3 data """ -from symsuite.data.data_generator import DataGenerator from typing import Union -import numpy as np + import matplotlib.pyplot as plt +import numpy as np + +from symsuite.data.data_generator import DataGenerator class SO3(DataGenerator): @@ -74,9 +76,9 @@ def _sphere(self, points: int): """ if self.noise: - self.radial_values = np.random.uniform(self.radius - self.variance, - self.radius + self.variance, - points) + self.radial_values = np.random.uniform( + self.radius - self.variance, self.radius + self.variance, points + ) else: self.radial_values = self.radius @@ -112,7 +114,9 @@ def load_data(self, points: Union[int, np.ndarray], save: bool = False): # set domain and generate image data. else: - raise ValueError(f"Type {type(points)} is not valid for this data generator, try an integer") + raise ValueError( + f"Type {type(points)} is not valid for this data generator, try an integer" + ) def plot_data(self, save: bool = False, show: bool = True): """ @@ -133,15 +137,15 @@ def plot_data(self, save: bool = False, show: bool = True): fig = plt.figure() ax = fig.add_subplot(111, projection="3d") - ax.scatter(self.domain[:, 0], - self.domain[:, 1], - self.domain[:, 2], - marker=".", - color="k") + ax.scatter( + self.domain[:, 0], + self.domain[:, 1], + self.domain[:, 2], + marker=".", + color="k", + ) if save: - plt.savefig(f"SO(2)_{len(self.domain)}.svg", - dpi=800, - format="svg") + plt.savefig(f"SO(2)_{len(self.domain)}.svg", dpi=800, format="svg") plt.show() def build_clusters(self, **kwargs): diff --git a/symsuite/distance_metrics/__init__.py b/symsuite/distance_metrics/__init__.py new file mode 100644 index 0000000..7ac5f7a --- /dev/null +++ b/symsuite/distance_metrics/__init__.py @@ -0,0 +1,39 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +distance metric module +""" +from symsuite.distance_metrics.angular_distance import AngularDistance +from symsuite.distance_metrics.cosine_distance import CosineDistance +from symsuite.distance_metrics.distance_metric import DistanceMetric +from symsuite.distance_metrics.hyper_sphere_distance import HyperSphere +from symsuite.distance_metrics.l_p_norm import LPNorm +from symsuite.distance_metrics.mahalanobis_distance import MahalanobisDistance +from symsuite.distance_metrics.order_n_difference import OrderNDifference + +__all__ = [ + DistanceMetric.__name__, + CosineDistance.__name__, + AngularDistance.__name__, + LPNorm.__name__, + OrderNDifference.__name__, + MahalanobisDistance.__name__, + HyperSphere.__name__, +] diff --git a/symsuite/distance_metrics/angular_distance.py b/symsuite/distance_metrics/angular_distance.py new file mode 100644 index 0000000..3817c16 --- /dev/null +++ b/symsuite/distance_metrics/angular_distance.py @@ -0,0 +1,78 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Compute the angular distance between two points normalized by the point density in the +circle. +""" +import jax.numpy as np + +from symsuite.distance_metrics.distance_metric import DistanceMetric + + +class AngularDistance(DistanceMetric): + """ + Class for the angular distance metric. + """ + + def __init__(self, points: int = None): + """ + Constructor for the angular distance metric. + + Parameters + ---------- + points : int + Number of points in the circle. If None, normalization by pi is used. + """ + if points is None: + self.normalization = np.pi + elif type(points) is int and points > 0: + self.normalization = points / np.pi + else: + raise ValueError("Invalid points input.") + + def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): + """ + Call the distance metric. + + Distance between points in the point_1 tensor will be computed between those in + the point_2 tensor element-wise. Therefore, we will have: + + point_1[i] - point_2[i] for all i. + + Parameters + ---------- + point_1 : np.ndarray (n_points, point_dimension) + First set of points in the comparison. + point_2 : np.ndarray (n_points, point_dimension) + Second set of points in the comparison. + kwargs + Miscellaneous keyword arguments for the specific metric. + + Returns + ------- + d(point_1, point_2) : tf.tensor, shape=(n_points, 1) + Array of distances for each point. + """ + numerator = np.einsum("ij, ij -> i", point_1, point_2) + denominator = np.sqrt( + np.einsum("ij, ij -> i", point_1, point_1) + * np.einsum("ij, ij -> i", point_2, point_2) + ) + return np.arccos(abs(np.divide(numerator, denominator))) / self.normalization diff --git a/symsuite/distance_metrics/cosine_distance.py b/symsuite/distance_metrics/cosine_distance.py new file mode 100644 index 0000000..89d0f57 --- /dev/null +++ b/symsuite/distance_metrics/cosine_distance.py @@ -0,0 +1,65 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: +Summary +------- +Module for the ZnTrack cosine distance. +""" +import jax.numpy as np + +from symsuite.distance_metrics.distance_metric import DistanceMetric + + +class CosineDistance(DistanceMetric): + """ + Class for the cosine distance metric. + + Notes + ----- + This is not a real distance metric. + """ + + def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): + """ + Call the distance metric. + + Distance between points in the point_1 tensor will be computed between those in + the point_2 tensor element-wise. Therefore, we will have: + + point_1[i] - point_2[i] for all i. + + Parameters + ---------- + point_1 : np.ndarray (n_points, point_dimension) + First set of points in the comparison. + point_2 : np.ndarray (n_points, point_dimension) + Second set of points in the comparison. + kwargs + Miscellaneous keyword arguments for the specific metric. + + Returns + ------- + d(point_1, point_2) : tf.tensor : shape=(n_points, 1) + Array of distances for each point. + """ + numerator = np.einsum("ij, ij -> i", point_1, point_2) + denominator = np.sqrt( + np.einsum("ij, ij -> i", point_1, point_1) + * np.einsum("ij, ij -> i", point_2, point_2) + ) + + return 1 - abs(np.divide(numerator, denominator)) diff --git a/symsuite/distance_metrics/distance_metric.py b/symsuite/distance_metrics/distance_metric.py new file mode 100644 index 0000000..49761a6 --- /dev/null +++ b/symsuite/distance_metrics/distance_metric.py @@ -0,0 +1,54 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Module for the parent class of a ZnRND distance metric. +""" +import jax.numpy as np + + +class DistanceMetric: + """ + Parent class for a ZnRND distance metric. + """ + + def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): + """ + Call the distance metric. + + Distance between points in the point_1 tensor will be computed between those in + the point_2 tensor element-wise. Therefore, we will have: + + point_1[i] - point_2[i] for all i. + + Parameters + ---------- + point_1 : np.ndarray (n_points, point_dimension) + First set of points in the comparison. + point_2 : np.ndarray (n_points, point_dimension) + Second set of points in the comparison. + kwargs + Miscellaneous keyword arguments for the specific metric. + + Returns + ------- + d(point_1, point_2) : tf.tensor : shape=(n_points, 1) + Array of distances for each point. + """ + raise NotImplementedError("Implemented in child class.") diff --git a/symsuite/distance_metrics/hyper_sphere_distance.py b/symsuite/distance_metrics/hyper_sphere_distance.py new file mode 100644 index 0000000..2a1aa31 --- /dev/null +++ b/symsuite/distance_metrics/hyper_sphere_distance.py @@ -0,0 +1,71 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Module for a distance that combines the properties of cosine and lp-norm distance. +""" +import jax.numpy as np + +from symsuite.distance_metrics.cosine_distance import CosineDistance +from symsuite.distance_metrics.distance_metric import DistanceMetric +from symsuite.distance_metrics.l_p_norm import LPNorm + + +class HyperSphere(DistanceMetric): + """ + Compute the L_p norm between vectors. + """ + + def __init__(self, order: float): + """ + Constructor for the LPNorm class. + + Parameters + ---------- + order : float + order of the space + """ + self.order = order + + def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): + """ + Call the distance metric. + + Distance between points in the point_1 tensor will be computed between those in + the point_2 tensor element-wise. Therefore, we will have: + + point_1[i] - point_2[i] for all i. + + Parameters + ---------- + point_1 : np.ndarray (n_points, point_dimension) + First set of points in the comparison. + point_2 : np.ndarray (n_points, point_dimension) + Second set of points in the comparison. + kwargs + Miscellaneous keyword arguments for the specific metric. + + Returns + ------- + d(point_1, point_2) : tf.tensor : shape=(n_points, 1) + Array of distances for each point. + """ + return LPNorm(order=self.order)(point_1, point_2) * CosineDistance()( + point_1, point_2 + ) diff --git a/symsuite/distance_metrics/l_p_norm.py b/symsuite/distance_metrics/l_p_norm.py new file mode 100644 index 0000000..9925644 --- /dev/null +++ b/symsuite/distance_metrics/l_p_norm.py @@ -0,0 +1,71 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Module for the L_p norm class. + + r = p1 - p2 + + d = (|r[0]|^p + |r[1]|^p + ... + |r[n]|^p)^(1/p) +""" +import jax.numpy as np + +from symsuite.distance_metrics.distance_metric import DistanceMetric + + +class LPNorm(DistanceMetric): + """ + Compute the L_p norm between vectors. + """ + + def __init__(self, order: float): + """ + Constructor for the LPNorm class. + + Parameters + ---------- + order : float + order of the space + """ + self.order = order + + def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): + """ + Call the distance metric. + + Distance between points in the point_1 tensor will be computed between those in + the point_2 tensor element-wise. Therefore, we will have: + + point_1[i] - point_2[i] for all i. + + Parameters + ---------- + point_1 : np.ndarray (n_points, point_dimension) + First set of points in the comparison. + point_2 : np.ndarray (n_points, point_dimension) + Second set of points in the comparison. + kwargs + Miscellaneous keyword arguments for the specific metric. + + Returns + ------- + d(point_1, point_2) : np.ndarray : shape=(n_points, 1) + Array of distances for each point. + """ + return np.linalg.norm(point_1 - point_2, axis=1, ord=self.order) diff --git a/symsuite/distance_metrics/mahalanobis_distance.py b/symsuite/distance_metrics/mahalanobis_distance.py new file mode 100644 index 0000000..b1ef019 --- /dev/null +++ b/symsuite/distance_metrics/mahalanobis_distance.py @@ -0,0 +1,67 @@ +""" +ZnRND: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Module for the Mahalanobis distance. +""" +import jax.numpy as np +import scipy.spatial.distance + +from symsuite.distance_metrics.distance_metric import DistanceMetric + + +class MahalanobisDistance(DistanceMetric): + """ + Compute the mahalanobis distance between points. + """ + + def __call__(self, point_1: np.array, point_2: np.array, **kwargs) -> np.array: + """ + Call the distance metric. + + Mahalanobis Distance between points in the point_1 tensor will be computed + between those in the point_2 tensor element-wise. Therefore, we will have: + + point_1[i] - point_2[i] for all i. + + Parameters + ---------- + point_1 : tf.Tensor (n_points, point_dimension) + First set of points in the comparison. + point_2 : tf.Tensor (n_points, point_dimension) + Second set of points in the comparison. + kwargs + Miscellaneous keyword arguments for the specific metric. + Returns + ------- + d(point_1, point_2) : tf.tensor : shape=(n_points, 1) + Array of distances for each point. + """ + inverted_covariance = np.linalg.inv(np.cov(point_1.T)) + distances = [] + for i in range(len(point_1.T[0, :])): + distance = scipy.spatial.distance.mahalanobis( + point_1[i], point_2[i], inverted_covariance + ) + distances.append(distance) + + return distances diff --git a/symsuite/distance_metrics/order_n_difference.py b/symsuite/distance_metrics/order_n_difference.py new file mode 100644 index 0000000..c36e96a --- /dev/null +++ b/symsuite/distance_metrics/order_n_difference.py @@ -0,0 +1,79 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Raise a difference to a power of order n. + +e.g. (a - b)^n +""" +import jax.numpy as np + +from symsuite.distance_metrics.distance_metric import DistanceMetric + + +class OrderNDifference(DistanceMetric): + """ + Compute the order n difference between points. + """ + + def __init__(self, order: float = 2, reduce_operation: str = "mean"): + """ + Constructor for the order n distance. + + Parameters + ---------- + order : float (default=2) + Order to which the difference should be raised. + reduce_operation : str (default = "mean") + How to reduce the order N difference, either a sum or a mean. + """ + self.order = order + self.reduce_operation = reduce_operation + + def __call__(self, point_1: np.ndarray, point_2: np.ndarray, **kwargs): + """ + Call the distance metric. + + Distance between points in the point_1 tensor will be computed between those in + the point_2 tensor element-wise. Therefore, we will have: + + point_1[i] - point_2[i] for all i. + + Parameters + ---------- + point_1 : np.ndarray (n_points, point_dimension) + First set of points in the comparison. + point_2 : np.ndarray (n_points, point_dimension) + Second set of points in the comparison. + kwargs + Miscellaneous keyword arguments for the specific metric. + + Returns + ------- + d(point_1, point_2) : np.ndarray : shape=(n_points, 1) + Array of distances for each point. + """ + diff = point_1 - point_2 + + if self.reduce_operation == "mean": + return np.mean(np.power(diff, self.order), axis=1) + elif self.reduce_operation == "sum": + return np.sum(np.power(diff, self.order), axis=1) + else: + raise ValueError(f"Invalid reduction operation: {self.reduce_operation}") diff --git a/symsuite/generator_extraction/generators.py b/symsuite/generator_extraction/generators.py index 60bb02f..239f6cb 100644 --- a/symsuite/generator_extraction/generators.py +++ b/symsuite/generator_extraction/generators.py @@ -10,15 +10,16 @@ ========== Python module to extract generators from data. """ -import tensorflow as tf -import numpy as np -from tqdm import tqdm import random -from sklearn.linear_model import LinearRegression -from sklearn.decomposition import PCA -import matplotlib.pyplot as plt from typing import Tuple +import jax.numpy as jnp +import matplotlib.pyplot as plt +import numpy as np +from sklearn.decomposition import PCA +from sklearn.linear_model import LinearRegression +from tqdm import tqdm + class GeneratorExtraction: """ @@ -52,7 +53,7 @@ class GeneratorExtraction: def __init__( self, - point_cloud: tf.Tensor, + point_cloud: jnp.array, delta: float = 0.5, epsilon: float = 0.3, candidate_runs: int = 10, @@ -77,8 +78,8 @@ def __init__( self.epsilon = epsilon self.candidate_runs = candidate_runs - self.basis: tf.Tensor - self.hyperplane_set: tf.Tensor + self.basis: jnp.ndarray + self.hyperplane_set: jnp.ndarray self.point_pairs: list self.dimension = self._get_dimension() @@ -134,7 +135,7 @@ def _generate_basis_set(self, gs_precision: int): for item in reduced_candidates: basis.append(self._perform_gs(item, basis)) - self.basis = tf.convert_to_tensor(basis) # set the class attribute + self.basis = jnp.array(basis) # set the class attribute self._gs_check(gs_precision) def _gs_check(self, gs_precision: int): @@ -153,11 +154,14 @@ def _gs_check(self, gs_precision: int): Will throw an exception if the assert fails. """ for basis in self.basis: - np.testing.assert_almost_equal(np.linalg.norm(basis), 1) # check the normalization. + # check the normalization. + np.testing.assert_almost_equal(np.linalg.norm(basis), 1) for test in self.basis: if all(test == basis): continue - np.testing.assert_almost_equal(np.dot(basis, test), 0, decimal=gs_precision) # check the orthogonality. + np.testing.assert_almost_equal( + np.dot(basis, test), 0, decimal=gs_precision + ) # check the orthogonality. def _perform_gs(self, vector: list, basis_set: list) -> np.ndarray: """ @@ -178,12 +182,11 @@ def _perform_gs(self, vector: list, basis_set: list) -> np.ndarray: for basis_item in basis_set: basis_vector -= self._projection_operator(basis_item, basis_vector) - return basis_vector / np.linalg.norm(basis_vector) + return jnp.array(basis_vector) / np.linalg.norm(basis_vector) def _eliminate_closest_vector( - self, - reference_vectors: list, - test_vectors: list) -> np.ndarray: + self, reference_vectors: list, test_vectors: list + ) -> np.ndarray: """ Remove the closest vectors in the theoretical basis set @@ -258,7 +261,7 @@ def _construct_hyperplane_set(self): if all(truth_table): self.hyperplane_set.append(point) - self.hyperplane_set = tf.convert_to_tensor(self.hyperplane_set) + self.hyperplane_set = jnp.array(self.hyperplane_set) def _identify_point_pairs(self): """ @@ -336,28 +339,31 @@ def _full_regression(self): def _simple_regression(self): """ - In the case where additional constraints are not needed, we simply perform regression on the problem to - extract generator candidates. + In the case where additional constraints are not needed, we simply perform + regression on the problem to extract generator candidates. Returns ------- Updates the class state. """ - Y = [] - X = [] + y_data = [] + x_data = [] for pair in self.point_pairs: points = [self.hyperplane_set[pair[0]], self.hyperplane_set[pair[1]]] sigma = self._compute_sigma(points) - Y.append( + y_data.append( ((points[0] - points[1]) * np.linalg.norm(points[0])) / (sigma * np.linalg.norm(points[1] - points[0])) ) - X.append(points[0]) + x_data.append(points[0]) generator = [] for i in range(self.dimension): generator = np.concatenate( - (generator, LinearRegression().fit(X, np.array(Y)[:, i]).coef_) + ( + generator, + LinearRegression().fit(x_data, np.array(y_data)[:, i]).coef_, + ) ) self.generator_candidates.append(generator) @@ -381,7 +387,7 @@ def _compute_sigma(self, pair) -> int: ) def _extract_generators( - self, pca_components: object, factor: object = True + self, pca_components: object, factor: object = True ) -> tuple: """ Perform PCA on candidates and extract true generators. @@ -407,9 +413,12 @@ def _extract_generators( pca = PCA(n_components=pca_components) pca.fit(self.generator_candidates) if factor: - return np.sqrt(self.dimension) * pca.components_, pca.explained_variance_ratio_ + return ( + np.sqrt(self.dimension) * pca.components_, + pca.explained_variance_ratio_, + ) else: - return (pca.components_, pca.explained_variance_ratio_) + return pca.components_, pca.explained_variance_ratio_ def _plot_results(self, std_values: list, save: bool = False): """ @@ -438,12 +447,13 @@ def _plot_results(self, std_values: list, save: bool = False): plt.show() def perform_generator_extraction( - self, - pca_components: int = 4, - plot: bool = False, - save: bool = False, - factor: bool = True, - gs_precision: int = 5) -> Tuple: + self, + pca_components: int = 4, + plot: bool = False, + save: bool = False, + factor: bool = True, + gs_precision: int = 5, + ) -> Tuple: """ Collect all methods and perform the generator extraction. @@ -468,9 +478,9 @@ def perform_generator_extraction( explained variance list. """ - for _ in tqdm(range(self.candidate_runs), - ncols=100, - desc="Producing generator candidates"): + for _ in tqdm( + range(self.candidate_runs), ncols=100, desc="Producing generator candidates" + ): try: self._remove_redundancy() self._generate_basis_set(gs_precision) @@ -480,8 +490,9 @@ def perform_generator_extraction( except ValueError: continue - generators, std_array = self._extract_generators(pca_components=pca_components, - factor=factor) + generators, std_array = self._extract_generators( + pca_components=pca_components, factor=factor + ) for i, item in enumerate(generators): print(f"Principle Component {i + 1}: Explained Variance: {std_array[i]}") print(item.reshape((self.dimension, self.dimension))) diff --git a/symsuite/loss_functions/__init__.py b/symsuite/loss_functions/__init__.py new file mode 100644 index 0000000..2460dca --- /dev/null +++ b/symsuite/loss_functions/__init__.py @@ -0,0 +1,44 @@ +""" +Symsuite + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Package containing custom loss functions. +""" +from symsuite.loss_functions.absolute_angle_difference import AngleDistanceLoss +from symsuite.loss_functions.cosine_distance import CosineDistanceLoss +from symsuite.loss_functions.cross_entropy_loss import CrossEntropyLoss +from symsuite.loss_functions.l_p_norm import LPNormLoss +from symsuite.loss_functions.mahalanobis import MahalanobisLoss +from symsuite.loss_functions.mean_power_error import MeanPowerLoss +from symsuite.loss_functions.loss import Loss + +__all__ = [ + AngleDistanceLoss.__name__, + CosineDistanceLoss.__name__, + LPNormLoss.__name__, + MahalanobisLoss.__name__, + MeanPowerLoss.__name__, + Loss.__name__, + CrossEntropyLoss.__name__, +] diff --git a/symsuite/loss_functions/absolute_angle_difference.py b/symsuite/loss_functions/absolute_angle_difference.py new file mode 100644 index 0000000..31d6df5 --- /dev/null +++ b/symsuite/loss_functions/absolute_angle_difference.py @@ -0,0 +1,37 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +ZnRND absolute angle difference TF loss function. +""" +from symsuite.distance_metrics.angular_distance import AngularDistance +from symsuite.loss_functions.loss import Loss + + +class AngleDistanceLoss(Loss): + """ + Class for the mean power loss + """ + + def __init__(self): + """ + Constructor for the mean power loss class. + """ + super(AngleDistanceLoss, self).__init__() + self.metric = AngularDistance() diff --git a/symsuite/loss_functions/cosine_distance.py b/symsuite/loss_functions/cosine_distance.py new file mode 100644 index 0000000..73e157c --- /dev/null +++ b/symsuite/loss_functions/cosine_distance.py @@ -0,0 +1,37 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +ZnRND Cosine similarity TF loss function. +""" +from symsuite.distance_metrics.cosine_distance import CosineDistance +from symsuite.loss_functions.loss import Loss + + +class CosineDistanceLoss(Loss): + """ + Class for the mean power loss + """ + + def __init__(self): + """ + Constructor for the mean power loss class. + """ + super(CosineDistanceLoss, self).__init__() + self.metric = CosineDistance() diff --git a/symsuite/loss_functions/cross_entropy_loss.py b/symsuite/loss_functions/cross_entropy_loss.py new file mode 100644 index 0000000..a21bccf --- /dev/null +++ b/symsuite/loss_functions/cross_entropy_loss.py @@ -0,0 +1,88 @@ +""" +ZnRND: A zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Implement a cross entropy loss function. +""" +import jax +import optax + +from symsuite.loss_functions.loss import Loss + + +class CrossEntropyDistance: + """ + Class for the cross entropy distance + """ + + def __init__(self, classes: int, apply_softmax: bool = False): + """ + Constructor for the distance + + Parameters + ---------- + classes : int + Number of classes in the one-hot encoding. + apply_softmax : bool (default = False) + If true, softmax is applied to the prediction before computing the loss. + """ + self.classes = classes + self.apply_softmax = apply_softmax + + def __call__(self, prediction, target): + """ + + Parameters + ---------- + prediction (batch_size, n_classes) + target + + Returns + ------- + + """ + if self.apply_softmax: + prediction = jax.nn.softmax(prediction) + one_hot_labels = jax.nn.one_hot(target, num_classes=self.classes) + return optax.softmax_cross_entropy(logits=prediction, labels=one_hot_labels) + + +class CrossEntropyLoss(Loss): + """ + Class for the cross entropy loss + """ + + def __init__(self, classes: int = 10, apply_softmax: bool = False): + """ + Constructor for the mean power loss class. + + Parameters + ---------- + classes : int (default=10) + Number of classes in the loss. + apply_softmax : bool (default = False) + If true, softmax is applied to the prediction before computing the loss. + """ + super(CrossEntropyLoss, self).__init__() + self.metric = CrossEntropyDistance(classes=classes, apply_softmax=apply_softmax) diff --git a/symsuite/loss_functions/l_p_norm.py b/symsuite/loss_functions/l_p_norm.py new file mode 100644 index 0000000..39616bd --- /dev/null +++ b/symsuite/loss_functions/l_p_norm.py @@ -0,0 +1,42 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +ZnRND L^{p} norm TF loss function. +""" +from symsuite.distance_metrics.l_p_norm import LPNorm +from symsuite.loss_functions.loss import Loss + + +class LPNormLoss(Loss): + """ + Class for the mean power loss + """ + + def __init__(self, order: float): + """ + Constructor for the L_p norm loss class. + + Parameters + ---------- + order : float + Order to which the difference should be raised. + """ + super(LPNormLoss, self).__init__() + self.metric = LPNorm(order=order) diff --git a/symsuite/loss_functions/loss.py b/symsuite/loss_functions/loss.py new file mode 100644 index 0000000..d19c474 --- /dev/null +++ b/symsuite/loss_functions/loss.py @@ -0,0 +1,64 @@ +""" +ZnRND: A Zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Module for the simple loss for TensorFlow. +""" +from abc import ABC + +import jax.numpy as np + +from symsuite.distance_metrics.distance_metric import DistanceMetric + + +class Loss(ABC): + """ + Class for the simple loss. + + Attributes + ---------- + metric : DistanceMetric + """ + + def __init__(self): + """ + Constructor for the simple loss parent class. + """ + super().__init__() + self.metric: DistanceMetric = None + + def __call__(self, point_1: np.array, point_2: np.array) -> float: + """ + Summation over the tensor of the respective similarity measurement + Parameters + ---------- + point_1 : np.array + first neural network representation of the considered points + point_2 : np.array + second neural network representation of the considered points + + Returns + ------- + loss : float + total loss of all points based on the similarity measurement + """ + return np.mean(self.metric(point_1, point_2), axis=0) diff --git a/symsuite/loss_functions/mahalanobis.py b/symsuite/loss_functions/mahalanobis.py new file mode 100644 index 0000000..d16e4eb --- /dev/null +++ b/symsuite/loss_functions/mahalanobis.py @@ -0,0 +1,37 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +ZnRND Mahalanobis distance TF loss function. +""" +import symsuite.distance_metrics.mahalanobis_distance as mahalanobis +from symsuite.loss_functions.loss import Loss + + +class MahalanobisLoss(Loss): + """ + Class for the mean power loss + """ + + def __init__(self): + """ + Constructor for the Mahalanobis loss class. + """ + super(MahalanobisLoss, self).__init__() + self.metric = mahalanobis.MahalanobisDistance() diff --git a/symsuite/loss_functions/mean_power_error.py b/symsuite/loss_functions/mean_power_error.py new file mode 100644 index 0000000..4700182 --- /dev/null +++ b/symsuite/loss_functions/mean_power_error.py @@ -0,0 +1,42 @@ +""" +ZnRND: A Zincwarecode package. +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 +Copyright Contributors to the Zincwarecode Project. +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ +Citation +-------- +If you use this module please cite us with: + +Summary +------- +ZnRND mean square error TF loss function. +""" +from symsuite.distance_metrics.order_n_difference import OrderNDifference +from symsuite.loss_functions.loss import Loss + + +class MeanPowerLoss(Loss): + """ + Class for the mean power loss + """ + + def __init__(self, order: float): + """ + Constructor for the mean power loss class. + + Parameters + ---------- + order : float + Order to which the difference should be raised. + """ + super(MeanPowerLoss, self).__init__() + self.metric = OrderNDifference(order=order) diff --git a/symsuite/ml_models/dense_model.py b/symsuite/ml_models/dense_model.py index deba5f0..1e33e21 100644 --- a/symsuite/ml_models/dense_model.py +++ b/symsuite/ml_models/dense_model.py @@ -10,8 +10,8 @@ =========== Dense neural network model """ -import tensorflow as tf import numpy as np +import tensorflow as tf from tensorflow.keras import regularizers from tensorflow.keras.layers import Input @@ -23,8 +23,9 @@ class DenseModel: Attributes ---------- data_dict : dict - Dictionary of data where the key is the class name and the values are the coordinates belongin to that - class. This is fundamentall a classification problem! + Dictionary of data where the key is the class name and the values are the + coordinates belonging to that class. This is fundamental to a classification + problem. n_layers : int Number of hidden layers to use. units : int @@ -63,7 +64,8 @@ def __init__( lr: float = 1e-4, batch_size: int = 100, terminate_patience: int = 10, - lr_patience: int = 5): + lr_patience: int = 5, + ): """ Constructor fpr the Dense model class. @@ -121,7 +123,9 @@ def add_data(self, data: dict): """ self.data_dict = data - self.input_shape = len(self.data_dict[list(self.data_dict.keys())[0]]['domain'][0]) + self.input_shape = len( + self.data_dict[list(self.data_dict.keys())[0]]["domain"][0] + ) def _shuffle_and_split_data(self): """ @@ -135,12 +139,16 @@ def _shuffle_and_split_data(self): for key in self.data_dict: labels = tf.repeat( tf.convert_to_tensor(np.array(key), dtype=tf.float32), - len(self.data_dict[key]['domain']), + len(self.data_dict[key]["domain"]), ) - stacked_data = tf.concat([tf.cast(self.data_dict[key]['domain'], dtype=tf.float32), - tf.transpose([labels])], - axis=1) + stacked_data = tf.concat( + [ + tf.cast(self.data_dict[key]["domain"], dtype=tf.float32), + tf.transpose([labels]), + ], + axis=1, + ) data_volume = len(stacked_data) train, test, validate = tf.split( @@ -155,18 +163,15 @@ def _shuffle_and_split_data(self): if self.train_ds is None: self.train_ds = train else: - self.train_ds = tf.concat([self.train_ds, train], - axis=0) + self.train_ds = tf.concat([self.train_ds, train], axis=0) if self.test_ds is None: self.test_ds = test else: - self.test_ds = tf.concat([self.test_ds, test], - axis=0) + self.test_ds = tf.concat([self.test_ds, test], axis=0) if self.val_ds is None: self.val_ds = validate else: - self.val_ds = tf.concat([self.val_ds, validate], - axis=0) + self.val_ds = tf.concat([self.val_ds, validate], axis=0) self.train_ds = tf.random.shuffle(self.train_ds) self.test_ds = tf.random.shuffle(self.test_ds) @@ -268,12 +273,12 @@ def train_model(self): # Train the model for i in range(1, 6): self.model.fit( - x=self.train_ds[:, 0:self.input_shape], + x=self.train_ds[:, 0 : self.input_shape], y=tf.keras.utils.to_categorical(self.train_ds[:, -1]), batch_size=self.batch_size, shuffle=True, validation_data=( - self.test_ds[:, 0:self.input_shape], + self.test_ds[:, 0 : self.input_shape], tf.keras.utils.to_categorical(self.test_ds[:, -1]), ), verbose=1, @@ -290,7 +295,8 @@ def _evaluate_model(self): """ attributes = self.model.evaluate( - x=self.val_ds[:, 0:self.input_shape], y=tf.keras.utils.to_categorical(self.val_ds[:, -1]) + x=self.val_ds[:, 0 : self.input_shape], + y=tf.keras.utils.to_categorical(self.val_ds[:, -1]), ) print(f"Loss: {attributes[0]} \n" f"Accuracy: {attributes[1]}") @@ -316,4 +322,4 @@ def get_embedding_layer_representation(self, data_array: np.ndarray) -> tf.Tenso model.add(layer) model.build() - return model.predict(data_array[:, 0:self.input_shape]) + return model.predict(data_array[:, 0 : self.input_shape]) diff --git a/symsuite/ml_models/model.py b/symsuite/ml_models/model.py new file mode 100644 index 0000000..4eb9221 --- /dev/null +++ b/symsuite/ml_models/model.py @@ -0,0 +1,118 @@ +""" +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the Zincware Project. + +Description: Parent class for the models. +""" +from typing import Any, Callable, Union + +import jax.numpy as jnp +from jax.random import PRNGKeyArray + + +class Model: + """ + Parent class for ZnRND Models. + + Attributes + ---------- + model : Callable + A callable class or function that takes a feature + vector and returns something from it. Typically this is a + neural network layer stack. + """ + + model: Callable + + def init_model( + self, + init_rng: Union[Any, PRNGKeyArray] = None, + kernel_init: Callable = None, + bias_init: Callable = None, + ): + """ + Initialize a model. + + Parameters + ---------- + init_rng : Union[Any, PRNGKeyArray] + Initial rng for train state that is immediately deleted. + kernel_init : Callable + Define the kernel initialization. + bias_init : Callable + Define the bias initialization. + """ + raise NotImplementedError("Implemented in child class.") + + def train_model( + self, + train_ds: dict, + test_ds: dict, + epochs: int = 10, + batch_size: int = 1, + disable_loading_bar: bool = False, + ): + """ + Train the model on data. + + Parameters + ---------- + train_ds : dict + Train dataset with inputs and targets. + test_ds : dict + Test dataset with inputs and targets. + epochs : int + Number of epochs to train over. + batch_size : int + Size of the batch to use in training. + disable_loading_bar : bool + Disable the output visualization of the loading par. + """ + raise NotImplementedError("Implemented in child class.") + + def compute_ntk( + self, + x_i: jnp.ndarray, + x_j: jnp.ndarray = None, + normalize: bool = True, + infinite: bool = False, + ): + """ + Compute the NTK matrix for the model. + + Parameters + ---------- + x_i : jnp.ndarray + Dataset for which to compute the NTK matrix. + x_j : jnp.ndarray (optional) + Dataset for which to compute the NTK matrix. + normalize : bool (default = True) + If true, divide each row by its max value. + infinite : bool (default = False) + If true, compute the infinite width limit as well. + + Returns + ------- + NTK : dict + The NTK matrix for both the empirical and infinite width computation. + """ + raise NotImplementedError("Implemented in child class") + + def __call__(self, feature_vector: jnp.ndarray): + """ + Call the network. + + Parameters + ---------- + feature_vector : jnp.ndarray + Feature vector on which to apply operation. + + Returns + ------- + output of the model. + """ + self.model(feature_vector) diff --git a/symsuite/ml_models/nt_model.py b/symsuite/ml_models/nt_model.py new file mode 100644 index 0000000..a6cb84c --- /dev/null +++ b/symsuite/ml_models/nt_model.py @@ -0,0 +1,459 @@ +""" +ZnRND: A zincwarecode package. + +License +------- +This program and the accompanying materials are made available under the terms +of the Eclipse Public License v2.0 which accompanies this distribution, and is +available at https://www.eclipse.org/legal/epl-v20.html + +SPDX-License-Identifier: EPL-2.0 + +Copyright Contributors to the zincwarecode Project. + +Contact Information +------------------- +email: zincwarecode@gmail.com +github: https://github.com/zincware +web: https://zincwarecode.com/ + +Citation +-------- +If you use this module please cite us with: + +Summary +------- +Module for the neural tangents infinite width network models. +""" +import logging +from typing import Any, Callable, Union + +import jax +import jax.numpy as jnp +import neural_tangents as nt +import numpy as onp +from flax.training import train_state +from jax.random import PRNGKeyArray +from neural_tangents.stax import serial +from tqdm import trange + +from symsuite.accuracy_functions.accuracy_function import AccuracyFunction +from symsuite.loss_functions.loss import Loss +from symsuite.ml_models.model import Model +from symsuite.utils import normalize_covariance_matrix + +logger = logging.getLogger(__name__) + + +class NTModel(Model): + """ + Class for a neural tangents model. + """ + + def __init__( + self, + loss_fn: Loss, + optimizer: Callable, + input_shape: tuple, + nt_module: serial = None, + data_pool: jnp.ndarray = None, + accuracy_fn: AccuracyFunction = None, + batch_size: int = 10, + ): + """ + Constructor for a Flax model. + + Parameters + ---------- + loss_fn : SimpleLoss + A function to use in the loss computation. + accuracy_fn : AccuracyFunction + Accuracy function to use for accuracy computation. + optimizer : Callable + optimizer to use in the training. OpTax is used by default and + cross-compatibility is not assured. + input_shape : tuple + Shape of the NN input. + batch_size : int (default=10) + Batch size to use in the NTK computation. + nt_module : serial + NT stax module for training. + data_pool : jnp.ndarray + Data pool from which TTV is built. + + """ + self.rng = jax.random.PRNGKey(onp.random.randint(0, 500)) + self.init_fn = nt_module[0] + self.apply_fn = jax.jit(nt_module[1]) + self.kernel_fn = nt.batch(nt_module[2], batch_size=batch_size) + self.empirical_ntk = nt.batch( + nt.empirical_ntk_fn(self.apply_fn), batch_size=batch_size + ) + self.empirical_ntk_jit = jax.jit(self.empirical_ntk) + self.loss_fn = loss_fn + self.accuracy_fn = accuracy_fn + self.optimizer = optimizer + self.input_shape = input_shape + + self.data_pool = data_pool + + # initialize the model state + self.model_state = None + self.init_model() + + def init_model( + self, + init_rng: Union[Any, PRNGKeyArray] = None, + kernel_init: Callable = None, + bias_init: Callable = None, + ): + """ + Initialize a model. + + If no rng key is given, the key will be produced randomly. + + Parameters + ---------- + init_rng : Union[Any, PRNGKeyArray] + Initial rng for train state that is immediately deleted. + kernel_init : Callable + Define the kernel initialization. + bias_init : Callable + Define the bias initialization. + """ + if kernel_init: + raise NotImplementedError( + "Currently, there is no option customize the weight initialization. " + ) + if bias_init: + raise NotImplementedError( + "Currently, there is no option customize the bias initialization. " + ) + if init_rng is None: + init_rng = jax.random.PRNGKey(onp.random.randint(0, 1000000)) + self.model_state = self._create_train_state(init_rng) + + def compute_ntk( + self, + x_i: jnp.ndarray, + x_j: jnp.ndarray = None, + normalize: bool = True, + infinite: bool = False, + ): + """ + Compute the NTK matrix for the model. + + Parameters + ---------- + x_i : np.ndarray + Dataset for which to compute the NTK matrix. + x_j : np.ndarray (optional) + Dataset for which to compute the NTK matrix. + normalize : bool (default = True) + If true, divide each row by its max value. + infinite : bool (default = False) + If true, compute the infinite width limit as well. + + Returns + ------- + NTK : dict + The NTK matrix for both the empirical and infinite width computation. + """ + if x_j is None: + x_j = x_i + empirical_ntk = self.empirical_ntk_jit(x_i, x_j, self.model_state.params) + + if infinite: + infinite_ntk = self.kernel_fn(x_i, x_j, "ntk") + else: + infinite_ntk = None + + if normalize: + empirical_ntk = normalize_covariance_matrix(empirical_ntk) + if infinite: + infinite_ntk = normalize_covariance_matrix(infinite_ntk) + + return {"empirical": empirical_ntk, "infinite": infinite_ntk} + + def _create_train_state(self, init_rng: Union[Any, PRNGKeyArray]): + """ + Create a training state of the model. + + Parameters + ---------- + init_rng : Union[Any, PRNGKeyArray] + Initial rng for train state that is immediately deleted. + + Returns + ------- + initial state of model to then be trained. + + Notes + ----- + TODO: Make the TrainState class passable by the user as it can track custom + model properties. + """ + _, params = self.init_fn(init_rng, self.input_shape) + + return train_state.TrainState.create( + apply_fn=self.apply_fn, params=params, tx=self.optimizer + ) + + def _compute_metrics( + self, + predictions: jnp.ndarray, + targets: jnp.ndarray, + ): + """ + Compute the current metrics of the training. + + Parameters + ---------- + predictions : np.ndarray + Predictions made by the network. + targets : np.ndarray + Targets from the training data. + + Returns + ------- + metrics : dict + A dict of current training metrics, e.g. {"loss": ..., "accuracy": ...} + """ + loss = self.loss_fn(predictions, targets) + if self.accuracy_fn is not None: + accuracy = self.accuracy_fn(predictions, targets) + metrics = {"loss": loss, "accuracy": accuracy} + + else: + metrics = {"loss": loss} + + return metrics + + def _train_step(self, state: train_state.TrainState, batch: dict): + """ + Train a single step. + + Parameters + ---------- + state : TrainState + Current state of the neural network. + batch : dict + Batch of data to train on. + + Returns + ------- + state : dict + Updated state of the neural network. + metrics : dict + Metrics for the current model. + """ + + def loss_fn(params): + """ + helper loss computation + """ + inner_predictions = self.apply_fn(params, batch["inputs"]) + loss = self.loss_fn(inner_predictions, batch["targets"]) + return loss, inner_predictions + + grad_fn = jax.value_and_grad(loss_fn, has_aux=True) + + (_, predictions), grads = grad_fn(state.params) + + state = state.apply_gradients(grads=grads) # in place state update. + metrics = self._compute_metrics( + predictions=predictions, targets=batch["targets"] + ) + + return state, metrics + + def _evaluate_step(self, params: dict, batch: dict): + """ + Evaluate the model on test data. + + Parameters + ---------- + params : dict + Parameters of the model. + batch : dict + Batch of data to test on. + + Returns + ------- + metrics : dict + Metrics dict computed on test data. + """ + predictions = self.apply_fn(params, batch["inputs"]) + + return self._compute_metrics(predictions, batch["targets"]) + + def _train_epoch( + self, state: train_state.TrainState, train_ds: dict, batch_size: int + ): + """ + Train for a single epoch. + + Performs the following steps: + + * Shuffles the data + * Runs an optimization step on each batch + * Computes the metrics for the batch + * Return an updated optimizer, state, and metrics dictionary. + + Parameters + ---------- + state : TrainState + Current state of the model. + train_ds : dict + Dataset on which to train. + batch_size : int + Size of each batch. + + Returns + ------- + state : TrainState + State of the model after the epoch. + metrics : dict + Dict of metrics for current state. + """ + # Some housekeeping variables. + train_ds_size = len(train_ds["inputs"]) + steps_per_epoch = train_ds_size // batch_size + + if train_ds_size == 1: + state, metrics = self._train_step(state, train_ds) + batch_metrics = [metrics] + + else: + # Prepare the shuffle. + permutations = jax.random.permutation(self.rng, train_ds_size) + permutations = permutations[: steps_per_epoch * batch_size] + permutations = permutations.reshape((steps_per_epoch, batch_size)) + + # Step over items in batch. + batch_metrics = [] + for permutation in permutations: + batch = {k: v[permutation, ...] for k, v in train_ds.items()} + # print(batch) + state, metrics = self._train_step(state, batch) + batch_metrics.append(metrics) + + # Get the metrics off device for printing. + batch_metrics_np = jax.device_get(batch_metrics) + epoch_metrics_np = { + k: onp.mean([metrics[k] for metrics in batch_metrics_np]) + for k in batch_metrics_np[0] + } + + return state, epoch_metrics_np + + def _evaluate_model(self, params: dict, test_ds: dict) -> dict: + """ + Evaluate the model. + + Parameters + ---------- + params : dict + Current state of the model. + test_ds : dict + Dataset on which to evaluate. + Returns + ------- + metrics : dict + Loss of the model. + """ + metrics = self._evaluate_step(params, test_ds) + metrics = jax.device_get(metrics) + summary = jax.tree_map(lambda x: x.item(), metrics) + + return summary + + def validate_model(self, dataset: dict, loss_fn: SimpleLoss): + """ + Validate the model on some external data. + + Parameters + ---------- + loss_fn : SimpleLoss + Loss function to use in the computation. + dataset : dict + Dataset on which to validate the model. + {"inputs": np.ndarray, "targets": np.ndarray}. + + Returns + ------- + metrics : dict + Metrics computed in the validation. {"loss": [], "accuracy": []}. + Note, for ease of large scale experiments we always return both keywords + whether they are computed or not. + """ + predictions = self.apply_fn(self.model_state.params, dataset["inputs"]) + + loss = loss_fn(predictions, dataset["targets"]) + + if self.accuracy_fn is not None: + accuracy = self.accuracy_fn(predictions, dataset["targets"]) + else: + accuracy = None + + return {"loss": loss, "accuracy": accuracy} + + def train_model( + self, + train_ds: dict, + test_ds: dict, + epochs: int = 50, + batch_size: int = 1, + disable_loading_bar: bool = False, + ): + """ + Train the model. + + See the parent class for a full doc-string. + """ + if self.model_state is None: + self.init_model() + + state = self.model_state + + loading_bar = trange( + 1, epochs + 1, ncols=100, unit="batch", disable=disable_loading_bar + ) + test_losses = [] + test_accuracy = [] + train_losses = [] + train_accuracy = [] + for i in loading_bar: + loading_bar.set_description(f"Epoch: {i}") + + state, train_metrics = self._train_epoch( + state, train_ds, batch_size=batch_size + ) + metrics = self._evaluate_model(state.params, test_ds) + + loading_bar.set_postfix(test_loss=metrics["loss"]) + if self.accuracy_fn is not None: + loading_bar.set_postfix(accuracy=metrics["accuracy"]) + test_accuracy.append(metrics["accuracy"]) + train_accuracy.append(train_metrics["accuracy"]) + + test_losses.append(metrics["loss"]) + train_losses.append(train_metrics["loss"]) + + # Update the final model state. + self.model_state = state + + return { + "test_losses": test_losses, + "test_accuracy": test_accuracy, + "train_losses": train_losses, + "train_accuracy": train_accuracy, + } + + def __call__(self, feature_vector: jnp.ndarray): + """ + See parent class for full doc string. + """ + state = self.model_state + + return self.apply_fn(state.params, feature_vector) diff --git a/symsuite/symmetry_group_extraction/group_detection.py b/symsuite/symmetry_group_extraction/group_detection.py index 02c6070..f495a86 100644 --- a/symsuite/symmetry_group_extraction/group_detection.py +++ b/symsuite/symmetry_group_extraction/group_detection.py @@ -10,12 +10,18 @@ =============== Cluster raw data into symmetry groups """ -from symsuite.ml_models.dense_model import DenseModel -from symsuite.analysis.model_visualization import Visualizer from typing import Tuple + +import jax.numpy as jnp import numpy as np -import tensorflow as tf -from symsuite.utils.data_clustering import compute_com, compute_radius_of_gyration + +from symsuite.analysis.model_visualization import Visualizer +from symsuite.ml_models.dense_model import DenseModel +from symsuite.utils.data_clustering import ( + compute_com, + compute_radius_of_gyration, + to_categorical, +) class GroupDetection: @@ -32,57 +38,22 @@ class GroupDetection: Which set to use in the representation, train, validation, or test. """ - def __init__(self, model: DenseModel, data_clusters: dict, representation_set: str = 'train'): + def __init__( + self, model_representation: np.ndarray, data_classes: list + ): """ Constructor for the GroupDetection class. Parameters ---------- - model : DenseModel - Model to use in the group detection. - data_clusters : dict - Data cluster class used for the partitioning of the data. - representation_set : str - Which set to use in the representation, train, validation, or test. - """ - self.model = model - self.data = data_clusters - self.representation_set = representation_set - self.model.add_data(self.data) # add the data to the model. - - def _get_model_predictions(self) -> Tuple: - """ - Train the attached model. - - Returns - ------- - val_data : tf.Tensor - Data on which the prediction were made. - model_predictions : Tuple - Embedding layer of the NN on validation data. - """ - self.model.train_model() - if self.representation_set == 'train': - val_data = self.model.train_ds - predictions = self.model.model.predict(val_data[:, 0:self.model.input_shape]) - elif self.representation_set == 'test:': - val_data = self.model.test_ds - predictions = self.model.model.predict(val_data[:, 0:self.model.input_shape]) - else: - val_data = self.model.val_ds - predictions = self.model.model.predict(val_data[:, 0:self.model.input_shape]) - - return val_data, predictions - - def _run_visualization(self): - """ - Perform a visualization on the TSNE data. - - Returns - ------- - + model_representation : np.ndarray + Model representation on which to perform the symmetry connection + analysis. + data_classes : list + List of the data classes for better visualization """ - pass + self.model_representations = model_representation + self.data_classes = data_classes @staticmethod def _cluster_detection(function_data: np.ndarray, data: np.ndarray): @@ -103,50 +74,23 @@ def _cluster_detection(function_data: np.ndarray, data: np.ndarray): e.g. {1: [radial values], 2: [radial_values], ...} """ net_array = np.concatenate((data, function_data), 1) - sorted_data = tf.gather(net_array, tf.argsort(net_array[:, -1])).numpy() + sorted_data = jnp.take(net_array, jnp.argsort(net_array[:, -1])) class_array = np.unique(function_data[:, -1]) point_cloud = {} # loop over the class array for i, item in enumerate(class_array): - start = np.searchsorted(sorted_data[:, -1], item, side='left') - stop = np.searchsorted(sorted_data[:, -1], item, side='right') - 1 + start = np.searchsorted(sorted_data[:, -1], item, side="left") + stop = np.searchsorted(sorted_data[:, -1], item, side="right") - 1 com = compute_com(sorted_data[start:stop, 0:2]) rg = compute_radius_of_gyration(sorted_data[start:stop, 0:2], com) - #print(f"Class: {item}, COM: {com}, Rg: {rg}") + # print(f"Class: {item}, COM: {com}, Rg: {rg}") if rg > 1000: point_cloud[item] = sorted_data[start:stop, 2:-1] return point_cloud - @staticmethod - def _filter_data(predictions: tf.Tensor, targets: tf.Tensor): - """ - Check which data points are predicted well and include them in the data. - - Parameters - ---------- - targets : tf.Tensor - Target values on which predictions were made. - predictions : tf.Tensor - Network predictions. - - Returns - ------- - - """ - accepted_candidates = np.zeros(len(predictions)) - target_values = tf.keras.utils.to_categorical(targets[:, -1]) - counter = 0 - for i, item in enumerate(predictions): - if np.linalg.norm(predictions[i] - target_values[i]) <= 2e-1: - accepted_candidates[counter] = i - counter += 1 - accepted_candidates = tf.convert_to_tensor(accepted_candidates[0:counter], dtype=tf.int32) - - return tf.gather(targets, accepted_candidates) - def run_symmetry_detection(self, plot: bool = True, save: bool = False): """ Run the symmetry detection routine. @@ -163,7 +107,9 @@ def run_symmetry_detection(self, plot: bool = True, save: bool = False): """ validation_data, predictions = self._get_model_predictions() accepted_data = self._filter_data(predictions, validation_data) - representation = self.model.get_embedding_layer_representation(accepted_data) # get the embedding layer + representation = self.model.get_embedding_layer_representation( + accepted_data + ) # get the embedding layer visualizer = Visualizer(representation, accepted_data[:, -1]) data = visualizer.tsne_visualization(plot=plot, save=save) diff --git a/symsuite/utils/data_clustering.py b/symsuite/utils/data_clustering.py index 84265cf..5dd77a7 100644 --- a/symsuite/utils/data_clustering.py +++ b/symsuite/utils/data_clustering.py @@ -8,13 +8,13 @@ Description: Methods to help with clustering data. """ -import tensorflow as tf -import numpy as np from typing import Tuple -import sys + +import jax.numpy as jnp +import numpy as np -def _build_condlist(data: np.array, bin_values: dict) -> Tuple: +def _build_condition_list(data: np.array, bin_values: dict) -> Tuple: """ Build the condition list for the piecewise implementation. @@ -37,16 +37,14 @@ def _build_condlist(data: np.array, bin_values: dict) -> Tuple: classes = [] for key in bin_values: conditions.append( - np.logical_and( - data >= (bin_values[key][0]), data <= (bin_values[key][1]) - ) + np.logical_and(data >= (bin_values[key][0]), data <= (bin_values[key][1])) ) classes.append(key) return conditions, classes -def _function_to_bins(function_values: tf.Tensor, bin_values: dict) -> tf.Tensor: +def _function_to_bins(function_values: jnp.ndarray, bin_values: dict) -> jnp.ndarray: """ Sort function values into bins. @@ -59,29 +57,49 @@ def _function_to_bins(function_values: tf.Tensor, bin_values: dict) -> tf.Tensor Returns ------- - conditions : tf.Tensor + conditions : jnp.ndarrau Conditions from the cond list build. """ - conditions, functions = _build_condlist(function_values, bin_values) + conditions, functions = _build_condition_list(function_values, bin_values) + + return jnp.array(conditions) + + +def to_categorical(data: jnp.ndarray): + """ + Implementation of the keras.to_categorical function + + Parameters + ---------- + data : jnp.ndarray (n_points,) + + Returns + ------- + categorical_data : jnp.ndarray (n_points, n_classes) + Data converted into categorical format. + """ + order = int(max(data) + 1) + classes = jnp.eye(order) - return tf.convert_to_tensor(conditions) + return jnp.take(classes, data, axis=0) def range_binning( - image: tf.Tensor, - domain: tf.Tensor, - value_range: list, - bin_operation: list, - representatives: int = 100) -> dict: + image: jnp.ndarray, + domain: jnp.ndarray, + value_range: list, + bin_operation: list, + representatives: int = 100, +) -> dict: """ A method to apply simple range binning to some data. Parameters ---------- - image : tf.Tensor + image : jnp.ndarrau data to cluster. - domain : tf.Tensor + domain : jnp.ndarrau data pool to return clustered. representatives : int Number of class representatives to have for each bin. @@ -110,7 +128,7 @@ def range_binning( bin_masks = _function_to_bins(image, bin_values) # Check that there is enough data in each class. - bin_count = tf.reduce_sum(tf.cast(bin_masks, tf.int8), 1) + bin_count = jnp.sum(bin_masks.astype(int), axis=1) if any(bin_count) < representatives: print("WARNING: Not enough data! Some classes will be under-represented.") @@ -119,10 +137,10 @@ def range_binning( clustered_data = {} for i in range(len(class_keys)): clustered_data[class_keys[i]] = {} - filtered_domain = tf.boolean_mask(domain, bin_masks[i]) - filtered_image = tf.boolean_mask(image, bin_masks[i]) - clustered_data[class_keys[i]]['domain'] = filtered_domain[0:representatives] - clustered_data[class_keys[i]]['image'] = filtered_image[0:representatives] + filtered_domain = domain[bin_masks[i]] + filtered_image = image[bin_masks[i]] + clustered_data[class_keys[i]]["domain"] = filtered_domain[0:representatives] + clustered_data[class_keys[i]]["image"] = filtered_image[0:representatives] return clustered_data @@ -140,7 +158,7 @@ def compute_com(data: np.ndarray): ------- """ - return tf.reduce_mean(data, axis=0) + return jnp.mean(data, axis=0) def compute_radius_of_gyration(data: np.ndarray, com: np.ndarray): @@ -156,6 +174,6 @@ def compute_radius_of_gyration(data: np.ndarray, com: np.ndarray): ------- """ - rg_primitive = tf.reduce_sum((data - com)**2, axis=1) + rg_primitive = jnp.sum((data - com) ** 2, axis=1) - return tf.reduce_mean(rg_primitive, axis=0) + return jnp.mean(rg_primitive, axis=0)