Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Disable jump threading of float equality #128271

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions compiler/rustc_mir_transform/src/jump_threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -509,6 +509,13 @@ impl<'tcx, 'a> TOFinder<'tcx, 'a> {
BinOp::Ne => ScalarInt::FALSE,
_ => return None,
};
if value.const_.ty().is_floating_point() {
// Floating point equality does not follow bit-patterns.
// -0.0 and NaN both have special rules for equality,
// and therefore we cannot use integer comparisons for them.
// Avoid handling them, though this could be extended in the future.
return None;
}
let value = value.const_.normalize(self.tcx, self.param_env).try_to_scalar_int()?;
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this PR is an easily backportable targeted fix, it might be worth thinking about how to prevent this more generally. It seems highly confusing that a method called try_to_scalar_int will return floats, and it doesn't surprise me that this pass was confused by this. I would also expect try_to_scalar_int to be OK for comparing bitwise.
Maybe it should be changed to not return anything for floats?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Possibly bad suggestion, but maybe we shouldn't use valtrees to represent floats 🤔

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

try_to_bitwise_repr?

let conds = conditions.map(self.arena, |c| Condition {
value,
Expand Down
59 changes: 59 additions & 0 deletions tests/mir-opt/jump_threading.floats.JumpThreading.panic-abort.diff
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
- // MIR for `floats` before JumpThreading
+ // MIR for `floats` after JumpThreading

fn floats() -> u32 {
let mut _0: u32;
let _1: f64;
let mut _2: bool;
let mut _3: bool;
let mut _4: f64;
scope 1 {
debug x => _1;
}

bb0: {
StorageLive(_1);
StorageLive(_2);
_2 = const true;
- switchInt(move _2) -> [0: bb2, otherwise: bb1];
+ goto -> bb1;
}

bb1: {
_1 = const -0f64;
goto -> bb3;
}

bb2: {
_1 = const 1f64;
goto -> bb3;
}

bb3: {
StorageDead(_2);
StorageLive(_3);
StorageLive(_4);
_4 = _1;
_3 = Eq(move _4, const 0f64);
switchInt(move _3) -> [0: bb5, otherwise: bb4];
}

bb4: {
StorageDead(_4);
_0 = const 0_u32;
goto -> bb6;
}

bb5: {
StorageDead(_4);
_0 = const 1_u32;
goto -> bb6;
}

bb6: {
StorageDead(_3);
StorageDead(_1);
return;
}
}

Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
- // MIR for `floats` before JumpThreading
+ // MIR for `floats` after JumpThreading

fn floats() -> u32 {
let mut _0: u32;
let _1: f64;
let mut _2: bool;
let mut _3: bool;
let mut _4: f64;
scope 1 {
debug x => _1;
}

bb0: {
StorageLive(_1);
StorageLive(_2);
_2 = const true;
- switchInt(move _2) -> [0: bb2, otherwise: bb1];
+ goto -> bb1;
}

bb1: {
_1 = const -0f64;
goto -> bb3;
}

bb2: {
_1 = const 1f64;
goto -> bb3;
}

bb3: {
StorageDead(_2);
StorageLive(_3);
StorageLive(_4);
_4 = _1;
_3 = Eq(move _4, const 0f64);
switchInt(move _3) -> [0: bb5, otherwise: bb4];
}

bb4: {
StorageDead(_4);
_0 = const 0_u32;
goto -> bb6;
}

bb5: {
StorageDead(_4);
_0 = const 1_u32;
goto -> bb6;
}

bb6: {
StorageDead(_3);
StorageDead(_1);
return;
}
}

12 changes: 12 additions & 0 deletions tests/mir-opt/jump_threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,16 @@ fn aggregate_copy() -> u32 {
if c == 2 { b.0 } else { 13 }
}

fn floats() -> u32 {
// CHECK-LABEL: fn floats(
// CHECK: switchInt(

// Test for issue #128243, where float equality was assumed to be bitwise.
// When adding float support, it must be ensured that this continues working properly.
let x = if true { -0.0 } else { 1.0 };
if x == 0.0 { 0 } else { 1 }
}

fn main() {
// CHECK-LABEL: fn main(
too_complex(Ok(0));
Expand All @@ -535,6 +545,7 @@ fn main() {
disappearing_bb(7);
aggregate(7);
assume(7, false);
floats();
}

// EMIT_MIR jump_threading.too_complex.JumpThreading.diff
Expand All @@ -550,3 +561,4 @@ fn main() {
// EMIT_MIR jump_threading.aggregate.JumpThreading.diff
// EMIT_MIR jump_threading.assume.JumpThreading.diff
// EMIT_MIR jump_threading.aggregate_copy.JumpThreading.diff
// EMIT_MIR jump_threading.floats.JumpThreading.diff
Loading