From a1ffedfe72b320437b5068ce0bec76090c338fc4 Mon Sep 17 00:00:00 2001 From: Raphael Deem Date: Tue, 22 Oct 2024 21:42:59 -0700 Subject: [PATCH] one failing test --- brokers/base_broker.py | 138 +++++++++++++++++---------- database/db_manager.py | 46 ++++++++- tests/test_brokers.py | 212 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 341 insertions(+), 55 deletions(-) diff --git a/brokers/base_broker.py b/brokers/base_broker.py index 249d370d..4c3a7fa9 100644 --- a/brokers/base_broker.py +++ b/brokers/base_broker.py @@ -100,79 +100,123 @@ async def has_bought_today(self, symbol): logger.error('Failed to check if bought today', extra={'error': str(e)}) return False + async def update_positions(self, trade_id, session): '''Update the positions based on the trade''' - result = await session.execute(select(Trade).filter_by(id=trade_id)) - trade = result.scalars().first() - logger.info('Updating positions', extra={'trade': trade}) - - if trade.quantity == 0: - logger.error('Trade quantity is 0, doing nothing', extra={'trade': trade}) - return - try: - # Log before querying the position - logger.debug(f"Querying position for symbol: {trade.symbol}, broker: {self.broker_name}, strategy: {trade.strategy}") + # Fetch the trade + result = await session.execute(select(Trade).filter_by(id=trade_id)) + trade = result.scalars().first() + logger.info('Updating positions', extra={'trade': trade}) + if trade.quantity == 0: + logger.error('Trade quantity is 0, doing nothing', extra={'trade': trade}) + return + + # Query the current position for the trade's symbol, broker, and strategy result = await session.execute( - select(Position).filter_by( - symbol=trade.symbol, broker=self.broker_name, strategy=trade.strategy - ) + select(Position).filter_by(symbol=trade.symbol, broker=self.broker_name, strategy=trade.strategy) ) position = result.scalars().first() - - # Log after querying the position logger.debug(f"Queried position: {position}") + # Initialize profit/loss + profit_loss = 0 + + # Handling Buy Orders if trade.order_type == 'buy': - logger.info('Processing buy order', extra={'trade': trade}) - if position: - logger.debug(f"Updating existing position: {position}") - if is_option(trade.symbol): - position.cost_basis += float(trade.executed_price) * float(trade.quantity) * OPTION_MULTIPLIER - if is_futures_symbol(trade.symbol): - multiplier = futures_contract_size(trade.symbol) - position.cost_basis += float(trade.executed_price) * float(trade.quantity) * multiplier + if position and position.quantity < 0: # This is a short cover + logger.info('Processing short cover', extra={'trade': trade}) + + # Calculate P/L for short cover (covering short position) + cost_per_share = position.cost_basis / abs(position.quantity) + profit_loss = (cost_per_share - float(trade.executed_price)) * abs(trade.quantity) + logger.info(f'Short cover profit/loss calculated: {profit_loss}', extra={'trade': trade, 'position': position}) + + # Update or remove the short position + if abs(position.quantity) == trade.quantity: + logger.info('Fully covering short position, removing position', extra={'position': position}) + await session.delete(position) + else: + position.cost_basis -= cost_per_share * abs(trade.quantity) + position.quantity += trade.quantity # Add back the covered quantity + position.latest_price = float(trade.executed_price) + position.timestamp = datetime.now() + session.add(position) + trade.profit_loss = profit_loss + session.add(trade) + + else: # Regular Buy + logger.info('Processing regular buy order', extra={'trade': trade}) + if position: + # Update existing position + cost_increment = float(trade.executed_price) * trade.quantity + if is_option(trade.symbol): + position.cost_basis += cost_increment * OPTION_MULTIPLIER + elif is_futures_symbol(trade.symbol): + multiplier = futures_contract_size(trade.symbol) + position.cost_basis += cost_increment * multiplier + else: + position.cost_basis += cost_increment + position.quantity += trade.quantity + position.latest_price = float(trade.executed_price) + position.timestamp = datetime.now() + session.add(position) else: - position.cost_basis += float(trade.executed_price) * float(trade.quantity) - position.quantity += trade.quantity - position.latest_price = float(trade.executed_price) - position.timestamp = datetime.now() - else: - logger.debug(f"Creating new position for symbol: {trade.symbol}") + # Create a new position + position = Position( + broker=self.broker_name, + strategy=trade.strategy, + symbol=trade.symbol, + quantity=trade.quantity, + latest_price=float(trade.executed_price), + cost_basis=float(trade.executed_price) * trade.quantity, + ) + session.add(position) + + # Handling Sell Orders + elif trade.order_type == 'sell': + logger.info('Processing sell order', extra={'trade': trade}) + + # Short sales + if position is None: + logger.info('Short sale detected', extra={'trade': trade, 'quantity': trade.quantity, 'symbol': trade.symbol}) position = Position( broker=self.broker_name, strategy=trade.strategy, symbol=trade.symbol, - quantity=trade.quantity, + quantity=-trade.quantity, latest_price=float(trade.executed_price), - cost_basis=float(trade.executed_price) * float(trade.quantity), + cost_basis=float(trade.executed_price) * trade.quantity, ) session.add(position) - - elif trade.order_type == 'sell': - logger.info('Processing sell order', extra={'trade': trade}) if position: - if position.quantity == trade.quantity: + cost_per_share = position.cost_basis / position.quantity + profit_loss = (float(trade.executed_price) - cost_per_share) * trade.quantity + logger.info(f'Sell order profit/loss calculated: {profit_loss}', extra={'trade': trade, 'position': position}) + + if position.quantity == trade.quantity: # Full sell logger.info('Deleting sold position', extra={'position': position}) await session.delete(position) - await session.commit() - elif position.quantity > trade.quantity: - logger.debug(f"Reducing quantity of position: {position}") - cost_per_share = position.cost_basis / position.quantity + else: # Partial sell position.cost_basis -= trade.quantity * cost_per_share position.quantity -= trade.quantity position.latest_price = float(trade.executed_price) - session.add(position) + session.add(position) + trade.profit_loss = profit_loss + session.add(trade) + # Update the trade with the calculated profit/loss + session.add(trade) + + # Commit the transaction await session.commit() - # Log after committing changes logger.info('Position updated', extra={'position': position}) except Exception as e: logger.error('Failed to update positions', extra={'error': str(e)}) - + await session.rollback() async def place_future_option_order(self, symbol, quantity, order_type, strategy, price=None): multiplier = futures_contract_size(symbol) @@ -227,12 +271,6 @@ async def _place_order_generic(self, symbol, quantity, order_type, strategy, pri success='yes' ) - # Calculate profit/loss if it's a sell order - if order_type == 'sell': - profit_loss = await self.db_manager.calculate_profit_loss(trade) - logger.info('Profit/Loss calculated', extra={'profit_loss': profit_loss}) - trade.profit_loss = profit_loss - # Update the trade and positions in the database async with self.Session() as session: session.add(trade) @@ -325,12 +363,10 @@ async def update_trade(self, session, trade_id, order_info): executed_price = order_info.get('filled_price', trade.price) trade.executed_price = executed_price - profit_loss = await self.db_manager.calculate_profit_loss(trade) - success = "success" if profit_loss > 0 else "failure" + success = "success" if trade.profit_loss > 0 else "failure" trade.executed_price = executed_price trade.success = success - trade.profit_loss = profit_loss await session.commit() logger.info('Trade updated', extra={'trade': trade}) except Exception as e: diff --git a/database/db_manager.py b/database/db_manager.py index 4aab0018..928beeca 100644 --- a/database/db_manager.py +++ b/database/db_manager.py @@ -75,6 +75,7 @@ async def get_position(self, broker, symbol, strategy): logger.error('Failed to retrieve position', extra={'error': str(e)}) return None + async def calculate_profit_loss(self, trade): async with self.Session() as session: try: @@ -88,10 +89,19 @@ async def calculate_profit_loss(self, trade): logger.error('Executed price is None, cannot calculate profit/loss', extra={'trade': trade}) return None - # Handling buy trades + # Handling buy trades that cover a short position if trade.order_type.lower() == 'buy': - logger.info('Buy order detected, no P/L calculation needed.', extra={'trade': trade}) - return profit_loss + position = await self.get_position(trade.broker, trade.symbol, trade.strategy) + if position and position.quantity < 0: # Detect if this is a short cover + logger.info('Short cover detected, calculating P/L.', extra={'trade': trade}) + + # Calculate P/L for covering short (short sell price - buy price) + cost_per_share = float(position.cost_basis) / abs(position.quantity) + profit_loss = (cost_per_share - current_price) * abs(trade.quantity) + logger.info(f'Short cover P/L calculated as: {profit_loss}', extra={'trade': trade}) + else: + logger.info('Regular buy order detected, no P/L calculation needed.', extra={'trade': trade}) + return profit_loss # Handling sell trades elif trade.order_type.lower() == 'sell': @@ -120,15 +130,27 @@ async def calculate_profit_loss(self, trade): logger.error('Failed to calculate profit/loss', extra={'error': str(e), 'trade': trade}) return None - async def calculate_partial_profit_loss(self, trade, position): try: profit_loss = None logger.info('Calculating partial profit/loss', extra={'trade': trade, 'position': position}) + if trade.order_type.lower() == 'sell': + # Partial sell for regular positions + logger.info('Partial sell order detected, calculating P/L.', extra={'trade': trade}) profit_loss = (float(trade.executed_price) - (float(position.cost_basis) / position.quantity)) * trade.quantity + + elif trade.order_type.lower() == 'buy' and position.quantity < 0: + # Partial short cover (buying back part of the short position) + logger.info('Partial short cover detected, calculating P/L.', extra={'trade': trade}) + + # Calculate P/L for covering a short (short sell price - cover price) + cost_per_share = float(position.cost_basis) / abs(position.quantity) + profit_loss = (cost_per_share - float(trade.executed_price)) * abs(trade.quantity) + logger.info('Partial profit/loss calculated', extra={'trade': trade, 'position': position, 'profit_loss': profit_loss}) return profit_loss + except Exception as e: logger.error('Failed to calculate partial profit/loss', extra={'error': str(e)}) return None @@ -187,3 +209,19 @@ async def rename_strategy(self, broker, old_strategy_name, new_strategy_name): except Exception as e: await session.rollback() logger.error('Failed to update strategy name', extra={'error': str(e)}) + + + async def get_profit_loss(self, trade_id): + async with self.Session() as session: + try: + logger.info('Retrieving profit/loss', extra={'trade': trade_id}) + result = await session.execute(select(Trade).filter_by(id=trade_id)) + trade = result.scalar() + if trade is None: + logger.warning(f"No trade found with id {trade.id}") + return None + logger.info('Profit/loss retrieved', extra={'trade': trade}) + return trade.profit_loss + except Exception as e: + logger.error('Failed to retrieve profit/loss', extra={'error': str(e)}) + return None diff --git a/tests/test_brokers.py b/tests/test_brokers.py index 729cb8bc..fa02165e 100644 --- a/tests/test_brokers.py +++ b/tests/test_brokers.py @@ -369,3 +369,215 @@ async def test_pl_calculation_no_position(session, broker): profit_loss = await broker.db_manager.calculate_profit_loss(trade) assert profit_loss is None + +@pytest.mark.asyncio +async def test_short_cover_full(session, broker): + # Create a short position + position = Position( + symbol="AAPL", + broker="dummy_broker", + quantity=-10, # Short 10 shares + latest_price=150.0, + cost_basis=1500.0, + last_updated=datetime.now(), + strategy="test_strategy" + ) + async with session.begin(): + session.add(position) + await session.commit() + + # Cover the short position + trade = Trade( + symbol="AAPL", + quantity=10, # Buying back 10 shares + price=140.0, + executed_price=140.0, + order_type="buy", + status="executed", + timestamp=datetime.now(), + broker="dummy_broker", + strategy="test_strategy", + profit_loss=None, + success="yes" + ) + async with session.begin(): + session.add(trade) + await session.commit() + await session.refresh(trade) + + # Update positions and check P/L + trade_id = trade.id + await broker.update_positions(trade_id, session) + # Get the trade's profit/loss from the database + profit_loss = await broker.db_manager.get_profit_loss(trade_id) + + # Ensure the short position was fully covered and deleted + result = await session.execute(select(Position).filter_by(symbol="AAPL")) + position = result.scalars().first() + assert position is None + + # Check P/L (should be profit since buy price < sell price) + assert profit_loss == 100.0 # (150 - 140) * 10 + +@pytest.mark.asyncio +async def test_short_cover_partial(session, broker): + # Create a short position + position = Position( + symbol="AAPL", + broker="dummy_broker", + quantity=-10, # Short 10 shares + latest_price=150.0, + cost_basis=1500.0, + last_updated=datetime.now(), + strategy="test_strategy" + ) + async with session.begin(): + session.add(position) + await session.commit() + + # Partially cover the short position (buying back 5 shares) + trade = Trade( + symbol="AAPL", + quantity=5, # Buying back 5 shares + price=140.0, + executed_price=140.0, + order_type="buy", + status="executed", + timestamp=datetime.now(), + broker="dummy_broker", + strategy="test_strategy", + profit_loss=None, + success="yes" + ) + async with session.begin(): + session.add(trade) + await session.commit() + await session.refresh(trade) + + # Update positions and check P/L + trade_id = trade.id + await broker.update_positions(trade_id, session) + profit_loss = await broker.db_manager.get_profit_loss(trade_id) + + # Ensure the position was partially covered (remaining short 5 shares) + result = await session.execute(select(Position).filter_by(symbol="AAPL")) + position = result.scalars().first() + assert position.quantity == -5 # Short 5 shares remaining + assert position.cost_basis == 750.0 # Updated cost basis for 5 remaining shares + + # Check P/L (should be profit since buy price < sell price) + assert profit_loss == 50.0 # (150 - 140) * 5 + + +@pytest.mark.asyncio +async def test_normal_buy_sell(session, broker): + # Buy 10 shares of AAPL + trade1 = Trade( + symbol="AAPL", + quantity=10, + price=150.0, + executed_price=150.0, + order_type="buy", + status="executed", + timestamp=datetime.now(), + broker="dummy_broker", + strategy="test_strategy", + profit_loss=None, + success="yes" + ) + async with session.begin(): + session.add(trade1) + await session.commit() + await session.refresh(trade1) + + # Update positions after the buy + await broker.update_positions(trade1.id, session) + + # Sell 5 shares of AAPL + trade2 = Trade( + symbol="AAPL", + quantity=5, + price=155.0, + executed_price=155.0, + order_type="sell", + status="executed", + timestamp=datetime.now(), + broker="dummy_broker", + strategy="test_strategy", + profit_loss=None, + success="yes" + ) + async with session.begin(): + session.add(trade2) + await session.commit() + await session.refresh(trade2) + + # Update positions after the sell + trade2_id = trade2.id + await broker.update_positions(trade2_id, session) + # Get the trade's profit/loss from the database + profit_loss = await broker.db_manager.get_profit_loss(trade2_id) + assert profit_loss == 25.0 + + # Check the position after the sell + result = await session.execute(select(Position).filter_by(symbol="AAPL")) + position = result.scalars().first() + assert position.quantity == 5 + + +@pytest.mark.asyncio +async def test_normal_sell_buy(session, broker): + # Sell 10 shares of AAPL + trade1 = Trade( + symbol="AAPL", + quantity=10, + price=150.0, + executed_price=150.0, + order_type="sell", + status="executed", + timestamp=datetime.now(), + broker="dummy_broker", + strategy="test_strategy", + profit_loss=None, + success="yes" + ) + async with session.begin(): + session.add(trade1) + await session.commit() + await session.refresh(trade1) + + # Update positions after the sell + await broker.update_positions(trade1.id, session) + session.commit() + session.refresh(trade1) + + # Buy 5 shares of AAPL + trade2 = Trade( + symbol="AAPL", + quantity=5, + price=155.0, + executed_price=155.0, + order_type="buy", + status="executed", + timestamp=datetime.now(), + broker="dummy_broker", + strategy="test_strategy", + profit_loss=None, + success="yes" + ) + async with session.begin(): + session.add(trade2) + session.commit() + await session.refresh(trade2) + + # Update positions after the buy + trade2_id = trade2.id + await broker.update_positions(trade2_id, session) + # Get the trade's profit/loss from the database + profit_loss = await broker.db_manager.get_profit_loss(trade2_id) + assert profit_loss == -25.0 + + # Check the position after the buy + result = await session.execute(select(Position).filter_by(symbol="AAPL")) + position = result.scalars().first() + assert position.quantity == -5