Skip to content

Commit

Permalink
[Lang] Remove the real_matrix switch (#6885)
Browse files Browse the repository at this point in the history
Issue: #5819

### Brief Summary

We no longer need the switch after #6801.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
strongoier and pre-commit-ci[bot] authored Dec 13, 2022
1 parent fd6cbe9 commit 4ff41bb
Show file tree
Hide file tree
Showing 21 changed files with 77 additions and 191 deletions.
5 changes: 1 addition & 4 deletions python/taichi/_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,10 +162,7 @@ def _svd3d(A, dt, iters=None):
Decomposed 3x3 matrices `U`, 'S' and `V`.
"""
assert A.n == 3 and A.m == 3
if impl.current_cfg().real_matrix:
inputs = get_runtime().prog.current_ast_builder().expand_expr([A.ptr])
else:
inputs = tuple([e.ptr for e in A.entries])
inputs = get_runtime().prog.current_ast_builder().expand_expr([A.ptr])
assert dt in [f32, f64]
if iters is None:
if dt == f32:
Expand Down
115 changes: 42 additions & 73 deletions python/taichi/lang/ast/ast_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
TaichiTypeError)
from taichi.lang.expr import Expr, make_expr_group
from taichi.lang.field import Field
from taichi.lang.impl import current_cfg
from taichi.lang.matrix import Matrix, MatrixType, Vector, is_vector
from taichi.lang.snode import append, deactivate, length
from taichi.lang.struct import Struct, StructType
Expand Down Expand Up @@ -300,8 +299,7 @@ def process_generators(ctx, node, now_comp, func, result):
with ctx.static_scope_guard():
_iter = build_stmt(ctx, node.generators[now_comp].iter)

if impl.current_cfg().real_matrix and isinstance(
_iter, impl.Expr) and _iter.ptr.is_tensor():
if isinstance(_iter, impl.Expr) and _iter.ptr.is_tensor():
shape = _iter.ptr.get_shape()
flattened = [
Expr(x) for x in ctx.ast_builder.expand_expr([_iter.ptr])
Expand Down Expand Up @@ -505,8 +503,7 @@ def build_Call(ctx, node):
for arg in node.args:
if isinstance(arg, ast.Starred):
arg_list = arg.ptr
if impl.current_cfg().real_matrix and isinstance(
arg_list, Expr):
if isinstance(arg_list, Expr) and arg_list.is_tensor():
# Expand Expr with Matrix-type return into list of Exprs
arg_list = [
Expr(x)
Expand All @@ -529,8 +526,7 @@ def build_Call(ctx, node):
node.ptr = impl.ti_format(*args, **keywords)
return node.ptr

if ((id(func) == id(Matrix)
or id(func) == id(Vector))) and impl.current_cfg().real_matrix:
if id(func) == id(Matrix) or id(func) == id(Vector):
node.ptr = matrix.make_matrix(*args, **keywords)
return node.ptr

Expand Down Expand Up @@ -654,57 +650,40 @@ def transform_as_kernel():
if isinstance(ctx.func.arguments[i].annotation,
(MatrixType)):

if current_cfg().real_matrix:
# with real_matrix=True, "data" is expected to be an Expr here
# Therefore we simply call "impl.expr_init_func(data)" to perform:
#
# TensorType* t = alloca()
# assign(t, data)
#
# We created local variable "t" - a copy of the passed-in argument "data"
if not isinstance(
data,
expr.Expr) or not data.ptr.is_tensor():
raise TaichiSyntaxError(
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix, but got {type(data)}."
)

element_shape = data.ptr.get_ret_type().shape()
if len(element_shape
) != ctx.func.arguments[i].annotation.ndim:
raise TaichiSyntaxError(
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with ndim {ctx.func.arguments[i].annotation.ndim}, but got {len(element_shape)}."
)

assert ctx.func.arguments[i].annotation.ndim > 0
if element_shape[0] != ctx.func.arguments[
i].annotation.n:
raise TaichiSyntaxError(
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with n {ctx.func.arguments[i].annotation.n}, but got {element_shape[0]}."
)

if ctx.func.arguments[
i].annotation.ndim == 2 and element_shape[
1] != ctx.func.arguments[
i].annotation.m:
raise TaichiSyntaxError(
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with m {ctx.func.arguments[i].annotation.m}, but got {element_shape[0]}."
)
else:
if not isinstance(data, Matrix):
raise TaichiSyntaxError(
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix, but got {type(data)}."
)

if data.m != ctx.func.arguments[i].annotation.m:
raise TaichiSyntaxError(
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with m {ctx.func.arguments[i].annotation.m}, but got {data.m}."
)

if data.n != ctx.func.arguments[i].annotation.n:
raise TaichiSyntaxError(
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with n {ctx.func.arguments[i].annotation.n}, but got {data.n}."
)
# "data" is expected to be an Expr here,
# so we simply call "impl.expr_init_func(data)" to perform:
#
# TensorType* t = alloca()
# assign(t, data)
#
# We created local variable "t" - a copy of the passed-in argument "data"
if not isinstance(
data, expr.Expr) or not data.ptr.is_tensor():
raise TaichiSyntaxError(
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix, but got {type(data)}."
)

element_shape = data.ptr.get_ret_type().shape()
if len(element_shape
) != ctx.func.arguments[i].annotation.ndim:
raise TaichiSyntaxError(
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with ndim {ctx.func.arguments[i].annotation.ndim}, but got {len(element_shape)}."
)

assert ctx.func.arguments[i].annotation.ndim > 0
if element_shape[0] != ctx.func.arguments[
i].annotation.n:
raise TaichiSyntaxError(
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with n {ctx.func.arguments[i].annotation.n}, but got {element_shape[0]}."
)

if ctx.func.arguments[
i].annotation.ndim == 2 and element_shape[
1] != ctx.func.arguments[i].annotation.m:
raise TaichiSyntaxError(
f"Argument {arg.arg} of type {ctx.func.arguments[i].annotation} is expected to be a Matrix with m {ctx.func.arguments[i].annotation.m}, but got {element_shape[0]}."
)

ctx.create_variable(arg.arg, impl.expr_init_func(data))
continue

Expand Down Expand Up @@ -1189,12 +1168,8 @@ def build_grouped_ndrange_for(ctx, node):
f"Group for should have 1 loop target, found {len(targets)}"
)
target = targets[0]
if current_cfg().real_matrix:
mat = matrix.make_matrix([0] * len(ndrange_var.dimensions),
dt=primitive_types.i32)
else:
mat = matrix.Vector([0] * len(ndrange_var.dimensions),
dt=primitive_types.i32)
mat = matrix.make_matrix([0] * len(ndrange_var.dimensions),
dt=primitive_types.i32)
target_var = impl.expr_init(mat)

ctx.create_variable(target, target_var)
Expand Down Expand Up @@ -1236,15 +1211,9 @@ def build_struct_for(ctx, node, is_grouped):
expr_group = expr.make_expr_group(loop_indices)
impl.begin_frontend_struct_for(ctx.ast_builder, expr_group,
loop_var)
if impl.current_cfg().real_matrix:
ctx.create_variable(
target,
matrix.make_matrix(loop_indices,
dt=primitive_types.i32))
else:
ctx.create_variable(
target,
matrix.Vector(loop_indices, dt=primitive_types.i32))
ctx.create_variable(
target,
matrix.make_matrix(loop_indices, dt=primitive_types.i32))
build_stmts(ctx, node.body)
ctx.ast_builder.end_frontend_struct_for()
else:
Expand Down
4 changes: 0 additions & 4 deletions python/taichi/lang/ast/ast_transformer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from taichi.lang.exception import (TaichiCompilationError, TaichiNameError,
TaichiSyntaxError,
handle_exception_from_cpp)
from taichi.lang.matrix import Matrix


class Builder:
Expand Down Expand Up @@ -246,9 +245,6 @@ def get_var_by_name(self, name):
if name in s:
return s[name]
if name in self.global_vars:
if isinstance(self.global_vars[name],
Matrix) and impl.current_cfg().real_matrix:
return impl.expr_init(self.global_vars[name])
return self.global_vars[name]
try:
return getattr(builtins, name)
Expand Down
4 changes: 1 addition & 3 deletions python/taichi/lang/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ def __init__(self, *args, tb=None, dtype=None):
'Cannot initialize scalar expression from '
f'taichi class: {type(args[0])}')
elif isinstance(args[0], (list, tuple)):
assert impl.current_cfg().real_matrix
self.ptr = make_matrix(args[0]).ptr
else:
# assume to be constant
Expand Down Expand Up @@ -173,8 +172,7 @@ def _get_flattened_ptrs(val):
for item in val._members:
ptrs.extend(_get_flattened_ptrs(item))
return ptrs
if impl.current_cfg().real_matrix and isinstance(
val, Expr) and val.ptr.is_tensor():
if isinstance(val, Expr) and val.ptr.is_tensor():
return impl.get_runtime().prog.current_ast_builder().expand_expr(
[val.ptr])
return [Expr(val).ptr]
Expand Down
49 changes: 13 additions & 36 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from taichi.lang._ndarray import ScalarNdarray
from taichi.lang._ndrange import GroupedNDRange, _Ndrange
from taichi.lang._texture import RWTextureAccessor
from taichi.lang.any_array import AnyArray, AnyArrayAccess
from taichi.lang.any_array import AnyArray
from taichi.lang.enums import SNodeGradType
from taichi.lang.exception import (TaichiCompilationError, TaichiIndexError,
TaichiRuntimeError, TaichiSyntaxError,
Expand All @@ -17,8 +17,7 @@
from taichi.lang.field import Field, ScalarField
from taichi.lang.kernel_arguments import SparseMatrixProxy
from taichi.lang.matrix import (Matrix, MatrixField, MatrixNdarray, MatrixType,
Vector, VectorNdarray, _IntermediateMatrix,
_MatrixFieldElement, make_matrix)
VectorNdarray, make_matrix)
from taichi.lang.mesh import (ConvType, MeshElementFieldProxy, MeshInstance,
MeshRelationAccessProxy,
MeshReorderedMatrixFieldProxy,
Expand Down Expand Up @@ -58,18 +57,11 @@ def expr_init(rhs):
if isinstance(rhs, Matrix) and (hasattr(rhs, "_DIM")):
return Matrix(*rhs.to_list(), ndim=rhs.ndim)
if isinstance(rhs, Matrix):
if current_cfg().real_matrix:
if rhs.ndim == 1:
entries = [rhs(i) for i in range(rhs.n)]
else:
entries = [[rhs(i, j) for j in range(rhs.m)]
for i in range(rhs.n)]
return make_matrix(entries)
if (isinstance(rhs, Vector)
or getattr(rhs, "ndim", None) == 1) and rhs.m == 1:
# _IntermediateMatrix may reach here
return Vector(rhs.to_list(), ndim=rhs.ndim)
return Matrix(rhs.to_list(), ndim=rhs.ndim)
if rhs.ndim == 1:
entries = [rhs(i) for i in range(rhs.n)]
else:
entries = [[rhs(i, j) for j in range(rhs.m)] for i in range(rhs.n)]
return make_matrix(entries)
if isinstance(rhs, SharedArray):
return rhs
if isinstance(rhs, Struct):
Expand Down Expand Up @@ -230,11 +222,9 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
f'Field with dim {field_dim} accessed with indices of dim {index_dim}'
)
if isinstance(value, MatrixField):
if current_cfg().real_matrix:
return Expr(
_ti_core.subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))
return _MatrixFieldElement(value, indices_expr_group)
return Expr(
_ti_core.subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))
if isinstance(value, StructField):
entries = {
k: subscript(ast_builder, v, *indices)
Expand All @@ -252,25 +242,12 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
raise IndexError(
f'Field with dim {dim - element_dim} accessed with indices of dim {index_dim}'
)
if element_dim == 0 or current_cfg().real_matrix:
return Expr(
_ti_core.subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))
n = value.element_shape()[0]
m = 1 if element_dim == 1 else value.element_shape()[1]
any_array_access = AnyArrayAccess(value, indices)
ret = _IntermediateMatrix(n,
m, [
any_array_access.subscript(i, j)
for i in range(n) for j in range(m)
],
ndim=element_dim)
ret.any_array_access = any_array_access
return ret
return Expr(
_ti_core.subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))
if isinstance(value, Expr):
# Index into TensorType
# value: IndexExpression with ret_type = TensorType
assert current_cfg().real_matrix
assert value.is_tensor()

if has_slice:
Expand Down
4 changes: 2 additions & 2 deletions python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,8 @@ def func_call_rvalue(self, key, args):
elif isinstance(anno, primitive_types.RefType):
non_template_args.append(
_ti_core.make_reference(args[i].ptr))
elif impl.current_cfg().real_matrix and isinstance(
args[i], impl.Expr) and args[i].ptr.is_tensor():
elif isinstance(args[i],
impl.Expr) and args[i].ptr.is_tensor():
non_template_args.extend([
Expr(x) for x in impl.get_runtime().prog.
current_ast_builder().expand_expr([args[i].ptr])
Expand Down
7 changes: 2 additions & 5 deletions python/taichi/lang/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,9 +666,6 @@ def _subscript(self, *indices):
is_global_mat = isinstance(self, _MatrixFieldElement)
return self._impl._subscript(is_global_mat, *indices)

def _make_matrix(self):
return make_matrix(self._impl.entries)

def to_list(self):
"""Return this matrix as a 1D `list`.
Expand Down Expand Up @@ -1797,7 +1794,7 @@ def cast(self, mat):
if isinstance(mat, impl.Expr) and mat.ptr.is_tensor():
return ops_mod.cast(mat, self.dtype)

if isinstance(mat, Matrix) and impl.current_cfg().real_matrix:
if isinstance(mat, Matrix):
arr = [[mat(i, j) for j in range(self.m)] for i in range(self.n)]
return ops_mod.cast(make_matrix(arr), self.dtype)

Expand Down Expand Up @@ -1898,7 +1895,7 @@ def cast(self, vec):
if isinstance(vec, impl.Expr) and vec.ptr.is_tensor():
return ops_mod.cast(vec, self.dtype)

if isinstance(vec, Matrix) and impl.current_cfg().real_matrix:
if isinstance(vec, Matrix):
arr = vec.entries
return ops_mod.cast(make_matrix(arr), self.dtype)

Expand Down
2 changes: 1 addition & 1 deletion python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def uniform_matrix_inputs(*args):
results = []
for arg in args:
if has_real_matrix and is_matrix_class(arg):
results.append(arg._make_matrix())
results.append(impl.expr_init(arg))
else:
results.append(arg)

Expand Down
1 change: 0 additions & 1 deletion taichi/analysis/offline_cache_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ static std::vector<std::uint8_t> get_offline_cache_key_of_compile_config(
serializer(config->experimental_auto_mesh_local);
serializer(config->auto_mesh_local_default_occupacy);
serializer(config->dynamic_index);
serializer(config->real_matrix);
serializer(config->real_matrix_scalarize);
serializer.finalize();

Expand Down
6 changes: 1 addition & 5 deletions taichi/codegen/codegen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,7 @@
namespace taichi::lang {

inline bool codegen_vector_type(CompileConfig *config) {
if (config->real_matrix && !config->real_matrix_scalarize) {
return true;
}

return false;
return !config->real_matrix_scalarize;
}

} // namespace taichi::lang
8 changes: 1 addition & 7 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,7 @@ void UnaryOpExpression::type_check(CompileConfig *config) {
auto operand_primitive_type = operand->ret_type.get_element_type();
auto ret_primitive_type = ret_type;

if (config->real_matrix) {
TI_ASSERT(operand_primitive_type->is<PrimitiveType>());

} else if (!operand->ret_type->is<PrimitiveType>()) {
if (!operand_primitive_type->is<PrimitiveType>()) {
throw TaichiTypeError(fmt::format(
"unsupported operand type(s) for '{}': '{}'", unary_op_type_name(type),
operand_primitive_type->to_string()));
Expand Down Expand Up @@ -539,9 +536,6 @@ void ExternalTensorExpression::flatten(FlattenContext *ctx) {
// The scalarization should happen after
// irpass::lower_access()
auto prim_dt = dt;
if (!get_compile_config()->real_matrix) {
prim_dt = dt.get_element_type();
}
auto ptr = Stmt::make<ArgLoadStmt>(arg_id, prim_dt, /*is_ptr=*/true);

int external_dims = dim - std::abs(element_dim);
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ class MatrixFieldExpression : public Expression {

/**
* Creating a local matrix;
* lowered from ti.Matrix with real_matrix=True
* lowered from ti.Matrix
*/
class MatrixExpression : public Expression {
public:
Expand Down
Loading

0 comments on commit 4ff41bb

Please sign in to comment.