diff --git a/python/taichi/aot/utils.py b/python/taichi/aot/utils.py index 1920b65ad5f4d4..4195e7ba3bd9a1 100644 --- a/python/taichi/aot/utils.py +++ b/python/taichi/aot/utils.py @@ -1,13 +1,13 @@ from taichi.lang._ndarray import ScalarNdarray -from taichi.lang._texture import FORMAT2TY_CH, Texture +from taichi.lang._texture import Texture from taichi.lang.exception import TaichiCompilationError from taichi.lang.matrix import Matrix, MatrixNdarray, MatrixType, VectorNdarray from taichi.lang.util import cook_dtype from taichi.types.annotations import template from taichi.types.ndarray_type import NdarrayType -from taichi.types.texture_type import RWTextureType, TextureType +from taichi.types.texture_type import TY_CH2FORMAT, RWTextureType, TextureType -template_types = (NdarrayType, template) +template_types = (NdarrayType, TextureType, template) def check_type_match(lhs, rhs): @@ -99,14 +99,8 @@ def produce_injected_args(kernel, symbolic_args=None): 'Texture type annotation doesn\'t have enough information for aot. Please either specify the channel_format, shape and num_channels in the graph arg declaration.' ) texture_shape = tuple(symbolic_args[i].texture_shape) - # FIXME: (penguinliong) dtype + num_channels -> texel format. - fmt = None - for (fmt2, (channel_format2, - num_channels2)) in FORMAT2TY_CH.items(): - if channel_format2 == symbolic_args[i].channel_format( - ) and num_channels2 == symbolic_args[i].num_channels: - fmt = fmt2 - break + fmt = TY_CH2FORMAT[(symbolic_args[i].channel_format, + symbolic_args[i].num_channels)] injected_args.append(Texture(fmt, texture_shape)) elif isinstance(anno, MatrixType): if not isinstance(symbolic_args[i], list): diff --git a/tests/python/test_aot.py b/tests/python/test_aot.py index f25914e8ca4ddc..6dbc91c729f06f 100644 --- a/tests/python/test_aot.py +++ b/tests/python/test_aot.py @@ -665,3 +665,20 @@ def write(tex: ti.types.rw_texture(num_dimensions=2, m.add_kernel(write) with tempfile.TemporaryDirectory() as tmpdir: m.save(tmpdir) + + +@test_utils.test(arch=[ti.vulkan]) +def test_read_kernel_with_texture(): + @ti.kernel + def read(tex: ti.types.texture(num_dimensions=2), arr: ti.types.ndarray()): + for i, j in arr: + arr[i, j] = tex.fetch(ti.Vector([i, j]), 0).x + + res = (128, 128) + tex = ti.Texture(ti.Format.r32f, res) + arr = ti.ndarray(ti.f32, res) + + m = ti.aot.Module() + m.add_kernel(read, template_args={"tex": tex, "arr": arr}) + with tempfile.TemporaryDirectory() as tmpdir: + m.save(tmpdir)