Skip to content

Commit

Permalink
Sema: resolve union payload switch captures with peer type resolution
Browse files Browse the repository at this point in the history
This is a bit harder than it seems at first glance. Actually resolving
the type is the easy part: the interesting thing is actually getting the
capture value. We split this into three cases:

* If all payload types are the same (as is required in status quo), we
  can just do what we already do: get the first field value.
* If all payloads are in-memory coercible to the resolved type, we still
  fetch the first field, but we also emit a `bitcast` to convert to the
  resolved type.
* Otherwise, we need to handle each case separately. We emit a nested
  `switch_br` which, for each possible case, gets the corresponding
  union field, and coerces it to the resolved type. As an optimization,
  the inner switch's 'else' prong is used for any peer which is
  in-memory coercible to the target type, and the bitcast approach
  described above is used.

Pointer captures have the additional constraint that all payload types
must be in-memory coercible to the resolved type.

Resolves: ziglang#2812
  • Loading branch information
mlugg committed May 29, 2023
1 parent eeca7c5 commit ba5333e
Show file tree
Hide file tree
Showing 3 changed files with 338 additions and 31 deletions.
274 changes: 243 additions & 31 deletions src/Sema.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2281,6 +2281,34 @@ fn failWithOwnedErrorMsg(sema: *Sema, err_msg: *Module.ErrorMsg) CompileError {
return error.AnalysisFail;
}

/// Given an ErrorMsg, modify its message and source location to the given values, turning the
/// original message into a note. Notes on the original message are preserved as further notes.
/// Reference trace is preserved.
fn reparentOwnedErrorMsg(
sema: *Sema,
block: *Block,
src: LazySrcLoc,
msg: *Module.ErrorMsg,
comptime format: []const u8,
args: anytype,
) !void {
const mod = sema.mod;
const src_decl = mod.declPtr(block.src_decl);
const resolved_src = src.toSrcLoc(src_decl);
const msg_str = try std.fmt.allocPrint(mod.gpa, format, args);

const orig_notes = msg.notes.len;
msg.notes = try sema.gpa.realloc(msg.notes, orig_notes + 1);
std.mem.copyBackwards(Module.ErrorMsg, msg.notes[1..], msg.notes[0..orig_notes]);
msg.notes[0] = .{
.src_loc = msg.src_loc,
.msg = msg.msg,
};

msg.src_loc = resolved_src;
msg.msg = msg_str;
}

const align_ty = Type.u29;

fn analyzeAsAlign(
Expand Down Expand Up @@ -10051,6 +10079,8 @@ const SwitchProngAnalysis = struct {
operand: Air.Inst.Ref,
/// May be `undefined` if no prong has a by-ref capture.
operand_ptr: Air.Inst.Ref,
/// The switch condition value. For unions, `operand` is the union and `cond` is its tag.
cond: Air.Inst.Ref,
/// If this switch is on an error set, this is the type to assign to the
/// `else` prong. If `null`, the prong should be unreachable.
else_error_ty: ?Type,
Expand Down Expand Up @@ -10286,62 +10316,243 @@ const SwitchProngAnalysis = struct {
const first_field_index = @intCast(u32, operand_ty.unionTagFieldIndex(first_item_val, sema.mod).?);
const first_field = union_obj.fields.values()[first_field_index];

for (case_vals[1..], 0..) |item, i| {
const field_tys = try sema.arena.alloc(Type, case_vals.len);
for (case_vals, field_tys) |item, *field_ty| {
const item_val = sema.resolveConstValue(block, .unneeded, item, "") catch unreachable;
const field_idx = @intCast(u32, operand_ty.unionTagFieldIndex(item_val, sema.mod).?);
field_ty.* = union_obj.fields.values()[field_idx].ty;
}

const field_index = operand_ty.unionTagFieldIndex(item_val, sema.mod).?;
const field = union_obj.fields.values()[field_index];
if (!field.ty.eql(first_field.ty, sema.mod)) {
const msg = msg: {
const capture_src = raw_capture_src.resolve(sema.gpa, sema.mod.declPtr(block.src_decl), switch_node_offset, .none);
// Fast path: if all the operands are the same type already, we don't need to hit
// PTR! This will also allow us to emit simpler code.
const same_types = for (field_tys[1..]) |field_ty| {
if (!field_ty.eql(field_tys[0], sema.mod)) break false;
} else true;

const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{});
errdefer msg.destroy(sema.gpa);
const capture_ty = if (same_types) field_tys[0] else capture_ty: {
// We need values to run PTR on, so make a bunch of undef constants.
const dummy_captures = try sema.arena.alloc(Air.Inst.Ref, case_vals.len);
for (dummy_captures, field_tys) |*dummy, field_ty| {
dummy.* = try sema.addConstUndef(field_ty);
}

const case_srcs = try sema.arena.alloc(?LazySrcLoc, case_vals.len);
@memset(case_srcs, .unneeded);

break :capture_ty sema.resolvePeerTypes(block, .unneeded, dummy_captures, .{ .override = case_srcs }) catch |err| switch (err) {
error.NeededSourceLocation => {
// This must be a multi-prong so this must be a `multi_capture` src
const multi_idx = raw_capture_src.multi_capture;
const src_decl_ptr = sema.mod.declPtr(block.src_decl);
for (case_srcs, 0..) |*case_src, i| {
const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, i) } };
case_src.* = raw_case_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none);
}
const capture_src = raw_capture_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none);
_ = sema.resolvePeerTypes(block, capture_src, dummy_captures, .{ .override = case_srcs }) catch |err1| switch (err1) {
error.AnalysisFail => {
const msg = sema.err orelse return error.AnalysisFail;
try sema.reparentOwnedErrorMsg(block, capture_src, msg, "capture group with incompatible types", .{});
return error.AnalysisFail;
},
else => |e| return e,
};
unreachable;
},
else => |e| return e,
};
};

const raw_first_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 0 } };
const first_item_src = raw_first_item_src.resolve(sema.gpa, sema.mod.declPtr(block.src_decl), switch_node_offset, .first);
const raw_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 1 + @intCast(u32, i) } };
const item_src = raw_item_src.resolve(sema.gpa, sema.mod.declPtr(block.src_decl), switch_node_offset, .first);
try sema.errNote(block, first_item_src, msg, "type '{}' here", .{first_field.ty.fmt(sema.mod)});
try sema.errNote(block, item_src, msg, "type '{}' here", .{field.ty.fmt(sema.mod)});
break :msg msg;
};
return sema.failWithOwnedErrorMsg(msg);
}
}

// By-reference captures have some further restrictions which make them easier to emit
if (capture_byref) {
const field_ty_ptr = try Type.ptr(sema.arena, sema.mod, .{
.pointee_type = first_field.ty,
.@"addrspace" = .generic,
.mutable = operand_ptr_ty.ptrIsMutable(),
const operand_ptr_info = operand_ptr_ty.ptrInfo().data;
const capture_ptr_ty = try Type.ptr(sema.arena, sema.mod, .{
.pointee_type = capture_ty,
.@"addrspace" = operand_ptr_info.@"addrspace",
.mutable = operand_ptr_info.mutable,
.@"volatile" = operand_ptr_info.@"volatile",
// TODO: alignment!
});

// By-ref captures of hetereogeneous types are only allowed if each field
// pointer type is in-memory coercible to the capture pointer type.
if (!same_types) {
for (field_tys, 0..) |field_ty, i| {
const field_ptr_ty = try Type.ptr(sema.arena, sema.mod, .{
.pointee_type = field_ty,
.@"addrspace" = operand_ptr_info.@"addrspace",
.mutable = operand_ptr_info.mutable,
.@"volatile" = operand_ptr_info.@"volatile",
// TODO: alignment!
});
if (.ok != try sema.coerceInMemoryAllowed(block, capture_ptr_ty, field_ptr_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
const multi_idx = raw_capture_src.multi_capture;
const src_decl_ptr = sema.mod.declPtr(block.src_decl);
const capture_src = raw_capture_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none);
const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, i) } };
const case_src = raw_case_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none);
const msg = msg: {
const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{});
errdefer msg.destroy(sema.gpa);
try sema.errNote(block, case_src, msg, "pointer type child '{}' cannot cast into resolved pointer type child '{}'", .{
field_ty.fmt(sema.mod),
capture_ty.fmt(sema.mod),
});
try sema.errNote(block, capture_src, msg, "this coercion is only possible when capturing by value", .{});
break :msg msg;
};
return sema.failWithOwnedErrorMsg(msg);
}
}
}

if (try sema.resolveDefinedValue(block, operand_src, spa.operand_ptr)) |op_ptr_val| {
if (op_ptr_val.isUndef()) return sema.addConstUndef(capture_ptr_ty);
return sema.addConstant(
field_ty_ptr,
capture_ptr_ty,
try Value.Tag.field_ptr.create(sema.arena, .{
.container_ptr = op_ptr_val,
.container_ty = operand_ty,
.field_index = first_field_index,
}),
);
}

try sema.requireRuntimeBlock(block, operand_src, null);
return block.addStructFieldPtr(spa.operand_ptr, first_field_index, field_ty_ptr);
return block.addStructFieldPtr(spa.operand_ptr, first_field_index, capture_ptr_ty);
}

if (try sema.resolveDefinedValue(block, operand_src, spa.operand)) |operand_val| {
return sema.addConstant(
first_field.ty,
operand_val.castTag(.@"union").?.data.val,
);
if (operand_val.isUndef()) return sema.addConstUndef(capture_ty);
const union_val = operand_val.castTag(.@"union").?.data;
if (union_val.tag.isUndef()) return sema.addConstUndef(capture_ty);
const active_field_idx = @intCast(u32, operand_ty.unionTagFieldIndex(union_val.tag, sema.mod).?);
const field_ty = union_obj.fields.values()[active_field_idx].ty;
const uncoerced = try sema.addConstant(field_ty, union_val.val);
return sema.coerce(block, capture_ty, uncoerced, operand_src);
}

try sema.requireRuntimeBlock(block, operand_src, null);
return block.addStructFieldVal(spa.operand, first_field_index, first_field.ty);

if (same_types) {
return block.addStructFieldVal(spa.operand, first_field_index, capture_ty);
}

// We may have to emit a switch block which coerces the operand to the capture type.
// If we can, try to avoid that using in-memory coercions.
const first_non_imc = in_mem: {
for (field_tys, 0..) |field_ty, i| {
if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
break :in_mem i;
}
}
// All fields are in-memory coercible to the resolved type!
// Just take the first field and bitcast the result.
const uncoerced = try block.addStructFieldVal(spa.operand, first_field_index, first_field.ty);
return block.addBitCast(capture_ty, uncoerced);
};

// By-val capture with heterogeneous types which are not all in-memory coercible to
// the resolved capture type. We finally have to fall back to the ugly method.

// However, let's first track which operands are in-memory coercible. There may well
// be several, and we can squash all of these cases into the same switch prong using
// a simple bitcast. We'll make this the 'else' prong.

var in_mem_coercible = try std.DynamicBitSet.initFull(sema.arena, field_tys.len);
in_mem_coercible.unset(first_non_imc);
{
const next = first_non_imc + 1;
for (field_tys[next..], next..) |field_ty, i| {
if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) {
in_mem_coercible.unset(i);
}
}
}

const capture_block_inst = try block.addInstAsIndex(.{
.tag = .block,
.data = .{
.ty_pl = .{
.ty = try sema.addType(capture_ty),
.payload = undefined, // updated below
},
},
});

const prong_count = field_tys.len - in_mem_coercible.count();

const estimated_extra = prong_count * 6; // 2 for Case, 1 item, probably 3 insts
var cases_extra = try std.ArrayList(u32).initCapacity(sema.gpa, estimated_extra);
defer cases_extra.deinit();

{
// Non-bitcast cases
var it = in_mem_coercible.iterator(.{ .kind = .unset });
while (it.next()) |idx| {
var coerce_block = block.makeSubBlock();
defer coerce_block.instructions.deinit(sema.gpa);

const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(u32, idx), field_tys[idx]);
const coerced = sema.coerce(&coerce_block, capture_ty, uncoerced, .unneeded) catch |err| switch (err) {
error.NeededSourceLocation => {
const multi_idx = raw_capture_src.multi_capture;
const src_decl_ptr = sema.mod.declPtr(block.src_decl);
const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, idx) } };
const case_src = raw_case_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none);
_ = try sema.coerce(&coerce_block, capture_ty, uncoerced, case_src);
unreachable;
},
else => |e| return e,
};
_ = try coerce_block.addBr(capture_block_inst, coerced);

try cases_extra.ensureUnusedCapacity(3 + coerce_block.instructions.items.len);
cases_extra.appendAssumeCapacity(1); // items_len
cases_extra.appendAssumeCapacity(@intCast(u32, coerce_block.instructions.items.len)); // body_len
cases_extra.appendAssumeCapacity(@enumToInt(case_vals[idx])); // item
cases_extra.appendSliceAssumeCapacity(coerce_block.instructions.items); // body
}
}
const else_body_len = len: {
// 'else' prong uses a bitcast
var coerce_block = block.makeSubBlock();
defer coerce_block.instructions.deinit(sema.gpa);

const first_imc = in_mem_coercible.findFirstSet().?;
const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(u32, first_imc), field_tys[first_imc]);
const coerced = try coerce_block.addBitCast(capture_ty, uncoerced);
_ = try coerce_block.addBr(capture_block_inst, coerced);

try cases_extra.appendSlice(coerce_block.instructions.items);
break :len coerce_block.instructions.items.len;
};

try sema.air_extra.ensureUnusedCapacity(sema.gpa, @typeInfo(Air.SwitchBr).Struct.fields.len +
cases_extra.items.len +
@typeInfo(Air.Block).Struct.fields.len +
1);

const switch_br_inst = @intCast(u32, sema.air_instructions.len);
try sema.air_instructions.append(sema.gpa, .{
.tag = .switch_br,
.data = .{ .pl_op = .{
.operand = spa.cond,
.payload = sema.addExtraAssumeCapacity(Air.SwitchBr{
.cases_len = @intCast(u32, prong_count),
.else_body_len = @intCast(u32, else_body_len),
}),
} },
});
sema.air_extra.appendSliceAssumeCapacity(cases_extra.items);

// Set up block body
sema.air_instructions.items(.data)[capture_block_inst].ty_pl.payload = sema.addExtraAssumeCapacity(Air.Block{
.body_len = 1,
});
sema.air_extra.appendAssumeCapacity(switch_br_inst);

return Air.indexToRef(capture_block_inst);
},
.ErrorSet => {
if (capture_byref) {
Expand Down Expand Up @@ -11078,6 +11289,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r
.parent_block = block,
.operand = raw_operand.val,
.operand_ptr = raw_operand.ptr,
.cond = operand,
.else_error_ty = else_error_ty,
.switch_block_inst = inst,
.tag_capture_inst = tag_capture_inst,
Expand Down
Loading

0 comments on commit ba5333e

Please sign in to comment.