From fb73c7aadb3f5d0a0bed0bdbf79064bd44399bc7 Mon Sep 17 00:00:00 2001 From: mlugg Date: Sun, 28 May 2023 01:45:15 +0100 Subject: [PATCH] Sema: resolve union payload switch captures with peer type resolution This is a bit harder than it seems at first glance. Actually resolving the type is the easy part: the interesting thing is actually getting the capture value. We split this into three cases: * If all payload types are the same (as is required in status quo), we can just do what we already do: get the first field value. * If all payloads are in-memory coercible to the resolved type, we still fetch the first field, but we also emit a `bitcast` to convert to the resolved type. * Otherwise, we need to handle each case separately. We emit a nested `switch_br` which, for each possible case, gets the corresponding union field, and coerces it to the resolved type. As an optimization, the inner switch's 'else' prong is used for any peer which is in-memory coercible to the target type, and the bitcast approach described above is used. Pointer captures have the additional constraint that all payload types must be in-memory coercible to the resolved type. Resolves: #2812 --- src/Sema.zig | 274 ++++++++++++++++-- test/behavior/switch.zig | 68 +++++ .../switch_capture_incompatible_types.zig | 27 ++ 3 files changed, 338 insertions(+), 31 deletions(-) create mode 100644 test/cases/compile_errors/switch_capture_incompatible_types.zig diff --git a/src/Sema.zig b/src/Sema.zig index c34e11a7eede..04a0077d440f 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -2281,6 +2281,34 @@ fn failWithOwnedErrorMsg(sema: *Sema, err_msg: *Module.ErrorMsg) CompileError { return error.AnalysisFail; } +/// Given an ErrorMsg, modify its message and source location to the given values, turning the +/// original message into a note. Notes on the original message are preserved as further notes. +/// Reference trace is preserved. +fn reparentOwnedErrorMsg( + sema: *Sema, + block: *Block, + src: LazySrcLoc, + msg: *Module.ErrorMsg, + comptime format: []const u8, + args: anytype, +) !void { + const mod = sema.mod; + const src_decl = mod.declPtr(block.src_decl); + const resolved_src = src.toSrcLoc(src_decl); + const msg_str = try std.fmt.allocPrint(mod.gpa, format, args); + + const orig_notes = msg.notes.len; + msg.notes = try sema.gpa.realloc(msg.notes, orig_notes + 1); + std.mem.copyBackwards(Module.ErrorMsg, msg.notes[1..], msg.notes[0..orig_notes]); + msg.notes[0] = .{ + .src_loc = msg.src_loc, + .msg = msg.msg, + }; + + msg.src_loc = resolved_src; + msg.msg = msg_str; +} + const align_ty = Type.u29; fn analyzeAsAlign( @@ -10051,6 +10079,8 @@ const SwitchProngAnalysis = struct { operand: Air.Inst.Ref, /// May be `undefined` if no prong has a by-ref capture. operand_ptr: Air.Inst.Ref, + /// The switch condition value. For unions, `operand` is the union and `cond` is its tag. + cond: Air.Inst.Ref, /// If this switch is on an error set, this is the type to assign to the /// `else` prong. If `null`, the prong should be unreachable. else_error_ty: ?Type, @@ -10286,43 +10316,100 @@ const SwitchProngAnalysis = struct { const first_field_index = @intCast(u32, operand_ty.unionTagFieldIndex(first_item_val, sema.mod).?); const first_field = union_obj.fields.values()[first_field_index]; - for (case_vals[1..], 0..) |item, i| { + const field_tys = try sema.arena.alloc(Type, case_vals.len); + for (case_vals, field_tys) |item, *field_ty| { const item_val = sema.resolveConstValue(block, .unneeded, item, "") catch unreachable; + const field_idx = @intCast(u32, operand_ty.unionTagFieldIndex(item_val, sema.mod).?); + field_ty.* = union_obj.fields.values()[field_idx].ty; + } - const field_index = operand_ty.unionTagFieldIndex(item_val, sema.mod).?; - const field = union_obj.fields.values()[field_index]; - if (!field.ty.eql(first_field.ty, sema.mod)) { - const msg = msg: { - const capture_src = raw_capture_src.resolve(sema.gpa, sema.mod.declPtr(block.src_decl), switch_node_offset, .none); + // Fast path: if all the operands are the same type already, we don't need to hit + // PTR! This will also allow us to emit simpler code. + const same_types = for (field_tys[1..]) |field_ty| { + if (!field_ty.eql(field_tys[0], sema.mod)) break false; + } else true; - const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{}); - errdefer msg.destroy(sema.gpa); + const capture_ty = if (same_types) field_tys[0] else capture_ty: { + // We need values to run PTR on, so make a bunch of undef constants. + const dummy_captures = try sema.arena.alloc(Air.Inst.Ref, case_vals.len); + for (dummy_captures, field_tys) |*dummy, field_ty| { + dummy.* = try sema.addConstUndef(field_ty); + } + const case_srcs = try sema.arena.alloc(?LazySrcLoc, case_vals.len); + @memset(case_srcs, .unneeded); + + break :capture_ty sema.resolvePeerTypes(block, .unneeded, dummy_captures, .{ .override = case_srcs }) catch |err| switch (err) { + error.NeededSourceLocation => { // This must be a multi-prong so this must be a `multi_capture` src const multi_idx = raw_capture_src.multi_capture; + const src_decl_ptr = sema.mod.declPtr(block.src_decl); + for (case_srcs, 0..) |*case_src, i| { + const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, i) } }; + case_src.* = raw_case_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none); + } + const capture_src = raw_capture_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none); + _ = sema.resolvePeerTypes(block, capture_src, dummy_captures, .{ .override = case_srcs }) catch |err1| switch (err1) { + error.AnalysisFail => { + const msg = sema.err orelse return error.AnalysisFail; + try sema.reparentOwnedErrorMsg(block, capture_src, msg, "capture group with incompatible types", .{}); + return error.AnalysisFail; + }, + else => |e| return e, + }; + unreachable; + }, + else => |e| return e, + }; + }; - const raw_first_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 0 } }; - const first_item_src = raw_first_item_src.resolve(sema.gpa, sema.mod.declPtr(block.src_decl), switch_node_offset, .first); - const raw_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 1 + @intCast(u32, i) } }; - const item_src = raw_item_src.resolve(sema.gpa, sema.mod.declPtr(block.src_decl), switch_node_offset, .first); - try sema.errNote(block, first_item_src, msg, "type '{}' here", .{first_field.ty.fmt(sema.mod)}); - try sema.errNote(block, item_src, msg, "type '{}' here", .{field.ty.fmt(sema.mod)}); - break :msg msg; - }; - return sema.failWithOwnedErrorMsg(msg); - } - } - + // By-reference captures have some further restrictions which make them easier to emit if (capture_byref) { - const field_ty_ptr = try Type.ptr(sema.arena, sema.mod, .{ - .pointee_type = first_field.ty, - .@"addrspace" = .generic, - .mutable = operand_ptr_ty.ptrIsMutable(), + const operand_ptr_info = operand_ptr_ty.ptrInfo().data; + const capture_ptr_ty = try Type.ptr(sema.arena, sema.mod, .{ + .pointee_type = capture_ty, + .@"addrspace" = operand_ptr_info.@"addrspace", + .mutable = operand_ptr_info.mutable, + .@"volatile" = operand_ptr_info.@"volatile", + // TODO: alignment! }); + // By-ref captures of hetereogeneous types are only allowed if each field + // pointer type is in-memory coercible to the capture pointer type. + if (!same_types) { + for (field_tys, 0..) |field_ty, i| { + const field_ptr_ty = try Type.ptr(sema.arena, sema.mod, .{ + .pointee_type = field_ty, + .@"addrspace" = operand_ptr_info.@"addrspace", + .mutable = operand_ptr_info.mutable, + .@"volatile" = operand_ptr_info.@"volatile", + // TODO: alignment! + }); + if (.ok != try sema.coerceInMemoryAllowed(block, capture_ptr_ty, field_ptr_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) { + const multi_idx = raw_capture_src.multi_capture; + const src_decl_ptr = sema.mod.declPtr(block.src_decl); + const capture_src = raw_capture_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none); + const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, i) } }; + const case_src = raw_case_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none); + const msg = msg: { + const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{}); + errdefer msg.destroy(sema.gpa); + try sema.errNote(block, case_src, msg, "pointer type child '{}' cannot cast into resolved pointer type child '{}'", .{ + field_ty.fmt(sema.mod), + capture_ty.fmt(sema.mod), + }); + try sema.errNote(block, capture_src, msg, "this coercion is only possible when capturing by value", .{}); + break :msg msg; + }; + return sema.failWithOwnedErrorMsg(msg); + } + } + } + if (try sema.resolveDefinedValue(block, operand_src, spa.operand_ptr)) |op_ptr_val| { + if (op_ptr_val.isUndef()) return sema.addConstUndef(capture_ptr_ty); return sema.addConstant( - field_ty_ptr, + capture_ptr_ty, try Value.Tag.field_ptr.create(sema.arena, .{ .container_ptr = op_ptr_val, .container_ty = operand_ty, @@ -10330,18 +10417,142 @@ const SwitchProngAnalysis = struct { }), ); } + try sema.requireRuntimeBlock(block, operand_src, null); - return block.addStructFieldPtr(spa.operand_ptr, first_field_index, field_ty_ptr); + return block.addStructFieldPtr(spa.operand_ptr, first_field_index, capture_ptr_ty); } if (try sema.resolveDefinedValue(block, operand_src, spa.operand)) |operand_val| { - return sema.addConstant( - first_field.ty, - operand_val.castTag(.@"union").?.data.val, - ); + if (operand_val.isUndef()) return sema.addConstUndef(capture_ty); + const union_val = operand_val.castTag(.@"union").?.data; + if (union_val.tag.isUndef()) return sema.addConstUndef(capture_ty); + const active_field_idx = @intCast(u32, operand_ty.unionTagFieldIndex(union_val.tag, sema.mod).?); + const field_ty = union_obj.fields.values()[active_field_idx].ty; + const uncoerced = try sema.addConstant(field_ty, union_val.val); + return sema.coerce(block, capture_ty, uncoerced, operand_src); } + try sema.requireRuntimeBlock(block, operand_src, null); - return block.addStructFieldVal(spa.operand, first_field_index, first_field.ty); + + if (same_types) { + return block.addStructFieldVal(spa.operand, first_field_index, capture_ty); + } + + // We may have to emit a switch block which coerces the operand to the capture type. + // If we can, try to avoid that using in-memory coercions. + const first_non_imc = in_mem: { + for (field_tys, 0..) |field_ty, i| { + if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) { + break :in_mem i; + } + } + // All fields are in-memory coercible to the resolved type! + // Just take the first field and bitcast the result. + const uncoerced = try block.addStructFieldVal(spa.operand, first_field_index, first_field.ty); + return block.addBitCast(capture_ty, uncoerced); + }; + + // By-val capture with heterogeneous types which are not all in-memory coercible to + // the resolved capture type. We finally have to fall back to the ugly method. + + // However, let's first track which operands are in-memory coercible. There may well + // be several, and we can squash all of these cases into the same switch prong using + // a simple bitcast. We'll make this the 'else' prong. + + var in_mem_coercible = try std.DynamicBitSet.initFull(sema.arena, field_tys.len); + in_mem_coercible.unset(first_non_imc); + { + const next = first_non_imc + 1; + for (field_tys[next..], next..) |field_ty, i| { + if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) { + in_mem_coercible.unset(i); + } + } + } + + const capture_block_inst = try block.addInstAsIndex(.{ + .tag = .block, + .data = .{ + .ty_pl = .{ + .ty = try sema.addType(capture_ty), + .payload = undefined, // updated below + }, + }, + }); + + const prong_count = field_tys.len - in_mem_coercible.count(); + + const estimated_extra = prong_count * 6; // 2 for Case, 1 item, probably 3 insts + var cases_extra = try std.ArrayList(u32).initCapacity(sema.gpa, estimated_extra); + defer cases_extra.deinit(); + + { + // Non-bitcast cases + var it = in_mem_coercible.iterator(.{ .kind = .unset }); + while (it.next()) |idx| { + var coerce_block = block.makeSubBlock(); + defer coerce_block.instructions.deinit(sema.gpa); + + const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(u32, idx), field_tys[idx]); + const coerced = sema.coerce(&coerce_block, capture_ty, uncoerced, .unneeded) catch |err| switch (err) { + error.NeededSourceLocation => { + const multi_idx = raw_capture_src.multi_capture; + const src_decl_ptr = sema.mod.declPtr(block.src_decl); + const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, idx) } }; + const case_src = raw_case_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none); + _ = try sema.coerce(&coerce_block, capture_ty, uncoerced, case_src); + unreachable; + }, + else => |e| return e, + }; + _ = try coerce_block.addBr(capture_block_inst, coerced); + + try cases_extra.ensureUnusedCapacity(3 + coerce_block.instructions.items.len); + cases_extra.appendAssumeCapacity(1); // items_len + cases_extra.appendAssumeCapacity(@intCast(u32, coerce_block.instructions.items.len)); // body_len + cases_extra.appendAssumeCapacity(@enumToInt(case_vals[idx])); // item + cases_extra.appendSliceAssumeCapacity(coerce_block.instructions.items); // body + } + } + const else_body_len = len: { + // 'else' prong uses a bitcast + var coerce_block = block.makeSubBlock(); + defer coerce_block.instructions.deinit(sema.gpa); + + const first_imc = in_mem_coercible.findFirstSet().?; + const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(u32, first_imc), field_tys[first_imc]); + const coerced = try coerce_block.addBitCast(capture_ty, uncoerced); + _ = try coerce_block.addBr(capture_block_inst, coerced); + + try cases_extra.appendSlice(coerce_block.instructions.items); + break :len coerce_block.instructions.items.len; + }; + + try sema.air_extra.ensureUnusedCapacity(sema.gpa, @typeInfo(Air.SwitchBr).Struct.fields.len + + cases_extra.items.len + + @typeInfo(Air.Block).Struct.fields.len + + 1); + + const switch_br_inst = @intCast(u32, sema.air_instructions.len); + try sema.air_instructions.append(sema.gpa, .{ + .tag = .switch_br, + .data = .{ .pl_op = .{ + .operand = spa.cond, + .payload = sema.addExtraAssumeCapacity(Air.SwitchBr{ + .cases_len = @intCast(u32, prong_count), + .else_body_len = @intCast(u32, else_body_len), + }), + } }, + }); + sema.air_extra.appendSliceAssumeCapacity(cases_extra.items); + + // Set up block body + sema.air_instructions.items(.data)[capture_block_inst].ty_pl.payload = sema.addExtraAssumeCapacity(Air.Block{ + .body_len = 1, + }); + sema.air_extra.appendAssumeCapacity(switch_br_inst); + + return Air.indexToRef(capture_block_inst); }, .ErrorSet => { if (capture_byref) { @@ -11078,6 +11289,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r .parent_block = block, .operand = raw_operand.val, .operand_ptr = raw_operand.ptr, + .cond = operand, .else_error_ty = else_error_ty, .switch_block_inst = inst, .tag_capture_inst = tag_capture_inst, diff --git a/test/behavior/switch.zig b/test/behavior/switch.zig index 3f6cd3729873..72a36c988362 100644 --- a/test/behavior/switch.zig +++ b/test/behavior/switch.zig @@ -1,5 +1,6 @@ const builtin = @import("builtin"); const std = @import("std"); +const assert = std.debug.assert; const expect = std.testing.expect; const expectError = std.testing.expectError; const expectEqual = std.testing.expectEqual; @@ -717,3 +718,70 @@ test "comptime inline switch" { try expectEqual(u32, value); } + +test "switch capture peer type resolution" { + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + + const U = union(enum) { + a: u32, + b: u64, + fn innerVal(u: @This()) u64 { + switch (u) { + .a, .b => |x| return x, + } + } + }; + + try expectEqual(@as(u64, 100), U.innerVal(.{ .a = 100 })); + try expectEqual(@as(u64, 200), U.innerVal(.{ .b = 200 })); +} + +test "switch capture peer type resolution for in-memory coercible payloads" { + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + + const T1 = c_int; + const T2 = @Type(@typeInfo(T1)); + + comptime assert(T1 != T2); + + const U = union(enum) { + a: T1, + b: T2, + fn innerVal(u: @This()) c_int { + switch (u) { + .a, .b => |x| return x, + } + } + }; + + try expectEqual(@as(c_int, 100), U.innerVal(.{ .a = 100 })); + try expectEqual(@as(c_int, 200), U.innerVal(.{ .b = 200 })); +} + +test "switch pointer capture peer type resolution" { + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + + const T1 = c_int; + const T2 = @Type(@typeInfo(T1)); + + comptime assert(T1 != T2); + + const U = union(enum) { + a: T1, + b: T2, + fn innerVal(u: *@This()) *c_int { + switch (u.*) { + .a, .b => |*ptr| return ptr, + } + } + }; + + var ua: U = .{ .a = 100 }; + var ub: U = .{ .b = 200 }; + + ua.innerVal().* = 111; + ub.innerVal().* = 222; + + try expectEqual(U{ .a = 111 }, ua); + try expectEqual(U{ .b = 222 }, ub); +} diff --git a/test/cases/compile_errors/switch_capture_incompatible_types.zig b/test/cases/compile_errors/switch_capture_incompatible_types.zig new file mode 100644 index 000000000000..b6de7d5bf5b3 --- /dev/null +++ b/test/cases/compile_errors/switch_capture_incompatible_types.zig @@ -0,0 +1,27 @@ +export fn f() void { + const U = union(enum) { a: u32, b: *u8 }; + var u: U = undefined; + switch (u) { + .a, .b => |val| _ = val, + } +} + +export fn g() void { + const U = union(enum) { a: u64, b: u32 }; + var u: U = undefined; + switch (u) { + .a, .b => |*ptr| _ = ptr, + } +} + +// error +// backend=stage2 +// target=native +// +// :5:20: error: capture group with incompatible types +// :5:20: note: incompatible types: 'u32' and '*u8' +// :5:10: note: type 'u32' here +// :5:14: note: type '*u8' here +// :13:20: error: capture group with incompatible types +// :13:14: note: pointer type child 'u32' cannot cast into resolved pointer type child 'u64' +// :13:20: note: this coercion is only possible when capturing by value