Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat[next] Enable embedded field view in ffront_tests #1361

Merged
merged 19 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -343,7 +343,11 @@ markers = [
'uses_strided_neighbor_offset: tests that require backend support for strided neighbor offset',
'uses_tuple_args: tests that require backend support for tuple arguments',
'uses_tuple_returns: tests that require backend support for tuple results',
'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields'
'uses_zero_dimensional_fields: tests that require backend support for zero-dimensional fields',
'uses_cartesian_shift: tests that use a Cartesian connectivity',
'uses_unstructured_shift: tests that use a unstructured connectivity',
'uses_scan: tests that uses scan',
'checks_specific_error: tests that rely on the backend to produce a specific error message'
]
norecursedirs = ['dist', 'build', 'cpp_backend_tests/build*', '_local/*', '.*']
testpaths = 'tests'
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/_core/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,6 +446,9 @@ def shape(self) -> tuple[int, ...]:
def dtype(self) -> Any:
...

def astype(self, dtype: npt.DTypeLike) -> NDArrayObject:
...

def __getitem__(self, item: Any) -> NDArrayObject:
...

Expand Down
20 changes: 19 additions & 1 deletion src/gt4py/next/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,14 +133,24 @@ def __getitem__(self, index: int | slice) -> int | UnitRange:
else:
raise IndexError("UnitRange index out of range")

def __and__(self, other: Set[Any]) -> UnitRange:
def __and__(self, other: Set[int]) -> UnitRange:
if isinstance(other, UnitRange):
start = max(self.start, other.start)
stop = min(self.stop, other.stop)
return UnitRange(start, stop)
else:
raise NotImplementedError("Can only find the intersection between UnitRange instances.")

def __le__(self, other: Set[int]):
if isinstance(other, UnitRange):
return self.start >= other.start and self.stop <= other.stop
elif len(self) == Infinity.positive():
return False
else:
return Set.__le__(self, other)

__ge__ = __lt__ = __gt__ = lambda self, other: NotImplemented

def __str__(self) -> str:
return f"({self.start}:{self.stop})"

Expand Down Expand Up @@ -486,6 +496,14 @@ def __neg__(self) -> Field:
def __invert__(self) -> Field:
"""Only defined for `Field` of value type `bool`."""

@abc.abstractmethod
def __eq__(self, other: Any) -> Field: # type: ignore[override] # mypy wants return `bool`
...

@abc.abstractmethod
def __ne__(self, other: Any) -> Field: # type: ignore[override] # mypy wants return `bool`
...

@abc.abstractmethod
def __add__(self, other: Field | core_defs.ScalarT) -> Field:
...
Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def empty(
(3, 3)
"""
dtype = core_defs.dtype(dtype)
if allocator is None and device is None:
device = core_defs.Device(core_defs.DeviceType.CPU, device_id=0)
buffer = next_allocators.allocate(
domain, dtype, aligned_index=aligned_index, allocator=allocator, device=device
)
Expand Down
30 changes: 19 additions & 11 deletions src/gt4py/next/embedded/nd_array_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,25 +135,22 @@ def from_array(
/,
*,
domain: common.DomainLike,
dtype_like: Optional[core_defs.DType] = None, # TODO define DTypeLike
dtype: Optional[core_defs.DTypeLike] = None,
) -> NdArrayField:
domain = common.domain(domain)
xp = cls.array_ns

xp_dtype = None if dtype_like is None else xp.dtype(core_defs.dtype(dtype_like).scalar_type)
xp_dtype = None if dtype is None else xp.dtype(core_defs.dtype(dtype).scalar_type)
array = xp.asarray(data, dtype=xp_dtype)

if dtype_like is not None:
assert array.dtype.type == core_defs.dtype(dtype_like).scalar_type
if dtype is not None:
assert array.dtype.type == core_defs.dtype(dtype).scalar_type

assert issubclass(array.dtype.type, core_defs.SCALAR_TYPES)

assert all(isinstance(d, common.Dimension) for d in domain.dims), domain
assert len(domain) == array.ndim
assert all(
len(r) == s or (s == 1 and r == common.UnitRange.infinity())
for r, s in zip(domain.ranges, array.shape)
)
assert all(len(r) == s or s == 1 for r, s in zip(domain.ranges, array.shape))

return cls(domain, array)

Expand Down Expand Up @@ -194,6 +191,10 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Scala

__mod__ = __rmod__ = _make_builtin("mod", "mod")

__ne__ = _make_builtin("not_equal", "not_equal") # type: ignore[assignment] # mypy wants return `bool`
tehrengruber marked this conversation as resolved.
Show resolved Hide resolved

__eq__ = _make_builtin("equal", "equal") # type: ignore[assignment] # mypy wants return `bool`

def __and__(self, other: common.Field | core_defs.ScalarT) -> NdArrayField:
if self.dtype == core_defs.BoolDType():
return _make_builtin("logical_and", "logical_and")(self, other)
Expand Down Expand Up @@ -285,7 +286,7 @@ def _np_cp_setitem(
_nd_array_implementations = [np]


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class NumPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = np

Expand All @@ -298,7 +299,7 @@ class NumPyArrayField(NdArrayField):
if cp:
_nd_array_implementations.append(cp)

@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class CuPyArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = cp

Expand All @@ -310,7 +311,7 @@ class CuPyArrayField(NdArrayField):
if jnp:
_nd_array_implementations.append(jnp)

@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class JaxArrayField(NdArrayField):
array_ns: ClassVar[ModuleType] = jnp

Expand Down Expand Up @@ -351,6 +352,13 @@ def _builtins_broadcast(
NdArrayField.register_builtin_func(fbuiltins.broadcast, _builtins_broadcast)


def _astype(field: NdArrayField, type_: type) -> NdArrayField:
return field.__class__.from_array(field.ndarray.astype(type_), domain=field.domain)


NdArrayField.register_builtin_func(fbuiltins.astype, _astype) # type: ignore[arg-type] # TODO(havogt) the registry should not be for any Field


def _get_slices_from_domain_slice(
domain: common.Domain,
domain_slice: common.Domain | Sequence[common.NamedRange | common.NamedIndex | Any],
Expand Down
55 changes: 36 additions & 19 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from gt4py._core import definitions as core_defs
from gt4py.eve import utils as eve_utils
from gt4py.eve.extended_typing import Any, Optional
from gt4py.next import allocators as next_allocators
from gt4py.next import allocators as next_allocators, common
from gt4py.next.common import Dimension, DimensionKind, GridType
from gt4py.next.ffront import (
dialect_ast_enums,
Expand Down Expand Up @@ -171,14 +171,14 @@ class Program:
past_node: past.Program
closure_vars: dict[str, Any]
definition: Optional[types.FunctionType] = None
backend: Optional[ppi.ProgramExecutor] = None
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND
grid_type: Optional[GridType] = None

@classmethod
def from_function(
cls,
definition: types.FunctionType,
backend: Optional[ppi.ProgramExecutor] = None,
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND,
grid_type: Optional[GridType] = None,
) -> Program:
source_def = SourceDefinition.from_function(definition)
Expand Down Expand Up @@ -282,27 +282,23 @@ def itir(self) -> itir.FencilDefinition:
)

def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs) -> None:
if (
self.backend is None and DEFAULT_BACKEND is None
): # TODO(havogt): for now enable embedded execution by setting DEFAULT_BACKEND to None
self.definition(*args, **kwargs)
return

rewritten_args, size_args, kwargs = self._process_args(args, kwargs)

if not self.backend:
if self.backend is None:
warnings.warn(
UserWarning(
f"Field View Program '{self.itir.id}': Using default ({DEFAULT_BACKEND}) backend."
f"Field View Program '{self.itir.id}': Using Python execution, consider selecting a perfomance backend."
)
)
backend = self.backend or DEFAULT_BACKEND

ppi.ensure_processor_kind(backend, ppi.ProgramExecutor)
self.definition(*rewritten_args, **kwargs)
return

ppi.ensure_processor_kind(self.backend, ppi.ProgramExecutor)
if "debug" in kwargs:
debug(self.itir)

backend(
self.backend(
self.itir,
*rewritten_args,
*size_args,
Expand Down Expand Up @@ -547,14 +543,14 @@ class FieldOperator(GTCallable, Generic[OperatorNodeT]):
foast_node: OperatorNodeT
closure_vars: dict[str, Any]
definition: Optional[types.FunctionType] = None
backend: Optional[ppi.ProgramExecutor] = None
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND
grid_type: Optional[GridType] = None

@classmethod
def from_function(
cls,
definition: types.FunctionType,
backend: Optional[ppi.ProgramExecutor] = None,
backend: Optional[ppi.ProgramExecutor] = DEFAULT_BACKEND,
grid_type: Optional[GridType] = None,
*,
operator_node_cls: type[OperatorNodeT] = foast.FieldOperator,
Expand Down Expand Up @@ -687,9 +683,9 @@ def __call__(
# if we are reaching this from a program call.
if "out" in kwargs:
out = kwargs.pop("out")
if "offset_provider" in kwargs:
offset_provider = kwargs.pop("offset_provider", None)
if self.backend is not None:
# "out" and "offset_provider" -> field_operator as program
offset_provider = kwargs.pop("offset_provider")
args, kwargs = type_info.canonicalize_arguments(self.foast_node.type, args, kwargs)
# TODO(tehrengruber): check all offset providers are given
# deduce argument types
Expand All @@ -705,13 +701,34 @@ def __call__(
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
)
self.as_program(arg_types, kwarg_types)(
*args, out, offset_provider=offset_provider, **kwargs
)
return out

Just noticed this bug.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should field_operator as program return the out param? Maybe yes...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say no as programs don't return anything.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

than it was correct, right?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, the program does not return something, but the field operator which we are calling does.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, we return self.as_program()() i.e. whatever program does which is currently nothing therefore still correct and in case (for whatever reason) program() returns something, we will do the right thing implicitly.

else:
# "out" -> field_operator called from program in embedded execution
out.ndarray[:] = self.definition(*args, **kwargs).ndarray[:]
# TODO(egparedes): put offset_provider in ctxt var here when implementing remap
domain = kwargs.pop("domain", None)
res = self.definition(*args, **kwargs)
_tuple_assign_field(
out, res, domain=None if domain is None else common.domain(domain)
)
return
else:
# field_operator called from other field_operator in embedded execution
assert self.backend is None
return self.definition(*args, **kwargs)
havogt marked this conversation as resolved.
Show resolved Hide resolved


def _tuple_assign_field(
havogt marked this conversation as resolved.
Show resolved Hide resolved
target: tuple[common.Field | tuple, ...] | common.Field,
source: tuple[common.Field | tuple, ...] | common.Field,
domain: Optional[common.Domain],
):
if isinstance(target, tuple):
if not isinstance(source, tuple):
raise RuntimeError(f"Cannot assign {source} to {target}.")
for t, s in zip(target, source):
_tuple_assign_field(t, s, domain)
else:
domain = domain or target.domain
target[domain] = source[domain]


@typing.overload
def field_operator(
definition: types.FunctionType, *, backend: Optional[ppi.ProgramExecutor]
Expand Down
Loading