Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove more PtrToPtr casts in GVN #126844

Merged
merged 6 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 1 addition & 13 deletions compiler/rustc_middle/src/mir/tcx.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,19 +289,7 @@ impl<'tcx> UnOp {
pub fn ty(&self, tcx: TyCtxt<'tcx>, arg_ty: Ty<'tcx>) -> Ty<'tcx> {
match self {
UnOp::Not | UnOp::Neg => arg_ty,
UnOp::PtrMetadata => {
let pointee_ty = arg_ty
.builtin_deref(true)
.unwrap_or_else(|| bug!("PtrMetadata of non-dereferenceable ty {arg_ty:?}"));
if pointee_ty.is_trivially_sized(tcx) {
tcx.types.unit
} else {
let Some(metadata_def_id) = tcx.lang_items().metadata_type() else {
bug!("No metadata_type lang item while looking at {arg_ty:?}")
};
Ty::new_projection(tcx, metadata_def_id, [pointee_ty])
}
}
UnOp::PtrMetadata => arg_ty.pointee_metadata_ty_or_projection(tcx),
}
}
}
Expand Down
28 changes: 28 additions & 0 deletions compiler/rustc_middle/src/ty/sty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1638,6 +1638,34 @@ impl<'tcx> Ty<'tcx> {
}
}

/// Given a pointer or reference type, returns the type of the *pointee*'s
/// metadata. If it can't be determined exactly (perhaps due to still
/// being generic) then a projection through `ptr::Pointee` will be returned.
///
/// This is particularly useful for getting the type of the result of
/// [`UnOp::PtrMetadata`](crate::mir::UnOp::PtrMetadata).
///
/// Panics if `self` is not dereferencable.
#[track_caller]
pub fn pointee_metadata_ty_or_projection(self, tcx: TyCtxt<'tcx>) -> Ty<'tcx> {
let Some(pointee_ty) = self.builtin_deref(true) else {
bug!("Type {self:?} is not a pointer or reference type")
};
if pointee_ty.is_trivially_sized(tcx) {
tcx.types.unit
} else {
match pointee_ty.ptr_metadata_ty_or_tail(tcx, |x| x) {
Ok(metadata_ty) => metadata_ty,
Err(tail_ty) => {
let Some(metadata_def_id) = tcx.lang_items().metadata_type() else {
bug!("No metadata_type lang item while looking at {self:?}")
};
Ty::new_projection(tcx, metadata_def_id, [tail_ty])
}
}
}
}

/// When we create a closure, we record its kind (i.e., what trait
/// it implements, constrained by how it uses its borrows) into its
/// [`ty::ClosureArgs`] or [`ty::CoroutineClosureArgs`] using a type
Expand Down
29 changes: 22 additions & 7 deletions compiler/rustc_mir_transform/src/cost_checker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,15 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {

fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, _location: Location) {
match rvalue {
Rvalue::NullaryOp(NullOp::UbChecks, ..) if !self.tcx.sess.ub_checks() => {
Rvalue::NullaryOp(NullOp::UbChecks, ..)
if !self
.tcx
.sess
.opts
.unstable_opts
.inline_mir_preserve_debug
.unwrap_or(self.tcx.sess.ub_checks()) =>
{
// If this is in optimized MIR it's because it's used later,
// so if we don't need UB checks this session, give a bonus
// here to offset the cost of the call later.
Expand Down Expand Up @@ -111,12 +119,19 @@ impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
}
}
TerminatorKind::Assert { unwind, msg, .. } => {
self.penalty +=
if msg.is_optional_overflow_check() && !self.tcx.sess.overflow_checks() {
INSTR_COST
} else {
CALL_PENALTY
};
self.penalty += if msg.is_optional_overflow_check()
&& !self
.tcx
.sess
.opts
.unstable_opts
.inline_mir_preserve_debug
.unwrap_or(self.tcx.sess.overflow_checks())
{
INSTR_COST
} else {
CALL_PENALTY
};
if let UnwindAction::Cleanup(_) = unwind {
self.penalty += LANDINGPAD_PENALTY;
}
Expand Down
114 changes: 85 additions & 29 deletions compiler/rustc_mir_transform/src/gvn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -823,18 +823,7 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
return self.simplify_cast(kind, value, to, location);
}
Rvalue::BinaryOp(op, box (ref mut lhs, ref mut rhs)) => {
let ty = lhs.ty(self.local_decls, self.tcx);
let lhs = self.simplify_operand(lhs, location);
let rhs = self.simplify_operand(rhs, location);
// Only short-circuit options after we called `simplify_operand`
// on both operands for side effect.
let lhs = lhs?;
let rhs = rhs?;

if let Some(value) = self.simplify_binary(op, ty, lhs, rhs) {
return Some(value);
}
Value::BinaryOp(op, lhs, rhs)
return self.simplify_binary(op, lhs, rhs, location);
}
Rvalue::UnaryOp(op, ref mut arg_op) => {
return self.simplify_unary(op, arg_op, location);
Expand Down Expand Up @@ -987,23 +976,10 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
// `*const [T]` -> `*const T` which remove metadata.
// We run on potentially-generic MIR, though, so unlike codegen
// we can't always know exactly what the metadata are.
// Thankfully, equality on `ptr_metadata_ty_or_tail` gives us
// what we need: `Ok(meta_ty)` if the metadata is known, or
// `Err(tail_ty)` if not. Matching metadata is ok, but if
// that's not known, then matching tail types is also ok,
// allowing things like `*mut (?A, ?T)` <-> `*mut (?B, ?T)`.
// FIXME: Would it be worth trying to normalize, rather than
// passing the identity closure? Or are the types in the
// Cast realistically about as normalized as we can get anyway?
// To allow things like `*mut (?A, ?T)` <-> `*mut (?B, ?T)`,
// it's fine to get a projection as the type.
Value::Cast { kind: CastKind::PtrToPtr, value: inner, from, to }
if from
.builtin_deref(true)
.unwrap()
.ptr_metadata_ty_or_tail(self.tcx, |t| t)
== to
.builtin_deref(true)
.unwrap()
.ptr_metadata_ty_or_tail(self.tcx, |t| t) =>
if self.pointers_have_same_metadata(*from, *to) =>
{
arg_index = *inner;
was_updated = true;
Expand Down Expand Up @@ -1068,6 +1044,52 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {

#[instrument(level = "trace", skip(self), ret)]
fn simplify_binary(
&mut self,
op: BinOp,
lhs_operand: &mut Operand<'tcx>,
rhs_operand: &mut Operand<'tcx>,
location: Location,
) -> Option<VnIndex> {
let lhs = self.simplify_operand(lhs_operand, location);
let rhs = self.simplify_operand(rhs_operand, location);
// Only short-circuit options after we called `simplify_operand`
// on both operands for side effect.
let mut lhs = lhs?;
let mut rhs = rhs?;

let lhs_ty = lhs_operand.ty(self.local_decls, self.tcx);

// If we're comparing pointers, remove `PtrToPtr` casts if the from
// types of both casts and the metadata all match.
if let BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge = op
&& lhs_ty.is_any_ptr()
&& let Value::Cast {
kind: CastKind::PtrToPtr, value: lhs_value, from: lhs_from, ..
} = self.get(lhs)
&& let Value::Cast {
kind: CastKind::PtrToPtr, value: rhs_value, from: rhs_from, ..
} = self.get(rhs)
&& lhs_from == rhs_from
&& self.pointers_have_same_metadata(*lhs_from, lhs_ty)
{
lhs = *lhs_value;
rhs = *rhs_value;
if let Some(op) = self.try_as_operand(lhs, location) {
*lhs_operand = op;
}
if let Some(op) = self.try_as_operand(rhs, location) {
*rhs_operand = op;
}
}

if let Some(value) = self.simplify_binary_inner(op, lhs_ty, lhs, rhs) {
return Some(value);
}
let value = Value::BinaryOp(op, lhs, rhs);
Some(self.insert(value))
}

fn simplify_binary_inner(
&mut self,
op: BinOp,
lhs_ty: Ty<'tcx>,
Expand Down Expand Up @@ -1228,14 +1250,33 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
}
}

// PtrToPtr-then-PtrToPtr can skip the intermediate step
if let PtrToPtr = kind
&& let Value::Cast { kind: inner_kind, value: inner_value, from: inner_from, to: _ } =
*self.get(value)
&& let PtrToPtr = inner_kind
{
from = inner_from;
value = inner_value;
*kind = PtrToPtr;
was_updated = true;
if inner_from == to {
return Some(inner_value);
}
}

// PtrToPtr-then-Transmute can just transmute the original, so long as the
// PtrToPtr didn't change metadata (and thus the size of the pointer)
if let Transmute = kind
&& let Value::Cast {
kind: PtrToPtr,
value: inner_value,
from: inner_from,
to: inner_to,
} = *self.get(value)
&& self.pointers_have_same_metadata(inner_from, inner_to)
{
from = inner_from;
value = inner_value;
was_updated = true;
if inner_from == to {
return Some(inner_value);
Expand Down Expand Up @@ -1289,6 +1330,21 @@ impl<'body, 'tcx> VnState<'body, 'tcx> {
// Fallback: a symbolic `Len`.
Some(self.insert(Value::Len(inner)))
}

fn pointers_have_same_metadata(&self, left_ptr_ty: Ty<'tcx>, right_ptr_ty: Ty<'tcx>) -> bool {
let left_meta_ty = left_ptr_ty.pointee_metadata_ty_or_projection(self.tcx);
let right_meta_ty = right_ptr_ty.pointee_metadata_ty_or_projection(self.tcx);
if left_meta_ty == right_meta_ty {
true
} else if let Ok(left) =
self.tcx.try_normalize_erasing_regions(self.param_env, left_meta_ty)
&& let Ok(right) = self.tcx.try_normalize_erasing_regions(self.param_env, right_meta_ty)
{
left == right
} else {
false
}
}
}

fn op_to_prop_const<'tcx>(
Expand Down
8 changes: 4 additions & 4 deletions tests/coverage/closure_macro.cov-map
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,19 @@ Number of file 0 mappings: 5
- Code(Counter(0)) at (prev + 3, 1) to (start + 0, 2)

Function name: closure_macro::main::{closure#0}
Raw bytes (35): 0x[01, 01, 03, 01, 05, 05, 0b, 09, 0d, 05, 01, 10, 1c, 03, 21, 05, 04, 11, 01, 27, 02, 03, 11, 00, 16, 0d, 00, 17, 00, 1e, 07, 02, 09, 00, 0a]
Raw bytes (35): 0x[01, 01, 03, 01, 05, 05, 0b, 09, 00, 05, 01, 10, 1c, 03, 21, 05, 04, 11, 01, 27, 02, 03, 11, 00, 16, 00, 00, 17, 00, 1e, 07, 02, 09, 00, 0a]
Number of files: 1
- file 0 => global file 1
Number of expressions: 3
- expression 0 operands: lhs = Counter(0), rhs = Counter(1)
- expression 1 operands: lhs = Counter(1), rhs = Expression(2, Add)
- expression 2 operands: lhs = Counter(2), rhs = Counter(3)
- expression 2 operands: lhs = Counter(2), rhs = Zero
Number of file 0 mappings: 5
- Code(Counter(0)) at (prev + 16, 28) to (start + 3, 33)
- Code(Counter(1)) at (prev + 4, 17) to (start + 1, 39)
- Code(Expression(0, Sub)) at (prev + 3, 17) to (start + 0, 22)
= (c0 - c1)
- Code(Counter(3)) at (prev + 0, 23) to (start + 0, 30)
- Code(Zero) at (prev + 0, 23) to (start + 0, 30)
- Code(Expression(1, Add)) at (prev + 2, 9) to (start + 0, 10)
= (c1 + (c2 + c3))
= (c1 + (c2 + Zero))

8 changes: 4 additions & 4 deletions tests/coverage/closure_macro_async.cov-map
Original file line number Diff line number Diff line change
Expand Up @@ -30,19 +30,19 @@ Number of file 0 mappings: 5
- Code(Counter(0)) at (prev + 3, 1) to (start + 0, 2)

Function name: closure_macro_async::test::{closure#0}::{closure#0}
Raw bytes (35): 0x[01, 01, 03, 01, 05, 05, 0b, 09, 0d, 05, 01, 12, 1c, 03, 21, 05, 04, 11, 01, 27, 02, 03, 11, 00, 16, 0d, 00, 17, 00, 1e, 07, 02, 09, 00, 0a]
Raw bytes (35): 0x[01, 01, 03, 01, 05, 05, 0b, 09, 00, 05, 01, 12, 1c, 03, 21, 05, 04, 11, 01, 27, 02, 03, 11, 00, 16, 00, 00, 17, 00, 1e, 07, 02, 09, 00, 0a]
Number of files: 1
- file 0 => global file 1
Number of expressions: 3
- expression 0 operands: lhs = Counter(0), rhs = Counter(1)
- expression 1 operands: lhs = Counter(1), rhs = Expression(2, Add)
- expression 2 operands: lhs = Counter(2), rhs = Counter(3)
- expression 2 operands: lhs = Counter(2), rhs = Zero
Number of file 0 mappings: 5
- Code(Counter(0)) at (prev + 18, 28) to (start + 3, 33)
- Code(Counter(1)) at (prev + 4, 17) to (start + 1, 39)
- Code(Expression(0, Sub)) at (prev + 3, 17) to (start + 0, 22)
= (c0 - c1)
- Code(Counter(3)) at (prev + 0, 23) to (start + 0, 30)
- Code(Zero) at (prev + 0, 23) to (start + 0, 30)
- Code(Expression(1, Add)) at (prev + 2, 9) to (start + 0, 10)
= (c1 + (c2 + c3))
= (c1 + (c2 + Zero))

Loading
Loading