diff --git a/src/connector/mysql.rs b/src/connector/mysql.rs index 8a5d68a11..f9bd75ef1 100644 --- a/src/connector/mysql.rs +++ b/src/connector/mysql.rs @@ -329,7 +329,7 @@ mod tests { use url::Url; lazy_static! { - static ref CONN_STR: String = env::var("TEST_MYSQL").unwrap(); + static ref CONN_STR: String = env::var("TEST_MYSQL").expect("TEST_MYSQL env var"); } #[test] @@ -464,22 +464,48 @@ VALUES (1, 'Joe', 27, 20000.00 ); .await .unwrap(); - let res = conn - .query_raw("INSERT INTO test_null_constraint_violation () VALUES ()", &[]) - .await; + // Error code 1364 + { + let res = conn + .query_raw("INSERT INTO test_null_constraint_violation () VALUES ()", &[]) + .await; + + let err = res.unwrap_err(); + + match err.kind() { + ErrorKind::NullConstraintViolation { constraint } => { + assert_eq!(Some("1364"), err.original_code()); + assert_eq!( + Some("Field \'id1\' doesn\'t have a default value"), + err.original_message() + ); + assert_eq!(&DatabaseConstraint::Fields(vec![String::from("id1")]), constraint) + } + _ => panic!(err), + } + } - let err = res.unwrap_err(); + // Error code 1048 + { + conn.query_raw( + "INSERT INTO test_null_constraint_violation (id1, id2) VALUES (50, 55)", + &[], + ) + .await + .unwrap(); - match err.kind() { - ErrorKind::NullConstraintViolation { constraint } => { - assert_eq!(Some("1364"), err.original_code()); - assert_eq!( - Some("Field \'id1\' doesn\'t have a default value"), - err.original_message() - ); - assert_eq!(&DatabaseConstraint::Fields(vec![String::from("id1")]), constraint) + let err = conn + .query_raw("UPDATE test_null_constraint_violation SET id2 = NULL", &[]) + .await + .unwrap_err(); + + match err.kind() { + ErrorKind::NullConstraintViolation { constraint } => { + assert_eq!(Some("1048"), err.original_code()); + assert_eq!(&DatabaseConstraint::Fields(vec![String::from("id2")]), constraint); + } + _ => panic!("{:?}", err), } - _ => panic!(err), } } } diff --git a/src/connector/mysql/error.rs b/src/connector/mysql/error.rs index bb1295303..93cf769ef 100644 --- a/src/connector/mysql/error.rs +++ b/src/connector/mysql/error.rs @@ -55,7 +55,7 @@ impl From for Error { builder.build() } - my::error::Error::Server(ServerError { ref message, code, .. }) if code == 1364 => { + my::error::Error::Server(ServerError { ref message, code, .. }) if code == 1364 || code == 1048 => { let splitted: Vec<&str> = message.split_whitespace().collect(); let splitted: Vec<&str> = splitted.get(1).map(|s| s.split('\'').collect()).unwrap();