From ec0b4cbdaac792f3b34b012c8f2a68e342b8d098 Mon Sep 17 00:00:00 2001 From: rzvxa <3788964+rzvxa@users.noreply.github.com> Date: Wed, 7 Aug 2024 20:29:06 +0000 Subject: [PATCH] feat(ast_codegen): add `derive_clone_in` generator. (#4731) Follow-on after #4276, related to #4284. --- .../oxc_ast/src/generated/derive_clone_in.rs | 7 + .../src/generators/derive_clone_in.rs | 126 ++++++++++++++++++ tasks/ast_codegen/src/generators/mod.rs | 2 + tasks/ast_codegen/src/main.rs | 3 +- tasks/ast_codegen/src/schema/defs.rs | 15 +++ 5 files changed, 152 insertions(+), 1 deletion(-) create mode 100644 crates/oxc_ast/src/generated/derive_clone_in.rs create mode 100644 tasks/ast_codegen/src/generators/derive_clone_in.rs diff --git a/crates/oxc_ast/src/generated/derive_clone_in.rs b/crates/oxc_ast/src/generated/derive_clone_in.rs new file mode 100644 index 0000000000000..63e7f0afe4337 --- /dev/null +++ b/crates/oxc_ast/src/generated/derive_clone_in.rs @@ -0,0 +1,7 @@ +// Auto-generated code, DO NOT EDIT DIRECTLY! +// To edit this generated file you have to edit `tasks/ast_codegen/src/generators/derive_clone_in.rs` + +use oxc_allocator::{Allocator, CloneIn}; + +use crate::ast::*; + diff --git a/tasks/ast_codegen/src/generators/derive_clone_in.rs b/tasks/ast_codegen/src/generators/derive_clone_in.rs new file mode 100644 index 0000000000000..2ffd680b8cd32 --- /dev/null +++ b/tasks/ast_codegen/src/generators/derive_clone_in.rs @@ -0,0 +1,126 @@ +use itertools::Itertools; +use proc_macro2::TokenStream; +use quote::{format_ident, quote}; +use syn::Ident; + +use crate::{ + output, + schema::{EnumDef, GetIdent, StructDef, TypeDef}, + GeneratorOutput, LateCtx, +}; + +use super::{define_generator, generated_header, Generator}; + +define_generator! { + pub struct DeriveCloneIn; +} + +impl Generator for DeriveCloneIn { + fn name(&self) -> &'static str { + stringify!(DeriveCloneIn) + } + + fn generate(&mut self, ctx: &LateCtx) -> GeneratorOutput { + let impls: Vec = ctx + .schema + .definitions + .iter() + .filter(|def| def.generates_derive("CloneIn")) + .map(|def| match &def { + TypeDef::Enum(it) => derive_enum(it), + TypeDef::Struct(it) => derive_struct(it), + }) + .collect(); + + let header = generated_header!(); + + GeneratorOutput::Stream(( + output(crate::AST_CRATE, "derive_clone_in.rs"), + quote! { + #header + + use oxc_allocator::{Allocator, CloneIn}; + endl!(); + use crate::ast::*; + endl!(); + + #(#impls)* + }, + )) + } +} + +fn derive_enum(def: &EnumDef) -> TokenStream { + let ty_ident = def.ident(); + let (alloc, body) = { + let mut used_alloc = false; + let matches = def + .all_variants() + .map(|var| { + let ident = var.ident(); + if var.is_unit() { + quote!(Self :: #ident => Self :: Cloned :: #ident) + } else { + used_alloc = true; + quote!(Self :: #ident(it) => Self :: Cloned :: #ident(it.clone_in(alloc))) + } + }) + .collect_vec(); + let alloc_ident = if used_alloc { format_ident!("alloc") } else { format_ident!("_") }; + ( + alloc_ident, + quote! { + match self { + #(#matches),* + } + }, + ) + }; + impl_clone_in(&ty_ident, def.has_lifetime, &alloc, &body) +} + +fn derive_struct(def: &StructDef) -> TokenStream { + let ty_ident = def.ident(); + let (alloc, body) = { + let (alloc_ident, body) = if def.fields.is_empty() { + (format_ident!("_"), TokenStream::default()) + } else { + let fields = def.fields.iter().map(|field| { + let ident = field.ident(); + quote!(#ident: self.#ident.clone_in(alloc)) + }); + (format_ident!("alloc"), quote!({ #(#fields),* })) + }; + (alloc_ident, quote!( #ty_ident #body )) + }; + impl_clone_in(&ty_ident, def.has_lifetime, &alloc, &body) +} + +fn impl_clone_in( + ty_ident: &Ident, + has_lifetime: bool, + alloc: &Ident, + body: &TokenStream, +) -> TokenStream { + if has_lifetime { + quote! { + endl!(); + impl <'old_alloc, 'new_alloc> CloneIn<'new_alloc> for #ty_ident<'old_alloc> { + type Cloned = #ty_ident<'new_alloc>; + fn clone_in(&self, #alloc: &'new_alloc Allocator) -> Self::Cloned { + #body + } + } + } + } else { + quote! { + endl!(); + impl <'alloc> CloneIn<'alloc> for #ty_ident { + type Cloned = #ty_ident; + fn clone_in(&self, #alloc: &'alloc Allocator) -> Self::Cloned { + #body + } + } + } + } +} diff --git a/tasks/ast_codegen/src/generators/mod.rs b/tasks/ast_codegen/src/generators/mod.rs index 2bce32dda52a5..862a6f9a369d0 100644 --- a/tasks/ast_codegen/src/generators/mod.rs +++ b/tasks/ast_codegen/src/generators/mod.rs @@ -1,6 +1,7 @@ mod assert_layouts; mod ast_builder; mod ast_kind; +mod derive_clone_in; mod impl_get_span; mod visit; @@ -43,6 +44,7 @@ pub(crate) use insert; pub use assert_layouts::AssertLayouts; pub use ast_builder::AstBuilderGenerator; pub use ast_kind::AstKindGenerator; +pub use derive_clone_in::DeriveCloneIn; pub use impl_get_span::ImplGetSpanGenerator; pub use visit::{VisitGenerator, VisitMutGenerator}; diff --git a/tasks/ast_codegen/src/main.rs b/tasks/ast_codegen/src/main.rs index 1c5cfaf1d3dde..c703263776496 100644 --- a/tasks/ast_codegen/src/main.rs +++ b/tasks/ast_codegen/src/main.rs @@ -16,7 +16,7 @@ mod util; use fmt::{cargo_fmt, pprint}; use generators::{ - AssertLayouts, AstBuilderGenerator, AstKindGenerator, Generator, VisitGenerator, + AssertLayouts, AstBuilderGenerator, AstKindGenerator, DeriveCloneIn, Generator, VisitGenerator, VisitMutGenerator, }; use passes::{CalcLayout, Linker, Pass}; @@ -297,6 +297,7 @@ fn main() -> std::result::Result<(), Box> { .gen(AssertLayouts) .gen(AstKindGenerator) .gen(AstBuilderGenerator) + .gen(DeriveCloneIn) .gen(ImplGetSpanGenerator) .gen(VisitGenerator) .gen(VisitMutGenerator) diff --git a/tasks/ast_codegen/src/schema/defs.rs b/tasks/ast_codegen/src/schema/defs.rs index 46bb2926614a1..9265361880e3b 100644 --- a/tasks/ast_codegen/src/schema/defs.rs +++ b/tasks/ast_codegen/src/schema/defs.rs @@ -79,6 +79,17 @@ impl EnumDef { pub fn all_variants(&self) -> impl Iterator { self.variants.iter().chain(self.inherits.iter().flat_map(|it| it.variants.iter())) } + + /// Are all the variants in this enum unit? + /// Example: + /// ``` + /// enum E { A, B, C, D } + /// + /// ``` + /// + pub fn is_unit(&self) -> bool { + self.all_variants().all(VariantDef::is_unit) + } } #[derive(Debug, Serialize)] @@ -93,6 +104,10 @@ impl VariantDef { pub fn ident(&self) -> syn::Ident { self.name.to_ident() } + + pub fn is_unit(&self) -> bool { + self.fields.is_empty() + } } #[derive(Debug, Serialize)]