From ef2cc356d76a3cbc4273d37af9b4ad9f12ef9deb Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Fri, 29 Jul 2022 18:08:10 +0800 Subject: [PATCH] [Lang] [type] Refine SNode with quant 9/n: Rename some parameters in quant APIs (#5566) * [Lang] [type] Rename some parameters in quant APIs * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- python/taichi/_snode/fields_builder.py | 4 ++-- python/taichi/lang/snode.py | 6 +++--- python/taichi/types/quantized_types.py | 16 ++++++++-------- tests/python/test_bitpacked_fields.py | 2 +- tests/python/test_matrix_different_type.py | 4 ++-- tests/python/test_quant_array.py | 12 ++++++------ tests/python/test_quant_array_vectorization.py | 16 ++++++++-------- tests/python/test_quant_atomics.py | 6 +++--- tests/python/test_quant_fixed.py | 8 ++++---- tests/python/test_quant_time_integration.py | 2 +- 10 files changed, 38 insertions(+), 38 deletions(-) diff --git a/python/taichi/_snode/fields_builder.py b/python/taichi/_snode/fields_builder.py index d07ad68f6c624..282ef825ef813 100644 --- a/python/taichi/_snode/fields_builder.py +++ b/python/taichi/_snode/fields_builder.py @@ -99,11 +99,11 @@ def bitmasked(self, indices: Union[Sequence[_Axis], _Axis], return self.root.bitmasked(indices, dimensions) def quant_array(self, indices: Union[Sequence[_Axis], _Axis], - dimensions: Union[Sequence[int], int], num_bits: int): + dimensions: Union[Sequence[int], int], max_num_bits: int): """Same as :func:`taichi.lang.snode.SNode.quant_array`""" self._check_not_finalized() self.empty = False - return self.root.quant_array(indices, dimensions, num_bits) + return self.root.quant_array(indices, dimensions, max_num_bits) def place(self, *args: Any, diff --git a/python/taichi/lang/snode.py b/python/taichi/lang/snode.py index 2f8344ba9f8df..9c7bf1857ecd4 100644 --- a/python/taichi/lang/snode.py +++ b/python/taichi/lang/snode.py @@ -96,13 +96,13 @@ def bitmasked(self, axes, dimensions): self.ptr.bitmasked(axes, dimensions, impl.current_cfg().packed)) - def quant_array(self, axes, dimensions, num_bits): + def quant_array(self, axes, dimensions, max_num_bits): """Adds a quant_array SNode as a child component of `self`. Args: axes (List[Axis]): Axes to activate. dimensions (Union[List[int], int]): Shape of each axis. - num_bits (int): Number of bits to use. + max_num_bits (int): Maximum number of bits it can hold. Returns: The added :class:`~taichi.lang.SNode` instance. @@ -110,7 +110,7 @@ def quant_array(self, axes, dimensions, num_bits): if isinstance(dimensions, int): dimensions = [dimensions] * len(axes) return SNode( - self.ptr.quant_array(axes, dimensions, num_bits, + self.ptr.quant_array(axes, dimensions, max_num_bits, impl.current_cfg().packed)) def place(self, *args, offset=None): diff --git a/python/taichi/types/quantized_types.py b/python/taichi/types/quantized_types.py index 69ff8a344fe0e..06a4f34f5f9e7 100644 --- a/python/taichi/types/quantized_types.py +++ b/python/taichi/types/quantized_types.py @@ -27,13 +27,13 @@ def int(bits, signed=True, compute=None): # pylint: disable=W0622 return _type_factory.get_quant_int_type(bits, signed, compute) -def fixed(frac, signed=True, range=1.0, compute=None, scale=None): # pylint: disable=W0622 +def fixed(bits, signed=True, max_value=1.0, compute=None, scale=None): """Generates a quantized type for fixed-point real numbers. Args: - frac (int): Number of bits. + bits (int): Number of bits. signed (bool): Signed or unsigned. - range (float): Range of the number. + max_value (float): Maximum value of the number. compute (DataType): Type for computation. scale (float): Scaling factor. The argument is prioritized over range. @@ -44,14 +44,14 @@ def fixed(frac, signed=True, range=1.0, compute=None, scale=None): # pylint: di compute = impl.get_runtime().default_fp if isinstance(compute, _ti_python_core.DataType): compute = compute.get_ptr() - # TODO: handle cases with frac > 32 - frac_type = int(bits=frac, signed=signed, compute=i32) + # TODO: handle cases with bits > 32 + underlying_type = int(bits=bits, signed=signed, compute=i32) if scale is None: if signed: - scale = range / 2**(frac - 1) + scale = max_value / 2**(bits - 1) else: - scale = range / 2**frac - return _type_factory.get_quant_fixed_type(frac_type, compute, scale) + scale = max_value / 2**bits + return _type_factory.get_quant_fixed_type(underlying_type, compute, scale) def float(exp, frac, signed=True, compute=None): # pylint: disable=W0622 diff --git a/tests/python/test_bitpacked_fields.py b/tests/python/test_bitpacked_fields.py index 2a1c8294628f1..2e3ca5f5c0251 100644 --- a/tests/python/test_bitpacked_fields.py +++ b/tests/python/test_bitpacked_fields.py @@ -157,7 +157,7 @@ def test_bitpacked_fields_struct_for(): block_size = 16 N = 64 cell = ti.root.pointer(ti.i, N // block_size) - fixed32 = ti.types.quant.fixed(frac=32, range=1024) + fixed32 = ti.types.quant.fixed(bits=32, max_value=1024) x = ti.field(dtype=fixed32) bitpack = ti.BitpackedFields(max_num_bits=32) diff --git a/tests/python/test_matrix_different_type.py b/tests/python/test_matrix_different_type.py index d9478f76e9da9..03d9ac1db2aca 100644 --- a/tests/python/test_matrix_different_type.py +++ b/tests/python/test_matrix_different_type.py @@ -71,9 +71,9 @@ def verify(): @test_utils.test(require=ti.extension.quant_basic) def test_quant_type(): qit1 = ti.types.quant.int(bits=10, signed=True) - qfxt1 = ti.types.quant.fixed(frac=10, signed=True, scale=0.1) + qfxt1 = ti.types.quant.fixed(bits=10, signed=True, scale=0.1) qit2 = ti.types.quant.int(bits=22, signed=False) - qfxt2 = ti.types.quant.fixed(frac=22, signed=False, scale=0.1) + qfxt2 = ti.types.quant.fixed(bits=22, signed=False, scale=0.1) type_list = [[qit1, qfxt2], [qfxt1, qit2]] a = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list) b = ti.Matrix.field(len(type_list), len(type_list[0]), dtype=type_list) diff --git a/tests/python/test_quant_array.py b/tests/python/test_quant_array.py index 75486153e1d51..114a9abed3c6d 100644 --- a/tests/python/test_quant_array.py +++ b/tests/python/test_quant_array.py @@ -10,7 +10,7 @@ def test_1D_quant_array(): N = 32 - ti.root.quant_array(ti.i, N, num_bits=32).place(x) + ti.root.quant_array(ti.i, N, max_num_bits=32).place(x) @ti.kernel def set_val(): @@ -31,7 +31,7 @@ def test_1D_quant_array_negative(): N = 4 qi7 = ti.types.quant.int(7) x = ti.field(dtype=qi7) - ti.root.quant_array(ti.i, N, num_bits=32).place(x) + ti.root.quant_array(ti.i, N, max_num_bits=32).place(x) @ti.kernel def assign(): @@ -45,13 +45,13 @@ def assign(): @test_utils.test(require=ti.extension.quant, debug=True) def test_1D_quant_array_fixed(): - qfxt = ti.types.quant.fixed(frac=8, range=2) + qfxt = ti.types.quant.fixed(bits=8, max_value=2) x = ti.field(dtype=qfxt) N = 4 - ti.root.quant_array(ti.i, N, num_bits=32).place(x) + ti.root.quant_array(ti.i, N, max_num_bits=32).place(x) @ti.kernel def set_val(): @@ -75,7 +75,7 @@ def test_2D_quant_array(): M, N = 4, 8 - ti.root.quant_array(ti.ij, (M, N), num_bits=32).place(x) + ti.root.quant_array(ti.ij, (M, N), max_num_bits=32).place(x) @ti.kernel def set_val(): @@ -102,7 +102,7 @@ def test_quant_array_struct_for(): x = ti.field(dtype=qi7) cell.dense(ti.i, block_size // 4).quant_array(ti.i, 4, - num_bits=32).place(x) + max_num_bits=32).place(x) @ti.kernel def activate(): diff --git a/tests/python/test_quant_array_vectorization.py b/tests/python/test_quant_array_vectorization.py index b1f6578dafa25..6044da5c83128 100644 --- a/tests/python/test_quant_array_vectorization.py +++ b/tests/python/test_quant_array_vectorization.py @@ -18,9 +18,9 @@ def test_vectorized_struct_for(): block = ti.root.pointer(ti.ij, (n_blocks, n_blocks)) block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array( - ti.j, bits, num_bits=bits).place(x) + ti.j, bits, max_num_bits=bits).place(x) block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array( - ti.j, bits, num_bits=bits).place(y) + ti.j, bits, max_num_bits=bits).place(y) @ti.kernel def init(): @@ -61,11 +61,11 @@ def test_offset_load(): block = ti.root.pointer(ti.ij, (n_blocks, n_blocks)) block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array( - ti.j, bits, num_bits=bits).place(x) + ti.j, bits, max_num_bits=bits).place(x) block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array( - ti.j, bits, num_bits=bits).place(y) + ti.j, bits, max_num_bits=bits).place(y) block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array( - ti.j, bits, num_bits=bits).place(z) + ti.j, bits, max_num_bits=bits).place(z) @ti.kernel def init(): @@ -121,11 +121,11 @@ def test_evolve(): block = ti.root.pointer(ti.ij, (n_blocks, n_blocks)) block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array( - ti.j, bits, num_bits=bits).place(x) + ti.j, bits, max_num_bits=bits).place(x) block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array( - ti.j, bits, num_bits=bits).place(y) + ti.j, bits, max_num_bits=bits).place(y) block.dense(ti.ij, (N // n_blocks, N // (bits * n_blocks))).quant_array( - ti.j, bits, num_bits=bits).place(z) + ti.j, bits, max_num_bits=bits).place(z) @ti.kernel def init(): diff --git a/tests/python/test_quant_atomics.py b/tests/python/test_quant_atomics.py index 8336fb5645d32..fb98939f0b9fc 100644 --- a/tests/python/test_quant_atomics.py +++ b/tests/python/test_quant_atomics.py @@ -50,7 +50,7 @@ def test_quant_int_atomics_b64(): x = ti.field(dtype=qi13) - ti.root.quant_array(ti.i, 4, num_bits=64).place(x) + ti.root.quant_array(ti.i, 4, max_num_bits=64).place(x) x[0] = 100 x[1] = 200 @@ -70,8 +70,8 @@ def foo(): @test_utils.test(require=ti.extension.quant_basic, debug=True) def test_quant_fixed_atomics(): - qfxt13 = ti.types.quant.fixed(frac=13, signed=True, scale=0.1) - qfxt19 = ti.types.quant.fixed(frac=19, signed=False, scale=0.1) + qfxt13 = ti.types.quant.fixed(bits=13, signed=True, scale=0.1) + qfxt19 = ti.types.quant.fixed(bits=19, signed=False, scale=0.1) x = ti.field(dtype=qfxt13) y = ti.field(dtype=qfxt19) diff --git a/tests/python/test_quant_fixed.py b/tests/python/test_quant_fixed.py index 82275482fd782..f6e6a0b0c00af 100644 --- a/tests/python/test_quant_fixed.py +++ b/tests/python/test_quant_fixed.py @@ -8,7 +8,7 @@ @test_utils.test(require=ti.extension.quant_basic) def test_quant_fixed(): - qfxt = ti.types.quant.fixed(frac=32, range=2) + qfxt = ti.types.quant.fixed(bits=32, max_value=2) x = ti.field(dtype=qfxt) bitpack = ti.BitpackedFields(max_num_bits=32) @@ -31,7 +31,7 @@ def foo(): @test_utils.test(require=ti.extension.quant_basic) def test_quant_fixed_matrix_rotation(): - qfxt = ti.types.quant.fixed(frac=16, range=1.2) + qfxt = ti.types.quant.fixed(bits=16, max_value=1.2) x = ti.Matrix.field(2, 2, dtype=qfxt) @@ -61,7 +61,7 @@ def rotate_18_degrees(): @test_utils.test(require=ti.extension.quant_basic) def test_quant_fixed_implicit_cast(): - qfxt = ti.types.quant.fixed(frac=13, scale=0.1) + qfxt = ti.types.quant.fixed(bits=13, scale=0.1) x = ti.field(dtype=qfxt) bitpack = ti.BitpackedFields(max_num_bits=32) @@ -78,7 +78,7 @@ def foo(): @test_utils.test(require=ti.extension.quant_basic) def test_quant_fixed_cache_read_only(): - qfxt = ti.types.quant.fixed(frac=15, scale=0.1) + qfxt = ti.types.quant.fixed(bits=15, scale=0.1) x = ti.field(dtype=qfxt) bitpack = ti.BitpackedFields(max_num_bits=32) diff --git a/tests/python/test_quant_time_integration.py b/tests/python/test_quant_time_integration.py index 0c8d11a70b6e7..166d03a9efaf1 100644 --- a/tests/python/test_quant_time_integration.py +++ b/tests/python/test_quant_time_integration.py @@ -28,7 +28,7 @@ def test_quant_time_integration(use_quant, use_exponent, use_shared_exp): bitpack.place(x.get_scalar_field(1)) ti.root.place(bitpack) else: - qfxt = ti.types.quant.fixed(frac=16, range=2) + qfxt = ti.types.quant.fixed(bits=16, max_value=2) x = ti.Vector.field(2, dtype=qfxt) bitpack = ti.BitpackedFields(max_num_bits=32) bitpack.place(x)