Skip to content

Commit

Permalink
stage1: Implement @reduce builtin for vector types
Browse files Browse the repository at this point in the history
The builtin folds a Vector(N,T) into a scalar T using a specified
operator.

Closes #2698
  • Loading branch information
LemonBoy authored and andrewrk committed Oct 5, 2020
1 parent 7c5a24e commit 22b5e47
Show file tree
Hide file tree
Showing 8 changed files with 430 additions and 39 deletions.
10 changes: 10 additions & 0 deletions lib/std/builtin.zig
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,16 @@ pub const AtomicOrder = enum {
SeqCst,
};

/// This data structure is used by the Zig language code generation and
/// therefore must be kept in sync with the compiler implementation.
pub const ReduceOp = enum {
And,
Or,
Xor,
Min,
Max,
};

/// This data structure is used by the Zig language code generation and
/// therefore must be kept in sync with the compiler implementation.
pub const AtomicRmwOp = enum {
Expand Down
26 changes: 26 additions & 0 deletions src/stage1/all_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1821,6 +1821,7 @@ enum BuiltinFnId {
BuiltinFnIdWasmMemorySize,
BuiltinFnIdWasmMemoryGrow,
BuiltinFnIdSrc,
BuiltinFnIdReduce,
};

struct BuiltinFnEntry {
Expand Down Expand Up @@ -2436,6 +2437,15 @@ enum AtomicOrder {
AtomicOrderSeqCst,
};

// synchronized with code in define_builtin_compile_vars
enum ReduceOp {
ReduceOp_and,
ReduceOp_or,
ReduceOp_xor,
ReduceOp_min,
ReduceOp_max,
};

// synchronized with the code in define_builtin_compile_vars
enum AtomicRmwOp {
AtomicRmwOp_xchg,
Expand Down Expand Up @@ -2545,6 +2555,7 @@ enum IrInstSrcId {
IrInstSrcIdEmbedFile,
IrInstSrcIdCmpxchg,
IrInstSrcIdFence,
IrInstSrcIdReduce,
IrInstSrcIdTruncate,
IrInstSrcIdIntCast,
IrInstSrcIdFloatCast,
Expand Down Expand Up @@ -2667,6 +2678,7 @@ enum IrInstGenId {
IrInstGenIdErrName,
IrInstGenIdCmpxchg,
IrInstGenIdFence,
IrInstGenIdReduce,
IrInstGenIdTruncate,
IrInstGenIdShuffleVector,
IrInstGenIdSplat,
Expand Down Expand Up @@ -3516,6 +3528,20 @@ struct IrInstGenFence {
AtomicOrder order;
};

struct IrInstSrcReduce {
IrInstSrc base;

IrInstSrc *op;
IrInstSrc *value;
};

struct IrInstGenReduce {
IrInstGen base;

ReduceOp op;
IrInstGen *value;
};

struct IrInstSrcTruncate {
IrInstSrc base;

Expand Down
95 changes: 56 additions & 39 deletions src/stage1/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2583,36 +2583,6 @@ static LLVMValueRef ir_render_return(CodeGen *g, IrExecutableGen *executable, Ir
return nullptr;
}

enum class ScalarizePredicate {
// Returns true iff all the elements in the vector are 1.
// Equivalent to folding all the bits with `and`.
All,
// Returns true iff there's at least one element in the vector that is 1.
// Equivalent to folding all the bits with `or`.
Any,
};

// Collapses a <N x i1> vector into a single i1 according to the given predicate
static LLVMValueRef scalarize_cmp_result(CodeGen *g, LLVMValueRef val, ScalarizePredicate predicate) {
assert(LLVMGetTypeKind(LLVMTypeOf(val)) == LLVMVectorTypeKind);
LLVMTypeRef scalar_type = LLVMIntType(LLVMGetVectorSize(LLVMTypeOf(val)));
LLVMValueRef casted = LLVMBuildBitCast(g->builder, val, scalar_type, "");

switch (predicate) {
case ScalarizePredicate::Any: {
LLVMValueRef all_zeros = LLVMConstNull(scalar_type);
return LLVMBuildICmp(g->builder, LLVMIntNE, casted, all_zeros, "");
}
case ScalarizePredicate::All: {
LLVMValueRef all_ones = LLVMConstAllOnes(scalar_type);
return LLVMBuildICmp(g->builder, LLVMIntEQ, casted, all_ones, "");
}
}

zig_unreachable();
}


static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
LLVMValueRef val1, LLVMValueRef val2)
{
Expand All @@ -2637,7 +2607,7 @@ static LLVMValueRef gen_overflow_shl_op(CodeGen *g, ZigType *operand_type,
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);

Expand Down Expand Up @@ -2668,7 +2638,7 @@ static LLVMValueRef gen_overflow_shr_op(CodeGen *g, ZigType *operand_type,
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "OverflowFail");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);

Expand Down Expand Up @@ -2745,7 +2715,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
}

if (operand_type->id == ZigTypeIdVector) {
is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
is_zero_bit = ZigLLVMBuildOrReduce(g->builder, is_zero_bit);
}

LLVMBasicBlockRef div_zero_fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivZeroFail");
Expand All @@ -2770,7 +2740,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMValueRef den_is_neg_1 = LLVMBuildICmp(g->builder, LLVMIntEQ, val2, neg_1_value, "");
LLVMValueRef overflow_fail_bit = LLVMBuildAnd(g->builder, num_is_int_min, den_is_neg_1, "");
if (operand_type->id == ZigTypeIdVector) {
overflow_fail_bit = scalarize_cmp_result(g, overflow_fail_bit, ScalarizePredicate::Any);
overflow_fail_bit = ZigLLVMBuildOrReduce(g->builder, overflow_fail_bit);
}
LLVMBuildCondBr(g->builder, overflow_fail_bit, overflow_fail_block, overflow_ok_block);

Expand All @@ -2795,7 +2765,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
LLVMValueRef ok_bit = LLVMBuildFCmp(g->builder, LLVMRealOEQ, floored, result, "");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);

Expand All @@ -2812,7 +2782,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef end_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivTruncEnd");
LLVMValueRef ltz = LLVMBuildFCmp(g->builder, LLVMRealOLT, val1, zero, "");
if (operand_type->id == ZigTypeIdVector) {
ltz = scalarize_cmp_result(g, ltz, ScalarizePredicate::Any);
ltz = ZigLLVMBuildOrReduce(g->builder, ltz);
}
LLVMBuildCondBr(g->builder, ltz, ltz_block, gez_block);

Expand Down Expand Up @@ -2864,7 +2834,7 @@ static LLVMValueRef gen_div(CodeGen *g, bool want_runtime_safety, bool want_fast
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "DivExactFail");
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, remainder_val, zero, "");
if (operand_type->id == ZigTypeIdVector) {
ok_bit = scalarize_cmp_result(g, ok_bit, ScalarizePredicate::All);
ok_bit = ZigLLVMBuildAndReduce(g->builder, ok_bit);
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);

Expand Down Expand Up @@ -2928,7 +2898,7 @@ static LLVMValueRef gen_rem(CodeGen *g, bool want_runtime_safety, bool want_fast
}

if (operand_type->id == ZigTypeIdVector) {
is_zero_bit = scalarize_cmp_result(g, is_zero_bit, ScalarizePredicate::Any);
is_zero_bit = ZigLLVMBuildOrReduce(g->builder, is_zero_bit);
}

LLVMBasicBlockRef rem_zero_ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "RemZeroOk");
Expand Down Expand Up @@ -2985,7 +2955,7 @@ static void gen_shift_rhs_check(CodeGen *g, ZigType *lhs_type, ZigType *rhs_type
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CheckOk");
LLVMValueRef less_than_bit = LLVMBuildICmp(g->builder, LLVMIntULT, value, bit_count_value, "");
if (rhs_type->id == ZigTypeIdVector) {
less_than_bit = scalarize_cmp_result(g, less_than_bit, ScalarizePredicate::Any);
less_than_bit = ZigLLVMBuildOrReduce(g->builder, less_than_bit);
}
LLVMBuildCondBr(g->builder, less_than_bit, ok_block, fail_block);

Expand Down Expand Up @@ -5470,6 +5440,50 @@ static LLVMValueRef ir_render_cmpxchg(CodeGen *g, IrExecutableGen *executable, I
return result_loc;
}

static LLVMValueRef ir_render_reduce(CodeGen *g, IrExecutableGen *executable, IrInstGenReduce *instruction) {
LLVMValueRef value = ir_llvm_value(g, instruction->value);

ZigType *value_type = instruction->value->value->type;
assert(value_type->id == ZigTypeIdVector);
ZigType *scalar_type = value_type->data.vector.elem_type;

LLVMValueRef result_val;
switch (instruction->op) {
case ReduceOp_and:
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
result_val = ZigLLVMBuildAndReduce(g->builder, value);
break;
case ReduceOp_or:
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
result_val = ZigLLVMBuildOrReduce(g->builder, value);
break;
case ReduceOp_xor:
assert(scalar_type->id == ZigTypeIdInt || scalar_type->id == ZigTypeIdBool);
result_val = ZigLLVMBuildXorReduce(g->builder, value);
break;
case ReduceOp_min: {
if (scalar_type->id == ZigTypeIdInt) {
const bool is_signed = scalar_type->data.integral.is_signed;
result_val = ZigLLVMBuildIntMinReduce(g->builder, value, is_signed);
} else if (scalar_type->id == ZigTypeIdFloat) {
result_val = ZigLLVMBuildFPMinReduce(g->builder, value);
} else zig_unreachable();
} break;
case ReduceOp_max: {
if (scalar_type->id == ZigTypeIdInt) {
const bool is_signed = scalar_type->data.integral.is_signed;
result_val = ZigLLVMBuildIntMaxReduce(g->builder, value, is_signed);
} else if (scalar_type->id == ZigTypeIdFloat) {
result_val = ZigLLVMBuildFPMaxReduce(g->builder, value);
} else zig_unreachable();
} break;
default:
zig_unreachable();
}

return result_val;
}

static LLVMValueRef ir_render_fence(CodeGen *g, IrExecutableGen *executable, IrInstGenFence *instruction) {
LLVMAtomicOrdering atomic_order = to_LLVMAtomicOrdering(instruction->order);
LLVMBuildFence(g->builder, atomic_order, false, "");
Expand Down Expand Up @@ -6674,6 +6688,8 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutableGen *executabl
return ir_render_cmpxchg(g, executable, (IrInstGenCmpxchg *)instruction);
case IrInstGenIdFence:
return ir_render_fence(g, executable, (IrInstGenFence *)instruction);
case IrInstGenIdReduce:
return ir_render_reduce(g, executable, (IrInstGenReduce *)instruction);
case IrInstGenIdTruncate:
return ir_render_truncate(g, executable, (IrInstGenTruncate *)instruction);
case IrInstGenIdBoolNot:
Expand Down Expand Up @@ -8630,6 +8646,7 @@ static void define_builtin_fns(CodeGen *g) {
create_builtin_fn(g, BuiltinFnIdWasmMemorySize, "wasmMemorySize", 1);
create_builtin_fn(g, BuiltinFnIdWasmMemoryGrow, "wasmMemoryGrow", 2);
create_builtin_fn(g, BuiltinFnIdSrc, "src", 0);
create_builtin_fn(g, BuiltinFnIdReduce, "reduce", 2);
}

static const char *bool_to_str(bool b) {
Expand Down
Loading

0 comments on commit 22b5e47

Please sign in to comment.