From 04bda9eb571c58a18a389bcd3dc8e6caef28ec9b Mon Sep 17 00:00:00 2001 From: Chris Sellers Date: Sun, 21 Apr 2024 10:25:41 +1000 Subject: [PATCH] Fix ParquetDataCatalog bar queries by instrument_id --- RELEASES.md | 2 +- .../persistence/catalog/parquet.py | 17 ++++++++++++-- tests/unit_tests/persistence/test_catalog.py | 22 ++++++++++++++++++- 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/RELEASES.md b/RELEASES.md index 57b19d21e4d0..4f25cd44276d 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -9,7 +9,7 @@ None None ### Fixes -None +- Fixed `ParquetDataCatalog` bar queries by `instrument_id` which were no longer returning data (the intent is to use `bar_type`, however using `instrument_id` now returns all matching bars) --- diff --git a/nautilus_trader/persistence/catalog/parquet.py b/nautilus_trader/persistence/catalog/parquet.py index d9793df276eb..3068a592254c 100644 --- a/nautilus_trader/persistence/catalog/parquet.py +++ b/nautilus_trader/persistence/catalog/parquet.py @@ -421,10 +421,23 @@ def backend_session( # Parse the parent directory which *should* be the instrument ID, # this prevents us matching all instrument ID substrings. dir = path.split("/")[-2] - if instrument_ids and not any(dir == urisafe_instrument_id(x) for x in instrument_ids): - continue + + # Filter by instrument ID + if data_cls == Bar: + if instrument_ids and not any( + dir.startswith(urisafe_instrument_id(x) + "-") for x in instrument_ids + ): + continue + else: + if instrument_ids and not any( + dir == urisafe_instrument_id(x) for x in instrument_ids + ): + continue + + # Filter by bar type if bar_types and not any(dir == urisafe_instrument_id(x) for x in bar_types): continue + table = f"{file_prefix}_{idx}" query = self._build_query( table, diff --git a/tests/unit_tests/persistence/test_catalog.py b/tests/unit_tests/persistence/test_catalog.py index 4af347ab62fa..3bb577dd3d38 100644 --- a/tests/unit_tests/persistence/test_catalog.py +++ b/tests/unit_tests/persistence/test_catalog.py @@ -241,7 +241,7 @@ def test_catalog_custom_data(catalog: ParquetDataCatalog) -> None: assert isinstance(data[0], CustomData) -def test_catalog_bars(catalog: ParquetDataCatalog) -> None: +def test_catalog_bars_querying_by_bar_type(catalog: ParquetDataCatalog) -> None: # Arrange bar_type = TestDataStubs.bartype_adabtc_binance_1min_last() instrument = TestInstrumentProvider.adabtc_binance() @@ -261,6 +261,24 @@ def test_catalog_bars(catalog: ParquetDataCatalog) -> None: assert len(bars) == len(stub_bars) == 10 +def test_catalog_bars_querying_by_instrument_id(catalog: ParquetDataCatalog) -> None: + # Arrange + bar_type = TestDataStubs.bartype_adabtc_binance_1min_last() + instrument = TestInstrumentProvider.adabtc_binance() + stub_bars = TestDataStubs.binance_bars_from_csv( + "ADABTC-1m-2021-11-27.csv", + bar_type, + instrument, + ) + + # Act + catalog.write_data(stub_bars) + + # Assert + bars = catalog.bars(instrument_ids=[instrument.id.value]) + assert len(bars) == len(stub_bars) == 10 + + def test_catalog_write_pyo3_order_book_depth10(catalog: ParquetDataCatalog) -> None: # Arrange instrument = TestInstrumentProvider.ethusdt_binance() @@ -339,9 +357,11 @@ def test_catalog_multiple_bar_types(catalog: ParquetDataCatalog) -> None: # Assert bars1 = catalog.bars(bar_types=[str(bar_type1)]) bars2 = catalog.bars(bar_types=[str(bar_type2)]) + bars3 = catalog.bars(instrument_ids=[instrument1.id.value]) all_bars = catalog.bars() assert len(bars1) == 10 assert len(bars2) == 10 + assert len(bars3) == 10 assert len(all_bars) == 20