From b9571dfbabd29e6098281191ac6afad3a6b064c1 Mon Sep 17 00:00:00 2001 From: Takeshi Yoneda Date: Fri, 14 Jun 2024 10:47:07 -0700 Subject: [PATCH] compiler: memory usage optimization around br_table (#2251) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This optimizes the memory usage during compilation for br_table instructions. As you can see in the bench results below, for some cases where lots of br_tables exists (the case named `zz`), the compilation uses 10% less allocations and 5% less memory, hence the slightly faster compilation. ``` goos: darwin goarch: arm64 pkg: github.com/tetratelabs/wazero │ old.txt │ new.txt │ │ sec/op │ sec/op vs base │ Compilation/wazero-10 2.015 ± 2% 1.993 ± 0% -1.09% (p=0.002 n=6) Compilation/zig-10 4.200 ± 0% 4.161 ± 1% -0.93% (p=0.004 n=6) Compilation/zz-10 18.70 ± 0% 18.57 ± 0% -0.69% (p=0.002 n=6) geomean 5.409 5.360 -0.90% │ old.txt │ new.txt │ │ B/op │ B/op vs base │ Compilation/wazero-10 297.5Mi ± 0% 287.1Mi ± 0% -3.48% (p=0.002 n=6) Compilation/zig-10 593.9Mi ± 0% 590.3Mi ± 0% -0.61% (p=0.002 n=6) Compilation/zz-10 582.6Mi ± 0% 553.7Mi ± 0% -4.96% (p=0.002 n=6) geomean 468.7Mi 454.4Mi -3.03% │ old.txt │ new.txt │ │ allocs/op │ allocs/op vs base │ Compilation/wazero-10 457.0k ± 0% 449.1k ± 0% -1.72% (p=0.002 n=6) Compilation/zig-10 275.8k ± 0% 273.8k ± 0% -0.70% (p=0.002 n=6) Compilation/zz-10 926.5k ± 0% 830.9k ± 0% -10.32% (p=0.002 n=6) geomean 488.7k 467.5k -4.35% ``` Signed-off-by: Takeshi Yoneda --- .../engine/wazevo/backend/compiler_lower.go | 3 +- .../wazevo/backend/isa/amd64/machine.go | 48 +++++++++++-------- .../wazevo/backend/isa/arm64/lower_instr.go | 17 ++++--- .../wazevo/backend/isa/arm64/machine.go | 23 +++++---- .../wazevo/backend/isa/arm64/util_test.go | 4 +- internal/engine/wazevo/frontend/lower.go | 13 ++--- internal/engine/wazevo/ssa/basic_block.go | 21 ++++---- internal/engine/wazevo/ssa/builder.go | 17 ++++++- internal/engine/wazevo/ssa/instructions.go | 47 ++++++++++-------- .../engine/wazevo/ssa/pass_blk_layouts.go | 17 +++---- .../wazevo/ssa/pass_blk_layouts_test.go | 16 +++---- 11 files changed, 137 insertions(+), 89 deletions(-) diff --git a/internal/engine/wazevo/backend/compiler_lower.go b/internal/engine/wazevo/backend/compiler_lower.go index 80e65668ad..a58e71dd9d 100644 --- a/internal/engine/wazevo/backend/compiler_lower.go +++ b/internal/engine/wazevo/backend/compiler_lower.go @@ -105,11 +105,12 @@ func (c *compiler) lowerBranches(br0, br1 *ssa.Instruction) { } if br0.Opcode() == ssa.OpcodeJump { - _, args, target := br0.BranchData() + _, args, targetBlockID := br0.BranchData() argExists := len(args) != 0 if argExists && br1 != nil { panic("BUG: critical edge split failed") } + target := c.ssaBuilder.BasicBlock(targetBlockID) if argExists && target.ReturnBlock() { if len(args) > 0 { c.mach.LowerReturns(args) diff --git a/internal/engine/wazevo/backend/isa/amd64/machine.go b/internal/engine/wazevo/backend/isa/amd64/machine.go index 61ae6f4061..1d118d1578 100644 --- a/internal/engine/wazevo/backend/isa/amd64/machine.go +++ b/internal/engine/wazevo/backend/isa/amd64/machine.go @@ -67,8 +67,11 @@ type ( labelResolutionPends []labelResolutionPend + // jmpTableTargets holds the labels of the jump table targets. jmpTableTargets [][]uint32 - consts []_const + // jmpTableTargetNext is the index to the jmpTableTargets slice to be used for the next jump table. + jmpTableTargetsNext int + consts []_const constSwizzleMaskConstIndex, constSqmulRoundSatIndex, constI8x16SHLMaskTableIndex, constI8x16LogicalSHRMaskTableIndex, @@ -131,7 +134,7 @@ func (m *machine) Reset() { m.maxRequiredStackSizeForCalls = 0 m.amodePool.Reset() - m.jmpTableTargets = m.jmpTableTargets[:0] + m.jmpTableTargetsNext = 0 m.constSwizzleMaskConstIndex = -1 m.constSqmulRoundSatIndex = -1 m.constI8x16SHLMaskTableIndex = -1 @@ -187,12 +190,12 @@ func (m *machine) LowerSingleBranch(b *ssa.Instruction) { ectx := m.ectx switch b.Opcode() { case ssa.OpcodeJump: - _, _, targetBlk := b.BranchData() + _, _, targetBlkID := b.BranchData() if b.IsFallthroughJump() { return } jmp := m.allocateInstr() - target := ectx.GetOrAllocateSSABlockLabel(targetBlk) + target := ectx.GetOrAllocateSSABlockLabel(m.c.SSABuilder().BasicBlock(targetBlkID)) if target == backend.LabelReturn { jmp.asRet() } else { @@ -200,33 +203,40 @@ func (m *machine) LowerSingleBranch(b *ssa.Instruction) { } m.insert(jmp) case ssa.OpcodeBrTable: - index, target := b.BrTableData() - m.lowerBrTable(index, target) + index, targetBlkIDs := b.BrTableData() + m.lowerBrTable(index, targetBlkIDs) default: panic("BUG: unexpected branch opcode" + b.Opcode().String()) } } -func (m *machine) addJmpTableTarget(targets []ssa.BasicBlock) (index int) { - // TODO: reuse the slice! - labels := make([]uint32, len(targets)) - for j, target := range targets { - labels[j] = uint32(m.ectx.GetOrAllocateSSABlockLabel(target)) +func (m *machine) addJmpTableTarget(targets ssa.Values) (index int) { + if m.jmpTableTargetsNext == len(m.jmpTableTargets) { + m.jmpTableTargets = append(m.jmpTableTargets, make([]uint32, 0, len(targets.View()))) + } + + index = m.jmpTableTargetsNext + m.jmpTableTargetsNext++ + m.jmpTableTargets[index] = m.jmpTableTargets[index][:0] + for _, targetBlockID := range targets.View() { + target := m.c.SSABuilder().BasicBlock(ssa.BasicBlockID(targetBlockID)) + m.jmpTableTargets[index] = append(m.jmpTableTargets[index], + uint32(m.ectx.GetOrAllocateSSABlockLabel(target))) } - index = len(m.jmpTableTargets) - m.jmpTableTargets = append(m.jmpTableTargets, labels) return } var condBranchMatches = [...]ssa.Opcode{ssa.OpcodeIcmp, ssa.OpcodeFcmp} -func (m *machine) lowerBrTable(index ssa.Value, targets []ssa.BasicBlock) { +func (m *machine) lowerBrTable(index ssa.Value, targets ssa.Values) { _v := m.getOperand_Reg(m.c.ValueDefinition(index)) v := m.copyToTmp(_v.reg()) + targetCount := len(targets.View()) + // First, we need to do the bounds check. maxIndex := m.c.AllocateVReg(ssa.TypeI32) - m.lowerIconst(maxIndex, uint64(len(targets)-1), false) + m.lowerIconst(maxIndex, uint64(targetCount-1), false) cmp := m.allocateInstr().asCmpRmiR(true, newOperandReg(maxIndex), v, false) m.insert(cmp) @@ -255,23 +265,23 @@ func (m *machine) lowerBrTable(index ssa.Value, targets []ssa.BasicBlock) { jmpTable := m.allocateInstr() targetSliceIndex := m.addJmpTableTarget(targets) - jmpTable.asJmpTableSequence(targetSliceIndex, len(targets)) + jmpTable.asJmpTableSequence(targetSliceIndex, targetCount) m.insert(jmpTable) } // LowerConditionalBranch implements backend.Machine. func (m *machine) LowerConditionalBranch(b *ssa.Instruction) { exctx := m.ectx - cval, args, targetBlk := b.BranchData() + cval, args, targetBlkID := b.BranchData() if len(args) > 0 { panic(fmt.Sprintf( "conditional branch shouldn't have args; likely a bug in critical edge splitting: from %s to %s", exctx.CurrentSSABlk, - targetBlk, + targetBlkID, )) } - target := exctx.GetOrAllocateSSABlockLabel(targetBlk) + target := exctx.GetOrAllocateSSABlockLabel(m.c.SSABuilder().BasicBlock(targetBlkID)) cvalDef := m.c.ValueDefinition(cval) switch m.c.MatchInstrOneOf(cvalDef, condBranchMatches[:]) { diff --git a/internal/engine/wazevo/backend/isa/arm64/lower_instr.go b/internal/engine/wazevo/backend/isa/arm64/lower_instr.go index 048bf32040..87652569ae 100644 --- a/internal/engine/wazevo/backend/isa/arm64/lower_instr.go +++ b/internal/engine/wazevo/backend/isa/arm64/lower_instr.go @@ -20,11 +20,12 @@ func (m *machine) LowerSingleBranch(br *ssa.Instruction) { ectx := m.executableContext switch br.Opcode() { case ssa.OpcodeJump: - _, _, targetBlk := br.BranchData() + _, _, targetBlkID := br.BranchData() if br.IsFallthroughJump() { return } b := m.allocateInstr() + targetBlk := m.compiler.SSABuilder().BasicBlock(targetBlkID) target := ectx.GetOrAllocateSSABlockLabel(targetBlk) if target == labelReturn { b.asRet() @@ -40,7 +41,8 @@ func (m *machine) LowerSingleBranch(br *ssa.Instruction) { } func (m *machine) lowerBrTable(i *ssa.Instruction) { - index, targets := i.BrTableData() + index, targetBlockIDs := i.BrTableData() + targetBlockCount := len(targetBlockIDs.View()) indexOperand := m.getOperand_NR(m.compiler.ValueDefinition(index), extModeNone) // Firstly, we have to do the bounds check of the index, and @@ -50,7 +52,7 @@ func (m *machine) lowerBrTable(i *ssa.Instruction) { // subs wzr, index, maxIndexReg // csel adjustedIndex, maxIndexReg, index, hs ;; if index is higher or equal than maxIndexReg. maxIndexReg := m.compiler.AllocateVReg(ssa.TypeI32) - m.lowerConstantI32(maxIndexReg, int32(len(targets)-1)) + m.lowerConstantI32(maxIndexReg, int32(targetBlockCount-1)) subs := m.allocateInstr() subs.asALU(aluOpSubS, xzrVReg, indexOperand, operandNR(maxIndexReg), false) m.insert(subs) @@ -61,23 +63,24 @@ func (m *machine) lowerBrTable(i *ssa.Instruction) { brSequence := m.allocateInstr() - tableIndex := m.addJmpTableTarget(targets) - brSequence.asBrTableSequence(adjustedIndex, tableIndex, len(targets)) + tableIndex := m.addJmpTableTarget(targetBlockIDs) + brSequence.asBrTableSequence(adjustedIndex, tableIndex, targetBlockCount) m.insert(brSequence) } // LowerConditionalBranch implements backend.Machine. func (m *machine) LowerConditionalBranch(b *ssa.Instruction) { exctx := m.executableContext - cval, args, targetBlk := b.BranchData() + cval, args, targetBlkID := b.BranchData() if len(args) > 0 { panic(fmt.Sprintf( "conditional branch shouldn't have args; likely a bug in critical edge splitting: from %s to %s", exctx.CurrentSSABlk, - targetBlk, + targetBlkID, )) } + targetBlk := m.compiler.SSABuilder().BasicBlock(targetBlkID) target := exctx.GetOrAllocateSSABlockLabel(targetBlk) cvalDef := m.compiler.ValueDefinition(cval) diff --git a/internal/engine/wazevo/backend/isa/arm64/machine.go b/internal/engine/wazevo/backend/isa/arm64/machine.go index 5f584f928b..506c263936 100644 --- a/internal/engine/wazevo/backend/isa/arm64/machine.go +++ b/internal/engine/wazevo/backend/isa/arm64/machine.go @@ -35,6 +35,8 @@ type ( // jmpTableTargets holds the labels of the jump table targets. jmpTableTargets [][]uint32 + // jmpTableTargetNext is the index to the jmpTableTargets slice to be used for the next jump table. + jmpTableTargetsNext int // spillSlotSize is the size of the stack slot in bytes used for spilling registers. // During the execution of the function, the stack looks like: @@ -151,7 +153,7 @@ func (m *machine) Reset() { m.unresolvedAddressModes = m.unresolvedAddressModes[:0] m.maxRequiredStackSizeForCalls = 0 m.executableContext.Reset() - m.jmpTableTargets = m.jmpTableTargets[:0] + m.jmpTableTargetsNext = 0 m.amodePool.Reset() } @@ -508,13 +510,18 @@ func (m *machine) frameSize() int64 { return s } -func (m *machine) addJmpTableTarget(targets []ssa.BasicBlock) (index int) { - // TODO: reuse the slice! - labels := make([]uint32, len(targets)) - for j, target := range targets { - labels[j] = uint32(m.executableContext.GetOrAllocateSSABlockLabel(target)) +func (m *machine) addJmpTableTarget(targets ssa.Values) (index int) { + if m.jmpTableTargetsNext == len(m.jmpTableTargets) { + m.jmpTableTargets = append(m.jmpTableTargets, make([]uint32, 0, len(targets.View()))) + } + + index = m.jmpTableTargetsNext + m.jmpTableTargetsNext++ + m.jmpTableTargets[index] = m.jmpTableTargets[index][:0] + for _, targetBlockID := range targets.View() { + target := m.compiler.SSABuilder().BasicBlock(ssa.BasicBlockID(targetBlockID)) + m.jmpTableTargets[index] = append(m.jmpTableTargets[index], + uint32(m.executableContext.GetOrAllocateSSABlockLabel(target))) } - index = len(m.jmpTableTargets) - m.jmpTableTargets = append(m.jmpTableTargets, labels) return } diff --git a/internal/engine/wazevo/backend/isa/arm64/util_test.go b/internal/engine/wazevo/backend/isa/arm64/util_test.go index a19fb33657..1ab27aa6f9 100644 --- a/internal/engine/wazevo/backend/isa/arm64/util_test.go +++ b/internal/engine/wazevo/backend/isa/arm64/util_test.go @@ -36,6 +36,7 @@ func newSetupWithMockContext() (*mockCompiler, ssa.Builder, *machine) { m := NewBackend().(*machine) m.SetCompiler(ctx) ssaB := ssa.NewBuilder() + ctx.ssaBuilder = ssaB blk := ssaB.AllocateBasicBlock() ssaB.SetCurrentBlock(blk) return ctx, ssaB, m @@ -57,6 +58,7 @@ type mockCompiler struct { definitions map[ssa.Value]*backend.SSAValueDefinition sigs map[ssa.SignatureID]*ssa.Signature typeOf map[regalloc.VRegID]ssa.Type + ssaBuilder ssa.Builder relocs []backend.RelocationInfo buf []byte } @@ -68,7 +70,7 @@ func (m *mockCompiler) GetFunctionABI(sig *ssa.Signature) *backend.FunctionABI { panic("implement me") } -func (m *mockCompiler) SSABuilder() ssa.Builder { return nil } +func (m *mockCompiler) SSABuilder() ssa.Builder { return m.ssaBuilder } func (m *mockCompiler) LoopNestingForestRoots() []ssa.BasicBlock { panic("TODO") } diff --git a/internal/engine/wazevo/frontend/lower.go b/internal/engine/wazevo/frontend/lower.go index ff963e605b..ef0b205ad1 100644 --- a/internal/engine/wazevo/frontend/lower.go +++ b/internal/engine/wazevo/frontend/lower.go @@ -4068,13 +4068,14 @@ func (c *Compiler) lowerBrTable(labels []uint32, index ssa.Value) { numArgs = len(f.blockType.Results) } - targets := make([]ssa.BasicBlock, len(labels)) + varPool := builder.VarLengthPool() + trampolineBlockIDs := varPool.Allocate(len(labels)) // We need trampoline blocks since depending on the target block structure, we might end up inserting moves before jumps, // which cannot be done with br_table. Instead, we can do such per-block moves in the trampoline blocks. // At the linking phase (very end of the backend), we can remove the unnecessary jumps, and therefore no runtime overhead. currentBlk := builder.CurrentBlock() - for i, l := range labels { + for _, l := range labels { // Args are always on the top of the stack. Note that we should not share the args slice // among the jump instructions since the args are modified during passes (e.g. redundant phi elimination). args := c.nPeekDup(numArgs) @@ -4082,17 +4083,17 @@ func (c *Compiler) lowerBrTable(labels []uint32, index ssa.Value) { trampoline := builder.AllocateBasicBlock() builder.SetCurrentBlock(trampoline) c.insertJumpToBlock(args, targetBlk) - targets[i] = trampoline + trampolineBlockIDs = trampolineBlockIDs.Append(builder.VarLengthPool(), ssa.Value(trampoline.ID())) } builder.SetCurrentBlock(currentBlk) // If the target block has no arguments, we can just jump to the target block. brTable := builder.AllocateInstruction() - brTable.AsBrTable(index, targets) + brTable.AsBrTable(index, trampolineBlockIDs) builder.InsertInstruction(brTable) - for _, trampoline := range targets { - builder.Seal(trampoline) + for _, trampolineID := range trampolineBlockIDs.View() { + builder.Seal(builder.BasicBlock(ssa.BasicBlockID(trampolineID))) } } diff --git a/internal/engine/wazevo/ssa/basic_block.go b/internal/engine/wazevo/ssa/basic_block.go index 0e1bab02d7..cf7f14d3b1 100644 --- a/internal/engine/wazevo/ssa/basic_block.go +++ b/internal/engine/wazevo/ssa/basic_block.go @@ -34,9 +34,6 @@ type BasicBlock interface { // The returned Value is the definition of the param in this block. Param(i int) Value - // InsertInstruction inserts an instruction that implements Value into the tail of this block. - InsertInstruction(raw *Instruction) - // Root returns the root instruction of this block. Root() *Instruction @@ -208,8 +205,8 @@ func (bb *basicBlock) Sealed() bool { return bb.sealed } -// InsertInstruction implements BasicBlock.InsertInstruction. -func (bb *basicBlock) InsertInstruction(next *Instruction) { +// insertInstruction implements BasicBlock.InsertInstruction. +func (bb *basicBlock) insertInstruction(b *builder, next *Instruction) { current := bb.currentInstr if current != nil { current.next = next @@ -221,12 +218,12 @@ func (bb *basicBlock) InsertInstruction(next *Instruction) { switch next.opcode { case OpcodeJump, OpcodeBrz, OpcodeBrnz: - target := next.blk.(*basicBlock) - target.addPred(bb, next) + target := BasicBlockID(next.rValue) + b.basicBlock(target).addPred(bb, next) case OpcodeBrTable: - for _, _target := range next.targets { - target := _target.(*basicBlock) - target.addPred(bb, next) + for _, _target := range next.rValues.View() { + target := BasicBlockID(_target) + b.basicBlock(target).addPred(bb, next) } } } @@ -339,7 +336,9 @@ func (bb *basicBlock) validate(b *builder) { if len(bb.preds) > 0 { for _, pred := range bb.preds { if pred.branch.opcode != OpcodeBrTable { - if target := pred.branch.blk; target != bb { + blockID := int(pred.branch.rValue) + target := b.basicBlocksPool.View(blockID) + if target != bb { panic(fmt.Sprintf("BUG: '%s' is not branch to %s, but to %s", pred.branch.Format(b), bb.Name(), target.Name())) } diff --git a/internal/engine/wazevo/ssa/builder.go b/internal/engine/wazevo/ssa/builder.go index 8debe17b76..0eb0fc518f 100644 --- a/internal/engine/wazevo/ssa/builder.go +++ b/internal/engine/wazevo/ssa/builder.go @@ -129,6 +129,9 @@ type Builder interface { // InsertZeroValue inserts a zero value constant instruction of the given type. InsertZeroValue(t Type) + + // BasicBlock returns the BasicBlock of the given ID. + BasicBlock(id BasicBlockID) BasicBlock } // NewBuilder returns a new Builder implementation. @@ -214,6 +217,18 @@ type redundantParam struct { uniqueValue Value } +// BasicBlock implements Builder.BasicBlock. +func (b *builder) BasicBlock(id BasicBlockID) BasicBlock { + return b.basicBlock(id) +} + +func (b *builder) basicBlock(id BasicBlockID) *basicBlock { + if id == basicBlockIDReturnBlock { + return b.returnBlk + } + return b.basicBlocksPool.View(int(id)) +} + // InsertZeroValue implements Builder.InsertZeroValue. func (b *builder) InsertZeroValue(t Type) { if b.zeros[t].Valid() { @@ -362,7 +377,7 @@ func (b *builder) Idom(blk BasicBlock) BasicBlock { // InsertInstruction implements Builder.InsertInstruction. func (b *builder) InsertInstruction(instr *Instruction) { - b.currentBB.InsertInstruction(instr) + b.currentBB.insertInstruction(b, instr) if l := b.currentSourceOffset; l.Valid() { // Emit the source offset info only when the instruction has side effect because diff --git a/internal/engine/wazevo/ssa/instructions.go b/internal/engine/wazevo/ssa/instructions.go index 3e3482efc4..9a3d1da6e9 100644 --- a/internal/engine/wazevo/ssa/instructions.go +++ b/internal/engine/wazevo/ssa/instructions.go @@ -25,11 +25,13 @@ type Instruction struct { v3 Value vs Values typ Type - blk BasicBlock - targets []BasicBlock prev, next *Instruction - rValue Value + // rValue is the (first) return value of this instruction. + // For branching instructions except for OpcodeBrTable, they hold BlockID to jump cast to Value. + rValue Value + // rValues are the rest of the return values of this instruction. + // For OpcodeBrTable, it holds the list of BlockID to jump cast to Value. rValues Values gid InstructionGroupID sourceOffset SourceOffset @@ -105,6 +107,9 @@ type InstructionGroupID uint32 // Returns Value(s) produced by this instruction if any. // The `first` is the first return value, and `rest` is the rest of the values. func (i *Instruction) Returns() (first Value, rest []Value) { + if i.IsBranching() { + return ValueInvalid, nil + } return i.rValue, i.rValues.View() } @@ -2077,7 +2082,7 @@ func (i *Instruction) InvertBrx() { } // BranchData returns the branch data for this instruction necessary for backends. -func (i *Instruction) BranchData() (condVal Value, blockArgs []Value, target BasicBlock) { +func (i *Instruction) BranchData() (condVal Value, blockArgs []Value, target BasicBlockID) { switch i.opcode { case OpcodeJump: condVal = ValueInvalid @@ -2087,17 +2092,17 @@ func (i *Instruction) BranchData() (condVal Value, blockArgs []Value, target Bas panic("BUG") } blockArgs = i.vs.View() - target = i.blk + target = BasicBlockID(i.rValue) return } // BrTableData returns the branch table data for this instruction necessary for backends. -func (i *Instruction) BrTableData() (index Value, targets []BasicBlock) { +func (i *Instruction) BrTableData() (index Value, targets Values) { if i.opcode != OpcodeBrTable { panic("BUG: BrTableData only available for OpcodeBrTable") } index = i.v - targets = i.targets + targets = i.rValues return } @@ -2105,7 +2110,7 @@ func (i *Instruction) BrTableData() (index Value, targets []BasicBlock) { func (i *Instruction) AsJump(vs Values, target BasicBlock) *Instruction { i.opcode = OpcodeJump i.vs = vs - i.blk = target + i.rValue = Value(target.ID()) return i } @@ -2130,7 +2135,7 @@ func (i *Instruction) AsBrz(v Value, args Values, target BasicBlock) { i.opcode = OpcodeBrz i.v = v i.vs = args - i.blk = target + i.rValue = Value(target.ID()) } // AsBrnz initializes this instruction as a branch-if-not-zero instruction with OpcodeBrnz. @@ -2138,15 +2143,16 @@ func (i *Instruction) AsBrnz(v Value, args Values, target BasicBlock) *Instructi i.opcode = OpcodeBrnz i.v = v i.vs = args - i.blk = target + i.rValue = Value(target.ID()) return i } // AsBrTable initializes this instruction as a branch-table instruction with OpcodeBrTable. -func (i *Instruction) AsBrTable(index Value, targets []BasicBlock) { +// targets is a list of basic block IDs cast to Values. +func (i *Instruction) AsBrTable(index Value, targets Values) { i.opcode = OpcodeBrTable i.v = index - i.targets = targets + i.rValues = targets } // AsCall initializes this instruction as a call instruction with OpcodeCall. @@ -2531,7 +2537,8 @@ func (i *Instruction) Format(b Builder) string { if i.IsFallthroughJump() { vs[0] = " fallthrough" } else { - vs[0] = " " + i.blk.(*basicBlock).Name() + blockId := BasicBlockID(i.rValue) + vs[0] = " " + b.BasicBlock(blockId).Name() } for idx := range view { vs[idx+1] = view[idx].Format(b) @@ -2542,7 +2549,8 @@ func (i *Instruction) Format(b Builder) string { view := i.vs.View() vs := make([]string, len(view)+2) vs[0] = " " + i.v.Format(b) - vs[1] = i.blk.(*basicBlock).Name() + blockId := BasicBlockID(i.rValue) + vs[1] = b.BasicBlock(blockId).Name() for idx := range view { vs[idx+2] = view[idx].Format(b) } @@ -2551,8 +2559,8 @@ func (i *Instruction) Format(b Builder) string { // `BrTable index, [label1, label2, ... labelN]` instSuffix = fmt.Sprintf(" %s", i.v.Format(b)) instSuffix += ", [" - for i, target := range i.targets { - blk := target.(*basicBlock) + for i, target := range i.rValues.View() { + blk := b.BasicBlock(BasicBlockID(target)) if i == 0 { instSuffix += blk.Name() } else { @@ -2621,11 +2629,12 @@ func (i *Instruction) Format(b Builder) string { instr := i.opcode.String() + instSuffix var rvs []string - if rv := i.rValue; rv.Valid() { - rvs = append(rvs, rv.formatWithType(b)) + r1, rs := i.Returns() + if r1.Valid() { + rvs = append(rvs, r1.formatWithType(b)) } - for _, v := range i.rValues.View() { + for _, v := range rs { rvs = append(rvs, v.formatWithType(b)) } diff --git a/internal/engine/wazevo/ssa/pass_blk_layouts.go b/internal/engine/wazevo/ssa/pass_blk_layouts.go index c7e14fb489..0118e8b2e5 100644 --- a/internal/engine/wazevo/ssa/pass_blk_layouts.go +++ b/internal/engine/wazevo/ssa/pass_blk_layouts.go @@ -33,7 +33,7 @@ func passLayoutBlocks(b *builder) { } nonSplitBlocks = append(nonSplitBlocks, blk) if i != len(b.reversePostOrderedBasicBlocks)-1 { - _ = maybeInvertBranches(blk, b.reversePostOrderedBasicBlocks[i+1]) + _ = maybeInvertBranches(b, blk, b.reversePostOrderedBasicBlocks[i+1]) } } @@ -111,7 +111,7 @@ func passLayoutBlocks(b *builder) { } fallthroughBranch := blk.currentInstr - if fallthroughBranch.opcode == OpcodeJump && fallthroughBranch.blk == trampoline { + if fallthroughBranch.opcode == OpcodeJump && BasicBlockID(fallthroughBranch.rValue) == trampoline.id { // This can be lowered as fallthrough at the end of the block. b.reversePostOrderedBasicBlocks = append(b.reversePostOrderedBasicBlocks, trampoline) trampoline.visited = 1 // mark as inserted. @@ -157,7 +157,7 @@ func (b *builder) markFallthroughJumps() { for i, blk := range b.reversePostOrderedBasicBlocks { if i < l { cur := blk.currentInstr - if cur.opcode == OpcodeJump && cur.blk == b.reversePostOrderedBasicBlocks[i+1] { + if cur.opcode == OpcodeJump && BasicBlockID(cur.rValue) == b.reversePostOrderedBasicBlocks[i+1].id { cur.AsFallthroughJump() } } @@ -168,7 +168,7 @@ func (b *builder) markFallthroughJumps() { // nextInRPO is the next block in the reverse post-order. // // Returns true if the branch is inverted for testing purpose. -func maybeInvertBranches(now *basicBlock, nextInRPO *basicBlock) bool { +func maybeInvertBranches(b *builder, now *basicBlock, nextInRPO *basicBlock) bool { fallthroughBranch := now.currentInstr if fallthroughBranch.opcode == OpcodeBrTable { return false @@ -187,7 +187,8 @@ func maybeInvertBranches(now *basicBlock, nextInRPO *basicBlock) bool { // So this block has two branches (a conditional branch followed by an unconditional branch) at the end. // We can invert the condition of the branch if it makes the fallthrough more likely. - fallthroughTarget, condTarget := fallthroughBranch.blk.(*basicBlock), condBranch.blk.(*basicBlock) + fallthroughTarget := b.basicBlock(BasicBlockID(fallthroughBranch.rValue)) + condTarget := b.basicBlock(BasicBlockID(condBranch.rValue)) if fallthroughTarget.loopHeader { // First, if the tail's target is loopHeader, we don't need to do anything here, @@ -231,8 +232,8 @@ invert: } condBranch.InvertBrx() - condBranch.blk = fallthroughTarget - fallthroughBranch.blk = condTarget + condBranch.rValue = Value(fallthroughTarget.ID()) + fallthroughBranch.rValue = Value(condTarget.ID()) if wazevoapi.SSALoggingEnabled { fmt.Printf("inverting branches at %d->%d and %d->%d\n", now.ID(), fallthroughTarget.ID(), now.ID(), condTarget.ID()) @@ -275,7 +276,7 @@ func (b *builder) splitCriticalEdge(pred, succ *basicBlock, predInfo *basicBlock // Replace originalBranch with the newBranch. newBranch := b.AllocateInstruction() newBranch.opcode = originalBranch.opcode - newBranch.blk = trampoline + newBranch.rValue = Value(trampoline.ID()) switch originalBranch.opcode { case OpcodeJump: case OpcodeBrz, OpcodeBrnz: diff --git a/internal/engine/wazevo/ssa/pass_blk_layouts_test.go b/internal/engine/wazevo/ssa/pass_blk_layouts_test.go index cc4e004374..6fa5411dbf 100644 --- a/internal/engine/wazevo/ssa/pass_blk_layouts_test.go +++ b/internal/engine/wazevo/ssa/pass_blk_layouts_test.go @@ -120,9 +120,9 @@ func Test_maybeInvertBranch(t *testing.T) { require.Equal(t, tail, next.preds[0].branch) verify = func(t *testing.T) { require.Equal(t, OpcodeJump, tail.opcode) - require.Equal(t, OpcodeBrnz, conditionalBr.opcode) // inversion. - require.Equal(t, loopHeader, tail.blk) // swapped. - require.Equal(t, next, conditionalBr.blk) // swapped. + require.Equal(t, OpcodeBrnz, conditionalBr.opcode) // inversion. + require.Equal(t, loopHeader, b.basicBlock(BasicBlockID(tail.rValue))) // swapped. + require.Equal(t, next, b.basicBlock(BasicBlockID(conditionalBr.rValue))) // swapped. require.Equal(t, conditionalBr, tail.prev) // Predecessor info should correctly point to the inverted jump instruction. @@ -150,9 +150,9 @@ func Test_maybeInvertBranch(t *testing.T) { verify = func(t *testing.T) { require.Equal(t, OpcodeJump, tail.opcode) - require.Equal(t, OpcodeBrnz, conditionalBr.opcode) // inversion. - require.Equal(t, next, tail.blk) // swapped. - require.Equal(t, nowTarget, conditionalBr.blk) // swapped. + require.Equal(t, OpcodeBrnz, conditionalBr.opcode) // inversion. + require.Equal(t, next, b.basicBlock(BasicBlockID(tail.rValue))) // swapped. + require.Equal(t, nowTarget, b.basicBlock(BasicBlockID(conditionalBr.rValue))) // swapped. require.Equal(t, conditionalBr, tail.prev) require.Equal(t, conditionalBr, nowTarget.preds[0].branch) @@ -166,7 +166,7 @@ func Test_maybeInvertBranch(t *testing.T) { t.Run(tc.name, func(t *testing.T) { b := NewBuilder().(*builder) now, next, verify := tc.setup(b) - actual := maybeInvertBranches(now, next) + actual := maybeInvertBranches(b, now, next) verify(t) require.Equal(t, tc.exp, actual) }) @@ -202,7 +202,7 @@ func TestBuilder_splitCriticalEdge(t *testing.T) { replacedBrz := predBlk.rootInstr.next require.Equal(t, OpcodeBrz, replacedBrz.opcode) - require.Equal(t, trampoline, replacedBrz.blk) + require.Equal(t, trampoline, b.basicBlock(BasicBlockID(replacedBrz.rValue))) } func Test_swapInstruction(t *testing.T) {