diff --git a/coordinator/src/db/positions.rs b/coordinator/src/db/positions.rs index d19ffd754..b972013cd 100644 --- a/coordinator/src/db/positions.rs +++ b/coordinator/src/db/positions.rs @@ -301,8 +301,8 @@ impl Position { conn: &mut PgConnection, trader_pubkey: String, temporary_contract_id: ContractId, - ) -> Result<()> { - let affected_rows = diesel::update(positions::table) + ) -> QueryResult<()> { + diesel::update(positions::table) .filter(positions::trader_pubkey.eq(trader_pubkey)) .filter( positions::position_state @@ -316,8 +316,6 @@ impl Position { )) .execute(conn)?; - ensure!(affected_rows > 0, "Could not set position to open"); - Ok(()) } diff --git a/coordinator/src/dlc_protocol.rs b/coordinator/src/dlc_protocol.rs index a07207a4b..f2b29922d 100644 --- a/coordinator/src/dlc_protocol.rs +++ b/coordinator/src/dlc_protocol.rs @@ -8,6 +8,7 @@ use diesel::r2d2::Pool; use diesel::result::Error::RollbackTransaction; use diesel::Connection; use diesel::PgConnection; +use diesel::QueryResult; use dlc_manager::ContractId; use dlc_manager::ReferenceId; use ln_dlc_node::node::rust_dlc_manager::DlcChannelId; @@ -222,6 +223,50 @@ impl DlcProtocolExecutor { Ok(()) } + /// Finishes a dlc protocol by the corresponding dlc protocol type handling. + pub fn finish_dlc_protocol( + &self, + protocol_id: ProtocolId, + trader_id: &PublicKey, + contract_id: &ContractId, + channel_id: &DlcChannelId, + ) -> Result<()> { + let mut conn = self.pool.get()?; + conn.transaction(|conn| { + let dlc_protocol = db::dlc_protocols::get_dlc_protocol(conn, protocol_id)?; + + match dlc_protocol.protocol_type { + DlcProtocolType::Open { trade_params } + | DlcProtocolType::Renew { trade_params } => self.finish_trade_dlc_protocol( + conn, + trade_params, + protocol_id, + false, + contract_id, + channel_id, + ), + DlcProtocolType::Settle { trade_params } => self.finish_trade_dlc_protocol( + conn, + trade_params, + protocol_id, + true, + contract_id, + channel_id, + ), + DlcProtocolType::Rollover { .. } => self.finish_rollover_dlc_protocol( + conn, + trader_id, + protocol_id, + contract_id, + channel_id, + ), + _ => Ok(()), + } + })?; + + Ok(()) + } + /// Completes the trade dlc protocol as successful and updates the 10101 meta data /// accordingly in a single database transaction. /// - Set dlc protocol to success @@ -230,142 +275,128 @@ impl DlcProtocolExecutor { /// - If closing: Calculates the pnl and sets the `[PositionState::Closing`] position state to /// `[PositionState::Closed`] /// - Creates and inserts the new trade - pub fn finish_trade_dlc_protocol( + fn finish_trade_dlc_protocol( &self, + conn: &mut PgConnection, + trade_params: TradeParams, protocol_id: ProtocolId, closing: bool, contract_id: &ContractId, channel_id: &DlcChannelId, - ) -> Result<()> { - let mut conn = self.pool.get()?; - - conn.transaction(|conn| { - let trade_params: TradeParams = db::trade_params::get(conn, protocol_id)?; - - db::dlc_protocols::set_dlc_protocol_state_to_success( + ) -> QueryResult<()> { + db::dlc_protocols::set_dlc_protocol_state_to_success( + conn, + protocol_id, + contract_id, + channel_id, + )?; + + // TODO(holzeis): We are still updating the position based on the position state. This + // will change once we only have a single position per user and representing + // the position only as view on multiple trades. + let position = match closing { + false => db::positions::Position::update_proposed_position( conn, - protocol_id, - contract_id, - channel_id, - )?; - - // TODO(holzeis): We are still updating the position based on the position state. This - // will change once we only have a single position per user and representing - // the position only as view on multiple trades. - let position = match closing { - false => db::positions::Position::update_proposed_position( + trade_params.trader.to_string(), + PositionState::Open, + ), + true => { + let position = match db::positions::Position::get_position_by_trader( conn, - trade_params.trader.to_string(), - PositionState::Open, - ), - true => { - let position = match db::positions::Position::get_position_by_trader( - conn, - trade_params.trader, - vec![ - // The price doesn't matter here. - PositionState::Closing { closing_price: 0.0 }, - ], - )? { - Some(position) => position, - None => { - tracing::error!("No position in state Closing found."); - return Err(RollbackTransaction); - } + trade_params.trader, + vec![ + // The price doesn't matter here. + PositionState::Closing { closing_price: 0.0 }, + ], + )? { + Some(position) => position, + None => { + tracing::error!("No position in state Closing found."); + return Err(RollbackTransaction); + } + }; + + tracing::debug!( + ?position, + trader_id = %trade_params.trader, + "Finalize closing position", + ); + + let pnl = { + let (initial_margin_long, initial_margin_short) = match trade_params.direction { + Direction::Long => (position.trader_margin, position.coordinator_margin), + Direction::Short => (position.coordinator_margin, position.trader_margin), }; - tracing::debug!( - ?position, - trader_id = %trade_params.trader, - "Finalize closing position", - ); - - let pnl = { - let (initial_margin_long, initial_margin_short) = - match trade_params.direction { - Direction::Long => { - (position.trader_margin, position.coordinator_margin) - } - Direction::Short => { - (position.coordinator_margin, position.trader_margin) - } - }; - - match calculate_pnl( - Decimal::from_f32(position.average_entry_price) - .expect("to fit into decimal"), - Decimal::from_f32(trade_params.average_price) - .expect("to fit into decimal"), - trade_params.quantity, - trade_params.direction, - initial_margin_long as u64, - initial_margin_short as u64, - ) { - Ok(pnl) => pnl, - Err(e) => { - tracing::error!("Failed to calculate pnl. Error: {e:#}"); - return Err(RollbackTransaction); - } + match calculate_pnl( + Decimal::from_f32(position.average_entry_price) + .expect("to fit into decimal"), + Decimal::from_f32(trade_params.average_price).expect("to fit into decimal"), + trade_params.quantity, + trade_params.direction, + initial_margin_long as u64, + initial_margin_short as u64, + ) { + Ok(pnl) => pnl, + Err(e) => { + tracing::error!("Failed to calculate pnl. Error: {e:#}"); + return Err(RollbackTransaction); } - }; + } + }; - db::positions::Position::set_position_to_closed_with_pnl(conn, position.id, pnl) - } - }?; - - let coordinator_margin = calculate_margin( - Decimal::try_from(trade_params.average_price).expect("to fit into decimal"), - trade_params.quantity, - crate::trade::coordinator_leverage_for_trade(&trade_params.trader) - .map_err(|_| RollbackTransaction)?, - ); - - // TODO(holzeis): Add optional pnl to trade. - // Instead of tracking pnl on the position we want to track pnl on the trade. e.g. Long - // -> Short or Short -> Long. - let new_trade = NewTrade { - position_id: position.id, - contract_symbol: position.contract_symbol, - trader_pubkey: trade_params.trader, - quantity: trade_params.quantity, - trader_leverage: trade_params.leverage, - coordinator_margin: coordinator_margin as i64, - trader_direction: trade_params.direction, - average_price: trade_params.average_price, - dlc_expiry_timestamp: None, - }; - - db::trades::insert(conn, new_trade)?; - - db::trade_params::delete(conn, protocol_id) - })?; + db::positions::Position::set_position_to_closed_with_pnl(conn, position.id, pnl) + } + }?; + + let coordinator_margin = calculate_margin( + Decimal::try_from(trade_params.average_price).expect("to fit into decimal"), + trade_params.quantity, + crate::trade::coordinator_leverage_for_trade(&trade_params.trader) + .map_err(|_| RollbackTransaction)?, + ); + + // TODO(holzeis): Add optional pnl to trade. + // Instead of tracking pnl on the position we want to track pnl on the trade. e.g. Long + // -> Short or Short -> Long. + let new_trade = NewTrade { + position_id: position.id, + contract_symbol: position.contract_symbol, + trader_pubkey: trade_params.trader, + quantity: trade_params.quantity, + trader_leverage: trade_params.leverage, + coordinator_margin: coordinator_margin as i64, + trader_direction: trade_params.direction, + average_price: trade_params.average_price, + dlc_expiry_timestamp: None, + }; + + db::trades::insert(conn, new_trade)?; + + db::trade_params::delete(conn, protocol_id)?; Ok(()) } /// Completes the rollover dlc protocol as successful and updates the 10101 meta data /// accordingly in a single database transaction. - pub fn finish_rollover_dlc_protocol( + fn finish_rollover_dlc_protocol( &self, + conn: &mut PgConnection, + trader: &PublicKey, protocol_id: ProtocolId, contract_id: &ContractId, channel_id: &DlcChannelId, - trader: &PublicKey, - ) -> Result<()> { + ) -> QueryResult<()> { tracing::debug!(%trader, %protocol_id, "Finalizing rollover"); - let mut conn = self.pool.get()?; - - conn.transaction(|conn| { - db::dlc_protocols::set_dlc_protocol_state_to_success( - conn, - protocol_id, - contract_id, - channel_id, - )?; - - db::positions::Position::set_position_to_open(conn, trader.to_string(), *contract_id) - })?; + db::dlc_protocols::set_dlc_protocol_state_to_success( + conn, + protocol_id, + contract_id, + channel_id, + )?; + db::positions::Position::set_position_to_open(conn, trader.to_string(), *contract_id)?; Ok(()) } } diff --git a/coordinator/src/node.rs b/coordinator/src/node.rs index c9955abec..f6765aa61 100644 --- a/coordinator/src/node.rs +++ b/coordinator/src/node.rs @@ -293,21 +293,12 @@ impl Node { let protocol_executor = dlc_protocol::DlcProtocolExecutor::new(self.pool.clone()); - if self.is_in_rollover(node_id)? { - protocol_executor.finish_rollover_dlc_protocol( - protocol_id, - &contract_id, - &channel.get_id(), - &channel.get_counter_party_id(), - )?; - } else { - protocol_executor.finish_trade_dlc_protocol( - protocol_id, - false, - &contract_id, - &channel.get_id(), - )?; - } + protocol_executor.finish_dlc_protocol( + protocol_id, + &channel.get_counter_party_id(), + &contract_id, + channel_id, + )?; } ChannelMessage::SettleFinalize(SettleFinalize { channel_id, @@ -367,9 +358,9 @@ impl Node { let protocol_executor = dlc_protocol::DlcProtocolExecutor::new(self.pool.clone()); - protocol_executor.finish_trade_dlc_protocol( + protocol_executor.finish_dlc_protocol( protocol_id, - true, + &node_id, &contract_id, channel_id, )?; @@ -427,9 +418,9 @@ impl Node { let protocol_executor = dlc_protocol::DlcProtocolExecutor::new(self.pool.clone()); - protocol_executor.finish_trade_dlc_protocol( + protocol_executor.finish_dlc_protocol( protocol_id, - false, + &node_id, &contract_id, &channel_id, )?;