From 1d48f69d65af74201314304623f37f5bcefa9a24 Mon Sep 17 00:00:00 2001
From: Michael Goulet <michael@errs.io>
Date: Thu, 4 Jan 2024 01:46:43 +0000
Subject: [PATCH] Check yield terminator's resume type in borrowck

---
 .../src/type_check/input_output.rs            | 33 +++++++----------
 .../src/type_check/liveness/mod.rs            |  1 +
 compiler/rustc_borrowck/src/type_check/mod.rs | 26 ++++++++++++--
 .../rustc_borrowck/src/universal_regions.rs   | 12 +++++--
 compiler/rustc_middle/src/mir/mod.rs          |  9 +++++
 compiler/rustc_middle/src/mir/visit.rs        |  8 +++++
 compiler/rustc_mir_build/src/build/mod.rs     | 27 ++++++++------
 compiler/rustc_mir_transform/src/coroutine.rs |  1 +
 .../coroutine/check-resume-ty-lifetimes-2.rs  | 35 ++++++++++++++++++
 .../check-resume-ty-lifetimes-2.stderr        | 36 +++++++++++++++++++
 .../ui/coroutine/check-resume-ty-lifetimes.rs | 27 ++++++++++++++
 .../check-resume-ty-lifetimes.stderr          | 11 ++++++
 12 files changed, 190 insertions(+), 36 deletions(-)
 create mode 100644 tests/ui/coroutine/check-resume-ty-lifetimes-2.rs
 create mode 100644 tests/ui/coroutine/check-resume-ty-lifetimes-2.stderr
 create mode 100644 tests/ui/coroutine/check-resume-ty-lifetimes.rs
 create mode 100644 tests/ui/coroutine/check-resume-ty-lifetimes.stderr

diff --git a/compiler/rustc_borrowck/src/type_check/input_output.rs b/compiler/rustc_borrowck/src/type_check/input_output.rs
index 5bd7cc9514ca2..61b6bef3b87b9 100644
--- a/compiler/rustc_borrowck/src/type_check/input_output.rs
+++ b/compiler/rustc_borrowck/src/type_check/input_output.rs
@@ -94,31 +94,22 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
             );
         }
 
-        debug!(
-            "equate_inputs_and_outputs: body.yield_ty {:?}, universal_regions.yield_ty {:?}",
-            body.yield_ty(),
-            universal_regions.yield_ty
-        );
-
-        // We will not have a universal_regions.yield_ty if we yield (by accident)
-        // outside of a coroutine and return an `impl Trait`, so emit a span_delayed_bug
-        // because we don't want to panic in an assert here if we've already got errors.
-        if body.yield_ty().is_some() != universal_regions.yield_ty.is_some() {
-            self.tcx().dcx().span_delayed_bug(
-                body.span,
-                format!(
-                    "Expected body to have yield_ty ({:?}) iff we have a UR yield_ty ({:?})",
-                    body.yield_ty(),
-                    universal_regions.yield_ty,
-                ),
+        if let Some(mir_yield_ty) = body.yield_ty() {
+            let yield_span = body.local_decls[RETURN_PLACE].source_info.span;
+            self.equate_normalized_input_or_output(
+                universal_regions.yield_ty.unwrap(),
+                mir_yield_ty,
+                yield_span,
             );
         }
 
-        if let (Some(mir_yield_ty), Some(ur_yield_ty)) =
-            (body.yield_ty(), universal_regions.yield_ty)
-        {
+        if let Some(mir_resume_ty) = body.resume_ty() {
             let yield_span = body.local_decls[RETURN_PLACE].source_info.span;
-            self.equate_normalized_input_or_output(ur_yield_ty, mir_yield_ty, yield_span);
+            self.equate_normalized_input_or_output(
+                universal_regions.resume_ty.unwrap(),
+                mir_resume_ty,
+                yield_span,
+            );
         }
 
         // Return types are a bit more complex. They may contain opaque `impl Trait` types.
diff --git a/compiler/rustc_borrowck/src/type_check/liveness/mod.rs b/compiler/rustc_borrowck/src/type_check/liveness/mod.rs
index dc4695fd2b058..e137bc1be0aeb 100644
--- a/compiler/rustc_borrowck/src/type_check/liveness/mod.rs
+++ b/compiler/rustc_borrowck/src/type_check/liveness/mod.rs
@@ -183,6 +183,7 @@ impl<'cx, 'tcx> Visitor<'tcx> for LiveVariablesVisitor<'cx, 'tcx> {
         match ty_context {
             TyContext::ReturnTy(SourceInfo { span, .. })
             | TyContext::YieldTy(SourceInfo { span, .. })
+            | TyContext::ResumeTy(SourceInfo { span, .. })
             | TyContext::UserTy(span)
             | TyContext::LocalDecl { source_info: SourceInfo { span, .. }, .. } => {
                 span_bug!(span, "should not be visiting outside of the CFG: {:?}", ty_context);
diff --git a/compiler/rustc_borrowck/src/type_check/mod.rs b/compiler/rustc_borrowck/src/type_check/mod.rs
index 80575e30a8d23..9c0f53ddb86fa 100644
--- a/compiler/rustc_borrowck/src/type_check/mod.rs
+++ b/compiler/rustc_borrowck/src/type_check/mod.rs
@@ -1450,13 +1450,13 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
                     }
                 }
             }
-            TerminatorKind::Yield { value, .. } => {
+            TerminatorKind::Yield { value, resume_arg, .. } => {
                 self.check_operand(value, term_location);
 
-                let value_ty = value.ty(body, tcx);
                 match body.yield_ty() {
                     None => span_mirbug!(self, term, "yield in non-coroutine"),
                     Some(ty) => {
+                        let value_ty = value.ty(body, tcx);
                         if let Err(terr) = self.sub_types(
                             value_ty,
                             ty,
@@ -1474,6 +1474,28 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
                         }
                     }
                 }
+
+                match body.resume_ty() {
+                    None => span_mirbug!(self, term, "yield in non-coroutine"),
+                    Some(ty) => {
+                        let resume_ty = resume_arg.ty(body, tcx);
+                        if let Err(terr) = self.sub_types(
+                            ty,
+                            resume_ty.ty,
+                            term_location.to_locations(),
+                            ConstraintCategory::Yield,
+                        ) {
+                            span_mirbug!(
+                                self,
+                                term,
+                                "type of resume place is {:?}, but the resume type is {:?}: {:?}",
+                                resume_ty,
+                                ty,
+                                terr
+                            );
+                        }
+                    }
+                }
             }
         }
     }
diff --git a/compiler/rustc_borrowck/src/universal_regions.rs b/compiler/rustc_borrowck/src/universal_regions.rs
index a02304a2f8b30..addb41ff5fc8f 100644
--- a/compiler/rustc_borrowck/src/universal_regions.rs
+++ b/compiler/rustc_borrowck/src/universal_regions.rs
@@ -76,6 +76,8 @@ pub struct UniversalRegions<'tcx> {
     pub unnormalized_input_tys: &'tcx [Ty<'tcx>],
 
     pub yield_ty: Option<Ty<'tcx>>,
+
+    pub resume_ty: Option<Ty<'tcx>>,
 }
 
 /// The "defining type" for this MIR. The key feature of the "defining
@@ -525,9 +527,12 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
         debug!("build: extern regions = {}..{}", first_extern_index, first_local_index);
         debug!("build: local regions  = {}..{}", first_local_index, num_universals);
 
-        let yield_ty = match defining_ty {
-            DefiningTy::Coroutine(_, args) => Some(args.as_coroutine().yield_ty()),
-            _ => None,
+        let (resume_ty, yield_ty) = match defining_ty {
+            DefiningTy::Coroutine(_, args) => {
+                let tys = args.as_coroutine();
+                (Some(tys.resume_ty()), Some(tys.yield_ty()))
+            }
+            _ => (None, None),
         };
 
         UniversalRegions {
@@ -541,6 +546,7 @@ impl<'cx, 'tcx> UniversalRegionsBuilder<'cx, 'tcx> {
             unnormalized_output_ty: *unnormalized_output_ty,
             unnormalized_input_tys,
             yield_ty,
+            resume_ty,
         }
     }
 
diff --git a/compiler/rustc_middle/src/mir/mod.rs b/compiler/rustc_middle/src/mir/mod.rs
index 5c425fef27ebc..01ad3aefffa3b 100644
--- a/compiler/rustc_middle/src/mir/mod.rs
+++ b/compiler/rustc_middle/src/mir/mod.rs
@@ -250,6 +250,9 @@ pub struct CoroutineInfo<'tcx> {
     /// The yield type of the function, if it is a coroutine.
     pub yield_ty: Option<Ty<'tcx>>,
 
+    /// The resume type of the function, if it is a coroutine.
+    pub resume_ty: Option<Ty<'tcx>>,
+
     /// Coroutine drop glue.
     pub coroutine_drop: Option<Body<'tcx>>,
 
@@ -385,6 +388,7 @@ impl<'tcx> Body<'tcx> {
             coroutine: coroutine_kind.map(|coroutine_kind| {
                 Box::new(CoroutineInfo {
                     yield_ty: None,
+                    resume_ty: None,
                     coroutine_drop: None,
                     coroutine_layout: None,
                     coroutine_kind,
@@ -551,6 +555,11 @@ impl<'tcx> Body<'tcx> {
         self.coroutine.as_ref().and_then(|coroutine| coroutine.yield_ty)
     }
 
+    #[inline]
+    pub fn resume_ty(&self) -> Option<Ty<'tcx>> {
+        self.coroutine.as_ref().and_then(|coroutine| coroutine.resume_ty)
+    }
+
     #[inline]
     pub fn coroutine_layout(&self) -> Option<&CoroutineLayout<'tcx>> {
         self.coroutine.as_ref().and_then(|coroutine| coroutine.coroutine_layout.as_ref())
diff --git a/compiler/rustc_middle/src/mir/visit.rs b/compiler/rustc_middle/src/mir/visit.rs
index 132ecf91af187..2ccf5a9f6f7ad 100644
--- a/compiler/rustc_middle/src/mir/visit.rs
+++ b/compiler/rustc_middle/src/mir/visit.rs
@@ -996,6 +996,12 @@ macro_rules! super_body {
                     TyContext::YieldTy(SourceInfo::outermost(span))
                 );
             }
+            if let Some(resume_ty) = $(& $mutability)? gen.resume_ty {
+                $self.visit_ty(
+                    resume_ty,
+                    TyContext::ResumeTy(SourceInfo::outermost(span))
+                );
+            }
         }
 
         for (bb, data) in basic_blocks_iter!($body, $($mutability, $invalidate)?) {
@@ -1244,6 +1250,8 @@ pub enum TyContext {
 
     YieldTy(SourceInfo),
 
+    ResumeTy(SourceInfo),
+
     /// A type found at some location.
     Location(Location),
 }
diff --git a/compiler/rustc_mir_build/src/build/mod.rs b/compiler/rustc_mir_build/src/build/mod.rs
index e0199fb876717..c4cade839478c 100644
--- a/compiler/rustc_mir_build/src/build/mod.rs
+++ b/compiler/rustc_mir_build/src/build/mod.rs
@@ -488,7 +488,7 @@ fn construct_fn<'tcx>(
 
     let arguments = &thir.params;
 
-    let (yield_ty, return_ty) = if coroutine_kind.is_some() {
+    let (resume_ty, yield_ty, return_ty) = if coroutine_kind.is_some() {
         let coroutine_ty = arguments[thir::UPVAR_ENV_PARAM].ty;
         let coroutine_sig = match coroutine_ty.kind() {
             ty::Coroutine(_, gen_args, ..) => gen_args.as_coroutine().sig(),
@@ -496,9 +496,9 @@ fn construct_fn<'tcx>(
                 span_bug!(span, "coroutine w/o coroutine type: {:?}", coroutine_ty)
             }
         };
-        (Some(coroutine_sig.yield_ty), coroutine_sig.return_ty)
+        (Some(coroutine_sig.resume_ty), Some(coroutine_sig.yield_ty), coroutine_sig.return_ty)
     } else {
-        (None, fn_sig.output())
+        (None, None, fn_sig.output())
     };
 
     if let Some(custom_mir_attr) =
@@ -562,9 +562,12 @@ fn construct_fn<'tcx>(
     } else {
         None
     };
-    if yield_ty.is_some() {
+
+    if coroutine_kind.is_some() {
         body.coroutine.as_mut().unwrap().yield_ty = yield_ty;
+        body.coroutine.as_mut().unwrap().resume_ty = resume_ty;
     }
+
     body
 }
 
@@ -631,18 +634,18 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
     let hir_id = tcx.local_def_id_to_hir_id(def_id);
     let coroutine_kind = tcx.coroutine_kind(def_id);
 
-    let (inputs, output, yield_ty) = match tcx.def_kind(def_id) {
+    let (inputs, output, resume_ty, yield_ty) = match tcx.def_kind(def_id) {
         DefKind::Const
         | DefKind::AssocConst
         | DefKind::AnonConst
         | DefKind::InlineConst
-        | DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None),
+        | DefKind::Static(_) => (vec![], tcx.type_of(def_id).instantiate_identity(), None, None),
         DefKind::Ctor(..) | DefKind::Fn | DefKind::AssocFn => {
             let sig = tcx.liberate_late_bound_regions(
                 def_id.to_def_id(),
                 tcx.fn_sig(def_id).instantiate_identity(),
             );
-            (sig.inputs().to_vec(), sig.output(), None)
+            (sig.inputs().to_vec(), sig.output(), None, None)
         }
         DefKind::Closure if coroutine_kind.is_some() => {
             let coroutine_ty = tcx.type_of(def_id).instantiate_identity();
@@ -650,9 +653,10 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
                 bug!("expected type of coroutine-like closure to be a coroutine")
             };
             let args = args.as_coroutine();
+            let resume_ty = args.resume_ty();
             let yield_ty = args.yield_ty();
             let return_ty = args.return_ty();
-            (vec![coroutine_ty, args.resume_ty()], return_ty, Some(yield_ty))
+            (vec![coroutine_ty, args.resume_ty()], return_ty, Some(resume_ty), Some(yield_ty))
         }
         DefKind::Closure => {
             let closure_ty = tcx.type_of(def_id).instantiate_identity();
@@ -666,7 +670,7 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
                 ty::ClosureKind::FnMut => Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, closure_ty),
                 ty::ClosureKind::FnOnce => closure_ty,
             };
-            ([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None)
+            ([self_ty].into_iter().chain(sig.inputs().to_vec()).collect(), sig.output(), None, None)
         }
         dk => bug!("{:?} is not a body: {:?}", def_id, dk),
     };
@@ -705,7 +709,10 @@ fn construct_error(tcx: TyCtxt<'_>, def_id: LocalDefId, guar: ErrorGuaranteed) -
         Some(guar),
     );
 
-    body.coroutine.as_mut().map(|gen| gen.yield_ty = yield_ty);
+    body.coroutine.as_mut().map(|gen| {
+        gen.yield_ty = yield_ty;
+        gen.resume_ty = resume_ty;
+    });
 
     body
 }
diff --git a/compiler/rustc_mir_transform/src/coroutine.rs b/compiler/rustc_mir_transform/src/coroutine.rs
index ce1a36cf67021..33e305497b505 100644
--- a/compiler/rustc_mir_transform/src/coroutine.rs
+++ b/compiler/rustc_mir_transform/src/coroutine.rs
@@ -1733,6 +1733,7 @@ impl<'tcx> MirPass<'tcx> for StateTransform {
         }
 
         body.coroutine.as_mut().unwrap().yield_ty = None;
+        body.coroutine.as_mut().unwrap().resume_ty = None;
         body.coroutine.as_mut().unwrap().coroutine_layout = Some(layout);
 
         // Insert `drop(coroutine_struct)` which is used to drop upvars for coroutines in
diff --git a/tests/ui/coroutine/check-resume-ty-lifetimes-2.rs b/tests/ui/coroutine/check-resume-ty-lifetimes-2.rs
new file mode 100644
index 0000000000000..a316c50e86732
--- /dev/null
+++ b/tests/ui/coroutine/check-resume-ty-lifetimes-2.rs
@@ -0,0 +1,35 @@
+#![feature(coroutine_trait)]
+#![feature(coroutines)]
+
+use std::ops::Coroutine;
+
+struct Contravariant<'a>(fn(&'a ()));
+struct Covariant<'a>(fn() -> &'a ());
+
+fn bad1<'short, 'long: 'short>() -> impl Coroutine<Covariant<'short>> {
+    |_: Covariant<'short>| {
+        let a: Covariant<'long> = yield ();
+        //~^ ERROR lifetime may not live long enough
+    }
+}
+
+fn bad2<'short, 'long: 'short>() -> impl Coroutine<Contravariant<'long>> {
+    |_: Contravariant<'long>| {
+        let a: Contravariant<'short> = yield ();
+        //~^ ERROR lifetime may not live long enough
+    }
+}
+
+fn good1<'short, 'long: 'short>() -> impl Coroutine<Covariant<'long>> {
+    |_: Covariant<'long>| {
+        let a: Covariant<'short> = yield ();
+    }
+}
+
+fn good2<'short, 'long: 'short>() -> impl Coroutine<Contravariant<'short>> {
+    |_: Contravariant<'short>| {
+        let a: Contravariant<'long> = yield ();
+    }
+}
+
+fn main() {}
diff --git a/tests/ui/coroutine/check-resume-ty-lifetimes-2.stderr b/tests/ui/coroutine/check-resume-ty-lifetimes-2.stderr
new file mode 100644
index 0000000000000..e0cbca2dd5267
--- /dev/null
+++ b/tests/ui/coroutine/check-resume-ty-lifetimes-2.stderr
@@ -0,0 +1,36 @@
+error: lifetime may not live long enough
+  --> $DIR/check-resume-ty-lifetimes-2.rs:11:16
+   |
+LL | fn bad1<'short, 'long: 'short>() -> impl Coroutine<Covariant<'short>> {
+   |         ------  ----- lifetime `'long` defined here
+   |         |
+   |         lifetime `'short` defined here
+LL |     |_: Covariant<'short>| {
+LL |         let a: Covariant<'long> = yield ();
+   |                ^^^^^^^^^^^^^^^^ type annotation requires that `'short` must outlive `'long`
+   |
+   = help: consider adding the following bound: `'short: 'long`
+help: consider adding 'move' keyword before the nested closure
+   |
+LL |     move |_: Covariant<'short>| {
+   |     ++++
+
+error: lifetime may not live long enough
+  --> $DIR/check-resume-ty-lifetimes-2.rs:18:40
+   |
+LL | fn bad2<'short, 'long: 'short>() -> impl Coroutine<Contravariant<'long>> {
+   |         ------  ----- lifetime `'long` defined here
+   |         |
+   |         lifetime `'short` defined here
+LL |     |_: Contravariant<'long>| {
+LL |         let a: Contravariant<'short> = yield ();
+   |                                        ^^^^^^^^ yielding this value requires that `'short` must outlive `'long`
+   |
+   = help: consider adding the following bound: `'short: 'long`
+help: consider adding 'move' keyword before the nested closure
+   |
+LL |     move |_: Contravariant<'long>| {
+   |     ++++
+
+error: aborting due to 2 previous errors
+
diff --git a/tests/ui/coroutine/check-resume-ty-lifetimes.rs b/tests/ui/coroutine/check-resume-ty-lifetimes.rs
new file mode 100644
index 0000000000000..add0b5080a8a8
--- /dev/null
+++ b/tests/ui/coroutine/check-resume-ty-lifetimes.rs
@@ -0,0 +1,27 @@
+#![feature(coroutine_trait)]
+#![feature(coroutines)]
+#![allow(unused)]
+
+use std::ops::Coroutine;
+use std::ops::CoroutineState;
+use std::pin::pin;
+
+fn mk_static(s: &str) -> &'static str {
+    let mut storage: Option<&'static str> = None;
+
+    let mut coroutine = pin!(|_: &str| {
+        let x: &'static str = yield ();
+        //~^ ERROR lifetime may not live long enough
+        storage = Some(x);
+    });
+
+    coroutine.as_mut().resume(s);
+    coroutine.as_mut().resume(s);
+
+    storage.unwrap()
+}
+
+fn main() {
+    let s = mk_static(&String::from("hello, world"));
+    println!("{s}");
+}
diff --git a/tests/ui/coroutine/check-resume-ty-lifetimes.stderr b/tests/ui/coroutine/check-resume-ty-lifetimes.stderr
new file mode 100644
index 0000000000000..f373aa778a82c
--- /dev/null
+++ b/tests/ui/coroutine/check-resume-ty-lifetimes.stderr
@@ -0,0 +1,11 @@
+error: lifetime may not live long enough
+  --> $DIR/check-resume-ty-lifetimes.rs:13:16
+   |
+LL | fn mk_static(s: &str) -> &'static str {
+   |                 - let's call the lifetime of this reference `'1`
+...
+LL |         let x: &'static str = yield ();
+   |                ^^^^^^^^^^^^ type annotation requires that `'1` must outlive `'static`
+
+error: aborting due to 1 previous error
+