Skip to content

Commit

Permalink
Standardize setup_catalog fixture naming
Browse files Browse the repository at this point in the history
  • Loading branch information
cjdsellers committed Jan 5, 2024
1 parent 9789e10 commit b59bae2
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 42 deletions.
22 changes: 9 additions & 13 deletions nautilus_trader/test_kit/mocks/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@
# -------------------------------------------------------------------------------------------------

from pathlib import Path
from typing import Literal

from nautilus_trader.common.clock import TestClock
from nautilus_trader.common.logging import Logger
from nautilus_trader.common.providers import InstrumentProvider
from nautilus_trader.model.identifiers import Venue
from nautilus_trader.persistence.catalog.parquet import ParquetDataCatalog
from nautilus_trader.persistence.catalog.singleton import clear_singleton_instances
from nautilus_trader.persistence.wranglers import QuoteTickDataWrangler
Expand All @@ -27,7 +27,7 @@
from nautilus_trader.trading.filters import NewsEvent


AUDUSD_SIM = TestInstrumentProvider.default_fx_ccy("AUD/USD")
_AUDUSD_SIM = TestInstrumentProvider.default_fx_ccy("AUD/USD")


class NewsEventData(NewsEvent):
Expand All @@ -36,9 +36,9 @@ class NewsEventData(NewsEvent):
"""


def data_catalog_setup(
protocol: str,
path: str | Path | None = None,
def setup_catalog(
protocol: Literal["memory", "file"],
path: Path | str | None = None,
) -> ParquetDataCatalog:
if protocol not in ("memory", "file"):
raise ValueError("`protocol` should only be one of `memory` or `file` for testing")
Expand All @@ -62,21 +62,17 @@ def data_catalog_setup(
return catalog


def aud_usd_data_loader(catalog: ParquetDataCatalog) -> None:
from nautilus_trader.test_kit.providers import TestInstrumentProvider

instrument = TestInstrumentProvider.default_fx_ccy("AUD/USD", venue=Venue("SIM"))

def load_catalog_with_stub_quote_ticks_audusd(catalog: ParquetDataCatalog) -> None:
clock = TestClock()
logger = Logger(clock)

instrument_provider = InstrumentProvider(
logger=logger,
)
instrument_provider.add(instrument)
instrument_provider.add(_AUDUSD_SIM)

wrangler = QuoteTickDataWrangler(instrument)
wrangler = QuoteTickDataWrangler(_AUDUSD_SIM)
ticks = wrangler.process(TestDataProvider().read_csv_ticks("truefx/audusd-ticks.csv"))
ticks.sort(key=lambda x: x.ts_init) # CAUTION: data was not originally sorted
catalog.write_data([instrument])
catalog.write_data([_AUDUSD_SIM])
catalog.write_data(ticks)
6 changes: 3 additions & 3 deletions tests/acceptance_tests/test_backtest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
from nautilus_trader.persistence.wranglers import BarDataWrangler
from nautilus_trader.persistence.wranglers import QuoteTickDataWrangler
from nautilus_trader.persistence.wranglers import TradeTickDataWrangler
from nautilus_trader.test_kit.mocks.data import data_catalog_setup
from nautilus_trader.test_kit.mocks.data import setup_catalog
from nautilus_trader.test_kit.providers import TestDataProvider
from nautilus_trader.test_kit.providers import TestInstrumentProvider
from tests import TEST_DATA_DIR
Expand Down Expand Up @@ -697,7 +697,7 @@ def test_run_ema_cross_with_tick_bar_spec(self):
class TestBacktestAcceptanceTestsOrderBookImbalance:
def setup(self):
# Fixture Setup
data_catalog_setup(protocol="memory")
setup_catalog(protocol="memory")

config = BacktestEngineConfig(
logging=LoggingConfig(bypass_logging=True),
Expand Down Expand Up @@ -758,7 +758,7 @@ def test_run_order_book_imbalance(self):
class TestBacktestAcceptanceTestsMarketMaking:
def setup(self):
# Fixture Setup
data_catalog_setup(protocol="memory")
setup_catalog(protocol="memory")

config = BacktestEngineConfig(
logging=LoggingConfig(bypass_logging=True),
Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/adapters/betfair/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from nautilus_trader.model.identifiers import AccountId
from nautilus_trader.model.identifiers import Venue
from nautilus_trader.persistence.catalog import ParquetDataCatalog
from nautilus_trader.test_kit.mocks.data import data_catalog_setup
from nautilus_trader.test_kit.mocks.data import setup_catalog
from nautilus_trader.test_kit.stubs.events import TestEventStubs
from tests.integration_tests.adapters.betfair.test_kit import BetfairResponses
from tests.integration_tests.adapters.betfair.test_kit import BetfairTestStubs
Expand Down Expand Up @@ -169,7 +169,7 @@ def exec_client(

@pytest.fixture()
def data_catalog() -> ParquetDataCatalog:
catalog: ParquetDataCatalog = data_catalog_setup(protocol="memory", path="/")
catalog: ParquetDataCatalog = setup_catalog(protocol="memory", path="/")
load_betfair_data(catalog)
return catalog

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,14 @@
from nautilus_trader.model.objects import Price
from nautilus_trader.model.objects import Quantity
from nautilus_trader.serialization.arrow.serializer import ArrowSerializer
from nautilus_trader.test_kit.mocks.data import data_catalog_setup
from nautilus_trader.test_kit.mocks.data import setup_catalog
from tests.integration_tests.adapters.betfair.test_kit import betting_instrument
from tests.integration_tests.adapters.betfair.test_kit import load_betfair_data


class TestBetfairPersistence:
def setup(self):
self.catalog = data_catalog_setup(protocol="memory", path="/catalog")
self.catalog = setup_catalog(protocol="memory", path="/catalog")
self.fs = self.catalog.fs
self.instrument = betting_instrument()

Expand Down
4 changes: 2 additions & 2 deletions tests/integration_tests/adapters/databento/test_loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
DATABENTO_TEST_DATA_DIR = TEST_DATA_DIR / "databento"


@pytest.mark.skip(reason="Used for development")
@pytest.mark.skip(reason="development_only")
def test_loader_definition_glbx_all_symbols() -> None:
# Arrange
loader = DatabentoDataLoader()
Expand All @@ -59,7 +59,7 @@ def test_loader_definition_glbx_all_symbols() -> None:
assert len(data) == 10_000_000


@pytest.mark.skip(reason="Used for development")
@pytest.mark.skip(reason="development_only")
def test_loader_spy_xnas_itch_mbo() -> None:
# Arrange
loader = DatabentoDataLoader()
Expand Down
10 changes: 5 additions & 5 deletions tests/unit_tests/backtest/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
from nautilus_trader.model.identifiers import InstrumentId
from nautilus_trader.model.identifiers import Venue
from nautilus_trader.test_kit.mocks.data import NewsEventData
from nautilus_trader.test_kit.mocks.data import aud_usd_data_loader
from nautilus_trader.test_kit.mocks.data import data_catalog_setup
from nautilus_trader.test_kit.mocks.data import load_catalog_with_stub_quote_ticks_audusd
from nautilus_trader.test_kit.mocks.data import setup_catalog
from nautilus_trader.test_kit.providers import TestDataProvider
from nautilus_trader.test_kit.providers import TestInstrumentProvider
from nautilus_trader.test_kit.stubs.config import TestConfigStubs
Expand All @@ -51,8 +51,8 @@
class TestBacktestConfig:
def setup(self):
self.fs_protocol = "file"
self.catalog = data_catalog_setup(protocol=self.fs_protocol)
aud_usd_data_loader(self.catalog)
self.catalog = setup_catalog(protocol=self.fs_protocol)
load_catalog_with_stub_quote_ticks_audusd(self.catalog)
self.venue = Venue("SIM")
self.instrument = TestInstrumentProvider.default_fx_ccy("AUD/USD", venue=self.venue)
self.backtest_config = TestConfigStubs.backtest_run_config(catalog=self.catalog)
Expand Down Expand Up @@ -191,7 +191,7 @@ def test_backtest_config_to_json(self):

class TestBacktestConfigParsing:
def setup(self):
self.catalog = data_catalog_setup(protocol="memory", path="/.nautilus/")
self.catalog = setup_catalog(protocol="memory", path="/.nautilus/")
self.venue = Venue("SIM")
self.instrument = TestInstrumentProvider.default_fx_ccy("AUD/USD", venue=self.venue)
self.backtest_config = TestConfigStubs.backtest_run_config(catalog=self.catalog)
Expand Down
10 changes: 6 additions & 4 deletions tests/unit_tests/backtest/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
# -------------------------------------------------------------------------------------------------

import sys
import tempfile
from decimal import Decimal
from pathlib import Path

import pandas as pd
import pytest
Expand Down Expand Up @@ -215,11 +215,10 @@ def test_account_state_timestamp(self):
assert report.index[0] == start

@pytest.mark.skipif(sys.platform == "win32", reason="Failing on windows")
def test_persistence_files_cleaned_up(self):
def test_persistence_files_cleaned_up(self, tmp_path: Path) -> None:
# Arrange
temp_dir = tempfile.mkdtemp()
catalog = ParquetDataCatalog(
path=str(temp_dir),
path=tmp_path,
fs_protocol="file",
)
config = TestConfigStubs.backtest_engine_config(persist=True, catalog=catalog)
Expand All @@ -228,9 +227,12 @@ def test_persistence_files_cleaned_up(self):
instrument=self.usdjpy,
ticks=TestDataStubs.quote_ticks_usdjpy(),
)

# Act
engine.run()
engine.dispose()

# Assert
assert all(f.closed for f in engine.kernel.writer._files.values())

def test_backtest_engine_multiple_runs(self):
Expand Down
10 changes: 5 additions & 5 deletions tests/unit_tests/backtest/test_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from decimal import Decimal

import msgspec.json
import msgspec
import pytest

from nautilus_trader.backtest.engine import BacktestEngineConfig
Expand All @@ -29,13 +29,13 @@
from nautilus_trader.model.data import QuoteTick
from nautilus_trader.model.identifiers import InstrumentId
from nautilus_trader.persistence.funcs import parse_bytes
from nautilus_trader.test_kit.mocks.data import aud_usd_data_loader
from nautilus_trader.test_kit.mocks.data import data_catalog_setup
from nautilus_trader.test_kit.mocks.data import load_catalog_with_stub_quote_ticks_audusd
from nautilus_trader.test_kit.mocks.data import setup_catalog


class TestBacktestNode:
def setup(self):
self.catalog = data_catalog_setup(protocol="file", path="./data_catalog")
self.catalog = setup_catalog(protocol="file", path="./data_catalog")
self.venue_config = BacktestVenueConfig(
name="SIM",
oms_type="HEDGING",
Expand Down Expand Up @@ -76,7 +76,7 @@ def setup(self):
data=[self.data_config],
),
]
aud_usd_data_loader(self.catalog) # Load sample data
load_catalog_with_stub_quote_ticks_audusd(self.catalog) # Load sample data

def test_init(self):
node = BacktestNode(configs=self.backtest_configs)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit_tests/common/test_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
from nautilus_trader.portfolio.portfolio import Portfolio
from nautilus_trader.test_kit.mocks.actors import KaboomActor
from nautilus_trader.test_kit.mocks.actors import MockActor
from nautilus_trader.test_kit.mocks.data import data_catalog_setup
from nautilus_trader.test_kit.mocks.data import setup_catalog
from nautilus_trader.test_kit.providers import TestInstrumentProvider
from nautilus_trader.test_kit.stubs.component import TestComponentStubs
from nautilus_trader.test_kit.stubs.data import UNIX_EPOCH
Expand Down Expand Up @@ -2051,7 +2051,7 @@ def test_publish_data_persist(self) -> None:
clock=self.clock,
logger=self.logger,
)
catalog = data_catalog_setup(protocol="memory", path="/catalog")
catalog = setup_catalog(protocol="memory", path="/catalog")

writer = StreamingFeatherWriter(
path=catalog.path,
Expand Down
6 changes: 3 additions & 3 deletions tests/unit_tests/data/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
from nautilus_trader.model.objects import Price
from nautilus_trader.model.objects import Quantity
from nautilus_trader.portfolio.portfolio import Portfolio
from nautilus_trader.test_kit.mocks.data import data_catalog_setup
from nautilus_trader.test_kit.mocks.data import setup_catalog
from nautilus_trader.test_kit.providers import TestInstrumentProvider
from nautilus_trader.test_kit.stubs.component import TestComponentStubs
from nautilus_trader.test_kit.stubs.data import TestDataStubs
Expand Down Expand Up @@ -2126,7 +2126,7 @@ def test_request_instruments_reaches_client(self):
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on windows")
def test_request_instrument_when_catalog_registered(self):
# Arrange
catalog = data_catalog_setup(protocol="file")
catalog = setup_catalog(protocol="file")

idealpro = Venue("IDEALPRO")
instrument = TestInstrumentProvider.default_fx_ccy("AUD/USD", venue=idealpro)
Expand Down Expand Up @@ -2156,7 +2156,7 @@ def test_request_instrument_when_catalog_registered(self):
@pytest.mark.skipif(sys.platform == "win32", reason="Failing on windows")
def test_request_instruments_for_venue_when_catalog_registered(self):
# Arrange
catalog = data_catalog_setup(protocol="file")
catalog = setup_catalog(protocol="file")

idealpro = Venue("IDEALPRO")
instrument = TestInstrumentProvider.default_fx_ccy("AUD/USD", venue=idealpro)
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/model/test_orderbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -664,7 +664,7 @@ def make_delta(side: OrderSide, price: float, size: float, ts):
assert book.ts_last == new.ts_last
assert book.sequence == new.sequence

@pytest.mark.skip(reason="Used for development")
@pytest.mark.skip(reason="development_only")
def test_orderbook_spy_xnas_itch_mbo_l3(self) -> None:
# Arrange
loader = DatabentoDataLoader()
Expand Down

0 comments on commit b59bae2

Please sign in to comment.