From b53523c19802659fddfa0b409cb7b12a3b653dc7 Mon Sep 17 00:00:00 2001 From: Lin Jiang <90667349+lin-hitonami@users.noreply.github.com> Date: Sat, 8 Jan 2022 19:42:18 +0800 Subject: [PATCH] [Error] Shorten the length of traceback of TaichiCompilationError (#3965) * [Error] Shorten the length of traceback of TaichiCompilationError * format * fix test --- python/taichi/lang/kernel_impl.py | 7 +++++-- tests/python/test_exception.py | 18 +++++++++--------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/python/taichi/lang/kernel_impl.py b/python/taichi/lang/kernel_impl.py index 59525976084b3..982c3597bbadb 100644 --- a/python/taichi/lang/kernel_impl.py +++ b/python/taichi/lang/kernel_impl.py @@ -12,7 +12,7 @@ from taichi.lang.ast import (ASTTransformerContext, KernelSimplicityASTChecker, transform_tree) from taichi.lang.enums import Layout -from taichi.lang.exception import TaichiSyntaxError +from taichi.lang.exception import TaichiCompilationError, TaichiSyntaxError from taichi.lang.expr import Expr from taichi.lang.matrix import MatrixType from taichi.lang.shell import _shell_pop_print, oinspect @@ -718,7 +718,10 @@ def wrapped(*args, **kwargs): @functools.wraps(_func) def wrapped(*args, **kwargs): - return primal(*args, **kwargs) + try: + return primal(*args, **kwargs) + except TaichiCompilationError as e: + raise type(e)('\n' + str(e)) from None wrapped.grad = adjoint diff --git a/tests/python/test_exception.py b/tests/python/test_exception.py index f2c2195a4218f..2682dda11d1a5 100644 --- a/tests/python/test_exception.py +++ b/tests/python/test_exception.py @@ -21,11 +21,11 @@ def foo(): # yapf: enable if version_info < (3, 8): - msg = f"""\ + msg = f""" On line {frameinfo.lineno + 5} of file "{frameinfo.filename}": aaaa(111,""" else: - msg = f"""\ + msg = f""" On line {frameinfo.lineno + 5} of file "{frameinfo.filename}": aaaa(111, ^^^^""" @@ -54,7 +54,7 @@ def foo(): lineno = frameinfo.lineno file = frameinfo.filename if version_info < (3, 8): - msg = f"""\ + msg = f""" On line {lineno + 13} of file "{file}": bar() On line {lineno + 9} of file "{file}": @@ -62,7 +62,7 @@ def foo(): On line {lineno + 5} of file "{file}": t()""" else: - msg = f"""\ + msg = f""" On line {lineno + 13} of file "{file}": bar() ^^^^^ @@ -89,11 +89,11 @@ def foo(): lineno = frameinfo.lineno file = frameinfo.filename if version_info < (3, 8): - msg = f"""\ + msg = f""" On line {lineno + 5} of file "{file}": a(11, 22, 3)""" else: - msg = f"""\ + msg = f""" On line {lineno + 5} of file "{file}": a(11, 22, 3) ^""" @@ -114,12 +114,12 @@ def foo(): lineno = frameinfo.lineno file = frameinfo.filename if version_info < (3, 8): - msg = f"""\ + msg = f""" On line {lineno + 5} of file "{file}": aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(111) """ else: - msg = f"""\ + msg = f""" On line {lineno + 5} of file "{file}": aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbaaaaaa ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -147,7 +147,7 @@ def foo(): foo() lineno = frameinfo.lineno file = frameinfo.filename - msg = f"""\ + msg = f""" On line {lineno + 3} of file "{file}": for i in range(1, 2, 3): ^^^^^^^^^^^^^^^^^^^^^^^^