From 57a831d0d1e8ef5f5a29ebb14e5684179efda90f Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Mon, 24 Jan 2022 17:33:50 +0800 Subject: [PATCH 1/7] [refactor] Remove exposure of internal functions in taichi.lang.ops --- python/taichi/lang/ops.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index 9ecd6eef828c7..3542b83846fdc 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -858,3 +858,11 @@ def ti_any(a): def ti_all(a): return a.all() + + +__all__ = [ + "acos", "asin", "atan2", "atomic_and", "atomic_or", "atomic_xor", + "atomic_max", "atomic_sub", "atomic_min", "atomic_add", "bit_cast", + "bit_shr", "cast", "ceil", "cos", "exp", "floor", "log", "random", + "raw_mod", "raw_div", "round", "rsqrt", "sin", "sqrt", "tan", "tanh" +] From 5a6ed4888a1d22c685d898a720f7ca4569e815ca Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Tue, 25 Jan 2022 11:38:22 +0800 Subject: [PATCH 2/7] fix --- python/taichi/lang/ast/ast_transformer.py | 4 ++-- python/taichi/lang/matrix.py | 4 ++-- python/taichi/lang/ops.py | 18 +++++++++--------- tests/python/test_abs.py | 2 +- tests/python/test_element_wise.py | 6 +++--- tests/python/test_f16.py | 4 ++-- tests/python/test_loops.py | 4 ++-- tests/python/test_scalar_op.py | 8 ++++---- 8 files changed, 25 insertions(+), 25 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index f15cf569587ed..9efb4ad690418 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -322,8 +322,8 @@ def build_call_if_is_builtin(node, args, keywords): func = node.func.ptr replace_func = { id(print): impl.ti_print, - id(min): ti_ops.ti_min, - id(max): ti_ops.ti_max, + id(min): ti_ops.min, + id(max): ti_ops.max, id(int): impl.ti_int, id(float): impl.ti_float, id(any): ti_ops.ti_any, diff --git a/python/taichi/lang/matrix.py b/python/taichi/lang/matrix.py index 306d12da21b53..da697fc1d7cda 100644 --- a/python/taichi/lang/matrix.py +++ b/python/taichi/lang/matrix.py @@ -614,11 +614,11 @@ def norm_sqr(self): def max(self): """Return the maximum element value.""" - return ops_mod.ti_max(*self.entries) + return ops_mod.max(*self.entries) def min(self): """Return the minimum element value.""" - return ops_mod.ti_min(*self.entries) + return ops_mod.min(*self.entries) def any(self): """Test whether any element not equal zero. diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index 3542b83846fdc..97ab9f5f37c81 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -504,7 +504,7 @@ def truediv(a, b): @binary -def max(a, b): # pylint: disable=W0622 +def max_impl(a, b): # pylint: disable=W0622 """The maxnimum function. Args: @@ -518,7 +518,7 @@ def max(a, b): # pylint: disable=W0622 @binary -def min(a, b): # pylint: disable=W0622 +def min_impl(a, b): # pylint: disable=W0622 """The minimum function. Args: @@ -832,24 +832,24 @@ def assign(a, b): return a -def ti_max(*args): +def max(*args): num_args = len(args) assert num_args >= 1 if num_args == 1: return args[0] if num_args == 2: - return max(args[0], args[1]) - return max(args[0], ti_max(*args[1:])) + return max_impl(args[0], args[1]) + return max_impl(args[0], max(*args[1:])) -def ti_min(*args): +def min(*args): num_args = len(args) assert num_args >= 1 if num_args == 1: return args[0] if num_args == 2: - return min(args[0], args[1]) - return min(args[0], ti_min(*args[1:])) + return min_impl(args[0], args[1]) + return min_impl(args[0], min(*args[1:])) def ti_any(a): @@ -864,5 +864,5 @@ def ti_all(a): "acos", "asin", "atan2", "atomic_and", "atomic_or", "atomic_xor", "atomic_max", "atomic_sub", "atomic_min", "atomic_add", "bit_cast", "bit_shr", "cast", "ceil", "cos", "exp", "floor", "log", "random", - "raw_mod", "raw_div", "round", "rsqrt", "sin", "sqrt", "tan", "tanh" + "raw_mod", "raw_div", "round", "rsqrt", "sin", "sqrt", "tan", "tanh", "max", "min" ] diff --git a/tests/python/test_abs.py b/tests/python/test_abs.py index 7ca262e58259e..7e1f6db780979 100644 --- a/tests/python/test_abs.py +++ b/tests/python/test_abs.py @@ -15,7 +15,7 @@ def test_abs(): @ti.kernel def func(): for i in range(N): - x[i] = ti.abs(y[i]) + x[i] = abs(y[i]) for i in range(N): y[i] = i - 10 diff --git a/tests/python/test_element_wise.py b/tests/python/test_element_wise.py index 2bfc3bc00c9f3..891b58f7b835f 100644 --- a/tests/python/test_element_wise.py +++ b/tests/python/test_element_wise.py @@ -256,10 +256,10 @@ def test_unary(): def func(): xi[0] = -yi[None] xi[1] = ~yi[None] - xi[2] = ti.logical_not(yi[None]) - xi[3] = ti.abs(yi[None]) + xi[2] = not yi[None] + xi[3] = abs(yi[None]) xf[0] = -yf[None] - xf[1] = ti.abs(yf[None]) + xf[1] = abs(yf[None]) xf[2] = ti.sqrt(yf[None]) xf[3] = ti.sin(yf[None]) xf[4] = ti.cos(yf[None]) diff --git a/tests/python/test_f16.py b/tests/python/test_f16.py index 37839a1317663..0eff41c502ae8 100644 --- a/tests/python/test_f16.py +++ b/tests/python/test_f16.py @@ -141,7 +141,7 @@ def test_unary_op(): @ti.kernel def foo(): - x[None] = ti.neg(y[None]) + x[None] = -y[None] x[None] = ti.floor(x[None]) y[None] = ti.ceil(y[None]) @@ -159,7 +159,7 @@ def test_extra_unary_promote(): @ti.kernel def foo(): - x[None] = ti.abs(y[None]) + x[None] = abs(y[None]) y[None] = -0.3 foo() diff --git a/tests/python/test_loops.py b/tests/python/test_loops.py index b58f6e9a93ceb..d63b56037df25 100644 --- a/tests/python/test_loops.py +++ b/tests/python/test_loops.py @@ -18,7 +18,7 @@ def test_loops(): @ti.kernel def func(): for i in range(ti.static(N // 2 + 3), N): - x[i] = ti.abs(y[i]) + x[i] = abs(y[i]) func() @@ -50,7 +50,7 @@ def test_numpy_loops(): @ti.kernel def func(): for i in range(begin, end): - x[i] = ti.abs(y[i]) + x[i] = abs(y[i]) func() diff --git a/tests/python/test_scalar_op.py b/tests/python/test_scalar_op.py index 32b87852b7962..b200d410a994d 100644 --- a/tests/python/test_scalar_op.py +++ b/tests/python/test_scalar_op.py @@ -31,8 +31,8 @@ unary_func_table = [ (ops.neg, ) * 2, (ops.invert, ) * 2, - (ti.logical_not, np.logical_not), - (ti.abs, np.abs), + (ti.lang.ops.logical_not, np.logical_not), + (ti.lang.ops.abs, np.abs), (ti.exp, np.exp), (ti.log, np.log), (ti.sin, np.sin), @@ -64,10 +64,10 @@ def test_python_scope_vector_binary(ti_func, np_func): def test_python_scope_vector_unary(ti_func, np_func): ti.init() x = ti.Vector([2, 3] if ti_func in - [ops.invert, ti.logical_not] else [0.2, 0.3]) + [ops.invert, ti.lang.ops.logical_not] else [0.2, 0.3]) result = ti_func(x).to_numpy() - if ti_func in [ti.logical_not]: + if ti_func in [ti.lang.ops.logical_not]: result = result.astype(bool) expected = np_func(x.to_numpy()) assert allclose(result, expected) From 061087d5a1ce93cc6eee22ac569ff3c0cea10423 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Tue, 25 Jan 2022 11:50:22 +0800 Subject: [PATCH 3/7] add warning --- python/taichi/lang/ast/ast_transformer.py | 10 ++++++++-- python/taichi/lang/ops.py | 3 ++- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 9efb4ad690418..5fafa7fcecdd0 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -318,7 +318,7 @@ def build_JoinedStr(ctx, node): return node.ptr @staticmethod - def build_call_if_is_builtin(node, args, keywords): + def build_call_if_is_builtin(ctx, node, args, keywords): func = node.func.ptr replace_func = { id(print): impl.ti_print, @@ -333,6 +333,12 @@ def build_call_if_is_builtin(node, args, keywords): } if id(func) in replace_func: node.ptr = replace_func[id(func)](*args, **keywords) + if func is min or func is max: + name = "min" if func is min else "max" + warnings.warn_explicit( + f'Calling builtin function "{name}" in Taichi scope is deprecated. ' + f'Please use "ti.{name}" instead.', UserWarning, ctx.file, + node.lineno + ctx.lineno_offset) return True return False @@ -383,7 +389,7 @@ def build_Call(ctx, node): node.ptr = impl.ti_format(*args, **keywords) return node.ptr - if ASTTransformer.build_call_if_is_builtin(node, args, keywords): + if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords): return node.ptr node.ptr = func(*args, **keywords) diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index 97ab9f5f37c81..bcbe622cb5caf 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -864,5 +864,6 @@ def ti_all(a): "acos", "asin", "atan2", "atomic_and", "atomic_or", "atomic_xor", "atomic_max", "atomic_sub", "atomic_min", "atomic_add", "bit_cast", "bit_shr", "cast", "ceil", "cos", "exp", "floor", "log", "random", - "raw_mod", "raw_div", "round", "rsqrt", "sin", "sqrt", "tan", "tanh", "max", "min" + "raw_mod", "raw_div", "round", "rsqrt", "sin", "sqrt", "tan", "tanh", + "max", "min" ] From 291b3c5cf2c49d17e88270b6103cae20fae1108d Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Tue, 25 Jan 2022 11:56:08 +0800 Subject: [PATCH 4/7] change warning type --- python/taichi/lang/ast/ast_transformer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 5fafa7fcecdd0..f4c17aa5b3390 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -337,7 +337,7 @@ def build_call_if_is_builtin(ctx, node, args, keywords): name = "min" if func is min else "max" warnings.warn_explicit( f'Calling builtin function "{name}" in Taichi scope is deprecated. ' - f'Please use "ti.{name}" instead.', UserWarning, ctx.file, + f'Please use "ti.{name}" instead.', DeprecationWarning, ctx.file, node.lineno + ctx.lineno_offset) return True return False From e5aceb423a086399886bba3c4dae1d9b525eda2f Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Tue, 25 Jan 2022 11:58:13 +0800 Subject: [PATCH 5/7] fix pylint --- python/taichi/lang/ast/ast_transformer.py | 4 ++-- python/taichi/lang/ops.py | 8 ++++---- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index f4c17aa5b3390..27a7ca6274651 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -337,8 +337,8 @@ def build_call_if_is_builtin(ctx, node, args, keywords): name = "min" if func is min else "max" warnings.warn_explicit( f'Calling builtin function "{name}" in Taichi scope is deprecated. ' - f'Please use "ti.{name}" instead.', DeprecationWarning, ctx.file, - node.lineno + ctx.lineno_offset) + f'Please use "ti.{name}" instead.', DeprecationWarning, + ctx.file, node.lineno + ctx.lineno_offset) return True return False diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index bcbe622cb5caf..fe43c831ef041 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -504,7 +504,7 @@ def truediv(a, b): @binary -def max_impl(a, b): # pylint: disable=W0622 +def max_impl(a, b): """The maxnimum function. Args: @@ -518,7 +518,7 @@ def max_impl(a, b): # pylint: disable=W0622 @binary -def min_impl(a, b): # pylint: disable=W0622 +def min_impl(a, b): """The minimum function. Args: @@ -832,7 +832,7 @@ def assign(a, b): return a -def max(*args): +def max(*args): # pylint: disable=W0622 num_args = len(args) assert num_args >= 1 if num_args == 1: @@ -842,7 +842,7 @@ def max(*args): return max_impl(args[0], max(*args[1:])) -def min(*args): +def min(*args): # pylint: disable=W0622 num_args = len(args) assert num_args >= 1 if num_args == 1: From 65ce4861fd0ec2e04b3bdc96ef0aba0984e95f0d Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Tue, 25 Jan 2022 13:51:35 +0800 Subject: [PATCH 6/7] fix select --- python/taichi/lang/ast/ast_transformer.py | 5 +++++ python/taichi/lang/ops.py | 2 +- tests/python/test_matrix.py | 3 ++- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 27a7ca6274651..24ae05fa329f6 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -1080,6 +1080,11 @@ def build_IfExp(ctx, node): node.body.ptr) or is_taichi_class(node.orelse.ptr): node.ptr = ti_ops.select(node.test.ptr, node.body.ptr, node.orelse.ptr) + warnings.warn_explicit( + f'Using conditional expression for element-wise select operation on ' + f'Taichi vectors/matrices is deprecated. ' + f'Please use "ti.select" instead.', DeprecationWarning, + ctx.file, node.lineno + ctx.lineno_offset) return node.ptr is_static_if = (ASTTransformer.get_decorator(ctx, diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index fe43c831ef041..0b54dad6b8986 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -865,5 +865,5 @@ def ti_all(a): "atomic_max", "atomic_sub", "atomic_min", "atomic_add", "bit_cast", "bit_shr", "cast", "ceil", "cos", "exp", "floor", "log", "random", "raw_mod", "raw_div", "round", "rsqrt", "sin", "sqrt", "tan", "tanh", - "max", "min" + "max", "min", "select" ] diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index 6df0bcb4f4cac..8160bdfa7c213 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -1,3 +1,4 @@ +import math import operator import numpy as np @@ -73,7 +74,7 @@ def test_python_scope_matrix_field(ops): @ti.test(arch=ti.get_host_arch_list()) def test_constant_matrices(): - assert ti.cos(ti.math.pi / 3) == approx(0.5) + assert ti.cos(math.pi / 3) == approx(0.5) assert np.allclose((-ti.Vector([2, 3])).to_numpy(), np.array([-2, -3])) assert ti.cos(ti.Vector([2, 3])).to_numpy() == approx(np.cos(np.array([2, From 76bb2a5036127081561ffda8d9535ab65bb670e6 Mon Sep 17 00:00:00 2001 From: lin-hitonami Date: Tue, 25 Jan 2022 14:57:49 +0800 Subject: [PATCH 7/7] fix pylint --- python/taichi/lang/ast/ast_transformer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 24ae05fa329f6..9c68ee8f53145 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -1081,9 +1081,9 @@ def build_IfExp(ctx, node): node.ptr = ti_ops.select(node.test.ptr, node.body.ptr, node.orelse.ptr) warnings.warn_explicit( - f'Using conditional expression for element-wise select operation on ' - f'Taichi vectors/matrices is deprecated. ' - f'Please use "ti.select" instead.', DeprecationWarning, + 'Using conditional expression for element-wise select operation on ' + 'Taichi vectors/matrices is deprecated. ' + 'Please use "ti.select" instead.', DeprecationWarning, ctx.file, node.lineno + ctx.lineno_offset) return node.ptr