From 43aa51e3d8399d27e70c62793f443df513a5ac81 Mon Sep 17 00:00:00 2001 From: g-w1 Date: Thu, 31 Dec 2020 17:10:49 -0500 Subject: [PATCH] improve stage2 to allow catch at comptime: * add error_union value tag. * add analyzeIsErr * add Value.isError * add TZIR wrap_errunion_payload and wrap_errunion_err for wrapping from T -> E!T and E -> E!T * add anlyzeInstUnwrapErrCode and analyzeInstUnwrapErr * add analyzeInstEnsureErrPayloadVoid: * Fix bug in astgen where .? was in wrong spot causing segfault when using this check * add wrapErrorUnion * add comptime error comparison for tests * tests!!! --- src/Module.zig | 70 +++++++++++++++++++++++-- src/astgen.zig | 4 +- src/codegen.zig | 62 +++++++++++++++++++++++ src/ir.zig | 18 +++++++ src/value.zig | 118 ++++++++++++++++++++++++++++++++++++++++++- src/zir.zig | 12 +++++ src/zir_sema.zig | 107 ++++++++++++++++++++++++++++++++++++--- test/stage2/test.zig | 107 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 485 insertions(+), 13 deletions(-) diff --git a/src/Module.zig b/src/Module.zig index 8de03b54ab12..5475877189a8 100644 --- a/src/Module.zig +++ b/src/Module.zig @@ -2533,7 +2533,15 @@ pub fn analyzeIsNull( } pub fn analyzeIsErr(self: *Module, scope: *Scope, src: usize, operand: *Inst) InnerError!*Inst { - return self.fail(scope, src, "TODO implement analysis of iserr", .{}); + const ot = operand.ty.zigTypeTag(); + if (ot != .ErrorSet and ot != .ErrorUnion) return self.constBool(scope, src, false); + if (ot == .ErrorSet) return self.constBool(scope, src, true); + assert(ot == .ErrorUnion); + if (operand.value()) |err_union| { + return self.constBool(scope, src, err_union.getError() != null); + } + const b = try self.requireRuntimeBlock(scope, src); + return self.addUnOp(b, src, Type.initTag(.bool), .is_err, operand); } pub fn analyzeSlice(self: *Module, scope: *Scope, src: usize, array_ptr: *Inst, start: *Inst, end_opt: ?*Inst, sentinel_opt: ?*Inst) InnerError!*Inst { @@ -2836,6 +2844,56 @@ fn wrapOptional(self: *Module, scope: *Scope, dest_type: Type, inst: *Inst) !*In return self.addUnOp(b, inst.src, dest_type, .wrap_optional, inst); } +fn wrapErrorUnion(self: *Module, scope: *Scope, dest_type: Type, inst: *Inst) !*Inst { + // TODO deal with inferred error sets + const err_union = dest_type.castTag(.error_union).?; + if (inst.value()) |val| { + return self.constInst(scope, inst.src, .{ + .ty = dest_type, + // creating a SubValue for the error_union payload + .val = try Value.Tag.error_union.create( + scope.arena(), + blk: { + if (inst.ty.zigTypeTag() != .ErrorSet) { + // .? because we know it has a value + _ = try self.coerce(scope, err_union.data.payload, inst); + break :blk val; + } else { + switch (err_union.data.error_set.tag()) { + // TODO these should be moved to coerce with error union widening + .anyerror => break :blk val, + .error_set_single => { + const n = err_union.data.error_set.castTag(.error_set_single).?.data; + if (!mem.eql(u8, val.castTag(.@"error").?.data.name, n)) + return self.fail(scope, inst.src, "expected type '{}', found type '{}'", .{ err_union.data.error_set, inst.ty }); + break :blk val; + }, + .error_set => { + const f = err_union.data.error_set.castTag(.error_set).?.data.typed_value.most_recent.typed_value.val.castTag(.error_set).?.data.fields; + if (f.get(val.castTag(.@"error").?.data.name) == null) + return self.fail(scope, inst.src, "expected type '{}', found type '{}'", .{ err_union.data.error_set, inst.ty }); + break :blk val; + }, + else => unreachable, + } + } + }, + ), + }); + } + + const b = try self.requireRuntimeBlock(scope, inst.src); + + // we are coercing from E to E!T + if (inst.ty.zigTypeTag() == .ErrorSet) { + var coerced = try self.coerce(scope, err_union.data.error_set, inst); + return self.addUnOp(b, inst.src, dest_type, .wrap_errunion_err, coerced); + } else { + var coerced = try self.coerce(scope, err_union.data.payload, inst); + return self.addUnOp(b, inst.src, dest_type, .wrap_errunion_payload, coerced); + } +} + fn makeIntType(self: *Module, scope: *Scope, signed: bool, bits: u16) !Type { const int_payload = try scope.arena().create(Type.Payload.Bits); int_payload.* = .{ @@ -2902,7 +2960,7 @@ pub fn resolvePeerTypes(self: *Module, scope: *Scope, instructions: []*Inst) !Ty return chosen.ty; } -pub fn coerce(self: *Module, scope: *Scope, dest_type: Type, inst: *Inst) !*Inst { +pub fn coerce(self: *Module, scope: *Scope, dest_type: Type, inst: *Inst) InnerError!*Inst { // If the types are the same, we can return the operand. if (dest_type.eql(inst.ty)) return inst; @@ -2936,6 +2994,11 @@ pub fn coerce(self: *Module, scope: *Scope, dest_type: Type, inst: *Inst) !*Inst } } + // T to E!T or E to E!T + if (dest_type.tag() == .error_union) { + return try self.wrapErrorUnion(scope, dest_type, inst); + } + // Coercions where the source is a single pointer to an array. src_array_ptr: { if (!inst.ty.isSinglePointer()) break :src_array_ptr; @@ -3014,7 +3077,7 @@ pub fn coerce(self: *Module, scope: *Scope, dest_type: Type, inst: *Inst) !*Inst return self.fail(scope, inst.src, "expected {}, found {}", .{ dest_type, inst.ty }); } -pub fn coerceNum(self: *Module, scope: *Scope, dest_type: Type, inst: *Inst) !?*Inst { +pub fn coerceNum(self: *Module, scope: *Scope, dest_type: Type, inst: *Inst) InnerError!?*Inst { const val = inst.value() orelse return null; const src_zig_tag = inst.ty.zigTypeTag(); const dst_zig_tag = dest_type.zigTypeTag(); @@ -3504,6 +3567,7 @@ pub fn dumpInst(self: *Module, scope: *Scope, inst: *Inst) void { pub const PanicId = enum { unreach, unwrap_null, + unwrap_errunion, }; pub fn addSafetyCheck(mod: *Module, parent_block: *Scope.Block, ok: *Inst, panic_id: PanicId) !void { diff --git a/src/astgen.zig b/src/astgen.zig index ece16d70da3e..a4850a99c8c0 100644 --- a/src/astgen.zig +++ b/src/astgen.zig @@ -1762,8 +1762,8 @@ const CondKind = union(enum) { fn thenSubScope(self: CondKind, mod: *Module, then_scope: *Scope.GenZIR, src: usize, payload_node: ?*ast.Node) !*Scope { if (self == .bool) return &then_scope.base; - - const payload = payload_node.?.castTag(.PointerPayload) orelse { + const payload = + payload_node.?.castTag(.PointerPayload) orelse { // condition is error union and payload is not explicitly ignored _ = try addZIRUnOp(mod, &then_scope.base, src, .ensure_err_payload_void, self.err_union.?); return &then_scope.base; diff --git a/src/codegen.zig b/src/codegen.zig index 904fda0debb3..3caaba91625a 100644 --- a/src/codegen.zig +++ b/src/codegen.zig @@ -880,7 +880,13 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { .unreach => return MCValue{ .unreach = {} }, .optional_payload => return self.genOptionalPayload(inst.castTag(.optional_payload).?), .optional_payload_ptr => return self.genOptionalPayloadPtr(inst.castTag(.optional_payload_ptr).?), + .unwrap_errunion_err => return self.genUnwrapErrErr(inst.castTag(.unwrap_errunion_err).?), + .unwrap_errunion_payload => return self.genUnwrapErrPayload(inst.castTag(.unwrap_errunion_payload).?), + .unwrap_errunion_err_ptr => return self.genUnwrapErrErrPtr(inst.castTag(.unwrap_errunion_err_ptr).?), + .unwrap_errunion_payload_ptr => return self.genUnwrapErrPayloadPtr(inst.castTag(.unwrap_errunion_payload_ptr).?), .wrap_optional => return self.genWrapOptional(inst.castTag(.wrap_optional).?), + .wrap_errunion_payload => return self.genWrapErrUnionPayload(inst.castTag(.wrap_errunion_payload).?), + .wrap_errunion_err => return self.genWrapErrUnionErr(inst.castTag(.wrap_errunion_err).?), .varptr => return self.genVarPtr(inst.castTag(.varptr).?), .xor => return self.genXor(inst.castTag(.xor).?), } @@ -1141,6 +1147,41 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { } } + fn genUnwrapErrErr(self: *Self, inst: *ir.Inst.UnOp) !MCValue { + // No side effects, so if it's unreferenced, do nothing. + if (inst.base.isUnused()) + return MCValue.dead; + switch (arch) { + else => return self.fail(inst.base.src, "TODO implement unwrap error union error for {}", .{self.target.cpu.arch}), + } + } + + fn genUnwrapErrPayload(self: *Self, inst: *ir.Inst.UnOp) !MCValue { + // No side effects, so if it's unreferenced, do nothing. + if (inst.base.isUnused()) + return MCValue.dead; + switch (arch) { + else => return self.fail(inst.base.src, "TODO implement unwrap error union payload for {}", .{self.target.cpu.arch}), + } + } + // *(E!T) -> E + fn genUnwrapErrErrPtr(self: *Self, inst: *ir.Inst.UnOp) !MCValue { + // No side effects, so if it's unreferenced, do nothing. + if (inst.base.isUnused()) + return MCValue.dead; + switch (arch) { + else => return self.fail(inst.base.src, "TODO implement unwrap error union error ptr for {}", .{self.target.cpu.arch}), + } + } + // *(E!T) -> *T + fn genUnwrapErrPayloadPtr(self: *Self, inst: *ir.Inst.UnOp) !MCValue { + // No side effects, so if it's unreferenced, do nothing. + if (inst.base.isUnused()) + return MCValue.dead; + switch (arch) { + else => return self.fail(inst.base.src, "TODO implement unwrap error union payload ptr for {}", .{self.target.cpu.arch}), + } + } fn genWrapOptional(self: *Self, inst: *ir.Inst.UnOp) !MCValue { const optional_ty = inst.base.ty; @@ -1157,6 +1198,27 @@ fn Function(comptime arch: std.Target.Cpu.Arch) type { } } + /// T to E!T + fn genWrapErrUnionPayload(self: *Self, inst: *ir.Inst.UnOp) !MCValue { + // No side effects, so if it's unreferenced, do nothing. + if (inst.base.isUnused()) + return MCValue.dead; + + switch (arch) { + else => return self.fail(inst.base.src, "TODO implement wrap errunion payload for {}", .{self.target.cpu.arch}), + } + } + + /// E to E!T + fn genWrapErrUnionErr(self: *Self, inst: *ir.Inst.UnOp) !MCValue { + // No side effects, so if it's unreferenced, do nothing. + if (inst.base.isUnused()) + return MCValue.dead; + + switch (arch) { + else => return self.fail(inst.base.src, "TODO implement wrap errunion error for {}", .{self.target.cpu.arch}), + } + } fn genVarPtr(self: *Self, inst: *ir.Inst.VarPtr) !MCValue { // No side effects, so if it's unreferenced, do nothing. if (inst.base.isUnused()) diff --git a/src/ir.zig b/src/ir.zig index 0e83dbfd5608..aa503c91d10a 100644 --- a/src/ir.zig +++ b/src/ir.zig @@ -114,6 +114,18 @@ pub const Inst = struct { // *?T => *T optional_payload_ptr, wrap_optional, + /// E!T -> T + unwrap_errunion_payload, + /// E!T -> E + unwrap_errunion_err, + /// *(E!T) -> *T + unwrap_errunion_payload_ptr, + /// *(E!T) -> E + unwrap_errunion_err_ptr, + /// wrap from T to E!T + wrap_errunion_payload, + /// wrap from E to E!T + wrap_errunion_err, xor, switchbr, @@ -143,6 +155,12 @@ pub const Inst = struct { .optional_payload, .optional_payload_ptr, .wrap_optional, + .unwrap_errunion_payload, + .unwrap_errunion_err, + .unwrap_errunion_payload_ptr, + .unwrap_errunion_err_ptr, + .wrap_errunion_payload, + .wrap_errunion_err, => UnOp, .add, diff --git a/src/value.zig b/src/value.zig index 50298da68289..a602d08c0675 100644 --- a/src/value.zig +++ b/src/value.zig @@ -102,6 +102,7 @@ pub const Value = extern union { enum_literal, error_set, @"error", + error_union, /// This is a special value that tracks a set of types that have been stored /// to an inferred allocation. It does not support any of the normal value queries. inferred_alloc, @@ -174,6 +175,7 @@ pub const Value = extern union { .ref_val, .repeated, + .error_union, => Payload.SubValue, .bytes, @@ -388,9 +390,17 @@ pub const Value = extern union { return Value{ .ptr_otherwise = &new_payload.base }; }, .@"error" => return self.copyPayloadShallow(allocator, Payload.Error), + .error_union => { + const payload = self.castTag(.error_union).?; + const new_payload = try allocator.create(Payload.SubValue); + new_payload.* = .{ + .base = payload.base, + .data = try payload.data.copy(allocator), + }; + return Value{ .ptr_otherwise = &new_payload.base }; + }, .error_set => return self.copyPayloadShallow(allocator, Payload.ErrorSet), - .inferred_alloc => unreachable, } } @@ -510,6 +520,8 @@ pub const Value = extern union { return out_stream.writeAll("}"); }, .@"error" => return out_stream.print("error.{s}", .{val.castTag(.@"error").?.data.name}), + // TODO to print this it should be error{ Set, Items }!T(val), but we need the type for that + .error_union => return out_stream.print("error_union_val({})", .{val.castTag(.error_union).?.data}), .inferred_alloc => return out_stream.writeAll("(inferred allocation value)"), }; } @@ -622,6 +634,7 @@ pub const Value = extern union { .float_128, .enum_literal, .@"error", + .error_union, .empty_struct_value, .inferred_alloc, => unreachable, @@ -692,6 +705,7 @@ pub const Value = extern union { .empty_array, .enum_literal, .error_set, + .error_union, .@"error", .empty_struct_value, .inferred_alloc, @@ -779,6 +793,7 @@ pub const Value = extern union { .enum_literal, .error_set, .@"error", + .error_union, .empty_struct_value, .inferred_alloc, => unreachable, @@ -865,6 +880,7 @@ pub const Value = extern union { .enum_literal, .error_set, .@"error", + .error_union, .empty_struct_value, .inferred_alloc, => unreachable, @@ -979,6 +995,7 @@ pub const Value = extern union { .enum_literal, .error_set, .@"error", + .error_union, .empty_struct_value, .inferred_alloc, => unreachable, @@ -1069,6 +1086,7 @@ pub const Value = extern union { .enum_literal, .error_set, .@"error", + .error_union, .empty_struct_value, .inferred_alloc, => unreachable, @@ -1228,6 +1246,7 @@ pub const Value = extern union { .enum_literal, .error_set, .@"error", + .error_union, .empty_struct_value, .inferred_alloc, => unreachable, @@ -1305,6 +1324,7 @@ pub const Value = extern union { .enum_literal, .error_set, .@"error", + .error_union, .empty_struct_value, .inferred_alloc, => unreachable, @@ -1543,7 +1563,10 @@ pub const Value = extern union { hasher.update(payload.name); std.hash.autoHash(&hasher, payload.value); }, - + .error_union => { + const payload = self.castTag(.error_union).?.data; + std.hash.autoHash(&hasher, payload.hash()); + }, .inferred_alloc => unreachable, } return hasher.final(); @@ -1621,6 +1644,7 @@ pub const Value = extern union { .enum_literal, .error_set, .@"error", + .error_union, .empty_struct_value, .inferred_alloc, => unreachable, @@ -1707,6 +1731,7 @@ pub const Value = extern union { .enum_literal, .error_set, .@"error", + .error_union, .empty_struct_value, .inferred_alloc, => unreachable, @@ -1810,6 +1835,7 @@ pub const Value = extern union { .enum_literal, .error_set, .@"error", + .error_union, .empty_struct_value, => false, @@ -1820,6 +1846,93 @@ pub const Value = extern union { }; } + /// Valid for all types. Asserts the value is not undefined and not unreachable. + pub fn getError(self: Value) ?[]const u8 { + return switch (self.tag()) { + .ty, + .int_type, + .u8_type, + .i8_type, + .u16_type, + .i16_type, + .u32_type, + .i32_type, + .u64_type, + .i64_type, + .usize_type, + .isize_type, + .c_short_type, + .c_ushort_type, + .c_int_type, + .c_uint_type, + .c_long_type, + .c_ulong_type, + .c_longlong_type, + .c_ulonglong_type, + .c_longdouble_type, + .f16_type, + .f32_type, + .f64_type, + .f128_type, + .c_void_type, + .bool_type, + .void_type, + .type_type, + .anyerror_type, + .comptime_int_type, + .comptime_float_type, + .noreturn_type, + .null_type, + .undefined_type, + .fn_noreturn_no_args_type, + .fn_void_no_args_type, + .fn_naked_noreturn_no_args_type, + .fn_ccc_void_no_args_type, + .single_const_pointer_to_comptime_int_type, + .const_slice_u8_type, + .enum_literal_type, + .anyframe_type, + .zero, + .one, + .null_value, + .empty_array, + .bool_true, + .bool_false, + .function, + .extern_fn, + .variable, + .int_u64, + .int_i64, + .int_big_positive, + .int_big_negative, + .ref_val, + .decl_ref, + .elem_ptr, + .bytes, + .repeated, + .float_16, + .float_32, + .float_64, + .float_128, + .void_value, + .enum_literal, + .error_set, + .empty_struct_value, + => null, + + .error_union => { + const data = self.castTag(.error_union).?.data; + return if (data.tag() == .@"error") + data.castTag(.@"error").?.data.name + else + null; + }, + .@"error" => self.castTag(.@"error").?.data.name, + .undef => unreachable, + .unreachable_value => unreachable, + .inferred_alloc => unreachable, + }; + } /// Valid for all types. Asserts the value is not undefined. pub fn isFloat(self: Value) bool { return switch (self.tag()) { @@ -1908,6 +2021,7 @@ pub const Value = extern union { .void_value, .enum_literal, .@"error", + .error_union, .empty_struct_value, .null_value, => false, diff --git a/src/zir.zig b/src/zir.zig index 30bfeead9bb3..576d571c4c1a 100644 --- a/src/zir.zig +++ b/src/zir.zig @@ -1608,6 +1608,12 @@ const DumpTzir = struct { .optional_payload, .optional_payload_ptr, .wrap_optional, + .wrap_errunion_payload, + .wrap_errunion_err, + .unwrap_errunion_payload, + .unwrap_errunion_err, + .unwrap_errunion_payload_ptr, + .unwrap_errunion_err_ptr, => { const un_op = inst.cast(ir.Inst.UnOp).?; try dtz.findConst(un_op.operand); @@ -1721,6 +1727,12 @@ const DumpTzir = struct { .optional_payload, .optional_payload_ptr, .wrap_optional, + .wrap_errunion_err, + .wrap_errunion_payload, + .unwrap_errunion_err, + .unwrap_errunion_payload, + .unwrap_errunion_payload_ptr, + .unwrap_errunion_err_ptr, => { const un_op = inst.cast(ir.Inst.UnOp).?; const kinky = try dtz.writeInst(writer, un_op.operand); diff --git a/src/zir_sema.zig b/src/zir_sema.zig index f373d7174d76..eb0d21254815 100644 --- a/src/zir_sema.zig +++ b/src/zir_sema.zig @@ -1271,34 +1271,124 @@ fn zirOptionalPayload( fn zirErrUnionPayload(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnOp, safety_check: bool) InnerError!*Inst { const tracy = trace(@src()); defer tracy.end(); - return mod.fail(scope, unwrap.base.src, "TODO implement zir_sema.zirErrUnionPayload", .{}); + + const operand = try resolveInst(mod, scope, unwrap.positionals.operand); + if (operand.ty.zigTypeTag() != .ErrorUnion) + return mod.fail(scope, operand.src, "expected error union type, found '{}'", .{operand.ty}); + + if (operand.value()) |val| { + if (val.getError()) |name| { + return mod.fail(scope, unwrap.base.src, "caught unexpected error '{s}'", .{name}); + } + const data = val.castTag(.error_union).?.data; + return mod.constInst(scope, unwrap.base.src, .{ + .ty = operand.ty.castTag(.error_union).?.data.payload, + .val = data, + }); + } + const b = try mod.requireRuntimeBlock(scope, unwrap.base.src); + if (safety_check and mod.wantSafety(scope)) { + const is_non_err = try mod.addUnOp(b, unwrap.base.src, Type.initTag(.bool), .is_err, operand); + try mod.addSafetyCheck(b, is_non_err, .unwrap_errunion); + } + return mod.addUnOp(b, unwrap.base.src, operand.ty.castTag(.error_union).?.data.payload, .unwrap_errunion_payload, operand); } /// Pointer in, pointer out fn zirErrUnionPayloadPtr(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnOp, safety_check: bool) InnerError!*Inst { const tracy = trace(@src()); defer tracy.end(); - return mod.fail(scope, unwrap.base.src, "TODO implement zir_sema.zirErrUnionPayloadPtr", .{}); + + const operand = try resolveInst(mod, scope, unwrap.positionals.operand); + assert(operand.ty.zigTypeTag() == .Pointer); + + if (operand.ty.elemType().zigTypeTag() != .ErrorUnion) + return mod.fail(scope, unwrap.base.src, "expected error union type, found {}", .{operand.ty.elemType()}); + + const operand_pointer_ty = try mod.simplePtrType(scope, unwrap.base.src, operand.ty.elemType().castTag(.error_union).?.data.payload, !operand.ty.isConstPtr(), .One); + + if (operand.value()) |pointer_val| { + const val = try pointer_val.pointerDeref(scope.arena()); + if (val.getError()) |name| { + return mod.fail(scope, unwrap.base.src, "caught unexpected error '{s}'", .{name}); + } + const data = val.castTag(.error_union).?.data; + // The same Value represents the pointer to the error union and the payload. + return mod.constInst(scope, unwrap.base.src, .{ + .ty = operand_pointer_ty, + .val = try Value.Tag.ref_val.create( + scope.arena(), + data, + ), + }); + } + + const b = try mod.requireRuntimeBlock(scope, unwrap.base.src); + if (safety_check and mod.wantSafety(scope)) { + const is_non_err = try mod.addUnOp(b, unwrap.base.src, Type.initTag(.bool), .is_err, operand); + try mod.addSafetyCheck(b, is_non_err, .unwrap_errunion); + } + return mod.addUnOp(b, unwrap.base.src, operand_pointer_ty, .unwrap_errunion_payload_ptr, operand); } /// Value in, value out fn zirErrUnionCode(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnOp) InnerError!*Inst { const tracy = trace(@src()); defer tracy.end(); - return mod.fail(scope, unwrap.base.src, "TODO implement zir_sema.zirErrUnionCode", .{}); + + const operand = try resolveInst(mod, scope, unwrap.positionals.operand); + if (operand.ty.zigTypeTag() != .ErrorUnion) + return mod.fail(scope, unwrap.base.src, "expected error union type, found '{}'", .{operand.ty}); + + if (operand.value()) |val| { + assert(val.getError() != null); + const data = val.castTag(.error_union).?.data; + return mod.constInst(scope, unwrap.base.src, .{ + .ty = operand.ty.castTag(.error_union).?.data.error_set, + .val = data, + }); + } + + const b = try mod.requireRuntimeBlock(scope, unwrap.base.src); + return mod.addUnOp(b, unwrap.base.src, operand.ty.castTag(.error_union).?.data.payload, .unwrap_errunion_err, operand); } /// Pointer in, value out fn zirErrUnionCodePtr(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnOp) InnerError!*Inst { const tracy = trace(@src()); defer tracy.end(); - return mod.fail(scope, unwrap.base.src, "TODO implement zir_sema.zirErrUnionCodePtr", .{}); + + const operand = try resolveInst(mod, scope, unwrap.positionals.operand); + assert(operand.ty.zigTypeTag() == .Pointer); + + if (operand.ty.elemType().zigTypeTag() != .ErrorUnion) + return mod.fail(scope, unwrap.base.src, "expected error union type, found {}", .{operand.ty.elemType()}); + + if (operand.value()) |pointer_val| { + const val = try pointer_val.pointerDeref(scope.arena()); + assert(val.getError() != null); + const data = val.castTag(.error_union).?.data; + return mod.constInst(scope, unwrap.base.src, .{ + .ty = operand.ty.elemType().castTag(.error_union).?.data.error_set, + .val = data, + }); + } + + const b = try mod.requireRuntimeBlock(scope, unwrap.base.src); + return mod.addUnOp(b, unwrap.base.src, operand.ty.castTag(.error_union).?.data.payload, .unwrap_errunion_err_ptr, operand); } fn zirEnsureErrPayloadVoid(mod: *Module, scope: *Scope, unwrap: *zir.Inst.UnOp) InnerError!*Inst { const tracy = trace(@src()); defer tracy.end(); - return mod.fail(scope, unwrap.base.src, "TODO implement zirEnsureErrPayloadVoid", .{}); + + const operand = try resolveInst(mod, scope, unwrap.positionals.operand); + if (operand.ty.zigTypeTag() != .ErrorUnion) + return mod.fail(scope, unwrap.base.src, "expected error union type, found '{}'", .{operand.ty}); + if (operand.ty.castTag(.error_union).?.data.payload.zigTypeTag() != .Void) { + return mod.fail(scope, unwrap.base.src, "expression value is ignored", .{}); + } + return mod.constVoid(scope, unwrap.base.src); } fn zirFnType(mod: *Module, scope: *Scope, fntype: *zir.Inst.FnType) InnerError!*Inst { @@ -2068,7 +2158,12 @@ fn zirCmp( if (!is_equality_cmp) { return mod.fail(scope, inst.base.src, "{s} operator not allowed for errors", .{@tagName(op)}); } - return mod.fail(scope, inst.base.src, "TODO implement equality comparison between errors", .{}); + if (rhs.value()) |rval| { + if (lhs.value()) |lval| { + return mod.constBool(scope, inst.base.src, (lval.castTag(.@"error").?.data.value == rval.castTag(.@"error").?.data.value) == (op == .eq)); + } + } + return mod.fail(scope, inst.base.src, "TODO implement equality comparison between runtime errors", .{}); } else if (lhs.ty.isNumeric() and rhs.ty.isNumeric()) { // This operation allows any combination of integer and float types, regardless of the // signed-ness, comptime-ness, and bit-width. So peer type resolution is incorrect for diff --git a/test/stage2/test.zig b/test/stage2/test.zig index 78d7eba26221..1fbd329d0dfa 100644 --- a/test/stage2/test.zig +++ b/test/stage2/test.zig @@ -1393,4 +1393,111 @@ pub fn addCases(ctx: *TestContext) !void { "", ); } + { + var case = ctx.exe("catch at comptime", linux_x64); + case.addCompareOutput( + \\export fn _start() noreturn { + \\ const i: anyerror!u64 = 0; + \\ const caught = i catch 5; + \\ assert(caught == 0); + \\ exit(); + \\} + \\fn assert(b: bool) void { + \\ if (!b) unreachable; + \\} + \\fn exit() noreturn { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (231), + \\ [arg1] "{rdi}" (0) + \\ : "rcx", "r11", "memory" + \\ ); + \\ unreachable; + \\} + , + "", + ); + case.addCompareOutput( + \\export fn _start() noreturn { + \\ const i: anyerror!u64 = error.B; + \\ const caught = i catch 5; + \\ assert(caught == 5); + \\ exit(); + \\} + \\fn assert(b: bool) void { + \\ if (!b) unreachable; + \\} + \\fn exit() noreturn { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (231), + \\ [arg1] "{rdi}" (0) + \\ : "rcx", "r11", "memory" + \\ ); + \\ unreachable; + \\} + , + "", + ); + case.addCompareOutput( + \\export fn _start() noreturn { + \\ const a: anyerror!comptime_int = 42; + \\ const b: *const comptime_int = &(a catch unreachable); + \\ assert(b.* == 42); + \\ + \\ exit(); + \\} + \\fn assert(b: bool) void { + \\ if (!b) unreachable; // assertion failure + \\} + \\fn exit() noreturn { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (231), + \\ [arg1] "{rdi}" (0) + \\ : "rcx", "r11", "memory" + \\ ); + \\ unreachable; + \\} + , ""); + case.addCompareOutput( + \\export fn _start() noreturn { + \\const a: anyerror!u32 = error.B; + \\_ = &(a catch |err| assert(err == error.B)); + \\exit(); + \\} + \\fn assert(b: bool) void { + \\ if (!b) unreachable; + \\} + \\fn exit() noreturn { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (231), + \\ [arg1] "{rdi}" (0) + \\ : "rcx", "r11", "memory" + \\ ); + \\ unreachable; + \\} + , ""); + case.addCompareOutput( + \\export fn _start() noreturn { + \\ const a: anyerror!u32 = error.Bar; + \\ a catch |err| assert(err == error.Bar); + \\ + \\ exit(); + \\} + \\fn assert(b: bool) void { + \\ if (!b) unreachable; + \\} + \\fn exit() noreturn { + \\ asm volatile ("syscall" + \\ : + \\ : [number] "{rax}" (231), + \\ [arg1] "{rdi}" (0) + \\ : "rcx", "r11", "memory" + \\ ); + \\ unreachable; + \\} + , ""); + } }