Skip to content

Commit

Permalink
one failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
r0fls committed Oct 23, 2024
1 parent c8ff0a1 commit a1ffedf
Show file tree
Hide file tree
Showing 3 changed files with 341 additions and 55 deletions.
138 changes: 87 additions & 51 deletions brokers/base_broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,79 +100,123 @@ async def has_bought_today(self, symbol):
logger.error('Failed to check if bought today', extra={'error': str(e)})
return False


async def update_positions(self, trade_id, session):
'''Update the positions based on the trade'''
result = await session.execute(select(Trade).filter_by(id=trade_id))
trade = result.scalars().first()
logger.info('Updating positions', extra={'trade': trade})

if trade.quantity == 0:
logger.error('Trade quantity is 0, doing nothing', extra={'trade': trade})
return

try:
# Log before querying the position
logger.debug(f"Querying position for symbol: {trade.symbol}, broker: {self.broker_name}, strategy: {trade.strategy}")
# Fetch the trade
result = await session.execute(select(Trade).filter_by(id=trade_id))
trade = result.scalars().first()
logger.info('Updating positions', extra={'trade': trade})

if trade.quantity == 0:
logger.error('Trade quantity is 0, doing nothing', extra={'trade': trade})
return

# Query the current position for the trade's symbol, broker, and strategy
result = await session.execute(
select(Position).filter_by(
symbol=trade.symbol, broker=self.broker_name, strategy=trade.strategy
)
select(Position).filter_by(symbol=trade.symbol, broker=self.broker_name, strategy=trade.strategy)
)
position = result.scalars().first()

# Log after querying the position
logger.debug(f"Queried position: {position}")

# Initialize profit/loss
profit_loss = 0

# Handling Buy Orders
if trade.order_type == 'buy':
logger.info('Processing buy order', extra={'trade': trade})
if position:
logger.debug(f"Updating existing position: {position}")
if is_option(trade.symbol):
position.cost_basis += float(trade.executed_price) * float(trade.quantity) * OPTION_MULTIPLIER
if is_futures_symbol(trade.symbol):
multiplier = futures_contract_size(trade.symbol)
position.cost_basis += float(trade.executed_price) * float(trade.quantity) * multiplier
if position and position.quantity < 0: # This is a short cover
logger.info('Processing short cover', extra={'trade': trade})

# Calculate P/L for short cover (covering short position)
cost_per_share = position.cost_basis / abs(position.quantity)
profit_loss = (cost_per_share - float(trade.executed_price)) * abs(trade.quantity)
logger.info(f'Short cover profit/loss calculated: {profit_loss}', extra={'trade': trade, 'position': position})

# Update or remove the short position
if abs(position.quantity) == trade.quantity:
logger.info('Fully covering short position, removing position', extra={'position': position})
await session.delete(position)
else:
position.cost_basis -= cost_per_share * abs(trade.quantity)
position.quantity += trade.quantity # Add back the covered quantity
position.latest_price = float(trade.executed_price)
position.timestamp = datetime.now()
session.add(position)
trade.profit_loss = profit_loss
session.add(trade)

else: # Regular Buy
logger.info('Processing regular buy order', extra={'trade': trade})
if position:
# Update existing position
cost_increment = float(trade.executed_price) * trade.quantity
if is_option(trade.symbol):
position.cost_basis += cost_increment * OPTION_MULTIPLIER
elif is_futures_symbol(trade.symbol):
multiplier = futures_contract_size(trade.symbol)
position.cost_basis += cost_increment * multiplier
else:
position.cost_basis += cost_increment
position.quantity += trade.quantity
position.latest_price = float(trade.executed_price)
position.timestamp = datetime.now()
session.add(position)
else:
position.cost_basis += float(trade.executed_price) * float(trade.quantity)
position.quantity += trade.quantity
position.latest_price = float(trade.executed_price)
position.timestamp = datetime.now()
else:
logger.debug(f"Creating new position for symbol: {trade.symbol}")
# Create a new position
position = Position(
broker=self.broker_name,
strategy=trade.strategy,
symbol=trade.symbol,
quantity=trade.quantity,
latest_price=float(trade.executed_price),
cost_basis=float(trade.executed_price) * trade.quantity,
)
session.add(position)

# Handling Sell Orders
elif trade.order_type == 'sell':
logger.info('Processing sell order', extra={'trade': trade})

# Short sales
if position is None:
logger.info('Short sale detected', extra={'trade': trade, 'quantity': trade.quantity, 'symbol': trade.symbol})
position = Position(
broker=self.broker_name,
strategy=trade.strategy,
symbol=trade.symbol,
quantity=trade.quantity,
quantity=-trade.quantity,
latest_price=float(trade.executed_price),
cost_basis=float(trade.executed_price) * float(trade.quantity),
cost_basis=float(trade.executed_price) * trade.quantity,
)
session.add(position)

elif trade.order_type == 'sell':
logger.info('Processing sell order', extra={'trade': trade})
if position:
if position.quantity == trade.quantity:
cost_per_share = position.cost_basis / position.quantity
profit_loss = (float(trade.executed_price) - cost_per_share) * trade.quantity
logger.info(f'Sell order profit/loss calculated: {profit_loss}', extra={'trade': trade, 'position': position})

if position.quantity == trade.quantity: # Full sell
logger.info('Deleting sold position', extra={'position': position})
await session.delete(position)
await session.commit()
elif position.quantity > trade.quantity:
logger.debug(f"Reducing quantity of position: {position}")
cost_per_share = position.cost_basis / position.quantity
else: # Partial sell
position.cost_basis -= trade.quantity * cost_per_share
position.quantity -= trade.quantity
position.latest_price = float(trade.executed_price)
session.add(position)
session.add(position)
trade.profit_loss = profit_loss
session.add(trade)

# Update the trade with the calculated profit/loss
session.add(trade)

# Commit the transaction
await session.commit()

# Log after committing changes
logger.info('Position updated', extra={'position': position})

except Exception as e:
logger.error('Failed to update positions', extra={'error': str(e)})

await session.rollback()

async def place_future_option_order(self, symbol, quantity, order_type, strategy, price=None):
multiplier = futures_contract_size(symbol)
Expand Down Expand Up @@ -227,12 +271,6 @@ async def _place_order_generic(self, symbol, quantity, order_type, strategy, pri
success='yes'
)

# Calculate profit/loss if it's a sell order
if order_type == 'sell':
profit_loss = await self.db_manager.calculate_profit_loss(trade)
logger.info('Profit/Loss calculated', extra={'profit_loss': profit_loss})
trade.profit_loss = profit_loss

# Update the trade and positions in the database
async with self.Session() as session:
session.add(trade)
Expand Down Expand Up @@ -325,12 +363,10 @@ async def update_trade(self, session, trade_id, order_info):

executed_price = order_info.get('filled_price', trade.price)
trade.executed_price = executed_price
profit_loss = await self.db_manager.calculate_profit_loss(trade)
success = "success" if profit_loss > 0 else "failure"
success = "success" if trade.profit_loss > 0 else "failure"

trade.executed_price = executed_price
trade.success = success
trade.profit_loss = profit_loss
await session.commit()
logger.info('Trade updated', extra={'trade': trade})
except Exception as e:
Expand Down
46 changes: 42 additions & 4 deletions database/db_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ async def get_position(self, broker, symbol, strategy):
logger.error('Failed to retrieve position', extra={'error': str(e)})
return None


async def calculate_profit_loss(self, trade):
async with self.Session() as session:
try:
Expand All @@ -88,10 +89,19 @@ async def calculate_profit_loss(self, trade):
logger.error('Executed price is None, cannot calculate profit/loss', extra={'trade': trade})
return None

# Handling buy trades
# Handling buy trades that cover a short position
if trade.order_type.lower() == 'buy':
logger.info('Buy order detected, no P/L calculation needed.', extra={'trade': trade})
return profit_loss
position = await self.get_position(trade.broker, trade.symbol, trade.strategy)
if position and position.quantity < 0: # Detect if this is a short cover
logger.info('Short cover detected, calculating P/L.', extra={'trade': trade})

# Calculate P/L for covering short (short sell price - buy price)
cost_per_share = float(position.cost_basis) / abs(position.quantity)
profit_loss = (cost_per_share - current_price) * abs(trade.quantity)
logger.info(f'Short cover P/L calculated as: {profit_loss}', extra={'trade': trade})
else:
logger.info('Regular buy order detected, no P/L calculation needed.', extra={'trade': trade})
return profit_loss

# Handling sell trades
elif trade.order_type.lower() == 'sell':
Expand Down Expand Up @@ -120,15 +130,27 @@ async def calculate_profit_loss(self, trade):
logger.error('Failed to calculate profit/loss', extra={'error': str(e), 'trade': trade})
return None


async def calculate_partial_profit_loss(self, trade, position):
try:
profit_loss = None
logger.info('Calculating partial profit/loss', extra={'trade': trade, 'position': position})

if trade.order_type.lower() == 'sell':
# Partial sell for regular positions
logger.info('Partial sell order detected, calculating P/L.', extra={'trade': trade})
profit_loss = (float(trade.executed_price) - (float(position.cost_basis) / position.quantity)) * trade.quantity

elif trade.order_type.lower() == 'buy' and position.quantity < 0:
# Partial short cover (buying back part of the short position)
logger.info('Partial short cover detected, calculating P/L.', extra={'trade': trade})

# Calculate P/L for covering a short (short sell price - cover price)
cost_per_share = float(position.cost_basis) / abs(position.quantity)
profit_loss = (cost_per_share - float(trade.executed_price)) * abs(trade.quantity)

logger.info('Partial profit/loss calculated', extra={'trade': trade, 'position': position, 'profit_loss': profit_loss})
return profit_loss

except Exception as e:
logger.error('Failed to calculate partial profit/loss', extra={'error': str(e)})
return None
Expand Down Expand Up @@ -187,3 +209,19 @@ async def rename_strategy(self, broker, old_strategy_name, new_strategy_name):
except Exception as e:
await session.rollback()
logger.error('Failed to update strategy name', extra={'error': str(e)})


async def get_profit_loss(self, trade_id):
async with self.Session() as session:
try:
logger.info('Retrieving profit/loss', extra={'trade': trade_id})
result = await session.execute(select(Trade).filter_by(id=trade_id))
trade = result.scalar()
if trade is None:
logger.warning(f"No trade found with id {trade.id}")
return None
logger.info('Profit/loss retrieved', extra={'trade': trade})
return trade.profit_loss
except Exception as e:
logger.error('Failed to retrieve profit/loss', extra={'error': str(e)})
return None
Loading

0 comments on commit a1ffedf

Please sign in to comment.