Skip to content

Commit

Permalink
[ir] Constant fold support for u1
Browse files Browse the repository at this point in the history
ghstack-source-id: 0dbe30141e30f436c9831bf018e3c488a896df7f
Pull Request resolved: #8019
  • Loading branch information
listerily committed May 16, 2023
1 parent f47966c commit d92d0a4
Showing 1 changed file with 32 additions and 21 deletions.
53 changes: 32 additions & 21 deletions taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class ConstantFold : public BasicStmtVisitor {
// https://github.com/taichi-dev/taichi/pull/839#issuecomment-625902727
if (dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::i64) ||
dt->is_primitive(PrimitiveTypeID::u1) ||
dt->is_primitive(PrimitiveTypeID::u32) ||
dt->is_primitive(PrimitiveTypeID::u64) ||
dt->is_primitive(PrimitiveTypeID::f32) ||
Expand Down Expand Up @@ -66,26 +67,30 @@ class ConstantFold : public BasicStmtVisitor {
auto dt = lhs->val.dt;
switch (stmt->op_type) {
#define COMMA ,
#define HANDLE_REAL_AND_INTEGRAL_BINARY(OP_TYPE, PREFIX, OP_CPP) \
case BinaryOpType::OP_TYPE: { \
if (dt->is_primitive(PrimitiveTypeID::f32) || \
dt->is_primitive(PrimitiveTypeID::f64)) { \
auto res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_cast_to_float64() \
OP_CPP rhs->val.val_cast_to_float64())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::i32) || \
dt->is_primitive(PrimitiveTypeID::i64)) { \
auto res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_int() OP_CPP rhs->val.val_int())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::u32) || \
dt->is_primitive(PrimitiveTypeID::u64)) { \
auto res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_uint() OP_CPP rhs->val.val_uint())); \
insert_and_erase(stmt, res); \
} \
break; \
#define HANDLE_REAL_AND_INTEGRAL_BINARY(OP_TYPE, PREFIX, OP_CPP) \
case BinaryOpType::OP_TYPE: { \
if (dt->is_primitive(PrimitiveTypeID::f32) || \
dt->is_primitive(PrimitiveTypeID::f64)) { \
auto res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_cast_to_float64() \
OP_CPP rhs->val.val_cast_to_float64())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::i32) || \
dt->is_primitive(PrimitiveTypeID::i64)) { \
auto res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_int() OP_CPP rhs->val.val_int())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::u32) || \
dt->is_primitive(PrimitiveTypeID::u64)) { \
auto res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_uint() OP_CPP rhs->val.val_uint())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::u1)) { \
auto res = TypedConstant( \
dst_type, PREFIX(lhs->val.val_uint1() OP_CPP rhs->val.val_uint1())); \
insert_and_erase(stmt, res); \
} \
break; \
}

HANDLE_REAL_AND_INTEGRAL_BINARY(mul, , *)
Expand Down Expand Up @@ -179,6 +184,9 @@ class ConstantFold : public BasicStmtVisitor {
dt->is_primitive(PrimitiveTypeID::u64)) { \
auto res = TypedConstant(dst_type, OP_CPP(operand->val.val_uint())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::u1)) { \
auto res = TypedConstant(dst_type, OP_CPP(operand->val.val_uint1())); \
insert_and_erase(stmt, res); \
} \
break; \
}
Expand All @@ -199,7 +207,6 @@ class ConstantFold : public BasicStmtVisitor {
HANDLE_REAL_AND_INTEGRAL_UNARY(exp, std::exp)
HANDLE_REAL_AND_INTEGRAL_UNARY(cast_value, )
HANDLE_REAL_AND_INTEGRAL_UNARY(rsqrt, 1.0 / std::sqrt)
HANDLE_REAL_AND_INTEGRAL_UNARY(logic_not, !)
#undef HANDLE_REAL_AND_INTEGRAL_UNARY

#define HANDLE_INTEGRAL_UNARY(OP_TYPE, OP_CPP) \
Expand All @@ -212,11 +219,15 @@ class ConstantFold : public BasicStmtVisitor {
dt->is_primitive(PrimitiveTypeID::u64)) { \
auto res = TypedConstant(dst_type, OP_CPP(operand->val.val_uint())); \
insert_and_erase(stmt, res); \
} else if (dt->is_primitive(PrimitiveTypeID::u1)) { \
auto res = TypedConstant(dst_type, !operand->val.val_uint1()); \
insert_and_erase(stmt, res); \
} \
break; \
}

HANDLE_INTEGRAL_UNARY(bit_not, ~)
HANDLE_INTEGRAL_UNARY(logic_not, !)
#undef HANDLE_INTEGRAL_UNARY

default:
Expand Down

0 comments on commit d92d0a4

Please sign in to comment.