Skip to content

Commit

Permalink
Merge pull request #858 from helmholtz-analytics/feature/855-normaldist
Browse files Browse the repository at this point in the history
add normal, standard_normal
  • Loading branch information
coquelin77 authored Sep 16, 2021
2 parents 45d4a00 + 2e406e2 commit 631f113
Show file tree
Hide file tree
Showing 3 changed files with 172 additions and 2 deletions.
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
- [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll`
- [#853](https://github.com/helmholtz-analytics/heat/pull/853) New Feature: `swapaxes`
- [#854](https://github.com/helmholtz-analytics/heat/pull/854) New Feature: `moveaxis`
### Random
- [#858](https://github.com/helmholtz-analytics/heat/pull/858) New Feature: `standard_normal`, `normal`

# v1.1.1
- [#864](https://github.com/helmholtz-analytics/heat/pull/864) Dependencies: constrain `torchvision` version range to match supported `pytorch` version range.

# v1.1.0

## Highlights
- Slicing/indexing overhaul for a more NumPy-like user experience. Warning for distributed arrays: [breaking change!](#breaking-changes) Indexing one element along the distribution axis now implies the indexed element is communicated to all processes.
- More flexibility in handling non-load-balanced distributed arrays.
Expand Down
117 changes: 117 additions & 0 deletions heat/core/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import communication
from . import devices
from . import factories
from . import logical
from . import stride_tricks
from . import types

Expand All @@ -19,6 +20,7 @@

__all__ = [
"get_state",
"normal",
"permutation",
"rand",
"ranf",
Expand All @@ -31,6 +33,7 @@
"sample",
"seed",
"set_state",
"standard_normal",
]

# introduce the global random state variables, will be correctly initialized at the end of file
Expand Down Expand Up @@ -262,6 +265,64 @@ def __kundu_transform(values: torch.Tensor) -> torch.Tensor:
return (torch.log(-torch.log(inner + tiny) + tiny) - 1.0821) * __KUNDU_INVERSE


def normal(
mean: Union[float, DNDarray] = 0.0,
std: Union[float, DNDarray] = 1.0,
shape: Optional[Tuple[int, ...]] = None,
dtype: Type[datatype] = types.float32,
split: Optional[int] = None,
device: Optional[str] = None,
comm: Optional[Communication] = None,
) -> DNDarray:
"""
Returns an array filled with random numbers from a normal distribution whose mean and standard deviation are given.
If `std` and `mean` are DNDarrays, they have to match `shape`.
Parameters
----------
mean : float or DNDarray
The mean of the distribution.
std : float or DNDarray
The standard deviation of the distribution. Must be non-negative.
shape : tuple[int]
The shape of the returned array, should all be positive. If no argument is given a single random sample is
generated.
dtype : Type[datatype], optional
The datatype of the returned values. Has to be one of :class:`~heat.core.types.float32` or
:class:`~heat.core.types.float64`.
split : int, optional
The axis along which the array is split and distributed, defaults to no distribution.
device : str, optional
Specifies the :class:`~heat.core.devices.Device` the array shall be allocated on, defaults to globally
set default device.
comm : Communication, optional
Handle to the nodes holding distributed parts or copies of this array.
See Also
--------
randn
Uses the standard normal distribution
standard_noramal
Uses the standard normal distribution
Examples
--------
>>> ht.random.normal(ht.array([-1,2]), ht.array([0.5, 2]), (2,))
DNDarray([-1.4669, 1.6596], dtype=ht.float64, device=cpu:0, split=None)
"""
if not (isinstance(mean, float) or isinstance(mean, int)) and not isinstance(mean, DNDarray):
raise TypeError("'mean' must be float or DNDarray")
if not (isinstance(std, float) or isinstance(std, int)) and not isinstance(std, DNDarray):
raise TypeError("'mean' must be float or DNDarray")

if ((isinstance(std, float) or isinstance(std, int)) and std < 0) or (
isinstance(std, DNDarray) and logical.any(std < 0)
):
raise ValueError("'std' must be non-negative")

return mean + std * standard_normal(shape, dtype, split, device, comm)


def permutation(x: Union[int, DNDarray]) -> DNDarray:
"""
Randomly permute a sequence, or return a permuted range. If ``x`` is a multi-dimensional array, it is only shuffled
Expand Down Expand Up @@ -541,6 +602,13 @@ def randn(
comm : Communication, optional
Handle to the nodes holding distributed parts or copies of this array.
See Also
--------
normal
Similar, but takes a tuple as its argumant.
standard_normal
Accepts arguments for mean and standard deviation.
Raises
-------
TypeError
Expand Down Expand Up @@ -744,6 +812,55 @@ def set_state(state: Tuple[str, int, int, int, float]):
__counter = int(state[2])


def standard_normal(
shape: Optional[Tuple[int, ...]] = None,
dtype: Type[datatype] = types.float32,
split: Optional[int] = None,
device: Optional[str] = None,
comm: Optional[Communication] = None,
) -> DNDarray:
"""
Returns an array filled with random numbers from a standard normal distribution with zero mean and variance of one.
Parameters
----------
shape : tuple[int]
The shape of the returned array, should all be positive. If no argument is given a single random sample is
generated.
dtype : Type[datatype], optional
The datatype of the returned values. Has to be one of :class:`~heat.core.types.float32` or
:class:`~heat.core.types.float64`.
split : int, optional
The axis along which the array is split and distributed, defaults to no distribution.
device : str, optional
Specifies the :class:`~heat.core.devices.Device` the array shall be allocated on, defaults to globally
set default device.
comm : Communication, optional
Handle to the nodes holding distributed parts or copies of this array.
See Also
--------
randn
Similar, but accepts separate arguments for the shape dimensions.
normal
Equivalent function with arguments for the mean and standard deviation.
Examples
--------
>>> ht.random.standard_normal((3,))
DNDarray([ 0.1921, -0.9635, 0.5047], dtype=ht.float32, device=cpu:0, split=None)
>>> ht.random.standard_normal((4, 4))
DNDarray([[-1.1261, 0.5971, 0.2851, 0.9998],
[-1.8548, -1.2574, 0.2391, -0.3302],
[ 1.3365, -1.5212, 1.4159, -0.1671],
[ 0.1260, 1.2126, -0.0804, 0.0907]], dtype=ht.float32, device=cpu:0, split=None)
"""
if not shape:
shape = (1,)
shape = stride_tricks.sanitize_shape(shape)
return randn(*shape, dtype=dtype, split=split, device=device, comm=comm)


def __threefry32(
x0: torch.Tensor, x1: torch.Tensor, seed: int
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
53 changes: 53 additions & 0 deletions heat/core/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,41 @@


class TestRandom(TestCase):
def test_normal(self):
shape = (3, 4, 6)
ht.random.seed(2)
gnormal = ht.random.normal(shape=shape, split=2)
ht.random.seed(2)
snormal = ht.random.randn(*shape, split=2)

self.assertEqual(gnormal.dtype, snormal.dtype)
self.assertEqual(gnormal.shape, snormal.shape)
self.assertEqual(gnormal.device, snormal.device)
self.assertTrue(ht.equal(gnormal, snormal))

shape = (2, 2)
mu = ht.array([[-1, -0.5], [0, 5]])
sigma = ht.array([[0, 0.5], [1, 2.5]])

ht.random.seed(22)
gnormal = ht.random.normal(mu, sigma, shape)
ht.random.seed(22)
snormal = ht.random.randn(*shape)

compare = mu + sigma * snormal

self.assertEqual(gnormal.dtype, compare.dtype)
self.assertEqual(gnormal.shape, compare.shape)
self.assertEqual(gnormal.device, compare.device)
self.assertTrue(ht.equal(gnormal, compare))

with self.assertRaises(TypeError):
ht.random.normal([4, 5], 1, shape)
with self.assertRaises(TypeError):
ht.random.normal(0, "r", shape)
with self.assertRaises(ValueError):
ht.random.normal(0, -1, shape)

def test_permutation(self):
# Reset RNG
ht.random.seed()
Expand Down Expand Up @@ -418,3 +453,21 @@ def test_set_state(self):
ht.random.set_state(("Thrfry", 12, 0xF))
with self.assertRaises(TypeError):
ht.random.set_state(("Threefry", 12345))

def test_standard_normal(self):
# empty input
stdn = ht.random.standard_normal()
self.assertEqual(stdn.dtype, ht.float32)
self.assertEqual(stdn.shape, (1,))

# simple test
shape = (3, 4, 6)
ht.random.seed(11235)
stdn = ht.random.standard_normal(shape, split=2)
ht.random.seed(11235)
rndn = ht.random.randn(*shape, split=2)

self.assertEqual(stdn.shape, rndn.shape)
self.assertEqual(stdn.dtype, rndn.dtype)
self.assertEqual(stdn.device, rndn.device)
self.assertTrue(ht.equal(stdn, rndn))

0 comments on commit 631f113

Please sign in to comment.