Skip to content

Commit

Permalink
feat(derive): add #[postgres(allow_mismatch)]
Browse files Browse the repository at this point in the history
  • Loading branch information
viniciusth committed Apr 30, 2023
1 parent 8b9b5d0 commit 3d92ed1
Show file tree
Hide file tree
Showing 8 changed files with 207 additions and 7 deletions.
23 changes: 23 additions & 0 deletions postgres-derive-test/src/compile-fail/invalid-allow-mismatch.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use postgres_types::{FromSql, ToSql};

#[derive(ToSql, Debug)]
#[postgres(allow_mismatch)]
struct ToSqlAllowMismatchStruct {
a: i32,
}

#[derive(FromSql, Debug)]
#[postgres(allow_mismatch)]
struct FromSqlAllowMismatchStruct {
a: i32,
}

#[derive(ToSql, Debug)]
#[postgres(allow_mismatch)]
struct ToSqlAllowMismatchTupleStruct(i32, i32);

#[derive(FromSql, Debug)]
#[postgres(allow_mismatch)]
struct FromSqlAllowMismatchTupleStruct(i32, i32);

fn main() {}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
error: #[postgres(allow_mismatch)] may only be applied to enums
--> src/compile-fail/invalid-allow-mismatch.rs:4:1
|
4 | / #[postgres(allow_mismatch)]
5 | | struct ToSqlAllowMismatchStruct {
6 | | a: i32,
7 | | }
| |_^

error: #[postgres(allow_mismatch)] may only be applied to enums
--> src/compile-fail/invalid-allow-mismatch.rs:10:1
|
10 | / #[postgres(allow_mismatch)]
11 | | struct FromSqlAllowMismatchStruct {
12 | | a: i32,
13 | | }
| |_^

error: #[postgres(allow_mismatch)] may only be applied to enums
--> src/compile-fail/invalid-allow-mismatch.rs:16:1
|
16 | / #[postgres(allow_mismatch)]
17 | | struct ToSqlAllowMismatchTupleStruct(i32, i32);
| |_______________________________________________^

error: #[postgres(allow_mismatch)] may only be applied to enums
--> src/compile-fail/invalid-allow-mismatch.rs:20:1
|
20 | / #[postgres(allow_mismatch)]
21 | | struct FromSqlAllowMismatchTupleStruct(i32, i32);
| |_________________________________________________^
72 changes: 71 additions & 1 deletion postgres-derive-test/src/enums.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::test_type;
use postgres::{Client, NoTls};
use postgres::{error::DbError, Client, NoTls};
use postgres_types::{FromSql, ToSql, WrongType};
use std::error::Error;

Expand Down Expand Up @@ -102,3 +102,73 @@ fn missing_variant() {
let err = conn.execute("SELECT $1::foo", &[&Foo::Bar]).unwrap_err();
assert!(err.source().unwrap().is::<WrongType>());
}

#[test]
fn allow_mismatch_enums() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
#[postgres(allow_mismatch)]
enum Foo {
Bar,
}

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[])
.unwrap();

let row = conn.query_one("SELECT $1::\"Foo\"", &[&Foo::Bar]).unwrap();
assert_eq!(row.get::<_, Foo>(0), Foo::Bar);
}

#[test]
fn missing_enum_variant() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
#[postgres(allow_mismatch)]
enum Foo {
Bar,
Buz,
}

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute("CREATE TYPE pg_temp.\"Foo\" AS ENUM ('Bar', 'Baz')", &[])
.unwrap();

let err = conn
.query_one("SELECT $1::\"Foo\"", &[&Foo::Buz])
.unwrap_err();
assert!(err.source().unwrap().is::<DbError>());
}

#[test]
fn allow_mismatch_and_renaming() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
#[postgres(name = "foo", allow_mismatch)]
enum Foo {
#[postgres(name = "bar")]
Bar,
#[postgres(name = "buz")]
Buz,
}

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('bar', 'baz', 'buz')", &[])
.unwrap();

let row = conn.query_one("SELECT $1::foo", &[&Foo::Buz]).unwrap();
assert_eq!(row.get::<_, Foo>(0), Foo::Buz);
}

#[test]
fn wrong_name_and_allow_mismatch() {
#[derive(Debug, ToSql, FromSql, PartialEq)]
#[postgres(allow_mismatch)]
enum Foo {
Bar,
}

let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
conn.execute("CREATE TYPE pg_temp.foo AS ENUM ('Bar', 'Baz')", &[])
.unwrap();

let err = conn.query_one("SELECT $1::foo", &[&Foo::Bar]).unwrap_err();
assert!(err.source().unwrap().is::<WrongType>());
}
6 changes: 5 additions & 1 deletion postgres-derive/src/accepts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub fn domain_body(name: &str, field: &syn::Field) -> TokenStream {
}
}

pub fn enum_body(name: &str, variants: &[Variant]) -> TokenStream {
pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> TokenStream {
let num_variants = variants.len();
let variant_names = variants.iter().map(|v| &v.name);

Expand All @@ -40,6 +40,10 @@ pub fn enum_body(name: &str, variants: &[Variant]) -> TokenStream {
return false;
}

if #allow_mismatch {
return true;
}

match *type_.kind() {
::postgres_types::Kind::Enum(ref variants) => {
if variants.len() != #num_variants {
Expand Down
22 changes: 21 additions & 1 deletion postgres-derive/src/fromsql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,26 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
))
}
}
} else if overrides.allow_mismatch {
match input.data {
Data::Enum(ref data) => {
let variants = data
.variants
.iter()
.map(Variant::parse)
.collect::<Result<Vec<_>, _>>()?;
(
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
enum_body(&input.ident, &variants),
)
}
_ => {
return Err(Error::new_spanned(
input,
"#[postgres(allow_mismatch)] may only be applied to enums",
));
}
}
} else {
match input.data {
Data::Enum(ref data) => {
Expand All @@ -54,7 +74,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
.map(Variant::parse)
.collect::<Result<Vec<_>, _>>()?;
(
accepts::enum_body(&name, &variants),
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
enum_body(&input.ident, &variants),
)
}
Expand Down
10 changes: 7 additions & 3 deletions postgres-derive/src/overrides.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@ use syn::{Attribute, Error, Expr, ExprLit, Lit, Meta, Token};
pub struct Overrides {
pub name: Option<String>,
pub transparent: bool,
pub allow_mismatch: bool,
}

impl Overrides {
pub fn extract(attrs: &[Attribute]) -> Result<Overrides, Error> {
let mut overrides = Overrides {
name: None,
transparent: false,
allow_mismatch: false,
};

for attr in attrs {
Expand Down Expand Up @@ -44,11 +46,13 @@ impl Overrides {
overrides.name = Some(value);
}
Meta::Path(path) => {
if !path.is_ident("transparent") {
if path.is_ident("transparent") {
overrides.transparent = true;
} else if path.is_ident("allow_mismatch") {
overrides.allow_mismatch = true;
} else {
return Err(Error::new_spanned(path, "unknown override"));
}

overrides.transparent = true;
}
bad => return Err(Error::new_spanned(bad, "unknown attribute")),
}
Expand Down
22 changes: 21 additions & 1 deletion postgres-derive/src/tosql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,26 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
));
}
}
} else if overrides.allow_mismatch {
match input.data {
Data::Enum(ref data) => {
let variants = data
.variants
.iter()
.map(Variant::parse)
.collect::<Result<Vec<_>, _>>()?;
(
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
enum_body(&input.ident, &variants),
)
}
_ => {
return Err(Error::new_spanned(
input,
"#[postgres(allow_mismatch)] may only be applied to enums",
));
}
}
} else {
match input.data {
Data::Enum(ref data) => {
Expand All @@ -50,7 +70,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
.map(Variant::parse)
.collect::<Result<Vec<_>, _>>()?;
(
accepts::enum_body(&name, &variants),
accepts::enum_body(&name, &variants, overrides.allow_mismatch),
enum_body(&input.ident, &variants),
)
}
Expand Down
28 changes: 28 additions & 0 deletions postgres-types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,34 @@
//! Happy,
//! }
//! ```
//!
//! ## Allowing Enum Mismatches
//!
//! By default the generated implementation of [`ToSql`] & [`FromSql`] for enums will require an exact match of the enum
//! variants between the Rust and Postgres types.
//! To allow mismatches, the `#[postgres(allow_mismatch)]` attribute can be used on the enum definition:
//!
//! ```sql
//! CREATE TYPE mood AS ENUM (
//! 'Sad',
//! 'Ok',
//! 'Happy'
//! );
//! ```
//!
//! ```rust
//! # #[cfg(feature = "derive")]
//! use postgres_types::{ToSql, FromSql};
//!
//! # #[cfg(feature = "derive")]
//! #[derive(Debug, ToSql, FromSql)]
//! #[postgres(allow_mismatch)]
//! enum Mood {
//! Sad,
//! Happy,
//! Meh,
//! }
//! ```
#![doc(html_root_url = "https://docs.rs/postgres-types/0.2")]
#![warn(clippy::all, rust_2018_idioms, missing_docs)]

Expand Down

0 comments on commit 3d92ed1

Please sign in to comment.