diff --git a/sea-orm-macros/src/derives/column.rs b/sea-orm-macros/src/derives/column.rs index 034e966d4..16f9bb225 100644 --- a/sea-orm-macros/src/derives/column.rs +++ b/sea-orm-macros/src/derives/column.rs @@ -1,4 +1,4 @@ -use heck::SnakeCase; +use heck::{MixedCase, SnakeCase}; use proc_macro2::{Ident, TokenStream}; use quote::{quote, quote_spanned}; use syn::{Data, DataEnum, Fields, Variant}; @@ -41,6 +41,39 @@ pub fn impl_default_as_str(ident: &Ident, data: &Data) -> syn::Result syn::Result { + let data_enum = match data { + Data::Enum(data_enum) => data_enum, + _ => { + return Ok(quote_spanned! { + ident.span() => compile_error!("you can only derive DeriveColumn on enums"); + }) + } + }; + + let columns = data_enum.variants.iter().map(|column| { + let column_iden = column.ident.clone(); + let column_str_snake = column_iden.to_string().to_snake_case(); + let column_str_mixed = column_iden.to_string().to_mixed_case(); + quote!( + #column_str_snake | #column_str_mixed => Ok(#ident::#column_iden) + ) + }); + + Ok(quote!( + impl std::str::FromStr for #ident { + type Err = sea_orm::ColumnFromStrErr; + + fn from_str(s: &str) -> Result { + match s { + #(#columns),*, + _ => Err(sea_orm::ColumnFromStrErr(format!("Failed to parse '{}' as `{}`", s, stringify!(#ident)))), + } + } + } + )) +} + pub fn expand_derive_column(ident: &Ident, data: &Data) -> syn::Result { let impl_iden = expand_derive_custom_column(ident, data)?; @@ -57,10 +90,13 @@ pub fn expand_derive_column(ident: &Ident, data: &Data) -> syn::Result syn::Result { let impl_default_as_str = impl_default_as_str(ident, data)?; + let impl_col_from_str = impl_col_from_str(ident, data)?; Ok(quote!( #impl_default_as_str + #impl_col_from_str + impl sea_orm::Iden for #ident { fn unquoted(&self, s: &mut dyn std::fmt::Write) { write!(s, "{}", self.as_str()).unwrap(); diff --git a/src/entity/column.rs b/src/entity/column.rs index 045a85ba6..165460573 100644 --- a/src/entity/column.rs +++ b/src/entity/column.rs @@ -1,3 +1,4 @@ +use std::str::FromStr; use crate::{EntityName, IdenStatic, Iterable}; use sea_query::{DynIden, Expr, SeaRc, SelectStatement, SimpleExpr, Value}; @@ -77,7 +78,7 @@ macro_rules! bind_subquery_func { // LINT: when the operand value does not match column type /// Wrapper of the identically named method in [`sea_query::Expr`] -pub trait ColumnTrait: IdenStatic + Iterable { +pub trait ColumnTrait: IdenStatic + Iterable + FromStr { type EntityName: EntityName; fn def(&self) -> ColumnDef; @@ -348,4 +349,30 @@ mod tests { .join(" ") ); } + + #[test] + fn test_col_from_str() { + use std::str::FromStr; + + assert!(matches!( + fruit::Column::from_str("id"), + Ok(fruit::Column::Id) + )); + assert!(matches!( + fruit::Column::from_str("name"), + Ok(fruit::Column::Name) + )); + assert!(matches!( + fruit::Column::from_str("cake_id"), + Ok(fruit::Column::CakeId) + )); + assert!(matches!( + fruit::Column::from_str("cakeId"), + Ok(fruit::Column::CakeId) + )); + assert!(matches!( + fruit::Column::from_str("does_not_exist"), + Err(crate::ColumnFromStrErr(_)) + )); + } } diff --git a/src/error.rs b/src/error.rs index eff999121..8a695dac3 100644 --- a/src/error.rs +++ b/src/error.rs @@ -16,3 +16,14 @@ impl std::fmt::Display for DbErr { } } } + +#[derive(Debug, Clone)] +pub struct ColumnFromStrErr(pub String); + +impl std::error::Error for ColumnFromStrErr {} + +impl std::fmt::Display for ColumnFromStrErr { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "{}", self.0.as_str()) + } +}