Skip to content

Commit

Permalink
[lang] Added ti.u1 definition
Browse files Browse the repository at this point in the history
ghstack-source-id: ac52abfd5e136811d4a7f3b86d5b36362ed94f4f
Pull Request resolved: #7995
  • Loading branch information
listerily authored and feisuzhu committed May 16, 2023
1 parent bf7998b commit 1b84a2e
Show file tree
Hide file tree
Showing 23 changed files with 130 additions and 29 deletions.
15 changes: 14 additions & 1 deletion python/taichi/lang/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
i16,
i32,
i64,
u1,
u8,
u16,
u32,
Expand Down Expand Up @@ -119,6 +120,8 @@ def to_numpy_type(dt):
return np.int8
if dt == i16:
return np.int16
if dt == u1:
return np.bool_
if dt == u8:
return np.uint8
if dt == u16:
Expand Down Expand Up @@ -157,6 +160,8 @@ def to_pytorch_type(dt):
return torch.int8
if dt == i16:
return torch.int16
if dt == u1:
return torch.bool
if dt == u8:
return torch.uint8
if dt == f16:
Expand Down Expand Up @@ -190,6 +195,8 @@ def to_paddle_type(dt):
return paddle.int8
if dt == i16:
return paddle.int16
if dt == u1:
return paddle.bool
if dt == u8:
return paddle.uint8
if dt == f16:
Expand Down Expand Up @@ -224,6 +231,8 @@ def to_taichi_type(dt):
return i8
if dt == np.int16:
return i16
if dt == np.bool_:
return u1
if dt == np.uint8:
return u8
if dt == np.uint16:
Expand Down Expand Up @@ -251,6 +260,8 @@ def to_taichi_type(dt):
return i8
if dt == torch.int16:
return i16
if dt == torch.bool:
return u1
if dt == torch.uint8:
return u8
if dt == torch.float16:
Expand All @@ -273,6 +284,8 @@ def to_taichi_type(dt):
return i8
if dt == paddle.int16:
return i16
if dt == paddle.bool:
return u1
if dt == paddle.uint8:
return u8
if dt == paddle.float16:
Expand All @@ -293,7 +306,7 @@ def cook_dtype(dtype):
if dtype is int:
return impl.get_runtime().default_ip
if dtype is bool:
return i32 # TODO[Xiaoyan]: Use i1 in the future
return i32 # TODO(zhantong): Replace it with u1
raise ValueError(f"Invalid data type {dtype}")


Expand Down
16 changes: 15 additions & 1 deletion python/taichi/types/primitive_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@

# ----------------------------------------

uint1 = ti_python_core.DataType_u1
"""1-bit unsigned integer data type. Same as booleans.
"""

# ----------------------------------------

u1 = uint1
"""Alias for :const:`~taichi.types.primitive_types.uint1`
"""

# ----------------------------------------

u8 = uint8
"""Alias for :const:`~taichi.types.primitive_types.uint8`
"""
Expand Down Expand Up @@ -154,7 +166,7 @@ def ref(tp):
real_types = [f16, f32, f64, float]
real_type_ids = [id(t) for t in real_types]

integer_types = [i8, i16, i32, i64, u8, u16, u32, u64, int, bool]
integer_types = [i8, i16, i32, i64, u1, u8, u16, u32, u64, int, bool]
integer_type_ids = [id(t) for t in integer_types]

all_types = real_types + integer_types
Expand All @@ -175,6 +187,8 @@ def ref(tp):
"i32",
"int64",
"i64",
"uint1",
"u1",
"uint8",
"u8",
"uint16",
Expand Down
3 changes: 3 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,9 @@ void TaskCodeGenLLVM::visit(ConstStmt *stmt) {
} else if (val.dt->is_primitive(PrimitiveTypeID::f64)) {
llvm_val[stmt] =
llvm::ConstantFP::get(*llvm_context, llvm::APFloat(val.val_float64()));
} else if (val.dt->is_primitive(PrimitiveTypeID::u1)) {
llvm_val[stmt] = llvm::ConstantInt::get(
*llvm_context, llvm::APInt(1, (uint64)val.val_uint1(), false));
} else if (val.dt->is_primitive(PrimitiveTypeID::i8)) {
llvm_val[stmt] = llvm::ConstantInt::get(
*llvm_context, llvm::APInt(8, (uint64)val.val_int8(), true));
Expand Down
3 changes: 3 additions & 0 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ class TaskCodegen : public IRVisitor {
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
return ir_->int_immediate_number(
stype, static_cast<int64_t>(const_val.val_i16), false);
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return ir_->uint_immediate_number(
stype, static_cast<uint64_t>(const_val.val_u1), false);
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return ir_->uint_immediate_number(
stype, static_cast<uint64_t>(const_val.val_u8), false);
Expand Down
10 changes: 7 additions & 3 deletions taichi/codegen/spirv/spirv_ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ SType IRBuilder::get_primitive_uint_type(const DataType &dt) const {
} else if (dt == PrimitiveType::i16 || dt == PrimitiveType::u16 ||
dt == PrimitiveType::f16) {
return t_uint16_;
} else if (dt == PrimitiveType::u1) {
return t_bool_;
} else {
return t_uint8_;
}
Expand All @@ -392,6 +394,8 @@ DataType IRBuilder::get_taichi_uint_type(const DataType &dt) const {
} else if (dt == PrimitiveType::i16 || dt == PrimitiveType::u16 ||
dt == PrimitiveType::f16) {
return PrimitiveType::u16;
} else if (dt == PrimitiveType::u1) {
return PrimitiveType::u1;
} else {
return PrimitiveType::u8;
}
Expand Down Expand Up @@ -1090,10 +1094,10 @@ DEFINE_BUILDER_CMP_OP(ge, GreaterThanEqual);
Value IRBuilder::_OpName(Value a, Value b) { \
TI_ASSERT(a.stype.id == b.stype.id); \
const auto &bool_type = t_bool_; /* TODO: Only scalar supported now */ \
if (is_integral(a.stype.dt)) { \
return make_value(spv::OpI##_Op, bool_type, a, b); \
} else if (a.stype.id == bool_type.id) { \
if (a.stype.id == bool_type.id) { \
return make_value(spv::OpLogical##_Op, bool_type, a, b); \
} else if (is_integral(a.stype.dt)) { \
return make_value(spv::OpI##_Op, bool_type, a, b); \
} else { \
TI_ASSERT(is_real(a.stype.dt)); \
return make_value(spv::OpFOrd##_Op, bool_type, a, b); \
Expand Down
12 changes: 11 additions & 1 deletion taichi/codegen/spirv/spirv_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ const tinyir::Type *translate_ti_primitive(tinyir::Block &ir_module,
} else if (t == PrimitiveType::i64) {
return ir_module.emplace_back<IntType>(/*num_bits=*/64,
/*is_signed=*/true);
} else if (t == PrimitiveType::u1) {
// Spir-v has no full support for boolean types, using boolean types in
// backend may cause issues. These issues arise when we use boolean as
// return type, argument type and inner dtype of compount types. Since
// boolean types has the same width with int32 in GLSL, we use int32
// instead.
return ir_module.emplace_back<IntType>(/*num_bits=*/32,
/*is_signed=*/false);
} else if (t == PrimitiveType::u8) {
return ir_module.emplace_back<IntType>(/*num_bits=*/8,
/*is_signed=*/false);
Expand Down Expand Up @@ -395,7 +403,9 @@ class Translate2Spirv : public TypeVisitor {
vt = spir_builder_->i64_type();
}
} else {
if (type->num_bits() == 8) {
if (type->num_bits() == 1) {
vt = spir_builder_->bool_type();
} else if (type->num_bits() == 8) {
vt = spir_builder_->u8_type();
} else if (type->num_bits() == 16) {
vt = spir_builder_->u16_type();
Expand Down
2 changes: 2 additions & 0 deletions taichi/common/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ class CoreState {
// Types
//******************************************************************************

using uint1 = bool;

using uchar = unsigned char;

using int8 = int8_t;
Expand Down
2 changes: 2 additions & 0 deletions taichi/common/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

namespace taichi {

using uint1 = bool;

using uchar = unsigned char;

using int8 = int8_t;
Expand Down
3 changes: 2 additions & 1 deletion taichi/inc/data_type_with_c_type.inc.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
// Doesn't contain f16 and u1.
// Doesn't contain f16.
PER_C_TYPE(f32, float32)
PER_C_TYPE(f64, float64)
PER_C_TYPE(i8, int8)
PER_C_TYPE(i16, int16)
PER_C_TYPE(i32, int32)
PER_C_TYPE(i64, int64)
PER_C_TYPE(u1, uint1)
PER_C_TYPE(u8, uint8)
PER_C_TYPE(u16, uint16)
PER_C_TYPE(u32, uint32)
Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ void Expr::set_adjoint_checkbit(const Expr &o) {
this->cast<FieldExpression>()->adjoint_checkbit.set(o);
}

Expr::Expr(uint1 x) : Expr() {
expr = std::make_shared<ConstExpression>(PrimitiveType::u1, x);
}

Expr::Expr(int16 x) : Expr() {
expr = std::make_shared<ConstExpression>(PrimitiveType::i16, x);
}
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class Expr {
atomic = false;
}

explicit Expr(uint1 x);

explicit Expr(int16 x);

explicit Expr(int32 x);
Expand Down
9 changes: 5 additions & 4 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,11 @@ class FrontendSNodeOpStmt : public Stmt {
ExprGroup indices;
Expr val;

FrontendSNodeOpStmt(SNodeOpType op_type,
SNode *snode,
const ExprGroup &indices,
const Expr &val = Expr(nullptr));
FrontendSNodeOpStmt(
SNodeOpType op_type,
SNode *snode,
const ExprGroup &indices,
const Expr &val = Expr(std::shared_ptr<Expression>(nullptr)));

TI_DEFINE_ACCEPT
TI_DEFINE_CLONE_FOR_FRONTEND_IR
Expand Down
11 changes: 11 additions & 0 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ std::string TypedConstant::stringify() const {
return fmt::format("{}", val_i8);
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
return fmt::format("{}", val_i16);
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return fmt::format("{}", val_u1);
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return fmt::format("{}", val_u8);
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
Expand Down Expand Up @@ -391,6 +393,8 @@ bool TypedConstant::equal_type_and_value(const TypedConstant &o) const {
return val_i8 == o.val_i8;
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
return val_i16 == o.val_i16;
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return val_u1 == o.val_u1;
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return val_u8 == o.val_u8;
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
Expand Down Expand Up @@ -440,6 +444,11 @@ int16 &TypedConstant::val_int16() {
return val_i16;
}

uint1 &TypedConstant::val_uint1() {
TI_ASSERT(get_data_type<uint1>() == dt);
return val_u1;
}

uint8 &TypedConstant::val_uint8() {
TI_ASSERT(get_data_type<uint8>() == dt);
return val_u8;
Expand Down Expand Up @@ -483,6 +492,8 @@ uint64 TypedConstant::val_uint() const {
return val_u64;
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return val_u8;
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return val_u1;
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
return val_u16;
} else {
Expand Down
7 changes: 7 additions & 0 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ class TypedConstant {
float64 val_f64;
int8 val_i8;
int16 val_i16;
uint1 val_u1;
uint8 val_u8;
uint16 val_u16;
uint32 val_u32;
Expand Down Expand Up @@ -564,6 +565,9 @@ class TypedConstant {
explicit TypedConstant(int16 x) : dt(PrimitiveType::i16), val_i16(x) {
}

explicit TypedConstant(uint1 x) : dt(PrimitiveType::u1), val_u1(x) {
}

explicit TypedConstant(uint8 x) : dt(PrimitiveType::u8), val_u8(x) {
}

Expand Down Expand Up @@ -594,6 +598,8 @@ class TypedConstant {
val_i8 = value;
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
val_i16 = value;
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
val_u1 = value;
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
val_u8 = value;
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
Expand Down Expand Up @@ -627,6 +633,7 @@ class TypedConstant {
float64 &val_float64();
int8 &val_int8();
int16 &val_int16();
uint1 &val_uint1();
uint8 &val_uint8();
uint16 &val_uint16();
uint32 &val_uint32();
Expand Down
5 changes: 4 additions & 1 deletion taichi/ir/type_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ int data_type_size(DataType t) {
REGISTER_DATA_TYPE(i16, int16);
REGISTER_DATA_TYPE(i32, int32);
REGISTER_DATA_TYPE(i64, int64);
REGISTER_DATA_TYPE(u1, uint1);
REGISTER_DATA_TYPE(u8, uint8);
REGISTER_DATA_TYPE(u16, uint16);
REGISTER_DATA_TYPE(u32, uint32);
Expand Down Expand Up @@ -99,7 +100,9 @@ std::string tensor_type_format(DataType t, Arch arch) {
}

std::string data_type_format(DataType dt, Arch arch) {
if (dt->is_primitive(PrimitiveTypeID::i8)) {
if (dt->is_primitive(PrimitiveTypeID::u1)) {
return "%d";
} else if (dt->is_primitive(PrimitiveTypeID::i8)) {
// i8/u8 is converted to i16/u16 before printing, because CUDA doesn't
// support the "%hhd"/"%hhu" specifiers.
return "%hd";
Expand Down
7 changes: 7 additions & 0 deletions taichi/ir/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ inline DataType get_data_type() {
return PrimitiveType::i32;
} else if (std::is_same<T, int64>()) {
return PrimitiveType::i64;
} else if (std::is_same<T, uint1>()) {
return PrimitiveType::u1;
} else if (std::is_same<T, uint8>()) {
return PrimitiveType::u8;
} else if (std::is_same<T, uint16>()) {
Expand Down Expand Up @@ -101,6 +103,7 @@ inline bool is_integral(DataType dt) {
dt->is_primitive(PrimitiveTypeID::i16) ||
dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::i64) ||
dt->is_primitive(PrimitiveTypeID::u1) ||
dt->is_primitive(PrimitiveTypeID::u8) ||
dt->is_primitive(PrimitiveTypeID::u16) ||
dt->is_primitive(PrimitiveTypeID::u32) ||
Expand Down Expand Up @@ -146,6 +149,8 @@ inline TypedConstant get_max_value(DataType dt) {
return {dt, std::numeric_limits<int32>::max()};
} else if (dt->is_primitive(PrimitiveTypeID::i64)) {
return {dt, std::numeric_limits<int64>::max()};
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return {dt, std::numeric_limits<uint1>::max()};
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return {dt, std::numeric_limits<uint8>::max()};
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
Expand All @@ -172,6 +177,8 @@ inline TypedConstant get_min_value(DataType dt) {
return {dt, std::numeric_limits<int32>::lowest()};
} else if (dt->is_primitive(PrimitiveTypeID::i64)) {
return {dt, std::numeric_limits<int64>::lowest()};
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return {dt, std::numeric_limits<uint1>::lowest()};
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return {dt, std::numeric_limits<uint8>::lowest()};
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ T Kernel::fetch_ret(DataType dt, int i) {
return (T)program->fetch_result<int8>(i);
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
return (T)program->fetch_result<int16>(i);
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return (T)program->fetch_result<uint1>(i);
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return (T)program->fetch_result<uint8>(i);
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/program/launch_context_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ void LaunchContextBuilder::set_arg_int(int arg_id, int64 d) {
set_arg(arg_id, (int8)d);
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
set_arg(arg_id, (int16)d);
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
set_arg(arg_id, (uint1)d);
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
set_arg(arg_id, (uint8)d);
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
Expand Down
Loading

0 comments on commit 1b84a2e

Please sign in to comment.