Skip to content

Commit

Permalink
[lang] Migrate TensorType expansion for subscription indices from Pyt…
Browse files Browse the repository at this point in the history
…hon to Frontend IR (taichi-dev#6942)

Issue: taichi-dev#5819

### Brief Summary
For indices of TensorType, instead of scalarizing them at Python level,
it is up to the Frontend IR's consumer to decide whether TensorType'd
indices are acceptable and if we should have it scalarized.

This PR removes `expand_expr` in Expression subscription and migrate the
scalarization logics to the following constructors:

1. MeshIndexConversionExpression::MeshIndexConversionExpression
2. IndexExpression::IndexExpression
  • Loading branch information
jim19930609 authored and quadpixels committed May 13, 2023
1 parent ef1f060 commit 49d7d07
Show file tree
Hide file tree
Showing 21 changed files with 271 additions and 178 deletions.
3 changes: 2 additions & 1 deletion python/taichi/lang/_texture.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def _get_entries(mat):
if isinstance(mat, Matrix):
return mat.entries
assert isinstance(mat, Expr) and mat.is_tensor()
return impl.get_runtime().prog.current_ast_builder().expand_expr([mat.ptr])
return impl.get_runtime().prog.current_ast_builder().expand_exprs(
[mat.ptr])


class TextureSampler:
Expand Down
7 changes: 5 additions & 2 deletions python/taichi/lang/any_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,18 @@ def __init__(self, arr, indices_first):

@taichi_scope
def subscript(self, i, j):
ast_builder = impl.get_runtime().prog.current_ast_builder()

indices_second = (i, ) if len(self.arr.element_shape()) == 1 else (i,
j)
if self.arr.layout() == Layout.SOA:
indices = indices_second + self.indices_first
else:
indices = self.indices_first + indices_second
return Expr(
_ti_core.subscript(self.arr.ptr, make_expr_group(*indices),
impl.get_runtime().get_current_src_info()))
ast_builder.expr_subscript(
self.arr.ptr, make_expr_group(*indices),
impl.get_runtime().get_current_src_info()))


__all__ = []
33 changes: 18 additions & 15 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
ReturnStatus)
from taichi.lang.ast.symbol_resolver import ASTResolver
from taichi.lang.exception import (TaichiIndexError, TaichiSyntaxError,
TaichiTypeError)
TaichiTypeError, handle_exception_from_cpp)
from taichi.lang.expr import Expr, make_expr_group
from taichi.lang.field import Field
from taichi.lang.matrix import Matrix, MatrixType, Vector, is_vector
Expand Down Expand Up @@ -156,7 +156,7 @@ def build_assign_unpack(ctx, node_target, values, is_static_assign):
raise ValueError(
'Matrices with more than one columns cannot be unpacked')

values = ctx.ast_builder.expand_expr([values.ptr])
values = ctx.ast_builder.expand_exprs([values.ptr])
if len(values) == 1:
values = values[0]

Expand Down Expand Up @@ -302,7 +302,7 @@ def process_generators(ctx, node, now_comp, func, result):
if isinstance(_iter, impl.Expr) and _iter.ptr.is_tensor():
shape = _iter.ptr.get_shape()
flattened = [
Expr(x) for x in ctx.ast_builder.expand_expr([_iter.ptr])
Expr(x) for x in ctx.ast_builder.expand_exprs([_iter.ptr])
]
_iter = reshape_list(flattened, shape)

Expand Down Expand Up @@ -514,7 +514,7 @@ def build_Call(ctx, node):
# Expand Expr with Matrix-type return into list of Exprs
arg_list = [
Expr(x)
for x in ctx.ast_builder.expand_expr([arg_list.ptr])
for x in ctx.ast_builder.expand_exprs([arg_list.ptr])
]

for i in arg_list:
Expand Down Expand Up @@ -730,7 +730,7 @@ def build_Return(ctx, node):
elif isinstance(ctx.func.return_type, MatrixType):
values = node.value.ptr
if isinstance(values, Expr) and values.ptr.is_tensor():
values = ctx.ast_builder.expand_expr([values.ptr])
values = ctx.ast_builder.expand_exprs([values.ptr])
else:
assert isinstance(values, Matrix)
values = itertools.chain.from_iterable(values.to_list()) if\
Expand Down Expand Up @@ -819,12 +819,15 @@ def build_Attribute(ctx, node):
# we continue to process it as a normal attribute node.
try:
build_stmt(ctx, node.value)
except TaichiIndexError as e:
node.value.ptr = None
if ASTTransformer.build_attribute_if_is_dynamic_snode_method(
ctx, node):
return node.ptr
except Exception as e:
e = handle_exception_from_cpp(e)
if isinstance(e, TaichiIndexError):
node.value.ptr = None
if ASTTransformer.build_attribute_if_is_dynamic_snode_method(
ctx, node):
return node.ptr
raise e

if ASTTransformer.build_attribute_if_is_dynamic_snode_method(
ctx, node):
return node.ptr
Expand All @@ -837,11 +840,11 @@ def build_Attribute(ctx, node):
node.attr)
attr_len = len(node.attr)
if attr_len == 1:
node.ptr = Expr(
_ti_core.subscript(
node.value.ptr.ptr,
make_expr_group(keygroup.index(node.attr)),
impl.get_runtime().get_current_src_info()))
node.ptr = Expr(impl.get_runtime(
).prog.current_ast_builder().expr_subscript(
node.value.ptr.ptr,
make_expr_group(keygroup.index(node.attr)),
impl.get_runtime().get_current_src_info()))
else:
node.ptr = Expr(
_ti_core.subscript_with_multiple_indices(
Expand Down
2 changes: 2 additions & 0 deletions python/taichi/lang/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ def handle_exception_from_cpp(exc):
return TaichiTypeError(str(exc))
if isinstance(exc, core.TaichiSyntaxError):
return TaichiSyntaxError(str(exc))
if isinstance(exc, core.TaichiIndexError):
return TaichiIndexError(str(exc))
if isinstance(exc, core.TaichiAssertionError):
return TaichiAssertionError(str(exc))
return exc
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ def _get_flattened_ptrs(val):
ptrs.extend(_get_flattened_ptrs(item))
return ptrs
if isinstance(val, Expr) and val.ptr.is_tensor():
return impl.get_runtime().prog.current_ast_builder().expand_expr(
return impl.get_runtime().prog.current_ast_builder().expand_exprs(
[val.ptr])
return [Expr(val).ptr]

Expand Down
55 changes: 22 additions & 33 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,8 @@
from taichi.lang._texture import RWTextureAccessor
from taichi.lang.any_array import AnyArray
from taichi.lang.enums import SNodeGradType
from taichi.lang.exception import (TaichiCompilationError, TaichiIndexError,
TaichiRuntimeError, TaichiSyntaxError,
TaichiTypeError)
from taichi.lang.exception import (TaichiCompilationError, TaichiRuntimeError,
TaichiSyntaxError, TaichiTypeError)
from taichi.lang.expr import Expr, make_expr_group
from taichi.lang.field import Field, ScalarField
from taichi.lang.kernel_arguments import SparseMatrixProxy
Expand Down Expand Up @@ -132,6 +131,7 @@ def check_validity(x):

@taichi_scope
def subscript(ast_builder, value, *_indices, skip_reordered=False):
ast_builder = get_runtime().prog.current_ast_builder()
# Directly evaluate in Python for non-Taichi types
if not isinstance(
value,
Expand All @@ -150,9 +150,6 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
elif isinstance(_index, slice):
ind = [_index]
has_slice = True
elif isinstance(_index, Expr) and _index.is_tensor():
# Expand Expr with TensorType return
ind = [Expr(e) for e in ast_builder.expand_expr([_index.ptr])]
else:
ind = [_index]
flattened_indices += ind
Expand All @@ -167,7 +164,6 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
f"The type {type(value)} do not support index of slice type")
else:
indices_expr_group = make_expr_group(*indices)
index_dim = indices_expr_group.size()

if isinstance(value, SharedArray):
return value.subscript(*indices)
Expand All @@ -178,13 +174,13 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
if isinstance(value,
(MeshReorderedScalarFieldProxy,
MeshReorderedMatrixFieldProxy)) and not skip_reordered:
assert index_dim == 1

reordered_index = tuple([
Expr(
_ti_core.get_index_conversion(value.mesh_ptr,
value.element_type,
Expr(indices[0]).ptr,
ConvType.g2r))
ast_builder.mesh_index_conversion(value.mesh_ptr,
value.element_type,
Expr(indices[0]).ptr,
ConvType.g2r))
])
return subscript(ast_builder,
value,
Expand All @@ -203,29 +199,26 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
raise RuntimeError(
f"Gradient {_var.get_expr_name()} has not been placed, check whether `needs_grad=True`"
)
field_dim = snode.num_active_indices()
if field_dim != index_dim:
raise TaichiIndexError(
f'Field with dim {field_dim} accessed with indices of dim {index_dim}'
)

if isinstance(value, MatrixField):
return make_index_expr(value.ptr, indices_expr_group)
return Expr(
ast_builder.expr_subscript(
value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))
if isinstance(value, StructField):
entries = {
k: subscript(ast_builder, v, *indices)
for k, v in value._items
}
entries['__struct_methods'] = value.struct_methods
return _IntermediateStruct(entries)
return make_index_expr(_var, indices_expr_group)
return Expr(
ast_builder.expr_subscript(_var, indices_expr_group,
get_runtime().get_current_src_info()))
if isinstance(value, AnyArray):
dim = _ti_core.get_external_tensor_dim(value.ptr)
element_dim = len(value.element_shape())
if dim != index_dim + element_dim:
raise IndexError(
f'Field with dim {dim - element_dim} accessed with indices of dim {index_dim}'
)
return make_index_expr(value.ptr, indices_expr_group)
return Expr(
ast_builder.expr_subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))
assert isinstance(value, Expr)
# Index into TensorType
# value: IndexExpression with ret_type = TensorType
Expand All @@ -249,18 +242,14 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
make_expr_group(i, j) for i in indices[0] for j in indices[1]
]
return_shape = (len(indices[0]), len(indices[1]))

return Expr(
_ti_core.subscript_with_multiple_indices(
value.ptr, multiple_indices, return_shape,
get_runtime().get_current_src_info()))
return make_index_expr(value.ptr, indices_expr_group)


@taichi_scope
def make_index_expr(_var, indices_expr_group):
return Expr(
_ti_core.subscript(_var, indices_expr_group,
get_runtime().get_current_src_info()))
ast_builder.expr_subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))


class SrcInfoGuard:
Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def func_call_rvalue(self, key, args):
impl.Expr) and args[i].ptr.is_tensor():
non_template_args.extend([
Expr(x) for x in impl.get_runtime().prog.
current_ast_builder().expand_expr([args[i].ptr])
current_ast_builder().expand_exprs([args[i].ptr])
])
else:
non_template_args.append(args[i])
Expand Down
4 changes: 2 additions & 2 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1503,7 +1503,7 @@ def __call__(self, *args):
elif isinstance(x, impl.Expr) and x.ptr.is_tensor():
entries += [
impl.Expr(e) for e in impl.get_runtime().prog.
current_ast_builder().expand_expr([x.ptr])
current_ast_builder().expand_exprs([x.ptr])
]
elif isinstance(x, Matrix):
entries += x.entries
Expand Down Expand Up @@ -1616,7 +1616,7 @@ def __call__(self, *args):
elif isinstance(x, impl.Expr) and x.ptr.is_tensor():
entries += [
impl.Expr(e) for e in impl.get_runtime().prog.
current_ast_builder().expand_expr([x.ptr])
current_ast_builder().expand_exprs([x.ptr])
]
else:
entries.append(x)
Expand Down
16 changes: 10 additions & 6 deletions python/taichi/lang/mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -605,14 +605,17 @@ def _TetMesh():
class MeshElementFieldProxy:
def __init__(self, mesh: MeshInstance, element_type: MeshElementType,
entry_expr: impl.Expr):
ast_builder = impl.get_runtime().prog.current_ast_builder()

self.mesh = mesh
self.element_type = element_type
self.entry_expr = entry_expr

element_field = self.mesh.fields[self.element_type]
for key, attr in element_field.field_dict.items():

global_entry_expr = impl.Expr(
_ti_core.get_index_conversion(
ast_builder.mesh_index_conversion(
self.mesh.mesh_ptr, element_type, entry_expr,
ConvType.l2r if element_field.attr_dict[key].reorder else
ConvType.l2g)) # transform index space
Expand All @@ -622,7 +625,7 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType,
setattr(
self, key,
impl.Expr(
_ti_core.subscript(
ast_builder.expr_subscript(
attr.ptr, global_entry_expr_group,
impl.get_runtime().get_current_src_info())))
elif isinstance(attr, StructField):
Expand All @@ -633,7 +636,7 @@ def __init__(self, mesh: MeshInstance, element_type: MeshElementType,
setattr(
self, key,
impl.Expr(
_ti_core.subscript(
ast_builder.expr_subscript(
var, global_entry_expr_group,
impl.get_runtime().get_current_src_info())))

Expand All @@ -650,10 +653,11 @@ def ptr(self):

@property
def id(self): # return the global non-reordered index
ast_builder = impl.get_runtime().prog.current_ast_builder()
l2g_expr = impl.Expr(
_ti_core.get_index_conversion(self.mesh.mesh_ptr,
self.element_type, self.entry_expr,
ConvType.l2g))
ast_builder.mesh_index_conversion(self.mesh.mesh_ptr,
self.element_type,
self.entry_expr, ConvType.l2g))
return l2g_expr


Expand Down
7 changes: 5 additions & 2 deletions python/taichi/lang/simt/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,8 @@ def __init__(self, shape, dtype):

@taichi_scope
def subscript(self, *indices):
return impl.make_index_expr(self.shared_array_proxy,
make_expr_group(*indices))
ast_builder = impl.get_runtime().prog.current_ast_builder()
return impl.Expr(
ast_builder.expr_subscript(
self.shared_array_proxy, make_expr_group(*indices),
impl.get_runtime().get_current_src_info()))
4 changes: 4 additions & 0 deletions taichi/common/exceptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ class TaichiSyntaxError : public TaichiExceptionImpl {
using TaichiExceptionImpl::TaichiExceptionImpl;
};

class TaichiIndexError : public TaichiExceptionImpl {
using TaichiExceptionImpl::TaichiExceptionImpl;
};

class TaichiRuntimeError : public TaichiExceptionImpl {
using TaichiExceptionImpl::TaichiExceptionImpl;
};
Expand Down
6 changes: 0 additions & 6 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,6 @@ Expr bit_cast(const Expr &input, DataType dt) {
return Expr::make<UnaryOpExpression>(UnaryOpType::cast_bits, input, dt);
}

Expr Expr::operator[](const ExprGroup &indices) const {
TI_ASSERT(is<FieldExpression>() || is<MatrixFieldExpression>() ||
is<ExternalTensorExpression>() || is_tensor(expr->ret_type));
return Expr::make<IndexExpression>(*this, indices);
}

Expr &Expr::operator=(const Expr &o) {
set(o);
return *this;
Expand Down
2 changes: 0 additions & 2 deletions taichi/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,6 @@ class Expr {
// std::variant<Expr, std::string> in FrontendPrintStmt.
Expr &operator=(const Expr &o);

Expr operator[](const ExprGroup &indices) const;

template <typename T, typename... Args>
static Expr make(Args &&...args) {
return Expr(std::make_shared<T>(std::forward<Args>(args)...));
Expand Down
Loading

0 comments on commit 49d7d07

Please sign in to comment.