Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add normal, standard_normal #858

Merged
merged 4 commits into from
Sep 16, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@
- [#829](https://github.com/helmholtz-analytics/heat/pull/829) New feature: `roll`
- [#853](https://github.com/helmholtz-analytics/heat/pull/853) New Feature: `swapaxes`
- [#854](https://github.com/helmholtz-analytics/heat/pull/854) New Feature: `moveaxis`
### Random
- [#858](https://github.com/helmholtz-analytics/heat/pull/858) New Feature: `standard_normal`, `normal`

# v1.1.1
- [#864](https://github.com/helmholtz-analytics/heat/pull/864) Dependencies: constrain `torchvision` version range to match supported `pytorch` version range.

# v1.1.0

## Highlights
- Slicing/indexing overhaul for a more NumPy-like user experience. Warning for distributed arrays: [breaking change!](#breaking-changes) Indexing one element along the distribution axis now implies the indexed element is communicated to all processes.
- More flexibility in handling non-load-balanced distributed arrays.
Expand Down
117 changes: 117 additions & 0 deletions heat/core/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from . import communication
from . import devices
from . import factories
from . import logical
from . import stride_tricks
from . import types

Expand All @@ -19,6 +20,7 @@

__all__ = [
"get_state",
"normal",
"permutation",
"rand",
"ranf",
Expand All @@ -31,6 +33,7 @@
"sample",
"seed",
"set_state",
"standard_normal",
]

# introduce the global random state variables, will be correctly initialized at the end of file
Expand Down Expand Up @@ -262,6 +265,64 @@ def __kundu_transform(values: torch.Tensor) -> torch.Tensor:
return (torch.log(-torch.log(inner + tiny) + tiny) - 1.0821) * __KUNDU_INVERSE


def normal(
mean: Union[float, DNDarray] = 0.0,
std: Union[float, DNDarray] = 1.0,
shape: Optional[Tuple[int, ...]] = None,
dtype: Type[datatype] = types.float32,
split: Optional[int] = None,
device: Optional[str] = None,
comm: Optional[Communication] = None,
) -> DNDarray:
"""
Returns an array filled with random numbers from a normal distribution whose mean and standard deviation are given.
If `std` and `mean` are DNDarrays, they have to match `shape`.

Parameters
----------
mean : float or DNDarray
The mean of the distribution.
std : float or DNDarray
The standard deviation of the distribution. Must be non-negative.
shape : tuple[int]
The shape of the returned array, should all be positive. If no argument is given a single random sample is
generated.
dtype : Type[datatype], optional
The datatype of the returned values. Has to be one of :class:`~heat.core.types.float32` or
:class:`~heat.core.types.float64`.
split : int, optional
The axis along which the array is split and distributed, defaults to no distribution.
device : str, optional
Specifies the :class:`~heat.core.devices.Device` the array shall be allocated on, defaults to globally
set default device.
comm : Communication, optional
Handle to the nodes holding distributed parts or copies of this array.

See Also
--------
randn
Uses the standard normal distribution
standard_noramal
Uses the standard normal distribution

Examples
--------
>>> ht.random.normal(ht.array([-1,2]), ht.array([0.5, 2]), (2,))
DNDarray([-1.4669, 1.6596], dtype=ht.float64, device=cpu:0, split=None)
"""
if not (isinstance(mean, float) or isinstance(mean, int)) and not isinstance(mean, DNDarray):
raise TypeError("'mean' must be float or DNDarray")
if not (isinstance(std, float) or isinstance(std, int)) and not isinstance(std, DNDarray):
raise TypeError("'mean' must be float or DNDarray")

if ((isinstance(std, float) or isinstance(std, int)) and std < 0) or (
isinstance(std, DNDarray) and logical.any(std < 0)
):
raise ValueError("'std' must be non-negative")

return mean + std * standard_normal(shape, dtype, split, device, comm)


def permutation(x: Union[int, DNDarray]) -> DNDarray:
"""
Randomly permute a sequence, or return a permuted range. If ``x`` is a multi-dimensional array, it is only shuffled
Expand Down Expand Up @@ -541,6 +602,13 @@ def randn(
comm : Communication, optional
Handle to the nodes holding distributed parts or copies of this array.

See Also
--------
normal
Similar, but takes a tuple as its argumant.
standard_normal
Accepts arguments for mean and standard deviation.

Raises
-------
TypeError
Expand Down Expand Up @@ -744,6 +812,55 @@ def set_state(state: Tuple[str, int, int, int, float]):
__counter = int(state[2])


def standard_normal(
shape: Optional[Tuple[int, ...]] = None,
dtype: Type[datatype] = types.float32,
split: Optional[int] = None,
device: Optional[str] = None,
comm: Optional[Communication] = None,
) -> DNDarray:
"""
Returns an array filled with random numbers from a standard normal distribution with zero mean and variance of one.

Parameters
----------
shape : tuple[int]
The shape of the returned array, should all be positive. If no argument is given a single random sample is
generated.
dtype : Type[datatype], optional
The datatype of the returned values. Has to be one of :class:`~heat.core.types.float32` or
:class:`~heat.core.types.float64`.
split : int, optional
The axis along which the array is split and distributed, defaults to no distribution.
device : str, optional
Specifies the :class:`~heat.core.devices.Device` the array shall be allocated on, defaults to globally
set default device.
comm : Communication, optional
Handle to the nodes holding distributed parts or copies of this array.

See Also
--------
randn
Similar, but accepts separate arguments for the shape dimensions.
normal
Equivalent function with arguments for the mean and standard deviation.

Examples
--------
>>> ht.random.standard_normal((3,))
DNDarray([ 0.1921, -0.9635, 0.5047], dtype=ht.float32, device=cpu:0, split=None)
>>> ht.random.standard_normal((4, 4))
DNDarray([[-1.1261, 0.5971, 0.2851, 0.9998],
[-1.8548, -1.2574, 0.2391, -0.3302],
[ 1.3365, -1.5212, 1.4159, -0.1671],
[ 0.1260, 1.2126, -0.0804, 0.0907]], dtype=ht.float32, device=cpu:0, split=None)
"""
if not shape:
shape = (1,)
shape = stride_tricks.sanitize_shape(shape)
return randn(*shape, dtype=dtype, split=split, device=device, comm=comm)


def __threefry32(
x0: torch.Tensor, x1: torch.Tensor, seed: int
) -> Tuple[torch.Tensor, torch.Tensor]:
Expand Down
53 changes: 53 additions & 0 deletions heat/core/tests/test_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,41 @@


class TestRandom(TestCase):
def test_normal(self):
shape = (3, 4, 6)
ht.random.seed(2)
gnormal = ht.random.normal(shape=shape, split=2)
ht.random.seed(2)
snormal = ht.random.randn(*shape, split=2)

self.assertEqual(gnormal.dtype, snormal.dtype)
self.assertEqual(gnormal.shape, snormal.shape)
self.assertEqual(gnormal.device, snormal.device)
self.assertTrue(ht.equal(gnormal, snormal))

shape = (2, 2)
mu = ht.array([[-1, -0.5], [0, 5]])
sigma = ht.array([[0, 0.5], [1, 2.5]])

ht.random.seed(22)
gnormal = ht.random.normal(mu, sigma, shape)
ht.random.seed(22)
snormal = ht.random.randn(*shape)

compare = mu + sigma * snormal

self.assertEqual(gnormal.dtype, compare.dtype)
self.assertEqual(gnormal.shape, compare.shape)
self.assertEqual(gnormal.device, compare.device)
self.assertTrue(ht.equal(gnormal, compare))

with self.assertRaises(TypeError):
ht.random.normal([4, 5], 1, shape)
with self.assertRaises(TypeError):
ht.random.normal(0, "r", shape)
with self.assertRaises(ValueError):
ht.random.normal(0, -1, shape)

def test_permutation(self):
# Reset RNG
ht.random.seed()
Expand Down Expand Up @@ -418,3 +453,21 @@ def test_set_state(self):
ht.random.set_state(("Thrfry", 12, 0xF))
with self.assertRaises(TypeError):
ht.random.set_state(("Threefry", 12345))

def test_standard_normal(self):
# empty input
stdn = ht.random.standard_normal()
self.assertEqual(stdn.dtype, ht.float32)
self.assertEqual(stdn.shape, (1,))

# simple test
shape = (3, 4, 6)
ht.random.seed(11235)
stdn = ht.random.standard_normal(shape, split=2)
ht.random.seed(11235)
rndn = ht.random.randn(*shape, split=2)

self.assertEqual(stdn.shape, rndn.shape)
self.assertEqual(stdn.dtype, rndn.dtype)
self.assertEqual(stdn.device, rndn.device)
self.assertTrue(ht.equal(stdn, rndn))