Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

InsertResult to return the primary key's type #117

Merged
merged 15 commits into from
Sep 3, 2021
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ let pear = fruit::ActiveModel {
};

// insert one
let res: InsertResult = Fruit::insert(pear).exec(db).await?;
let res = Fruit::insert(pear).exec(db).await?;

println!("InsertResult: {}", res.last_insert_id);

Expand Down
2 changes: 2 additions & 0 deletions examples/async-std/src/example_cake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub enum PrimaryKey {
}

impl PrimaryKeyTrait for PrimaryKey {
type ValueType = i32;

fn auto_increment() -> bool {
true
}
Expand Down
2 changes: 2 additions & 0 deletions examples/async-std/src/example_cake_filling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ pub enum PrimaryKey {
}

impl PrimaryKeyTrait for PrimaryKey {
type ValueType = i32;

fn auto_increment() -> bool {
false
}
Expand Down
2 changes: 2 additions & 0 deletions examples/async-std/src/example_filling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub enum PrimaryKey {
}

impl PrimaryKeyTrait for PrimaryKey {
type ValueType = i32;

fn auto_increment() -> bool {
true
}
Expand Down
2 changes: 2 additions & 0 deletions examples/async-std/src/example_fruit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ pub enum PrimaryKey {
}

impl PrimaryKeyTrait for PrimaryKey {
type ValueType = i32;

fn auto_increment() -> bool {
true
}
Expand Down
2 changes: 1 addition & 1 deletion examples/async-std/src/operation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ pub async fn insert_and_update(db: &DbConn) -> Result<(), DbErr> {
name: Set("pear".to_owned()),
..Default::default()
};
let res: InsertResult = Fruit::insert(pear).exec(db).await?;
let res = Fruit::insert(pear).exec(db).await?;

println!();
println!("Inserted: last_insert_id = {}\n", res.last_insert_id);
Expand Down
2 changes: 2 additions & 0 deletions examples/tokio/src/cake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ pub enum PrimaryKey {
}

impl PrimaryKeyTrait for PrimaryKey {
type ValueType = i32;

fn auto_increment() -> bool {
true
}
Expand Down
24 changes: 23 additions & 1 deletion sea-orm-codegen/src/entity/base_entity.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,18 @@ impl Entity {
format_ident!("{}", auto_increment)
}

pub fn get_primary_key_rs_type(&self) -> TokenStream {
if let Some(primary_key) = self.primary_keys.first() {
self.columns
.iter()
.find(|col| col.name.eq(&primary_key.name))
.unwrap()
.get_rs_type()
} else {
TokenStream::new()
}
}

pub fn get_conjunct_relations_via_snake_case(&self) -> Vec<Ident> {
self.conjunct_relations
.iter()
Expand Down Expand Up @@ -151,7 +163,7 @@ mod tests {
columns: vec![
Column {
name: "id".to_owned(),
col_type: ColumnType::String(None),
col_type: ColumnType::Integer(None),
auto_increment: false,
not_null: false,
unique: false,
Expand Down Expand Up @@ -373,6 +385,16 @@ mod tests {
);
}

#[test]
fn test_get_primary_key_rs_type() {
let entity = setup();

assert_eq!(
entity.get_primary_key_rs_type().to_string(),
entity.columns[0].get_rs_type().to_string()
);
}

#[test]
fn test_get_conjunct_relations_via_snake_case() {
let entity = setup();
Expand Down
3 changes: 3 additions & 0 deletions sea-orm-codegen/src/entity/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,8 +173,11 @@ impl EntityWriter {

pub fn gen_impl_primary_key(entity: &Entity) -> TokenStream {
let primary_key_auto_increment = entity.get_primary_key_auto_increment();
let value_type = entity.get_primary_key_rs_type();
quote! {
impl PrimaryKeyTrait for PrimaryKey {
type ValueType = #value_type;

fn auto_increment() -> bool {
#primary_key_auto_increment
}
Expand Down
2 changes: 2 additions & 0 deletions sea-orm-codegen/tests/entity/cake.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ pub enum PrimaryKey {
}

impl PrimaryKeyTrait for PrimaryKey {
type ValueType = i32;

fn auto_increment() -> bool {
true
}
Expand Down
2 changes: 2 additions & 0 deletions sea-orm-codegen/tests/entity/cake_filling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ pub enum PrimaryKey {
}

impl PrimaryKeyTrait for PrimaryKey {
type ValueType = i32;

fn auto_increment() -> bool {
false
}
Expand Down
2 changes: 2 additions & 0 deletions sea-orm-codegen/tests/entity/filling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ pub enum PrimaryKey {
}

impl PrimaryKeyTrait for PrimaryKey {
type ValueType = i32;

fn auto_increment() -> bool {
true
}
Expand Down
2 changes: 2 additions & 0 deletions sea-orm-codegen/tests/entity/fruit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub enum PrimaryKey {
}

impl PrimaryKeyTrait for PrimaryKey {
type ValueType = i32;

fn auto_increment() -> bool {
true
}
Expand Down
2 changes: 2 additions & 0 deletions sea-orm-codegen/tests/entity/vendor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ pub enum PrimaryKey {
}

impl PrimaryKeyTrait for PrimaryKey {
type ValueType = i32;

fn auto_increment() -> bool {
true
}
Expand Down
15 changes: 1 addition & 14 deletions src/driver/sqlx_postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -102,24 +102,11 @@ impl From<PgRow> for QueryResult {
impl From<PgQueryResult> for ExecResult {
fn from(result: PgQueryResult) -> ExecResult {
ExecResult {
result: ExecResultHolder::SqlxPostgres {
last_insert_id: 0,
rows_affected: result.rows_affected(),
},
result: ExecResultHolder::SqlxPostgres(result),
}
}
}

pub(crate) fn query_result_into_exec_result(res: QueryResult) -> Result<ExecResult, DbErr> {
let last_insert_id: i32 = res.try_get("", "last_insert_id")?;
Ok(ExecResult {
result: ExecResultHolder::SqlxPostgres {
last_insert_id: last_insert_id as u64,
rows_affected: 0,
},
})
}

fn sqlx_query(stmt: &Statement) -> sqlx::query::Query<'_, Postgres, PgArguments> {
let mut query = sqlx::query(&stmt.sql);
if let Some(values) = &stmt.values {
Expand Down
8 changes: 6 additions & 2 deletions src/entity/active_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,11 @@ where
let exec = E::insert(am).exec(db);
let res = exec.await?;
// TODO: if the entity does not have auto increment primary key, then last_insert_id is a wrong value
if <E::PrimaryKey as PrimaryKeyTrait>::auto_increment() && res.last_insert_id != 0 {
// FIXME: Assumed valid last_insert_id is not equals to Default::default()
if <E::PrimaryKey as PrimaryKeyTrait>::auto_increment()
&& res.last_insert_id != <E::PrimaryKey as PrimaryKeyTrait>::ValueType::default()
{
let last_insert_id = res.last_insert_id.to_string();
let find = E::find_by_id(res.last_insert_id).one(db);
let found = find.await;
let model: Option<E::Model> = found?;
Expand All @@ -230,7 +234,7 @@ where
None => Err(DbErr::Exec(format!(
"Failed to find inserted item: {} {}",
E::default().to_string(),
res.last_insert_id
last_insert_id
))),
}
} else {
Expand Down
2 changes: 1 addition & 1 deletion src/entity/column.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::str::FromStr;
use crate::{EntityName, IdenStatic, Iterable};
use sea_query::{DynIden, Expr, SeaRc, SelectStatement, SimpleExpr, Value};
use std::str::FromStr;

#[derive(Debug, Clone)]
pub struct ColumnDef {
Expand Down
15 changes: 15 additions & 0 deletions src/entity/primary_key.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,22 @@
use super::{ColumnTrait, IdenStatic, Iterable};
use crate::TryGetable;
use sea_query::IntoValueTuple;
use std::{
fmt::{Debug, Display},
str::FromStr,
};

//LINT: composite primary key cannot auto increment
pub trait PrimaryKeyTrait: IdenStatic + Iterable {
type ValueType: Sized
+ Default
+ Debug
+ Display
+ PartialEq
+ IntoValueTuple
+ TryGetable
+ FromStr;

fn auto_increment() -> bool;
}

Expand Down
9 changes: 3 additions & 6 deletions src/executor/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ pub(crate) enum ExecResultHolder {
#[cfg(feature = "sqlx-mysql")]
SqlxMySql(sqlx::mysql::MySqlQueryResult),
#[cfg(feature = "sqlx-postgres")]
SqlxPostgres {
last_insert_id: u64,
rows_affected: u64,
},
SqlxPostgres(sqlx::postgres::PgQueryResult),
#[cfg(feature = "sqlx-sqlite")]
SqlxSqlite(sqlx::sqlite::SqliteQueryResult),
#[cfg(feature = "mock")]
Expand All @@ -26,7 +23,7 @@ impl ExecResult {
#[cfg(feature = "sqlx-mysql")]
ExecResultHolder::SqlxMySql(result) => result.last_insert_id(),
#[cfg(feature = "sqlx-postgres")]
ExecResultHolder::SqlxPostgres { last_insert_id, .. } => last_insert_id.to_owned(),
ExecResultHolder::SqlxPostgres(_) => panic!("Should not retrieve last_insert_id this way"),
#[cfg(feature = "sqlx-sqlite")]
ExecResultHolder::SqlxSqlite(result) => {
let last_insert_rowid = result.last_insert_rowid();
Expand All @@ -46,7 +43,7 @@ impl ExecResult {
#[cfg(feature = "sqlx-mysql")]
ExecResultHolder::SqlxMySql(result) => result.rows_affected(),
#[cfg(feature = "sqlx-postgres")]
ExecResultHolder::SqlxPostgres { rows_affected, .. } => rows_affected.to_owned(),
ExecResultHolder::SqlxPostgres(result) => result.rows_affected(),
#[cfg(feature = "sqlx-sqlite")]
ExecResultHolder::SqlxSqlite(result) => result.rows_affected(),
#[cfg(feature = "mock")]
Expand Down
75 changes: 54 additions & 21 deletions src/executor/insert.rs
Original file line number Diff line number Diff line change
@@ -1,31 +1,43 @@
use crate::{error::*, ActiveModelTrait, DatabaseConnection, Insert, Statement};
use crate::{
error::*, ActiveModelTrait, DatabaseConnection, EntityTrait, Insert, PrimaryKeyTrait, Statement,
};
use sea_query::InsertStatement;
use std::future::Future;
use std::{future::Future, marker::PhantomData};

#[derive(Clone, Debug)]
pub struct Inserter {
pub struct Inserter<A>
where
A: ActiveModelTrait,
{
query: InsertStatement,
model: PhantomData<A>,
}

#[derive(Clone, Debug)]
pub struct InsertResult {
pub last_insert_id: u64,
#[derive(Debug)]
pub struct InsertResult<A>
where
A: ActiveModelTrait,
{
pub last_insert_id: <<<A as ActiveModelTrait>::Entity as EntityTrait>::PrimaryKey as PrimaryKeyTrait>::ValueType,
}

impl<A> Insert<A>
where
A: ActiveModelTrait,
{
#[allow(unused_mut)]
pub fn exec(
pub fn exec<'a>(
self,
db: &DatabaseConnection,
) -> impl Future<Output = Result<InsertResult, DbErr>> + '_ {
db: &'a DatabaseConnection,
) -> impl Future<Output = Result<InsertResult<A>, DbErr>> + 'a
where
A: 'a,
{
// so that self is dropped before entering await
let mut query = self.query;
#[cfg(feature = "sqlx-postgres")]
if let DatabaseConnection::SqlxPostgresPoolConnection(_) = db {
use crate::{EntityTrait, Iterable};
use crate::Iterable;
use sea_query::{Alias, Expr, Query};
for key in <A::Entity as EntityTrait>::PrimaryKey::iter() {
query.returning(
Expand All @@ -35,36 +47,57 @@ where
);
}
}
Inserter::new(query).exec(db)
Inserter::<A>::new(query).exec(db)
}
}

impl Inserter {
impl<A> Inserter<A>
where
A: ActiveModelTrait,
{
pub fn new(query: InsertStatement) -> Self {
Self { query }
Self {
query,
model: PhantomData,
}
}

pub fn exec(
pub fn exec<'a>(
self,
db: &DatabaseConnection,
) -> impl Future<Output = Result<InsertResult, DbErr>> + '_ {
db: &'a DatabaseConnection,
) -> impl Future<Output = Result<InsertResult<A>, DbErr>> + 'a
where
A: 'a,
{
let builder = db.get_database_backend();
exec_insert(builder.build(&self.query), db)
}
}

// Only Statement impl Send
async fn exec_insert(statement: Statement, db: &DatabaseConnection) -> Result<InsertResult, DbErr> {
async fn exec_insert<A>(
statement: Statement,
db: &DatabaseConnection,
) -> Result<InsertResult<A>, DbErr>
where
A: ActiveModelTrait,
{
// TODO: Postgres instead use query_one + returning clause
let result = match db {
let last_insert_id = match db {
#[cfg(feature = "sqlx-postgres")]
DatabaseConnection::SqlxPostgresPoolConnection(conn) => {
let res = conn.query_one(statement).await?.unwrap();
crate::query_result_into_exec_result(res)?
res.try_get("", "last_insert_id").unwrap_or_default()
}
_ => db.execute(statement).await?,
_ => {
db.execute(statement).await?
.last_insert_id()
.to_string()
.parse()
billy1624 marked this conversation as resolved.
Show resolved Hide resolved
.unwrap_or_default()
},
};
Ok(InsertResult {
last_insert_id: result.last_insert_id(),
last_insert_id,
})
}
Loading