Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add '#[ts(bound)]' attribute #269

Merged
merged 10 commits into from
Mar 18, 2024
14 changes: 12 additions & 2 deletions macros/src/attr/enum.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::collections::HashMap;

use syn::{Attribute, Ident, Result, Type};
use syn::{Attribute, Ident, Result, Type, WherePredicate};

use crate::{
attr::{parse_assign_inflection, parse_assign_str, parse_concrete, Inflection},
utils::{parse_attrs, parse_docs},
};

use super::parse_bound;

#[derive(Default)]
pub struct EnumAttr {
pub rename_all: Option<Inflection>,
Expand All @@ -16,6 +18,7 @@ pub struct EnumAttr {
pub export: bool,
pub docs: String,
pub concrete: HashMap<Ident, Type>,
pub bound: Option<Vec<WherePredicate>>,
tag: Option<String>,
untagged: bool,
content: Option<String>,
Expand Down Expand Up @@ -71,6 +74,7 @@ impl EnumAttr {
export,
docs,
concrete,
bound,
}: EnumAttr,
) {
self.rename = self.rename.take().or(rename);
Expand All @@ -83,6 +87,10 @@ impl EnumAttr {
self.export_to = self.export_to.take().or(export_to);
self.docs = docs;
self.concrete.extend(concrete);
self.bound = self.bound
.take()
.map(|b| b.into_iter().chain(bound.clone().unwrap_or_default()).collect())
.or(bound);
}
}

Expand All @@ -97,6 +105,7 @@ impl_parse! {
"content" => out.content = Some(parse_assign_str(input)?),
"untagged" => out.untagged = true,
"concrete" => out.concrete = parse_concrete(input)?,
"bound" => out.bound = Some(parse_bound(input)?),
}
}

Expand All @@ -108,6 +117,7 @@ impl_parse! {
"rename_all_fields" => out.0.rename_all_fields = Some(parse_assign_inflection(input)?),
"tag" => out.0.tag = Some(parse_assign_str(input)?),
"content" => out.0.content = Some(parse_assign_str(input)?),
"untagged" => out.0.untagged = true
"untagged" => out.0.untagged = true,
"bound" => out.0.bound = Some(parse_bound(input)?),
}
}
16 changes: 14 additions & 2 deletions macros/src/attr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ pub use r#enum::*;
pub use r#struct::*;
use syn::{
parse::{Parse, ParseStream},
Error, Lit, Result, Token,
Error, Lit, Result, Token, WherePredicate, punctuated::Punctuated,
};
pub use variant::*;

Expand Down Expand Up @@ -107,7 +107,7 @@ fn parse_concrete(input: ParseStream) -> Result<HashMap<syn::Ident, syn::Type>>
syn::parenthesized!(content in input);

Ok(
syn::punctuated::Punctuated::<Concrete, Token![,]>::parse_terminated(&content)?
Punctuated::<Concrete, Token![,]>::parse_terminated(&content)?
.into_iter()
.map(|concrete| (concrete.ident, concrete.ty))
.collect(),
Expand All @@ -128,3 +128,15 @@ where
other => Err(Error::new(other.span(), "expected string")),
}
}

fn parse_bound(input: ParseStream) -> Result<Vec<WherePredicate>> {
input.parse::<Token![=]>()?;
match Lit::parse(input)? {
Lit::Str(string) => {
let parser = Punctuated::<WherePredicate, Token![,]>::parse_terminated;

Ok(string.parse_with(parser)?.into_iter().collect())
},
other => Err(Error::new(other.span(), "expected string")),
}
}
12 changes: 11 additions & 1 deletion macros/src/attr/struct.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
use std::{collections::HashMap, convert::TryFrom};

use syn::{Attribute, Ident, Result, Type};
use syn::{Attribute, Ident, Result, Type, WherePredicate};

use crate::{
attr::{parse_assign_str, parse_concrete, Inflection, VariantAttr},
utils::{parse_attrs, parse_docs},
};

use super::parse_bound;

#[derive(Default, Clone)]
pub struct StructAttr {
pub rename_all: Option<Inflection>,
Expand All @@ -16,6 +18,7 @@ pub struct StructAttr {
pub tag: Option<String>,
pub docs: String,
pub concrete: HashMap<Ident, Type>,
pub bound: Option<Vec<WherePredicate>>,
}

#[cfg(feature = "serde-compat")]
Expand Down Expand Up @@ -45,6 +48,7 @@ impl StructAttr {
tag,
docs,
concrete,
bound,
}: StructAttr,
) {
self.rename = self.rename.take().or(rename);
Expand All @@ -54,6 +58,10 @@ impl StructAttr {
self.tag = self.tag.take().or(tag);
self.docs = docs;
self.concrete.extend(concrete);
self.bound = self.bound
.take()
.map(|b| b.into_iter().chain(bound.clone().unwrap_or_default()).collect())
.or(bound);
}
}

Expand All @@ -80,6 +88,7 @@ impl_parse! {
"export" => out.export = true,
"export_to" => out.export_to = Some(parse_assign_str(input)?),
"concrete" => out.concrete = parse_concrete(input)?,
"bound" => out.bound = Some(parse_bound(input)?),
}
}

Expand All @@ -89,6 +98,7 @@ impl_parse! {
"rename" => out.0.rename = Some(parse_assign_str(input)?),
"rename_all" => out.0.rename_all = Some(parse_assign_str(input).and_then(Inflection::try_from)?),
"tag" => out.0.tag = Some(parse_assign_str(input)?),
"bound" => out.0.bound = Some(parse_bound(input)?),
// parse #[serde(default)] to not emit a warning
"deny_unknown_fields" | "default" => {
use syn::Token;
Expand Down
24 changes: 19 additions & 5 deletions macros/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use quote::{format_ident, quote};
use syn::{
parse_quote, spanned::Spanned, ConstParam, GenericParam, Generics, Item, LifetimeParam, Result,
Type, TypeArray, TypeParam, TypeParen, TypePath, TypeReference, TypeSlice, TypeTuple,
WhereClause,
WhereClause, WherePredicate,
};

use crate::{deps::Dependencies, utils::format_generics};
Expand All @@ -26,6 +26,7 @@ struct DerivedTS {
inline_flattened: Option<TokenStream>,
dependencies: Dependencies,
concrete: HashMap<Ident, Type>,
bound: Option<Vec<WherePredicate>>,

export: bool,
export_to: Option<String>,
Expand Down Expand Up @@ -59,7 +60,12 @@ impl DerivedTS {
};

let ident = self.ts_name.clone();
let impl_start = generate_impl_block_header(&rust_ty, &generics, &self.dependencies);
let impl_start = generate_impl_block_header(
&rust_ty,
&generics,
self.bound.as_deref(),
&self.dependencies,
);
let assoc_type = generate_assoc_type(&rust_ty, &generics, &self.concrete);
let name = self.generate_name_fn(&generics);
let inline = self.generate_inline_fn();
Expand Down Expand Up @@ -295,11 +301,12 @@ fn generate_assoc_type(
fn generate_impl_block_header(
ty: &Ident,
generics: &Generics,
bounds: Option<&[WherePredicate]>,
dependencies: &Dependencies,
) -> TokenStream {
use GenericParam as G;

let bounds = generics.params.iter().map(|param| match param {
let params = generics.params.iter().map(|param| match param {
G::Type(TypeParam {
ident,
colon_token,
Expand All @@ -325,8 +332,15 @@ fn generate_impl_block_header(
G::Lifetime(LifetimeParam { lifetime, .. }) => quote!(#lifetime),
});

let where_bound = generate_where_clause(generics, dependencies);
quote!(impl <#(#bounds),*> ::ts_rs::TS for #ty <#(#type_args),*> #where_bound)
let where_bound = match bounds {
Some(bounds) => quote! { where #(#bounds),* },
None => {
let bounds = generate_where_clause(generics, dependencies);
quote! { #bounds }
}
};

quote!(impl <#(#params),*> ::ts_rs::TS for #ty <#(#type_args),*> #where_bound)
}

fn generate_where_clause(generics: &Generics, dependencies: &Dependencies) -> WhereClause {
Expand Down
3 changes: 3 additions & 0 deletions macros/src/types/enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pub(crate) fn r#enum_def(s: &ItemEnum) -> syn::Result<DerivedTS> {
export: enum_attr.export,
export_to: enum_attr.export_to,
concrete: enum_attr.concrete,
bound: enum_attr.bound,
});
}

Expand All @@ -55,6 +56,7 @@ pub(crate) fn r#enum_def(s: &ItemEnum) -> syn::Result<DerivedTS> {
export_to: enum_attr.export_to,
ts_name: name,
concrete: enum_attr.concrete,
bound: enum_attr.bound,
})
}

Expand Down Expand Up @@ -205,5 +207,6 @@ fn empty_enum(name: impl Into<String>, enum_attr: EnumAttr) -> DerivedTS {
export_to: enum_attr.export_to,
ts_name: name,
concrete: enum_attr.concrete,
bound: enum_attr.bound,
}
}
1 change: 1 addition & 0 deletions macros/src/types/named.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ pub(crate) fn named(attr: &StructAttr, name: &str, fields: &FieldsNamed) -> Resu
export_to: attr.export_to.clone(),
ts_name: name.to_owned(),
concrete: attr.concrete.clone(),
bound: attr.bound.clone(),
})
}

Expand Down
1 change: 1 addition & 0 deletions macros/src/types/newtype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,6 @@ pub(crate) fn newtype(attr: &StructAttr, name: &str, fields: &FieldsUnnamed) ->
export_to: attr.export_to.clone(),
ts_name: name.to_owned(),
concrete: attr.concrete.clone(),
bound: attr.bound.clone(),
})
}
1 change: 1 addition & 0 deletions macros/src/types/tuple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ pub(crate) fn tuple(attr: &StructAttr, name: &str, fields: &FieldsUnnamed) -> Re
export_to: attr.export_to.clone(),
ts_name: name.to_owned(),
concrete: attr.concrete.clone(),
bound: attr.bound.clone(),
})
}

Expand Down
3 changes: 3 additions & 0 deletions macros/src/types/unit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub(crate) fn empty_object(attr: &StructAttr, name: &str) -> Result<DerivedTS> {
export_to: attr.export_to.clone(),
ts_name: name.to_owned(),
concrete: attr.concrete.clone(),
bound: attr.bound.clone(),
})
}

Expand All @@ -30,6 +31,7 @@ pub(crate) fn empty_array(attr: &StructAttr, name: &str) -> Result<DerivedTS> {
export_to: attr.export_to.clone(),
ts_name: name.to_owned(),
concrete: attr.concrete.clone(),
bound: attr.bound.clone(),
})
}

Expand All @@ -45,6 +47,7 @@ pub(crate) fn null(attr: &StructAttr, name: &str) -> Result<DerivedTS> {
export_to: attr.export_to.clone(),
ts_name: name.to_owned(),
concrete: attr.concrete.clone(),
bound: attr.bound.clone(),
})
}

Expand Down
38 changes: 38 additions & 0 deletions ts-rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,44 @@ pub mod typelist;
/// ```
/// <br/><br/>
///
/// - **`#[ts(bound)]`**
/// Override the bounds generated on the `TS` implementation for this type. This is useful in
/// combination with `#[ts(concrete)]`, when the type's generic parameters aren't directly used
/// in a field or variant.
///
/// Example:
/// ```
/// # use ts_rs::TS;
///
/// trait Container {
/// type Value: TS;
/// }
///
/// struct MyContainer;
///
/// ##[derive(TS)]
/// struct MyValue;
///
/// impl Container for MyContainer {
/// type Value = MyValue;
/// }
///
/// ##[derive(TS)]
/// ##[ts(export, concrete(C = MyContainer))]
/// struct Inner<C: Container> {
/// value: C::Value,
/// }
///
/// ##[derive(TS)]
/// // Without `#[ts(bound)]`, `#[derive(TS)]` would generate an unnecessary
/// // `C: TS` bound
/// ##[ts(export, concrete(C = MyContainer), bound = "C::Value: TS")]
/// struct Outer<C: Container> {
/// inner: Inner<C>,
/// }
/// ```
/// <br/><br/>
///
/// ### struct attributes
/// - **`#[ts(tag = "..")]`**
/// Include the structs name (or value of `#[ts(rename = "..")]`) as a field with the given key.
Expand Down
36 changes: 36 additions & 0 deletions ts-rs/tests/bound.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#![allow(dead_code)]

use ts_rs::TS;

trait Driver {
type Info;
}

struct TsDriver;

#[derive(TS)]
struct TsInfo;

impl Driver for TsDriver {
type Info = TsInfo;
}

#[derive(TS)]
#[ts(export, export_to = "bound/")]
#[ts(concrete(D = TsDriver))]
struct Inner<D: Driver> {
info: D::Info,
}

#[derive(TS)]
#[ts(export, export_to = "bound/")]
#[ts(concrete(D = TsDriver), bound = "D::Info: TS")]
struct Outer<D: Driver> {
inner: Inner<D>,
}

#[test]
fn test_bound() {
assert_eq!(Outer::<TsDriver>::decl(), "type Outer = { inner: Inner, };");
assert_eq!(Inner::<TsDriver>::decl(), "type Inner = { info: TsInfo, };");
}
Loading