Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Fallback to Bar Price in Cache and Portfolio when Ticks unavailable #1594

Merged
merged 2 commits into from
Apr 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions nautilus_trader/cache/cache.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# -------------------------------------------------------------------------------------------------

from cpython.datetime cimport datetime
from cpython.datetime cimport timedelta
from libc.stdint cimport uint64_t

from nautilus_trader.accounting.accounts.base cimport Account
Expand All @@ -22,13 +23,15 @@ from nautilus_trader.cache.base cimport CacheFacade
from nautilus_trader.cache.facade cimport CacheDatabaseFacade
from nautilus_trader.common.actor cimport Actor
from nautilus_trader.common.component cimport Logger
from nautilus_trader.core.rust.model cimport AggregationSource
from nautilus_trader.core.rust.model cimport OmsType
from nautilus_trader.core.rust.model cimport OrderSide
from nautilus_trader.core.rust.model cimport PositionSide
from nautilus_trader.execution.messages cimport SubmitOrder
from nautilus_trader.execution.messages cimport SubmitOrderList
from nautilus_trader.model.book cimport OrderBook
from nautilus_trader.model.data cimport Bar
from nautilus_trader.model.data cimport BarType
from nautilus_trader.model.data cimport QuoteTick
from nautilus_trader.model.data cimport TradeTick
from nautilus_trader.model.identifiers cimport AccountId
Expand Down Expand Up @@ -175,3 +178,12 @@ cdef class Cache(CacheFacade):
cpdef void delete_strategy(self, Strategy strategy)

cpdef void heartbeat(self, datetime timestamp)

cdef timedelta _get_timedelta(self, BarType bar_type)

cpdef list bar_types(
self,
InstrumentId instrument_id=*,
object price_type=*,
AggregationSource aggregation_source=*,
)
61 changes: 59 additions & 2 deletions nautilus_trader/cache/cache.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ from decimal import Decimal
from nautilus_trader.cache.config import CacheConfig

from cpython.datetime cimport datetime
from cpython.datetime cimport timedelta
from libc.stdint cimport uint8_t
from libc.stdint cimport uint64_t

Expand All @@ -37,6 +38,10 @@ from nautilus_trader.core.rust.model cimport OmsType
from nautilus_trader.core.rust.model cimport OrderSide
from nautilus_trader.core.rust.model cimport PositionSide
from nautilus_trader.core.rust.model cimport PriceType

from nautilus_trader.core.rust.model import PriceType as PriceType_py

from nautilus_trader.core.rust.model cimport AggregationSource
from nautilus_trader.core.rust.model cimport TriggerType
from nautilus_trader.execution.messages cimport SubmitOrder
from nautilus_trader.model.data cimport Bar
Expand Down Expand Up @@ -2033,10 +2038,20 @@ cdef class Cache(CacheFacade):

if price_type == PriceType.LAST:
trade_tick = self.trade_tick(instrument_id)
return trade_tick.price if trade_tick is not None else None
if trade_tick is not None:
return trade_tick.price
else:
quote_tick = self.quote_tick(instrument_id)
return quote_tick.extract_price(price_type) if quote_tick is not None else None
if quote_tick is not None:
return quote_tick.extract_price(price_type)

# Fallback to bar pricing
cdef Bar bar
cdef list bar_types = self.bar_types(instrument_id, price_type, AggregationSource.EXTERNAL)
if bar_types:
bar = self.bar(bar_types[0]) # Bar with smallest timedelta
if bar is not None:
return bar.close

cpdef OrderBook order_book(self, InstrumentId instrument_id):
"""
Expand Down Expand Up @@ -2439,6 +2454,48 @@ cdef class Cache(CacheFacade):
"""
return [x for x in self._instruments.values() if venue is None or venue == x.id.venue]

cdef timedelta _get_timedelta(self, BarType bar_type):
""" Helper method to get the timedelta from a BarType. """
return bar_type.spec.timedelta

cpdef list bar_types(
self,
InstrumentId instrument_id = None,
object price_type = None,
AggregationSource aggregation_source = AggregationSource.EXTERNAL,
):
"""
Return a list of BarType for the given instrument ID and price type.

Parameters
----------
instrument_id : InstrumentId, optional
The instrument ID to filter the BarType objects. If None, no filtering is done based on instrument ID.
price_type : PriceType or None, optional
The price type to filter the BarType objects. If None, no filtering is done based on price type.
aggregation_source : AggregationSource, default AggregationSource.EXTERNAL
The aggregation source to filter the BarType objects.
Returns
-------
list[BarType]
"""
Condition.type_or_none(instrument_id, InstrumentId, "instrument_id")
Condition.type_or_none(price_type, PriceType_py, "price_type")

cdef list bar_types = [bar_type for bar_type in self._bars.keys()
if bar_type.aggregation_source == aggregation_source]

if instrument_id is not None:
bar_types = [bar_type for bar_type in bar_types if bar_type.instrument_id == instrument_id]

if price_type is not None:
bar_types = [bar_type for bar_type in bar_types if bar_type.spec.price_type == price_type]

if instrument_id and price_type:
bar_types.sort(key=self._get_timedelta)

return bar_types

# -- SYNTHETIC QUERIES ----------------------------------------------------------------------------

cpdef SyntheticInstrument synthetic(self, InstrumentId instrument_id):
Expand Down
39 changes: 27 additions & 12 deletions nautilus_trader/portfolio/portfolio.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,13 @@ cdef class Portfolio(PortfolioFacade):
)
return None # Cannot calculate

if position.side == PositionSide.FLAT:
self._log.error(
f"Cannot calculate net exposures: "
f"position is flat for {position.instrument_id}"
)
continue # Nothing to calculate

last = self._get_last_price(position)
if last is None:
self._log.error(
Expand Down Expand Up @@ -1070,6 +1077,9 @@ cdef class Portfolio(PortfolioFacade):
if position.instrument_id != instrument_id:
continue # Nothing to calculate

if position.side == PositionSide.FLAT:
continue # Nothing to calculate

last = self._get_last_price(position)
if last is None:
self._log.debug(
Expand Down Expand Up @@ -1102,19 +1112,24 @@ cdef class Portfolio(PortfolioFacade):
return Money(total_pnl, currency)

cdef Price _get_last_price(self, Position position):
cdef QuoteTick quote_tick = self._cache.quote_tick(position.instrument_id)
if quote_tick is not None:
if position.side == PositionSide.LONG:
return quote_tick.bid_price
elif position.side == PositionSide.SHORT:
return quote_tick.ask_price
else: # pragma: no cover (design-time error)
raise RuntimeError(
f"invalid `PositionSide`, was {position_side_to_str(position.side)}",
)
cdef PriceType price_type
if position.side == PositionSide.LONG:
price_type = PriceType.BID
elif position.side == PositionSide.SHORT:
price_type = PriceType.ASK
else: # pragma: no cover (design-time error)
raise RuntimeError(
f"invalid `PositionSide`, was {position_side_to_str(position.side)}",
)

cdef TradeTick trade_tick = self._cache.trade_tick(position.instrument_id)
return trade_tick.price if trade_tick is not None else None
cdef Price price
return self._cache.price(
instrument_id=position.instrument_id,
price_type=price_type,
) or self._cache.price(
instrument_id=position.instrument_id,
price_type=PriceType.LAST,
)

cdef double _calculate_xrate_to_base(self, Account account, Instrument instrument, OrderSide side):
if account.base_currency is not None:
Expand Down
21 changes: 21 additions & 0 deletions nautilus_trader/test_kit/stubs/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ def quote_ticks_usdjpy() -> list[QuoteTick]:
def bar_spec_1min_bid() -> BarSpecification:
return BarSpecification(1, BarAggregation.MINUTE, PriceType.BID)

@staticmethod
def bar_spec_5min_bid() -> BarSpecification:
return BarSpecification(5, BarAggregation.MINUTE, PriceType.BID)

@staticmethod
def bar_spec_1min_ask() -> BarSpecification:
return BarSpecification(1, BarAggregation.MINUTE, PriceType.ASK)
Expand All @@ -144,6 +148,10 @@ def bar_spec_100tick_last() -> BarSpecification:
def bartype_audusd_1min_bid() -> BarType:
return BarType(TestIdStubs.audusd_id(), TestDataStubs.bar_spec_1min_bid())

@staticmethod
def bartype_audusd_5min_bid() -> BarType:
return BarType(TestIdStubs.audusd_id(), TestDataStubs.bar_spec_5min_bid())

@staticmethod
def bartype_audusd_1min_ask() -> BarType:
return BarType(TestIdStubs.audusd_id(), TestDataStubs.bar_spec_1min_ask())
Expand Down Expand Up @@ -189,6 +197,19 @@ def bar_5decimal() -> Bar:
ts_init=0,
)

@staticmethod
def bar_5decimal_5min_bid() -> Bar:
return Bar(
bar_type=TestDataStubs.bartype_audusd_5min_bid(),
open=Price.from_str("1.00101"),
high=Price.from_str("1.00208"),
low=Price.from_str("1.00100"),
close=Price.from_str("1.00205"),
volume=Quantity.from_int(1_000_000),
ts_event=0,
ts_init=0,
)

@staticmethod
def bar_3decimal() -> Bar:
return Bar(
Expand Down
65 changes: 65 additions & 0 deletions tests/unit_tests/cache/test_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import pytest

from nautilus_trader.core.rust.model import AggregationSource
from nautilus_trader.model.currencies import AUD
from nautilus_trader.model.currencies import JPY
from nautilus_trader.model.currencies import USD
Expand Down Expand Up @@ -387,6 +388,70 @@ def test_price_given_various_quote_price_types_when_quote_tick_returns_expected_
# Assert
assert result == expected

@pytest.mark.parametrize(
("price_type", "expected"),
[[PriceType.BID, Price.from_str("1.00003")], [PriceType.LAST, None]],
)
def test_price_returned_with_external_bars(self, price_type, expected):
# Arrange
self.cache.add_bar(TestDataStubs.bar_5decimal())
self.cache.add_bar(TestDataStubs.bar_5decimal_5min_bid())
self.cache.add_bar(TestDataStubs.bar_3decimal())

# Act
result = self.cache.price(AUDUSD_SIM.id, price_type)

# Assert
assert result == expected

@pytest.mark.parametrize(
("instrument_id", "price_type", "aggregation_source", "expected"),
[
[
AUDUSD_SIM.id,
PriceType.BID,
AggregationSource.EXTERNAL,
[TestDataStubs.bartype_audusd_1min_bid(), TestDataStubs.bartype_audusd_5min_bid()],
],
[AUDUSD_SIM.id, PriceType.BID, AggregationSource.INTERNAL, []],
[AUDUSD_SIM.id, PriceType.ASK, AggregationSource.EXTERNAL, []],
[ETHUSDT_BINANCE.id, PriceType.BID, AggregationSource.EXTERNAL, []],
],
)
def test_retrieved_bar_types_match_expected(
self,
instrument_id,
price_type,
aggregation_source,
expected,
):
# Arrange
self.cache.add_bar(TestDataStubs.bar_5decimal())
self.cache.add_bar(TestDataStubs.bar_5decimal_5min_bid())
self.cache.add_bar(TestDataStubs.bar_3decimal())

# Act
result = self.cache.bar_types(
instrument_id=instrument_id,
price_type=price_type,
aggregation_source=aggregation_source,
)

# Assert
assert result == expected

def test_retrieved_all_bar_types_match_expected(self):
# Arrange
self.cache.add_bar(TestDataStubs.bar_5decimal())
self.cache.add_bar(TestDataStubs.bar_5decimal_5min_bid())
self.cache.add_bar(TestDataStubs.bar_3decimal())

# Act
result = self.cache.bar_types()

# Assert
assert len(result) == 3

def test_quote_tick_when_index_out_of_range_returns_none(self):
# Arrange
tick = TestDataStubs.quote_tick()
Expand Down