Skip to content

Commit

Permalink
Merge pull request #792 from helmholtz-analytics/features/783-api-rel…
Browse files Browse the repository at this point in the history
…ational

Features/783 api relational
  • Loading branch information
lenablind authored Jun 14, 2021
2 parents 3338b9d + afd0506 commit e3ac1c1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 1 deletion.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ Example on 2 processes:
- [#728](https://github.com/helmholtz-analytics/heat/pull/728) New feature: `nn.DataParallelMultiGPU` which uses `torch.distributed` for local communication (for use with `optim.DASO`)
- [#728](https://github.com/helmholtz-analytics/heat/pull/728) New feature: `optim.DetectMetricPlateau` detects when a given metric plateaus.

### Relational
- [#792](https://github.com/helmholtz-analytics/heat/pull/792) API extension (aliases): `greater`,`greater_equal`, `less`, `less_equal`, `not_equal`

### Statistical Functions
- [#679](https://github.com/helmholtz-analytics/heat/pull/679) New feature: ``histc()`` and ``histogram()``

Expand Down
35 changes: 34 additions & 1 deletion heat/core/relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,20 @@
from . import dndarray
from . import types

__all__ = ["eq", "equal", "ge", "gt", "le", "lt", "ne"]
__all__ = [
"eq",
"equal",
"ge",
"greater",
"greater_equal",
"gt",
"le",
"less",
"less_equal",
"lt",
"ne",
"not_equal",
]


def eq(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
Expand Down Expand Up @@ -139,6 +152,10 @@ def ge(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarr
DNDarray.__ge__ = lambda self, other: ge(self, other)
DNDarray.__ge__.__doc__ = ge.__doc__

# alias
greater_equal = ge
greater_equal.__doc__ = ge.__doc__


def gt(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
"""
Expand Down Expand Up @@ -184,6 +201,10 @@ def gt(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarr
DNDarray.__gt__ = lambda self, other: gt(self, other)
DNDarray.__gt__.__doc__ = gt.__doc__

# alias
greater = gt
greater.__doc__ = gt.__doc__


def le(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
"""
Expand Down Expand Up @@ -229,6 +250,10 @@ def le(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarr
DNDarray.__le__ = lambda self, other: le(self, other)
DNDarray.__le__.__doc__ = le.__doc__

# alias
less_equal = le
less_equal.__doc__ = le.__doc__


def lt(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
"""
Expand Down Expand Up @@ -274,6 +299,10 @@ def lt(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarr
DNDarray.__lt__ = lambda self, other: lt(self, other)
DNDarray.__lt__.__doc__ = lt.__doc__

# alias
less = lt
less.__doc__ = lt.__doc__


def ne(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarray:
"""
Expand Down Expand Up @@ -318,3 +347,7 @@ def ne(x: Union[DNDarray, float, int], y: Union[DNDarray, float, int]) -> DNDarr

DNDarray.__ne__ = lambda self, other: ne(self, other)
DNDarray.__ne__.__doc__ = ne.__doc__

# alias
not_equal = ne
not_equal.__doc__ = ne.__doc__

0 comments on commit e3ac1c1

Please sign in to comment.