Skip to content

Commit

Permalink
FEAT-#7310: NumPy 2.0 support (#7312)
Browse files Browse the repository at this point in the history
Signed-off-by: Anatoly Myachev <[email protected]>
  • Loading branch information
anmyachev authored Jun 13, 2024
1 parent 5c1c6a2 commit 767e28e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 53 deletions.
56 changes: 30 additions & 26 deletions modin/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,27 +11,28 @@
# ANY KIND, either express or implied. See the License for the specific language
# governing permissions and limitations under the License.

import numpy
from packaging import version

from . import linalg
from .arr import array
from .array_creation import ones_like, tri, zeros_like
from .array_shaping import append, hstack, ravel, shape, split, transpose
from .constants import (
NAN,
NINF,
NZERO,
PINF,
PZERO,
Inf,
Infinity,
NaN,
e,
euler_gamma,
inf,
infty,
nan,
newaxis,
pi,
)
from .constants import e, euler_gamma, inf, nan, newaxis, pi

if version.parse(numpy.__version__) < version.parse("2.0.0b1"):
from .constants import (
NAN,
NINF,
NZERO,
PINF,
PZERO,
Inf,
Infinity,
NaN,
infty,
)

from .logic import (
all,
any,
Expand Down Expand Up @@ -151,18 +152,9 @@ def where(condition, x=None, y=None):
"amin",
"min",
"where",
"Inf",
"Infinity",
"NAN",
"NINF",
"NZERO",
"NaN",
"PINF",
"PZERO",
"e",
"euler_gamma",
"inf",
"infty",
"nan",
"newaxis",
"pi",
Expand All @@ -177,3 +169,15 @@ def where(condition, x=None, y=None):
"append",
"tri",
]
if version.parse(numpy.__version__) < version.parse("2.0.0b1"):
__all__ += [
"Inf",
"Infinity",
"NAN",
"NINF",
"NZERO",
"NaN",
"PINF",
"PZERO",
"infty",
]
46 changes: 19 additions & 27 deletions modin/numpy/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,31 @@
# ANY KIND, either express or implied. See the License for the specific language
# governing permissions and limitations under the License.

# flake8: noqa
from numpy import (
NAN,
NINF,
NZERO,
PINF,
PZERO,
Inf,
Infinity,
NaN,
e,
euler_gamma,
inf,
infty,
nan,
newaxis,
pi,
)
import numpy
from numpy import e, euler_gamma, inf, nan, newaxis, pi
from packaging import version

if version.parse(numpy.__version__) < version.parse("2.0.0b1"):
from numpy import NAN, NINF, NZERO, PINF, PZERO, Inf, Infinity, NaN, infty

__all__ = [
"Inf",
"Infinity",
"NAN",
"NINF",
"NZERO",
"NaN",
"PINF",
"PZERO",
"e",
"euler_gamma",
"inf",
"infty",
"nan",
"newaxis",
"pi",
]

if version.parse(numpy.__version__) < version.parse("2.0.0b1"):
__all__ += [
"Inf",
"Infinity",
"NAN",
"NINF",
"NZERO",
"NaN",
"PINF",
"PZERO",
"infty",
]

0 comments on commit 767e28e

Please sign in to comment.