Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[lang] Remove redundant codegen of integer pow #6048

Merged
merged 2 commits into from
Sep 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -663,14 +663,12 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
TI_NOT_IMPLEMENTED
}
} else {
// Note that ret_type here cannot be integral because pow with an
// integral exponent has been demoted in the demote_operations pass
if (ret_type->is_primitive(PrimitiveTypeID::f32)) {
llvm_val[stmt] = create_call("__nv_powf", {lhs, rhs});
} else if (ret_type->is_primitive(PrimitiveTypeID::f64)) {
llvm_val[stmt] = create_call("__nv_pow", {lhs, rhs});
} else if (ret_type->is_primitive(PrimitiveTypeID::i32)) {
ailzhang marked this conversation as resolved.
Show resolved Hide resolved
llvm_val[stmt] = create_call("pow_i32", {lhs, rhs});
} else if (ret_type->is_primitive(PrimitiveTypeID::i64)) {
llvm_val[stmt] = create_call("pow_i64", {lhs, rhs});
} else {
TI_P(data_type_name(ret_type));
TI_NOT_IMPLEMENTED
Expand Down
6 changes: 2 additions & 4 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -658,14 +658,12 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
}
} else if (op == BinaryOpType::pow) {
if (arch_is_cpu(current_arch())) {
// Note that ret_type here cannot be integral because pow with an
// integral exponent has been demoted in the demote_operations pass
if (ret_type->is_primitive(PrimitiveTypeID::f32)) {
llvm_val[stmt] = create_call("pow_f32", {lhs, rhs});
} else if (ret_type->is_primitive(PrimitiveTypeID::f64)) {
llvm_val[stmt] = create_call("pow_f64", {lhs, rhs});
} else if (ret_type->is_primitive(PrimitiveTypeID::i32)) {
llvm_val[stmt] = create_call("pow_i32", {lhs, rhs});
} else if (ret_type->is_primitive(PrimitiveTypeID::i64)) {
llvm_val[stmt] = create_call("pow_i64", {lhs, rhs});
} else {
TI_P(data_type_name(ret_type));
TI_NOT_IMPLEMENTED
Expand Down
6 changes: 0 additions & 6 deletions taichi/codegen/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -559,12 +559,6 @@ class KernelCodegenImpl : public IRVisitor {
}
return;
}
if (op_type == BinaryOpType::pow && is_integral(bin->ret_type)) {
// TODO(k-ye): Make sure the type is not i64?
emit("const {} {} = pow_i32({}, {});", dt_name, bin_name, lhs_name,
rhs_name);
return;
}
const auto binop = metal_binary_op_type_symbol(op_type);
if (is_metal_binary_op_infix(op_type)) {
if (is_comparison(op_type)) {
Expand Down
30 changes: 1 addition & 29 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -846,35 +846,6 @@ class TaskCodegen : public IRVisitor {
BINARY_OP_TO_SPIRV_LOGICAL(cmp_ne, ne)
#undef BINARY_OP_TO_SPIRV_LOGICAL

#define INT_OR_FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(op, instruction, \
instruction_id, max_bits) \
else if (op_type == BinaryOpType::op) { \
const uint32_t instruction = instruction_id; \
if (is_real(bin->element_type()) || is_integral(bin->element_type())) { \
if (data_type_bits(bin->element_type()) > max_bits) { \
TI_ERROR( \
"[glsl450] the operand type of instruction {}({}) must <= {}bits", \
#instruction, instruction_id, max_bits); \
} \
if (is_integral(bin->element_type())) { \
bin_value = ir_->cast( \
dst_type, \
ir_->add(ir_->call_glsl450(ir_->f32_type(), instruction, \
ir_->cast(ir_->f32_type(), lhs_value), \
ir_->cast(ir_->f32_type(), rhs_value)), \
ir_->float_immediate_number(ir_->f32_type(), 0.5f))); \
} else { \
bin_value = \
ir_->call_glsl450(dst_type, instruction, lhs_value, rhs_value); \
} \
} else { \
TI_NOT_IMPLEMENTED \
} \
}

INT_OR_FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(pow, Pow, 26, 32)
#undef INT_OR_FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC

#define FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(op, instruction, instruction_id, \
max_bits) \
else if (op_type == BinaryOpType::op) { \
Expand All @@ -893,6 +864,7 @@ class TaskCodegen : public IRVisitor {
}

FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(atan2, Atan2, 25, 32)
FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(pow, Pow, 26, 32)
#undef FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC

#define BINARY_OP_TO_SPIRV_FUNC(op, S_inst, S_inst_id, U_inst, U_inst_id, \
Expand Down
16 changes: 0 additions & 16 deletions taichi/runtime/llvm/runtime_module/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,22 +205,6 @@ DEFINE_UNARY_REAL_FUNC(asin)
DEFINE_UNARY_REAL_FUNC(cos)
DEFINE_UNARY_REAL_FUNC(sin)

#define DEFINE_FAST_POW(T) \
T pow_##T(T x, T n) { \
T ans = 1; \
T tmp = x; \
while (n > 0) { \
if (n & 1) \
ans *= tmp; \
tmp *= tmp; \
n >>= 1; \
} \
return ans; \
}

DEFINE_FAST_POW(i32)
DEFINE_FAST_POW(i64)

i32 abs_i32(i32 a) {
return a >= 0 ? a : -a;
}
Expand Down
12 changes: 0 additions & 12 deletions taichi/runtime/metal/shaders/helpers.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,6 @@ STR(
: intm);
}

int32_t pow_i32(int32_t x, int32_t n) {
int32_t tmp = x;
int32_t ans = 1;
while (n > (int32_t)(0)) {
if (n & 1)
ans *= tmp;
tmp *= tmp;
n >>= 1;
}
return ans;
}

float fatomic_fetch_add(device float *dest, const float operand) {
// A huge hack! Metal does not support atomic floating point numbers
// natively.
Expand Down