Skip to content

Commit

Permalink
implement digitize/bucketize
Browse files Browse the repository at this point in the history
Fixes #926
  • Loading branch information
mtar committed Mar 4, 2022
1 parent 0c89691 commit c27df91
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 0 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
- [#858](https://github.com/helmholtz-analytics/heat/pull/858) New Feature: `standard_normal`, `normal`
### Rounding
- [#827](https://github.com/helmholtz-analytics/heat/pull/827) New feature: `sign`, `sgn`
### Statistics
- [#928](https://github.com/helmholtz-analytics/heat/pull/928) New feature: `bucketize`, `digitize`

# v1.1.1
- [#864](https://github.com/helmholtz-analytics/heat/pull/864) Dependencies: constrain `torchvision` version range to match supported `pytorch` version range.
Expand Down
150 changes: 150 additions & 0 deletions heat/core/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
"argmin",
"average",
"bincount",
"bucketize",
"cov",
"digitize",
"histc",
"histogram",
"kurtosis",
Expand Down Expand Up @@ -388,6 +390,79 @@ def bincount(x: DNDarray, weights: Optional[DNDarray] = None, minlength: int = 0
)


def bucketize(
input: DNDarray,
boundaries: Union[DNDarray, torch.Tensor],
out_int32: bool = False,
right: bool = False,
out: DNDarray = None,
) -> DNDarray:
"""
Returns the indices of the buckets to which each value in the input belongs, where the boundaries of the buckets are set by boundaries.
Parameters
----------
input : DNDarray
The input array.
boundaries : DNDarray or torch.Tensor
monotonically increasing sequence defining the bucket boundaries, 1-dimensional
out_int32 : bool, optional
set the dtype of the output to ``ht.int64`` (`False`) or ``ht.int32`` (True)
right : bool, optional
indicate whether the buckets include the right (`False`) or left (`True`) boundaries, see Notes.
out : DNDarray, optional
The output array, must be the shame shape and split as the input array.
Notes
-----
This function uses the PyTorch's setting for ``right``:
===== ====================================
right returned index `i` satisfies
===== ====================================
False boundaries[i-1] < x <= boundaries[i]
True boundaries[i-1] <= x < boundaries[i]
===== ====================================
Raises
------
RuntimeError
If `boundaries` is distributed.
See Also
--------
digitize
NumPy-like version of this function.
Examples
--------
>>> boundaries = ht.array([1, 3, 5, 7, 9])
>>> v = ht.array([[3, 6, 9], [3, 6, 9]])
>>> ht.bucketize(v, boundaries)
DNDarray([[1, 3, 4],
[1, 3, 4]], dtype=ht.int64, device=cpu:0, split=None)
>>> ht.bucketize(v, boundaries, right=True)
DNDarray([[2, 3, 5],
[2, 3, 5]], dtype=ht.int64, device=cpu:0, split=None)
"""
if isinstance(boundaries, DNDarray):
if boundaries.is_distributed():
raise RuntimeError("'boundaries' must be undistributed.")
boundaries = boundaries.larray
else:
boundaries = torch.as_tensor(boundaries)

return _operations.__local_op(
torch.bucketize,
input,
out,
no_cast=True,
boundaries=boundaries,
out_int32=out_int32,
right=right,
)


def cov(
m: DNDarray,
y: Optional[DNDarray] = None,
Expand Down Expand Up @@ -463,6 +538,81 @@ def cov(
return c


def digitize(x: DNDarray, bins: Union[DNDarray, torch.Tensor], right: bool = False) -> DNDarray:
"""
Return the indices of the bins to which each value in the input array `x` belongs.
If values in `x` are beyond the bounds of bins, 0 or len(bins) is returned as appropriate.
Parameters
----------
x : DNDarray
The input array
bins : DNDarray or torch.Tensor
A 1-dimensional array containing a monotonic sequence describing the bin boundaries.
right : bool, optional
Indicating whether the intervals include the right or the left bin edge, see Notes.
Notes
-----
This function uses NumPy's setting for ``right``:
===== ============= ============================
right order of bins returned index `i` satisfies
===== ============= ============================
False increasing bins[i-1] <= x < bins[i]
True increasing bins[i-1] < x <= bins[i]
False decreasing bins[i-1] > x >= bins[i]
True decreasing bins[i-1] >= x > bins[i]
===== ============= ============================
Raises
------
RuntimeError
If `bins` is distributed.
See Also
--------
bucketize
PyTorch-like version of this function.
Examples
--------
>>> x = ht.array([1.2, 10.0, 12.4, 15.5, 20.])
>>> bins = ht.array([0, 5, 10, 15, 20])
>>> ht.digitize(x,bins,right=True)
DNDarray([1, 2, 3, 4, 4], dtype=ht.int64, device=cpu:0, split=None)
>>> ht.digitize(x,bins,right=False)
DNDarray([1, 3, 3, 4, 5], dtype=ht.int64, device=cpu:0, split=None)
"""
if isinstance(bins, DNDarray):
if bins.is_distributed():
raise RuntimeError("'bins' must be undistributed.")
bins = bins.larray
else:
bins = torch.as_tensor(bins)

reverse = False

if bins[0] > bins[-1]:
bins = torch.flipud(bins)
reverse = True

result = _operations.__local_op(
torch.bucketize,
x,
out=None,
no_cast=True,
boundaries=bins,
out_int32=False,
right=not right,
)

if reverse:
result = bins.numel() - result

return result


def histc(
input: DNDarray, bins: int = 100, min: int = 0, max: int = 0, out: Optional[DNDarray] = None
) -> DNDarray:
Expand Down
55 changes: 55 additions & 0 deletions heat/core/tests/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,33 @@ def test_bincount(self):
with self.assertRaises(ValueError):
ht.bincount(ht.array([0, 1, 2, 3], split=0), weights=ht.array([1, 2, 3, 4]))

def test_bucketize(self):
boundaries = ht.array([1, 3, 5, 7, 9])
v = ht.array([[3, 6, 9], [3, 6, 9]])
a = ht.bucketize(v, boundaries)

self.assertTrue(ht.equal(a, ht.array([[1, 3, 4], [1, 3, 4]])))
self.assertTrue(a.dtype, ht.int64)
self.assertTrue(a.shape, v.shape)

a = ht.bucketize(v, boundaries, right=True)
self.assertTrue(ht.equal(a, ht.array([[2, 3, 5], [2, 3, 5]])))
self.assertEqual(a.dtype, ht.int64)
self.assertTrue(a.shape, v.shape)

boundaries, _ = torch.sort(torch.rand(5))
v = torch.rand(6)
t = torch.bucketize(v, boundaries, out_int32=True)

v = ht.array(v, split=0)
a = ht.bucketize(v, boundaries, out_int32=True)
self.assertTrue(ht.equal(ht.resplit(a, None), ht.asarray(t)))
self.assertEqual(a.dtype, ht.int32)

if ht.MPI_WORLD.size > 1:
with self.assertRaises(RuntimeError):
ht.bucketize(a, ht.array([0.0, 0.5, 1.0], split=0))

def test_cov(self):
x = ht.array([[0, 2], [1, 1], [2, 0]], dtype=ht.float, split=1).T
if x.comm.size < 3:
Expand Down Expand Up @@ -442,6 +469,34 @@ def test_cov(self):
with self.assertRaises(ValueError):
ht.cov(htdata, ddof=10000)

def test_digitize(self):
x = ht.array([1.2, 10.0, 12.4, 15.5, 20.0])
bins = ht.array([0, 5, 10, 15, 20])
a = ht.digitize(x, bins, right=True)

self.assertTrue(ht.equal(a, ht.array([1, 2, 3, 4, 4])))
self.assertTrue(a.dtype, ht.int64)
self.assertTrue(a.shape, x.shape)

a = ht.digitize(x, bins, right=False)
self.assertTrue(ht.equal(a, ht.array([1, 3, 3, 4, 5])))
self.assertEqual(a.dtype, ht.int64)
self.assertTrue(a.shape, x.shape)

bins = np.sort(np.random.rand(5))
x = np.random.rand(6)
t = np.digitize(x, bins)

x = ht.array(x, split=0)
a = ht.digitize(x, bins)
self.assertTrue(ht.equal(ht.resplit(a, None), ht.asarray(t)))
self.assertEqual(a.dtype, ht.int64)
self.assertTrue(a.shape, x.shape)

if ht.MPI_WORLD.size > 1:
with self.assertRaises(RuntimeError):
ht.bucketize(a, ht.array([0.0, 0.5, 1.0], split=0))

def test_histc(self):
# few entries and float64
c = torch.arange(4, dtype=torch.float64, device=self.device.torch_device)
Expand Down

0 comments on commit c27df91

Please sign in to comment.