From 2cf09c35a6d5d0c4fc55d2b8b93ea66b18a9f80a Mon Sep 17 00:00:00 2001 From: Lin Jiang Date: Mon, 31 Oct 2022 16:20:30 +0800 Subject: [PATCH] [Error] Return NotImplemented for operations between field and Expr/Matrix/Struct (#6474) Issue: #6472 ### Brief Summary Traceback of #6472 before #6474, #6475 and #6477: ``` Traceback (most recent call last): File "/home/lin/test/ant.py", line 21, in myclass.do_something() File "/home/lin/taichi/python/taichi/lang/kernel_impl.py", line 1022, in __call__ return self._primal(self._kernel_owner, *args, **kwargs) File "/home/lin/taichi/python/taichi/lang/kernel_impl.py", line 888, in __call__ key = self.ensure_compiled(*args) File "/home/lin/taichi/python/taichi/lang/kernel_impl.py", line 855, in ensure_compiled self.materialize(key=key, args=args, arg_features=arg_features) File "/home/lin/taichi/python/taichi/lang/kernel_impl.py", line 566, in materialize taichi_kernel = impl.get_runtime().prog.create_kernel( File "/home/lin/taichi/python/taichi/lang/kernel_impl.py", line 556, in taichi_ast_generator transform_tree(tree, ctx) File "/home/lin/taichi/python/taichi/lang/ast/transform.py", line 6, in transform_tree ASTTransformer()(ctx, tree) File "/home/lin/taichi/python/taichi/lang/ast/ast_transformer_utils.py", line 28, in __call__ raise e.with_traceback(None) taichi.lang.exception.TaichiCompilationError: File "/home/lin/test/ant.py", line 16, in do_something: NEE_contrib = brdf * self.light_weight * self.light_color[None] ^^^^^^^^^^^^^^^^^^^^^^^^ Traceback (most recent call last): File "/home/lin/taichi/python/taichi/lang/ast/ast_transformer_utils.py", line 25, in __call__ return method(ctx, node) File "/home/lin/taichi/python/taichi/lang/ast/ast_transformer.py", line 800, in build_BinOp node.ptr = op(node.left.ptr, node.right.ptr) File "/home/lin/taichi/python/taichi/lang/ast/ast_transformer.py", line 788, in ast.Mult: lambda l, r: l * r, File "/home/lin/taichi/python/taichi/lang/common_ops.py", line 47, in __mul__ return ops.mul(self, other) File "/home/lin/taichi/python/taichi/lang/ops.py", line 55, in wrapped return a._element_wise_binary(imp_foo, b) File "/home/lin/taichi/python/taichi/lang/matrix.py", line 491, in _element_wise_binary other = self._broadcast_copy(other) File "/home/lin/taichi/python/taichi/lang/matrix.py", line 507, in _broadcast_copy other = Vector([other for _ in range(self.n)]) File "/home/lin/taichi/python/taichi/lang/matrix.py", line 1465, in __init__ super().__init__(arr, dt=dt, **kwargs) File "/home/lin/taichi/python/taichi/lang/matrix.py", line 433, in __init__ flattened += row File "/home/lin/taichi/python/taichi/lang/field.py", line 264, in __iter__ raise NotImplementedError( NotImplementedError: Struct for is only available in Taichi scope. ``` Traceback after #6474: ``` Traceback (most recent call last): File "/home/lin/test/ant.py", line 21, in myclass.do_something() File "/home/lin/taichi/python/taichi/lang/kernel_impl.py", line 1022, in __call__ return self._primal(self._kernel_owner, *args, **kwargs) File "/home/lin/taichi/python/taichi/lang/kernel_impl.py", line 888, in __call__ key = self.ensure_compiled(*args) File "/home/lin/taichi/python/taichi/lang/kernel_impl.py", line 855, in ensure_compiled self.materialize(key=key, args=args, arg_features=arg_features) File "/home/lin/taichi/python/taichi/lang/kernel_impl.py", line 566, in materialize taichi_kernel = impl.get_runtime().prog.create_kernel( File "/home/lin/taichi/python/taichi/lang/kernel_impl.py", line 556, in taichi_ast_generator transform_tree(tree, ctx) File "/home/lin/taichi/python/taichi/lang/ast/transform.py", line 6, in transform_tree ASTTransformer()(ctx, tree) File "/home/lin/taichi/python/taichi/lang/ast/ast_transformer_utils.py", line 28, in __call__ raise e.with_traceback(None) taichi.lang.exception.TaichiCompilationError: File "/home/lin/test/ant.py", line 16, in do_something: NEE_contrib = brdf * self.light_weight * self.light_color[None] ^^^^^^^^^^^^^^^^^^^^^^^^ Traceback (most recent call last): File "/home/lin/taichi/python/taichi/lang/ast/ast_transformer_utils.py", line 25, in __call__ return method(ctx, node) File "/home/lin/taichi/python/taichi/lang/ast/ast_transformer.py", line 800, in build_BinOp node.ptr = op(node.left.ptr, node.right.ptr) File "/home/lin/taichi/python/taichi/lang/ast/ast_transformer.py", line 788, in ast.Mult: lambda l, r: l * r, TypeError: unsupported operand type(s) for *: 'Vector' and 'ScalarField' ``` Traceback after #6474 and #6475: ``` Traceback (most recent call last): File "/home/lin/test/ant.py", line 21, in myclass.do_something() File "/home/lin/taichi/python/taichi/lang/kernel_impl.py", line 1025, in __call__ raise type(e)('\n' + str(e)) from None taichi.lang.exception.TaichiCompilationError: File "/home/lin/test/ant.py", line 16, in do_something: NEE_contrib = brdf * self.light_weight * self.light_color[None] ^^^^^^^^^^^^^^^^^^^^^^^^ Traceback (most recent call last): File "/home/lin/taichi/python/taichi/lang/ast/ast_transformer_utils.py", line 25, in __call__ return method(ctx, node) File "/home/lin/taichi/python/taichi/lang/ast/ast_transformer.py", line 800, in build_BinOp node.ptr = op(node.left.ptr, node.right.ptr) File "/home/lin/taichi/python/taichi/lang/ast/ast_transformer.py", line 788, in ast.Mult: lambda l, r: l * r, TypeError: unsupported operand type(s) for *: 'Vector' and 'ScalarField' ``` Traceback after #6474, #6475 and #6477: ``` Traceback (most recent call last): File "/home/lin/test/ant.py", line 21, in myclass.do_something() File "/home/lin/taichi/python/taichi/lang/kernel_impl.py", line 1025, in __call__ raise type(e)('\n' + str(e)) from None taichi.lang.exception.TaichiTypeError: File "/home/lin/test/ant.py", line 16, in do_something: NEE_contrib = brdf * self.light_weight * self.light_color[None] ^^^^^^^^^^^^^^^^^^^^^^^^ unsupported operand type(s) for *: 'Vector' and 'ScalarField' ``` Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- python/taichi/lang/ast/ast_transformer.py | 2 +- python/taichi/lang/ops.py | 8 ++++++++ tests/python/test_field.py | 13 +++++++++++++ 3 files changed, 22 insertions(+), 1 deletion(-) diff --git a/python/taichi/lang/ast/ast_transformer.py b/python/taichi/lang/ast/ast_transformer.py index 681f28225c4d5..7b9418a02ea57 100644 --- a/python/taichi/lang/ast/ast_transformer.py +++ b/python/taichi/lang/ast/ast_transformer.py @@ -806,7 +806,7 @@ def build_BinOp(ctx, node): try: node.ptr = op(node.left.ptr, node.right.ptr) except TypeError as e: - raise TaichiTypeError(str(e)) + raise TaichiTypeError(str(e)) from None return node.ptr @staticmethod diff --git a/python/taichi/lang/ops.py b/python/taichi/lang/ops.py index ecb4b1310003d..1265a815766c8 100644 --- a/python/taichi/lang/ops.py +++ b/python/taichi/lang/ops.py @@ -6,6 +6,7 @@ from taichi._lib import core as _ti_core from taichi.lang import expr, impl from taichi.lang.exception import TaichiSyntaxError +from taichi.lang.field import Field from taichi.lang.util import cook_dtype, is_taichi_class, taichi_scope unary_ops = [] @@ -51,6 +52,8 @@ def rev_foo(x, y): @functools.wraps(foo) def wrapped(a, b): + if isinstance(a, Field) or isinstance(b, Field): + return NotImplemented if is_taichi_class(a): return a._element_wise_binary(imp_foo, b) if is_taichi_class(b): @@ -79,6 +82,9 @@ def cab_foo(c, a, b): @functools.wraps(foo) def wrapped(a, b, c): + if isinstance(a, Field) or isinstance(b, Field) or isinstance( + c, Field): + return NotImplemented if is_taichi_class(a): return a._element_wise_ternary(abc_foo, b, c) if is_taichi_class(b): @@ -101,6 +107,8 @@ def imp_foo(x, y): @functools.wraps(foo) def wrapped(a, b): + if isinstance(a, Field) or isinstance(b, Field): + return NotImplemented if is_taichi_class(a): return a._element_wise_writeback_binary(imp_foo, b) if is_taichi_class(b): diff --git a/tests/python/test_field.py b/tests/python/test_field.py index d8e934329e686..d502c7621bb51 100644 --- a/tests/python/test_field.py +++ b/tests/python/test_field.py @@ -311,3 +311,16 @@ def test_python_for_in(): match="Struct for is only available in Taichi scope"): for i in x: pass + + +@test_utils.test() +def test_matrix_mult_field(): + x = ti.field(int, shape=()) + with pytest.raises(ti.TaichiTypeError, match="unsupported operand type"): + + @ti.kernel + def foo(): + a = ti.Vector([1, 1, 1]) + b = a * x + + foo()