diff --git a/compiler/rustc_builtin_macros/src/deriving/mod.rs b/compiler/rustc_builtin_macros/src/deriving/mod.rs index e3a93ae13e45f..32936ac183df9 100644 --- a/compiler/rustc_builtin_macros/src/deriving/mod.rs +++ b/compiler/rustc_builtin_macros/src/deriving/mod.rs @@ -27,6 +27,7 @@ pub(crate) mod decodable; pub(crate) mod default; pub(crate) mod encodable; pub(crate) mod hash; +pub(crate) mod smart_ptr; #[path = "cmp/eq.rs"] pub(crate) mod eq; diff --git a/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs b/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs new file mode 100644 index 0000000000000..f976602f7c093 --- /dev/null +++ b/compiler/rustc_builtin_macros/src/deriving/smart_ptr.rs @@ -0,0 +1,135 @@ +use ast::HasAttrs; +use rustc_ast::{ + self as ast, GenericArg, GenericBound, GenericParamKind, ItemKind, MetaItem, + TraitBoundModifiers, +}; +use rustc_expand::base::{Annotatable, ExtCtxt}; +use rustc_span::symbol::{sym, Ident}; +use rustc_span::Span; +use smallvec::{smallvec, SmallVec}; +use thin_vec::{thin_vec, ThinVec}; + +macro_rules! path { + ($span:expr, $($part:ident)::*) => { vec![$(Ident::new(sym::$part, $span),)*] } +} + +pub fn expand_deriving_smart_ptr( + cx: &ExtCtxt<'_>, + span: Span, + _mitem: &MetaItem, + item: &Annotatable, + push: &mut dyn FnMut(Annotatable), + _is_const: bool, +) { + let (name_ident, generics) = if let Annotatable::Item(aitem) = item + && let ItemKind::Struct(_, g) = &aitem.kind + { + (aitem.ident, g) + } else { + cx.dcx().struct_span_err(span, "`SmartPointer` can only be derived on `struct`s").emit(); + return; + }; + + // Convert generic parameters (from the struct) into generic args. + let mut pointee_param = None; + let mut multiple_pointee_diag: SmallVec<[_; 2]> = smallvec![]; + let self_params = generics + .params + .iter() + .enumerate() + .map(|(idx, p)| match p.kind { + GenericParamKind::Lifetime => GenericArg::Lifetime(cx.lifetime(p.span(), p.ident)), + GenericParamKind::Type { .. } => { + if p.attrs().iter().any(|attr| attr.has_name(sym::pointee)) { + if pointee_param.is_some() { + multiple_pointee_diag.push(cx.dcx().struct_span_err( + p.span(), + "`SmartPointer` can only admit one type as pointee", + )); + } else { + pointee_param = Some(idx); + } + } + GenericArg::Type(cx.ty_ident(p.span(), p.ident)) + } + GenericParamKind::Const { .. } => GenericArg::Const(cx.const_ident(p.span(), p.ident)), + }) + .collect::>(); + let Some(pointee_param_idx) = pointee_param else { + cx.dcx().struct_span_err( + span, + "At least one generic type should be designated as `#[pointee]` in order to derive `SmartPointer` traits", + ).emit(); + return; + }; + if !multiple_pointee_diag.is_empty() { + for diag in multiple_pointee_diag { + diag.emit(); + } + return; + } + + // Create the type of `self`. + let path = cx.path_all(span, false, vec![name_ident], self_params.clone()); + let self_type = cx.ty_path(path); + + // Declare helper function that adds implementation blocks. + // FIXME: Copy attrs from struct? + let attrs = thin_vec![cx.attr_word(sym::automatically_derived, span),]; + let mut add_impl_block = |generics, trait_symbol, trait_args| { + let mut parts = path!(span, core::ops); + parts.push(Ident::new(trait_symbol, span)); + let trait_path = cx.path_all(span, true, parts, trait_args); + let trait_ref = cx.trait_ref(trait_path); + let item = cx.item( + span, + Ident::empty(), + attrs.clone(), + ast::ItemKind::Impl(Box::new(ast::Impl { + safety: ast::Safety::Default, + polarity: ast::ImplPolarity::Positive, + defaultness: ast::Defaultness::Final, + constness: ast::Const::No, + generics, + of_trait: Some(trait_ref), + self_ty: self_type.clone(), + items: ThinVec::new(), + })), + ); + push(Annotatable::Item(item)); + }; + + // Create unsized `self`, that is, one where the first type arg is replace with `__S`. For + // example, instead of `MyType<'a, T>`, it will be `MyType<'a, __S>`. + let s_ty = cx.ty_ident(span, Ident::new(sym::__S, span)); + let mut alt_self_params = self_params; + alt_self_params[pointee_param_idx] = GenericArg::Type(s_ty.clone()); + let alt_self_type = cx.ty_path(cx.path_all(span, false, vec![name_ident], alt_self_params)); + + // Find the first type parameter and add an `Unsize<__S>` bound to it. + let mut impl_generics = generics.clone(); + { + let p = &mut impl_generics.params[pointee_param_idx]; + let arg = GenericArg::Type(s_ty.clone()); + let unsize = cx.path_all(span, true, path!(span, core::marker::Unsize), vec![arg]); + p.bounds.push(cx.trait_bound(unsize, false)); + } + + // Add the `__S: ?Sized` extra parameter to the impl block. + let sized = cx.path_global(span, path!(span, core::marker::Sized)); + let bound = GenericBound::Trait( + cx.poly_trait_ref(span, sized), + TraitBoundModifiers { + polarity: ast::BoundPolarity::Maybe(span), + constness: ast::BoundConstness::Never, + asyncness: ast::BoundAsyncness::Normal, + }, + ); + let extra_param = cx.typaram(span, Ident::new(sym::__S, span), vec![bound], None); + impl_generics.params.push(extra_param); + + // Add the impl blocks for `DispatchFromDyn` and `CoerceUnsized`. + let gen_args = vec![GenericArg::Type(alt_self_type.clone())]; + add_impl_block(impl_generics.clone(), sym::DispatchFromDyn, gen_args.clone()); + add_impl_block(impl_generics.clone(), sym::CoerceUnsized, gen_args.clone()); +} diff --git a/compiler/rustc_builtin_macros/src/lib.rs b/compiler/rustc_builtin_macros/src/lib.rs index 744c7f9d09006..aedcc30f1c8b0 100644 --- a/compiler/rustc_builtin_macros/src/lib.rs +++ b/compiler/rustc_builtin_macros/src/lib.rs @@ -125,6 +125,7 @@ pub fn register_builtin_macros(resolver: &mut dyn ResolverExpand) { PartialOrd: partial_ord::expand_deriving_partial_ord, RustcDecodable: decodable::expand_deriving_rustc_decodable, RustcEncodable: encodable::expand_deriving_rustc_encodable, + SmartPointer: smart_ptr::expand_deriving_smart_ptr, } let client = proc_macro::bridge::client::Client::expand1(proc_macro::quote); diff --git a/compiler/rustc_feature/src/builtin_attrs.rs b/compiler/rustc_feature/src/builtin_attrs.rs index 0b4a871dd50c2..e76e1fac87861 100644 --- a/compiler/rustc_feature/src/builtin_attrs.rs +++ b/compiler/rustc_feature/src/builtin_attrs.rs @@ -527,6 +527,12 @@ pub const BUILTIN_ATTRIBUTES: &[BuiltinAttribute] = &[ EncodeCrossCrate::No, coroutines, experimental!(coroutines) ), + // `#[pointee]` attribute to designate the pointee type in SmartPointer derive-macro + gated!( + pointee, Normal, template!(Word), ErrorFollowing, + EncodeCrossCrate::No, derive_smart_pointer, experimental!(derive_smart_pointer) + ), + // ========================================================================== // Internal attributes: Stability, deprecation, and unsafe: // ========================================================================== diff --git a/compiler/rustc_feature/src/unstable.rs b/compiler/rustc_feature/src/unstable.rs index dc4807bab2d3d..0ede9cd8542b9 100644 --- a/compiler/rustc_feature/src/unstable.rs +++ b/compiler/rustc_feature/src/unstable.rs @@ -438,6 +438,8 @@ declare_features! ( (unstable, deprecated_safe, "1.61.0", Some(94978)), /// Allows having using `suggestion` in the `#[deprecated]` attribute. (unstable, deprecated_suggestion, "1.61.0", Some(94785)), + /// Allows deriving `SmartPointer` traits + (unstable, derive_smart_pointer, "1.79.0", Some(123430)), /// Allows deref patterns. (incomplete, deref_patterns, "1.79.0", Some(87121)), /// Controls errors in trait implementations. diff --git a/compiler/rustc_span/src/symbol.rs b/compiler/rustc_span/src/symbol.rs index ace4dff46aa0a..c2e639f014cf9 100644 --- a/compiler/rustc_span/src/symbol.rs +++ b/compiler/rustc_span/src/symbol.rs @@ -172,6 +172,7 @@ symbols! { Center, Cleanup, Clone, + CoerceUnsized, Command, ConstParamTy, Context, @@ -187,6 +188,7 @@ symbols! { DiagMessage, Diagnostic, DirBuilder, + DispatchFromDyn, Display, DoubleEndedIterator, Duration, @@ -298,8 +300,10 @@ symbols! { Saturating, Send, SeqCst, + Sized, SliceIndex, SliceIter, + SmartPointer, Some, SpanCtxt, String, @@ -322,6 +326,7 @@ symbols! { TyCtxt, TyKind, Unknown, + Unsize, Upvars, Vec, VecDeque, @@ -699,6 +704,7 @@ symbols! { derive, derive_const, derive_default_enum, + derive_smart_pointer, destruct, destructuring_assignment, diagnostic, @@ -1302,6 +1308,7 @@ symbols! { on, on_unimplemented, opaque, + ops, opt_out_copy, optimize, optimize_attribute, @@ -1376,6 +1383,7 @@ symbols! { plugin, plugin_registrar, plugins, + pointee, pointee_trait, pointer, pointer_like, diff --git a/library/core/src/lib.rs b/library/core/src/lib.rs index 86078c8377bdb..0770f23648ddb 100644 --- a/library/core/src/lib.rs +++ b/library/core/src/lib.rs @@ -475,3 +475,12 @@ pub mod simd { } include!("primitive_docs.rs"); + +/// Derive macro generating impls of traits related to smart pointers. +#[cfg(not(bootstrap))] +#[rustc_builtin_macro] +#[allow_internal_unstable(dispatch_from_dyn, coerce_unsized, receiver_trait, unsize)] +#[unstable(feature = "derive_smart_pointer", issue = "123430")] +pub macro SmartPointer($item:item) { + /* compiler built-in */ +} diff --git a/library/std/src/lib.rs b/library/std/src/lib.rs index 9d6576fa84117..759c13b8d096f 100644 --- a/library/std/src/lib.rs +++ b/library/std/src/lib.rs @@ -569,6 +569,9 @@ pub use core::u8; #[stable(feature = "rust1", since = "1.0.0")] #[allow(deprecated, deprecated_in_future)] pub use core::usize; +#[cfg(not(bootstrap))] +#[unstable(feature = "derive_smart_pointer", issue = "123430")] +pub use core::SmartPointer; #[unstable(feature = "f128", issue = "116909")] pub mod f128; diff --git a/tests/ui/deriving/deriving-smart-pointer.rs b/tests/ui/deriving/deriving-smart-pointer.rs new file mode 100644 index 0000000000000..7d7591fde5861 --- /dev/null +++ b/tests/ui/deriving/deriving-smart-pointer.rs @@ -0,0 +1,53 @@ +//@ run-pass +#![feature(derive_smart_pointer, arbitrary_self_types)] + +#[derive(SmartPointer)] +struct MyPointer<'a, #[pointee] T: ?Sized> { + ptr: &'a T, +} + +impl Copy for MyPointer<'_, T> {} +impl Clone for MyPointer<'_, T> { + fn clone(&self) -> Self { + Self { ptr: self.ptr } + } +} + +impl<'a, T: ?Sized> core::ops::Deref for MyPointer<'a, T> { + type Target = T; + fn deref(&self) -> &'a T { + self.ptr + } +} + +struct MyValue(u32); +impl MyValue { + fn through_pointer(self: MyPointer<'_, Self>) -> u32 { + self.ptr.0 + } +} + +trait MyTrait { + fn through_trait(&self) -> u32; + fn through_trait_and_pointer(self: MyPointer<'_, Self>) -> u32; +} + +impl MyTrait for MyValue { + fn through_trait(&self) -> u32 { + self.0 + } + + fn through_trait_and_pointer(self: MyPointer<'_, Self>) -> u32 { + self.ptr.0 + } +} + +pub fn main() { + let v = MyValue(10); + let ptr = MyPointer { ptr: &v }; + assert_eq!(v.0, ptr.through_pointer()); + assert_eq!(v.0, ptr.through_pointer()); + let dptr = ptr as MyPointer; + assert_eq!(v.0, dptr.through_trait()); + assert_eq!(v.0, dptr.through_trait_and_pointer()); +}