diff --git a/src/all_types.hpp b/src/all_types.hpp index e77753ec4d8f..281855443525 100644 --- a/src/all_types.hpp +++ b/src/all_types.hpp @@ -2370,6 +2370,9 @@ enum IrInstructionId { IrInstructionIdLoadPtr, IrInstructionIdLoadPtrGen, IrInstructionIdStorePtr, + IrInstructionIdVectorElem, + IrInstructionIdExtract, + IrInstructionIdInsert, IrInstructionIdFieldPtr, IrInstructionIdStructFieldPtr, IrInstructionIdUnionFieldPtr, @@ -2689,6 +2692,29 @@ struct IrInstructionLoadPtr { IrInstruction *ptr; }; +struct IrInstructionVectorElem { + IrInstruction base; + + IrInstruction *agg; + IrInstruction *index; + IrInstruction *result_loc; +}; + +struct IrInstructionExtract { + IrInstruction base; + + IrInstruction *agg; + IrInstruction *index; +}; + +struct IrInstructionInsert { + IrInstruction base; + + IrInstruction *agg; + IrInstruction *index; + IrInstruction *value; +}; + struct IrInstructionLoadPtrGen { IrInstruction base; diff --git a/src/codegen.cpp b/src/codegen.cpp index 026be0f4c5df..d84ae6e89434 100644 --- a/src/codegen.cpp +++ b/src/codegen.cpp @@ -3350,6 +3350,30 @@ static LLVMValueRef ir_render_decl_var(CodeGen *g, IrExecutable *executable, IrI return nullptr; } +static LLVMValueRef ir_render_extract(CodeGen *g, IrExecutable *executable, IrInstructionExtract *instruction) { + assert(instruction->index); + ZigType *child_type = instruction->base.value.type; + if (!type_has_bits(child_type)) + return nullptr; + + LLVMValueRef agg = ir_llvm_value(g, instruction->agg); + ZigType *agg_type = instruction->agg->value.type; + assert(agg_type->id == ZigTypeIdVector && "Arrays not yet implemented"); + + return LLVMBuildExtractElement(g->builder, agg, ir_llvm_value(g, instruction->index), ""); +} + +static LLVMValueRef ir_render_insert(CodeGen *g, IrExecutable *executable, IrInstructionInsert *instruction) { + assert(instruction->index); + + LLVMValueRef agg = ir_llvm_value(g, instruction->agg); + ZigType *agg_type = instruction->agg->value.type; + assert(agg_type->id == ZigTypeIdVector && "Arrays not yet implemented"); + + return LLVMBuildInsertElement(g->builder, + agg, ir_llvm_value(g, instruction->value), ir_llvm_value(g, instruction->index), ""); +} + static LLVMValueRef ir_render_load_ptr(CodeGen *g, IrExecutable *executable, IrInstructionLoadPtrGen *instruction) { ZigType *child_type = instruction->base.value.type; if (!type_has_bits(child_type)) @@ -3591,7 +3615,8 @@ static LLVMValueRef ir_render_return_ptr(CodeGen *g, IrExecutable *executable, static LLVMValueRef ir_render_elem_ptr(CodeGen *g, IrExecutable *executable, IrInstructionElemPtr *instruction) { LLVMValueRef array_ptr_ptr = ir_llvm_value(g, instruction->array_ptr); ZigType *array_ptr_type = instruction->array_ptr->value.type; - assert(array_ptr_type->id == ZigTypeIdPointer); + if (array_ptr_type->id != ZigTypeIdPointer) + return ir_llvm_value(g, instruction->array_ptr); ZigType *array_type = array_ptr_type->data.pointer.child_type; LLVMValueRef array_ptr = get_handle_value(g, array_ptr_ptr, array_type, array_ptr_type); LLVMValueRef subscript_value = ir_llvm_value(g, instruction->elem_index); @@ -5906,6 +5931,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, case IrInstructionIdFrameSizeSrc: case IrInstructionIdAllocaGen: case IrInstructionIdAwaitSrc: + case IrInstructionIdVectorElem: zig_unreachable(); case IrInstructionIdDeclVarGen: @@ -5924,6 +5950,10 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable, return ir_render_br(g, executable, (IrInstructionBr *)instruction); case IrInstructionIdUnOp: return ir_render_un_op(g, executable, (IrInstructionUnOp *)instruction); + case IrInstructionIdExtract: + return ir_render_extract(g, executable, (IrInstructionExtract *)instruction); + case IrInstructionIdInsert: + return ir_render_insert(g, executable, (IrInstructionInsert *)instruction); case IrInstructionIdLoadPtrGen: return ir_render_load_ptr(g, executable, (IrInstructionLoadPtrGen *)instruction); case IrInstructionIdStorePtr: @@ -6085,6 +6115,8 @@ static void ir_render(CodeGen *g, ZigFn *fn_entry) { LLVMPositionBuilderAtEnd(g->builder, current_block->llvm_block); for (size_t instr_i = 0; instr_i < current_block->instruction_list.length; instr_i += 1) { IrInstruction *instruction = current_block->instruction_list.at(instr_i); + if (instruction->id == IrInstructionIdLoadPtr) + abort(); if (instruction->ref_count == 0 && !ir_has_side_effects(instruction)) continue; if (get_scope_typeof(instruction->scope) != nullptr) diff --git a/src/ir.cpp b/src/ir.cpp index 8fd66249edc3..3aa4e77cc1e4 100644 --- a/src/ir.cpp +++ b/src/ir.cpp @@ -485,6 +485,18 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionStorePtr *) { return IrInstructionIdStorePtr; } +static constexpr IrInstructionId ir_instruction_id(IrInstructionExtract *) { + return IrInstructionIdExtract; +} + +static constexpr IrInstructionId ir_instruction_id(IrInstructionVectorElem *) { + return IrInstructionIdVectorElem; +} + +static constexpr IrInstructionId ir_instruction_id(IrInstructionInsert *) { + return IrInstructionIdInsert; +} + static constexpr IrInstructionId ir_instruction_id(IrInstructionFieldPtr *) { return IrInstructionIdFieldPtr; } @@ -1667,6 +1679,47 @@ static IrInstruction *ir_build_load_ptr(IrBuilder *irb, Scope *scope, AstNode *s return &instruction->base; } +static IrInstruction *ir_build_vector_elem(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *agg, IrInstruction *index, IrInstruction *result_loc) { + IrInstructionVectorElem *instruction = ir_build_instruction(irb, scope, source_node); + instruction->agg = agg; + instruction->index = index; + instruction->result_loc = result_loc; + + ir_ref_instruction(agg, irb->current_basic_block); + ir_ref_instruction(index, irb->current_basic_block); + if (result_loc != nullptr) + ir_ref_instruction(result_loc, irb->current_basic_block); + + return &instruction->base; +} + +static IrInstruction *ir_build_extract(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *agg, IrInstruction *index) { + IrInstructionExtract *instruction = ir_build_instruction(irb, scope, source_node); + instruction->agg = agg; + instruction->index = index; + + ir_ref_instruction(agg, irb->current_basic_block); + ir_ref_instruction(index, irb->current_basic_block); + + return &instruction->base; +} + +static IrInstruction *ir_build_insert(IrBuilder *irb, Scope *scope, AstNode *source_node, + IrInstruction *agg, IrInstruction *index, IrInstruction *value) { + IrInstructionInsert *instruction = ir_build_instruction(irb, scope, source_node); + instruction->agg = agg; + instruction->index = index; + instruction->value = value; + + ir_ref_instruction(agg, irb->current_basic_block); + ir_ref_instruction(index, irb->current_basic_block); + ir_ref_instruction(value, irb->current_basic_block); + + return &instruction->base; +} + static IrInstruction *ir_build_typeof(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *value) { IrInstructionTypeOf *instruction = ir_build_instruction(irb, scope, source_node); instruction->value = value; @@ -15352,7 +15405,8 @@ static IrInstruction *ir_resolve_result(IrAnalyze *ira, IrInstruction *suspend_s result_loc->value.special = ConstValSpecialRuntime; } - ir_assert(result_loc->value.type->id == ZigTypeIdPointer, suspend_source_instr); + ir_assert(result_loc->value.type->id == ZigTypeIdPointer || + result_loc->id == IrInstructionIdVectorElem, suspend_source_instr); ZigType *actual_elem_type = result_loc->value.type->data.pointer.child_type; if (actual_elem_type->id == ZigTypeIdOptional && value_type->id != ZigTypeIdOptional && value_type->id != ZigTypeIdNull) @@ -15739,9 +15793,77 @@ static void mark_comptime_value_escape(IrAnalyze *ira, IrInstruction *source_ins } } +static IrInstruction *ir_analyze_insert(IrAnalyze *ira, IrInstruction *source_instr, + IrInstruction *agg, IrInstruction *index_arg, IrInstruction *value_arg) { + ZigType *vector_type = agg->value.type; + + IrInstruction *index = ir_implicit_cast(ira, index_arg, ira->codegen->builtin_types.entry_u32); + if (type_is_invalid(index->value.type)) { + ir_add_error(ira, source_instr, + buf_sprintf("vector element index invalid, expected '%s', got '%s'", + buf_ptr(&ira->codegen->builtin_types.entry_u32->name), + buf_ptr(&index_arg->value.type->name))); + return ira->codegen->invalid_instruction; + } + + IrInstruction *value = ir_implicit_cast(ira, value_arg, vector_type->data.vector.elem_type); + if (type_is_invalid(value->value.type)) { + ir_add_error(ira, source_instr, + buf_sprintf("element to insert of invalid type, expected '%s', got '%s'", + buf_ptr(&vector_type->data.vector.elem_type->name), + buf_ptr(&value_arg->value.type->name))); + return ira->codegen->invalid_instruction; + } + + if (instr_is_comptime(index)) { + uint64_t index_int; + if (!ir_resolve_unsigned(ira, index, ira->codegen->builtin_types.entry_u32, &index_int)) + return ira->codegen->invalid_instruction; + + if (index_int >= vector_type->data.vector.len) { + ir_add_error_node(ira, index->source_node, + buf_sprintf("vector index out of range; max is %" ZIG_PRI_u64 ", got %" ZIG_PRI_u64, + (uint64_t)vector_type->data.vector.len - 1, index_int)); + return ira->codegen->invalid_instruction; + } + + if (instr_is_comptime(agg) && instr_is_comptime(index)) { + IrInstruction *result = ir_const(ira, source_instr, agg->value.type); + result->value = agg->value; + agg->value.data.x_array.data.s_none.elements[index_int] = value->value; + result->value.special = ConstValSpecialStatic; + return result; + } + } + + IrInstruction *result = ir_build_insert(&ira->new_irb, source_instr->scope, + source_instr->source_node, agg, index, value); + result->value.type = ira->codegen->builtin_types.entry_void; + result->value.special = ConstValSpecialRuntime; + return result; +} + +static IrInstruction *ir_analyze_instruction_insert(IrAnalyze *ira, IrInstructionInsert *instruction) { + return ir_analyze_insert(ira, &instruction->base, instruction->agg, instruction->index, instruction->value); +} + static IrInstruction *ir_analyze_store_ptr(IrAnalyze *ira, IrInstruction *source_instr, IrInstruction *ptr, IrInstruction *uncasted_value, bool allow_write_through_const) { + if (ptr->id == IrInstructionIdVectorElem) { + IrInstructionVectorElem *velem = (IrInstructionVectorElem *)ptr; + IrInstruction *result = ir_analyze_insert(ira, source_instr, + velem->agg, velem->index, uncasted_value); + if (type_is_invalid(result->value.type)) + return ira->codegen->invalid_instruction; + + result->value.type = velem->agg->value.type; + IrInstruction *store_back_because_ssa_is_a_bad_idea_apparently = ir_analyze_store_ptr(ira, source_instr, + velem->result_loc, result, false); + if (type_is_invalid(store_back_because_ssa_is_a_bad_idea_apparently->value.type)) + return ira->codegen->invalid_instruction; + return result; + } assert(ptr->value.type->id == ZigTypeIdPointer); if (ptr->value.data.x_ptr.special == ConstPtrSpecialDiscard) { @@ -17243,6 +17365,16 @@ static IrInstruction *ir_analyze_instruction_elem_ptr(IrAnalyze *ira, IrInstruct return ir_get_const_ptr(ira, &elem_ptr_instruction->base, &ira->codegen->const_void_val, ira->codegen->builtin_types.entry_void, ConstPtrMutComptimeConst, is_const, is_volatile, 0); } + } else if (array_type->id == ZigTypeIdVector) { + IrInstruction *agg = ir_get_deref(ira, &elem_ptr_instruction->base, elem_ptr_instruction->array_ptr->child, nullptr); + if (!agg) + return ira->codegen->invalid_instruction; + + IrInstruction *result = ir_build_vector_elem(&ira->new_irb, elem_ptr_instruction->base.scope, + elem_ptr_instruction->base.source_node, + agg, elem_ptr_instruction->elem_index->child, elem_ptr_instruction->array_ptr->child); + result->value.type = array_type; + return result; } else { ir_add_error_node(ira, elem_ptr_instruction->base.source_node, buf_sprintf("array access of non-array type '%s'", buf_ptr(&array_type->name))); @@ -18212,6 +18344,51 @@ static IrInstruction *ir_analyze_instruction_field_ptr(IrAnalyze *ira, IrInstruc } } +static IrInstruction *ir_analyze_extract(IrAnalyze *ira, IrInstruction *source_instr, + IrInstruction *agg, IrInstruction *index_arg) { + ZigType *vector_type = agg->value.type; + assert(vector_type->id == ZigTypeIdVector); + ZigType *return_type = vector_type->data.vector.elem_type; + + IrInstruction *index = ir_implicit_cast(ira, index_arg, ira->codegen->builtin_types.entry_u32); + if (type_is_invalid(index->value.type)) { + ir_add_error(ira, source_instr, + buf_sprintf("vector element index invalid, expected '%s', got '%s'", + buf_ptr(&ira->codegen->builtin_types.entry_u32->name), + buf_ptr(&index->value.type->name))); + return ira->codegen->invalid_instruction; + } + + if (instr_is_comptime(index)) { + uint64_t index_int; + if (!ir_resolve_unsigned(ira, index, ira->codegen->builtin_types.entry_u32, &index_int)) + return ira->codegen->invalid_instruction; + + if (index_int >= vector_type->data.vector.len) { + ir_add_error_node(ira, index->source_node, + buf_sprintf("Vector index out of range. Max is %" ZIG_PRI_u64 ", got %" ZIG_PRI_u64 ".", + (uint64_t)vector_type->data.vector.len - 1, index_int)); + return ira->codegen->invalid_instruction; + } + + if (instr_is_comptime(agg)) { + IrInstruction *comptime_result = ir_const(ira, source_instr, return_type); + ConstExprValue *vector_val = ir_resolve_const(ira, agg, UndefOk); + comptime_result->value = vector_val->data.x_array.data.s_none.elements[index_int]; + return comptime_result; + } + } + + IrInstruction *result = ir_build_extract(&ira->new_irb, source_instr->scope, source_instr->source_node, + agg, index); + result->value.type = return_type; + return result; +} + +static IrInstruction *ir_analyze_instruction_extract(IrAnalyze *ira, IrInstructionExtract *instruction) { + return ir_analyze_extract(ira, &instruction->base, instruction->agg, instruction->index); +} + static IrInstruction *ir_analyze_instruction_store_ptr(IrAnalyze *ira, IrInstructionStorePtr *instruction) { IrInstruction *ptr = instruction->ptr->child; if (type_is_invalid(ptr->value.type)) @@ -18228,6 +18405,13 @@ static IrInstruction *ir_analyze_instruction_load_ptr(IrAnalyze *ira, IrInstruct IrInstruction *ptr = instruction->ptr->child; if (type_is_invalid(ptr->value.type)) return ira->codegen->invalid_instruction; + if (ptr->id == IrInstructionIdVectorElem) { + IrInstructionVectorElem *velem = (IrInstructionVectorElem *)ptr; + IrInstruction *result = ir_analyze_extract(ira, &velem->base, velem->agg, velem->index); + if (!result) + return ira->codegen->invalid_instruction; + return result; + } return ir_get_deref(ira, &instruction->base, ptr, nullptr); } @@ -25696,6 +25880,7 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction case IrInstructionIdTestErrGen: case IrInstructionIdFrameSizeGen: case IrInstructionIdAwaitGen: + case IrInstructionIdVectorElem: zig_unreachable(); case IrInstructionIdReturn: @@ -25714,6 +25899,10 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction return ir_analyze_instruction_store_ptr(ira, (IrInstructionStorePtr *)instruction); case IrInstructionIdElemPtr: return ir_analyze_instruction_elem_ptr(ira, (IrInstructionElemPtr *)instruction); + case IrInstructionIdExtract: + return ir_analyze_instruction_extract(ira, (IrInstructionExtract *)instruction); + case IrInstructionIdInsert: + return ir_analyze_instruction_insert(ira, (IrInstructionInsert *)instruction); case IrInstructionIdVarPtr: return ir_analyze_instruction_var_ptr(ira, (IrInstructionVarPtr *)instruction); case IrInstructionIdFieldPtr: @@ -26102,6 +26291,7 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdPtrType: case IrInstructionIdSetAlignStack: case IrInstructionIdExport: + case IrInstructionIdInsert: case IrInstructionIdSaveErrRetAddr: case IrInstructionIdAddImplicitReturnType: case IrInstructionIdAtomicRmw: @@ -26126,6 +26316,8 @@ bool ir_has_side_effects(IrInstruction *instruction) { case IrInstructionIdSpillBegin: return true; + case IrInstructionIdVectorElem: + case IrInstructionIdExtract: case IrInstructionIdPhi: case IrInstructionIdUnOp: case IrInstructionIdBinOp: diff --git a/src/ir_print.cpp b/src/ir_print.cpp index f869766e7a65..f0291cada01f 100644 --- a/src/ir_print.cpp +++ b/src/ir_print.cpp @@ -44,6 +44,12 @@ static const char* ir_instruction_type_str(IrInstruction* instruction) { return "Invalid"; case IrInstructionIdShuffleVector: return "Shuffle"; + case IrInstructionIdVectorElem: + return "VectorElem"; + case IrInstructionIdExtract: + return "Extract"; + case IrInstructionIdInsert: + return "Insert"; case IrInstructionIdDeclVarSrc: return "DeclVarSrc"; case IrInstructionIdDeclVarGen: @@ -768,6 +774,22 @@ static void ir_print_load_ptr_gen(IrPrint *irp, IrInstructionLoadPtrGen *instruc ir_print_other_instruction(irp, instruction->result_loc); } +static void ir_print_insert(IrPrint *irp, IrInstructionInsert *instruction) { + ir_print_var_instruction(irp, instruction->agg); + fprintf(irp->f, "["); + ir_print_var_instruction(irp, instruction->index); + fprintf(irp->f, "]"); + fprintf(irp->f, " = "); + ir_print_other_instruction(irp, instruction->value); +} + +static void ir_print_extract(IrPrint *irp, IrInstructionExtract *instruction) { + ir_print_var_instruction(irp, instruction->agg); + fprintf(irp->f, "["); + ir_print_var_instruction(irp, instruction->index); + fprintf(irp->f, "]"); +} + static void ir_print_store_ptr(IrPrint *irp, IrInstructionStorePtr *instruction) { fprintf(irp->f, "*"); ir_print_var_instruction(irp, instruction->ptr); @@ -2007,6 +2029,13 @@ static void ir_print_instruction(IrPrint *irp, IrInstruction *instruction, bool case IrInstructionIdStorePtr: ir_print_store_ptr(irp, (IrInstructionStorePtr *)instruction); break; + case IrInstructionIdExtract: + case IrInstructionIdVectorElem: + ir_print_extract(irp, (IrInstructionExtract *)instruction); + break; + case IrInstructionIdInsert: + ir_print_insert(irp, (IrInstructionInsert *)instruction); + break; case IrInstructionIdTypeOf: ir_print_typeof(irp, (IrInstructionTypeOf *)instruction); break; diff --git a/test/compile_errors.zig b/test/compile_errors.zig index 64c322e711ab..c7254f14c049 100644 --- a/test/compile_errors.zig +++ b/test/compile_errors.zig @@ -6507,6 +6507,16 @@ pub fn addCases(cases: *tests.CompileErrorContext) void { "tmp.zig:2:26: error: vector element type must be integer, float, bool, or pointer; '@Vector(4, u8)' is invalid", ); + cases.addTest( + "vector out-of bounds index", + \\export fn entry() void { + \\ var v: @Vector(4, u32) = [4]u32{0, 1, 2, 3}; + \\ v[5] = 5; + \\} + , + "tmp.zig:3:7: error: vector index out of range; max is 3, got 5", + ); + cases.add("compileLog of tagged enum doesn't crash the compiler", \\const Bar = union(enum(u32)) { \\ X: i32 = 1 diff --git a/test/stage1/behavior/vector.zig b/test/stage1/behavior/vector.zig index 27277b5e5224..4c24cab25493 100644 --- a/test/stage1/behavior/vector.zig +++ b/test/stage1/behavior/vector.zig @@ -138,3 +138,76 @@ test "vector casts of sizes not divisable by 8" { S.doTheTest(); comptime S.doTheTest(); } + +test "implicit cast vector to array - bool" { + const S = struct { + fn doTheTest() void { + const a: @Vector(4, bool) = [_]bool{ true, false, true, false }; + const result_array: [4]bool = a; + expect(mem.eql(bool, result_array, [4]bool{ true, false, true, false })); + } + }; + S.doTheTest(); + comptime S.doTheTest(); +} + +test "vector bin compares with mem.eql" { + const S = struct { + fn doTheTest() void { + var v: @Vector(4, i32) = [4]i32{ 2147483647, -2, 30, 40 }; + var x: @Vector(4, i32) = [4]i32{ 1, 2147483647, 30, 4 }; + expect(mem.eql(bool, ([4]bool)(v == x), [4]bool{ false, false, true, false})); + expect(mem.eql(bool, ([4]bool)(v != x), [4]bool{ true, true, false, true})); + expect(mem.eql(bool, ([4]bool)(v < x), [4]bool{ false, true, false, false})); + expect(mem.eql(bool, ([4]bool)(v > x), [4]bool{ true, false, false, true})); + expect(mem.eql(bool, ([4]bool)(v <= x), [4]bool{ false, true, true, false})); + expect(mem.eql(bool, ([4]bool)(v >= x), [4]bool{ true, false, true, true})); + } + }; + S.doTheTest(); + comptime S.doTheTest(); +} + +test "vector access elements - load" { + { + var a: @Vector(4, i32) = [_]i32{ 1, 2, 3, undefined }; + var i: u32 = 2; + expect(a[i] == 3); + expect(3 == a[2]); + i -= 1; + expect(a[i] == i32(2)); + } + + comptime { + comptime var a: @Vector(4, i32) = [_]i32{ 1, 2, 3, undefined }; + var i: u32 = 0; + expect(a[0] == 1); + i += 1; + expect(a[i] == i32(2)); + i += 1; + expect(3 == a[i]); + } +} + +test "vector access elements - store" { + { + var a: @Vector(4, i32) = [_]i32{ 1, 5, 3, undefined }; + var i: u32 = 2; + a[i] = 1; + expect(a[1] == 5); + expect(a[2] == i32(1)); + i += 1; + a[i] = -364; + expect(-364 == a[3]); + } + + comptime { + comptime var a: @Vector(4, i32) = [_]i32{ 1, 2, 3, undefined }; + var i: u32 = 2; + a[i] = 5; + expect(a[2] == i32(5)); + i += 1; + a[i] = -364; + expect(-364 == a[3]); + } +}