Skip to content

Commit

Permalink
@byteswap on vectors
Browse files Browse the repository at this point in the history
  • Loading branch information
shawnl committed Sep 8, 2019
1 parent 169063d commit e63b240
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 17 deletions.
1 change: 1 addition & 0 deletions src/all_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1771,6 +1771,7 @@ struct ZigLLVMFnKey {
} overflow_arithmetic;
struct {
uint32_t bit_count;
uint32_t vector_len; // 0 means not a vector
} bswap;
struct {
uint32_t bit_count;
Expand Down
28 changes: 21 additions & 7 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4504,7 +4504,13 @@ static LLVMValueRef ir_render_optional_unwrap_ptr(CodeGen *g, IrExecutable *exec
}
}

static LLVMValueRef get_int_builtin_fn(CodeGen *g, ZigType *int_type, BuiltinFnId fn_id) {
static LLVMValueRef get_int_builtin_fn(CodeGen *g, ZigType *expr_type, BuiltinFnId fn_id) {
bool is_vector = expr_type->id == ZigTypeIdVector;
ZigType *int_type = is_vector ? expr_type->data.vector.elem_type : expr_type;
assert(int_type->id == ZigTypeIdInt);
uint32_t vector_len = 0;
if (is_vector)
vector_len = expr_type->data.vector.len;
ZigLLVMFnKey key = {};
const char *fn_name;
uint32_t n_args;
Expand All @@ -4528,6 +4534,7 @@ static LLVMValueRef get_int_builtin_fn(CodeGen *g, ZigType *int_type, BuiltinFnI
n_args = 1;
key.id = ZigLLVMFnIdBswap;
key.data.bswap.bit_count = (uint32_t)int_type->data.integral.bit_count;
key.data.bswap.vector_len = vector_len;
} else if (fn_id == BuiltinFnIdBitReverse) {
fn_name = "bitreverse";
n_args = 1;
Expand All @@ -4542,12 +4549,15 @@ static LLVMValueRef get_int_builtin_fn(CodeGen *g, ZigType *int_type, BuiltinFnI
return existing_entry->value;

char llvm_name[64];
sprintf(llvm_name, "llvm.%s.i%" PRIu32, fn_name, int_type->data.integral.bit_count);
if (is_vector)
sprintf(llvm_name, "llvm.%s.v%" PRIu32 "i%" PRIu32, fn_name, vector_len, int_type->data.integral.bit_count);
else
sprintf(llvm_name, "llvm.%s.i%" PRIu32, fn_name, int_type->data.integral.bit_count);
LLVMTypeRef param_types[] = {
get_llvm_type(g, int_type),
get_llvm_type(g, expr_type),
LLVMInt1Type(),
};
LLVMTypeRef fn_type = LLVMFunctionType(get_llvm_type(g, int_type), param_types, n_args, false);
LLVMTypeRef fn_type = LLVMFunctionType(get_llvm_type(g, expr_type), param_types, n_args, false);
LLVMValueRef fn_val = LLVMAddFunction(g->module, llvm_name, fn_type);
assert(LLVMGetIntrinsicID(fn_val));

Expand Down Expand Up @@ -5539,15 +5549,19 @@ static LLVMValueRef ir_render_mul_add(CodeGen *g, IrExecutable *executable, IrIn

static LLVMValueRef ir_render_bswap(CodeGen *g, IrExecutable *executable, IrInstructionBswap *instruction) {
LLVMValueRef op = ir_llvm_value(g, instruction->op);
ZigType *int_type = instruction->base.value.type;
ZigType *expr_type = instruction->base.value.type;
bool is_vector = expr_type->id == ZigTypeIdVector;
ZigType *int_type = is_vector ? expr_type->data.vector.elem_type : expr_type;
assert(int_type->id == ZigTypeIdInt);
if (int_type->data.integral.bit_count % 16 == 0) {
LLVMValueRef fn_val = get_int_builtin_fn(g, instruction->base.value.type, BuiltinFnIdBswap);
LLVMValueRef fn_val = get_int_builtin_fn(g, expr_type, BuiltinFnIdBswap);
return LLVMBuildCall(g->builder, fn_val, &op, 1, "");
}
// Not an even number of bytes, so we zext 1 byte, then bswap, shift right 1 byte, truncate
ZigType *extended_type = get_int_type(g, int_type->data.integral.is_signed,
int_type->data.integral.bit_count + 8);
if (is_vector)
extended_type = get_vector_type(g, expr_type->data.vector.len, extended_type);
// aabbcc
LLVMValueRef extended = LLVMBuildZExt(g->builder, op, get_llvm_type(g, extended_type), "");
// 00aabbcc
Expand All @@ -5557,7 +5571,7 @@ static LLVMValueRef ir_render_bswap(CodeGen *g, IrExecutable *executable, IrInst
LLVMValueRef shifted = ZigLLVMBuildLShrExact(g->builder, swapped,
LLVMConstInt(get_llvm_type(g, extended_type), 8, false), "");
// 00ccbbaa
return LLVMBuildTrunc(g->builder, shifted, get_llvm_type(g, int_type), "");
return LLVMBuildTrunc(g->builder, shifted, get_llvm_type(g, expr_type), "");
}

static LLVMValueRef ir_render_bit_reverse(CodeGen *g, IrExecutable *executable, IrInstructionBitReverse *instruction) {
Expand Down
62 changes: 52 additions & 10 deletions src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25400,16 +25400,42 @@ static IrInstruction *ir_analyze_instruction_float_op(IrAnalyze *ira, IrInstruct
}

static IrInstruction *ir_analyze_instruction_bswap(IrAnalyze *ira, IrInstructionBswap *instruction) {
ZigType *int_type = ir_resolve_int_type(ira, instruction->type->child);
if (type_is_invalid(int_type))
IrInstruction *op = instruction->op->child;
ZigType *type_expr = ir_resolve_type(ira, instruction->type->child);
if (type_is_invalid(type_expr))
return ira->codegen->invalid_instruction;

IrInstruction *op = ir_implicit_cast(ira, instruction->op->child, int_type);
if (type_expr->id != ZigTypeIdInt) {
ir_add_error(ira, instruction->type,
buf_sprintf("expected integer type, found '%s'", buf_ptr(&type_expr->name)));
if (type_expr->id == ZigTypeIdVector &&
type_expr->data.vector.elem_type->id == ZigTypeIdInt)
ir_add_error(ira, instruction->type,
buf_sprintf("represent vectors with their scalar types, i.e. '%s'",
buf_ptr(&type_expr->data.vector.elem_type->name)));
return ira->codegen->invalid_instruction;
}
ZigType *int_type = type_expr;

ZigType *expr_type = op->value.type;
bool is_vector = expr_type->id == ZigTypeIdVector;
ZigType *ret_type = int_type;
if (is_vector)
ret_type = get_vector_type(ira->codegen, expr_type->data.vector.len, int_type);

op = ir_implicit_cast(ira, instruction->op->child, ret_type);
if (type_is_invalid(op->value.type))
return ira->codegen->invalid_instruction;

if (int_type->data.integral.bit_count == 0) {
IrInstruction *result = ir_const(ira, &instruction->base, int_type);
IrInstruction *result = ir_const(ira, &instruction->base, ret_type);
if (is_vector) {
expand_undef_array(ira->codegen, &result->value);
result->value.data.x_array.data.s_none.elements =
allocate<ConstExprValue>(expr_type->data.vector.len);
for (unsigned i = 0; i < expr_type->data.vector.len; i++)
bigint_init_unsigned(&result->value.data.x_array.data.s_none.elements[i].data.x_bigint, 0);
}
bigint_init_unsigned(&result->value.data.x_bigint, 0);
return result;
}
Expand All @@ -25429,20 +25455,36 @@ static IrInstruction *ir_analyze_instruction_bswap(IrAnalyze *ira, IrInstruction
if (val == nullptr)
return ira->codegen->invalid_instruction;
if (val->special == ConstValSpecialUndef)
return ir_const_undef(ira, &instruction->base, int_type);
return ir_const_undef(ira, &instruction->base, ret_type);

IrInstruction *result = ir_const(ira, &instruction->base, int_type);
IrInstruction *result = ir_const(ira, &instruction->base, ret_type);
size_t buf_size = int_type->data.integral.bit_count / 8;
uint8_t *buf = allocate_nonzero<uint8_t>(buf_size);
bigint_write_twos_complement(&val->data.x_bigint, buf, int_type->data.integral.bit_count, true);
bigint_read_twos_complement(&result->value.data.x_bigint, buf, int_type->data.integral.bit_count, false,
int_type->data.integral.is_signed);
if (is_vector) {
expand_undef_array(ira->codegen, &result->value);
result->value.data.x_array.data.s_none.elements =
allocate<ConstExprValue>(expr_type->data.vector.len);
for (unsigned i = 0; i < expr_type->data.vector.len; i++) {
ConstExprValue *cur = &val->data.x_array.data.s_none.elements[i];
result->value.data.x_array.data.s_none.elements[i].special = cur->special;
if (cur->special == ConstValSpecialUndef)
continue;
bigint_write_twos_complement(&cur->data.x_bigint, buf, int_type->data.integral.bit_count, true);
bigint_read_twos_complement(&result->value.data.x_array.data.s_none.elements[i].data.x_bigint,
buf, int_type->data.integral.bit_count, false,
int_type->data.integral.is_signed);
}
} else {
bigint_write_twos_complement(&val->data.x_bigint, buf, int_type->data.integral.bit_count, true);
bigint_read_twos_complement(&result->value.data.x_bigint, buf, int_type->data.integral.bit_count, false,
int_type->data.integral.is_signed);
}
return result;
}

IrInstruction *result = ir_build_bswap(&ira->new_irb, instruction->base.scope,
instruction->base.source_node, nullptr, op);
result->value.type = int_type;
result->value.type = ret_type;
return result;
}

Expand Down
11 changes: 11 additions & 0 deletions test/stage1/behavior/byteswap.zig
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ test "@byteSwap" {
testByteSwap();
}

test "@byteSwap on vectors" {
comptime testVectorByteSwap();
testVectorByteSwap();
}

fn testByteSwap() void {
expect(@byteSwap(u0, 0) == 0);
expect(@byteSwap(u8, 0x12) == 0x12);
Expand All @@ -30,3 +35,9 @@ fn testByteSwap() void {
expect(@byteSwap(i128, @bitCast(i128, u128(0x123456789abcdef11121314151617181))) ==
@bitCast(i128, u128(0x8171615141312111f1debc9a78563412)));
}

fn testVectorByteSwap() void {
expect((@byteSwap(u8, @Vector(2, u8)([2]u8{0x12, 0x13})) == @Vector(2, u8)([2]u8{0x12, 0x13})).all);
expect((@byteSwap(u16, @Vector(2, u16)([2]u16{0x1234, 0x2345})) == @Vector(2, u16)([2]u16{0x3412, 0x4523})).all);
expect((@byteSwap(u24, @Vector(2, u24)([2]u24{0x123456, 0x234567})) == @Vector(2, u24)([2]u24{0x563412, 0x674523})).all);
}

0 comments on commit e63b240

Please sign in to comment.