diff --git a/src/Module.zig b/src/Module.zig index 6a4575394acf..fdb782ee1a1b 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -2840,6 +2840,7 @@ pub fn resolvePeerTypes(self: *Module, scope: *Scope, instructions: []*Inst) !Ty chosen = candidate; continue; } + if (chosen.ty.isInt() and candidate.ty.isInt() and chosen.ty.isSignedInt() == candidate.ty.isSignedInt()) @@ -2950,6 +2951,14 @@ pub fn coerce(self: *Module, scope: *Scope, dest_type: Type, inst: *Inst) !*Inst } } + // error set widening + if (inst.ty.zigTypeTag() == .ErrorSet and dest_type.zigTypeTag() == .ErrorSet) { + if (dest_type.errorSetFitsInAnother(inst.ty)) { + inst.ty = dest_type; + return inst; + } + } + // comptime known number to other number if (try self.coerceNum(scope, dest_type, inst)) |some| return some; diff --git a/src/type.zig b/src/type.zig index 9e2cd321f01e..bacc5789793f 100644 --- a/src/type.zig +++ b/src/type.zig @@ -261,10 +261,54 @@ pub const Type = extern union { var buf_b: Payload.ElemType = undefined; return a.optionalChild(&buf_a).eql(b.optionalChild(&buf_b)); }, + .ErrorSet => { + if (a.tag() == .anyerror) { + if (b.tag() == .anyerror) + return true + else + return false; + } else if (b.tag() == .anyerror) { + return false; + } + + if (a.tag() == .error_set_single and b.tag() == .error_set_single) { + return std.mem.eql(u8, a.getErrs().err_single, b.getErrs().err_single); + } + if (a.tag() == .error_set_single and b.tag() == .error_set) { + var b_fields = b.getErrs().multiple; + if (b_fields.size != 1) return false; + return b_fields.contains(a.getErrs().err_single); + } + + if (b.tag() == .error_set_single and a.tag() == .error_set) { + var a_fields = a.getErrs().multiple; + if (a_fields.size != 1) return false; + return a_fields.contains(b.getErrs().err_single); + } + + if (a.tag() == .error_set and b.tag() == .error_set) { + // they both have to be >=1 size error sets + var a_fields = a.getErrs().multiple; + var b_fields = b.getErrs().multiple; + if (a_fields.size != b_fields.size) + return false; + var a_fields_it = a_fields.iterator(); + while (a_fields_it.next()) |entry| { + if (!b_fields.contains(entry.key)) { + return false; + } + } + return true; + } + unreachable; + }, + .ErrorUnion => { + const a_data = a.castTag(.error_union).?.data; + const b_data = b.castTag(.error_union).?.data; + return a_data.error_set.eql(b_data.error_set) and a_data.payload.eql(b_data.payload); + }, .Float, .Struct, - .ErrorUnion, - .ErrorSet, .Enum, .Union, .BoundFn, @@ -1596,6 +1640,123 @@ pub const Type = extern union { }; } + /// Asserts the type is error_set or error_set_single or anyerror and that if it is error_set, it's decl has been analyzed + /// If it is error_set_single, it will return the []const u8, otherwise a pointer to the map with the error_set or void for anyerror + pub fn getErrs(self: Type) union(enum) { err_single: []const u8, multiple: *std.StringHashMapUnmanaged(u16), anyerror: void } { + return switch (self.tag()) { + .u8, + .i8, + .u16, + .i16, + .u32, + .i32, + .u64, + .i64, + .usize, + .isize, + .c_short, + .c_ushort, + .c_int, + .c_uint, + .c_long, + .c_ulong, + .c_longlong, + .c_ulonglong, + .c_longdouble, + .f16, + .f32, + .f64, + .f128, + .c_void, + .bool, + .void, + .type, + .comptime_int, + .comptime_float, + .noreturn, + .@"null", + .@"undefined", + .fn_noreturn_no_args, + .fn_void_no_args, + .fn_naked_noreturn_no_args, + .fn_ccc_void_no_args, + .function, + .int_unsigned, + .int_signed, + .optional, + .optional_single_const_pointer, + .optional_single_mut_pointer, + .enum_literal, + .error_union, + .@"anyframe", + .anyframe_T, + .anyerror_void_error_union, + .empty_struct, + .array, + .array_sentinel, + .single_const_pointer, + .single_mut_pointer, + .many_const_pointer, + .many_mut_pointer, + .c_const_pointer, + .c_mut_pointer, + .const_slice, + .mut_slice, + .array_u8, + .array_u8_sentinel_0, + .const_slice_u8, + .single_const_pointer_to_comptime_int, + .pointer, + .inferred_alloc_mut, + .inferred_alloc_const, + => unreachable, + .anyerror => return .{ .anyerror = {} }, + .error_set_single => return .{ .err_single = self.castTag(.error_set_single).?.data }, + .error_set => return .{ .multiple = &self.castTag(.error_set).?.data.typed_value.most_recent.typed_value.val.castTag(.error_set).?.data.fields }, + }; + } + /// Asserts the type is error_set or error_set_single or anyerror and that if it is error_set, it's decl has been analyzed. + /// Returns true if fitee can fit into fitter, and false if not. + pub fn errorSetFitsInAnother(fitter: Type, fitee: Type) bool { + if (fitter.eql(fitee)) return true; + switch (fitee.getErrs()) { + .multiple => |fitee_set| { + switch (fitter.getErrs()) { + .multiple => |fitter_set| { + // we can do < because if they were equal then Type.eql wouldn't have let control flow get this far + if (fitee_set.size < fitter_set.size) { + var it = fitter_set.iterator(); + while (it.next()) |entry| { + // the smaller set has a key not in the larger set + if (fitee_set.get(entry.key) == null) return false; + } + return true; + } else return false; + }, + // we return false, because if they were equal Type.eql would have caught and if not, then not coercible + .err_single => return false, + // any set can fit into anyerror + .anyerror => return true, + } + }, + .err_single => |fitee_name| { + switch (fitter.getErrs()) { + .multiple => |fitter_set| { + // the smaller set has a key not in the larger set + if (fitter_set.get(fitee_name) == null) return false; + return true; + }, + // we return false, because if they were equal Type.eql would have caught and if not, then not coercible + .err_single => return false, + // any set can fit into anyerror + .anyerror => return true, + } + }, + + // if they were both anyerror, then Type.eql would have caught it, and if not, then nothing else can fit into anyerorr + .anyerror => return false, + } + } /// Asserts the type is a pointer or array type. pub fn elemType(self: Type) Type { return switch (self.tag()) { diff --git a/src/value.zig b/src/value.zig index 11c385b44654..905ac1da1e69 100644 --- a/src/value.zig +++ b/src/value.zig @@ -1383,6 +1383,8 @@ pub const Value = extern union { const a_name = a.castTag(.enum_literal).?.data; const b_name = b.castTag(.enum_literal).?.data; return std.mem.eql(u8, a_name, b_name); + } else if (a.tag() == .@"error" and b.tag() == .@"error") { + return a.castTag(.@"error").?.data.value == b.castTag(.@"error").?.data.value; } } if (a.isType() and b.isType()) { diff --git a/src/zir_sema.zig b/src/zir_sema.zig index dbe10c4bdea4..b915a74d0a5a 100644 --- a/src/zir_sema.zig +++ b/src/zir_sema.zig @@ -1138,7 +1138,74 @@ fn analyzeInstErrorSet(mod: *Module, scope: *Scope, inst: *zir.Inst.ErrorSet) In fn analyzeInstMergeErrorSets(mod: *Module, scope: *Scope, inst: *zir.Inst.BinOp) InnerError!*Inst { const tracy = trace(@src()); defer tracy.end(); - return mod.fail(scope, inst.base.src, "TODO implement merge_error_sets", .{}); + + const rhs_fields = (try resolveType(mod, scope, inst.positionals.rhs)).getErrs(); + const lhs_fields = (try resolveType(mod, scope, inst.positionals.lhs)).getErrs(); + if (lhs_fields == .anyerror or rhs_fields == .anyerror) + return mod.constInst(scope, inst.base.src, .{ + .ty = Type.initTag(.type), + .val = Value.initTag(.anyerror_type), + }); + // The declarations arena will store the hashmap. + var new_decl_arena = std.heap.ArenaAllocator.init(mod.gpa); + errdefer new_decl_arena.deinit(); + + const payload = try scope.arena().create(Value.Payload.ErrorSet); + payload.* = .{ + .base = .{ .tag = .error_set }, + .data = .{ + .fields = .{}, + .decl = undefined, // populated below + }, + }; + try payload.data.fields.ensureCapacity(&new_decl_arena.allocator, @intCast(u32, switch (rhs_fields) { + .err_single => 1, + .multiple => |mul| mul.size, + else => unreachable, + } + switch (lhs_fields) { + .err_single => 1, + .multiple => |mul| mul.size, + else => unreachable, + })); + + switch (lhs_fields) { + .err_single => |name| { + const num = mod.global_error_set.get(name).?; + payload.data.fields.putAssumeCapacity(name, num); + }, + .multiple => |multiple| { + var it = multiple.iterator(); + while (it.next()) |entry| { + payload.data.fields.putAssumeCapacity(entry.key, entry.value); + } + }, + else => unreachable, + } + + switch (rhs_fields) { + .err_single => |name| { + const num = mod.global_error_set.get(name).?; + payload.data.fields.putAssumeCapacity(name, num); + }, + .multiple => |multiple| { + var it = multiple.iterator(); + while (it.next()) |name| { + const entry = try mod.getErrorValue(name.key); + payload.data.fields.putAssumeCapacity(entry.key, entry.value); + } + }, + else => unreachable, + } + const new_decl = try mod.createAnonymousDecl(scope, &new_decl_arena, .{ + .ty = Type.initTag(.type), + .val = Value.initPayload(&payload.base), + }); + payload.data.decl = new_decl; + + return mod.constInst(scope, inst.base.src, .{ + .ty = Type.initTag(.type), + .val = Value.initPayload(&payload.base), + }); } fn analyzeInstEnumLiteral(mod: *Module, scope: *Scope, inst: *zir.Inst.EnumLiteral) InnerError!*Inst { @@ -1632,8 +1699,39 @@ fn validateSwitch(mod: *Module, scope: *Scope, target: *Inst, inst: *zir.Inst.Sw // validate for duplicate items/missing else prong switch (target.ty.zigTypeTag()) { .Enum => return mod.fail(scope, inst.base.src, "TODO validateSwitch .Enum", .{}), - .ErrorSet => return mod.fail(scope, inst.base.src, "TODO validateSwitch .ErrorSet", .{}), .Union => return mod.fail(scope, inst.base.src, "TODO validateSwitch .Union", .{}), + .ErrorSet => { + const gotten_err_set = target.ty.getErrs(); + const is_anyerror = gotten_err_set == .anyerror; + var is_single = gotten_err_set == .err_single; + if (is_anyerror and !(inst.kw_args.special_prong == .@"else")) { + return mod.fail(scope, inst.base.src, "else prong required when switching on type 'anyerror'", .{}); + } + var seen_values = std.HashMap(Value, usize, Value.hash, Value.eql, std.hash_map.DefaultMaxLoadPercentage).init(mod.gpa); + defer seen_values.deinit(); + for (inst.positionals.items) |item| { + const resolved = try resolveInst(mod, scope, item); + const casted = try mod.coerce(scope, target.ty, resolved); + const val = try mod.resolveConstValue(scope, casted); + const err_name = val.castTag(.@"error").?.data.name; + if (try seen_values.fetchPut(val, item.src)) |prev| { + return mod.fail(scope, item.src, "duplicate switch value", .{}); + } + if (is_anyerror) { + continue; + } else if (is_single and !mem.eql(u8, gotten_err_set.err_single, err_name)) { + return mod.fail(scope, item.src, "expected type '{}', found '{}'", .{ gotten_err_set.err_single, gotten_err_set.err_single }); + } + // TODO print this error, but it will never happen because coerce will handle it above + // else if (gotten_err_set == .multiple) { // we know it is an actual error set + // if (gotten_err_set.multiple.get(err_name) == null) { + // return mod.fail(scope, item.src, "'{}' not a member of destination error set", .{err_name}); + // } + // } + } + if (!is_single and !is_anyerror and gotten_err_set.multiple.size > inst.positionals.items.len and !(inst.kw_args.special_prong == .@"else")) + return mod.fail(scope, inst.base.src, "switch must handle all possibilities", .{}); + }, .Int, .ComptimeInt => { var range_set = @import("RangeSet.zig").init(mod.gpa); defer range_set.deinit(); @@ -1644,7 +1742,6 @@ fn validateSwitch(mod: *Module, scope: *Scope, target: *Inst, inst: *zir.Inst.Sw const start_casted = try mod.coerce(scope, target.ty, start_resolved); const end_resolved = try resolveInst(mod, scope, range.positionals.rhs); const end_casted = try mod.coerce(scope, target.ty, end_resolved); - break :blk try range_set.add( try mod.resolveConstValue(scope, start_casted), try mod.resolveConstValue(scope, end_casted), @@ -2049,7 +2146,12 @@ fn analyzeInstCmp( if (!is_equality_cmp) { return mod.fail(scope, inst.base.src, "{s} operator not allowed for errors", .{@tagName(op)}); } - return mod.fail(scope, inst.base.src, "TODO implement equality comparison between errors", .{}); + const lhs_val = lhs.value(); + const rhs_val = rhs.value(); + if (lhs_val != null and rhs_val != null) { + return mod.constBool(scope, inst.base.src, (lhs_val.?.castTag(.@"error").?.data.value == rhs_val.?.castTag(.@"error").?.data.value) == (op == .eq)); + } + return mod.fail(scope, inst.base.src, "TODO runtime error comparison", .{}); } else if (lhs.ty.isNumeric() and rhs.ty.isNumeric()) { // This operation allows any combination of integer and float types, regardless of the // signed-ness, comptime-ness, and bit-width. So peer type resolution is incorrect for @@ -2061,7 +2163,7 @@ fn analyzeInstCmp( } return mod.constBool(scope, inst.base.src, lhs.value().?.eql(rhs.value().?) == (op == .eq)); } - return mod.fail(scope, inst.base.src, "TODO implement more cmp analysis", .{}); + return mod.fail(scope, inst.base.src, "TODO implement more cmp analysis for type {} and {}", .{ @tagName(lhs_ty_tag), @tagName(rhs_ty_tag) }); } fn analyzeInstTypeOf(mod: *Module, scope: *Scope, inst: *zir.Inst.UnOp) InnerError!*Inst { diff --git a/test/stage2/test.zig b/test/stage2/test.zig index 6e25dc283b59..c35a529620a5 100644 --- a/test/stage2/test.zig +++ b/test/stage2/test.zig @@ -316,7 +316,223 @@ pub fn addCases(ctx: *TestContext) !void { "Hello, World!\n", ); } + { + var case = ctx.exe("switch and coerce error sets", linux_x64); + case.addCompareOutput( + \\export fn _start() noreturn { + \\ const T = error{ A, B, C, D }; + \\ const e = T.B; + \\ + \\ switch (e) { + \\ error.B => condPrint(), + \\ else => unreachable, + \\ } + \\ exit(); + \\} + \\ + \\fn condPrint() void { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (1), + \\ [arg1] "{rdi}" (1), + \\ [arg2] "{rsi}" (@ptrToInt("Reached\n")), + \\ [arg3] "{rdx}" (8) + \\ : "rcx", "r11", "memory" + \\ ); + \\ return; + \\} + \\ + \\fn exit() noreturn { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (231), + \\ [arg1] "{rdi}" (0) + \\ : "rcx", "r11", "memory" + \\ ); + \\ unreachable; + \\} + , "Reached\n"); + case.addError( + \\export fn _start() noreturn { + \\ const T = error{ A, B, C, D }; + \\ const e: T = T.B; + \\ + \\ switch (e) { + \\ error.Z => {}, + \\ } + \\ unreachable; // because it will give error above + \\} + , &[_][]const u8{":6:14: error: expected _start__anon_12, found error{Z}"}); + case.addError( + \\export fn _start() noreturn { + \\ const T = error{ A, B, C, D }; + \\ const e: T = T.B; + \\ + \\ switch (e) { + \\ error.B => {}, + \\ } + \\ unreachable; // because it will give error above + \\} + , &[_][]const u8{":5:5: error: switch must handle all possibilities"}); + case.addError( + \\export fn _start() noreturn { + \\ const T = error{ A, B, C, D }; + \\ const e: anyerror = T.B; + \\ + \\ switch (e) { + \\ error.B => {}, + \\ } + \\ unreachable; // because it will give error above + \\} + , &[_][]const u8{":5:5: error: else prong required when switching on type 'anyerror'"}); + } + { + var case = ctx.exe("merge error sets", linux_x64); + case.addCompareOutput( + \\export fn _start() noreturn { + \\ const b = error{ A, B, D } || error { A, B, C } == error { A, B, C, D }; + \\ if (b) { + \\ condPrint(); + \\ } + \\ exit(); + \\} + \\fn condPrint() void { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (1), + \\ [arg1] "{rdi}" (1), + \\ [arg2] "{rsi}" (@ptrToInt("The Types Were Equal\n")), + \\ [arg3] "{rdx}" (21) + \\ : "rcx", "r11", "memory" + \\ ); + \\ return; + \\} + \\fn exit() noreturn { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (231), + \\ [arg1] "{rdi}" (0) + \\ : "rcx", "r11", "memory" + \\ ); + \\ unreachable; + \\} + , + "The Types Were Equal\n", + ); + case.addCompareOutput( + \\export fn _start() noreturn { + \\ const T: type = error{ A, B, D } || error { A, B, C }; + \\ const x: T = error.D; + \\ assert(T == error { A, B, C, D }); + \\ + \\ exit(); + \\} + \\ + \\pub fn assert(ok: bool) void { + \\ if (!ok) unreachable; // assertion failure + \\} + \\ + \\fn condPrint() void { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (1), + \\ [arg1] "{rdi}" (1), + \\ [arg2] "{rsi}" (@ptrToInt("The Types Were Equal\n")), + \\ [arg3] "{rdx}" (21) + \\ : "rcx", "r11", "memory" + \\ ); + \\ return; + \\} + \\ + \\fn exit() noreturn { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (231), + \\ [arg1] "{rdi}" (0) + \\ : "rcx", "r11", "memory" + \\ ); + \\ unreachable; + \\} + , + "", + ); + } + { + var case = ctx.exe("comptime type equality", linux_x64); + case.addCompareOutput( + \\export fn _start() noreturn { + \\ const b1 = error{ T, V, E } == error{ V, T, E }; + \\ const b2 = u32 == u8; + \\ + \\ if (!b2 and b1) { + \\ condPrint(); + \\ } + \\ + \\ exit(); + \\} + \\ + \\fn condPrint() void { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (1), + \\ [arg1] "{rdi}" (1), + \\ [arg2] "{rsi}" (@ptrToInt("The Types Were Equal\n")), + \\ [arg3] "{rdx}" (21) + \\ : "rcx", "r11", "memory" + \\ ); + \\ return; + \\} + \\ + \\fn exit() noreturn { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (231), + \\ [arg1] "{rdi}" (0) + \\ : "rcx", "r11", "memory" + \\ ); + \\ unreachable; + \\} + , + "The Types Were Equal\n", + ); + } + { + var case = ctx.exe("comptime error equality", linux_x64); + case.addCompareOutput( + \\export fn _start() noreturn { + \\ if (error.T == error.T) { + \\ condPrint(); + \\ } + \\ + \\ exit(); + \\} + \\ + \\fn condPrint() void { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (1), + \\ [arg1] "{rdi}" (1), + \\ [arg2] "{rsi}" (@ptrToInt("The errs were equal\n")), + \\ [arg3] "{rdx}" (20) + \\ : "rcx", "r11", "memory" + \\ ); + \\ return; + \\} + \\ + \\fn exit() noreturn { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (231), + \\ [arg1] "{rdi}" (0) + \\ : "rcx", "r11", "memory" + \\ ); + \\ unreachable; + \\} + , + "The errs were equal\n", + ); + } { var case = ctx.exe("adding numbers at runtime and comptime", linux_x64); case.addCompareOutput(