Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sam tov jax implementation #38

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion CI/unit_tests/data/test_double_well_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
Description: Test the double_well_potential module.
"""
import unittest

from symsuite.data.double_well_potential import DoubleWellPotential


Expand Down Expand Up @@ -52,5 +53,5 @@ def test_double_well(self):
self.assertEqual(len(self.generator.image), 500)


if __name__ == '__main__':
if __name__ == "__main__":
unittest.main()
72 changes: 72 additions & 0 deletions CI/unit_tests/distance_metrics/test_angular_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""
ZnRND: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Test the angular distance module.
"""
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import jax.numpy as np
from numpy.testing import assert_array_almost_equal

from symsuite.distance_metrics.angular_distance import AngularDistance


class TestAngularDistance:
"""
Class to test the cosine distance measure module.
"""

def test_angular_distance(self):
"""
Test the cosine similarity measure.

Returns
-------
Assert the correct answer is returned for orthogonal, parallel, and
somewhere in between.
"""
metric = AngularDistance()

# Test orthogonal vectors
point_1 = np.array([[1, 0]])
point_2 = np.array([[0, 1]])
assert_array_almost_equal(metric(point_1, point_2), [0.5])

# Test parallel vectors
point_1 = np.array([[1, 0]])
point_2 = np.array([[1, 1]])
assert_array_almost_equal(metric(point_1, point_2), [0.25])

def test_multiple_distances(self):
"""
Test the cosine similarity measure.

Returns
-------
Assert the correct answer is returned for orthogonal, parallel, and
somewhere in between.
"""
metric = AngularDistance()

# Test orthogonal vectors
point_1 = np.array([[1, 0], [1, 0]])
point_2 = np.array([[0, 1], [1, 1]])
assert_array_almost_equal(metric(point_1, point_2), [0.5, 0.25])
75 changes: 75 additions & 0 deletions CI/unit_tests/distance_metrics/test_cosine_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
ZnRND: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:
Summary
-------
Test the cosine distance module.
"""
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import jax.numpy as np
from numpy.testing import assert_array_almost_equal

from symsuite.distance_metrics.cosine_distance import CosineDistance


class TestCosineDistance:
"""
Class to test the cosine distance measure module.
"""

def test_cosine_distance(self):
"""
Test the cosine similarity measure.

Returns
-------
Assert the correct answer is returned for orthogonal, parallel, and
somewhere in between.
"""
metric = CosineDistance()

# Test orthogonal vectors
point_1 = np.array([[1, 0, 0, 0]])
point_2 = np.array([[0, 1, 0, 0]])
assert_array_almost_equal(metric(point_1, point_2), [1])

# Test parallel vectors
assert_array_almost_equal(metric(point_1, point_1), [0])

# Somewhere in between
point_1 = np.array([[1.0, 0, 0, 0]])
point_2 = np.array([[0.5, 1.0, 0, 3.0]])
assert_array_almost_equal(metric(point_1, point_2), [0.84382623])

def test_multiple_distances(self):
"""
Test the cosine similarity measure.

Returns
-------
Assert the correct answer is returned for orthogonal, parallel, and
somewhere in between.
"""
metric = CosineDistance()

# Test orthogonal vectors
point_1 = np.array([[1, 0, 0, 0], [1, 0, 0, 0], [1.0, 0, 0, 0]])
point_2 = np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0.5, 1.0, 0, 3.0]])
assert_array_almost_equal(metric(point_1, point_2), [1, 0, 0.843826], decimal=6)
87 changes: 87 additions & 0 deletions CI/unit_tests/distance_metrics/test_hyper_sphere_distance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
ZnRND: A Zincwarecode package.

License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.

Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/

Citation
--------
If you use this module please cite us with:

Summary
-------
Test the hyper sphere distance module.
"""
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import jax.numpy as np
from numpy.testing import assert_array_almost_equal

from symsuite.distance_metrics.hyper_sphere_distance import HyperSphere


class TestCosineDistance:
"""
Class to test the cosine distance measure module.
"""

def test_hyper_sphere_distance(self):
"""
Test the hyper sphere distance.

Returns
-------
Assert the correct answer is returned for orthogonal, parallel, and
somewhere in between.
"""
metric = HyperSphere(order=2)

# Test orthogonal vectors
point_1 = np.array([[1, 0, 0, 0]])
point_2 = np.array([[0, 1, 0, 0]])
assert_array_almost_equal(metric(point_1, point_2), [1.41421356])

# Test parallel vectors
point_1 = np.array([[1, 0, 0, 0]])
point_2 = np.array([[1, 0, 0, 0]])
assert_array_almost_equal(metric(point_1, point_2), [0])

# Somewhere in between
point_1 = np.array([[1.0, 0, 0, 0]])
point_2 = np.array([[0.5, 1.0, 0, 3.0]])
assert_array_almost_equal(
metric(point_1, point_2), [0.84382623 * np.sqrt(10.25)]
)

def test_multiple_distances(self):
"""
Test the hyper sphere distance.

Returns
-------
Assert the correct answer is returned for orthogonal, parallel, and
somewhere in between.
"""
metric = HyperSphere(order=2)

# Test orthogonal vectors
point_1 = np.array([[1, 0, 0, 0], [1, 0, 0, 0], [1.0, 0, 0, 0]])
point_2 = np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0.5, 1.0, 0, 3.0]])
assert_array_almost_equal(
metric(point_1, point_2),
[np.sqrt(2), 0, 0.84382623 * np.sqrt(10.25)],
decimal=6,
)
85 changes: 85 additions & 0 deletions CI/unit_tests/distance_metrics/test_l_p_norm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
ZnRND: A Zincwarecode package.
License
-------
This program and the accompanying materials are made available under the terms
of the Eclipse Public License v2.0 which accompanies this distribution, and is
available at https://www.eclipse.org/legal/epl-v20.html
SPDX-License-Identifier: EPL-2.0
Copyright Contributors to the Zincwarecode Project.
Contact Information
-------------------
email: [email protected]
github: https://github.com/zincware
web: https://zincwarecode.com/
Citation
--------
If you use this module please cite us with:

Summary
-------
Test the l_p norm metric.
"""
import os

os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import jax.numpy as np
from numpy.testing import assert_almost_equal, assert_array_almost_equal

from symsuite.distance_metrics.l_p_norm import LPNorm


class TestLPNorm:
"""
Class to test the cosine distance measure module.
"""

def test_l_2_distance(self):
"""
Test the cosine similarity measure.

Returns
-------
Assert the correct answer is returned for orthogonal, parallel, and
somewhere in between.
"""
metric = LPNorm(order=2)

# Test orthogonal vectors
point_1 = np.array([[1.0, 7.0, 0.0, 0.0]])
point_2 = np.array([[1.0, 1.0, 0.0, 0.0]])

assert_array_almost_equal(metric(point_1, point_2), [6.0])

def test_l_3_distance(self):
"""
Test the cosine similarity measure.

Returns
-------
Assert the correct answer is returned for orthogonal, parallel, and
somewhere in between.
"""
metric = LPNorm(order=3)

# Test orthogonal vectors
point_1 = np.array([[1.0, 7.0, 0.0, 0.0]])
point_2 = np.array([[1.0, 1.0, 0.0, 0.0]])
assert_almost_equal(metric(point_1, point_2), [6.0], decimal=4)

def test_multi_distance(self):
"""
Test the cosine similarity measure.

Returns
-------
Assert the correct answer is returned for orthogonal, parallel, and
somewhere in between.
"""
metric = LPNorm(order=1)

# Test orthogonal vectors
point_1 = np.array([[1.0, 7.0, 0.0, 0.0], [4, 7, 2, 1]])
point_2 = np.array([[1.0, 1.0, 0.0, 0.0], [6, 3, 1, 8]])
assert_array_almost_equal(metric(point_1, point_2), [6.0, 14.0], decimal=4)
Loading