Skip to content

Commit

Permalink
[lang] MatrixType refactor: Support vector swizzle (#6506)
Browse files Browse the repository at this point in the history
Issue: #5819

### Brief Summary

This PR only contains the valid case and a corresponding test. More
comprehensive tests will be enabled once #6425 is in.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Nov 3, 2022
1 parent 72d7b1f commit 10dd56b
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 5 deletions.
28 changes: 23 additions & 5 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
ReturnStatus)
from taichi.lang.ast.symbol_resolver import ASTResolver
from taichi.lang.exception import TaichiSyntaxError, TaichiTypeError
from taichi.lang.expr import Expr
from taichi.lang.expr import Expr, make_expr_group
from taichi.lang.field import Field
from taichi.lang.impl import current_cfg
from taichi.lang.matrix import Matrix, MatrixType, Vector, is_vector
Expand Down Expand Up @@ -776,10 +776,28 @@ def build_Attribute(ctx, node):
build_stmt(ctx, node.value)
if isinstance(node.value.ptr,
Expr) and not hasattr(node.value.ptr, node.attr):
# pylint: disable-msg=C0415
from taichi.lang import matrix_ops as tensor_ops
node.ptr = getattr(tensor_ops, node.attr)
setattr(node, 'caller', node.value.ptr)
if node.attr in Matrix._swizzle_to_keygroup:
keygroup = Matrix._swizzle_to_keygroup[node.attr]
attr_len = len(node.attr)
if attr_len == 1:
node.ptr = Expr(
_ti_core.subscript(
node.value.ptr.ptr,
make_expr_group(keygroup.index(node.attr)),
impl.get_runtime().get_current_src_info()))
else:
node.ptr = Expr(
_ti_core.subscript_with_multiple_indices(
node.value.ptr.ptr, [
make_expr_group(keygroup.index(ch))
for ch in node.attr
], (attr_len, ),
impl.get_runtime().get_current_src_info()))
else:
from taichi.lang import \
matrix_ops as tensor_ops # pylint: disable=C0415
node.ptr = getattr(tensor_ops, node.attr)
setattr(node, 'caller', node.value.ptr)
else:
node.ptr = getattr(node.value.ptr, node.attr)
return node.ptr
Expand Down
3 changes: 3 additions & 0 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def _gen_swizzles(cls):
swizzle_gen = SwizzleGenerator()
# https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)#Swizzling
KEYGROUP_SET = ['xyzw', 'rgba', 'stpq']
cls._swizzle_to_keygroup = {}

def make_valid_attribs_checker(key_group):
def check(instance, pattern):
Expand Down Expand Up @@ -59,6 +60,7 @@ def prop_setter(instance, value):

prop = gen_property(attr, index, key_group)
setattr(cls, attr, prop)
cls._swizzle_to_keygroup[attr] = key_group

for key_group in KEYGROUP_SET:
sw_patterns = swizzle_gen.generate(key_group, required_length=4)
Expand Down Expand Up @@ -93,6 +95,7 @@ def prop_setter(instance, value):

prop_key, prop = gen_property(pat, key_group)
setattr(cls, prop_key, prop)
cls._swizzle_to_keygroup[prop_key] = key_group
return cls


Expand Down
11 changes: 11 additions & 0 deletions tests/python/test_vector_swizzle.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,14 @@ def test_vector_invalid_swizzle_patterns():
match=re.escape(
"value len does not match the swizzle pattern=xy")):
a.xy = [1, 2, 3]


@test_utils.test(real_matrix=True, real_matrix_scalarize=True)
def test_vector_swizzle_real_matrix_scalarize():
@ti.kernel
def foo() -> ti.types.vector(3, ti.i32):
v = ti.Vector([1, 2, 3])
v.zxy += [v.z, v.y, v.x]
return v

assert (foo() == ti.Vector([3, 3, 6])).all()

0 comments on commit 10dd56b

Please sign in to comment.