Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Do not expose internal function in field, exception, expr, any_array , _ndrange, _ndarray #4137

Merged
merged 11 commits into from
Jan 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 8 additions & 9 deletions python/taichi/lang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,9 @@
from taichi._lib import core as _ti_core
from taichi._lib.utils import locale_encode
from taichi.lang import impl
from taichi.lang._ndarray import ScalarNdarray
from taichi.lang._ndrange import GroupedNDRange, ndrange
from taichi.lang.any_array import AnyArray, AnyArrayAccess
from taichi.lang._ndrange import ndrange
from taichi.lang.enums import Layout
from taichi.lang.exception import (InvalidOperationError,
TaichiCompilationError, TaichiNameError,
from taichi.lang.exception import (TaichiCompilationError, TaichiNameError,
TaichiSyntaxError, TaichiTypeError)
from taichi.lang.expr import Expr, make_expr_group
from taichi.lang.field import Field, ScalarField
Expand Down Expand Up @@ -716,17 +713,19 @@ def cache_read_only(*args):

def assume_in_range(val, base, low, high):
return _ti_core.expr_assume_in_range(
Expr(val).ptr,
Expr(base).ptr, low, high)
impl.Expr(val).ptr,
impl.Expr(base).ptr, low, high)


def loop_unique(val, covers=None):
if covers is None:
covers = []
if not isinstance(covers, (list, tuple)):
covers = [covers]
covers = [x.snode.ptr if isinstance(x, Expr) else x.ptr for x in covers]
return _ti_core.expr_loop_unique(Expr(val).ptr, covers)
covers = [
x.snode.ptr if isinstance(x, impl.Expr) else x.ptr for x in covers
]
return _ti_core.expr_loop_unique(impl.Expr(val).ptr, covers)


parallelize = _ti_core.parallelize
Expand Down
3 changes: 3 additions & 0 deletions python/taichi/lang/_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,3 +354,6 @@ def setter(value):

self.getter = getter
self.setter = setter


__all__ = []
9 changes: 8 additions & 1 deletion python/taichi/lang/_ndrange.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from taichi.lang.matrix import _IntermediateMatrix


class ndrange:
class _Ndrange:
def __init__(self, *args):
args = list(args)
for i, arg in enumerate(args):
Expand Down Expand Up @@ -42,10 +42,17 @@ def grouped(self):
return GroupedNDRange(self)


def ndrange(*args):
return _Ndrange(*args)


class GroupedNDRange:
def __init__(self, r):
self.r = r

def __iter__(self):
for ind in self.r:
yield _IntermediateMatrix(len(ind), 1, list(ind))


__all__ = ['ndrange']
k-ye marked this conversation as resolved.
Show resolved Hide resolved
3 changes: 3 additions & 0 deletions python/taichi/lang/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,6 @@ def subscript(self, i, j):
indices = self.indices_first + indices_second
return Expr(_ti_core.subscript(self.arr.ptr,
make_expr_group(*indices)))


__all__ = []
4 changes: 2 additions & 2 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from taichi._lib import core as _ti_core
from taichi.lang import expr, impl, kernel_arguments, kernel_impl, matrix, mesh
from taichi.lang import ops as ti_ops
from taichi.lang._ndrange import ndrange
from taichi.lang._ndrange import _Ndrange, ndrange
from taichi.lang.ast.ast_transformer_utils import Builder, LoopStatus
from taichi.lang.ast.symbol_resolver import ASTResolver
from taichi.lang.exception import TaichiSyntaxError
Expand Down Expand Up @@ -731,7 +731,7 @@ def build_static_for(ctx, node, is_grouped):
if is_grouped:
assert len(node.iter.args[0].args) == 1
ndrange_arg = build_stmt(ctx, node.iter.args[0].args[0])
if not isinstance(ndrange_arg, ndrange):
if not isinstance(ndrange_arg, _Ndrange):
raise TaichiSyntaxError(
"Only 'ti.ndrange' is allowed in 'ti.static(ti.grouped(...))'."
)
Expand Down
3 changes: 3 additions & 0 deletions python/taichi/lang/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,3 +9,6 @@ class Layout(Enum):
"""
AOS = 1
SOA = 2


__all__ = ['Layout']
6 changes: 6 additions & 0 deletions python/taichi/lang/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,9 @@ def handle_exception_from_cpp(exc):
if isinstance(exc, core.TaichiTypeError):
return TaichiTypeError(str(exc))
return exc


__all__ = [
'TaichiSyntaxError', 'TaichiTypeError', 'TaichiCompilationError',
'TaichiNameError'
]
3 changes: 3 additions & 0 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,6 @@ def make_expr_group(*exprs):
else:
expr_group.push_back(Expr(i).ptr)
return expr_group


__all__ = []
3 changes: 3 additions & 0 deletions python/taichi/lang/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,6 @@ class SNodeHostAccess:
def __init__(self, accessor, key):
self.accessor = accessor
self.key = key


__all__ = []
11 changes: 6 additions & 5 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from taichi._lib import core as _ti_core
from taichi._logging import warn
from taichi.lang._ndarray import ScalarNdarray
from taichi.lang._ndrange import GroupedNDRange, ndrange
from taichi.lang._ndrange import GroupedNDRange, _Ndrange
from taichi.lang.any_array import AnyArray, AnyArrayAccess
from taichi.lang.exception import InvalidOperationError, TaichiTypeError
from taichi.lang.expr import Expr, make_expr_group
Expand Down Expand Up @@ -51,7 +51,7 @@ def expr_init(rhs):
return rhs
if isinstance(rhs, _ti_core.Arch):
return rhs
if isinstance(rhs, ndrange):
if isinstance(rhs, _Ndrange):
return rhs
if isinstance(rhs, MeshElementFieldProxy):
return rhs
Expand Down Expand Up @@ -848,8 +848,9 @@ def static(x, *xs):
if len(xs): # for python-ish pointer assign: x, y = ti.static(y, x)
return [static(x)] + [static(x) for x in xs]

if isinstance(x, (bool, int, float, range, list, tuple, enumerate, ndrange,
GroupedNDRange, zip, filter, map)) or x is None:
if isinstance(x,
(bool, int, float, range, list, tuple, enumerate, _Ndrange,
GroupedNDRange, zip, filter, map)) or x is None:
return x
if isinstance(x, AnyArray):
return x
Expand All @@ -874,7 +875,7 @@ def grouped(x):
>>> for I in ti.grouped(ndrange(8, 16)):
>>> print(I[0] + I[1])
"""
if isinstance(x, ndrange):
if isinstance(x, _Ndrange):
return x.grouped()
return x

Expand Down
3 changes: 2 additions & 1 deletion python/taichi/linalg/sparse_solver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import taichi.lang
from taichi._lib import core as _ti_core
from taichi.lang.field import Field
from taichi.linalg import SparseMatrix
from taichi.types.primitive_types import f32

Expand Down Expand Up @@ -70,7 +71,7 @@ def solve(self, b):
Returns:
numpy.array: The solution of linear systems.
"""
if isinstance(b, taichi.lang.Field):
if isinstance(b, Field):
return self.solver.solve(b.to_numpy())
if isinstance(b, np.ndarray):
return self.solver.solve(b)
Expand Down