Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pin numpy version in dev requirements, only register type transformer if installed #2485

Merged
merged 14 commits into from
Jun 18, 2024
8 changes: 7 additions & 1 deletion .github/workflows/pythonbuild.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,12 @@ jobs:
os: [ubuntu-latest]
python-version: ${{fromJson(needs.detect-python-versions.outputs.python-versions)}}
pandas: ["pandas<2.0.0", "pandas>=2.0.0"]
include:
- numpy: "numpy<2.0.0"
pandas: "pandas<2.0.0"
- numpy: "numpy>=2.0.0"
pandas: "pandas>=2.0.0"

steps:
- uses: actions/checkout@v4
- name: 'Clear action cache'
Expand All @@ -141,7 +147,7 @@ jobs:
run: |
pip install uv
make setup-global-uv
uv pip install --system --force-reinstall "${{ matrix.pandas }}"
uv pip install --system --force-reinstall "${{ matrix.pandas }}" "${{ matrix.numpy }}"
eapolinario marked this conversation as resolved.
Show resolved Hide resolved
uv pip freeze
- name: Test with coverage
run: |
Expand Down
16 changes: 16 additions & 0 deletions flytekit/types/numpy/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,17 @@
from flytekit.loggers import logger

from .ndarray import NumpyArrayTransformer

try:
# isolate the exception to the numpy import
import numpy

_numpy_installed = True
except ImportError:
_numpy_installed = False


if _numpy_installed:
from .ndarray import NumpyArrayTransformer
else:
logger.info("We won't register NumpyArrayTransformer because numpy is not installed.")
53 changes: 32 additions & 21 deletions flytekit/types/schema/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from pathlib import Path
from typing import Type

import numpy as _np
from dataclasses_json import config
from marshmallow import fields
from mashumaro.mixins.json import DataClassJSONMixin
Expand All @@ -19,10 +18,41 @@
from flytekit.loggers import logger
from flytekit.models.literals import Literal, Scalar, Schema
from flytekit.models.types import LiteralType, SchemaType
from flytekit.types.numpy import _numpy_installed

T = typing.TypeVar("T")


SUPPORTED_SCHEMA_TYPES = {
int: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
float: SchemaType.SchemaColumn.SchemaColumnType.FLOAT,
bool: SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN,
datetime.datetime: SchemaType.SchemaColumn.SchemaColumnType.DATETIME,
datetime.timedelta: SchemaType.SchemaColumn.SchemaColumnType.DURATION,
str: SchemaType.SchemaColumn.SchemaColumnType.STRING,
}

if _numpy_installed:
import numpy as np

SUPPORTED_SCHEMA_TYPES.update(
{
np.int32: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
np.int64: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
np.uint32: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
np.uint64: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
np.float32: SchemaType.SchemaColumn.SchemaColumnType.FLOAT,
np.float64: SchemaType.SchemaColumn.SchemaColumnType.FLOAT,
np.bool_: SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN, # type: ignore
np.datetime64: SchemaType.SchemaColumn.SchemaColumnType.DATETIME,
np.timedelta64: SchemaType.SchemaColumn.SchemaColumnType.DURATION,
np.bytes_: SchemaType.SchemaColumn.SchemaColumnType.STRING,
np.str_: SchemaType.SchemaColumn.SchemaColumnType.STRING,
np.object_: SchemaType.SchemaColumn.SchemaColumnType.STRING,
}
)


class SchemaFormat(Enum):
"""
Represents the schema storage format (at rest).
Expand Down Expand Up @@ -319,26 +349,7 @@ def as_readonly(self) -> FlyteSchema:


class FlyteSchemaTransformer(TypeTransformer[FlyteSchema]):
_SUPPORTED_TYPES: typing.Dict[Type, SchemaType.SchemaColumn.SchemaColumnType] = {
_np.int32: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
_np.int64: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
_np.uint32: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
_np.uint64: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
int: SchemaType.SchemaColumn.SchemaColumnType.INTEGER,
_np.float32: SchemaType.SchemaColumn.SchemaColumnType.FLOAT,
_np.float64: SchemaType.SchemaColumn.SchemaColumnType.FLOAT,
float: SchemaType.SchemaColumn.SchemaColumnType.FLOAT,
_np.bool_: SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN, # type: ignore
bool: SchemaType.SchemaColumn.SchemaColumnType.BOOLEAN,
_np.datetime64: SchemaType.SchemaColumn.SchemaColumnType.DATETIME,
datetime.datetime: SchemaType.SchemaColumn.SchemaColumnType.DATETIME,
_np.timedelta64: SchemaType.SchemaColumn.SchemaColumnType.DURATION,
datetime.timedelta: SchemaType.SchemaColumn.SchemaColumnType.DURATION,
_np.bytes_: SchemaType.SchemaColumn.SchemaColumnType.STRING,
_np.str_: SchemaType.SchemaColumn.SchemaColumnType.STRING,
_np.object_: SchemaType.SchemaColumn.SchemaColumnType.STRING,
str: SchemaType.SchemaColumn.SchemaColumnType.STRING,
}
_SUPPORTED_TYPES: typing.Dict[Type, SchemaType.SchemaColumn.SchemaColumnType] = SUPPORTED_SCHEMA_TYPES

def __init__(self):
super().__init__("FlyteSchema Transformer", FlyteSchema)
Expand Down
35 changes: 22 additions & 13 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,28 +323,37 @@ def convert_schema_type_to_structured_dataset_type(


def get_supported_types():
import numpy as _np
from flytekit.types.numpy import _numpy_installed

_SUPPORTED_TYPES: typing.Dict[Type, LiteralType] = { # type: ignore
_np.int32: type_models.LiteralType(simple=type_models.SimpleType.INTEGER),
_np.int64: type_models.LiteralType(simple=type_models.SimpleType.INTEGER),
_np.uint32: type_models.LiteralType(simple=type_models.SimpleType.INTEGER),
_np.uint64: type_models.LiteralType(simple=type_models.SimpleType.INTEGER),
int: type_models.LiteralType(simple=type_models.SimpleType.INTEGER),
_np.float32: type_models.LiteralType(simple=type_models.SimpleType.FLOAT),
_np.float64: type_models.LiteralType(simple=type_models.SimpleType.FLOAT),
float: type_models.LiteralType(simple=type_models.SimpleType.FLOAT),
_np.bool_: type_models.LiteralType(simple=type_models.SimpleType.BOOLEAN), # type: ignore
bool: type_models.LiteralType(simple=type_models.SimpleType.BOOLEAN),
_np.datetime64: type_models.LiteralType(simple=type_models.SimpleType.DATETIME),
_datetime.datetime: type_models.LiteralType(simple=type_models.SimpleType.DATETIME),
_np.timedelta64: type_models.LiteralType(simple=type_models.SimpleType.DURATION),
_datetime.timedelta: type_models.LiteralType(simple=type_models.SimpleType.DURATION),
_np.bytes_: type_models.LiteralType(simple=type_models.SimpleType.STRING),
_np.str_: type_models.LiteralType(simple=type_models.SimpleType.STRING),
_np.object_: type_models.LiteralType(simple=type_models.SimpleType.STRING),
str: type_models.LiteralType(simple=type_models.SimpleType.STRING),
}

if _numpy_installed:
import numpy as _np

_SUPPORTED_TYPES.update(
{ # type: ignore
_np.int32: type_models.LiteralType(simple=type_models.SimpleType.INTEGER),
_np.int64: type_models.LiteralType(simple=type_models.SimpleType.INTEGER),
_np.uint32: type_models.LiteralType(simple=type_models.SimpleType.INTEGER),
_np.uint64: type_models.LiteralType(simple=type_models.SimpleType.INTEGER),
_np.float32: type_models.LiteralType(simple=type_models.SimpleType.FLOAT),
_np.float64: type_models.LiteralType(simple=type_models.SimpleType.FLOAT),
_np.bool_: type_models.LiteralType(simple=type_models.SimpleType.BOOLEAN), # type: ignore
_np.datetime64: type_models.LiteralType(simple=type_models.SimpleType.DATETIME),
_np.timedelta64: type_models.LiteralType(simple=type_models.SimpleType.DURATION),
_np.bytes_: type_models.LiteralType(simple=type_models.SimpleType.STRING),
_np.str_: type_models.LiteralType(simple=type_models.SimpleType.STRING),
_np.object_: type_models.LiteralType(simple=type_models.SimpleType.STRING),
}
)

return _SUPPORTED_TYPES


Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ dependencies = [
"marshmallow-enum",
"marshmallow-jsonschema>=0.12.0",
"mashumaro>=3.11",
"numpy<2",
"protobuf!=4.25.0",
"pyarrow",
"pygments",
Expand Down
Loading