Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
move expr unification to type check
Browse files Browse the repository at this point in the history
AD1024 committed Aug 18, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent f3ff3d2 commit 39ed915
Showing 2 changed files with 39 additions and 40 deletions.
52 changes: 39 additions & 13 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
@@ -181,7 +181,37 @@ void UnaryOpExpression::flatten(FlattenContext *ctx) {
ctx->push_back(std::move(unary));
}

Expr to_broadcast_tensor(const Expr &elt, const DataType &dt) {
TI_ASSERT(dt->is<TensorType>());
auto tensor_type = dt->as<TensorType>();
auto elt_type = tensor_type->get_element_type();
TI_ASSERT_INFO(elt_type->is<PrimitiveType>(),
"Only primitive types are supported in Tensors, got {}",
elt_type->to_string());
std::vector<Expr> broadcast_values(tensor_type->get_num_elements(), elt);
return Expr::make<MatrixExpression>(broadcast_values,
tensor_type->get_shape(), elt_type);
}

std::tuple<Expr, Expr> unify_binop_operands(const Expr &e1, const Expr &e2) {
if ((!e1->ret_type->is<TensorType>() && !e2->ret_type->is<TensorType>()) ||
(e1->ret_type->is<TensorType>() && e2->ret_type->is<TensorType>())) {
return std::tuple(e1, e2);
}
if (!e1->ret_type->is<TensorType>()) {
return std::tuple(to_broadcast_tensor(e1, e2->ret_type), e2);
}
return std::tuple(e1, to_broadcast_tensor(e2, e1->ret_type));
}

void BinaryOpExpression::type_check(CompileConfig *config) {
auto [unified_l, unified_r] = unify_binop_operands(lhs, rhs);
lhs = unified_l;
rhs = unified_r;
if (lhs->ret_type == PrimitiveType::unknown)
lhs.type_check(config);
if (rhs->ret_type == PrimitiveType::unknown)
rhs.type_check(config);
TI_ASSERT_TYPE_CHECKED(lhs);
TI_ASSERT_TYPE_CHECKED(rhs);
auto lhs_type = lhs->ret_type;
@@ -194,22 +224,18 @@ void BinaryOpExpression::type_check(CompileConfig *config) {
};

if (lhs_type->is<TensorType>()) {
TI_ASSERT(rhs_type->is<TensorType>());
auto rhs_tensor_type = rhs_type->cast<TensorType>();
auto dtype = lhs_type->as<TensorType>()->get_element_type();
if (rhs_type->is<PrimitiveType>()) {
ret_type = promoted_type(dtype, rhs_type);
} else {
TI_ASSERT(rhs_type->is<TensorType>());
auto rhs_tensor_type = rhs_type->cast<TensorType>();
if (rhs_tensor_type->get_shape() !=
lhs_type->cast<TensorType>()->get_shape())
error();
auto rhs_elem_type = rhs_type->as<TensorType>()->get_element_type();
if (rhs_elem_type != PrimitiveType::unknown)
ret_type = promoted_type(dtype, rhs_elem_type);
}
if (rhs_tensor_type->get_shape() !=
lhs_type->cast<TensorType>()->get_shape())
error();
auto rhs_elem_type = rhs_type->as<TensorType>()->get_element_type();
if (rhs_elem_type != PrimitiveType::unknown)
dtype = promoted_type(dtype, rhs_elem_type);
// TODO: shape check!
ret_type = TypeFactory::create_tensor_type(
lhs_type->cast<TensorType>()->get_shape(), ret_type);
lhs_type->cast<TensorType>()->get_shape(), dtype);
return;
}

27 changes: 0 additions & 27 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
@@ -379,33 +379,6 @@ class BinaryOpExpression : public Expression {

BinaryOpExpression(const BinaryOpType &type, const Expr &lhs, const Expr &rhs)
: type(type), lhs(lhs), rhs(rhs) {
auto to_broadcast_tensor = [](const Expr &elt, const DataType &dt) -> Expr {
TI_ASSERT(dt->is<TensorType>());
auto tensor_type = dt->as<TensorType>();
auto elt_type = tensor_type->get_element_type();
TI_ASSERT_INFO(elt_type->is<PrimitiveType>(),
"Only primitive types are supported in Tensors, got {}",
elt_type->to_string());
std::vector<Expr> broadcast_values(tensor_type->get_num_elements(), elt);
return Expr::make<MatrixExpression>(broadcast_values,
tensor_type->get_shape(), elt_type);
};

auto unify_expr = [&](const Expr &e1, const Expr &e2) {
if ((!e1->ret_type->is<TensorType>() &&
!e2->ret_type->is<TensorType>()) ||
(e1->ret_type->is<TensorType>() && e2->ret_type->is<TensorType>())) {
return std::tuple(e1, e2);
}
if (!e1->ret_type->is<TensorType>()) {
return std::tuple(to_broadcast_tensor(e1, e2->ret_type), e2);
}
return std::tuple(e1, to_broadcast_tensor(e2, e1->ret_type));
};

auto [unified_l, unified_r] = unify_expr(lhs, rhs);
this->lhs = unified_l;
this->rhs = unified_r;
}

void type_check(CompileConfig *config) override;

0 comments on commit 39ed915

Please sign in to comment.