Skip to content

Commit

Permalink
[Lang] MatrixType refactor: Support matrix slice (#6430)
Browse files Browse the repository at this point in the history
Issue: #5819

### Brief Summary

A slice actually corresponds to several indices. Let's take a look at
the following example:
```python
import taichi as ti

ti.init(real_matrix=True, real_matrix_scalarize=True, print_ir=True)

@ti.kernel
def foo():
    c = ti.Vector([1, 2, 3, 4, 5])
    print(c[:4:2])

foo()
```

`c[:4:2]` is a "partially selected vector with length 2" which contains
references to `c[0]` and `c[2]`. Let's represent it with `c[(0), (2),
shape=2]`.

This PR does the following:
- expands slices into groups of indices in Python;
- extends `IndexExpression` to accept a group of indices;
- creates `MatrixOfMatrixPtrStmt` to represent a "partially selected
matrix" in CHI IR;
- eliminates `MatrixOfMatrixPtrStmt` in the `lower_matrix_ptr()` pass.

Frontend IR and CHI IR of the above example:
```
[I 10/25/22 19:54:51.885 816603] [compile_to_offloads.cpp:operator()@22] [foo_c80_0] Initial IR:
kernel {
  $0 = alloca @tmp0
  @tmp0 = [1, 2, 3, 4, 5] (dt=[Tensor (5) i32])
  $2 = alloca @TMP1
  @TMP1 = @tmp0
  print @TMP1[(0), (2), shape=(2)], "\n"
}
[I 10/25/22 19:54:51.885 816603] [compile_to_offloads.cpp:operator()@22] [foo_c80_0] Lowered:
kernel {
  <[Tensor (5) i32]> $0 = alloca
  <i32> $1 = const 1
  <i32> $2 = const 2
  <i32> $3 = const 3
  <i32> $4 = const 4
  <i32> $5 = const 5
  <[Tensor (5) i32]> $6 = [$1, $2, $3, $4, $5]
  $7 : local store [$0 <- $6]
  <[Tensor (5) i32]> $8 = alloca
  $9 = local load [$0]
  $10 : local store [$8 <- $9]
  <i32> $11 = const 0
  <*i32> $12 = shift ptr [$8 + $11]
  <i32> $13 = const 2
  <*i32> $14 = shift ptr [$8 + $13]
  <[Tensor (2) i32]> $15 = matrix of matrix ptr [$12, $14]
  $16 = local load [$15]
  print $16, "\n"
}

```

Some tests of matrix slice haven't been enabled yet because they require
`x.transpose()`, which depends on #6425.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Zhanlue Yang <[email protected]>
  • Loading branch information
3 people authored Oct 26, 2022
1 parent b1a3583 commit c273503
Show file tree
Hide file tree
Showing 12 changed files with 213 additions and 50 deletions.
33 changes: 28 additions & 5 deletions python/taichi/lang/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
python_scope, taichi_scope, warning)
from taichi.types.primitive_types import (all_types, f16, f32, f64, i32, i64,
u8, u32, u64)
from taichi.types.utils import is_tensor


@taichi_scope
Expand Down Expand Up @@ -182,7 +181,8 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
indices = ()

if has_slice:
if not isinstance(value, Matrix):
if not isinstance(value, Matrix) and not (isinstance(value, Expr)
and value.is_tensor()):
raise SyntaxError(
f"The type {type(value)} do not support index of slice type")
else:
Expand Down Expand Up @@ -269,9 +269,32 @@ def subscript(ast_builder, value, *_indices, skip_reordered=False):
if isinstance(value, Expr):
# Index into TensorType
# value: IndexExpression with ret_type = TensorType
assert current_cfg().real_matrix is True
assert is_tensor(value.ptr.get_ret_type())

assert current_cfg().real_matrix
assert value.is_tensor()

if has_slice:
shape = value.get_shape()
dim = len(shape)
assert dim == len(indices)
indices = [
_calc_slice(index, shape[i])
if isinstance(index, slice) else [index]
for i, index in enumerate(indices)
]
if dim == 1:
multiple_indices = [make_expr_group(i) for i in indices[0]]
return_shape = (len(indices[0]), )
else:
assert dim == 2
multiple_indices = [
make_expr_group(i, j) for i in indices[0]
for j in indices[1]
]
return_shape = (len(indices[0]), len(indices[1]))
return Expr(
_ti_core.subscript_with_multiple_indices(
value.ptr, multiple_indices, return_shape,
get_runtime().get_current_src_info()))
return Expr(
_ti_core.subscript(value.ptr, indices_expr_group,
get_runtime().get_current_src_info()))
Expand Down
5 changes: 4 additions & 1 deletion taichi/analysis/gen_offline_cache_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,10 @@ class ASTSerializer : public IRVisitor, public ExpressionVisitor {
void visit(IndexExpression *expr) override {
emit(ExprOpCode::IndexExpression);
emit(expr->var);
emit(expr->indices.exprs);
for (auto &indices : expr->indices_group) {
emit(indices.exprs);
}
emit(expr->ret_shape);
}

void visit(MatrixExpression *expr) override {
Expand Down
1 change: 1 addition & 0 deletions taichi/inc/statements.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ PER_STATEMENT(GetChStmt)
PER_STATEMENT(LocalLoadStmt)
PER_STATEMENT(GlobalPtrStmt)
PER_STATEMENT(MatrixOfGlobalPtrStmt)
PER_STATEMENT(MatrixOfMatrixPtrStmt)

// Offloaded
PER_STATEMENT(OffloadedStmt)
Expand Down
13 changes: 12 additions & 1 deletion taichi/ir/expression_printer.h
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,18 @@ class ExpressionHumanFriendlyPrinter : public ExpressionPrinter {
void visit(IndexExpression *expr) override {
expr->var->accept(this);
emit('[');
emit_vector(expr->indices.exprs);
if (expr->ret_shape.empty()) {
emit_vector(expr->indices_group[0].exprs);
} else {
for (auto &indices : expr->indices_group) {
emit('(');
emit_vector(indices.exprs);
emit("), ");
}
emit("shape=(");
emit_vector(expr->ret_shape);
emit(')');
}
emit(']');
}

Expand Down
92 changes: 60 additions & 32 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "taichi/program/program.h"
#include "taichi/common/exceptions.h"

#include <numeric>

namespace taichi::lang {

#define TI_ASSERT_TYPE_CHECKED(x) \
Expand Down Expand Up @@ -572,17 +574,11 @@ Stmt *make_ndarray_access(Expression::FlattenContext *ctx,
return ctx->push_back(std::move(external_ptr_stmt));
}

Stmt *make_tensor_access(Expression::FlattenContext *ctx,
Expr var,
ExprGroup indices,
std::vector<int> shape,
int stride) {
flatten_lvalue(var, ctx);
if (!var->is_lvalue()) {
auto alloca_stmt = ctx->push_back<AllocaStmt>(var->ret_type);
ctx->push_back<LocalStoreStmt>(alloca_stmt, var->stmt);
var->stmt = alloca_stmt;
}
Stmt *make_tensor_access_single_element(Expression::FlattenContext *ctx,
const Expr &var,
const ExprGroup &indices,
const std::vector<int> &shape,
int stride) {
bool needs_dynamic_index = false;
for (int i = 0; i < (int)indices.size(); ++i) {
if (!indices[i].is<ConstExpression>()) {
Expand Down Expand Up @@ -616,6 +612,30 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx,
return ctx->push_back<MatrixPtrStmt>(var->stmt, offset_stmt);
}

Stmt *make_tensor_access(Expression::FlattenContext *ctx,
Expr var,
const std::vector<ExprGroup> &indices_group,
DataType ret_type,
std::vector<int> shape,
int stride) {
flatten_lvalue(var, ctx);
if (!var->is_lvalue()) {
auto alloca_stmt = ctx->push_back<AllocaStmt>(var->ret_type);
ctx->push_back<LocalStoreStmt>(alloca_stmt, var->stmt);
var->stmt = alloca_stmt;
}
if (is_tensor(ret_type)) {
std::vector<Stmt *> stmts;
for (auto &indices : indices_group) {
stmts.push_back(
make_tensor_access_single_element(ctx, var, indices, shape, stride));
}
return ctx->push_back<MatrixOfMatrixPtrStmt>(stmts, ret_type);
}
return make_tensor_access_single_element(ctx, var, indices_group[0], shape,
stride);
}

void MatrixExpression::type_check(CompileConfig *config) {
// TODO: typecheck matrix
for (auto &arg : elements) {
Expand Down Expand Up @@ -671,7 +691,14 @@ bool IndexExpression::is_global() const {
void IndexExpression::type_check(CompileConfig *) {
// TODO: Change to type-based solution
// Currently, dimension compatibility check happens in Python
if (is_field()) { // field
TI_ASSERT(indices_group.size() == std::accumulate(begin(ret_shape),
end(ret_shape), 1,
std::multiplies<>()));
if (!ret_shape.empty()) {
TI_ASSERT_INFO(is_tensor(), "Slice or swizzle can only apply on matrices");
auto element_type = var->ret_type->as<TensorType>()->get_element_type();
ret_type = TypeFactory::create_tensor_type(ret_shape, element_type);
} else if (is_field()) { // field
ret_type = var.cast<FieldExpression>()->dt->get_compute_type();
} else if (is_matrix_field()) {
auto matrix_field_expr = var.cast<MatrixFieldExpression>();
Expand All @@ -682,7 +709,7 @@ void IndexExpression::type_check(CompileConfig *) {
} else if (is_ndarray()) { // ndarray
auto external_tensor_expr = var.cast<ExternalTensorExpression>();
int total_dim = external_tensor_expr->dim;
int index_dim = indices.exprs.size();
int index_dim = indices_group[0].exprs.size();

if (index_dim == total_dim) {
// Access all the way to a single element
Expand All @@ -693,9 +720,9 @@ void IndexExpression::type_check(CompileConfig *) {
}
} else if (is_tensor()) { // local tensor
auto shape = var->ret_type->as<TensorType>()->get_shape();
if (indices.size() != shape.size()) {
if (indices_group[0].size() != shape.size()) {
TI_ERROR("Expected {} indices, but got {}.", shape.size(),
indices.size());
indices_group[0].size());
}
ret_type = var->ret_type->cast<TensorType>()->get_element_type();
} else {
Expand All @@ -704,28 +731,32 @@ void IndexExpression::type_check(CompileConfig *) {
"local tensor");
}

for (int i = 0; i < indices.exprs.size(); i++) {
auto &expr = indices.exprs[i];
TI_ASSERT_TYPE_CHECKED(expr);
if (!is_integral(expr->ret_type))
throw TaichiTypeError(
fmt::format("indices must be integers, however '{}' is "
"provided as index {}",
expr->ret_type->to_string(), i));
for (auto &indices : indices_group) {
for (int i = 0; i < indices.exprs.size(); i++) {
auto &expr = indices.exprs[i];
TI_ASSERT_TYPE_CHECKED(expr);
if (!is_integral(expr->ret_type))
throw TaichiTypeError(
fmt::format("indices must be integers, however '{}' is "
"provided as index {}",
expr->ret_type->to_string(), i));
}
}
}

void IndexExpression::flatten(FlattenContext *ctx) {
if (is_field()) {
stmt = make_field_access(ctx, *var.cast<FieldExpression>(), indices);
stmt =
make_field_access(ctx, *var.cast<FieldExpression>(), indices_group[0]);
} else if (is_matrix_field()) {
stmt = make_matrix_field_access(ctx, *var.cast<MatrixFieldExpression>(),
indices, ret_type);
indices_group[0], ret_type);
} else if (is_ndarray()) {
stmt = make_ndarray_access(ctx, var, indices);
stmt = make_ndarray_access(ctx, var, indices_group[0]);
} else if (is_tensor()) {
stmt = make_tensor_access(
ctx, var, indices, var->ret_type->cast<TensorType>()->get_shape(), 1);
stmt =
make_tensor_access(ctx, var, indices_group, ret_type,
var->ret_type->cast<TensorType>()->get_shape(), 1);
} else {
throw TaichiTypeError(
"Invalid IndexExpression: the source is not among field, ndarray or "
Expand All @@ -746,7 +777,7 @@ void StrideExpression::type_check(CompileConfig *) {
}

void StrideExpression::flatten(FlattenContext *ctx) {
stmt = make_tensor_access(ctx, var, indices, shape, stride);
stmt = make_tensor_access(ctx, var, {indices}, ret_type, shape, stride);
}

void RangeAssumptionExpression::type_check(CompileConfig *) {
Expand Down Expand Up @@ -1505,9 +1536,6 @@ void flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) {
}
} else if (ptr.is<StrideExpression>()) {
flatten_global_load(ptr, ctx);
} else if (ptr.is<FieldExpression>()) {
TI_ASSERT(ptr.cast<FieldExpression>()->snode->num_active_indices == 0);
flatten_global_load(ptr[ExprGroup()], ctx);
} else if (ptr.is<ArgLoadExpression>() &&
ptr.cast<ArgLoadExpression>()->is_ptr) {
flatten_global_load(ptr, ctx);
Expand Down
16 changes: 14 additions & 2 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,12 +579,24 @@ class IndexExpression : public Expression {
// `var` is one of FieldExpression, MatrixFieldExpression,
// ExternalTensorExpression, IdExpression
Expr var;
ExprGroup indices;
// In the cases of matrix slice and vector swizzle, there can be multiple
// indices, and the corresponding ret_shape should also be recorded. In normal
// index expressions ret_shape will be left empty.
std::vector<ExprGroup> indices_group;
std::vector<int> ret_shape;

IndexExpression(const Expr &var,
const ExprGroup &indices,
std::string tb = "")
: var(var), indices(indices) {
: var(var), indices_group({indices}) {
this->tb = tb;
}

IndexExpression(const Expr &var,
const std::vector<ExprGroup> &indices_group,
const std::vector<int> &ret_shape,
std::string tb = "")
: var(var), indices_group(indices_group), ret_shape(ret_shape) {
this->tb = tb;
}

Expand Down
10 changes: 9 additions & 1 deletion taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,19 @@ MatrixOfGlobalPtrStmt::MatrixOfGlobalPtrStmt(const std::vector<SNode *> &snodes,
TI_STMT_REG_FIELDS;
}

MatrixOfMatrixPtrStmt::MatrixOfMatrixPtrStmt(const std::vector<Stmt *> &stmts,
DataType dt)
: stmts(stmts) {
ret_type = dt;
TI_STMT_REG_FIELDS;
}

MatrixPtrStmt::MatrixPtrStmt(Stmt *origin_input, Stmt *offset_input) {
origin = origin_input;
offset = offset_input;
if (origin->is<AllocaStmt>() || origin->is<GlobalTemporaryStmt>() ||
origin->is<ExternalPtrStmt>() || origin->is<MatrixOfGlobalPtrStmt>()) {
origin->is<ExternalPtrStmt>() || origin->is<MatrixOfGlobalPtrStmt>() ||
origin->is<MatrixOfMatrixPtrStmt>()) {
auto tensor_type = origin->ret_type.ptr_removed()->cast<TensorType>();
TI_ASSERT(tensor_type != nullptr);
element_type() = tensor_type->get_element_type();
Expand Down
18 changes: 18 additions & 0 deletions taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,24 @@ class MatrixOfGlobalPtrStmt : public Stmt {
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* A matrix of MatrixPtrStmts. The purpose of this stmt is to handle matrix
* slice and vector swizzle. This stmt will be eliminated after the
* lower_matrix_ptr pass.
*
* TODO(yi/zhanlue): Keep scalarization pass alive for MatrixOfMatrixPtrStmt
* operations even with real_matrix_scalarize=False
*/
class MatrixOfMatrixPtrStmt : public Stmt {
public:
std::vector<Stmt *> stmts;

MatrixOfMatrixPtrStmt(const std::vector<Stmt *> &stmts, DataType dt);

TI_STMT_DEF_FIELDS(ret_type, stmts);
TI_DEFINE_ACCEPT_AND_CLONE
};

/**
* A pointer to an element of a matrix.
*/
Expand Down
5 changes: 5 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,11 @@ void export_lang(py::module &m) {
return idx_expr;
});

m.def(
"subscript_with_multiple_indices",
Expr::make<IndexExpression, const Expr &, const std::vector<ExprGroup> &,
const std::vector<int> &, std::string>);

m.def("make_stride_expr",
Expr::make<StrideExpression, const Expr &, const ExprGroup &,
const std::vector<int> &, int>);
Expand Down
13 changes: 13 additions & 0 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,19 @@ class IRPrinter : public IRVisitor {
print_raw(s);
}

void visit(MatrixOfMatrixPtrStmt *stmt) override {
std::string s = fmt::format("{}{} = matrix of matrix ptr [",
stmt->type_hint(), stmt->name());
for (int i = 0; i < (int)stmt->stmts.size(); i++) {
s += fmt::format("{}", stmt->stmts[i]->name());
if (i + 1 < (int)stmt->stmts.size()) {
s += ", ";
}
}
s += "]";
print_raw(s);
}

void visit(MatrixPtrStmt *stmt) override {
std::string s =
fmt::format("{}{} = shift ptr [{} + {}]", stmt->type_hint(),
Expand Down
Loading

0 comments on commit c273503

Please sign in to comment.