Skip to content

Commit

Permalink
Array accesses on vectors
Browse files Browse the repository at this point in the history
Lots of pain was caused by a5cb0f7
  • Loading branch information
shawnl committed Sep 8, 2019
1 parent d603ecd commit 169063d
Show file tree
Hide file tree
Showing 6 changed files with 364 additions and 2 deletions.
26 changes: 26 additions & 0 deletions src/all_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2370,6 +2370,9 @@ enum IrInstructionId {
IrInstructionIdLoadPtr,
IrInstructionIdLoadPtrGen,
IrInstructionIdStorePtr,
IrInstructionIdVectorElem,
IrInstructionIdExtract,
IrInstructionIdInsert,
IrInstructionIdFieldPtr,
IrInstructionIdStructFieldPtr,
IrInstructionIdUnionFieldPtr,
Expand Down Expand Up @@ -2689,6 +2692,29 @@ struct IrInstructionLoadPtr {
IrInstruction *ptr;
};

struct IrInstructionVectorElem {
IrInstruction base;

IrInstruction *agg;
IrInstruction *index;
IrInstruction *result_loc;
};

struct IrInstructionExtract {
IrInstruction base;

IrInstruction *agg;
IrInstruction *index;
};

struct IrInstructionInsert {
IrInstruction base;

IrInstruction *agg;
IrInstruction *index;
IrInstruction *value;
};

struct IrInstructionLoadPtrGen {
IrInstruction base;

Expand Down
34 changes: 33 additions & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3350,6 +3350,30 @@ static LLVMValueRef ir_render_decl_var(CodeGen *g, IrExecutable *executable, IrI
return nullptr;
}

static LLVMValueRef ir_render_extract(CodeGen *g, IrExecutable *executable, IrInstructionExtract *instruction) {
assert(instruction->index);
ZigType *child_type = instruction->base.value.type;
if (!type_has_bits(child_type))
return nullptr;

LLVMValueRef agg = ir_llvm_value(g, instruction->agg);
ZigType *agg_type = instruction->agg->value.type;
assert(agg_type->id == ZigTypeIdVector && "Arrays not yet implemented");

return LLVMBuildExtractElement(g->builder, agg, ir_llvm_value(g, instruction->index), "");
}

static LLVMValueRef ir_render_insert(CodeGen *g, IrExecutable *executable, IrInstructionInsert *instruction) {
assert(instruction->index);

LLVMValueRef agg = ir_llvm_value(g, instruction->agg);
ZigType *agg_type = instruction->agg->value.type;
assert(agg_type->id == ZigTypeIdVector && "Arrays not yet implemented");

return LLVMBuildInsertElement(g->builder,
agg, ir_llvm_value(g, instruction->value), ir_llvm_value(g, instruction->index), "");
}

static LLVMValueRef ir_render_load_ptr(CodeGen *g, IrExecutable *executable, IrInstructionLoadPtrGen *instruction) {
ZigType *child_type = instruction->base.value.type;
if (!type_has_bits(child_type))
Expand Down Expand Up @@ -3591,7 +3615,8 @@ static LLVMValueRef ir_render_return_ptr(CodeGen *g, IrExecutable *executable,
static LLVMValueRef ir_render_elem_ptr(CodeGen *g, IrExecutable *executable, IrInstructionElemPtr *instruction) {
LLVMValueRef array_ptr_ptr = ir_llvm_value(g, instruction->array_ptr);
ZigType *array_ptr_type = instruction->array_ptr->value.type;
assert(array_ptr_type->id == ZigTypeIdPointer);
if (array_ptr_type->id != ZigTypeIdPointer)
return ir_llvm_value(g, instruction->array_ptr);
ZigType *array_type = array_ptr_type->data.pointer.child_type;
LLVMValueRef array_ptr = get_handle_value(g, array_ptr_ptr, array_type, array_ptr_type);
LLVMValueRef subscript_value = ir_llvm_value(g, instruction->elem_index);
Expand Down Expand Up @@ -5906,6 +5931,7 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
case IrInstructionIdFrameSizeSrc:
case IrInstructionIdAllocaGen:
case IrInstructionIdAwaitSrc:
case IrInstructionIdVectorElem:
zig_unreachable();

case IrInstructionIdDeclVarGen:
Expand All @@ -5924,6 +5950,10 @@ static LLVMValueRef ir_render_instruction(CodeGen *g, IrExecutable *executable,
return ir_render_br(g, executable, (IrInstructionBr *)instruction);
case IrInstructionIdUnOp:
return ir_render_un_op(g, executable, (IrInstructionUnOp *)instruction);
case IrInstructionIdExtract:
return ir_render_extract(g, executable, (IrInstructionExtract *)instruction);
case IrInstructionIdInsert:
return ir_render_insert(g, executable, (IrInstructionInsert *)instruction);
case IrInstructionIdLoadPtrGen:
return ir_render_load_ptr(g, executable, (IrInstructionLoadPtrGen *)instruction);
case IrInstructionIdStorePtr:
Expand Down Expand Up @@ -6085,6 +6115,8 @@ static void ir_render(CodeGen *g, ZigFn *fn_entry) {
LLVMPositionBuilderAtEnd(g->builder, current_block->llvm_block);
for (size_t instr_i = 0; instr_i < current_block->instruction_list.length; instr_i += 1) {
IrInstruction *instruction = current_block->instruction_list.at(instr_i);
if (instruction->id == IrInstructionIdLoadPtr)
abort();
if (instruction->ref_count == 0 && !ir_has_side_effects(instruction))
continue;
if (get_scope_typeof(instruction->scope) != nullptr)
Expand Down
194 changes: 193 additions & 1 deletion src/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,18 @@ static constexpr IrInstructionId ir_instruction_id(IrInstructionStorePtr *) {
return IrInstructionIdStorePtr;
}

static constexpr IrInstructionId ir_instruction_id(IrInstructionExtract *) {
return IrInstructionIdExtract;
}

static constexpr IrInstructionId ir_instruction_id(IrInstructionVectorElem *) {
return IrInstructionIdVectorElem;
}

static constexpr IrInstructionId ir_instruction_id(IrInstructionInsert *) {
return IrInstructionIdInsert;
}

static constexpr IrInstructionId ir_instruction_id(IrInstructionFieldPtr *) {
return IrInstructionIdFieldPtr;
}
Expand Down Expand Up @@ -1667,6 +1679,47 @@ static IrInstruction *ir_build_load_ptr(IrBuilder *irb, Scope *scope, AstNode *s
return &instruction->base;
}

static IrInstruction *ir_build_vector_elem(IrBuilder *irb, Scope *scope, AstNode *source_node,
IrInstruction *agg, IrInstruction *index, IrInstruction *result_loc) {
IrInstructionVectorElem *instruction = ir_build_instruction<IrInstructionVectorElem>(irb, scope, source_node);
instruction->agg = agg;
instruction->index = index;
instruction->result_loc = result_loc;

ir_ref_instruction(agg, irb->current_basic_block);
ir_ref_instruction(index, irb->current_basic_block);
if (result_loc != nullptr)
ir_ref_instruction(result_loc, irb->current_basic_block);

return &instruction->base;
}

static IrInstruction *ir_build_extract(IrBuilder *irb, Scope *scope, AstNode *source_node,
IrInstruction *agg, IrInstruction *index) {
IrInstructionExtract *instruction = ir_build_instruction<IrInstructionExtract>(irb, scope, source_node);
instruction->agg = agg;
instruction->index = index;

ir_ref_instruction(agg, irb->current_basic_block);
ir_ref_instruction(index, irb->current_basic_block);

return &instruction->base;
}

static IrInstruction *ir_build_insert(IrBuilder *irb, Scope *scope, AstNode *source_node,
IrInstruction *agg, IrInstruction *index, IrInstruction *value) {
IrInstructionInsert *instruction = ir_build_instruction<IrInstructionInsert>(irb, scope, source_node);
instruction->agg = agg;
instruction->index = index;
instruction->value = value;

ir_ref_instruction(agg, irb->current_basic_block);
ir_ref_instruction(index, irb->current_basic_block);
ir_ref_instruction(value, irb->current_basic_block);

return &instruction->base;
}

static IrInstruction *ir_build_typeof(IrBuilder *irb, Scope *scope, AstNode *source_node, IrInstruction *value) {
IrInstructionTypeOf *instruction = ir_build_instruction<IrInstructionTypeOf>(irb, scope, source_node);
instruction->value = value;
Expand Down Expand Up @@ -15352,7 +15405,8 @@ static IrInstruction *ir_resolve_result(IrAnalyze *ira, IrInstruction *suspend_s
result_loc->value.special = ConstValSpecialRuntime;
}

ir_assert(result_loc->value.type->id == ZigTypeIdPointer, suspend_source_instr);
ir_assert(result_loc->value.type->id == ZigTypeIdPointer ||
result_loc->id == IrInstructionIdVectorElem, suspend_source_instr);
ZigType *actual_elem_type = result_loc->value.type->data.pointer.child_type;
if (actual_elem_type->id == ZigTypeIdOptional && value_type->id != ZigTypeIdOptional &&
value_type->id != ZigTypeIdNull)
Expand Down Expand Up @@ -15739,9 +15793,77 @@ static void mark_comptime_value_escape(IrAnalyze *ira, IrInstruction *source_ins
}
}

static IrInstruction *ir_analyze_insert(IrAnalyze *ira, IrInstruction *source_instr,
IrInstruction *agg, IrInstruction *index_arg, IrInstruction *value_arg) {
ZigType *vector_type = agg->value.type;

IrInstruction *index = ir_implicit_cast(ira, index_arg, ira->codegen->builtin_types.entry_u32);
if (type_is_invalid(index->value.type)) {
ir_add_error(ira, source_instr,
buf_sprintf("vector element index invalid, expected '%s', got '%s'",
buf_ptr(&ira->codegen->builtin_types.entry_u32->name),
buf_ptr(&index_arg->value.type->name)));
return ira->codegen->invalid_instruction;
}

IrInstruction *value = ir_implicit_cast(ira, value_arg, vector_type->data.vector.elem_type);
if (type_is_invalid(value->value.type)) {
ir_add_error(ira, source_instr,
buf_sprintf("element to insert of invalid type, expected '%s', got '%s'",
buf_ptr(&vector_type->data.vector.elem_type->name),
buf_ptr(&value_arg->value.type->name)));
return ira->codegen->invalid_instruction;
}

if (instr_is_comptime(index)) {
uint64_t index_int;
if (!ir_resolve_unsigned(ira, index, ira->codegen->builtin_types.entry_u32, &index_int))
return ira->codegen->invalid_instruction;

if (index_int >= vector_type->data.vector.len) {
ir_add_error_node(ira, index->source_node,
buf_sprintf("vector index out of range; max is %" ZIG_PRI_u64 ", got %" ZIG_PRI_u64,
(uint64_t)vector_type->data.vector.len - 1, index_int));
return ira->codegen->invalid_instruction;
}

if (instr_is_comptime(agg) && instr_is_comptime(index)) {
IrInstruction *result = ir_const(ira, source_instr, agg->value.type);
result->value = agg->value;
agg->value.data.x_array.data.s_none.elements[index_int] = value->value;
result->value.special = ConstValSpecialStatic;
return result;
}
}

IrInstruction *result = ir_build_insert(&ira->new_irb, source_instr->scope,
source_instr->source_node, agg, index, value);
result->value.type = ira->codegen->builtin_types.entry_void;
result->value.special = ConstValSpecialRuntime;
return result;
}

static IrInstruction *ir_analyze_instruction_insert(IrAnalyze *ira, IrInstructionInsert *instruction) {
return ir_analyze_insert(ira, &instruction->base, instruction->agg, instruction->index, instruction->value);
}

static IrInstruction *ir_analyze_store_ptr(IrAnalyze *ira, IrInstruction *source_instr,
IrInstruction *ptr, IrInstruction *uncasted_value, bool allow_write_through_const)
{
if (ptr->id == IrInstructionIdVectorElem) {
IrInstructionVectorElem *velem = (IrInstructionVectorElem *)ptr;
IrInstruction *result = ir_analyze_insert(ira, source_instr,
velem->agg, velem->index, uncasted_value);
if (type_is_invalid(result->value.type))
return ira->codegen->invalid_instruction;

result->value.type = velem->agg->value.type;
IrInstruction *store_back_because_ssa_is_a_bad_idea_apparently = ir_analyze_store_ptr(ira, source_instr,
velem->result_loc, result, false);
if (type_is_invalid(store_back_because_ssa_is_a_bad_idea_apparently->value.type))
return ira->codegen->invalid_instruction;
return result;
}
assert(ptr->value.type->id == ZigTypeIdPointer);

if (ptr->value.data.x_ptr.special == ConstPtrSpecialDiscard) {
Expand Down Expand Up @@ -17243,6 +17365,16 @@ static IrInstruction *ir_analyze_instruction_elem_ptr(IrAnalyze *ira, IrInstruct
return ir_get_const_ptr(ira, &elem_ptr_instruction->base, &ira->codegen->const_void_val,
ira->codegen->builtin_types.entry_void, ConstPtrMutComptimeConst, is_const, is_volatile, 0);
}
} else if (array_type->id == ZigTypeIdVector) {
IrInstruction *agg = ir_get_deref(ira, &elem_ptr_instruction->base, elem_ptr_instruction->array_ptr->child, nullptr);
if (!agg)
return ira->codegen->invalid_instruction;

IrInstruction *result = ir_build_vector_elem(&ira->new_irb, elem_ptr_instruction->base.scope,
elem_ptr_instruction->base.source_node,
agg, elem_ptr_instruction->elem_index->child, elem_ptr_instruction->array_ptr->child);
result->value.type = array_type;
return result;
} else {
ir_add_error_node(ira, elem_ptr_instruction->base.source_node,
buf_sprintf("array access of non-array type '%s'", buf_ptr(&array_type->name)));
Expand Down Expand Up @@ -18212,6 +18344,51 @@ static IrInstruction *ir_analyze_instruction_field_ptr(IrAnalyze *ira, IrInstruc
}
}

static IrInstruction *ir_analyze_extract(IrAnalyze *ira, IrInstruction *source_instr,
IrInstruction *agg, IrInstruction *index_arg) {
ZigType *vector_type = agg->value.type;
assert(vector_type->id == ZigTypeIdVector);
ZigType *return_type = vector_type->data.vector.elem_type;

IrInstruction *index = ir_implicit_cast(ira, index_arg, ira->codegen->builtin_types.entry_u32);
if (type_is_invalid(index->value.type)) {
ir_add_error(ira, source_instr,
buf_sprintf("vector element index invalid, expected '%s', got '%s'",
buf_ptr(&ira->codegen->builtin_types.entry_u32->name),
buf_ptr(&index->value.type->name)));
return ira->codegen->invalid_instruction;
}

if (instr_is_comptime(index)) {
uint64_t index_int;
if (!ir_resolve_unsigned(ira, index, ira->codegen->builtin_types.entry_u32, &index_int))
return ira->codegen->invalid_instruction;

if (index_int >= vector_type->data.vector.len) {
ir_add_error_node(ira, index->source_node,
buf_sprintf("Vector index out of range. Max is %" ZIG_PRI_u64 ", got %" ZIG_PRI_u64 ".",
(uint64_t)vector_type->data.vector.len - 1, index_int));
return ira->codegen->invalid_instruction;
}

if (instr_is_comptime(agg)) {
IrInstruction *comptime_result = ir_const(ira, source_instr, return_type);
ConstExprValue *vector_val = ir_resolve_const(ira, agg, UndefOk);
comptime_result->value = vector_val->data.x_array.data.s_none.elements[index_int];
return comptime_result;
}
}

IrInstruction *result = ir_build_extract(&ira->new_irb, source_instr->scope, source_instr->source_node,
agg, index);
result->value.type = return_type;
return result;
}

static IrInstruction *ir_analyze_instruction_extract(IrAnalyze *ira, IrInstructionExtract *instruction) {
return ir_analyze_extract(ira, &instruction->base, instruction->agg, instruction->index);
}

static IrInstruction *ir_analyze_instruction_store_ptr(IrAnalyze *ira, IrInstructionStorePtr *instruction) {
IrInstruction *ptr = instruction->ptr->child;
if (type_is_invalid(ptr->value.type))
Expand All @@ -18228,6 +18405,13 @@ static IrInstruction *ir_analyze_instruction_load_ptr(IrAnalyze *ira, IrInstruct
IrInstruction *ptr = instruction->ptr->child;
if (type_is_invalid(ptr->value.type))
return ira->codegen->invalid_instruction;
if (ptr->id == IrInstructionIdVectorElem) {
IrInstructionVectorElem *velem = (IrInstructionVectorElem *)ptr;
IrInstruction *result = ir_analyze_extract(ira, &velem->base, velem->agg, velem->index);
if (!result)
return ira->codegen->invalid_instruction;
return result;
}
return ir_get_deref(ira, &instruction->base, ptr, nullptr);
}

Expand Down Expand Up @@ -25696,6 +25880,7 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction
case IrInstructionIdTestErrGen:
case IrInstructionIdFrameSizeGen:
case IrInstructionIdAwaitGen:
case IrInstructionIdVectorElem:
zig_unreachable();

case IrInstructionIdReturn:
Expand All @@ -25714,6 +25899,10 @@ static IrInstruction *ir_analyze_instruction_base(IrAnalyze *ira, IrInstruction
return ir_analyze_instruction_store_ptr(ira, (IrInstructionStorePtr *)instruction);
case IrInstructionIdElemPtr:
return ir_analyze_instruction_elem_ptr(ira, (IrInstructionElemPtr *)instruction);
case IrInstructionIdExtract:
return ir_analyze_instruction_extract(ira, (IrInstructionExtract *)instruction);
case IrInstructionIdInsert:
return ir_analyze_instruction_insert(ira, (IrInstructionInsert *)instruction);
case IrInstructionIdVarPtr:
return ir_analyze_instruction_var_ptr(ira, (IrInstructionVarPtr *)instruction);
case IrInstructionIdFieldPtr:
Expand Down Expand Up @@ -26102,6 +26291,7 @@ bool ir_has_side_effects(IrInstruction *instruction) {
case IrInstructionIdPtrType:
case IrInstructionIdSetAlignStack:
case IrInstructionIdExport:
case IrInstructionIdInsert:
case IrInstructionIdSaveErrRetAddr:
case IrInstructionIdAddImplicitReturnType:
case IrInstructionIdAtomicRmw:
Expand All @@ -26126,6 +26316,8 @@ bool ir_has_side_effects(IrInstruction *instruction) {
case IrInstructionIdSpillBegin:
return true;

case IrInstructionIdVectorElem:
case IrInstructionIdExtract:
case IrInstructionIdPhi:
case IrInstructionIdUnOp:
case IrInstructionIdBinOp:
Expand Down
Loading

0 comments on commit 169063d

Please sign in to comment.