Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Remove exposure of internal functions in taichi.lang.ops #4101

Merged
merged 7 commits into from
Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,12 +318,12 @@ def build_JoinedStr(ctx, node):
return node.ptr

@staticmethod
def build_call_if_is_builtin(node, args, keywords):
def build_call_if_is_builtin(ctx, node, args, keywords):
func = node.func.ptr
replace_func = {
id(print): impl.ti_print,
id(min): ti_ops.ti_min,
id(max): ti_ops.ti_max,
id(min): ti_ops.min,
id(max): ti_ops.max,
id(int): impl.ti_int,
id(float): impl.ti_float,
id(any): ti_ops.ti_any,
Expand All @@ -333,6 +333,12 @@ def build_call_if_is_builtin(node, args, keywords):
}
if id(func) in replace_func:
node.ptr = replace_func[id(func)](*args, **keywords)
if func is min or func is max:
name = "min" if func is min else "max"
warnings.warn_explicit(
f'Calling builtin function "{name}" in Taichi scope is deprecated. '
f'Please use "ti.{name}" instead.', DeprecationWarning,
ctx.file, node.lineno + ctx.lineno_offset)
return True
return False

Expand Down Expand Up @@ -383,7 +389,7 @@ def build_Call(ctx, node):
node.ptr = impl.ti_format(*args, **keywords)
return node.ptr

if ASTTransformer.build_call_if_is_builtin(node, args, keywords):
if ASTTransformer.build_call_if_is_builtin(ctx, node, args, keywords):
return node.ptr

node.ptr = func(*args, **keywords)
Expand Down Expand Up @@ -1074,6 +1080,11 @@ def build_IfExp(ctx, node):
node.body.ptr) or is_taichi_class(node.orelse.ptr):
node.ptr = ti_ops.select(node.test.ptr, node.body.ptr,
node.orelse.ptr)
warnings.warn_explicit(
'Using conditional expression for element-wise select operation on '
'Taichi vectors/matrices is deprecated. '
'Please use "ti.select" instead.', DeprecationWarning,
ctx.file, node.lineno + ctx.lineno_offset)
return node.ptr

is_static_if = (ASTTransformer.get_decorator(ctx,
Expand Down
4 changes: 2 additions & 2 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,11 +614,11 @@ def norm_sqr(self):

def max(self):
"""Return the maximum element value."""
return ops_mod.ti_max(*self.entries)
return ops_mod.max(*self.entries)

def min(self):
"""Return the minimum element value."""
return ops_mod.ti_min(*self.entries)
return ops_mod.min(*self.entries)

def any(self):
"""Test whether any element not equal zero.
Expand Down
25 changes: 17 additions & 8 deletions python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -504,7 +504,7 @@ def truediv(a, b):


@binary
def max(a, b): # pylint: disable=W0622
def max_impl(a, b):
"""The maxnimum function.
Args:
Expand All @@ -518,7 +518,7 @@ def max(a, b): # pylint: disable=W0622


@binary
def min(a, b): # pylint: disable=W0622
def min_impl(a, b):
"""The minimum function.
Args:
Expand Down Expand Up @@ -832,24 +832,24 @@ def assign(a, b):
return a


def ti_max(*args):
def max(*args): # pylint: disable=W0622
num_args = len(args)
assert num_args >= 1
if num_args == 1:
return args[0]
if num_args == 2:
return max(args[0], args[1])
return max(args[0], ti_max(*args[1:]))
return max_impl(args[0], args[1])
return max_impl(args[0], max(*args[1:]))


def ti_min(*args):
def min(*args): # pylint: disable=W0622
num_args = len(args)
assert num_args >= 1
if num_args == 1:
return args[0]
if num_args == 2:
return min(args[0], args[1])
return min(args[0], ti_min(*args[1:]))
return min_impl(args[0], args[1])
return min_impl(args[0], min(*args[1:]))


def ti_any(a):
Expand All @@ -858,3 +858,12 @@ def ti_any(a):

def ti_all(a):
return a.all()


__all__ = [
"acos", "asin", "atan2", "atomic_and", "atomic_or", "atomic_xor",
"atomic_max", "atomic_sub", "atomic_min", "atomic_add", "bit_cast",
"bit_shr", "cast", "ceil", "cos", "exp", "floor", "log", "random",
"raw_mod", "raw_div", "round", "rsqrt", "sin", "sqrt", "tan", "tanh",
"max", "min", "select"
]
2 changes: 1 addition & 1 deletion tests/python/test_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_abs():
@ti.kernel
def func():
for i in range(N):
x[i] = ti.abs(y[i])
x[i] = abs(y[i])

for i in range(N):
y[i] = i - 10
Expand Down
6 changes: 3 additions & 3 deletions tests/python/test_element_wise.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,10 +256,10 @@ def test_unary():
def func():
xi[0] = -yi[None]
xi[1] = ~yi[None]
xi[2] = ti.logical_not(yi[None])
xi[3] = ti.abs(yi[None])
xi[2] = not yi[None]
xi[3] = abs(yi[None])
xf[0] = -yf[None]
xf[1] = ti.abs(yf[None])
xf[1] = abs(yf[None])
xf[2] = ti.sqrt(yf[None])
xf[3] = ti.sin(yf[None])
xf[4] = ti.cos(yf[None])
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_f16.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_unary_op():

@ti.kernel
def foo():
x[None] = ti.neg(y[None])
x[None] = -y[None]
x[None] = ti.floor(x[None])
y[None] = ti.ceil(y[None])

Expand All @@ -159,7 +159,7 @@ def test_extra_unary_promote():

@ti.kernel
def foo():
x[None] = ti.abs(y[None])
x[None] = abs(y[None])

y[None] = -0.3
foo()
Expand Down
4 changes: 2 additions & 2 deletions tests/python/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def test_loops():
@ti.kernel
def func():
for i in range(ti.static(N // 2 + 3), N):
x[i] = ti.abs(y[i])
x[i] = abs(y[i])

func()

Expand Down Expand Up @@ -50,7 +50,7 @@ def test_numpy_loops():
@ti.kernel
def func():
for i in range(begin, end):
x[i] = ti.abs(y[i])
x[i] = abs(y[i])

func()

Expand Down
3 changes: 2 additions & 1 deletion tests/python/test_matrix.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import operator

import numpy as np
Expand Down Expand Up @@ -73,7 +74,7 @@ def test_python_scope_matrix_field(ops):

@ti.test(arch=ti.get_host_arch_list())
def test_constant_matrices():
assert ti.cos(ti.math.pi / 3) == approx(0.5)
assert ti.cos(math.pi / 3) == approx(0.5)
assert np.allclose((-ti.Vector([2, 3])).to_numpy(), np.array([-2, -3]))
assert ti.cos(ti.Vector([2,
3])).to_numpy() == approx(np.cos(np.array([2,
Expand Down
8 changes: 4 additions & 4 deletions tests/python/test_scalar_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@
unary_func_table = [
(ops.neg, ) * 2,
(ops.invert, ) * 2,
(ti.logical_not, np.logical_not),
(ti.abs, np.abs),
(ti.lang.ops.logical_not, np.logical_not),
(ti.lang.ops.abs, np.abs),
(ti.exp, np.exp),
(ti.log, np.log),
(ti.sin, np.sin),
Expand Down Expand Up @@ -64,10 +64,10 @@ def test_python_scope_vector_binary(ti_func, np_func):
def test_python_scope_vector_unary(ti_func, np_func):
ti.init()
x = ti.Vector([2, 3] if ti_func in
[ops.invert, ti.logical_not] else [0.2, 0.3])
[ops.invert, ti.lang.ops.logical_not] else [0.2, 0.3])

result = ti_func(x).to_numpy()
if ti_func in [ti.logical_not]:
if ti_func in [ti.lang.ops.logical_not]:
result = result.astype(bool)
expected = np_func(x.to_numpy())
assert allclose(result, expected)
Expand Down