Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Eliminate switch_capture and switch_capture_ref ZIR tags
Browse files Browse the repository at this point in the history
These tags are unnecessary, as this information can be more efficiently
encoded within the switch_block instruction itself. We also use a neat
little trick to avoid needing a dummy instruction (like is used for
errdefer captures): since the switch_block itself cannot otherwise be
referenced within a prong, we can repurpose its index within prongs to
refer to the captured value.
mlugg committed May 30, 2023
1 parent b8f050f commit 9c70360
Showing 5 changed files with 544 additions and 325 deletions.
68 changes: 30 additions & 38 deletions src/AstGen.zig
Original file line number Diff line number Diff line change
@@ -2612,8 +2612,6 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
.switch_block,
.switch_cond,
.switch_cond_ref,
.switch_capture,
.switch_capture_ref,
.switch_capture_tag,
.struct_init_empty,
.struct_init,
@@ -6860,17 +6858,22 @@ fn switchExpr(
var dbg_var_inst: Zir.Inst.Ref = undefined;
var dbg_var_tag_name: ?u32 = null;
var dbg_var_tag_inst: Zir.Inst.Ref = undefined;
var capture_inst: Zir.Inst.Index = 0;
var tag_inst: Zir.Inst.Index = 0;
var capture_val_scope: Scope.LocalVal = undefined;
var tag_scope: Scope.LocalVal = undefined;

var capture: Zir.Inst.SwitchBlock.ProngInfo.Capture = .none;

const sub_scope = blk: {
const payload_token = case.payload_token orelse break :blk &case_scope.base;
const ident = if (token_tags[payload_token] == .asterisk)
payload_token + 1
else
payload_token;

const is_ptr = ident != payload_token;
capture = if (is_ptr) .by_ref else .by_val;

const ident_slice = tree.tokenSlice(ident);
var payload_sub_scope: *Scope = undefined;
if (mem.eql(u8, ident_slice, "_")) {
@@ -6879,46 +6882,18 @@ fn switchExpr(
}
payload_sub_scope = &case_scope.base;
} else {
if (case_node == special_node) {
const capture_tag: Zir.Inst.Tag = if (is_ptr)
.switch_capture_ref
else
.switch_capture;
capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
try astgen.instructions.append(gpa, .{
.tag = capture_tag,
.data = .{
.switch_capture = .{
.switch_inst = switch_block,
// Max int communicates that this is the else/underscore prong.
.prong_index = std.math.maxInt(u32),
},
},
});
} else {
const capture_tag: Zir.Inst.Tag = if (is_ptr) .switch_capture_ref else .switch_capture;
const capture_index = if (is_multi_case) scalar_cases_len + multi_case_index else scalar_case_index;
capture_inst = @intCast(Zir.Inst.Index, astgen.instructions.len);
try astgen.instructions.append(gpa, .{
.tag = capture_tag,
.data = .{ .switch_capture = .{
.switch_inst = switch_block,
.prong_index = capture_index,
} },
});
}
const capture_name = try astgen.identAsString(ident);
try astgen.detectLocalShadowing(&case_scope.base, capture_name, ident, ident_slice, .capture);
capture_val_scope = .{
.parent = &case_scope.base,
.gen_zir = &case_scope,
.name = capture_name,
.inst = indexToRef(capture_inst),
.inst = indexToRef(switch_block),
.token_src = payload_token,
.id_cat = .capture,
};
dbg_var_name = capture_name;
dbg_var_inst = indexToRef(capture_inst);
dbg_var_inst = indexToRef(switch_block);
payload_sub_scope = &capture_val_scope.base;
}

@@ -7007,7 +6982,6 @@ fn switchExpr(
case_scope.instructions_top = parent_gz.instructions.items.len;
defer case_scope.unstack();

if (capture_inst != 0) try case_scope.instructions.append(gpa, capture_inst);
if (tag_inst != 0) try case_scope.instructions.append(gpa, tag_inst);
try case_scope.addDbgBlockBegin();
if (dbg_var_name) |some| {
@@ -7026,10 +7000,28 @@ fn switchExpr(
}

const case_slice = case_scope.instructionsSlice();
const body_len = astgen.countBodyLenAfterFixups(case_slice);
// Since we use the switch_block instruction itself to refer to the
// capture, which will not be added to the child block, we need to
// handle ref_table manually.
const refs_len = refs: {
var n: usize = 0;
var check_inst = switch_block;
while (astgen.ref_table.get(check_inst)) |ref_inst| {
n += 1;
check_inst = ref_inst;
}
break :refs n;
};
const body_len = refs_len + astgen.countBodyLenAfterFixups(case_slice);
try payloads.ensureUnusedCapacity(gpa, body_len);
const inline_bit = @as(u32, @boolToInt(case.inline_token != null)) << 31;
payloads.items[body_len_index] = body_len | inline_bit;
payloads.items[body_len_index] = @bitCast(u32, Zir.Inst.SwitchBlock.ProngInfo{
.body_len = @intCast(u29, body_len),
.capture = capture,
.is_inline = case.inline_token != null,
});
if (astgen.ref_table.fetchRemove(switch_block)) |kv| {
appendPossiblyRefdBodyInst(astgen, payloads, kv.value);
}
appendBodyWithFixupsArrayList(astgen, payloads, case_slice);
}
}
@@ -7076,7 +7068,7 @@ fn switchExpr(
end_index += 3 + items_len + 2 * ranges_len;
}

const body_len = @truncate(u31, payloads.items[body_len_index]);
const body_len = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, payloads.items[body_len_index]).body_len;
end_index += body_len;

switch (strat.tag) {
24 changes: 17 additions & 7 deletions src/Module.zig
Original file line number Diff line number Diff line change
@@ -6013,6 +6013,7 @@ pub const SwitchProngSrc = union(enum) {
multi: Multi,
range: Multi,
multi_capture: u32,
special,

pub const Multi = struct {
prong: u32,
@@ -6049,14 +6050,22 @@ pub const SwitchProngSrc = union(enum) {
var scalar_i: u32 = 0;
for (case_nodes) |case_node| {
const case = tree.fullSwitchCase(case_node).?;
if (case.ast.values.len == 0)
continue;
if (case.ast.values.len == 1 and
node_tags[case.ast.values[0]] == .identifier and
mem.eql(u8, tree.tokenSlice(main_tokens[case.ast.values[0]]), "_"))
{
continue;

const is_special = special: {
if (case.ast.values.len == 0) break :special true;
if (case.ast.values.len == 1 and node_tags[case.ast.values[0]] == .identifier) {
break :special mem.eql(u8, tree.tokenSlice(main_tokens[case.ast.values[0]]), "_");
}
break :special false;
};

if (is_special) {
if (prong_src != .special) continue;
return LazySrcLoc.nodeOffset(
decl.nodeIndexToRelative(case.ast.values[0]),
);
}

const is_multi = case.ast.values.len != 1 or
node_tags[case.ast.values[0]] == .switch_range;

@@ -6097,6 +6106,7 @@ pub const SwitchProngSrc = union(enum) {
range_i += 1;
} else unreachable;
},
.special => {},
}
if (is_multi) {
multi_i += 1;
617 changes: 447 additions & 170 deletions src/Sema.zig

Large diffs are not rendered by default.

109 changes: 23 additions & 86 deletions src/Zir.zig
Original file line number Diff line number Diff line change
@@ -675,17 +675,6 @@ pub const Inst = struct {
/// what will be switched on.
/// Uses the `un_node` union field.
switch_cond_ref,
/// Produces the capture value for a switch prong.
/// Uses the `switch_capture` field.
/// If the `prong_index` field is max int, it means this is the capture
/// for the else/`_` prong.
switch_capture,
/// Produces the capture value for a switch prong.
/// Result is a pointer to the value.
/// Uses the `switch_capture` field.
/// If the `prong_index` field is max int, it means this is the capture
/// for the else/`_` prong.
switch_capture_ref,
/// Produces the capture value for an inline switch prong tag capture.
/// Uses the `un_tok` field.
switch_capture_tag,
@@ -1134,8 +1123,6 @@ pub const Inst = struct {
.typeof_log2_int_type,
.resolve_inferred_alloc,
.set_eval_branch_quota,
.switch_capture,
.switch_capture_ref,
.switch_capture_tag,
.switch_block,
.switch_cond,
@@ -1426,8 +1413,6 @@ pub const Inst = struct {
.slice_length,
.import,
.typeof_log2_int_type,
.switch_capture,
.switch_capture_ref,
.switch_capture_tag,
.switch_block,
.switch_cond,
@@ -1684,8 +1669,6 @@ pub const Inst = struct {
.switch_block = .pl_node,
.switch_cond = .un_node,
.switch_cond_ref = .un_node,
.switch_capture = .switch_capture,
.switch_capture_ref = .switch_capture,
.switch_capture_tag = .un_tok,
.array_base_ptr = .un_node,
.field_base_ptr = .un_node,
@@ -2598,10 +2581,6 @@ pub const Inst = struct {
operand: Ref,
payload_index: u32,
},
switch_capture: struct {
switch_inst: Index,
prong_index: u32,
},
dbg_stmt: LineColumn,
/// Used for unary operators which reference an inst,
/// with an AST node source location.
@@ -2671,7 +2650,6 @@ pub const Inst = struct {
bool_br,
@"unreachable",
@"break",
switch_capture,
dbg_stmt,
inst_node,
str_op,
@@ -3011,25 +2989,29 @@ pub const Inst = struct {

/// 0. multi_cases_len: u32 // If has_multi_cases is set.
/// 1. else_body { // If has_else or has_under is set.
/// body_len: u32,
/// body member Index for every body_len
/// info: ProngInfo,
/// body member Index for every info.body_len
/// }
/// 2. scalar_cases: { // for every scalar_cases_len
/// item: Ref,
/// body_len: u32,
/// body member Index for every body_len
/// info: ProngInfo,
/// body member Index for every info.body_len
/// }
/// 3. multi_cases: { // for every multi_cases_len
/// items_len: u32,
/// ranges_len: u32,
/// body_len: u32,
/// info: ProngInfo,
/// item: Ref // for every items_len
/// ranges: { // for every ranges_len
/// item_first: Ref,
/// item_last: Ref,
/// }
/// body member Index for every body_len
/// body member Index for every info.body_len
/// }
///
/// When analyzing a case body, the switch instruction itself refers to the
/// captured payload. Whether this is captured by reference or by value
/// depends on whether the `byref` bit is set for the corresponding body.
pub const SwitchBlock = struct {
/// This is always a `switch_cond` or `switch_cond_ref` instruction.
/// If it is a `switch_cond_ref` instruction, bits.is_ref is always true.
@@ -3041,6 +3023,19 @@ pub const Inst = struct {
operand: Ref,
bits: Bits,

/// These are stored in trailing data in `extra` for each prong.
pub const ProngInfo = packed struct(u32) {
body_len: u29,
capture: Capture,
is_inline: bool,

pub const Capture = enum(u2) {
none,
by_val,
by_ref,
};
};

pub const Bits = packed struct {
/// If true, one or more prongs have multiple items.
has_multi_cases: bool,
@@ -3068,64 +3063,6 @@ pub const Inst = struct {
items: []const Ref,
body: []const Index,
};

/// TODO performance optimization: instead of having this helper method
/// change the definition of switch_capture instruction to store extra_index
/// instead of prong_index. This way, Sema won't be doing O(N^2) iterations
/// over the switch prongs.
pub fn getProng(
self: SwitchBlock,
zir: Zir,
extra_end: usize,
prong_index: usize,
) MultiProng {
var extra_index: usize = extra_end + @boolToInt(self.bits.has_multi_cases);

if (self.bits.specialProng() != .none) {
const body_len = @truncate(u31, zir.extra[extra_index]);
extra_index += 1;
const body = zir.extra[extra_index..][0..body_len];
extra_index += body.len;
}

var cur_idx: usize = 0;
while (cur_idx < self.bits.scalar_cases_len) : (cur_idx += 1) {
const items = zir.refSlice(extra_index, 1);
extra_index += 1;
const body_len = @truncate(u31, zir.extra[extra_index]);
extra_index += 1;
const body = zir.extra[extra_index..][0..body_len];
extra_index += body_len;
if (cur_idx == prong_index) {
return .{
.items = items,
.body = body,
};
}
}
while (true) : (cur_idx += 1) {
const items_len = zir.extra[extra_index];
extra_index += 1;
const ranges_len = zir.extra[extra_index];
extra_index += 1;
const body_len = @truncate(u31, zir.extra[extra_index]);
extra_index += 1;
const items = zir.refSlice(extra_index, items_len);
extra_index += items_len;
// Each range has a start and an end.
extra_index += 2 * ranges_len;

const body = zir.extra[extra_index..][0..body_len];
extra_index += body_len;

if (cur_idx == prong_index) {
return .{
.items = items,
.body = body,
};
}
}
}
};

pub const Field = struct {
51 changes: 27 additions & 24 deletions src/print_zir.zig
Original file line number Diff line number Diff line change
@@ -435,10 +435,6 @@ const Writer = struct {

.@"unreachable" => try self.writeUnreachable(stream, inst),

.switch_capture,
.switch_capture_ref,
=> try self.writeSwitchCapture(stream, inst),

.dbg_stmt => try self.writeDbgStmt(stream, inst),

.dbg_block_begin,
@@ -1912,15 +1908,20 @@ const Writer = struct {
else => break :else_prong,
};

const body_len = @truncate(u31, self.code.extra[extra_index]);
const inline_text = if (self.code.extra[extra_index] >> 31 != 0) "inline " else "";
const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, self.code.extra[extra_index]);
const capture_text = switch (info.capture) {
.none => "",
.by_val => "by_val ",
.by_ref => "by_ref ",
};
const inline_text = if (info.is_inline) "inline " else "";
extra_index += 1;
const body = self.code.extra[extra_index..][0..body_len];
const body = self.code.extra[extra_index..][0..info.body_len];
extra_index += body.len;

try stream.writeAll(",\n");
try stream.writeByteNTimes(' ', self.indent);
try stream.print("{s}{s} => ", .{ inline_text, prong_name });
try stream.print("{s}{s}{s} => ", .{ capture_text, inline_text, prong_name });
try self.writeBracedBody(stream, body);
}

@@ -1930,15 +1931,19 @@ const Writer = struct {
while (scalar_i < scalar_cases_len) : (scalar_i += 1) {
const item_ref = @intToEnum(Zir.Inst.Ref, self.code.extra[extra_index]);
extra_index += 1;
const body_len = @truncate(u31, self.code.extra[extra_index]);
const is_inline = self.code.extra[extra_index] >> 31 != 0;
const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, self.code.extra[extra_index]);
extra_index += 1;
const body = self.code.extra[extra_index..][0..body_len];
extra_index += body_len;
const body = self.code.extra[extra_index..][0..info.body_len];
extra_index += info.body_len;

try stream.writeAll(",\n");
try stream.writeByteNTimes(' ', self.indent);
if (is_inline) try stream.writeAll("inline ");
switch (info.capture) {
.none => {},
.by_val => try stream.writeAll("by_val "),
.by_ref => try stream.writeAll("by_ref "),
}
if (info.is_inline) try stream.writeAll("inline ");
try self.writeInstRef(stream, item_ref);
try stream.writeAll(" => ");
try self.writeBracedBody(stream, body);
@@ -1951,15 +1956,19 @@ const Writer = struct {
extra_index += 1;
const ranges_len = self.code.extra[extra_index];
extra_index += 1;
const body_len = @truncate(u31, self.code.extra[extra_index]);
const is_inline = self.code.extra[extra_index] >> 31 != 0;
const info = @bitCast(Zir.Inst.SwitchBlock.ProngInfo, self.code.extra[extra_index]);
extra_index += 1;
const items = self.code.refSlice(extra_index, items_len);
extra_index += items_len;

try stream.writeAll(",\n");
try stream.writeByteNTimes(' ', self.indent);
if (is_inline) try stream.writeAll("inline ");
switch (info.capture) {
.none => {},
.by_val => try stream.writeAll("by_val "),
.by_ref => try stream.writeAll("by_ref "),
}
if (info.is_inline) try stream.writeAll("inline ");

for (items, 0..) |item_ref, item_i| {
if (item_i != 0) try stream.writeAll(", ");
@@ -1981,8 +1990,8 @@ const Writer = struct {
try self.writeInstRef(stream, item_last);
}

const body = self.code.extra[extra_index..][0..body_len];
extra_index += body_len;
const body = self.code.extra[extra_index..][0..info.body_len];
extra_index += info.body_len;
try stream.writeAll(" => ");
try self.writeBracedBody(stream, body);
}
@@ -2434,12 +2443,6 @@ const Writer = struct {
try self.writeSrc(stream, src);
}

fn writeSwitchCapture(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
const inst_data = self.code.instructions.items(.data)[inst].switch_capture;
try self.writeInstIndex(stream, inst_data.switch_inst);
try stream.print(", {d})", .{inst_data.prong_index});
}

fn writeDbgStmt(self: *Writer, stream: anytype, inst: Zir.Inst.Index) !void {
const inst_data = self.code.instructions.items(.data)[inst].dbg_stmt;
try stream.print("{d}, {d})", .{ inst_data.line + 1, inst_data.column + 1 });

0 comments on commit 9c70360

Please sign in to comment.