Skip to content

Commit

Permalink
fix fallback codegen on bool
Browse files Browse the repository at this point in the history
  • Loading branch information
Mike-Leo-Smith committed Jan 10, 2025
1 parent 854eb7c commit f1004f4
Showing 1 changed file with 25 additions and 61 deletions.
86 changes: 25 additions & 61 deletions src/backends/fallback/fallback_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,10 @@ class FallbackCodegen {
LUISA_ASSERT(llvm_operand->getType()->isIntOrIntVectorTy() &&
!llvm_operand->getType()->isIntOrIntVectorTy(1),
"Invalid operand type.");
if (operand->type()->is_bool() || operand->type()->is_bool_vector()) {// !b <=> (b == 0)
auto i1_operand = _cmp_eq_zero(b, llvm_operand);
return _zext_i1_to_i8(b, i1_operand);
}
return b.CreateNot(llvm_operand);
}

Expand Down Expand Up @@ -660,7 +664,7 @@ class FallbackCodegen {
LUISA_ERROR_WITH_LOCATION("Invalid binary mod operand type: {}.", elem_type->description());
}

[[nodiscard]] llvm::Value *_translate_binary_bit_and(CurrentFunction &current, IRBuilder &b, const xir::Value *lhs, const xir::Value *rhs) noexcept {
[[nodiscard]] llvm::Value *_translate_binary_bit_op(CurrentFunction &current, IRBuilder &b, xir::ArithmeticOp op, const xir::Value *lhs, const xir::Value *rhs) noexcept {
// Lookup LLVM values for operands
auto llvm_lhs = _lookup_value(current, b, lhs);
auto llvm_rhs = _lookup_value(current, b, rhs);
Expand All @@ -682,68 +686,28 @@ class FallbackCodegen {
case Type::Tag::UINT8: [[fallthrough]];
case Type::Tag::UINT16: [[fallthrough]];
case Type::Tag::UINT32: [[fallthrough]];
case Type::Tag::UINT64: return b.CreateAnd(llvm_lhs, llvm_rhs);
case Type::Tag::UINT64: {
auto is_bool = elem_type->is_bool();
if (is_bool) {
llvm_lhs = _cmp_ne_zero(b, llvm_lhs);
llvm_rhs = _cmp_ne_zero(b, llvm_rhs);
}
auto result = [&] {
switch (op) {
case xir::ArithmeticOp::BINARY_BIT_AND: return b.CreateAnd(llvm_lhs, llvm_rhs);
case xir::ArithmeticOp::BINARY_BIT_OR: return b.CreateOr(llvm_lhs, llvm_rhs);
case xir::ArithmeticOp::BINARY_BIT_XOR: return b.CreateXor(llvm_lhs, llvm_rhs);
default: break;
}
LUISA_ERROR_WITH_LOCATION("Invalid binary bit operation: {}.", static_cast<uint32_t>(op));
}();
return is_bool ? _zext_i1_to_i8(b, result) : result;
}
default: break;
}
LUISA_ERROR_WITH_LOCATION("Invalid binary bit and operand type: {}.", elem_type->description());
}

[[nodiscard]] llvm::Value *_translate_binary_bit_or(CurrentFunction &current, IRBuilder &b, const xir::Value *lhs, const xir::Value *rhs) noexcept {
// Lookup LLVM values for operands
auto llvm_lhs = _lookup_value(current, b, lhs);
auto llvm_rhs = _lookup_value(current, b, rhs);
auto lhs_type = lhs->type();
auto rhs_type = rhs->type();
auto elem_type = lhs->type()->is_vector() ? lhs->type()->element() : lhs->type();
// Type and null checks
LUISA_ASSERT(lhs_type != nullptr && rhs_type != nullptr, "Operand type is null.");
LUISA_ASSERT(lhs_type == rhs_type, "Type mismatch for bitwise and.");
LUISA_ASSERT(lhs_type->is_scalar() || lhs_type->is_vector(), "Invalid operand type.");

// Perform bitwise AND operation
switch (elem_type->tag()) {
case Type::Tag::BOOL: [[fallthrough]];
case Type::Tag::INT8: [[fallthrough]];
case Type::Tag::INT16: [[fallthrough]];
case Type::Tag::INT32: [[fallthrough]];
case Type::Tag::INT64: [[fallthrough]];
case Type::Tag::UINT8: [[fallthrough]];
case Type::Tag::UINT16: [[fallthrough]];
case Type::Tag::UINT32: [[fallthrough]];
case Type::Tag::UINT64: return b.CreateOr(llvm_lhs, llvm_rhs);
default: break;
}
LUISA_ERROR_WITH_LOCATION("Invalid binary bit or operand type: {}.", elem_type->description());
}

[[nodiscard]] llvm::Value *_translate_binary_bit_xor(CurrentFunction &current, IRBuilder &b, const xir::Value *lhs, const xir::Value *rhs) noexcept {
// Lookup LLVM values for operands
auto llvm_lhs = _lookup_value(current, b, lhs);
auto llvm_rhs = _lookup_value(current, b, rhs);
auto lhs_type = lhs->type();
auto rhs_type = rhs->type();
auto elem_type = lhs->type()->is_vector() ? lhs->type()->element() : lhs->type();
// Type and null checks
LUISA_ASSERT(lhs_type != nullptr && rhs_type != nullptr, "Operand type is null.");
LUISA_ASSERT(lhs_type == rhs_type, "Type mismatch for bitwise and.");
LUISA_ASSERT(lhs_type->is_scalar() || lhs_type->is_vector(), "Invalid operand type.");

// Perform bitwise AND operation
switch (elem_type->tag()) {
case Type::Tag::BOOL: [[fallthrough]];
case Type::Tag::INT8: [[fallthrough]];
case Type::Tag::INT16: [[fallthrough]];
case Type::Tag::INT32: [[fallthrough]];
case Type::Tag::INT64: [[fallthrough]];
case Type::Tag::UINT8: [[fallthrough]];
case Type::Tag::UINT16: [[fallthrough]];
case Type::Tag::UINT32: [[fallthrough]];
case Type::Tag::UINT64: return b.CreateXor(llvm_lhs, llvm_rhs);
default: break;
}
LUISA_ERROR_WITH_LOCATION("Invalid binary bit xor operand type: {}.", elem_type->description());
}

[[nodiscard]] llvm::Value *_translate_binary_shift_left(CurrentFunction &current, IRBuilder &b, const xir::Value *lhs, const xir::Value *rhs) noexcept {
// Lookup LLVM values for operands
auto llvm_lhs = _lookup_value(current, b, lhs);
Expand Down Expand Up @@ -1922,9 +1886,9 @@ class FallbackCodegen {
case xir::ArithmeticOp::BINARY_MUL: return _translate_binary_mul(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_DIV: return _translate_binary_div(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_MOD: return _translate_binary_mod(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_BIT_AND: return _translate_binary_bit_and(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_BIT_OR: return _translate_binary_bit_or(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_BIT_XOR: return _translate_binary_bit_xor(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_BIT_AND: return _translate_binary_bit_op(current, b, inst->op(), inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_BIT_OR: return _translate_binary_bit_op(current, b, inst->op(), inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_BIT_XOR: return _translate_binary_bit_op(current, b, inst->op(), inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_SHIFT_LEFT: return _translate_binary_shift_left(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_SHIFT_RIGHT: return _translate_binary_shift_right(current, b, inst->operand(0u), inst->operand(1u));
case xir::ArithmeticOp::BINARY_ROTATE_LEFT: return _translate_binary_rotate_left(current, b, inst->operand(0u), inst->operand(1u));
Expand Down

0 comments on commit f1004f4

Please sign in to comment.