Skip to content

Commit

Permalink
[ir] [refactor] Separate UnaryOpType::cast into cast_value and `c…
Browse files Browse the repository at this point in the history
…ast_bits` (#892)

* [skip ci] basic setup

* [skip ci] rename all

* [skip ci] fix typo

* trigger test

* [skip ci] apply reviews

* add bit_cast

* [skip ci] nit name

* nit unary_op_is_cast

* fix

* fix again

* fix name conflict
  • Loading branch information
archibate authored Apr 29, 2020
1 parent b67f4fa commit 75335e9
Show file tree
Hide file tree
Showing 18 changed files with 127 additions and 106 deletions.
6 changes: 6 additions & 0 deletions python/taichi/lang/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,12 @@ def cast(obj, type):
else:
return Expr(taichi_lang_core.value_cast(Expr(obj).ptr, type))

def bit_cast(obj, type):
if is_taichi_class(obj):
raise ValueError('Cannot apply bit_cast on Taichi classes')
else:
return Expr(taichi_lang_core.bits_cast(Expr(obj).ptr, type))


def sqr(obj):
return obj * obj
Expand Down
15 changes: 6 additions & 9 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,17 +325,11 @@ class KernelCodegen : public IRVisitor {
}

void visit(UnaryOpStmt *stmt) override {
if (stmt->op_type != UnaryOpType::cast) {
emit("const {} {} = {}({});", metal_data_type_name(stmt->element_type()),
stmt->raw_name(), metal_unary_op_type_symbol(stmt->op_type),
stmt->operand->raw_name());
} else {
// cast
if (stmt->cast_by_value) {
if (stmt->op_type == UnaryOpType::cast_value) {
emit("const {} {} = static_cast<{}>({});",
metal_data_type_name(stmt->element_type()), stmt->raw_name(),
metal_data_type_name(stmt->cast_type), stmt->operand->raw_name());
} else {
} else if (stmt->op_type == UnaryOpType::cast_bits) {
// reinterpret the bit pattern
const auto to_type = to_metal_type(stmt->cast_type);
const auto to_type_name = metal_data_type_name(to_type);
Expand All @@ -344,7 +338,10 @@ class KernelCodegen : public IRVisitor {
metal_data_type_bytes(to_type));
emit("const {} {} = union_cast<{}>({});", to_type_name,
stmt->raw_name(), to_type_name, stmt->operand->raw_name());
}
} else {
emit("const {} {} = {}({});", metal_data_type_name(stmt->element_type()),
stmt->raw_name(), metal_unary_op_type_symbol(stmt->op_type),
stmt->operand->raw_name());
}
}

Expand Down
13 changes: 6 additions & 7 deletions taichi/backends/opengl/codegen_opengl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -349,16 +349,12 @@ class KernelGen : public IRVisitor {
} else if (stmt->op_type == UnaryOpType::bit_not) {
emit("{} {} = {}(~{});", dt_name, stmt->short_name(), dt_name,
stmt->operand->short_name());
} else if (stmt->op_type != UnaryOpType::cast) {
emit("{} {} = {}({}({}));", dt_name, stmt->short_name(), dt_name,
unary_op_type_name(stmt->op_type), stmt->operand->short_name());
} else {
// cast
if (stmt->cast_by_value) {
} else if (stmt->op_type == UnaryOpType::cast_value) {
emit("{} {} = {}({});", dt_name, stmt->short_name(),
opengl_data_type_name(stmt->cast_type),
stmt->operand->short_name());
} else if (stmt->cast_type == DataType::f32 &&
} else if (stmt->op_type == UnaryOpType::cast_bits) {
if (stmt->cast_type == DataType::f32 &&
stmt->operand->element_type() == DataType::i32) {
emit("{} {} = intBitsToFloat({});", dt_name, stmt->short_name(),
stmt->operand->short_name());
Expand All @@ -369,6 +365,9 @@ class KernelGen : public IRVisitor {
} else {
TI_ERROR("unsupported reinterpret cast");
}
} else {
emit("{} {} = {}({}({}));", dt_name, stmt->short_name(), dt_name,
unary_op_type_name(stmt->op_type), stmt->operand->short_name());
}
}

Expand Down
112 changes: 53 additions & 59 deletions taichi/codegen/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,71 +292,65 @@ void CodeGenLLVM::visit(UnaryOpStmt *stmt) {
llvm_val[stmt] = \
builder->CreateIntrinsic(llvm::Intrinsic::x, {input_type}, {input}); \
}

if (stmt->op_type != UnaryOpType::cast) {
if (op == UnaryOpType::rsqrt) {
llvm::Function *sqrt_fn = Intrinsic::getDeclaration(
module.get(), Intrinsic::sqrt, input->getType());
auto intermediate = builder->CreateCall(sqrt_fn, input, "sqrt");
llvm_val[stmt] = builder->CreateFDiv(
tlctx->get_constant(stmt->ret_type.data_type, 1.0), intermediate);
} else if (op == UnaryOpType::bit_not) {
llvm_val[stmt] = builder->CreateNot(input);
} else if (op == UnaryOpType::neg) {
if (is_real(stmt->operand->ret_type.data_type)) {
llvm_val[stmt] = builder->CreateFNeg(input, "neg");
if (stmt->op_type == UnaryOpType::cast_value) {
llvm::CastInst::CastOps cast_op;
auto from = stmt->operand->ret_type.data_type;
auto to = stmt->cast_type;
TI_ASSERT(from != to);
if (is_real(from) != is_real(to)) {
if (is_real(from) && is_integral(to)) {
cast_op = llvm::Instruction::CastOps::FPToSI;
} else if (is_integral(from) && is_real(to)) {
cast_op = llvm::Instruction::CastOps::SIToFP;
} else {
llvm_val[stmt] = builder->CreateNeg(input, "neg");
TI_P(data_type_name(from));
TI_P(data_type_name(to));
TI_NOT_IMPLEMENTED;
}
}
UNARY_INTRINSIC(floor)
UNARY_INTRINSIC(ceil)
else emit_extra_unary(stmt);
#undef UNARY_INTRINSIC
} else {
// op = cast
if (stmt->cast_by_value) {
llvm::CastInst::CastOps cast_op;
auto from = stmt->operand->ret_type.data_type;
auto to = stmt->cast_type;
TI_ASSERT(from != to);
if (is_real(from) != is_real(to)) {
if (is_real(from) && is_integral(to)) {
cast_op = llvm::Instruction::CastOps::FPToSI;
} else if (is_integral(from) && is_real(to)) {
cast_op = llvm::Instruction::CastOps::SIToFP;
} else {
TI_P(data_type_name(from));
TI_P(data_type_name(to));
TI_NOT_IMPLEMENTED;
}
llvm_val[stmt] =
builder->CreateCast(cast_op, llvm_val[stmt->operand],
tlctx->get_data_type(stmt->cast_type));
} else if (is_real(from) && is_real(to)) {
if (data_type_size(from) < data_type_size(to)) {
llvm_val[stmt] = builder->CreateFPExt(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
} else {
llvm_val[stmt] = builder->CreateFPTrunc(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
}
} else if (!is_real(from) && !is_real(to)) {
if (data_type_size(from) < data_type_size(to)) {
llvm_val[stmt] = builder->CreateSExt(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
} else {
llvm_val[stmt] = builder->CreateTrunc(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
}
llvm_val[stmt] =
builder->CreateCast(cast_op, llvm_val[stmt->operand],
tlctx->get_data_type(stmt->cast_type));
} else if (is_real(from) && is_real(to)) {
if (data_type_size(from) < data_type_size(to)) {
llvm_val[stmt] = builder->CreateFPExt(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
} else {
llvm_val[stmt] = builder->CreateFPTrunc(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
}
} else if (!is_real(from) && !is_real(to)) {
if (data_type_size(from) < data_type_size(to)) {
llvm_val[stmt] = builder->CreateSExt(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
} else {
llvm_val[stmt] = builder->CreateTrunc(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
}
}
} else if (stmt->op_type == UnaryOpType::cast_bits) {
TI_ASSERT(data_type_size(stmt->ret_type.data_type) ==
data_type_size(stmt->cast_type));
llvm_val[stmt] = builder->CreateBitCast(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
} else if (op == UnaryOpType::rsqrt) {
llvm::Function *sqrt_fn = Intrinsic::getDeclaration(
module.get(), Intrinsic::sqrt, input->getType());
auto intermediate = builder->CreateCall(sqrt_fn, input, "sqrt");
llvm_val[stmt] = builder->CreateFDiv(
tlctx->get_constant(stmt->ret_type.data_type, 1.0), intermediate);
} else if (op == UnaryOpType::bit_not) {
llvm_val[stmt] = builder->CreateNot(input);
} else if (op == UnaryOpType::neg) {
if (is_real(stmt->operand->ret_type.data_type)) {
llvm_val[stmt] = builder->CreateFNeg(input, "neg");
} else {
TI_ASSERT(data_type_size(stmt->ret_type.data_type) ==
data_type_size(stmt->cast_type));
llvm_val[stmt] = builder->CreateBitCast(
llvm_val[stmt->operand], tlctx->get_data_type(stmt->cast_type));
llvm_val[stmt] = builder->CreateNeg(input, "neg");
}
}
UNARY_INTRINSIC(floor)
UNARY_INTRINSIC(ceil)
else emit_extra_unary(stmt);
#undef UNARY_INTRINSIC
}

void CodeGenLLVM::visit(BinaryOpStmt *stmt) {
Expand Down
3 changes: 2 additions & 1 deletion taichi/inc/unary_op.inc.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@ PER_UNARY_OP(neg)
PER_UNARY_OP(sqrt)
PER_UNARY_OP(floor)
PER_UNARY_OP(ceil)
PER_UNARY_OP(cast)
PER_UNARY_OP(cast_value)
PER_UNARY_OP(cast_bits)
PER_UNARY_OP(abs)
PER_UNARY_OP(sgn)
PER_UNARY_OP(sin)
Expand Down
6 changes: 2 additions & 4 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,16 +35,14 @@ Expr operator~(const Expr &expr) {
}

Expr cast(const Expr &input, DataType dt) {
auto ret = std::make_shared<UnaryOpExpression>(UnaryOpType::cast, input);
auto ret = std::make_shared<UnaryOpExpression>(UnaryOpType::cast_value, input);
ret->cast_type = dt;
ret->cast_by_value = true;
return Expr(ret);
}

Expr bit_cast(const Expr &input, DataType dt) {
auto ret = std::make_shared<UnaryOpExpression>(UnaryOpType::cast, input);
auto ret = std::make_shared<UnaryOpExpression>(UnaryOpType::cast_bits, input);
ret->cast_type = dt;
ret->cast_by_value = false;
return Expr(ret);
}

Expand Down
18 changes: 12 additions & 6 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,13 +378,16 @@ UnaryOpStmt::UnaryOpStmt(UnaryOpType op_type, Stmt *operand)
: op_type(op_type), operand(operand) {
TI_ASSERT(!operand->is<AllocaStmt>());
cast_type = DataType::unknown;
cast_by_value = true;
TI_STMT_REG_FIELDS;
}

bool UnaryOpStmt::is_cast() const {
return unary_op_is_cast(op_type);
}

bool UnaryOpStmt::same_operation(UnaryOpStmt *o) const {
if (op_type == o->op_type) {
if (op_type == UnaryOpType::cast) {
if (is_cast()) {
return cast_type == o->cast_type;
} else {
return true;
Expand All @@ -394,8 +397,8 @@ bool UnaryOpStmt::same_operation(UnaryOpStmt *o) const {
}

std::string UnaryOpExpression::serialize() {
if (type == UnaryOpType::cast) {
std::string reint = cast_by_value ? "" : "reinterpret_";
if (is_cast()) {
std::string reint = type == UnaryOpType::cast_value ? "" : "reinterpret_";
return fmt::format("({}{}<{}> {})", reint, unary_op_type_name(type),
data_type_name(cast_type), operand->serialize());
} else {
Expand All @@ -404,12 +407,15 @@ std::string UnaryOpExpression::serialize() {
}
}

bool UnaryOpExpression::is_cast() const {
return unary_op_is_cast(type);
}

void UnaryOpExpression::flatten(VecStatement &ret) {
operand->flatten(ret);
auto unary = std::make_unique<UnaryOpStmt>(type, operand->stmt);
if (type == UnaryOpType::cast) {
if (is_cast()) {
unary->cast_type = cast_type;
unary->cast_by_value = cast_by_value;
}
stmt = unary.get();
stmt->tb = tb;
Expand Down
8 changes: 4 additions & 4 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -896,17 +896,17 @@ class UnaryOpStmt : public Stmt {
UnaryOpType op_type;
Stmt *operand;
DataType cast_type;
bool cast_by_value = true;

UnaryOpStmt(UnaryOpType op_type, Stmt *operand);

bool same_operation(UnaryOpStmt *o) const;
bool is_cast() const;

virtual bool has_global_side_effect() const override {
return false;
}

TI_STMT_DEF_FIELDS(ret_type, op_type, operand, cast_type, cast_by_value);
TI_STMT_DEF_FIELDS(ret_type, op_type, operand, cast_type);
DEFINE_ACCEPT
};

Expand Down Expand Up @@ -1020,14 +1020,14 @@ class UnaryOpExpression : public Expression {
UnaryOpType type;
Expr operand;
DataType cast_type;
bool cast_by_value;

UnaryOpExpression(UnaryOpType type, const Expr &operand)
: type(type), operand(smart_load(operand)) {
cast_type = DataType::unknown;
cast_by_value = true;
}

bool is_cast() const;

std::string serialize() override;

void flatten(VecStatement &ret) override;
Expand Down
4 changes: 4 additions & 0 deletions taichi/lang_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ enum class UnaryOpType : int {

std::string unary_op_type_name(UnaryOpType type);

inline bool unary_op_is_cast(UnaryOpType op) {
return op == UnaryOpType::cast_value || op == UnaryOpType::cast_bits;
}

inline bool constexpr is_trigonometric(UnaryOpType op) {
return op == UnaryOpType::sin || op == UnaryOpType::asin ||
op == UnaryOpType::cos || op == UnaryOpType::acos ||
Expand Down
1 change: 1 addition & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -326,6 +326,7 @@ void export_lang(py::module &m) {
m.def("layout", layout);

m.def("value_cast", static_cast<Expr (*)(const Expr &expr, DataType)>(cast));
m.def("bits_cast", static_cast<Expr (*)(const Expr &expr, DataType)>(bit_cast));

m.def("expr_atomic_add", [&](const Expr &a, const Expr &b) {
return Expr::make<AtomicOpExpression>(AtomicOpType::add, ptr_if_global(a),
Expand Down
4 changes: 2 additions & 2 deletions taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ class ConstantFold : public BasicStmtVisitor {
}

void visit(UnaryOpStmt *stmt) override {
if (stmt->width() == 1 && stmt->op_type == UnaryOpType::cast &&
stmt->cast_by_value && stmt->operand->is<ConstStmt>()) {
if (stmt->width() == 1 && stmt->op_type == UnaryOpType::cast_value &&
stmt->operand->is<ConstStmt>()) {
auto input = stmt->operand->as<ConstStmt>()->val[0];
auto src_type = stmt->operand->ret_type.data_type;
auto dst_type = stmt->ret_type.data_type;
Expand Down
5 changes: 3 additions & 2 deletions taichi/transforms/ir_printer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,9 @@ class IRPrinter : public IRVisitor {
}

void visit(UnaryOpStmt *stmt) override {
if (stmt->op_type == UnaryOpType::cast) {
std::string reint = stmt->cast_by_value ? "" : "reinterpret_";
if (stmt->is_cast()) {
std::string reint = stmt->op_type == UnaryOpType::cast_value ?
"" : "reinterpret_";
print("{}{} = {}{}<{}> {}", stmt->type_hint(), stmt->name(), reint,
unary_op_type_name(stmt->op_type),
data_type_short_name(stmt->cast_type), stmt->operand->name());
Expand Down
5 changes: 2 additions & 3 deletions taichi/transforms/make_adjoint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -241,9 +241,8 @@ class MakeAdjoint : public IRVisitor {
} else if (stmt->op_type == UnaryOpType::sqrt) {
accumulate(stmt->operand,
mul(adjoint(stmt), div(constant(0.5f), sqrt(stmt->operand))));
} else if (stmt->op_type == UnaryOpType::cast) {
if (stmt->cast_by_value && is_real(stmt->cast_type) &&
is_real(stmt->operand->ret_type.data_type)) {
} else if (stmt->op_type == UnaryOpType::cast_value) {
if (is_real(stmt->cast_type) && is_real(stmt->operand->ret_type.data_type)) {
accumulate(stmt->operand, adjoint(stmt));
}
} else if (stmt->op_type == UnaryOpType::logic_not) {
Expand Down
2 changes: 1 addition & 1 deletion taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,7 @@ class BasicBlockSimplify : public IRVisitor {
void visit(UnaryOpStmt *stmt) override {
if (is_done(stmt))
return;
if (stmt->op_type == UnaryOpType::cast) {
if (stmt->is_cast()) {
if (stmt->cast_type == stmt->operand->ret_type.data_type) {
stmt->replace_with(stmt->operand);
stmt->parent->erase(current_stmt_id);
Expand Down
1 change: 0 additions & 1 deletion taichi/transforms/slp_vectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ class BasicBlockSLP : public IRVisitor {
dynamic_cast<UnaryOpStmt *>(building_pack[0])->op_type,
tmp_operands[0]);
tmp_stmt->as<UnaryOpStmt>()->cast_type = stmt->cast_type;
tmp_stmt->as<UnaryOpStmt>()->cast_by_value = stmt->cast_by_value;
update_type(stmt);
/*
if (tmp_stmt->as<UnaryOpStmt>()->op_type == UnaryOpType::cast) {
Expand Down
Loading

0 comments on commit 75335e9

Please sign in to comment.