Skip to content

Commit

Permalink
[refactor] Move literal construction to expr module (#4448)
Browse files Browse the repository at this point in the history
* [refactor] Move literal construction to expr module

* Add errer message check

* Auto Format

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
strongoier and taichi-gardener authored Mar 4, 2022
1 parent 58d3417 commit c8501f0
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 60 deletions.
40 changes: 38 additions & 2 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from taichi.lang import impl
from taichi.lang.common_ops import TaichiOperations
from taichi.lang.exception import TaichiTypeError
from taichi.lang.util import is_taichi_class
from taichi.lang.util import is_taichi_class, to_numpy_type, to_taichi_type


# Scalar, basic data type
Expand All @@ -30,7 +30,7 @@ def __init__(self, *args, tb=None, dtype=None):
"Only 0-dimensional numpy array can be used to initialize a scalar expression"
)
arg = arg.dtype.type(arg)
self.ptr = impl.make_constant_expr(arg, dtype).ptr
self.ptr = make_constant_expr(arg, dtype).ptr
else:
assert False
if self.tb:
Expand All @@ -47,6 +47,42 @@ def __repr__(self):
return '<ti.Expr>'


def _check_in_range(npty, val):
iif = np.iinfo(npty)
if not iif.min <= val <= iif.max:
# This isn't the case we want to deal with: |val| does't fall into the valid range of either
# the signed or the unsigned type.
raise TaichiTypeError(
f'Constant {val} has exceeded the range of {to_taichi_type(npty)}: [{iif.min}, {iif.max}]'
)


def _clamp_unsigned_to_range(npty, val):
# npty: np.int32 or np.int64
iif = np.iinfo(npty)
if iif.min <= val <= iif.max:
return val
cap = (1 << iif.bits)
assert 0 <= val < cap
new_val = val - cap
return new_val


def make_constant_expr(val, dtype):
if isinstance(val, (int, np.integer)):
constant_dtype = impl.get_runtime(
).default_ip if dtype is None else dtype
_check_in_range(to_numpy_type(constant_dtype), val)
return Expr(
_ti_core.make_const_expr_int(
constant_dtype, _clamp_unsigned_to_range(np.int64, val)))
if isinstance(val, (float, np.floating)):
constant_dtype = impl.get_runtime(
).default_fp if dtype is None else dtype
return Expr(_ti_core.make_const_expr_fp(constant_dtype, val))
raise TaichiTypeError(f'Invalid constant scalar data type: {type(val)}')


def make_var_list(size):
exprs = []
for _ in range(size):
Expand Down
53 changes: 2 additions & 51 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@

import numpy as np
from taichi._lib import core as _ti_core
from taichi._logging import error
from taichi._snode.fields_builder import FieldsBuilder
from taichi.lang._ndarray import ScalarNdarray
from taichi.lang._ndrange import GroupedNDRange, _Ndrange
from taichi.lang.any_array import AnyArray, AnyArrayAccess
from taichi.lang.exception import TaichiRuntimeError, TaichiTypeError
from taichi.lang.exception import TaichiRuntimeError
from taichi.lang.expr import Expr, make_expr_group
from taichi.lang.field import Field, ScalarField
from taichi.lang.kernel_arguments import SparseMatrixProxy
Expand All @@ -23,8 +22,7 @@
from taichi.lang.struct import Struct, StructField, _IntermediateStruct
from taichi.lang.tape import TapeImpl
from taichi.lang.util import (cook_dtype, get_traceback, is_taichi_class,
python_scope, taichi_scope, to_numpy_type,
to_taichi_type, warning)
python_scope, taichi_scope, warning)
from taichi.types.primitive_types import f16, f32, f64, i32, i64


Expand Down Expand Up @@ -117,12 +115,6 @@ def begin_frontend_if(ast_builder, cond):
ast_builder.begin_frontend_if(Expr(cond).ptr)


def wrap_scalar(x):
if type(x) in [int, float]:
return Expr(x)
return x


@taichi_scope
def subscript(value, *_indices, skip_reordered=False):
if isinstance(value, np.ndarray):
Expand Down Expand Up @@ -364,47 +356,6 @@ def get_runtime():
return pytaichi


def _check_in_range(npty, val):
iif = np.iinfo(npty)
if not iif.min <= val <= iif.max:
# This isn't the case we want to deal with: |val| does't fall into the valid range of either
# the signed or the unsigned type.
error(
f'Constant {val} has exceeded the range of {to_taichi_type(npty)}: [{iif.min}, {iif.max}]'
)


def _clamp_unsigned_to_range(npty, val):
# npty: np.int32 or np.int64
iif = np.iinfo(npty)
if iif.min <= val <= iif.max:
return val
cap = (1 << iif.bits)
assert 0 <= val < cap
new_val = val - cap
return new_val


@taichi_scope
def make_constant_expr_i32(val):
assert isinstance(val, (int, np.integer))
return Expr(_ti_core.make_const_expr_int(i32, val))


@taichi_scope
def make_constant_expr(val, dtype):
if isinstance(val, (int, np.integer)):
constant_dtype = pytaichi.default_ip if dtype is None else dtype
_check_in_range(to_numpy_type(constant_dtype), val)
return Expr(
_ti_core.make_const_expr_int(
constant_dtype, _clamp_unsigned_to_range(np.int64, val)))
if isinstance(val, (float, np.floating)):
constant_dtype = pytaichi.default_fp if dtype is None else dtype
return Expr(_ti_core.make_const_expr_fp(constant_dtype, val))
raise TaichiTypeError(f'Invalid constant scalar data type: {type(val)}')


def reset():
global pytaichi
old_kernels = pytaichi.kernels
Expand Down
8 changes: 4 additions & 4 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def __init__(self, n=1, m=1, dt=None, suppress_warning=False):
mat.append(
list([
impl.make_tensor_element_expr(
self.local_tensor_proxy,
(impl.make_constant_expr_i32(i), ),
self.local_tensor_proxy, (expr.Expr(
i, dtype=primitive_types.i32), ),
(len(n), ), self.dynamic_index_stride)
]))
else: # now init a Matrix
Expand Down Expand Up @@ -119,8 +119,8 @@ def __init__(self, n=1, m=1, dt=None, suppress_warning=False):
mat[i].append(
impl.make_tensor_element_expr(
self.local_tensor_proxy,
(impl.make_constant_expr_i32(i),
impl.make_constant_expr_i32(j)),
(expr.Expr(i, dtype=primitive_types.i32),
expr.Expr(j, dtype=primitive_types.i32)),
(len(n), len(n[0])),
self.dynamic_index_stride))
self.n = len(mat)
Expand Down
12 changes: 9 additions & 3 deletions tests/python/test_literal.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def test_literal_multi_args_error():
def multi_args_error():
a = ti.i64(1, 2)

with pytest.raises(ti.TaichiSyntaxError):
with pytest.raises(
ti.TaichiSyntaxError,
match="Type annotation can only be given to a single literal."):
multi_args_error()


Expand All @@ -33,7 +35,9 @@ def test_literal_keywords_error():
def keywords_error():
a = ti.f64(1, x=2)

with pytest.raises(ti.TaichiSyntaxError):
with pytest.raises(
ti.TaichiSyntaxError,
match="Type annotation can only be given to a single literal."):
keywords_error()


Expand All @@ -44,5 +48,7 @@ def expr_error():
a = 1
b = ti.f16(a)

with pytest.raises(ti.TaichiSyntaxError):
with pytest.raises(
ti.TaichiSyntaxError,
match="Type annotation can only be given to a single literal."):
expr_error()

0 comments on commit c8501f0

Please sign in to comment.