Skip to content

Commit

Permalink
Widening and shortening of vectors (in LLVM, sext, zext, fext, and tr…
Browse files Browse the repository at this point in the history
…unc), with safety checks.

Finishing this depends on ziglang#1757. I'd rather not re-work ir_gen_node_raw for explicit casts
(signed to unsigned, and safe narrowing casts) when that is upcoming.
  • Loading branch information
shawnl committed Sep 8, 2019
1 parent e63b240 commit 80ded98
Show file tree
Hide file tree
Showing 5 changed files with 330 additions and 57 deletions.
2 changes: 1 addition & 1 deletion doc/langref.html.in
Original file line number Diff line number Diff line change
Expand Up @@ -7857,7 +7857,7 @@ fn List(comptime T: type) type {
{#header_open|@truncate#}
<pre>{#syntax#}@truncate(comptime T: type, integer: var) T{#endsyntax#}</pre>
<p>
This function truncates bits from an integer type, resulting in a smaller
This function truncates bits from an integer type (or the integers of a vector), resulting in a smaller
or same-sized integer type.
</p>
<p>
Expand Down
72 changes: 54 additions & 18 deletions src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1389,11 +1389,22 @@ static void add_bounds_check(CodeGen *g, LLVMValueRef target_val,
LLVMPositionBuilderAtEnd(g->builder, ok_block);
}

static LLVMValueRef gen_assert_zero(CodeGen *g, LLVMValueRef expr_val, ZigType *int_type) {
LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, int_type));
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, expr_val, zero, "");
static LLVMValueRef gen_assert_zero(CodeGen *g, LLVMValueRef expr_val, ZigType *type) {
assert(type->id == ZigTypeIdInt || type->id == ZigTypeIdVector);
bool is_vector = type->id == ZigTypeIdVector;
LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, type));
LLVMValueRef ok_bits = LLVMBuildICmp(g->builder, LLVMIntEQ, expr_val, zero, "");
LLVMValueRef ok_bit = ok_bits;
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenFail");
if (is_vector) {
ok_bit = LLVMConstInt(g->builtin_types.entry_bool->llvm_type, 1, false);
for (size_t i = 0;i < type->data.vector.len;i++) {
LLVMValueRef i_val = LLVMConstInt(g->builtin_types.entry_i32->llvm_type, i, false);
LLVMValueRef extract = LLVMBuildExtractElement(g->builder, ok_bits, i_val, "");
ok_bit = LLVMBuildAnd(g->builder, ok_bit, extract, "");
}
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);

LLVMPositionBuilderAtEnd(g->builder, fail_block);
Expand All @@ -1407,29 +1418,45 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z
ZigType *wanted_type, LLVMValueRef expr_val)
{
assert(actual_type->id == wanted_type->id);
bool is_vector = actual_type->id == ZigTypeIdVector;
ZigType *actual_scalar_type = is_vector ? actual_type->data.vector.elem_type : actual_type;
ZigType *wanted_scalar_type = is_vector ? wanted_type->data.vector.elem_type : wanted_type;
assert(actual_scalar_type->id == wanted_scalar_type->id);
assert(expr_val != nullptr);
if (is_vector) {
assert(actual_type->data.vector.len == wanted_type->data.vector.len);
}

uint64_t actual_bits;
uint64_t wanted_bits;
if (actual_type->id == ZigTypeIdFloat) {
actual_bits = actual_type->data.floating.bit_count;
wanted_bits = wanted_type->data.floating.bit_count;
} else if (actual_type->id == ZigTypeIdInt) {
actual_bits = actual_type->data.integral.bit_count;
wanted_bits = wanted_type->data.integral.bit_count;
if (actual_scalar_type->id == ZigTypeIdFloat) {
actual_bits = actual_scalar_type->data.floating.bit_count;
wanted_bits = wanted_scalar_type->data.floating.bit_count;
} else if (actual_scalar_type->id == ZigTypeIdInt) {
actual_bits = actual_scalar_type->data.integral.bit_count;
wanted_bits = wanted_scalar_type->data.integral.bit_count;
} else {
zig_unreachable();
}

if (actual_type->id == ZigTypeIdInt &&
!wanted_type->data.integral.is_signed && actual_type->data.integral.is_signed &&
if (actual_scalar_type->id == ZigTypeIdInt &&
!wanted_scalar_type->data.integral.is_signed && actual_scalar_type->data.integral.is_signed &&
want_runtime_safety)
{
LLVMValueRef zero = LLVMConstNull(get_llvm_type(g, actual_type));
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntSGE, expr_val, zero, "");
LLVMValueRef ok_bits = LLVMBuildICmp(g->builder, LLVMIntSGE, expr_val, zero, "");
LLVMValueRef ok_bit = ok_bits;

LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "SignCastOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "SignCastFail");
if (is_vector) {
ok_bit = LLVMConstInt(g->builtin_types.entry_bool->llvm_type, 1, false);
for (size_t i = 0;i < wanted_type->data.vector.len;i++) {
LLVMValueRef i_val = LLVMConstInt(g->builtin_types.entry_i32->llvm_type, i, false);
LLVMValueRef extract = LLVMBuildExtractElement(g->builder, ok_bits, i_val, "");
ok_bit = LLVMBuildAnd(g->builder, ok_bit, extract, "");
}
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);

LLVMPositionBuilderAtEnd(g->builder, fail_block);
Expand All @@ -1441,9 +1468,9 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z
if (actual_bits == wanted_bits) {
return expr_val;
} else if (actual_bits < wanted_bits) {
if (actual_type->id == ZigTypeIdFloat) {
if (actual_scalar_type->id == ZigTypeIdFloat) {
return LLVMBuildFPExt(g->builder, expr_val, get_llvm_type(g, wanted_type), "");
} else if (actual_type->id == ZigTypeIdInt) {
} else if (actual_scalar_type->id == ZigTypeIdInt) {
if (actual_type->data.integral.is_signed) {
return LLVMBuildSExt(g->builder, expr_val, get_llvm_type(g, wanted_type), "");
} else {
Expand All @@ -1453,9 +1480,9 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z
zig_unreachable();
}
} else if (actual_bits > wanted_bits) {
if (actual_type->id == ZigTypeIdFloat) {
if (actual_scalar_type->id == ZigTypeIdFloat) {
return LLVMBuildFPTrunc(g->builder, expr_val, get_llvm_type(g, wanted_type), "");
} else if (actual_type->id == ZigTypeIdInt) {
} else if (actual_scalar_type->id == ZigTypeIdInt) {
if (wanted_bits == 0) {
if (!want_runtime_safety)
return nullptr;
Expand All @@ -1467,14 +1494,23 @@ static LLVMValueRef gen_widen_or_shorten(CodeGen *g, bool want_runtime_safety, Z
return trunc_val;
}
LLVMValueRef orig_val;
if (wanted_type->data.integral.is_signed) {
if (wanted_scalar_type->data.integral.is_signed) {
orig_val = LLVMBuildSExt(g->builder, trunc_val, get_llvm_type(g, actual_type), "");
} else {
orig_val = LLVMBuildZExt(g->builder, trunc_val, get_llvm_type(g, actual_type), "");
}
LLVMValueRef ok_bit = LLVMBuildICmp(g->builder, LLVMIntEQ, expr_val, orig_val, "");
LLVMValueRef ok_bits = LLVMBuildICmp(g->builder, LLVMIntEQ, expr_val, orig_val, "");
LLVMValueRef ok_bit = ok_bits;
LLVMBasicBlockRef ok_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenOk");
LLVMBasicBlockRef fail_block = LLVMAppendBasicBlock(g->cur_fn_val, "CastShortenFail");
if (is_vector) {
ok_bit = LLVMConstInt(g->builtin_types.entry_bool->llvm_type, 1, false);
for (size_t i = 0;i < wanted_type->data.vector.len;i++) {
LLVMValueRef i_val = LLVMConstInt(g->builtin_types.entry_i32->llvm_type, i, false);
LLVMValueRef extract = LLVMBuildExtractElement(g->builder, ok_bits, i_val, "");
ok_bit = LLVMBuildAnd(g->builder, ok_bit, extract, "");
}
}
LLVMBuildCondBr(g->builder, ok_bit, ok_block, fail_block);

LLVMPositionBuilderAtEnd(g->builder, fail_block);
Expand Down
Loading

0 comments on commit 80ded98

Please sign in to comment.