From 9be87d0746d8acbf43387e496b013d14754925c2 Mon Sep 17 00:00:00 2001 From: Raphael Mosaner Date: Tue, 3 Oct 2023 13:27:07 +0200 Subject: [PATCH] [GR-48705] Infer input msb during stamp inversion for integer SignExtend. --- .../nodes/test/StampInverterTest.java | 16 +++++++++ .../core/common/type/IntegerStamp.java | 35 +++++++++++++++++-- 2 files changed, 48 insertions(+), 3 deletions(-) diff --git a/compiler/src/jdk.internal.vm.compiler.test/src/org/graalvm/compiler/nodes/test/StampInverterTest.java b/compiler/src/jdk.internal.vm.compiler.test/src/org/graalvm/compiler/nodes/test/StampInverterTest.java index d9eede51abb7..ac91399a2e32 100644 --- a/compiler/src/jdk.internal.vm.compiler.test/src/org/graalvm/compiler/nodes/test/StampInverterTest.java +++ b/compiler/src/jdk.internal.vm.compiler.test/src/org/graalvm/compiler/nodes/test/StampInverterTest.java @@ -80,6 +80,22 @@ public void invertIntegerSignExtend04() { assertTrue("Stamp cannot be inverted and should be empty!", invertSignExtend(stamp).isEmpty()); } + @Test + public void invertIntegerSignExtend05() { + // 32 -> 8bit: xx...x0 xxxxxxxx -> 0xxxxxxx (msb has to be 0) + IntegerStamp stamp = IntegerStamp.stampForMask(32, 0, CodeUtil.mask(32) ^ 256); + Stamp expected = IntegerStamp.stampForMask(8, 0, CodeUtil.mask(7)); + assertEquals(expected, invertSignExtend(stamp)); + } + + @Test + public void invertIntegerSignExtend06() { + // 32 -> 8bit: xx...x1 xxxxxxxx -> 1xxxxxxx (msb has to be 1) + IntegerStamp stamp = IntegerStamp.stampForMask(32, 256, CodeUtil.mask(32)); + Stamp expected = IntegerStamp.stampForMask(8, 128, CodeUtil.mask(8)); + assertEquals(expected, invertSignExtend(stamp)); + } + private static Stamp invertZeroExtend(Stamp toInvert) { IntegerConvertOp signExtend = ArithmeticOpTable.forStamp(toInvert).getZeroExtend(); return signExtend.invertStamp(8, 32, toInvert); diff --git a/compiler/src/jdk.internal.vm.compiler/src/org/graalvm/compiler/core/common/type/IntegerStamp.java b/compiler/src/jdk.internal.vm.compiler/src/org/graalvm/compiler/core/common/type/IntegerStamp.java index e29e8f4d53bf..37ef6bc4b3b3 100644 --- a/compiler/src/jdk.internal.vm.compiler/src/org/graalvm/compiler/core/common/type/IntegerStamp.java +++ b/compiler/src/jdk.internal.vm.compiler/src/org/graalvm/compiler/core/common/type/IntegerStamp.java @@ -2092,11 +2092,40 @@ public Stamp invertStamp(int inputBits, int resultBits, Stamp outStamp) { return createEmptyStamp(inputBits); } + /* + * Calculate bounds and mayBeSet/mustBeSet bits for the input based on + * bit width and potentially inferred msb. + */ long inputMask = CodeUtil.mask(inputBits); - long inputUpperBound = maxValueForMasks(inputBits, stamp.mustBeSet() & inputMask, stamp.mayBeSet() & inputMask); - long inputLowerBound = minValueForMasks(inputBits, stamp.mustBeSet() & inputMask, stamp.mayBeSet() & inputMask); + long inputMustBeSet = stamp.mustBeSet() & inputMask; + long inputMayBeSet = stamp.mayBeSet() & inputMask; - return StampFactory.forIntegerWithMask(inputBits, inputLowerBound, inputUpperBound, stamp.mustBeSet() & inputMask, stamp.mayBeSet() & inputMask); + if (!inputMSBOne && !inputMSBZero) { + /* + * Input MSB yet unknown, try to infer it from the extension: + * + * @formatter:off + * + * xx0x xxxx implies that the extension is 0000 which implies that the MSB of the input is 0 + * x1xx xxxx implies that the extension is 1111 which implies that the MSB of the input is 1 + * + * @formatter:on + */ + if (zeroInExtension) { + long msbZeroMask = inputMask ^ (1 << (inputBits - 1)); + inputMustBeSet &= msbZeroMask; + inputMayBeSet &= msbZeroMask; + } else if (oneInExtension) { + long msbOneMask = 1 << (inputBits - 1); + inputMustBeSet |= msbOneMask; + inputMayBeSet |= msbOneMask; + } + } + + long inputUpperBound = maxValueForMasks(inputBits, inputMustBeSet, inputMayBeSet); + long inputLowerBound = minValueForMasks(inputBits, inputMustBeSet, inputMayBeSet); + + return StampFactory.forIntegerWithMask(inputBits, inputLowerBound, inputUpperBound, inputMustBeSet, inputMayBeSet); } },