Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Do not expose internal function in ti.lang.impl #4134

Merged
merged 14 commits into from
Jan 27, 2022
2 changes: 1 addition & 1 deletion docs/lang/articles/misc/internal.md
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ a = ti.field(ti.f32, shape=(128, 32, 8))
b = ti.field(ti.f32)
ti.root.dense(ti.j, 32).dense(ti.i, 16).place(b)

ti.get_runtime().materialize()
ti.lang.impl.get_runtime().materialize() # This is an internal api for dev, we don't make sure it is stable for user.

mapping_a = a.snode().physical_index_position()

Expand Down
4 changes: 3 additions & 1 deletion misc/benchmark_rebuild_graph.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from taichi.lang import impl

import taichi as ti

ti.init(arch=ti.cuda, async_mode=True)
Expand All @@ -23,4 +25,4 @@ def foo():
for i in range(1000):
foo()

ti.get_runtime().prog.benchmark_rebuild_graph()
impl.get_runtime().prog.benchmark_rebuild_graph()
4 changes: 3 additions & 1 deletion python/taichi/examples/algorithm/print_offset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from taichi.lang import impl

import taichi as ti

ti.init(arch=ti.cpu, print_ir=True)
Expand All @@ -19,7 +21,7 @@ def fill():
fill()
print(a.to_numpy())

ti.get_runtime().prog.visualize_layout('layout.pdf')
impl.get_runtime().prog.visualize_layout('layout.pdf')

gui = ti.GUI('layout', res=(256, 512), background_color=0xFFFFFF)

Expand Down
16 changes: 4 additions & 12 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,9 @@
TaichiSyntaxError, TaichiTypeError)
from taichi.lang.expr import Expr, make_expr_group
from taichi.lang.field import Field, ScalarField
from taichi.lang.impl import (axes, begin_frontend_if,
begin_frontend_struct_for, call_internal,
current_cfg, deactivate_all_snodes, expr_init,
expr_init_func, expr_init_list, field,
get_runtime, grouped,
insert_expr_stmt_if_ti_func, ndarray, one, root,
static, static_assert, static_print, stop_grad,
subscript, ti_assert, ti_float, ti_format,
ti_int, ti_print, zero)
from taichi.lang.impl import (axes, deactivate_all_snodes, field, grouped,
ndarray, one, root, static, static_assert,
static_print, stop_grad, zero)
from taichi.lang.kernel_arguments import SparseMatrixProxy
from taichi.lang.kernel_impl import (KernelArgError, KernelDefError,
data_oriented, func, kernel, pyfunc)
Expand All @@ -59,8 +53,6 @@

from taichi import _logging

runtime = impl.get_runtime()

i = axes(0)
j = axes(1)
k = axes(2)
Expand Down Expand Up @@ -787,7 +779,7 @@ def Tape(loss, clear_gradients=True):
from taichi._kernels import clear_loss # pylint: disable=C0415
clear_loss(loss)

return runtime.get_tape(loss)
return impl.get_runtime().get_tape(loss)


def clear_all_gradients():
Expand Down
6 changes: 6 additions & 0 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -904,3 +904,9 @@ def mesh_relation_access(mesh, from_index, to_element_type):
if isinstance(mesh, MeshInstance):
return MeshRelationAccessProxy(mesh, from_index, to_element_type)
raise RuntimeError("Relation access should be with a mesh instance!")


__all__ = [
'axes', 'deactivate_all_snodes', 'field', 'grouped', 'ndarray', 'one',
'root', 'static', 'static_assert', 'static_print', 'stop_grad', 'zero'
]
4 changes: 3 additions & 1 deletion tests/python/test_ad_if.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from taichi.lang import impl

import taichi as ti


Expand Down Expand Up @@ -235,6 +237,6 @@ def func():
def test_stack():
@ti.kernel
def func():
ti.call_internal("test_stack")
impl.call_internal("test_stack")

func()
7 changes: 4 additions & 3 deletions tests/python/test_ast_refactor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pytest
from taichi._testing import approx
from taichi.lang import impl
from taichi.lang.util import has_pytorch

import taichi as ti
Expand All @@ -13,7 +14,7 @@ def foo(x: ti.i32, y: ti.i32, a: ti.template()):
a[0] = x + y
a[1] = x - y
a[2] = x * y
a[3] = ti.ti_float(x) / y
a[3] = impl.ti_float(x) / y
a[4] = x // y
a[5] = x % y
a[6] = x**y
Expand Down Expand Up @@ -688,7 +689,7 @@ def bar(x: ti.template()):
return ti.Matrix([[1, 0], [0, 1]])

def fibonacci(x):
return ti.subscript(bar(x), 1, 0)
return impl.subscript(bar(x), 1, 0)

@ti.kernel
def foo(x: ti.template()) -> ti.i32:
Expand Down Expand Up @@ -780,7 +781,7 @@ def bar(x: tc.template()):
return tc.Matrix([[1, 0], [0, 1]])

def fibonacci(x):
return tc.subscript(bar(x), 1, 0)
return impl.subscript(bar(x), 1, 0)

@tc.kernel
def foo(x: tc.template()) -> tc.i32:
Expand Down
6 changes: 4 additions & 2 deletions tests/python/test_clear_all_gradients.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from taichi.lang import impl

import taichi as ti


Expand All @@ -23,7 +25,7 @@ def test_clear_all_gradients():
w.grad[i, j] = 6

ti.clear_all_gradients()
assert ti.get_runtime().get_num_compiled_functions() == 3
assert impl.get_runtime().get_num_compiled_functions() == 3

assert x.grad[None] == 0
for i in range(n):
Expand All @@ -34,4 +36,4 @@ def test_clear_all_gradients():

ti.clear_all_gradients()
# No more kernel compilation
assert ti.get_runtime().get_num_compiled_functions() == 3
assert impl.get_runtime().get_num_compiled_functions() == 3
3 changes: 2 additions & 1 deletion tests/python/test_compare.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import pytest
from taichi.lang import impl

import taichi as ti

Expand Down Expand Up @@ -106,7 +107,7 @@ def why_this_foo_fail(n):
return ti.atomic_add(b[None], n)

def foo(n):
return ti.atomic_add(ti.subscript(b, None), n)
return ti.atomic_add(impl.subscript(b, None), n)

@ti.kernel
def func():
Expand Down
8 changes: 5 additions & 3 deletions tests/python/test_cuda_internals.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from taichi.lang import impl

import taichi as ti

# TODO: these are not really tests...
Expand All @@ -8,7 +10,7 @@ def test_do_nothing():
@ti.kernel
def test():
for i in range(10):
ti.call_internal("do_nothing")
impl.call_internal("do_nothing")

test()

Expand All @@ -19,7 +21,7 @@ def test_active_mask():
def test():
for i in range(48):
if i % 2 == 0:
ti.call_internal("test_active_mask")
impl.call_internal("test_active_mask")

test()

Expand All @@ -29,6 +31,6 @@ def test_shfl_down():
@ti.kernel
def test():
for i in range(32):
ti.call_internal("test_shfl")
impl.call_internal("test_shfl")

test()
4 changes: 3 additions & 1 deletion tests/python/test_div.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from taichi.lang import impl

import taichi as ti


Expand Down Expand Up @@ -58,7 +60,7 @@ def test_true_div():

@ti.test()
def test_div_default_ip():
ti.get_runtime().set_default_ip(ti.i64)
impl.get_runtime().set_default_ip(ti.i64)
z = ti.field(ti.f32, shape=())

@ti.kernel
Expand Down
5 changes: 3 additions & 2 deletions tests/python/test_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
'''

import pytest
from taichi.lang import impl

import taichi as ti

Expand Down Expand Up @@ -121,7 +122,7 @@ def test_default_fp(dtype):

x = ti.Vector.field(2, float, ())

assert x.dtype == ti.get_runtime().default_fp
assert x.dtype == impl.get_runtime().default_fp


@pytest.mark.parametrize('dtype', [ti.i32, ti.i64])
Expand All @@ -130,7 +131,7 @@ def test_default_ip(dtype):

x = ti.Vector.field(2, int, ())

assert x.dtype == ti.get_runtime().default_ip
assert x.dtype == impl.get_runtime().default_ip


@ti.test()
Expand Down
14 changes: 8 additions & 6 deletions tests/python/test_internal_func.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import time

from taichi.lang import impl

import taichi as ti


Expand All @@ -8,7 +10,7 @@ def test_basic():
@ti.kernel
def test():
for _ in range(10):
ti.call_internal("do_nothing")
impl.call_internal("do_nothing")

test()

Expand All @@ -19,7 +21,7 @@ def test_host_polling():

@ti.kernel
def test():
ti.call_internal("refresh_counter")
impl.call_internal("refresh_counter")

for i in range(10):
print('updating tail to', i)
Expand All @@ -31,7 +33,7 @@ def test():
def test_list_manager():
@ti.kernel
def test():
ti.call_internal("test_list_manager")
impl.call_internal("test_list_manager")

test()
test()
Expand All @@ -41,7 +43,7 @@ def test():
def test_node_manager():
@ti.kernel
def test():
ti.call_internal("test_node_allocator")
impl.call_internal("test_node_allocator")

test()
test()
Expand All @@ -51,7 +53,7 @@ def test():
def test_node_manager_gc():
@ti.kernel
def test_cpu():
ti.call_internal("test_node_allocator_gc_cpu")
impl.call_internal("test_node_allocator_gc_cpu")

test_cpu()

Expand All @@ -60,7 +62,7 @@ def test_cpu():
def test_return():
@ti.kernel
def test_cpu():
ret = ti.call_internal("test_internal_func_args", 1.0, 2.0, 3)
ret = impl.call_internal("test_internal_func_args", 1.0, 2.0, 3)
assert ret == 9

test_cpu()
11 changes: 6 additions & 5 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest
from taichi._testing import approx
from taichi.lang import impl

import taichi as ti

Expand Down Expand Up @@ -426,7 +427,7 @@ def test_matrix_field_dynamic_index_different_path_length():
ti.root.dense(ti.i, 8).place(x)
ti.root.dense(ti.i, 2).dense(ti.i, 4).place(y)

ti.get_runtime().materialize()
impl.get_runtime().materialize()
assert v.dynamic_index_stride is None


Expand All @@ -439,7 +440,7 @@ def test_matrix_field_dynamic_index_not_pure_dense():
ti.root.dense(ti.i, 2).pointer(ti.i, 4).place(x)
ti.root.dense(ti.i, 2).dense(ti.i, 4).place(y)

ti.get_runtime().materialize()
impl.get_runtime().materialize()
assert v.dynamic_index_stride is None


Expand All @@ -454,7 +455,7 @@ def test_matrix_field_dynamic_index_different_cell_size_bytes():
ti.root.dense(ti.i, 8).place(x, temp)
ti.root.dense(ti.i, 8).place(y)

ti.get_runtime().materialize()
impl.get_runtime().materialize()
assert v.dynamic_index_stride is None


Expand All @@ -470,7 +471,7 @@ def test_matrix_field_dynamic_index_different_offset_bytes_in_parent_cell():
ti.root.dense(ti.i, 8).place(temp_a, x)
ti.root.dense(ti.i, 8).place(y, temp_b)

ti.get_runtime().materialize()
impl.get_runtime().materialize()
assert v.dynamic_index_stride is None


Expand All @@ -485,7 +486,7 @@ def test_matrix_field_dynamic_index_different_stride():

ti.root.dense(ti.i, 8).place(x, y, temp, z)

ti.get_runtime().materialize()
impl.get_runtime().materialize()
assert v.dynamic_index_stride is None


Expand Down
Loading