Skip to content

Commit

Permalink
[Error] Shorten the length of traceback of TaichiCompilationError (#3965
Browse files Browse the repository at this point in the history
)

* [Error] Shorten the length of traceback of TaichiCompilationError

* format

* fix test
  • Loading branch information
lin-hitonami authored Jan 8, 2022
1 parent d5f4951 commit b53523c
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 11 deletions.
7 changes: 5 additions & 2 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from taichi.lang.ast import (ASTTransformerContext, KernelSimplicityASTChecker,
transform_tree)
from taichi.lang.enums import Layout
from taichi.lang.exception import TaichiSyntaxError
from taichi.lang.exception import TaichiCompilationError, TaichiSyntaxError
from taichi.lang.expr import Expr
from taichi.lang.matrix import MatrixType
from taichi.lang.shell import _shell_pop_print, oinspect
Expand Down Expand Up @@ -718,7 +718,10 @@ def wrapped(*args, **kwargs):

@functools.wraps(_func)
def wrapped(*args, **kwargs):
return primal(*args, **kwargs)
try:
return primal(*args, **kwargs)
except TaichiCompilationError as e:
raise type(e)('\n' + str(e)) from None

wrapped.grad = adjoint

Expand Down
18 changes: 9 additions & 9 deletions tests/python/test_exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def foo():
# yapf: enable

if version_info < (3, 8):
msg = f"""\
msg = f"""
On line {frameinfo.lineno + 5} of file "{frameinfo.filename}":
aaaa(111,"""
else:
msg = f"""\
msg = f"""
On line {frameinfo.lineno + 5} of file "{frameinfo.filename}":
aaaa(111,
^^^^"""
Expand Down Expand Up @@ -54,15 +54,15 @@ def foo():
lineno = frameinfo.lineno
file = frameinfo.filename
if version_info < (3, 8):
msg = f"""\
msg = f"""
On line {lineno + 13} of file "{file}":
bar()
On line {lineno + 9} of file "{file}":
baz()
On line {lineno + 5} of file "{file}":
t()"""
else:
msg = f"""\
msg = f"""
On line {lineno + 13} of file "{file}":
bar()
^^^^^
Expand All @@ -89,11 +89,11 @@ def foo():
lineno = frameinfo.lineno
file = frameinfo.filename
if version_info < (3, 8):
msg = f"""\
msg = f"""
On line {lineno + 5} of file "{file}":
a(11, 22, 3)"""
else:
msg = f"""\
msg = f"""
On line {lineno + 5} of file "{file}":
a(11, 22, 3)
^"""
Expand All @@ -114,12 +114,12 @@ def foo():
lineno = frameinfo.lineno
file = frameinfo.filename
if version_info < (3, 8):
msg = f"""\
msg = f"""
On line {lineno + 5} of file "{file}":
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa(111)
"""
else:
msg = f"""\
msg = f"""
On line {lineno + 5} of file "{file}":
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbaaaaaa
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down Expand Up @@ -147,7 +147,7 @@ def foo():
foo()
lineno = frameinfo.lineno
file = frameinfo.filename
msg = f"""\
msg = f"""
On line {lineno + 3} of file "{file}":
for i in range(1, 2, 3):
^^^^^^^^^^^^^^^^^^^^^^^^
Expand Down

0 comments on commit b53523c

Please sign in to comment.