Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] Added ti.u1 definition #7995

Merged
merged 3 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion python/taichi/lang/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
i16,
i32,
i64,
u1,
u8,
u16,
u32,
Expand Down Expand Up @@ -119,6 +120,8 @@ def to_numpy_type(dt):
return np.int8
if dt == i16:
return np.int16
if dt == u1:
return np.bool_
if dt == u8:
return np.uint8
if dt == u16:
Expand Down Expand Up @@ -157,6 +160,8 @@ def to_pytorch_type(dt):
return torch.int8
if dt == i16:
return torch.int16
if dt == u1:
return torch.bool
if dt == u8:
return torch.uint8
if dt == f16:
Expand Down Expand Up @@ -190,6 +195,8 @@ def to_paddle_type(dt):
return paddle.int8
if dt == i16:
return paddle.int16
if dt == u1:
return paddle.bool
if dt == u8:
return paddle.uint8
if dt == f16:
Expand Down Expand Up @@ -224,6 +231,8 @@ def to_taichi_type(dt):
return i8
if dt == np.int16:
return i16
if dt == np.bool_:
return u1
if dt == np.uint8:
return u8
if dt == np.uint16:
Expand Down Expand Up @@ -251,6 +260,8 @@ def to_taichi_type(dt):
return i8
if dt == torch.int16:
return i16
if dt == torch.bool:
return u1
if dt == torch.uint8:
return u8
if dt == torch.float16:
Expand All @@ -273,6 +284,8 @@ def to_taichi_type(dt):
return i8
if dt == paddle.int16:
return i16
if dt == paddle.bool:
return u1
if dt == paddle.uint8:
return u8
if dt == paddle.float16:
Expand All @@ -293,7 +306,7 @@ def cook_dtype(dtype):
if dtype is int:
return impl.get_runtime().default_ip
if dtype is bool:
return i32 # TODO[Xiaoyan]: Use i1 in the future
return i32 # TODO(zhantong): Replace it with u1
raise ValueError(f"Invalid data type {dtype}")


Expand Down
16 changes: 15 additions & 1 deletion python/taichi/types/primitive_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,18 @@

# ----------------------------------------

uint1 = ti_python_core.DataType_u1
"""1-bit unsigned integer data type. Same as booleans.
"""

# ----------------------------------------

u1 = uint1
"""Alias for :const:`~taichi.types.primitive_types.uint1`
"""

# ----------------------------------------

u8 = uint8
"""Alias for :const:`~taichi.types.primitive_types.uint8`
"""
Expand Down Expand Up @@ -154,7 +166,7 @@ def ref(tp):
real_types = [f16, f32, f64, float]
real_type_ids = [id(t) for t in real_types]

integer_types = [i8, i16, i32, i64, u8, u16, u32, u64, int, bool]
integer_types = [i8, i16, i32, i64, u1, u8, u16, u32, u64, int, bool]
integer_type_ids = [id(t) for t in integer_types]

all_types = real_types + integer_types
Expand All @@ -175,6 +187,8 @@ def ref(tp):
"i32",
"int64",
"i64",
"uint1",
"u1",
"uint8",
"u8",
"uint16",
Expand Down
3 changes: 3 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1015,6 +1015,9 @@ void TaskCodeGenLLVM::visit(ConstStmt *stmt) {
} else if (val.dt->is_primitive(PrimitiveTypeID::f64)) {
llvm_val[stmt] =
llvm::ConstantFP::get(*llvm_context, llvm::APFloat(val.val_float64()));
} else if (val.dt->is_primitive(PrimitiveTypeID::u1)) {
llvm_val[stmt] = llvm::ConstantInt::get(
*llvm_context, llvm::APInt(1, (uint64)val.val_uint1(), false));
} else if (val.dt->is_primitive(PrimitiveTypeID::i8)) {
llvm_val[stmt] = llvm::ConstantInt::get(
*llvm_context, llvm::APInt(8, (uint64)val.val_int8(), true));
Expand Down
3 changes: 3 additions & 0 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ class TaskCodegen : public IRVisitor {
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
return ir_->int_immediate_number(
stype, static_cast<int64_t>(const_val.val_i16), false);
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return ir_->uint_immediate_number(
stype, static_cast<uint64_t>(const_val.val_u1), false);
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return ir_->uint_immediate_number(
stype, static_cast<uint64_t>(const_val.val_u8), false);
Expand Down
10 changes: 7 additions & 3 deletions taichi/codegen/spirv/spirv_ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,8 @@ SType IRBuilder::get_primitive_uint_type(const DataType &dt) const {
} else if (dt == PrimitiveType::i16 || dt == PrimitiveType::u16 ||
dt == PrimitiveType::f16) {
return t_uint16_;
} else if (dt == PrimitiveType::u1) {
return t_bool_;
} else {
return t_uint8_;
}
Expand All @@ -392,6 +394,8 @@ DataType IRBuilder::get_taichi_uint_type(const DataType &dt) const {
} else if (dt == PrimitiveType::i16 || dt == PrimitiveType::u16 ||
dt == PrimitiveType::f16) {
return PrimitiveType::u16;
} else if (dt == PrimitiveType::u1) {
return PrimitiveType::u1;
} else {
return PrimitiveType::u8;
}
Expand Down Expand Up @@ -1090,10 +1094,10 @@ DEFINE_BUILDER_CMP_OP(ge, GreaterThanEqual);
Value IRBuilder::_OpName(Value a, Value b) { \
TI_ASSERT(a.stype.id == b.stype.id); \
const auto &bool_type = t_bool_; /* TODO: Only scalar supported now */ \
if (is_integral(a.stype.dt)) { \
return make_value(spv::OpI##_Op, bool_type, a, b); \
} else if (a.stype.id == bool_type.id) { \
if (a.stype.id == bool_type.id) { \
return make_value(spv::OpLogical##_Op, bool_type, a, b); \
} else if (is_integral(a.stype.dt)) { \
return make_value(spv::OpI##_Op, bool_type, a, b); \
} else { \
TI_ASSERT(is_real(a.stype.dt)); \
return make_value(spv::OpFOrd##_Op, bool_type, a, b); \
Expand Down
12 changes: 11 additions & 1 deletion taichi/codegen/spirv/spirv_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,14 @@ const tinyir::Type *translate_ti_primitive(tinyir::Block &ir_module,
} else if (t == PrimitiveType::i64) {
return ir_module.emplace_back<IntType>(/*num_bits=*/64,
/*is_signed=*/true);
} else if (t == PrimitiveType::u1) {
// Spir-v has no full support for boolean types, using boolean types in
// backend may cause issues. These issues arise when we use boolean as
// return type, argument type and inner dtype of compount types. Since
// boolean types has the same width with int32 in GLSL, we use int32
// instead.
return ir_module.emplace_back<IntType>(/*num_bits=*/32,
/*is_signed=*/false);
} else if (t == PrimitiveType::u8) {
return ir_module.emplace_back<IntType>(/*num_bits=*/8,
/*is_signed=*/false);
Expand Down Expand Up @@ -395,7 +403,9 @@ class Translate2Spirv : public TypeVisitor {
vt = spir_builder_->i64_type();
}
} else {
if (type->num_bits() == 8) {
if (type->num_bits() == 1) {
vt = spir_builder_->bool_type();
} else if (type->num_bits() == 8) {
vt = spir_builder_->u8_type();
} else if (type->num_bits() == 16) {
vt = spir_builder_->u16_type();
Expand Down
2 changes: 2 additions & 0 deletions taichi/common/core.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ class CoreState {
// Types
//******************************************************************************

using uint1 = bool;

using uchar = unsigned char;

using int8 = int8_t;
Expand Down
2 changes: 2 additions & 0 deletions taichi/common/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

namespace taichi {

using uint1 = bool;

using uchar = unsigned char;

using int8 = int8_t;
Expand Down
3 changes: 2 additions & 1 deletion taichi/inc/data_type_with_c_type.inc.h
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
// Doesn't contain f16 and u1.
// Doesn't contain f16.
PER_C_TYPE(f32, float32)
PER_C_TYPE(f64, float64)
PER_C_TYPE(i8, int8)
PER_C_TYPE(i16, int16)
PER_C_TYPE(i32, int32)
PER_C_TYPE(i64, int64)
PER_C_TYPE(u1, uint1)
PER_C_TYPE(u8, uint8)
PER_C_TYPE(u16, uint16)
PER_C_TYPE(u32, uint32)
Expand Down
4 changes: 4 additions & 0 deletions taichi/ir/expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ void Expr::set_adjoint_checkbit(const Expr &o) {
this->cast<FieldExpression>()->adjoint_checkbit.set(o);
}

Expr::Expr(uint1 x) : Expr() {
expr = std::make_shared<ConstExpression>(PrimitiveType::u1, x);
}

Expr::Expr(int16 x) : Expr() {
expr = std::make_shared<ConstExpression>(PrimitiveType::i16, x);
}
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ class Expr {
atomic = false;
}

explicit Expr(uint1 x);

explicit Expr(int16 x);

explicit Expr(int32 x);
Expand Down
9 changes: 5 additions & 4 deletions taichi/ir/frontend_ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,11 @@ class FrontendSNodeOpStmt : public Stmt {
ExprGroup indices;
Expr val;

FrontendSNodeOpStmt(SNodeOpType op_type,
SNode *snode,
const ExprGroup &indices,
const Expr &val = Expr(nullptr));
FrontendSNodeOpStmt(
SNodeOpType op_type,
SNode *snode,
const ExprGroup &indices,
const Expr &val = Expr(std::shared_ptr<Expression>(nullptr)));

TI_DEFINE_ACCEPT
TI_DEFINE_CLONE_FOR_FRONTEND_IR
Expand Down
11 changes: 11 additions & 0 deletions taichi/ir/type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,8 @@ std::string TypedConstant::stringify() const {
return fmt::format("{}", val_i8);
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
return fmt::format("{}", val_i16);
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return fmt::format("{}", val_u1);
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return fmt::format("{}", val_u8);
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
Expand Down Expand Up @@ -391,6 +393,8 @@ bool TypedConstant::equal_type_and_value(const TypedConstant &o) const {
return val_i8 == o.val_i8;
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
return val_i16 == o.val_i16;
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return val_u1 == o.val_u1;
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return val_u8 == o.val_u8;
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
Expand Down Expand Up @@ -440,6 +444,11 @@ int16 &TypedConstant::val_int16() {
return val_i16;
}

uint1 &TypedConstant::val_uint1() {
TI_ASSERT(get_data_type<uint1>() == dt);
return val_u1;
}

uint8 &TypedConstant::val_uint8() {
TI_ASSERT(get_data_type<uint8>() == dt);
return val_u8;
Expand Down Expand Up @@ -483,6 +492,8 @@ uint64 TypedConstant::val_uint() const {
return val_u64;
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return val_u8;
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return val_u1;
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
return val_u16;
} else {
Expand Down
7 changes: 7 additions & 0 deletions taichi/ir/type.h
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,7 @@ class TypedConstant {
float64 val_f64;
int8 val_i8;
int16 val_i16;
uint1 val_u1;
uint8 val_u8;
uint16 val_u16;
uint32 val_u32;
Expand Down Expand Up @@ -564,6 +565,9 @@ class TypedConstant {
explicit TypedConstant(int16 x) : dt(PrimitiveType::i16), val_i16(x) {
}

explicit TypedConstant(uint1 x) : dt(PrimitiveType::u1), val_u1(x) {
}

explicit TypedConstant(uint8 x) : dt(PrimitiveType::u8), val_u8(x) {
}

Expand Down Expand Up @@ -594,6 +598,8 @@ class TypedConstant {
val_i8 = value;
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
val_i16 = value;
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
val_u1 = value;
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
val_u8 = value;
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
Expand Down Expand Up @@ -627,6 +633,7 @@ class TypedConstant {
float64 &val_float64();
int8 &val_int8();
int16 &val_int16();
uint1 &val_uint1();
uint8 &val_uint8();
uint16 &val_uint16();
uint32 &val_uint32();
Expand Down
5 changes: 4 additions & 1 deletion taichi/ir/type_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ int data_type_size(DataType t) {
REGISTER_DATA_TYPE(i16, int16);
REGISTER_DATA_TYPE(i32, int32);
REGISTER_DATA_TYPE(i64, int64);
REGISTER_DATA_TYPE(u1, uint1);
REGISTER_DATA_TYPE(u8, uint8);
REGISTER_DATA_TYPE(u16, uint16);
REGISTER_DATA_TYPE(u32, uint32);
Expand Down Expand Up @@ -99,7 +100,9 @@ std::string tensor_type_format(DataType t, Arch arch) {
}

std::string data_type_format(DataType dt, Arch arch) {
if (dt->is_primitive(PrimitiveTypeID::i8)) {
if (dt->is_primitive(PrimitiveTypeID::u1)) {
return "%d";
} else if (dt->is_primitive(PrimitiveTypeID::i8)) {
// i8/u8 is converted to i16/u16 before printing, because CUDA doesn't
// support the "%hhd"/"%hhu" specifiers.
return "%hd";
Expand Down
7 changes: 7 additions & 0 deletions taichi/ir/type_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ inline DataType get_data_type() {
return PrimitiveType::i32;
} else if (std::is_same<T, int64>()) {
return PrimitiveType::i64;
} else if (std::is_same<T, uint1>()) {
return PrimitiveType::u1;
} else if (std::is_same<T, uint8>()) {
return PrimitiveType::u8;
} else if (std::is_same<T, uint16>()) {
Expand Down Expand Up @@ -101,6 +103,7 @@ inline bool is_integral(DataType dt) {
dt->is_primitive(PrimitiveTypeID::i16) ||
dt->is_primitive(PrimitiveTypeID::i32) ||
dt->is_primitive(PrimitiveTypeID::i64) ||
dt->is_primitive(PrimitiveTypeID::u1) ||
dt->is_primitive(PrimitiveTypeID::u8) ||
dt->is_primitive(PrimitiveTypeID::u16) ||
dt->is_primitive(PrimitiveTypeID::u32) ||
Expand Down Expand Up @@ -146,6 +149,8 @@ inline TypedConstant get_max_value(DataType dt) {
return {dt, std::numeric_limits<int32>::max()};
} else if (dt->is_primitive(PrimitiveTypeID::i64)) {
return {dt, std::numeric_limits<int64>::max()};
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return {dt, std::numeric_limits<uint1>::max()};
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return {dt, std::numeric_limits<uint8>::max()};
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
Expand All @@ -172,6 +177,8 @@ inline TypedConstant get_min_value(DataType dt) {
return {dt, std::numeric_limits<int32>::lowest()};
} else if (dt->is_primitive(PrimitiveTypeID::i64)) {
return {dt, std::numeric_limits<int64>::lowest()};
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return {dt, std::numeric_limits<uint1>::lowest()};
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return {dt, std::numeric_limits<uint8>::lowest()};
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ T Kernel::fetch_ret(DataType dt, int i) {
return (T)program->fetch_result<int8>(i);
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
return (T)program->fetch_result<int16>(i);
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return (T)program->fetch_result<uint1>(i);
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
return (T)program->fetch_result<uint8>(i);
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
Expand Down
2 changes: 2 additions & 0 deletions taichi/program/launch_context_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ void LaunchContextBuilder::set_arg_int(int arg_id, int64 d) {
set_arg(arg_id, (int8)d);
} else if (dt->is_primitive(PrimitiveTypeID::i16)) {
set_arg(arg_id, (int16)d);
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
set_arg(arg_id, (uint1)d);
} else if (dt->is_primitive(PrimitiveTypeID::u8)) {
set_arg(arg_id, (uint8)d);
} else if (dt->is_primitive(PrimitiveTypeID::u16)) {
Expand Down
Loading