diff --git a/src/Sema.zig b/src/Sema.zig index c34e11a7eede..04a0077d440f 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -2281,6 +2281,34 @@ fn failWithOwnedErrorMsg(sema: *Sema, err_msg: *Module.ErrorMsg) CompileError { return error.AnalysisFail; } +/// Given an ErrorMsg, modify its message and source location to the given values, turning the +/// original message into a note. Notes on the original message are preserved as further notes. +/// Reference trace is preserved. +fn reparentOwnedErrorMsg( + sema: *Sema, + block: *Block, + src: LazySrcLoc, + msg: *Module.ErrorMsg, + comptime format: []const u8, + args: anytype, +) !void { + const mod = sema.mod; + const src_decl = mod.declPtr(block.src_decl); + const resolved_src = src.toSrcLoc(src_decl); + const msg_str = try std.fmt.allocPrint(mod.gpa, format, args); + + const orig_notes = msg.notes.len; + msg.notes = try sema.gpa.realloc(msg.notes, orig_notes + 1); + std.mem.copyBackwards(Module.ErrorMsg, msg.notes[1..], msg.notes[0..orig_notes]); + msg.notes[0] = .{ + .src_loc = msg.src_loc, + .msg = msg.msg, + }; + + msg.src_loc = resolved_src; + msg.msg = msg_str; +} + const align_ty = Type.u29; fn analyzeAsAlign( @@ -10051,6 +10079,8 @@ const SwitchProngAnalysis = struct { operand: Air.Inst.Ref, /// May be `undefined` if no prong has a by-ref capture. operand_ptr: Air.Inst.Ref, + /// The switch condition value. For unions, `operand` is the union and `cond` is its tag. + cond: Air.Inst.Ref, /// If this switch is on an error set, this is the type to assign to the /// `else` prong. If `null`, the prong should be unreachable. else_error_ty: ?Type, @@ -10286,43 +10316,100 @@ const SwitchProngAnalysis = struct { const first_field_index = @intCast(u32, operand_ty.unionTagFieldIndex(first_item_val, sema.mod).?); const first_field = union_obj.fields.values()[first_field_index]; - for (case_vals[1..], 0..) |item, i| { + const field_tys = try sema.arena.alloc(Type, case_vals.len); + for (case_vals, field_tys) |item, *field_ty| { const item_val = sema.resolveConstValue(block, .unneeded, item, "") catch unreachable; + const field_idx = @intCast(u32, operand_ty.unionTagFieldIndex(item_val, sema.mod).?); + field_ty.* = union_obj.fields.values()[field_idx].ty; + } - const field_index = operand_ty.unionTagFieldIndex(item_val, sema.mod).?; - const field = union_obj.fields.values()[field_index]; - if (!field.ty.eql(first_field.ty, sema.mod)) { - const msg = msg: { - const capture_src = raw_capture_src.resolve(sema.gpa, sema.mod.declPtr(block.src_decl), switch_node_offset, .none); + // Fast path: if all the operands are the same type already, we don't need to hit + // PTR! This will also allow us to emit simpler code. + const same_types = for (field_tys[1..]) |field_ty| { + if (!field_ty.eql(field_tys[0], sema.mod)) break false; + } else true; - const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{}); - errdefer msg.destroy(sema.gpa); + const capture_ty = if (same_types) field_tys[0] else capture_ty: { + // We need values to run PTR on, so make a bunch of undef constants. + const dummy_captures = try sema.arena.alloc(Air.Inst.Ref, case_vals.len); + for (dummy_captures, field_tys) |*dummy, field_ty| { + dummy.* = try sema.addConstUndef(field_ty); + } + const case_srcs = try sema.arena.alloc(?LazySrcLoc, case_vals.len); + @memset(case_srcs, .unneeded); + + break :capture_ty sema.resolvePeerTypes(block, .unneeded, dummy_captures, .{ .override = case_srcs }) catch |err| switch (err) { + error.NeededSourceLocation => { // This must be a multi-prong so this must be a `multi_capture` src const multi_idx = raw_capture_src.multi_capture; + const src_decl_ptr = sema.mod.declPtr(block.src_decl); + for (case_srcs, 0..) |*case_src, i| { + const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, i) } }; + case_src.* = raw_case_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none); + } + const capture_src = raw_capture_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none); + _ = sema.resolvePeerTypes(block, capture_src, dummy_captures, .{ .override = case_srcs }) catch |err1| switch (err1) { + error.AnalysisFail => { + const msg = sema.err orelse return error.AnalysisFail; + try sema.reparentOwnedErrorMsg(block, capture_src, msg, "capture group with incompatible types", .{}); + return error.AnalysisFail; + }, + else => |e| return e, + }; + unreachable; + }, + else => |e| return e, + }; + }; - const raw_first_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 0 } }; - const first_item_src = raw_first_item_src.resolve(sema.gpa, sema.mod.declPtr(block.src_decl), switch_node_offset, .first); - const raw_item_src = Module.SwitchProngSrc{ .multi = .{ .prong = multi_idx, .item = 1 + @intCast(u32, i) } }; - const item_src = raw_item_src.resolve(sema.gpa, sema.mod.declPtr(block.src_decl), switch_node_offset, .first); - try sema.errNote(block, first_item_src, msg, "type '{}' here", .{first_field.ty.fmt(sema.mod)}); - try sema.errNote(block, item_src, msg, "type '{}' here", .{field.ty.fmt(sema.mod)}); - break :msg msg; - }; - return sema.failWithOwnedErrorMsg(msg); - } - } - + // By-reference captures have some further restrictions which make them easier to emit if (capture_byref) { - const field_ty_ptr = try Type.ptr(sema.arena, sema.mod, .{ - .pointee_type = first_field.ty, - .@"addrspace" = .generic, - .mutable = operand_ptr_ty.ptrIsMutable(), + const operand_ptr_info = operand_ptr_ty.ptrInfo().data; + const capture_ptr_ty = try Type.ptr(sema.arena, sema.mod, .{ + .pointee_type = capture_ty, + .@"addrspace" = operand_ptr_info.@"addrspace", + .mutable = operand_ptr_info.mutable, + .@"volatile" = operand_ptr_info.@"volatile", + // TODO: alignment! }); + // By-ref captures of hetereogeneous types are only allowed if each field + // pointer type is in-memory coercible to the capture pointer type. + if (!same_types) { + for (field_tys, 0..) |field_ty, i| { + const field_ptr_ty = try Type.ptr(sema.arena, sema.mod, .{ + .pointee_type = field_ty, + .@"addrspace" = operand_ptr_info.@"addrspace", + .mutable = operand_ptr_info.mutable, + .@"volatile" = operand_ptr_info.@"volatile", + // TODO: alignment! + }); + if (.ok != try sema.coerceInMemoryAllowed(block, capture_ptr_ty, field_ptr_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) { + const multi_idx = raw_capture_src.multi_capture; + const src_decl_ptr = sema.mod.declPtr(block.src_decl); + const capture_src = raw_capture_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none); + const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, i) } }; + const case_src = raw_case_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none); + const msg = msg: { + const msg = try sema.errMsg(block, capture_src, "capture group with incompatible types", .{}); + errdefer msg.destroy(sema.gpa); + try sema.errNote(block, case_src, msg, "pointer type child '{}' cannot cast into resolved pointer type child '{}'", .{ + field_ty.fmt(sema.mod), + capture_ty.fmt(sema.mod), + }); + try sema.errNote(block, capture_src, msg, "this coercion is only possible when capturing by value", .{}); + break :msg msg; + }; + return sema.failWithOwnedErrorMsg(msg); + } + } + } + if (try sema.resolveDefinedValue(block, operand_src, spa.operand_ptr)) |op_ptr_val| { + if (op_ptr_val.isUndef()) return sema.addConstUndef(capture_ptr_ty); return sema.addConstant( - field_ty_ptr, + capture_ptr_ty, try Value.Tag.field_ptr.create(sema.arena, .{ .container_ptr = op_ptr_val, .container_ty = operand_ty, @@ -10330,18 +10417,142 @@ const SwitchProngAnalysis = struct { }), ); } + try sema.requireRuntimeBlock(block, operand_src, null); - return block.addStructFieldPtr(spa.operand_ptr, first_field_index, field_ty_ptr); + return block.addStructFieldPtr(spa.operand_ptr, first_field_index, capture_ptr_ty); } if (try sema.resolveDefinedValue(block, operand_src, spa.operand)) |operand_val| { - return sema.addConstant( - first_field.ty, - operand_val.castTag(.@"union").?.data.val, - ); + if (operand_val.isUndef()) return sema.addConstUndef(capture_ty); + const union_val = operand_val.castTag(.@"union").?.data; + if (union_val.tag.isUndef()) return sema.addConstUndef(capture_ty); + const active_field_idx = @intCast(u32, operand_ty.unionTagFieldIndex(union_val.tag, sema.mod).?); + const field_ty = union_obj.fields.values()[active_field_idx].ty; + const uncoerced = try sema.addConstant(field_ty, union_val.val); + return sema.coerce(block, capture_ty, uncoerced, operand_src); } + try sema.requireRuntimeBlock(block, operand_src, null); - return block.addStructFieldVal(spa.operand, first_field_index, first_field.ty); + + if (same_types) { + return block.addStructFieldVal(spa.operand, first_field_index, capture_ty); + } + + // We may have to emit a switch block which coerces the operand to the capture type. + // If we can, try to avoid that using in-memory coercions. + const first_non_imc = in_mem: { + for (field_tys, 0..) |field_ty, i| { + if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) { + break :in_mem i; + } + } + // All fields are in-memory coercible to the resolved type! + // Just take the first field and bitcast the result. + const uncoerced = try block.addStructFieldVal(spa.operand, first_field_index, first_field.ty); + return block.addBitCast(capture_ty, uncoerced); + }; + + // By-val capture with heterogeneous types which are not all in-memory coercible to + // the resolved capture type. We finally have to fall back to the ugly method. + + // However, let's first track which operands are in-memory coercible. There may well + // be several, and we can squash all of these cases into the same switch prong using + // a simple bitcast. We'll make this the 'else' prong. + + var in_mem_coercible = try std.DynamicBitSet.initFull(sema.arena, field_tys.len); + in_mem_coercible.unset(first_non_imc); + { + const next = first_non_imc + 1; + for (field_tys[next..], next..) |field_ty, i| { + if (.ok != try sema.coerceInMemoryAllowed(block, capture_ty, field_ty, false, sema.mod.getTarget(), .unneeded, .unneeded)) { + in_mem_coercible.unset(i); + } + } + } + + const capture_block_inst = try block.addInstAsIndex(.{ + .tag = .block, + .data = .{ + .ty_pl = .{ + .ty = try sema.addType(capture_ty), + .payload = undefined, // updated below + }, + }, + }); + + const prong_count = field_tys.len - in_mem_coercible.count(); + + const estimated_extra = prong_count * 6; // 2 for Case, 1 item, probably 3 insts + var cases_extra = try std.ArrayList(u32).initCapacity(sema.gpa, estimated_extra); + defer cases_extra.deinit(); + + { + // Non-bitcast cases + var it = in_mem_coercible.iterator(.{ .kind = .unset }); + while (it.next()) |idx| { + var coerce_block = block.makeSubBlock(); + defer coerce_block.instructions.deinit(sema.gpa); + + const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(u32, idx), field_tys[idx]); + const coerced = sema.coerce(&coerce_block, capture_ty, uncoerced, .unneeded) catch |err| switch (err) { + error.NeededSourceLocation => { + const multi_idx = raw_capture_src.multi_capture; + const src_decl_ptr = sema.mod.declPtr(block.src_decl); + const raw_case_src: Module.SwitchProngSrc = .{ .multi = .{ .prong = multi_idx, .item = @intCast(u32, idx) } }; + const case_src = raw_case_src.resolve(sema.gpa, src_decl_ptr, switch_node_offset, .none); + _ = try sema.coerce(&coerce_block, capture_ty, uncoerced, case_src); + unreachable; + }, + else => |e| return e, + }; + _ = try coerce_block.addBr(capture_block_inst, coerced); + + try cases_extra.ensureUnusedCapacity(3 + coerce_block.instructions.items.len); + cases_extra.appendAssumeCapacity(1); // items_len + cases_extra.appendAssumeCapacity(@intCast(u32, coerce_block.instructions.items.len)); // body_len + cases_extra.appendAssumeCapacity(@enumToInt(case_vals[idx])); // item + cases_extra.appendSliceAssumeCapacity(coerce_block.instructions.items); // body + } + } + const else_body_len = len: { + // 'else' prong uses a bitcast + var coerce_block = block.makeSubBlock(); + defer coerce_block.instructions.deinit(sema.gpa); + + const first_imc = in_mem_coercible.findFirstSet().?; + const uncoerced = try coerce_block.addStructFieldVal(spa.operand, @intCast(u32, first_imc), field_tys[first_imc]); + const coerced = try coerce_block.addBitCast(capture_ty, uncoerced); + _ = try coerce_block.addBr(capture_block_inst, coerced); + + try cases_extra.appendSlice(coerce_block.instructions.items); + break :len coerce_block.instructions.items.len; + }; + + try sema.air_extra.ensureUnusedCapacity(sema.gpa, @typeInfo(Air.SwitchBr).Struct.fields.len + + cases_extra.items.len + + @typeInfo(Air.Block).Struct.fields.len + + 1); + + const switch_br_inst = @intCast(u32, sema.air_instructions.len); + try sema.air_instructions.append(sema.gpa, .{ + .tag = .switch_br, + .data = .{ .pl_op = .{ + .operand = spa.cond, + .payload = sema.addExtraAssumeCapacity(Air.SwitchBr{ + .cases_len = @intCast(u32, prong_count), + .else_body_len = @intCast(u32, else_body_len), + }), + } }, + }); + sema.air_extra.appendSliceAssumeCapacity(cases_extra.items); + + // Set up block body + sema.air_instructions.items(.data)[capture_block_inst].ty_pl.payload = sema.addExtraAssumeCapacity(Air.Block{ + .body_len = 1, + }); + sema.air_extra.appendAssumeCapacity(switch_br_inst); + + return Air.indexToRef(capture_block_inst); }, .ErrorSet => { if (capture_byref) { @@ -11078,6 +11289,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r .parent_block = block, .operand = raw_operand.val, .operand_ptr = raw_operand.ptr, + .cond = operand, .else_error_ty = else_error_ty, .switch_block_inst = inst, .tag_capture_inst = tag_capture_inst, diff --git a/test/behavior/switch.zig b/test/behavior/switch.zig index 3f6cd3729873..72a36c988362 100644 --- a/test/behavior/switch.zig +++ b/test/behavior/switch.zig @@ -1,5 +1,6 @@ const builtin = @import("builtin"); const std = @import("std"); +const assert = std.debug.assert; const expect = std.testing.expect; const expectError = std.testing.expectError; const expectEqual = std.testing.expectEqual; @@ -717,3 +718,70 @@ test "comptime inline switch" { try expectEqual(u32, value); } + +test "switch capture peer type resolution" { + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + + const U = union(enum) { + a: u32, + b: u64, + fn innerVal(u: @This()) u64 { + switch (u) { + .a, .b => |x| return x, + } + } + }; + + try expectEqual(@as(u64, 100), U.innerVal(.{ .a = 100 })); + try expectEqual(@as(u64, 200), U.innerVal(.{ .b = 200 })); +} + +test "switch capture peer type resolution for in-memory coercible payloads" { + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + + const T1 = c_int; + const T2 = @Type(@typeInfo(T1)); + + comptime assert(T1 != T2); + + const U = union(enum) { + a: T1, + b: T2, + fn innerVal(u: @This()) c_int { + switch (u) { + .a, .b => |x| return x, + } + } + }; + + try expectEqual(@as(c_int, 100), U.innerVal(.{ .a = 100 })); + try expectEqual(@as(c_int, 200), U.innerVal(.{ .b = 200 })); +} + +test "switch pointer capture peer type resolution" { + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + + const T1 = c_int; + const T2 = @Type(@typeInfo(T1)); + + comptime assert(T1 != T2); + + const U = union(enum) { + a: T1, + b: T2, + fn innerVal(u: *@This()) *c_int { + switch (u.*) { + .a, .b => |*ptr| return ptr, + } + } + }; + + var ua: U = .{ .a = 100 }; + var ub: U = .{ .b = 200 }; + + ua.innerVal().* = 111; + ub.innerVal().* = 222; + + try expectEqual(U{ .a = 111 }, ua); + try expectEqual(U{ .b = 222 }, ub); +} diff --git a/test/cases/compile_errors/switch_capture_incompatible_types.zig b/test/cases/compile_errors/switch_capture_incompatible_types.zig new file mode 100644 index 000000000000..b6de7d5bf5b3 --- /dev/null +++ b/test/cases/compile_errors/switch_capture_incompatible_types.zig @@ -0,0 +1,27 @@ +export fn f() void { + const U = union(enum) { a: u32, b: *u8 }; + var u: U = undefined; + switch (u) { + .a, .b => |val| _ = val, + } +} + +export fn g() void { + const U = union(enum) { a: u64, b: u32 }; + var u: U = undefined; + switch (u) { + .a, .b => |*ptr| _ = ptr, + } +} + +// error +// backend=stage2 +// target=native +// +// :5:20: error: capture group with incompatible types +// :5:20: note: incompatible types: 'u32' and '*u8' +// :5:10: note: type 'u32' here +// :5:14: note: type '*u8' here +// :13:20: error: capture group with incompatible types +// :13:14: note: pointer type child 'u32' cannot cast into resolved pointer type child 'u64' +// :13:20: note: this coercion is only possible when capturing by value