From 2d3041078fb008fcf2c2879a521a8b60d2edd8ad Mon Sep 17 00:00:00 2001 From: Zhao Liang Date: Thu, 19 Jan 2023 22:28:21 +0800 Subject: [PATCH] [Lang] Simplify the swizzle generator (#7216) This PR replaces the `SwizzleGenerator.generate()` with a simpler function to generate all swizzling patterns. Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- python/taichi/lang/matrix.py | 32 +++++-- python/taichi/lang/swizzle_generator.py | 109 ------------------------ tests/python/test_api.py | 13 ++- tests/python/test_swizzle_generator.py | 10 --- 4 files changed, 30 insertions(+), 134 deletions(-) delete mode 100644 python/taichi/lang/swizzle_generator.py delete mode 100644 tests/python/test_swizzle_generator.py diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 38164e85aea19..2cc5c75c352ac 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -1,6 +1,7 @@ import functools import numbers from collections.abc import Iterable +from itertools import product import numpy as np from taichi._lib import core as ti_python_core @@ -13,7 +14,6 @@ from taichi.lang.exception import (TaichiRuntimeError, TaichiSyntaxError, TaichiTypeError) from taichi.lang.field import Field, ScalarField, SNodeHostAccess -from taichi.lang.swizzle_generator import SwizzleGenerator from taichi.lang.util import (cook_dtype, in_python_scope, python_scope, taichi_scope, to_numpy_type, to_paddle_type, to_pytorch_type, warning) @@ -21,8 +21,25 @@ from taichi.types.compound_types import CompoundType, TensorType +def _generate_swizzle_patterns(key_group: str, required_length=4): + """Generate vector swizzle patterns from a given set of characters. + + Example: + + For `key_group=xyzw` and `required_length=4`, this function will return a + list consists of all possible strings (no repeats) in characters + `x`, `y`, `z`, `w` and of length<=4: + [`x`, `y`, `z`, `w`, `xx`, `xy`, `yx`, ..., `xxxx`, `xxxy`, `xyzw`, ...] + The length of the list will be 4 + 4x4 + 4x4x4 + 4x4x4x4 = 340. + """ + result = [] + for k in range(1, required_length + 1): + result.extend(product(key_group, repeat=k)) + result = [''.join(pat) for pat in result] + return result + + def _gen_swizzles(cls): - swizzle_gen = SwizzleGenerator() # https://www.khronos.org/opengl/wiki/Data_Type_(GLSL)#Swizzling KEYGROUP_SET = ['xyzw', 'rgba', 'stpq'] cls._swizzle_to_keygroup = {} @@ -66,14 +83,13 @@ def prop_setter(instance, value): cls._swizzle_to_keygroup[attr] = key_group for key_group in KEYGROUP_SET: - sw_patterns = swizzle_gen.generate(key_group, required_length=4) + sw_patterns = _generate_swizzle_patterns(key_group, required_length=4) # len=1 accessors are handled specially above sw_patterns = filter(lambda p: len(p) > 1, sw_patterns) - for pat in sw_patterns: + for prop_key in sw_patterns: # Create a function for value capturing def gen_property(pattern, key_group): checker = cls._keygroup_to_checker[key_group] - prop_key = ''.join(pattern) def prop_getter(instance): checker(instance, pattern) @@ -86,16 +102,16 @@ def prop_getter(instance): def prop_setter(instance, value): if len(pattern) != len(value): raise TaichiRuntimeError( - f'value len does not match the swizzle pattern={prop_key}' + f'value len does not match the swizzle pattern={pattern}' ) checker(instance, pattern) for ch, val in zip(pattern, value): instance[key_group.index(ch)] = val prop = property(prop_getter, prop_setter) - return prop_key, prop + return prop - prop_key, prop = gen_property(pat, key_group) + prop = gen_property(prop_key, key_group) setattr(cls, prop_key, prop) cls._swizzle_to_keygroup[prop_key] = key_group return cls diff --git a/python/taichi/lang/swizzle_generator.py b/python/taichi/lang/swizzle_generator.py deleted file mode 100644 index 219be47aa2507..0000000000000 --- a/python/taichi/lang/swizzle_generator.py +++ /dev/null @@ -1,109 +0,0 @@ -from collections import namedtuple -from itertools import combinations, permutations -from typing import Iterable, List, Tuple - -NRE_LEN = namedtuple('NRE_LEN', ['num_required_elems', 'length']) - - -def generate_num_required_elems_required_len_map(max_unique_elems=5): - ''' - For example, if we want a sequence of length 4 to be composed by {'x', 'y'}, - we have the following options: - - * 'xxxx': 4 'x', 0 'y' - * 'xyyy': 1 'x', 3 'y' - * 'xxyy': 2 'x', 2 'y' - * 'xxxy': 3 'x', 1 'y' - * 'yyyy': 0 'x', 4 'y' - - Each of these pattern is a seed. We can then do a permutation on it to get - all the patterns for this seed. - - NRE_LEN(2, 4) maps to [(4, 0), (1, 3), (2, 2), (3, 1), (0, 4)] - ''' - class InvalidPattern(Exception): - pass - - m = {} - m[(0, 0)] = () - - def _gen_impl(num_required_elems: int, required_len: int): - mkey = NRE_LEN(num_required_elems, required_len) - try: - return m[mkey] - except KeyError: - pass - invalid_pat = InvalidPattern(f'{num_required_elems} {required_len}') - if num_required_elems > required_len: - raise invalid_pat - if num_required_elems == 0: - if required_len == 0: - return [] - raise invalid_pat - if num_required_elems == 1: - if required_len > 0: - m[mkey] = ((required_len, ), ) - return m[mkey] - raise invalid_pat - - res = [] - for n in range(1, required_len + 1): - try: - cur = _gen_impl(num_required_elems - 1, required_len - n) - res += [(n, ) + t for t in cur] - except InvalidPattern: - pass - res = tuple(res) - m[mkey] = res - return res - - upperbound = max_unique_elems + 1 - for num_req in range(1, upperbound): - for required_len in range(num_req, upperbound): - _gen_impl(num_req, required_len) - - return m - - -class SwizzleGenerator: - def __init__(self, max_unique_elems=4): - self._nrel_map = generate_num_required_elems_required_len_map( - max_unique_elems) - - def generate(self, accessors: Iterable[str], - required_length: int) -> List[Tuple[int, ...]]: - res = [] - for l in range(required_length): - res += self._gen_for_length(accessors, l + 1) - return res - - def _gen_for_length(self, accessors, required_length): - acc_list = list(accessors) - res = [] - for num_required_elems in range(1, required_length + 1): - cur_len_patterns = set() - for subacc in combinations(acc_list, num_required_elems): - nrel_key = NRE_LEN(num_required_elems, required_length) - nrel_vals = self._nrel_map[nrel_key] - seed_patterns = self._generate_seed_patterns(subacc, nrel_vals) - for sp in seed_patterns: - for p in permutations(sp): - cur_len_patterns.add(p) - res += sorted(list(cur_len_patterns)) - return res - - @staticmethod - def _generate_seed_patterns(acc, nrel_vals): - res = [] - for val in nrel_vals: - assert len(acc) == len(val) - seed = [] - for char, vi in zip(acc, val): - seed += [char] * vi - res.append(tuple(seed)) - return res - - -__all__ = [ - 'SwizzleGenerator', -] diff --git a/tests/python/test_api.py b/tests/python/test_api.py index 9aaa9d5328c0a..6fdf9fec85a95 100644 --- a/tests/python/test_api.py +++ b/tests/python/test_api.py @@ -7,12 +7,11 @@ def _get_matrix_swizzle_apis(): - swizzle_gen = ti.lang.swizzle_generator.SwizzleGenerator() + swizzle_gen = ti.lang.matrix._generate_swizzle_patterns KEMAP_SET = ['xyzw', 'rgba', 'stpq'] res = [] for key_group in KEMAP_SET: - sw_patterns = swizzle_gen.generate(key_group, required_length=4) - sw_patterns = map(lambda p: ''.join(p), sw_patterns) + sw_patterns = swizzle_gen(key_group, required_length=4) res += sw_patterns return sorted(res) @@ -81,10 +80,10 @@ def _get_expected_matrix_apis(): 'randn', 'random', 'raw_div', 'raw_mod', 'ref', 'rescale_index', 'reset', 'rgb_to_hex', 'root', 'round', 'rsqrt', 'select', 'set_logging_level', 'simt', 'sin', 'solve', 'sparse_matrix_builder', 'sqrt', 'static', - 'static_assert', 'static_print', 'stop_grad', 'svd', 'swizzle_generator', - 'sym_eig', 'sync', 'tan', 'tanh', 'template', 'tools', 'types', 'u16', - 'u32', 'u64', 'u8', 'ui', 'uint16', 'uint32', 'uint64', 'uint8', 'vulkan', - 'wasm', 'x64', 'x86_64', 'zero' + 'static_assert', 'static_print', 'stop_grad', 'svd', 'sym_eig', 'sync', + 'tan', 'tanh', 'template', 'tools', 'types', 'u16', 'u32', 'u64', 'u8', + 'ui', 'uint16', 'uint32', 'uint64', 'uint8', 'vulkan', 'wasm', 'x64', + 'x86_64', 'zero' ] user_api[ti.ad] = [ 'FwdMode', 'Tape', 'clear_all_gradients', 'grad_for', 'grad_replaced', diff --git a/tests/python/test_swizzle_generator.py b/tests/python/test_swizzle_generator.py deleted file mode 100644 index f908a0671c114..0000000000000 --- a/tests/python/test_swizzle_generator.py +++ /dev/null @@ -1,10 +0,0 @@ -import pytest -from taichi.lang.swizzle_generator import SwizzleGenerator - - -def test_swizzle_gen(): - sg = SwizzleGenerator(max_unique_elems=4) - pats = sg.generate('xyzw', 4) - uniq_pats = set(pats) - # https://jojendersie.de/performance-optimal-vector-swizzling-in-c/ - assert len(uniq_pats) == 340