-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Features/125 modf #402
Merged
Merged
Features/125 modf #402
Changes from 42 commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
ec6c508
BROKEN. First implementation of modf
lenablind f70fb61
Merge remote-tracking branch 'origin/master' into features/125-modf
lenablind b3d1c71
BROKEN. Implementation of modf (and round)
lenablind 71b8832
Modf and round modified.
lenablind 828cb50
Implemented test_round().
ClaudiaComito 9a8d6f3
Implementation of round in dndArray
lenablind 4ae4be5
Functions in alphabetical order
lenablind 2e57691
Test function modf
lenablind b045297
Added modf to dndArray
lenablind 3d23b9d
Added option of user-defined output buffer for ht.modf().
ClaudiaComito b99d531
Merge branch 'master' into features/125-modf
ClaudiaComito 11aa1f1
modf with out
lenablind bbe33a6
Expanded test_modf for out
lenablind e123b3c
Adaptation to pre-commit
lenablind 0ce6890
Merge branch 'master' into features/125-modf
lenablind f25a448
Merge branch 'master' into features/125-modf
coquelin77 5d088b9
Implementation of requested changes
lenablind db796cc
Code reformatted via black
lenablind dac92bd
Merge branch 'master' into features/125-modf
lenablind efcefbd
Tests with split tensors
lenablind 46b001c
Tests correction
lenablind 4c9e548
Merge branch 'master' into features/125-modf
lenablind f456d7b
Integration of ht.equal for tests
lenablind 2921012
Merge branch 'master' into features/125-modf
lenablind e4ed000
Merge branch 'master' into features/125-modf
coquelin77 1ecdeba
Debugging attempts.
ClaudiaComito 6a82e9c
Fixing debugging attempts.
ClaudiaComito 342e661
test_modf(), test_round(): defining test tensors so that they are alw…
ClaudiaComito 0f3d6e4
More debugging attempts.
ClaudiaComito 7ad1fb3
Removed print/debugging statements
ClaudiaComito e5aa64e
Merge branch 'master' into features/125-modf
ClaudiaComito 387d1f4
Debugging attempts.
ClaudiaComito 0af3803
Debugging
ClaudiaComito 9ef2614
Debugging. Removed test_modf() and test_round()
ClaudiaComito 865b868
In assert_array_equal(), Allreduce running on self._comm, not on sel…
ClaudiaComito d0613dc
Debugging. Replacing failing array comparison with BasicTest.assert_a…
ClaudiaComito bdc5be2
rest_round(), replacing all assertTrue(ht.equal(...)) with BasicTest.…
ClaudiaComito 30464c3
Debugging. test_round(), remiving distributed tests
ClaudiaComito c5810a6
Debugging. Adding back distributed test_round one bit at a time.
ClaudiaComito 3ecd850
Small changes after pre-commit failed.
ClaudiaComito c383c78
Debugging. Adding back distributed tests for test_round(). Replaced h…
ClaudiaComito 7f5e34b
Replaced ht.arange with ht.array for non-distribution case
lenablind 567c2e4
Changed inheritance hierarchy of test_rounding to BasicTest
lenablind 05b319c
Integration of assert_array_equal within test_round
lenablind 58ac9db
Replacing ht.arange with ht.array(npArray)
ClaudiaComito e85e5c3
Merge branch 'master' into features/125-modf
coquelin77 9356e81
extending coverage
8bd1ae7
corrected indent in docs for modf
cfa1f74
corrected indent in docs for modf, minor formatting
2b52855
self replacing x/a in modf and round
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,6 +3,8 @@ | |
import numpy as np | ||
import heat as ht | ||
|
||
from heat.core.tests.test_suites.basic_test import BasicTest | ||
|
||
|
||
class TestRounding(unittest.TestCase): | ||
def test_abs(self): | ||
|
@@ -162,6 +164,139 @@ def test_floor(self): | |
with self.assertRaises(TypeError): | ||
ht.floor(object()) | ||
|
||
# def test_modf(self): | ||
# size = ht.communication.MPI_WORLD.size | ||
# start, end = -5.0, 5.0 | ||
# step = (end - start) / (2 * size) | ||
# comparison = np.modf(np.arange(start, end, step, np.float32)) | ||
# | ||
# # exponential of float32 | ||
# float32_tensor = ht.arange(start, end, step, dtype=ht.float32) | ||
# float32_modf = float32_tensor.modf() | ||
# self.assertIsInstance(float32_modf[0], ht.DNDarray) | ||
# self.assertIsInstance(float32_modf[1], ht.DNDarray) | ||
# self.assertEqual(float32_modf[0].dtype, ht.float32) | ||
# self.assertEqual(float32_modf[1].dtype, ht.float32) | ||
# | ||
# self.assertAlmostEqual(float32_modf[0].numpy().all(), comparison[0].all()) | ||
# self.assertAlmostEqual(float32_modf[1].numpy().all(), comparison[1].all()) | ||
# | ||
# self.assertAlmostEqual(ht.all(float32_modf[0]), ht.all(ht.array(comparison[0]))) | ||
# self.assertAlmostEqual(ht.all(float32_modf[1]), ht.all(ht.array(comparison[1]))) | ||
# | ||
# # exponential of float64 | ||
# comparison = np.modf(np.arange(start, end, step, np.float64)) | ||
# | ||
# float64_tensor = ht.arange(start, end, step, dtype=ht.float64) | ||
# float64_modf = float64_tensor.modf() | ||
# self.assertIsInstance(float64_modf[0], ht.DNDarray) | ||
# self.assertIsInstance(float64_modf[1], ht.DNDarray) | ||
# self.assertEqual(float64_modf[0].dtype, ht.float64) | ||
# self.assertEqual(float64_modf[1].dtype, ht.float64) | ||
# | ||
# self.assertAlmostEqual(float64_modf[0].numpy().all(), comparison[0].all()) | ||
# self.assertAlmostEqual(float64_modf[1].numpy().all(), comparison[1].all()) | ||
# | ||
# self.assertAlmostEqual(ht.all(float64_modf[0]), ht.all(ht.array(comparison[0]))) | ||
# self.assertAlmostEqual(ht.all(float64_modf[1]), ht.all(ht.array(comparison[1]))) | ||
# | ||
# # check exceptions | ||
# with self.assertRaises(TypeError): | ||
# ht.modf([0, 1, 2, 3]) | ||
# with self.assertRaises(TypeError): | ||
# ht.modf(object()) | ||
# with self.assertRaises(TypeError): | ||
# ht.modf(float32_tensor, 1) | ||
# with self.assertRaises(ValueError): | ||
# ht.modf(float32_tensor, (float32_tensor, float32_tensor, float64_tensor)) | ||
# with self.assertRaises(TypeError): | ||
# ht.modf(float32_tensor, (float32_tensor, 2)) | ||
# | ||
# # # with split tensors | ||
# | ||
# # exponential of float32 | ||
# float32_tensor_distrbd = ht.arange(start, end, step, dtype=ht.float32, split=0) | ||
# float32_modf_distrbd = float32_tensor_distrbd.modf() | ||
# | ||
# self.assertIsInstance(float32_modf_distrbd[0], ht.DNDarray) | ||
# self.assertIsInstance(float32_modf_distrbd[1], ht.DNDarray) | ||
# self.assertEqual(float32_modf_distrbd[0].dtype, ht.float32) | ||
# self.assertEqual(float32_modf_distrbd[1].dtype, ht.float32) | ||
# | ||
# self.assertAlmostEqual(float32_modf_distrbd[0].numpy().all(), comparison[0].all()) | ||
# self.assertAlmostEqual(float32_modf_distrbd[1].numpy().all(), comparison[1].all()) | ||
# | ||
# self.assertAlmostEqual(ht.all(float32_modf_distrbd[0]), ht.all(ht.array(comparison[0]))) | ||
# self.assertAlmostEqual(ht.all(float32_modf_distrbd[1]), ht.all(ht.array(comparison[1]))) | ||
# | ||
# # exponential of float64 | ||
# comparison = np.modf(np.arange(start, end, step, np.float64)) | ||
# | ||
# float64_tensor_distrbd = ht.arange(start, end, step, dtype=ht.float64, split=0) | ||
# float64_modf_distrbd = float64_tensor_distrbd.modf() | ||
# self.assertIsInstance(float64_modf_distrbd[0], ht.DNDarray) | ||
# self.assertIsInstance(float64_modf_distrbd[1], ht.DNDarray) | ||
# self.assertEqual(float64_modf_distrbd[0].dtype, ht.float64) | ||
# self.assertEqual(float64_modf_distrbd[1].dtype, ht.float64) | ||
# | ||
# self.assertAlmostEqual(float64_modf_distrbd[0].numpy().all(), comparison[0].all()) | ||
# self.assertAlmostEqual(float64_modf_distrbd[1].numpy().all(), comparison[1].all()) | ||
# | ||
# self.assertAlmostEqual(ht.all(float64_modf_distrbd[0]), ht.all(ht.array(comparison[0]))) | ||
# self.assertAlmostEqual(ht.all(float64_modf_distrbd[1]), ht.all(ht.array(comparison[1]))) | ||
|
||
def test_round(self): | ||
size = ht.communication.MPI_WORLD.size | ||
start, end = -5.0, 5.0 | ||
step = (end - start) / (2 * size) | ||
comparison = torch.arange(start, end, step, dtype=torch.float32).round() | ||
|
||
# exponential of float32 | ||
float32_tensor = ht.array(comparison, dtype=ht.float32) | ||
float32_round = float32_tensor.round() | ||
self.assertIsInstance(float32_round, ht.DNDarray) | ||
self.assertEqual(float32_round.dtype, ht.float32) | ||
self.assertEqual(float32_round.dtype, ht.float32) | ||
BasicTest.assert_array_equal(self, float32_round, comparison) | ||
|
||
# exponential of float64 | ||
comparison = torch.arange(start, end, step, dtype=torch.float64).round() | ||
float64_tensor = ht.array(comparison, dtype=ht.float64) | ||
float64_round = float64_tensor.round() | ||
self.assertIsInstance(float64_round, ht.DNDarray) | ||
self.assertEqual(float64_round.dtype, ht.float64) | ||
self.assertEqual(float64_round.dtype, ht.float64) | ||
BasicTest.assert_array_equal(self, float64_round, comparison) | ||
|
||
# check exceptions | ||
with self.assertRaises(TypeError): | ||
ht.round([0, 1, 2, 3]) | ||
with self.assertRaises(TypeError): | ||
ht.round(object()) | ||
with self.assertRaises(TypeError): | ||
ht.round(float32_tensor, 1, 1) | ||
|
||
# with split tensors | ||
|
||
# exponential of float32 | ||
comparison = torch.arange(start, end, step, dtype=torch.float32) # .round() | ||
float32_tensor_distrbd = ht.array(comparison, split=0) | ||
comparison = comparison.round() | ||
float32_round_distrbd = float32_tensor_distrbd.round() | ||
self.assertIsInstance(float32_round_distrbd, ht.DNDarray) | ||
self.assertEqual(float32_round_distrbd.dtype, ht.float32) | ||
BasicTest.assert_array_equal(self, float32_round_distrbd, comparison) | ||
|
||
# exponential of float64 | ||
comparison = torch.arange(start, end, step, dtype=torch.float64) # .round() | ||
float64_tensor_distrbd = ht.array(comparison, split=0) | ||
comparison = comparison.round() | ||
float64_round_distrbd = float64_tensor_distrbd.round() | ||
self.assertIsInstance(float64_round_distrbd, ht.DNDarray) | ||
self.assertEqual(float64_round_distrbd.dtype, ht.float64) | ||
self.assertEqual(float64_round_distrbd.dtype, ht.float64) | ||
BasicTest.assert_array_equal(self, float64_round_distrbd, comparison) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's better to extend the BasicTest class and then use |
||
|
||
def test_trunc(self): | ||
base_array = np.random.randn(20) | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The BasicTest class is meant as an extension to the
unittest.TestCase
class and should therefore be extended when you want to use it.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for your feedback! I tried to implement your requested changes.