Skip to content

Commit

Permalink
fix: don't rerun SSA transform in ConstructorVisitor (#2236)
Browse files Browse the repository at this point in the history
  • Loading branch information
skylot committed Aug 2, 2024
1 parent 287ba49 commit 821cc66
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ public enum AFlag {

REQUEST_IF_REGION_OPTIMIZE, // run if region visitor again
REQUEST_CODE_SHRINK,
RERUN_SSA_TRANSFORM,

METHOD_CANDIDATE_FOR_INLINE,
USE_LINES_HINTS, // source lines info in methods can be trusted
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ public void bindArg(RegisterArg arg, BlockNode pred) {
if (blockBinds.contains(pred)) {
throw new JadxRuntimeException("Duplicate predecessors in PHI insn: " + pred + ", " + this);
}
if (pred == null) {
throw new JadxRuntimeException("Null bind block in PHI insn: " + this);
}
super.addArg(arg);
blockBinds.add(pred);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jadx.core.dex.attributes.AFlag;
import jadx.core.dex.instructions.InsnType;
import jadx.core.dex.instructions.InvokeNode;
import jadx.core.dex.instructions.PhiInsn;
import jadx.core.dex.instructions.args.InsnArg;
import jadx.core.dex.instructions.args.LiteralArg;
import jadx.core.dex.instructions.args.RegisterArg;
Expand All @@ -18,6 +19,7 @@
import jadx.core.dex.visitors.typeinference.TypeInferenceVisitor;
import jadx.core.utils.BlockUtils;
import jadx.core.utils.InsnRemover;
import jadx.core.utils.exceptions.JadxRuntimeException;

@JadxVisitor(
name = "ConstructorVisitor",
Expand All @@ -34,9 +36,6 @@ public void visit(MethodNode mth) {
if (replaceInvoke(mth)) {
MoveInlineVisitor.moveInline(mth);
}
if (mth.contains(AFlag.RERUN_SSA_TRANSFORM)) {
SSATransform.rerun(mth);
}
}

private static boolean replaceInvoke(MethodNode mth) {
Expand Down Expand Up @@ -75,7 +74,8 @@ private static boolean processInvoke(MethodNode mth, BlockNode block, int indexI
if (assignInsn != null) {
if (assignInsn.getType() == InsnType.CONSTRUCTOR) {
// arg already used in another constructor instruction
mth.add(AFlag.RERUN_SSA_TRANSFORM);
// insert new PHI insn to merge these branched constructors results
instanceArg = insertPhiInsn(mth, block, instanceArg, ((ConstructorInsn) assignInsn));
} else {
InsnNode newInstInsn = removeAssignChain(mth, assignInsn, remover, InsnType.NEW_INSTANCE);
if (newInstInsn != null) {
Expand All @@ -99,6 +99,31 @@ private static boolean processInvoke(MethodNode mth, BlockNode block, int indexI
return true;
}

private static RegisterArg insertPhiInsn(MethodNode mth, BlockNode curBlock,
RegisterArg instArg, ConstructorInsn otherCtr) {
BlockNode otherBlock = BlockUtils.getBlockByInsn(mth, otherCtr);
if (otherBlock == null) {
throw new JadxRuntimeException("Block not found by insn: " + otherCtr);
}
BlockNode crossBlock = BlockUtils.getPathCross(mth, curBlock, otherBlock);
if (crossBlock == null) {
throw new JadxRuntimeException("Path cross not found for blocks: " + curBlock + " and " + otherBlock);
}
RegisterArg newResArg = instArg.duplicateWithNewSSAVar(mth);
RegisterArg useArg = otherCtr.getResult();
RegisterArg otherResArg = useArg.duplicateWithNewSSAVar(mth);

PhiInsn phiInsn = SSATransform.addPhi(mth, crossBlock, useArg.getRegNum());
phiInsn.setResult(useArg.duplicate());
phiInsn.bindArg(newResArg.duplicate(), BlockUtils.getPrevBlockOnPath(crossBlock, curBlock));
phiInsn.bindArg(otherResArg.duplicate(), BlockUtils.getPrevBlockOnPath(crossBlock, otherBlock));
phiInsn.rebindArgs();

otherCtr.setResult(otherResArg.duplicate());
otherCtr.rebindArgs();
return newResArg;
}

private static boolean canRemoveConstructor(MethodNode mth, ConstructorInsn co) {
ClassNode parentClass = mth.getParentClass();
if (co.isSuper() && (co.getArgsCount() == 0 || parentClass.isEnum())) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,6 @@ public void visit(MethodNode mth) throws JadxException {
process(mth);
}

public static void rerun(MethodNode mth) {
mth.remove(AFlag.RERUN_SSA_TRANSFORM);
resetSSAVars(mth);
process(mth);
}

private static void process(MethodNode mth) {
if (!mth.getSVars().isEmpty()) {
return;
Expand Down Expand Up @@ -87,7 +81,8 @@ private static void placePhi(MethodNode mth, int regNum, LiveVarAnalysis la) {
for (int id = domFrontier.nextSetBit(0); id >= 0; id = domFrontier.nextSetBit(id + 1)) {
if (!hasPhi.get(id) && la.isLive(id, regNum)) {
BlockNode df = blocks.get(id);
addPhi(mth, df, regNum);
PhiInsn phiInsn = addPhi(mth, df, regNum);
df.getInstructions().add(0, phiInsn);
hasPhi.set(id);
if (!processed.get(id)) {
processed.set(id);
Expand Down Expand Up @@ -121,7 +116,6 @@ public static PhiInsn addPhi(MethodNode mth, BlockNode block, int regNum) {
PhiInsn phiInsn = new PhiInsn(regNum, size);
phiList.getList().add(phiInsn);
phiInsn.setOffset(block.getStartOffset());
block.getInstructions().add(0, phiInsn);
return phiInsn;
}

Expand Down Expand Up @@ -457,17 +451,6 @@ private static void hidePhiInsns(MethodNode mth) {
}
}

private static void resetSSAVars(MethodNode mth) {
for (SSAVar ssaVar : mth.getSVars()) {
ssaVar.getAssign().resetSSAVar();
ssaVar.getUseList().forEach(RegisterArg::resetSSAVar);
}
for (BlockNode block : mth.getBasicBlocks()) {
block.remove(AType.PHI_LIST);
}
mth.getSVars().clear();
}

private static void removeUnusedInvokeResults(MethodNode mth) {
Iterator<SSAVar> it = mth.getSVars().iterator();
while (it.hasNext()) {
Expand Down
17 changes: 17 additions & 0 deletions jadx-core/src/main/java/jadx/core/utils/BlockUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,23 @@ public static BlockNode getNextBlockToPath(BlockNode block, BlockNode pathEnd) {
return null;
}

/**
* Return predecessor on path from 'pathStart' block
*/
public static @Nullable BlockNode getPrevBlockOnPath(BlockNode block, BlockNode pathStart) {
List<BlockNode> preds = block.getPredecessors();
if (preds.contains(pathStart)) {
return pathStart;
}
Set<BlockNode> path = getAllPathsBlocks(pathStart, block);
for (BlockNode p : preds) {
if (path.contains(p)) {
return p;
}
}
return null;
}

/**
* Visit blocks on any path from start to end.
* Only one path will be visited!
Expand Down

0 comments on commit 821cc66

Please sign in to comment.