Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Matrix/Vector refactor: support basic matrix ops #6077

Merged
merged 21 commits into from
Sep 24, 2022
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 152 additions & 46 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -355,6 +355,57 @@ TaskCodeGenLLVM::TaskCodeGenLLVM(Kernel *kernel,
void TaskCodeGenLLVM::visit(DecorationStmt *stmt) {
}

void TaskCodeGenLLVM::create_value_cast(
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
UnaryOpStmt *stmt,
std::function<llvm::Value *(llvm::Value *, llvm::Type *)> cast_fn,
DataType to_ty) {
if (!to_ty->is<TensorType>()) {
llvm_val[stmt] =
cast_fn(llvm_val[stmt->operand], tlctx->get_data_type(to_ty));
} else {
auto from_ty = stmt->operand->ret_type->cast<TensorType>();
TI_ASSERT_INFO(from_ty, "Cannot cast non-tensor type {} to {}",
from_ty->to_string(), to_ty->to_string());
auto tensor_type = to_ty->cast<TensorType>();
llvm::Value *vec = llvm::UndefValue::get(tlctx->get_data_type(tensor_type));
for (int i = 0; i < tensor_type->get_num_elements(); ++i) {
auto elem = builder->CreateExtractElement(llvm_val[stmt->operand], i);
auto cast_input =
cast_fn(elem, tlctx->get_data_type(tensor_type->get_element_type()));
vec = builder->CreateInsertElement(vec, cast_input, i);
}
llvm_val[stmt] = vec;
}
}

void TaskCodeGenLLVM::create_fp_trunc(
UnaryOpStmt *stmt,
std::function<llvm::Value *(llvm::Value *, llvm::Type *)> trunc_fn,
llvm::Type *to_ty,
bool is_tensor,
bool trunc_self) {
if (!is_tensor) {
llvm_val[stmt] =
trunc_fn(trunc_self ? llvm_val[stmt] : llvm_val[stmt->operand], to_ty);
} else {
auto from_ty = stmt->operand->ret_type->cast<TensorType>();
TI_ASSERT_INFO(from_ty,
"Cannot truncate non-tensor type {} to a tensor type",
from_ty->to_string());
llvm::Value *vec = llvm::UndefValue::get(llvm::VectorType::get(
to_ty, from_ty->get_num_elements(), /*scalable=*/false));
// This assumes cast does not change the number of
// elements in a tensor value (should be legit)
for (int i = 0; i < from_ty->get_num_elements(); ++i) {
auto elem = builder->CreateExtractElement(
trunc_self ? llvm_val[stmt] : llvm_val[stmt->operand], i);
auto trunc_value = trunc_fn(elem, to_ty);
vec = builder->CreateInsertElement(vec, trunc_value, i);
}
llvm_val[stmt] = vec;
}
}

void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) {
auto input = llvm_val[stmt->operand];
auto input_type = input->getType();
Expand All @@ -369,47 +420,102 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) {
llvm::CastInst::CastOps cast_op;
auto from = stmt->operand->ret_type;
auto to = stmt->cast_type;
TI_ASSERT_INFO(
from->is<TensorType>() == to->is<TensorType>(),
"Cannot cast between tensor type and non-tensor type: {} v.s. {}",
from->to_string(), to->to_string());
if (from == to) {
llvm_val[stmt] = llvm_val[stmt->operand];
} else if (is_real(from) != is_real(to)) {
if (is_real(from) && is_integral(to)) {
cast_op = is_signed(to) ? llvm::Instruction::CastOps::FPToSI
: llvm::Instruction::CastOps::FPToUI;
} else if (is_integral(from) && is_real(to)) {
cast_op = is_signed(from) ? llvm::Instruction::CastOps::SIToFP
: llvm::Instruction::CastOps::UIToFP;
} else if (is_real(from) != is_real(to) ||
AD1024 marked this conversation as resolved.
Show resolved Hide resolved
is_real_tensor(from) != is_real_tensor(to)) {
if ((is_real(from) || is_real_tensor(from)) &&
(is_integral(to) || is_integral_tensor(to))) {
cast_op = (is_signed_tensor(to) || is_signed(to))
? llvm::Instruction::CastOps::FPToSI
: llvm::Instruction::CastOps::FPToUI;
} else if ((is_integral(from) || is_integral_tensor(from)) &&
(is_real(to) || is_real_tensor(to))) {
cast_op = (is_signed_tensor(from) || is_signed(from))
? llvm::Instruction::CastOps::SIToFP
: llvm::Instruction::CastOps::UIToFP;
} else {
TI_P(data_type_name(from));
TI_P(data_type_name(to));
TI_NOT_IMPLEMENTED;
}
auto cast_type = to->is_primitive(PrimitiveTypeID::f16)
? PrimitiveType::f32
: stmt->cast_type;

llvm_val[stmt] = builder->CreateCast(cast_op, llvm_val[stmt->operand],
tlctx->get_data_type(cast_type));

if (to->is_primitive(PrimitiveTypeID::f16)) {
llvm_val[stmt] = builder->CreateFPTrunc(
llvm_val[stmt], llvm::Type::getHalfTy(*llvm_context));
bool use_f16 = to->is_primitive(PrimitiveTypeID::f16) ||
(to->is<TensorType>() &&
to->cast<TensorType>()->get_element_type()->is_primitive(
PrimitiveTypeID::f16));
auto cast_type = use_f16 ? (to->is<TensorType>()
? TypeFactory::create_tensor_type(
to->cast<TensorType>()->get_shape(),
PrimitiveType::f32)
: PrimitiveType::f32)
: stmt->cast_type;

auto cast_func = [this, cast_op](llvm::Value *value, llvm::Type *type) {
return this->builder->CreateCast(cast_op, value, type);
};
create_value_cast(stmt, cast_func, cast_type);

if (use_f16) {
auto trunc_func = [this](llvm::Value *value, llvm::Type *type) {
return this->builder->CreateFPTrunc(value, type);
};
create_fp_trunc(stmt, trunc_func, llvm::Type::getHalfTy(*llvm_context),
cast_type->is<TensorType>(), /*trunc_self=*/true);
}
} else if (is_real(from) && is_real(to)) {
if (data_type_size(from) < data_type_size(to)) {
llvm_val[stmt] = builder->CreateFPExt(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
} else if ((is_real(from) || is_real_tensor(from)) &&
(is_real(to) || is_real_tensor(to))) {
auto t1 = from->is<TensorType>()
? from->cast<TensorType>()->get_element_type()
: from.operator->();
auto t2 = to->is<TensorType>()
? to->cast<TensorType>()->get_element_type()
: to.operator->();
if (data_type_size(t1) < data_type_size(t2)) {
auto cast_func = [this](llvm::Value *value, llvm::Type *type) {
return this->builder->CreateFPExt(value, type);
};
create_value_cast(stmt, cast_func, stmt->cast_type);
} else {
if (to->is_primitive(PrimitiveTypeID::f16)) {
llvm_val[stmt] = builder->CreateFPTrunc(
builder->CreateFPTrunc(llvm_val[stmt->operand],
llvm::Type::getFloatTy(*llvm_context)),
llvm::Type::getHalfTy(*llvm_context));
if (to->is_primitive(PrimitiveTypeID::f16) ||
(to->is<TensorType>() &&
to->cast<TensorType>()->get_element_type()->is_primitive(
PrimitiveTypeID::f16))) {
if (!to->is<TensorType>()) {
llvm_val[stmt] = builder->CreateFPTrunc(
builder->CreateFPTrunc(llvm_val[stmt->operand],
llvm::Type::getFloatTy(*llvm_context)),
llvm::Type::getHalfTy(*llvm_context));
} else {
auto tensor_type = to->cast<TensorType>();
llvm::Value *vec = llvm::UndefValue::get(tlctx->get_data_type(to));
for (int i = 0; i < tensor_type->get_num_elements(); ++i) {
auto elem = builder->CreateExtractElement(vec, i);
auto double_trunced = builder->CreateFPTrunc(
builder->CreateFPTrunc(elem,
llvm::Type::getFloatTy(*llvm_context)),
llvm::Type::getHalfTy(*llvm_context));
vec = builder->CreateInsertElement(vec, double_trunced, i);
}
llvm_val[stmt] = vec;
}
} else {
llvm_val[stmt] = builder->CreateFPTrunc(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
auto trunc_fn = [this](llvm::Value *value, llvm::Type *type) {
return this->builder->CreateFPTrunc(value, type);
};
auto cast_type =
stmt->cast_type->is<TensorType>()
? stmt->cast_type->cast<TensorType>()->get_element_type()
: stmt->cast_type.operator->();
create_fp_trunc(stmt, trunc_fn, tlctx->get_data_type(cast_type),
stmt->cast_type->is<TensorType>());
}
}
} else if (!is_real(from) && !is_real(to)) {
} else if (!(is_real(from) || is_real_tensor(from)) &&
!(is_real(to) || is_real_tensor(to))) {
llvm_val[stmt] = builder->CreateIntCast(
llvm_val[stmt->operand], tlctx->get_data_type(to), is_signed(from));
}
Expand Down Expand Up @@ -453,31 +559,31 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
auto ret_type = stmt->ret_type;

if (op == BinaryOpType::add) {
if (is_real(stmt->ret_type)) {
if (is_real_tensor(stmt->ret_type) || is_real(stmt->ret_type)) {
llvm_val[stmt] =
builder->CreateFAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
llvm_val[stmt] =
builder->CreateAdd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
}
} else if (op == BinaryOpType::sub) {
if (is_real(stmt->ret_type)) {
if (is_real_tensor(stmt->ret_type) || is_real(stmt->ret_type)) {
llvm_val[stmt] =
builder->CreateFSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
llvm_val[stmt] =
builder->CreateSub(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
}
} else if (op == BinaryOpType::mul) {
if (is_real(stmt->ret_type)) {
if (is_real_tensor(stmt->ret_type) || is_real(stmt->ret_type)) {
llvm_val[stmt] =
builder->CreateFMul(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
llvm_val[stmt] =
builder->CreateMul(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
}
} else if (op == BinaryOpType::floordiv) {
if (is_integral(ret_type))
if (is_integral_tensor(ret_type) || is_integral(ret_type))
llvm_val[stmt] =
create_call(fmt::format("floordiv_{}", data_type_name(ret_type)),
{llvm_val[stmt->lhs], llvm_val[stmt->rhs]});
Expand All @@ -487,7 +593,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
llvm::Intrinsic::floor, {tlctx->get_data_type(ret_type)}, {div});
}
} else if (op == BinaryOpType::div) {
if (is_real(stmt->ret_type)) {
if (is_real_tensor(stmt->ret_type) || is_real(stmt->ret_type)) {
llvm_val[stmt] =
builder->CreateFDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
Expand Down Expand Up @@ -524,7 +630,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
create_call("max_" #x, {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); \
}

if (is_real(ret_type)) {
if (is_real_tensor(stmt->ret_type) || is_real(ret_type)) {
llvm_val[stmt] =
builder->CreateMaxNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
}
Expand All @@ -545,7 +651,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
create_call("min_" #x, {llvm_val[stmt->lhs], llvm_val[stmt->rhs]}); \
}

if (is_real(ret_type)) {
if (is_real_tensor(stmt->ret_type) || is_real(ret_type)) {
llvm_val[stmt] =
builder->CreateMinNum(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
}
Expand All @@ -563,16 +669,16 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
llvm::Value *cmp = nullptr;
auto input_type = stmt->lhs->ret_type;
if (op == BinaryOpType::cmp_eq) {
if (is_real(input_type)) {
if (is_real_tensor(input_type) || is_real(input_type)) {
cmp = builder->CreateFCmpOEQ(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
cmp = builder->CreateICmpEQ(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
}
} else if (op == BinaryOpType::cmp_le) {
if (is_real(input_type)) {
if (is_real_tensor(input_type) || is_real(input_type)) {
cmp = builder->CreateFCmpOLE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
if (is_signed(input_type)) {
if (is_signed_tensor(input_type) || is_signed(input_type)) {
cmp =
builder->CreateICmpSLE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
Expand All @@ -581,10 +687,10 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
}
}
} else if (op == BinaryOpType::cmp_ge) {
if (is_real(input_type)) {
if (is_real_tensor(input_type) || is_real(input_type)) {
cmp = builder->CreateFCmpOGE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
if (is_signed(input_type)) {
if (is_signed_tensor(input_type) || is_signed(input_type)) {
cmp =
builder->CreateICmpSGE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
Expand All @@ -593,10 +699,10 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
}
}
} else if (op == BinaryOpType::cmp_lt) {
if (is_real(input_type)) {
if (is_real_tensor(input_type) || is_real(input_type)) {
cmp = builder->CreateFCmpOLT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
if (is_signed(input_type)) {
if (is_signed_tensor(input_type) || is_signed(input_type)) {
cmp =
builder->CreateICmpSLT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
Expand All @@ -605,10 +711,10 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
}
}
} else if (op == BinaryOpType::cmp_gt) {
if (is_real(input_type)) {
if (is_real_tensor(input_type) || is_real(input_type)) {
cmp = builder->CreateFCmpOGT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
if (is_signed(input_type)) {
if (is_signed_tensor(input_type) || is_signed(input_type)) {
cmp =
builder->CreateICmpSGT(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
Expand All @@ -617,7 +723,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
}
}
} else if (op == BinaryOpType::cmp_ne) {
if (is_real(input_type)) {
if (is_real_tensor(input_type) || is_real(input_type)) {
cmp = builder->CreateFCmpONE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
cmp = builder->CreateICmpNE(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
Expand Down
12 changes: 12 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,18 @@ class TaskCodeGenLLVM : public IRVisitor, public LLVMModuleBuilder {
llvm::Value *node_meta,
SNode *snode);

void create_value_cast(
UnaryOpStmt *stmt,
std::function<llvm::Value *(llvm::Value *, llvm::Type *)> cast_fn,
DataType to_ty);

void create_fp_trunc(
UnaryOpStmt *stmt,
std::function<llvm::Value *(llvm::Value *, llvm::Type *)> trunc_fn,
llvm::Type *to_ty,
bool is_tensor,
bool trunc_self = false);

std::unique_ptr<RuntimeObject> emit_struct_meta_object(SNode *snode);

llvm::Value *emit_struct_meta(SNode *snode);
Expand Down
Loading