Skip to content

Commit

Permalink
[Lang] Deprecate field_dim in ndarray annotation (taichi-dev#6687)
Browse files Browse the repository at this point in the history
Issue: taichi-dev#6572 

* Deprecate `field_dim` in ndarray type annotation, replace with `ndim`.
* Change tests accordingly. Examples are left unchanged.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
turbo0628 and pre-commit-ci[bot] authored Nov 22, 2022
1 parent 2d76386 commit 369e0c7
Show file tree
Hide file tree
Showing 12 changed files with 89 additions and 68 deletions.
5 changes: 2 additions & 3 deletions python/taichi/_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,8 +259,7 @@ def load_texture_from_numpy(tex: texture_type.rw_texture(num_dimensions=2,
num_channels=4,
channel_format=u8,
lod=0),
img: ndarray_type.ndarray(dtype=vec3,
field_dim=2)):
img: ndarray_type.ndarray(dtype=vec3, ndim=2)):
for i, j in img:
tex.store(
vector(2, i32)([i, j]),
Expand All @@ -273,7 +272,7 @@ def save_texture_to_numpy(tex: texture_type.rw_texture(num_dimensions=2,
num_channels=4,
channel_format=u8,
lod=0),
img: ndarray_type.ndarray(dtype=vec3, field_dim=2)):
img: ndarray_type.ndarray(dtype=vec3, ndim=2)):
for i, j in img:
img[i, j] = ops.round(tex.load(vector(2, i32)([i, j])).rgb * 255)

Expand Down
18 changes: 9 additions & 9 deletions python/taichi/aot/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,22 +47,22 @@ def produce_injected_args(kernel, symbolic_args=None):
if symbolic_args is not None:
element_shape = tuple(symbolic_args[i].element_shape)
element_dim = len(element_shape)
field_dim = symbolic_args[i].field_dim
ndim = symbolic_args[i].field_dim
dtype = symbolic_args[i].dtype()
else:
element_shape = anno.dtype.get_shape()
element_dim = anno.dtype.ndim
field_dim = anno.field_dim
ndim = anno.ndim
dtype = anno.dtype

if element_shape is None or field_dim is None:
if element_shape is None or ndim is None:
raise TaichiCompilationError(
'Please either specify both `element_shape` and `field_dim` '
'Please either specify both `element_shape` and `ndim` '
'in the param annotation, or provide an example '
f'ndarray for param={arg.name}')
if anno.field_dim is not None and field_dim != anno.field_dim:
if anno.ndim is not None and ndim != anno.ndim:
raise TaichiCompilationError(
f'{field_dim} from Arg {arg.name} doesn\'t match kernel\'s annotated field_dim={anno.field_dim}'
f'{ndim} from Arg {arg.name} doesn\'t match kernel\'s annotated ndim={anno.ndim}'
)
anno_dtype = anno.dtype
if isinstance(anno_dtype, MatrixType):
Expand All @@ -75,18 +75,18 @@ def produce_injected_args(kernel, symbolic_args=None):

if element_dim is None or element_dim == 0 or element_shape == (
1, ):
injected_args.append(ScalarNdarray(dtype, (2, ) * field_dim))
injected_args.append(ScalarNdarray(dtype, (2, ) * ndim))
elif element_dim == 1:
injected_args.append(
VectorNdarray(element_shape[0],
dtype=dtype,
shape=(2, ) * field_dim))
shape=(2, ) * ndim))
elif element_dim == 2:
injected_args.append(
MatrixNdarray(element_shape[0],
element_shape[1],
dtype=dtype,
shape=(2, ) * field_dim))
shape=(2, ) * ndim))
else:
raise RuntimeError('')
elif isinstance(anno, (TextureType, RWTextureType)):
Expand Down
28 changes: 20 additions & 8 deletions python/taichi/types/ndarray_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,14 @@ class NdarrayType:
Args:
dtype (Union[PrimitiveType, VectorType, MatrixType, NoneType], optional): None if not speicified.
[DERPRECATED] element_dim (Union[Int, NoneType], optional): None if not specified (will be treated as 0 for external arrays), 0 if scalar elements, 1 if vector elements, and 2 if matrix elements.
[DERPRECATED] element_shape (Union[Tuple[Int], NoneType]): None if not specified, shapes of each element. For example, element_shape must be 1d for vector and 2d tuple for matrix. This argument is ignored for external arrays for now.
field_dim (Union[Int, NoneType]): None if not specified, number of field dimensions. This argument is ignored for external arrays for now.
ndim (Union[Int, NoneType]): None if not specified, number of field dimensions. This argument is ignored for external arrays for now.
[DEPRECATED] element_dim (Union[Int, NoneType], optional): None if not specified (will be treated as 0 for external arrays), 0 if scalar elements, 1 if vector elements, and 2 if matrix elements.
[DEPRECATED] element_shape (Union[Tuple[Int], NoneType]): None if not specified, shapes of each element. For example, element_shape must be 1d for vector and 2d tuple for matrix. This argument is ignored for external arrays for now.
[DEPRECATED] field_dim (Union[Int, NoneType]): None if not specified, number of field dimensions. This argument is ignored for external arrays for now.
"""
def __init__(self,
dtype=None,
ndim=None,
element_dim=None,
element_shape=None,
field_dim=None):
Expand All @@ -83,7 +85,17 @@ def __init__(self,
else:
self.dtype = dtype

self.field_dim = field_dim
if field_dim is not None:
warnings.warn(
"The field_dim argument for ndarray will be deprecated in v1.4.0, use ndim instead.",
DeprecationWarning)
if ndim is not None:
raise ValueError(
"Cannot specify ndim and field_dim at the same time. The field_dim is going to be deprecated."
)
ndim = field_dim

self.ndim = ndim
self.layout = Layout.AOS

def check_matched(self, ndarray_type: NdarrayTypeMetadata):
Expand Down Expand Up @@ -121,12 +133,12 @@ def check_matched(self, ndarray_type: NdarrayTypeMetadata):
f"Expect element type {self.dtype} for Ndarray, but get {ndarray_type.element_type}"
)

# Check field dim shape match
if self.field_dim is not None and \
# Check ndim match
if self.ndim is not None and \
ndarray_type.shape is not None and \
self.field_dim != len(ndarray_type.shape):
self.ndim != len(ndarray_type.shape):
raise ValueError(
f"Invalid argument into ti.types.ndarray() - required field_dim={self.field_dim}, but {ndarray_type.element_type} is provided"
f"Invalid argument into ti.types.ndarray() - required ndim={self.ndim}, but {ndarray_type.element_type} is provided"
)


Expand Down
6 changes: 3 additions & 3 deletions tests/cpp/aot/python_scripts/graph_aot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ def compile_graph_aot(arch):
return

@ti.kernel
def run0(base: int, arr: ti.types.ndarray(field_dim=1, dtype=ti.i32)):
def run0(base: int, arr: ti.types.ndarray(ndim=1, dtype=ti.i32)):
for i in arr:
arr[i] += base + i

@ti.kernel
def run1(base: int, arr: ti.types.ndarray(field_dim=1, dtype=ti.i32)):
def run1(base: int, arr: ti.types.ndarray(ndim=1, dtype=ti.i32)):
for i in arr:
arr[i] += base + i

@ti.kernel
def run2(base: int, arr: ti.types.ndarray(field_dim=1, dtype=ti.i32)):
def run2(base: int, arr: ti.types.ndarray(ndim=1, dtype=ti.i32)):
for i in arr:
arr[i] += base + i

Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def test_aot_bind_id():
density1 = ti.ndarray(dtype=ti.f32, shape=(8, 8))

@ti.kernel
def init(x: ti.f32, density1: ti.types.ndarray(field_dim=2)):
def init(x: ti.f32, density1: ti.types.ndarray(ndim=2)):
for i, j in density1:
density[i, j] = x
density1[i, j] = x + 1
Expand Down
22 changes: 11 additions & 11 deletions tests/python/test_argument.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,17 +208,17 @@ def test_args_with_many_ndarrays():
def ti_import_cluster_data(
center: ti.types.vector(3,
ti.f32), particle_num: int, cluster_num: int,
permu_num: int, particlePosition: ti.types.ndarray(field_dim=1),
outClusterPosition: ti.types.ndarray(field_dim=1),
outClusterOffsets: ti.types.ndarray(field_dim=1),
outClusterSizes: ti.types.ndarray(field_dim=1),
outClusterIndices: ti.types.ndarray(field_dim=1),
particle_pos: ti.types.ndarray(field_dim=1),
particle_prev_pos: ti.types.ndarray(field_dim=1),
particle_rest_pos: ti.types.ndarray(field_dim=1),
cluster_rest_mass_center: ti.types.ndarray(field_dim=1),
cluster_begin: ti.types.ndarray(field_dim=1),
particle_index: ti.types.ndarray(field_dim=1)):
permu_num: int, particlePosition: ti.types.ndarray(ndim=1),
outClusterPosition: ti.types.ndarray(ndim=1),
outClusterOffsets: ti.types.ndarray(ndim=1),
outClusterSizes: ti.types.ndarray(ndim=1),
outClusterIndices: ti.types.ndarray(ndim=1),
particle_pos: ti.types.ndarray(ndim=1),
particle_prev_pos: ti.types.ndarray(ndim=1),
particle_rest_pos: ti.types.ndarray(ndim=1),
cluster_rest_mass_center: ti.types.ndarray(ndim=1),
cluster_begin: ti.types.ndarray(ndim=1),
particle_index: ti.types.ndarray(ndim=1)):

added_permu_num = outClusterIndices.shape[0]

Expand Down
13 changes: 13 additions & 0 deletions tests/python/test_deprecation.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,16 @@ def test_deprecate_element_dim_ndarray_annotation():
@ti.kernel
def func(x: ti.types.ndarray(element_dim=2)):
pass


@test_utils.test()
def test_deprecate_field_dim_ndarray_annotation():
with pytest.warns(
DeprecationWarning,
match=
"The field_dim argument for ndarray will be deprecated in v1.4.0, use ndim instead."
):

@ti.kernel
def func(x: ti.types.ndarray(field_dim=(16, 16))):
pass
8 changes: 4 additions & 4 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,19 +365,19 @@ def _test_func_ndarray_arg():
vec3 = ti.types.vector(3, ti.f32)

@ti.func
def test(a: ti.types.ndarray(field_dim=1)):
def test(a: ti.types.ndarray(ndim=1)):
a[0] = [100, 100, 100]

@ti.kernel
def test_k(x: ti.types.ndarray(field_dim=1)):
def test_k(x: ti.types.ndarray(ndim=1)):
test(x)

@ti.func
def test_error_func(a: ti.types.ndarray(dtype=ti.math.vec2, field_dim=1)):
def test_error_func(a: ti.types.ndarray(dtype=ti.math.vec2, ndim=1)):
a[0] = [100, 100]

@ti.kernel
def test_error(x: ti.types.ndarray(field_dim=1)):
def test_error(x: ti.types.ndarray(ndim=1)):
test_error_func(x)

arr = ti.ndarray(vec3, shape=(4))
Expand Down
34 changes: 17 additions & 17 deletions tests/python/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_ndarray_int():
n = 4

@ti.kernel
def test(pos: ti.types.ndarray(dtype=ti.i32, field_dim=1)):
def test(pos: ti.types.ndarray(dtype=ti.i32, ndim=1)):
for i in range(n):
pos[i] = 1

Expand All @@ -39,7 +39,7 @@ def test(pos: ti.types.ndarray(dtype=ti.i32, field_dim=1)):
@test_utils.test(arch=supported_archs_cgraph)
def test_ndarray_1dim_scalar():
@ti.kernel
def ti_test_debug(arr: ti.types.ndarray(field_dim=1)):
def ti_test_debug(arr: ti.types.ndarray(ndim=1)):
arr[0] = 0

debug_arr = ti.ndarray(ti.i32, shape=5)
Expand All @@ -56,7 +56,7 @@ def ti_test_debug(arr: ti.types.ndarray(field_dim=1)):
@test_utils.test(arch=supported_archs_cgraph)
def test_ndarray_0dim():
@ti.kernel
def test(pos: ti.types.ndarray(dtype=ti.i32, field_dim=0)):
def test(pos: ti.types.ndarray(dtype=ti.i32, ndim=0)):
pos[None] = 1

sym_pos = ti.graph.Arg(ti.graph.ArgKind.NDARRAY,
Expand All @@ -77,7 +77,7 @@ def test_ndarray_float():
n = 4

@ti.kernel
def test(pos: ti.types.ndarray(field_dim=1)):
def test(pos: ti.types.ndarray(ndim=1)):
for i in range(n):
pos[i] = 2.5

Expand All @@ -99,7 +99,7 @@ def test_arg_mismatched_field_dim():
n = 4

@ti.kernel
def test(pos: ti.types.ndarray(field_dim=1)):
def test(pos: ti.types.ndarray(ndim=1)):
for i in range(n):
pos[i] = 2.5

Expand All @@ -109,7 +109,7 @@ def test(pos: ti.types.ndarray(field_dim=1)):
field_dim=2)
g_init = ti.graph.GraphBuilder()
with pytest.raises(TaichiCompilationError,
match="doesn't match kernel's annotated field_dim"):
match="doesn't match kernel's annotated ndim"):
g_init.dispatch(test, sym_pos)


Expand All @@ -118,7 +118,7 @@ def test_arg_mismatched_field_dim_ndarray():
n = 4

@ti.kernel
def test(pos: ti.types.ndarray(field_dim=1)):
def test(pos: ti.types.ndarray(ndim=1)):
for i in range(n):
pos[i] = 2.5

Expand All @@ -137,7 +137,7 @@ def test_repeated_arg_name():
n = 4

@ti.kernel
def test1(pos: ti.types.ndarray(field_dim=1)):
def test1(pos: ti.types.ndarray(ndim=1)):
for i in range(n):
pos[i] = 2.5

Expand All @@ -163,7 +163,7 @@ def test_arg_mismatched_scalar_dtype():
n = 4

@ti.kernel
def test(pos: ti.types.ndarray(field_dim=1), val: ti.f32):
def test(pos: ti.types.ndarray(ndim=1), val: ti.f32):
for i in range(n):
pos[i] = val

Expand All @@ -180,7 +180,7 @@ def test_arg_mismatched_ndarray_dtype():
n = 4

@ti.kernel
def test(pos: ti.types.ndarray(dtype=ti.f32, field_dim=1)):
def test(pos: ti.types.ndarray(dtype=ti.f32, ndim=1)):
for i in range(n):
pos[i] = 2.5

Expand All @@ -196,7 +196,7 @@ def test_ndarray_dtype_mismatch_runtime():
n = 4

@ti.kernel
def test(pos: ti.types.ndarray(field_dim=1)):
def test(pos: ti.types.ndarray(ndim=1)):
for i in range(n):
pos[i] = 2.5

Expand All @@ -216,7 +216,7 @@ def test(pos: ti.types.ndarray(field_dim=1)):
def build_graph_vector(N, dtype):
@ti.kernel
def vector_sum(mat: ti.types.vector(N, dtype),
res: ti.types.ndarray(dtype=dtype, field_dim=1)):
res: ti.types.ndarray(dtype=dtype, ndim=1)):
res[0] = mat.sum() + mat[2]

sym_A = ti.graph.Arg(ti.graph.ArgKind.MATRIX, 'mat',
Expand All @@ -231,7 +231,7 @@ def vector_sum(mat: ti.types.vector(N, dtype),
def build_graph_matrix(N, dtype):
@ti.kernel
def matrix_sum(mat: ti.types.matrix(N, 2, dtype),
res: ti.types.ndarray(dtype=dtype, field_dim=1)):
res: ti.types.ndarray(dtype=dtype, ndim=1)):
res[0] = mat.sum()

sym_A = ti.graph.Arg(ti.graph.ArgKind.MATRIX, 'mat',
Expand Down Expand Up @@ -287,7 +287,7 @@ def test_vector_float():
@test_utils.test(arch=supported_archs_cgraph)
def test_arg_float(dt):
@ti.kernel
def foo(a: dt, b: ti.types.ndarray(dtype=dt, field_dim=1)):
def foo(a: dt, b: ti.types.ndarray(dtype=dt, ndim=1)):
b[0] = a

k = ti.ndarray(dt, shape=(1, ))
Expand All @@ -309,7 +309,7 @@ def foo(a: dt, b: ti.types.ndarray(dtype=dt, field_dim=1)):
@test_utils.test(arch=supported_archs_cgraph, exclude=[(ti.vulkan, "Darwin")])
def test_arg_int(dt):
@ti.kernel
def foo(a: dt, b: ti.types.ndarray(dtype=dt, field_dim=1)):
def foo(a: dt, b: ti.types.ndarray(dtype=dt, ndim=1)):
b[0] = a

k = ti.ndarray(dt, shape=(1, ))
Expand All @@ -331,7 +331,7 @@ def foo(a: dt, b: ti.types.ndarray(dtype=dt, field_dim=1)):
@test_utils.test(arch=ti.vulkan)
def test_arg_short(dt):
@ti.kernel
def foo(a: dt, b: ti.types.ndarray(dtype=dt, field_dim=1)):
def foo(a: dt, b: ti.types.ndarray(dtype=dt, ndim=1)):
b[0] = a

k = ti.ndarray(dt, shape=(1, ))
Expand Down Expand Up @@ -362,7 +362,7 @@ def make_texture(tex: ti.types.rw_texture(num_dimensions=2,
tex.store(ti.Vector([i, j]), ti.Vector([0.1, 0.0, 0.0, 0.0]))

@ti.kernel
def paint(t: ti.f32, pixels: ti.types.ndarray(field_dim=2),
def paint(t: ti.f32, pixels: ti.types.ndarray(ndim=2),
tex: ti.types.texture(num_dimensions=2)):
for i, j in pixels:
uv = ti.Vector([i / res[0], j / res[1]])
Expand Down
Loading

0 comments on commit 369e0c7

Please sign in to comment.