Skip to content

Commit

Permalink
chore: Add strong type for protocol id
Browse files Browse the repository at this point in the history
  • Loading branch information
holzeis committed Feb 27, 2024
1 parent c46b28e commit 37c7d83
Show file tree
Hide file tree
Showing 11 changed files with 205 additions and 173 deletions.
47 changes: 12 additions & 35 deletions coordinator/src/db/dlc_protocols.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use crate::dlc_protocol;
use crate::dlc_protocol::ProtocolId;
use crate::schema::dlc_protocols;
use crate::schema::sql_types::ProtocolStateType;
use bitcoin::hashes::hex::FromHex;
use bitcoin::hashes::hex::ToHex;
use bitcoin::secp256k1::PublicKey;
use diesel::query_builder::QueryId;
use diesel::result::Error::RollbackTransaction;
use diesel::AsExpression;
use diesel::ExpressionMethods;
use diesel::FromSqlRow;
Expand All @@ -16,8 +16,6 @@ use diesel::Queryable;
use diesel::RunQueryDsl;
use dlc_manager::ContractId;
use dlc_manager::DlcChannelId;
use dlc_manager::ReferenceId;
use ln_dlc_node::util;
use std::any::TypeId;
use std::str::FromStr;
use time::OffsetDateTime;
Expand Down Expand Up @@ -56,27 +54,21 @@ pub(crate) struct DlcProtocol {

pub(crate) fn get_dlc_protocol(
conn: &mut PgConnection,
protocol_id: ReferenceId,
protocol_id: ProtocolId,
) -> QueryResult<dlc_protocol::DlcProtocol> {
let protocol_id =
util::parse_from_reference_id(protocol_id).map_err(|_| RollbackTransaction)?;

let contract_transaction: DlcProtocol = dlc_protocols::table
.filter(dlc_protocols::protocol_id.eq(protocol_id))
.filter(dlc_protocols::protocol_id.eq(protocol_id.to_uuid()))
.first(conn)?;

Ok(dlc_protocol::DlcProtocol::from(contract_transaction))
}

pub(crate) fn set_dlc_protocol_state_to_failed(
conn: &mut PgConnection,
protocol_id: ReferenceId,
protocol_id: ProtocolId,
) -> QueryResult<()> {
let protocol_id =
util::parse_from_reference_id(protocol_id).map_err(|_| RollbackTransaction)?;

let affected_rows = diesel::update(dlc_protocols::table)
.filter(dlc_protocols::protocol_id.eq(protocol_id))
.filter(dlc_protocols::protocol_id.eq(protocol_id.to_uuid()))
.set((dlc_protocols::protocol_state.eq(DlcProtocolState::Failed),))
.execute(conn)?;

Expand All @@ -89,15 +81,12 @@ pub(crate) fn set_dlc_protocol_state_to_failed(

pub(crate) fn set_dlc_protocol_state_to_success(
conn: &mut PgConnection,
protocol_id: ReferenceId,
protocol_id: ProtocolId,
contract_id: ContractId,
channel_id: DlcChannelId,
) -> QueryResult<()> {
let protocol_id =
util::parse_from_reference_id(protocol_id).map_err(|_| RollbackTransaction)?;

let affected_rows = diesel::update(dlc_protocols::table)
.filter(dlc_protocols::protocol_id.eq(protocol_id))
.filter(dlc_protocols::protocol_id.eq(protocol_id.to_uuid()))
.set((
dlc_protocols::protocol_state.eq(DlcProtocolState::Success),
dlc_protocols::contract_id.eq(contract_id.to_hex()),
Expand All @@ -114,28 +103,16 @@ pub(crate) fn set_dlc_protocol_state_to_success(

pub(crate) fn create(
conn: &mut PgConnection,
protocol_id: ReferenceId,
previous_protocol_id: Option<ReferenceId>,
protocol_id: ProtocolId,
previous_protocol_id: Option<ProtocolId>,
contract_id: ContractId,
channel_id: DlcChannelId,
trader: &PublicKey,
) -> QueryResult<()> {
let protocol_id =
util::parse_from_reference_id(protocol_id).map_err(|_| RollbackTransaction)?;

let previous_protocol_id = match previous_protocol_id {
Some(previous_protocol_id) => {
let previous_protocol_id = util::parse_from_reference_id(previous_protocol_id)
.map_err(|_| RollbackTransaction)?;
Some(previous_protocol_id)
}
None => None,
};

let affected_rows = diesel::insert_into(dlc_protocols::table)
.values(&(
dlc_protocols::protocol_id.eq(protocol_id),
dlc_protocols::previous_protocol_id.eq(previous_protocol_id),
dlc_protocols::protocol_id.eq(protocol_id.to_uuid()),
dlc_protocols::previous_protocol_id.eq(previous_protocol_id.map(|ppid| ppid.to_uuid())),
dlc_protocols::contract_id.eq(contract_id.to_hex()),
dlc_protocols::channel_id.eq(channel_id.to_hex()),
dlc_protocols::protocol_state.eq(DlcProtocolState::Pending),
Expand All @@ -154,7 +131,7 @@ pub(crate) fn create(
impl From<DlcProtocol> for dlc_protocol::DlcProtocol {
fn from(value: DlcProtocol) -> Self {
dlc_protocol::DlcProtocol {
id: value.protocol_id,
id: value.protocol_id.into(),
timestamp: value.timestamp,
channel_id: DlcChannelId::from_hex(&value.channel_id).expect("valid dlc channel id"),
contract_id: ContractId::from_hex(&value.contract_id).expect("valid contract id"),
Expand Down
24 changes: 8 additions & 16 deletions coordinator/src/db/trade_params.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,14 @@
use crate::dlc_protocol;
use crate::dlc_protocol::ProtocolId;
use crate::orderbook::db::custom_types::Direction;
use crate::schema::trade_params;
use bitcoin::secp256k1::PublicKey;
use diesel::result::Error::RollbackTransaction;
use diesel::ExpressionMethods;
use diesel::PgConnection;
use diesel::QueryDsl;
use diesel::QueryResult;
use diesel::Queryable;
use diesel::RunQueryDsl;
use dlc_manager::ReferenceId;
use ln_dlc_node::util;
use rust_decimal::prelude::ToPrimitive;
use std::str::FromStr;
use uuid::Uuid;
Expand All @@ -30,19 +28,17 @@ pub(crate) struct TradeParams {

pub(crate) fn insert(
conn: &mut PgConnection,
protocol_id: ReferenceId,
protocol_id: ProtocolId,
params: &commons::TradeParams,
) -> QueryResult<()> {
let protocol_id =
util::parse_from_reference_id(protocol_id).map_err(|_| RollbackTransaction)?;
let average_price = params
.average_execution_price()
.to_f32()
.expect("to fit into f32");

let affected_rows = diesel::insert_into(trade_params::table)
.values(&(
trade_params::protocol_id.eq(protocol_id),
trade_params::protocol_id.eq(protocol_id.to_uuid()),
trade_params::quantity.eq(params.quantity),
trade_params::leverage.eq(params.leverage),
trade_params::trader_pubkey.eq(params.pubkey.to_string()),
Expand All @@ -60,29 +56,25 @@ pub(crate) fn insert(

pub(crate) fn get(
conn: &mut PgConnection,
protocol_id: ReferenceId,
protocol_id: ProtocolId,
) -> QueryResult<dlc_protocol::TradeParams> {
let protocol_id =
util::parse_from_reference_id(protocol_id).map_err(|_| RollbackTransaction)?;
let trade_params: TradeParams = trade_params::table
.filter(trade_params::protocol_id.eq(protocol_id))
.filter(trade_params::protocol_id.eq(protocol_id.to_uuid()))
.first(conn)?;

Ok(dlc_protocol::TradeParams::from(trade_params))
}

pub(crate) fn delete(conn: &mut PgConnection, protocol_id: ReferenceId) -> QueryResult<usize> {
let protocol_id =
util::parse_from_reference_id(protocol_id).map_err(|_| RollbackTransaction)?;
pub(crate) fn delete(conn: &mut PgConnection, protocol_id: ProtocolId) -> QueryResult<usize> {
diesel::delete(trade_params::table)
.filter(trade_params::protocol_id.eq(protocol_id))
.filter(trade_params::protocol_id.eq(protocol_id.to_uuid()))
.execute(conn)
}

impl From<TradeParams> for dlc_protocol::TradeParams {
fn from(value: TradeParams) -> Self {
Self {
protocol_id: value.protocol_id,
protocol_id: value.protocol_id.into(),
trader: PublicKey::from_str(&value.trader_pubkey).expect("valid pubkey"),
quantity: value.quantity,
leverage: value.leverage,
Expand Down
109 changes: 101 additions & 8 deletions coordinator/src/dlc_protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,90 @@ use dlc_manager::ReferenceId;
use ln_dlc_node::node::rust_dlc_manager::DlcChannelId;
use rust_decimal::prelude::FromPrimitive;
use rust_decimal::Decimal;
use std::fmt::Display;
use std::fmt::Formatter;
use std::str::from_utf8;
use time::OffsetDateTime;
use trade::cfd::calculate_margin;
use trade::cfd::calculate_pnl;
use trade::Direction;
use uuid::Uuid;

#[derive(Debug, Copy, Clone, PartialEq)]
pub struct ProtocolId(Uuid);

impl ProtocolId {
pub fn new() -> Self {
ProtocolId(Uuid::new_v4())
}

pub fn to_uuid(&self) -> Uuid {
self.0
}
}

impl Default for ProtocolId {
fn default() -> Self {
Self::new()
}
}

impl Display for ProtocolId {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
self.0.to_string().fmt(f)
}
}

impl From<ProtocolId> for ReferenceId {
fn from(value: ProtocolId) -> Self {
let uuid = value.to_uuid();

// 16 bytes.
let uuid_bytes = uuid.as_bytes();

// 32-digit hex string.
let hex = hex::encode(uuid_bytes);

// Derived `ReferenceId`: 32-bytes.
let hex_bytes = hex.as_bytes();

let mut array = [0u8; 32];
array.copy_from_slice(hex_bytes);

array
}
}

impl TryFrom<ReferenceId> for ProtocolId {
type Error = anyhow::Error;

fn try_from(value: ReferenceId) -> Result<Self> {
// 32-digit hex string.
let hex = from_utf8(&value)?;

// 16 bytes.
let uuid_bytes = hex::decode(hex)?;

let uuid = Uuid::from_slice(&uuid_bytes)?;

Ok(ProtocolId(uuid))
}
}

impl From<Uuid> for ProtocolId {
fn from(value: Uuid) -> Self {
ProtocolId(value)
}
}

impl From<ProtocolId> for Uuid {
fn from(value: ProtocolId) -> Self {
value.0
}
}

pub struct DlcProtocol {
pub id: Uuid,
pub id: ProtocolId,
pub timestamp: OffsetDateTime,
pub channel_id: DlcChannelId,
pub contract_id: ContractId,
Expand All @@ -29,7 +105,7 @@ pub struct DlcProtocol {
}

pub struct TradeParams {
pub protocol_id: Uuid,
pub protocol_id: ProtocolId,
pub trader: PublicKey,
pub quantity: f32,
pub leverage: f32,
Expand Down Expand Up @@ -58,8 +134,8 @@ impl DlcProtocolExecutor {
/// Returns a uniquely generated protocol id as [`dlc_manager::ReferenceId`]
pub fn start_dlc_protocol(
&self,
protocol_id: ReferenceId,
previous_protocol_id: Option<ReferenceId>,
protocol_id: ProtocolId,
previous_protocol_id: Option<ProtocolId>,
contract_id: ContractId,
channel_id: DlcChannelId,
trade_params: &commons::TradeParams,
Expand All @@ -82,7 +158,7 @@ impl DlcProtocolExecutor {
Ok(())
}

pub fn fail_dlc_protocol(&self, protocol_id: ReferenceId) -> Result<()> {
pub fn fail_dlc_protocol(&self, protocol_id: ProtocolId) -> Result<()> {
let mut conn = self.pool.get()?;
db::dlc_protocols::set_dlc_protocol_state_to_failed(&mut conn, protocol_id)?;

Expand All @@ -94,12 +170,12 @@ impl DlcProtocolExecutor {
/// - Set dlc protocol to success
/// - If not closing: Updates the `[PostionState::Proposed`] position state to
/// `[PostionState::Open]`
/// - If closing: Calculates the pnl and sets the `[PostionState::Closing`] position state to
/// `[PostionState::Closed`]
/// - If closing: Calculates the pnl and sets the `[PositionState::Closing`] position state to
/// `[PositionState::Closed`]
/// - Creates and inserts the new trade
pub fn finish_dlc_protocol(
&self,
protocol_id: ReferenceId,
protocol_id: ProtocolId,
closing: bool,
contract_id: ContractId,
channel_id: DlcChannelId,
Expand Down Expand Up @@ -210,3 +286,20 @@ impl DlcProtocolExecutor {
Ok(())
}
}

#[cfg(test)]
mod test {
use crate::dlc_protocol::ProtocolId;
use dlc_manager::ReferenceId;

#[test]
fn test_protocol_id_roundtrip() {
let protocol_id_0 = ProtocolId::new();

let reference_id = ReferenceId::from(protocol_id_0);

let protocol_id_1 = ProtocolId::try_from(reference_id).unwrap();

assert_eq!(protocol_id_0, protocol_id_1)
}
}
Loading

0 comments on commit 37c7d83

Please sign in to comment.