diff --git a/python/taichi/lang/impl.py b/python/taichi/lang/impl.py index eb8eaf61618ec..30d58f65c9d89 100644 --- a/python/taichi/lang/impl.py +++ b/python/taichi/lang/impl.py @@ -217,15 +217,16 @@ def get_runtime(): @taichi_scope def make_constant_expr(val): + import numpy as np _taichi_skip_traceback = 1 - if isinstance(val, int): + if isinstance(val, (int, np.integer)): if pytaichi.default_ip == i32: return Expr(taichi_lang_core.make_const_expr_i32(val)) elif pytaichi.default_ip == i64: return Expr(taichi_lang_core.make_const_expr_i64(val)) else: assert False - elif isinstance(val, float): + elif isinstance(val, (float, np.floating, np.ndarray)): if pytaichi.default_fp == f32: return Expr(taichi_lang_core.make_const_expr_f32(val)) elif pytaichi.default_fp == f64: diff --git a/tests/python/test_numpy.py b/tests/python/test_numpy.py index f8157b8144db0..3d45afac8db49 100644 --- a/tests/python/test_numpy.py +++ b/tests/python/test_numpy.py @@ -134,8 +134,8 @@ def test_numpy(arr: ti.ext_arr()): assert a[i, j, k] == i * j * (k + 1) + i + j + k * 2 -@ti.must_throw(AssertionError) -def test_numpy_3d(): +@ti.must_throw(IndexError) +def test_numpy_3d_error(): val = ti.var(ti.i32) n = 4