Skip to content

Commit

Permalink
stage2: error_set_merged type equality
Browse files Browse the repository at this point in the history
This implements type equality for error sets. This is done
through element-wise error set comparison.

Inferred error sets are always distinct types and other error sets are
always sorted. See ziglang#11022.
  • Loading branch information
mitchellh authored and andrewrk committed Mar 10, 2022
1 parent 0b82c02 commit 569870c
Show file tree
Hide file tree
Showing 8 changed files with 185 additions and 58 deletions.
14 changes: 13 additions & 1 deletion src/Module.zig
Original file line number Diff line number Diff line change
Expand Up @@ -824,7 +824,7 @@ pub const ErrorSet = struct {
/// Offset from Decl node index, points to the error set AST node.
node_offset: i32,
/// The string bytes are stored in the owner Decl arena.
/// They are in the same order they appear in the AST.
/// These must be in sorted order. See sortNames.
names: NameMap,

pub const NameMap = std.StringArrayHashMapUnmanaged(void);
Expand All @@ -836,6 +836,18 @@ pub const ErrorSet = struct {
.lazy = .{ .node_offset = self.node_offset },
};
}

/// sort the NameMap. This should be called whenever the map is modified.
/// alloc should be the allocator used for the NameMap data.
pub fn sortNames(names: *NameMap) void {
const Context = struct {
keys: [][]const u8,
pub fn lessThan(ctx: @This(), a_index: usize, b_index: usize) bool {
return std.mem.lessThan(u8, ctx.keys[a_index], ctx.keys[b_index]);
}
};
names.sort(Context{ .keys = names.keys() });
}
};

pub const RequiresComptime = enum { no, yes, unknown, wip };
Expand Down
4 changes: 4 additions & 0 deletions src/Sema.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2212,6 +2212,10 @@ fn zirErrorSetDecl(
return sema.fail(block, src, "duplicate error set field {s}", .{name});
}
}

// names must be sorted.
Module.ErrorSet.sortNames(&names);

error_set.* = .{
.owner_decl = new_decl,
.node_offset = inst_data.src_node,
Expand Down
66 changes: 45 additions & 21 deletions src/type.zig
Original file line number Diff line number Diff line change
Expand Up @@ -564,27 +564,30 @@ pub const Type = extern union {
=> {
if (b.zigTypeTag() != .ErrorSet) return false;

// TODO: revisit the language specification for how to evaluate equality
// for error set types.

if (a.tag() == .anyerror and b.tag() == .anyerror) {
return true;
// inferred error sets are only equal if both are inferred
// and they originate from the exact same function.
if (a.castTag(.error_set_inferred)) |a_pl| {
if (b.castTag(.error_set_inferred)) |b_pl| {
return a_pl.data.func == b_pl.data.func;
}
return false;
}

if (a.tag() == .error_set and b.tag() == .error_set) {
return a.castTag(.error_set).?.data.owner_decl == b.castTag(.error_set).?.data.owner_decl;
if (b.tag() == .error_set_inferred) return false;

// anyerror matches exactly.
const a_is_any = a.isAnyError();
const b_is_any = b.isAnyError();
if (a_is_any or b_is_any) return a_is_any and b_is_any;

// two resolved sets match if their error set names match.
const a_set = a.errorSetNames();
const b_set = b.errorSetNames();
if (a_set.len != b_set.len) return false;
for (b_set) |b_val| {
if (!a.errorSetHasField(b_val)) return false;
}

if (a.tag() == .error_set_inferred and b.tag() == .error_set_inferred) {
return a.castTag(.error_set_inferred).?.data == b.castTag(.error_set_inferred).?.data;
}

if (a.tag() == .error_set_single and b.tag() == .error_set_single) {
const a_data = a.castTag(.error_set_single).?.data;
const b_data = b.castTag(.error_set_single).?.data;
return std.mem.eql(u8, a_data, b_data);
}
return false;
return true;
},

.@"opaque" => {
Expand Down Expand Up @@ -961,12 +964,30 @@ pub const Type = extern union {

.error_set,
.error_set_single,
.anyerror,
.error_set_inferred,
.error_set_merged,
=> {
// all are treated like an "error set" for hashing
std.hash.autoHash(hasher, std.builtin.TypeId.ErrorSet);
std.hash.autoHash(hasher, Tag.error_set);

const names = ty.errorSetNames();
std.hash.autoHash(hasher, names.len);
assert(std.sort.isSorted([]const u8, names, u8, std.mem.lessThan));
for (names) |name| hasher.update(name);
},

.anyerror => {
// anyerror is distinct from other error sets
std.hash.autoHash(hasher, std.builtin.TypeId.ErrorSet);
// TODO implement this after revisiting Type.Eql for error sets
std.hash.autoHash(hasher, Tag.anyerror);
},

.error_set_inferred => {
// inferred error sets are compared using their data pointer
const data = ty.castTag(.error_set_inferred).?.data.func;
std.hash.autoHash(hasher, std.builtin.TypeId.ErrorSet);
std.hash.autoHash(hasher, Tag.error_set_inferred);
std.hash.autoHash(hasher, data);
},

.@"opaque" => {
Expand Down Expand Up @@ -4365,6 +4386,9 @@ pub const Type = extern union {
try names.put(arena, name, {});
}

// names must be sorted
Module.ErrorSet.sortNames(&names);

return try Tag.error_set_merged.create(arena, names);
}

Expand Down
10 changes: 10 additions & 0 deletions src/value.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1870,6 +1870,16 @@ pub const Value = extern union {

return eql(a_payload.container_ptr, b_payload.container_ptr, ty);
},
.@"error" => {
const a_name = a.castTag(.@"error").?.data.name;
const b_name = b.castTag(.@"error").?.data.name;
return std.mem.eql(u8, a_name, b_name);
},
.eu_payload => {
const a_payload = a.castTag(.eu_payload).?.data;
const b_payload = b.castTag(.eu_payload).?.data;
return eql(a_payload, b_payload, ty.errorUnionPayload());
},
.eu_payload_ptr => @panic("TODO: Implement more pointer eql cases"),
.opt_payload_ptr => @panic("TODO: Implement more pointer eql cases"),
.array => {
Expand Down
16 changes: 8 additions & 8 deletions test/behavior/cast.zig
Original file line number Diff line number Diff line change
Expand Up @@ -669,17 +669,17 @@ test "peer type resolution: disjoint error sets" {
try expect(error_set_info == .ErrorSet);
try expect(error_set_info.ErrorSet.?.len == 3);
try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Two"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Three"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two"));
}

{
const ty = @TypeOf(b, a);
const error_set_info = @typeInfo(ty);
try expect(error_set_info == .ErrorSet);
try expect(error_set_info.ErrorSet.?.len == 3);
try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "Three"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "One"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two"));
}
}
Expand All @@ -704,8 +704,8 @@ test "peer type resolution: error union and error set" {

const error_set_info = @typeInfo(info.ErrorUnion.error_set);
try expect(error_set_info.ErrorSet.?.len == 3);
try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "Three"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "One"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two"));
}

Expand All @@ -717,8 +717,8 @@ test "peer type resolution: error union and error set" {
const error_set_info = @typeInfo(info.ErrorUnion.error_set);
try expect(error_set_info.ErrorSet.?.len == 3);
try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Two"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Three"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two"));
}
}

Expand Down
107 changes: 100 additions & 7 deletions test/behavior/error.zig
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,11 @@ fn intLiteral(str: []const u8) !?i64 {
}

test "nested error union function call in optional unwrap" {
if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;

const S = struct {
const Foo = struct {
Expand Down Expand Up @@ -375,7 +379,11 @@ test "nested error union function call in optional unwrap" {
}

test "return function call to error set from error union function" {
if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;

const S = struct {
fn errorable() anyerror!i32 {
Expand Down Expand Up @@ -404,7 +412,11 @@ test "optional error set is the same size as error set" {
}

test "nested catch" {
if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;

const S = struct {
fn entry() !void {
Expand All @@ -428,11 +440,18 @@ test "nested catch" {
}

test "function pointer with return type that is error union with payload which is pointer of parent struct" {
if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
// This test uses the stage2 const fn pointer
if (builtin.zig_backend == .stage1) return error.SkipZigTest;

if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;

const S = struct {
const Foo = struct {
fun: fn (a: i32) (anyerror!*Foo),
fun: *const fn (a: i32) (anyerror!*Foo),
};

const Err = error{UnspecifiedErr};
Expand Down Expand Up @@ -480,7 +499,11 @@ test "return result loc as peer result loc in inferred error set function" {
}

test "error payload type is correctly resolved" {
if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;

const MyIntWrapper = struct {
const Self = @This();
Expand All @@ -496,7 +519,11 @@ test "error payload type is correctly resolved" {
}

test "error union comptime caching" {
if (builtin.zig_backend != .stage1) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;

const S = struct {
fn quux(comptime arg: anytype) void {
Expand Down Expand Up @@ -539,3 +566,69 @@ test "@errorName sentinel length matches slice length" {
pub fn testBuiltinErrorName(err: anyerror) [:0]const u8 {
return @errorName(err);
}

test "error set equality" {
// This tests using stage2 logic (#11022)
if (builtin.zig_backend == .stage1) return error.SkipZigTest;

if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;

const a = error{One};
const b = error{One};

try expect(a == a);
try expect(a == b);
try expect(a == error{One});

// should treat as a set
const c = error{ One, Two };
const d = error{ Two, One };

try expect(c == d);
}

test "inferred error set equality" {
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;

const S = struct {
fn foo() !void {
return @This().bar();
}

fn bar() !void {
return error.Bad;
}

fn baz() !void {
return quux();
}

fn quux() anyerror!void {}
};

const FooError = @typeInfo(@typeInfo(@TypeOf(S.foo)).Fn.return_type.?).ErrorUnion.error_set;
const BarError = @typeInfo(@typeInfo(@TypeOf(S.bar)).Fn.return_type.?).ErrorUnion.error_set;
const BazError = @typeInfo(@typeInfo(@TypeOf(S.baz)).Fn.return_type.?).ErrorUnion.error_set;

try expect(BarError != error{Bad});

try expect(FooError != anyerror);
try expect(BarError != anyerror);
try expect(BazError != anyerror);

try expect(FooError != BarError);
try expect(FooError != BazError);
try expect(BarError != BazError);

try expect(FooError == FooError);
try expect(BarError == BarError);
try expect(BazError == BazError);
}
7 changes: 5 additions & 2 deletions test/behavior/type_info.zig
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,9 @@ test "type info: error set single value" {
}

test "type info: error set merged" {
// #11022 forces ordering of error sets in stage2
if (builtin.zig_backend == .stage1) return error.SkipZigTest;

if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest;
if (builtin.zig_backend == .stage2_c) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
Expand All @@ -217,8 +220,8 @@ test "type info: error set merged" {
try expect(error_set_info == .ErrorSet);
try expect(error_set_info.ErrorSet.?.len == 3);
try expect(mem.eql(u8, error_set_info.ErrorSet.?[0].name, "One"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Two"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Three"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[1].name, "Three"));
try expect(mem.eql(u8, error_set_info.ErrorSet.?[2].name, "Two"));
}

test "type info: enum info" {
Expand Down
19 changes: 0 additions & 19 deletions test/stage2/x86_64.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1412,25 +1412,6 @@ pub fn addCases(ctx: *TestContext) !void {
});
}

{
var case = ctx.exe("error set equality", target);

case.addCompareOutput(
\\pub fn main() void {
\\ assert(@TypeOf(error.Foo) == @TypeOf(error.Foo));
\\ assert(@TypeOf(error.Bar) != @TypeOf(error.Foo));
\\ assert(anyerror == anyerror);
\\ assert(error{Foo} != error{Foo});
\\ // TODO put inferred error sets here when @typeInfo works
\\}
\\fn assert(b: bool) void {
\\ if (!b) unreachable;
\\}
,
"",
);
}

{
var case = ctx.exe("comptime var", target);

Expand Down

0 comments on commit 569870c

Please sign in to comment.