Skip to content

Commit

Permalink
[Lang] Simplify the swizzle generator (#7216)
Browse files Browse the repository at this point in the history
This PR replaces the `SwizzleGenerator.generate()` with a simpler
function to generate all swizzling patterns.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
neozhaoliang and pre-commit-ci[bot] authored Jan 19, 2023
1 parent f760fbb commit 2d30410
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 134 deletions.
32 changes: 24 additions & 8 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import numbers
from collections.abc import Iterable
from itertools import product

import numpy as np
from taichi._lib import core as ti_python_core
Expand All @@ -13,16 +14,32 @@
from taichi.lang.exception import (TaichiRuntimeError, TaichiSyntaxError,
TaichiTypeError)
from taichi.lang.field import Field, ScalarField, SNodeHostAccess
from taichi.lang.swizzle_generator import SwizzleGenerator
from taichi.lang.util import (cook_dtype, in_python_scope, python_scope,
taichi_scope, to_numpy_type, to_paddle_type,
to_pytorch_type, warning)
from taichi.types import primitive_types
from taichi.types.compound_types import CompoundType, TensorType


def _generate_swizzle_patterns(key_group: str, required_length=4):
"""Generate vector swizzle patterns from a given set of characters.
Example:
For `key_group=xyzw` and `required_length=4`, this function will return a
list consists of all possible strings (no repeats) in characters
`x`, `y`, `z`, `w` and of length<=4:
[`x`, `y`, `z`, `w`, `xx`, `xy`, `yx`, ..., `xxxx`, `xxxy`, `xyzw`, ...]
The length of the list will be 4 + 4x4 + 4x4x4 + 4x4x4x4 = 340.
"""
result = []
for k in range(1, required_length + 1):
result.extend(product(key_group, repeat=k))
result = [''.join(pat) for pat in result]
return result


def _gen_swizzles(cls):
swizzle_gen = SwizzleGenerator()
# https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)#Swizzling
KEYGROUP_SET = ['xyzw', 'rgba', 'stpq']
cls._swizzle_to_keygroup = {}
Expand Down Expand Up @@ -66,14 +83,13 @@ def prop_setter(instance, value):
cls._swizzle_to_keygroup[attr] = key_group

for key_group in KEYGROUP_SET:
sw_patterns = swizzle_gen.generate(key_group, required_length=4)
sw_patterns = _generate_swizzle_patterns(key_group, required_length=4)
# len=1 accessors are handled specially above
sw_patterns = filter(lambda p: len(p) > 1, sw_patterns)
for pat in sw_patterns:
for prop_key in sw_patterns:
# Create a function for value capturing
def gen_property(pattern, key_group):
checker = cls._keygroup_to_checker[key_group]
prop_key = ''.join(pattern)

def prop_getter(instance):
checker(instance, pattern)
Expand All @@ -86,16 +102,16 @@ def prop_getter(instance):
def prop_setter(instance, value):
if len(pattern) != len(value):
raise TaichiRuntimeError(
f'value len does not match the swizzle pattern={prop_key}'
f'value len does not match the swizzle pattern={pattern}'
)
checker(instance, pattern)
for ch, val in zip(pattern, value):
instance[key_group.index(ch)] = val

prop = property(prop_getter, prop_setter)
return prop_key, prop
return prop

prop_key, prop = gen_property(pat, key_group)
prop = gen_property(prop_key, key_group)
setattr(cls, prop_key, prop)
cls._swizzle_to_keygroup[prop_key] = key_group
return cls
Expand Down
109 changes: 0 additions & 109 deletions python/taichi/lang/swizzle_generator.py

This file was deleted.

13 changes: 6 additions & 7 deletions tests/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@


def _get_matrix_swizzle_apis():
swizzle_gen = ti.lang.swizzle_generator.SwizzleGenerator()
swizzle_gen = ti.lang.matrix._generate_swizzle_patterns
KEMAP_SET = ['xyzw', 'rgba', 'stpq']
res = []
for key_group in KEMAP_SET:
sw_patterns = swizzle_gen.generate(key_group, required_length=4)
sw_patterns = map(lambda p: ''.join(p), sw_patterns)
sw_patterns = swizzle_gen(key_group, required_length=4)
res += sw_patterns
return sorted(res)

Expand Down Expand Up @@ -81,10 +80,10 @@ def _get_expected_matrix_apis():
'randn', 'random', 'raw_div', 'raw_mod', 'ref', 'rescale_index', 'reset',
'rgb_to_hex', 'root', 'round', 'rsqrt', 'select', 'set_logging_level',
'simt', 'sin', 'solve', 'sparse_matrix_builder', 'sqrt', 'static',
'static_assert', 'static_print', 'stop_grad', 'svd', 'swizzle_generator',
'sym_eig', 'sync', 'tan', 'tanh', 'template', 'tools', 'types', 'u16',
'u32', 'u64', 'u8', 'ui', 'uint16', 'uint32', 'uint64', 'uint8', 'vulkan',
'wasm', 'x64', 'x86_64', 'zero'
'static_assert', 'static_print', 'stop_grad', 'svd', 'sym_eig', 'sync',
'tan', 'tanh', 'template', 'tools', 'types', 'u16', 'u32', 'u64', 'u8',
'ui', 'uint16', 'uint32', 'uint64', 'uint8', 'vulkan', 'wasm', 'x64',
'x86_64', 'zero'
]
user_api[ti.ad] = [
'FwdMode', 'Tape', 'clear_all_gradients', 'grad_for', 'grad_replaced',
Expand Down
10 changes: 0 additions & 10 deletions tests/python/test_swizzle_generator.py

This file was deleted.

0 comments on commit 2d30410

Please sign in to comment.