Skip to content

Commit

Permalink
[vulkan] Support texture type args in aot add_kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
Ailing Zhang committed Dec 2, 2022
1 parent 8bd91a7 commit cc2bb67
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 11 deletions.
16 changes: 5 additions & 11 deletions python/taichi/aot/utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from taichi.lang._ndarray import ScalarNdarray
from taichi.lang._texture import FORMAT2TY_CH, Texture
from taichi.lang._texture import Texture
from taichi.lang.exception import TaichiCompilationError
from taichi.lang.matrix import Matrix, MatrixNdarray, MatrixType, VectorNdarray
from taichi.lang.util import cook_dtype
from taichi.types.annotations import template
from taichi.types.ndarray_type import NdarrayType
from taichi.types.texture_type import RWTextureType, TextureType
from taichi.types.texture_type import TY_CH2FORMAT, RWTextureType, TextureType

template_types = (NdarrayType, template)
template_types = (NdarrayType, TextureType, template)


def check_type_match(lhs, rhs):
Expand Down Expand Up @@ -99,14 +99,8 @@ def produce_injected_args(kernel, symbolic_args=None):
'Texture type annotation doesn\'t have enough information for aot. Please either specify the channel_format, shape and num_channels in the graph arg declaration.'
)
texture_shape = tuple(symbolic_args[i].texture_shape)
# FIXME: (penguinliong) dtype + num_channels -> texel format.
fmt = None
for (fmt2, (channel_format2,
num_channels2)) in FORMAT2TY_CH.items():
if channel_format2 == symbolic_args[i].channel_format(
) and num_channels2 == symbolic_args[i].num_channels:
fmt = fmt2
break
fmt = TY_CH2FORMAT[(symbolic_args[i].channel_format,
symbolic_args[i].num_channels)]
injected_args.append(Texture(fmt, texture_shape))
elif isinstance(anno, MatrixType):
if not isinstance(symbolic_args[i], list):
Expand Down
17 changes: 17 additions & 0 deletions tests/python/test_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -665,3 +665,20 @@ def write(tex: ti.types.rw_texture(num_dimensions=2,
m.add_kernel(write)
with tempfile.TemporaryDirectory() as tmpdir:
m.save(tmpdir)


@test_utils.test(arch=[ti.vulkan])
def test_read_kernel_with_texture():
@ti.kernel
def read(tex: ti.types.texture(num_dimensions=2), arr: ti.types.ndarray()):
for i, j in arr:
arr[i, j] = tex.fetch(ti.Vector([i, j]), 0).x

res = (128, 128)
tex = ti.Texture(ti.Format.r32f, res)
arr = ti.ndarray(ti.f32, res)

m = ti.aot.Module()
m.add_kernel(read, template_args={"tex": tex, "arr": arr})
with tempfile.TemporaryDirectory() as tmpdir:
m.save(tmpdir)

0 comments on commit cc2bb67

Please sign in to comment.