From 7732e8e3d6180efd9aa247cb9d10ec9351494189 Mon Sep 17 00:00:00 2001 From: rsmb7z <105105941+rsmb7z@users.noreply.github.com> Date: Fri, 22 Dec 2023 00:37:35 +0300 Subject: [PATCH] Add more tests for TALibIndicatorManager (#1422) --- nautilus_trader/indicators/ta_lib/manager.py | 95 +++-- .../indicators/test_talib_func_wrapper.py | 5 + .../test_talib_indicator_manager.py | 366 +++++++++++++++++- 3 files changed, 412 insertions(+), 54 deletions(-) diff --git a/nautilus_trader/indicators/ta_lib/manager.py b/nautilus_trader/indicators/ta_lib/manager.py index 8ee4bfed3aae..fe50b7281a1e 100644 --- a/nautilus_trader/indicators/ta_lib/manager.py +++ b/nautilus_trader/indicators/ta_lib/manager.py @@ -260,6 +260,9 @@ def __init__( logger = Logger(clock) self._log = LoggerAdapter(component_name=repr(self), logger=logger) + # Initialize with empty indicators (acts as OHLCV placeholder in case no indicators are set) + self.set_indicators(()) + def change_logger(self, logger: Logger): PyCondition().type(logger, Logger, "logger") self._log = LoggerAdapter(component_name=repr(self), logger=logger) @@ -297,6 +300,10 @@ def set_indicators(self, indicators: tuple[TAFunctionWrapper, ...]) -> None: info levels, respectively. """ + if self.initialized: + self._log.info("Indicator already initialized. Skipping set_indicators.") + return + self._log.debug(f"Setting indicators {indicators}") PyCondition().list_type(list(indicators), TAFunctionWrapper, "ta_functions") @@ -308,9 +315,9 @@ def set_indicators(self, indicators: tuple[TAFunctionWrapper, ...]) -> None: output_names.extend(indicator.output_names) lookback = max(lookback, indicator.fn.lookback) - self.output_names = tuple(output_names) self._stable_period = lookback + self._period - self._input_deque = deque(maxlen=self._stable_period) + self._input_deque = deque(maxlen=lookback + 1) + self.output_names = tuple(output_names) # Initialize the output dtypes self._output_dtypes = [ @@ -343,39 +350,38 @@ def bar_type(self) -> BarType: def period(self) -> int: return self._period - def _calculate_ta(self, append: bool = True) -> None: + def _update_ta_outputs(self, append: bool = True) -> None: """ - Calculate technical analysis indicators and update the output deque. + Update the output deque with calculated technical analysis indicators. - This private method computes the output values for technical analysis indicators - set in the instance. It first initializes a combined output array with the base - values (like 'open', 'high', 'low', 'close', 'volume', and 'ts_event') from the - most recent entry in the input deque. It then iterates through each indicator, - computes its output, and updates the combined output array accordingly. - Depending on the 'append' flag, the method either appends or updates the latest - entry in the output deque with the combined output. + This private method computes and updates the output values for technical + analysis indicators based on the latest data in the input deque. It initializes + a combined output array with base values (e.g., 'open', 'high', 'low', 'close', + 'volume', 'ts_event') from the most recent input deque entry. Each indicator's + output is calculated and used to update the combined output array. The updated + data is either appended to or replaces the latest entry in the output deque, + depending on the value of the 'append' argument. Args: ---- - append : bool, optional - A flag to determine whether to append the new output to the output deque (True) or - to replace the most recent output (False). Defaults to True. + append (bool): Determines whether to append the new output to the output + deque (True) or replace the most recent output (False). + Defaults to True. - Returns + Returns: ------- - None + None The method performs the following steps: - - Initializes a combined output array with the default data types. - - Extracts the base values from the most recent entry in the input deque. + - Initializes a combined output array with base values from the latest input + deque entry. - Iterates through each indicator, calculates its output, and updates the combined output array. - - Depending on the 'append' flag, either appends the combined output to the - output deque or replaces its most recent entry. - - Resets the internal output array to ensure that it is rebuilt during the - next access. + - Appends the combined output to the output deque or replaces its most recent + entry based on the 'append' flag. + - Resets the internal output array for reconstruction during the next access. - This method logs actions at the debug level for tracking the calculation and + This method logs actions at the debug level to track the calculation and updating process. """ @@ -390,31 +396,30 @@ def _calculate_ta(self, append: bool = True) -> None: combined_output["close"] = self._input_deque[-1]["close"].item() combined_output["volume"] = self._input_deque[-1]["volume"].item() - input_array = None + input_array = np.concatenate(self._input_deque) for indicator in self._indicators: - if input_array is None: - input_array = np.concatenate(self._input_deque) - period = indicator.fn.lookback + 1 - inputs_dict = {name: input_array[name][-period:] for name in input_array.dtype.names} + self._log.debug(f"Calculating {indicator.name} outputs.") + inputs_dict = {name: input_array[name] for name in input_array.dtype.names} indicator.fn.set_input_arrays(inputs_dict) results = indicator.fn.run() if len(indicator.output_names) == 1: + self._log.debug("Single output.") combined_output[indicator.output_names[0]] = results[-1] else: + self._log.debug("Multiple outputs.") for i, output_name in enumerate(indicator.output_names): combined_output[output_name] = results[i][-1] - if self.initialized: - if append: - self._log.debug("Appending output.") - self._output_deque.append(combined_output) - else: - self._log.debug("Prepending output.") - self._output_deque[-1] = combined_output + if append: + self._log.debug("Appending output.") + self._output_deque.append(combined_output) + else: + self._log.debug("Prepending output.") + self._output_deque[-1] = combined_output - # Reset output array to force rebuild on next access - self._output_array = None + # Reset output array to force rebuild on next access + self._output_array = None def _increment_count(self) -> None: self.count += 1 @@ -424,7 +429,6 @@ def _increment_count(self) -> None: if self.count >= self._stable_period: self._set_initialized(True) self._log.info(f"Initialized with {self.count} bars") - self._calculate_ta() # Immediately make the first calculation def value(self, name: str, index=0): """ @@ -478,7 +482,7 @@ def value(self, name: str, index=0): return self.output_array[name][translated_index] @property - def output_array(self) -> np.recarray: + def output_array(self) -> np.recarray | None: """ Retrieve or generate the output array for the indicator. @@ -512,7 +516,7 @@ def output_array(self) -> np.recarray: self._log.debug("Using cached output array.") return self._output_array - def generate_output_array(self, truncate: bool) -> np.recarray: + def generate_output_array(self, truncate: bool) -> np.recarray | None: """ Generate the output array for the indicator, either truncated or complete. @@ -551,6 +555,10 @@ def generate_output_array(self, truncate: bool) -> np.recarray: ``` """ + if not self.initialized: + self._log.info("Indicator not initialized. Returning None.") + return None + if truncate: self._log.debug("Generating truncated output array.") output_array = np.concatenate(list(self._output_deque)[-self.period :]) @@ -705,11 +713,14 @@ def handle_bar(self, bar: Bar) -> None: if bar.ts_event == self._last_ts_event: self._input_deque[-1] = bar_data - self._calculate_ta(append=False) + self._update_ta_outputs(append=False) elif bar.ts_event > self._last_ts_event: self._input_deque.append(bar_data) self._increment_count() - self._calculate_ta() + self._update_ta_outputs() else: self._data_error_counter += 1 self._log.error(f"Received out of sync bar: {bar!r}") + return + + self._last_ts_event = bar.ts_event diff --git a/tests/unit_tests/indicators/test_talib_func_wrapper.py b/tests/unit_tests/indicators/test_talib_func_wrapper.py index dbaac075c324..810132c5f209 100644 --- a/tests/unit_tests/indicators/test_talib_func_wrapper.py +++ b/tests/unit_tests/indicators/test_talib_func_wrapper.py @@ -73,6 +73,11 @@ def test_from_str_invalid(): TAFunctionWrapper.from_str("INVALID_5") +def test_from_str_invalid_params(): + with pytest.raises(ValueError): + TAFunctionWrapper.from_str("SMA_20_10") + + def test_from_list_of_str(): indicators = ["SMA_5", "EMA_10"] wrappers = TAFunctionWrapper.from_list_of_str(indicators) diff --git a/tests/unit_tests/indicators/test_talib_indicator_manager.py b/tests/unit_tests/indicators/test_talib_indicator_manager.py index 931cebf6b780..e74f80727209 100644 --- a/tests/unit_tests/indicators/test_talib_indicator_manager.py +++ b/tests/unit_tests/indicators/test_talib_indicator_manager.py @@ -14,12 +14,20 @@ # ------------------------------------------------------------------------------------------------- import importlib.util +import inspect import sys from unittest.mock import Mock +import numpy as np import pytest +from nautilus_trader.model.data import Bar from nautilus_trader.model.data import BarType +from nautilus_trader.model.objects import Price +from nautilus_trader.model.objects import Quantity +from nautilus_trader.persistence.wranglers import BarDataWrangler +from nautilus_trader.test_kit.providers import TestDataProvider +from nautilus_trader.test_kit.providers import TestInstrumentProvider if importlib.util.find_spec("talib") is None: @@ -39,15 +47,68 @@ @pytest.fixture(scope="session") def bar_type() -> BarType: - return BarType.from_str("EUR/USD.IDEALPRO-1-HOUR-MID-EXTERNAL") + return BarType.from_str("GBP/USD.SIM-1-MINUTE-BID-EXTERNAL") @pytest.fixture() -def indicator_manager() -> "TALibIndicatorManager": - return TALibIndicatorManager( - bar_type=BarType.from_str("EUR/USD.IDEALPRO-1-HOUR-MID-EXTERNAL"), - period=10, +def indicator_manager(bar_type) -> "TALibIndicatorManager": + return TALibIndicatorManager(bar_type=bar_type, period=10) + + +@pytest.fixture() +def sample_bar_1(bar_type): + return Bar( + bar_type=bar_type, + open=Price.from_str("1.57593"), + high=Price.from_str("1.57614"), + low=Price.from_str("1.57593"), + close=Price.from_str("1.57610"), + volume=Quantity.from_int(1_000_000), + ts_event=1, + ts_init=1, + ) + + +@pytest.fixture() +def sample_bar_1_update(bar_type): + return Bar( + bar_type=bar_type, + open=Price.from_str("1.57593"), + high=Price.from_str("1.57619"), + low=Price.from_str("1.57593"), + close=Price.from_str("1.57619"), + volume=Quantity.from_int(2_000_000), + ts_event=1, + ts_init=1, + ) + + +@pytest.fixture() +def sample_bar_2(bar_type): + return Bar( + bar_type=bar_type, + open=Price.from_str("1.57610"), + high=Price.from_str("1.57621"), + low=Price.from_str("1.57606"), + close=Price.from_str("1.57608"), + volume=Quantity.from_int(1_000_000), + ts_event=2, + ts_init=2, + ) + + +@pytest.fixture() +def sample_data(bar_type) -> list[Bar]: + provider = TestDataProvider() + instrument = TestInstrumentProvider.default_fx_ccy( + symbol=bar_type.instrument_id.symbol.value, + venue=bar_type.instrument_id.venue, + ) + wrangler = BarDataWrangler(bar_type=bar_type, instrument=instrument) + bars = wrangler.process( + data=provider.read_csv_bars("fxcm/gbpusd-m1-bid-2012.csv")[:50], ) + return bars def test_setup(): @@ -141,7 +202,7 @@ def test_indicator_remains_uninitialized_with_insufficient_input_count(indicator def test_indicator_initializes_after_receiving_required_input_count(indicator_manager): # Arrange indicator_manager._stable_period = 5 - indicator_manager._calculate_ta = Mock() + indicator_manager._update_ta_outputs = Mock() # Act for i in range(5): @@ -152,14 +213,295 @@ def test_indicator_initializes_after_receiving_required_input_count(indicator_ma assert indicator_manager.initialized is True -def test_calculate_ta_called_on_initialization(indicator_manager): +def test_update_ta_outputs_default_append_is_true(indicator_manager): + # Arrange, Act + sig = inspect.signature(indicator_manager._update_ta_outputs) + default_append = sig.parameters["append"].default + + # Assert + assert default_append is True, "Default value for 'append' should be True" + + +def test_handle_bar_new(indicator_manager, sample_bar_1): # Arrange - indicator_manager._stable_period = 5 - indicator_manager._calculate_ta = Mock() + indicator_manager._update_ta_outputs = Mock() + indicator_manager.set_indicators(TAFunctionWrapper.from_list_of_str(["SMA_10"])) + expected = np.array( + [ + ( + sample_bar_1.ts_event, + sample_bar_1.ts_init, + sample_bar_1.open.as_double(), + sample_bar_1.high.as_double(), + sample_bar_1.low.as_double(), + sample_bar_1.close.as_double(), + sample_bar_1.volume.as_double(), + ), + ], + dtype=indicator_manager.input_dtypes(), + ) # Act - for i in range(5): - indicator_manager._increment_count() + indicator_manager.handle_bar(sample_bar_1) + + # Assert + assert indicator_manager._input_deque[-1] == [expected] + assert len(indicator_manager._input_deque) == 1 + assert indicator_manager.count == 1 + indicator_manager._update_ta_outputs.assert_called_once() + + +def test_handle_bar_update(indicator_manager, sample_bar_1, sample_bar_1_update): + # Arrange + indicator_manager._update_ta_outputs = Mock(side_effect=lambda append=True: None) + indicator_manager.set_indicators(TAFunctionWrapper.from_list_of_str(["SMA_10"])) + expected = np.array( + [ + ( + sample_bar_1_update.ts_event, + sample_bar_1_update.ts_init, + sample_bar_1_update.open.as_double(), + sample_bar_1_update.high.as_double(), + sample_bar_1_update.low.as_double(), + sample_bar_1_update.close.as_double(), + sample_bar_1_update.volume.as_double(), + ), + ], + dtype=indicator_manager.input_dtypes(), + ) + + # Act + indicator_manager.handle_bar(sample_bar_1) + indicator_manager.handle_bar(sample_bar_1_update) + + # Assert + assert indicator_manager._input_deque[-1] == [expected] + assert len(indicator_manager._input_deque) == 1 + assert indicator_manager.count == 1 + second_call_args, second_call_kwargs = indicator_manager._update_ta_outputs.call_args_list[1] + assert ( + second_call_kwargs.get("append", None) is False + ), "Second call was not made with append=False" + + +def test_handle_bar_out_of_sync(indicator_manager, sample_bar_1, sample_bar_2): + # Arrange + indicator_manager._update_ta_outputs = Mock(side_effect=lambda append=True: None) + indicator_manager.set_indicators(TAFunctionWrapper.from_list_of_str(["SMA_10"])) + expected = np.array( + [ + ( + sample_bar_2.ts_event, + sample_bar_2.ts_init, + sample_bar_2.open.as_double(), + sample_bar_2.high.as_double(), + sample_bar_2.low.as_double(), + sample_bar_2.close.as_double(), + sample_bar_2.volume.as_double(), + ), + ], + dtype=indicator_manager.input_dtypes(), + ) + + # Act + indicator_manager.handle_bar(sample_bar_2) + indicator_manager.handle_bar(sample_bar_1) # <- old bar received later + + # Assert + assert indicator_manager._input_deque[-1] == [expected] + assert len(indicator_manager._input_deque) == 1 + assert indicator_manager.count == 1 + assert indicator_manager._data_error_counter == 1 + assert indicator_manager._last_ts_event == 2 + + +def test_input_names(): + # Arrange + expected = ["ts_event", "ts_init", "open", "high", "low", "close", "volume"] + + # Act, Assert + assert TALibIndicatorManager.input_names() == expected + + +def test_input_dtypes(): + # Arrange + expected = [ + ("ts_event", np.dtype("uint64")), + ("ts_init", np.dtype("uint64")), + ("open", np.dtype("float64")), + ("high", np.dtype("float64")), + ("low", np.dtype("float64")), + ("close", np.dtype("float64")), + ("volume", np.dtype("float64")), + ] + + # Act, Assert + assert TALibIndicatorManager.input_dtypes() == expected + + +def test_output_dtypes(indicator_manager): + # Arrange + expected = [ + ("ts_event", np.dtype("uint64")), + ("ts_init", np.dtype("uint64")), + ("open", np.dtype("float64")), + ("high", np.dtype("float64")), + ("low", np.dtype("float64")), + ("close", np.dtype("float64")), + ("volume", np.dtype("float64")), + ("SMA_10", np.dtype("float64")), + ] + + # Act + indicator_manager.set_indicators(TAFunctionWrapper.from_list_of_str(["SMA_10"])) + + # Assert + assert indicator_manager._output_dtypes == expected + + +def test_input_deque_maxlen_is_one_more_than_lookback(indicator_manager): + # Arrange, Act + indicator_manager.set_indicators(TAFunctionWrapper.from_list_of_str(["SMA_10"])) + output_names = indicator_manager.input_names() + lookback = 0 + for indicator in indicator_manager._indicators: + output_names.extend(indicator.output_names) + lookback = max(lookback, indicator.fn.lookback) + expected_maxlen = lookback + 1 + + # Assert + assert indicator_manager._input_deque.maxlen == expected_maxlen + + +def test_stable_period_single_indicator(indicator_manager, sample_data): + # Arrange, Act + indicator_manager.set_indicators(TAFunctionWrapper.from_list_of_str(["SMA_10"])) + expected = 19 # indicator_manager.period + lookback of sma_10 + + # Assert + assert indicator_manager._stable_period == expected + + +def test_stable_period_multiple_indicators(indicator_manager, sample_data): + # Arrange, Act + indicator_manager.set_indicators( + TAFunctionWrapper.from_list_of_str(["SMA_10", "EMA_20", "ATR_14"]), + ) + expected = 29 # indicator_manager.period + max lookback of all indicator + + # Assert + assert indicator_manager._stable_period == expected + + +def test_output_array_when_not_initialized(indicator_manager, sample_data): + # Arrange + indicator_manager.set_indicators(TAFunctionWrapper.from_list_of_str(["SMA_10"])) + + # Act + for bar in sample_data[:15]: + indicator_manager.handle_bar(bar) + + # Assert + assert indicator_manager.output_array is None + + +def test_output_array_multiple_output_indicator(indicator_manager, sample_data): + # Arrange + indicator_manager.set_indicators(TAFunctionWrapper.from_list_of_str(["MACD_12_26_9"])) + expected_1 = np.array( + [ + 1.2788974555189014e-04, + 1.0112435988784974e-04, + 8.5582518167148791e-05, + 6.8531219604039961e-05, + 8.8447096904031852e-05, + 1.0224003011005678e-04, + 1.0908857559432938e-04, + 1.1207251438904997e-04, + 1.1158850602077663e-04, + 1.0202237800371883e-04, + ], + ) + expected_2 = np.array( + [ + -6.1579287446791549e-05, + -7.6999476312591649e-05, + -8.4977673358046753e-05, + -8.8624485371132294e-05, + -5.3610638737606260e-05, + -2.7077967138818465e-05, + -1.3647902407444410e-05, + -3.0235643303387811e-06, + 2.3138898333188349e-06, + -2.1016908817234661e-06, + ], + ) + + # Act + for bar in sample_data[:45]: + indicator_manager.handle_bar(bar) + + # Assert + assert np.array_equal(indicator_manager.output_array["MACD_12_26_9"], expected_1) + assert np.array_equal(indicator_manager.output_array["MACD_12_26_9_HIST"], expected_2) + + +def test_output_array_single_output_indicator(indicator_manager, sample_data): + # Arrange + indicator_manager.set_indicators(TAFunctionWrapper.from_list_of_str(["SMA_10"])) + expected = np.array( + [ + 1.575571, + 1.5755810000000001, + 1.575589, + 1.575594, + 1.575593, + 1.575602, + 1.575633, + 1.575673, + 1.5757189999999999, + 1.575764, + ], + ) + + # Act + for bar in sample_data[:20]: + indicator_manager.handle_bar(bar) + + # Assert + assert np.array_equal(indicator_manager.output_array["SMA_10"], expected) + + +def test_value(indicator_manager, sample_data): + # Arrange + indicator_manager.set_indicators(TAFunctionWrapper.from_list_of_str(["SMA_10"])) + + # Act + for bar in sample_data[:20]: + indicator_manager.handle_bar(bar) + + # Assert + assert indicator_manager.value("SMA_10") == 1.575764 + assert indicator_manager.value("SMA_10", 1) == 1.5757189999999999 + + +def test_value_with_invalid_index(indicator_manager, sample_data): + # Arrange + indicator_manager.set_indicators(TAFunctionWrapper.from_list_of_str(["SMA_10"])) + + # Act + for bar in sample_data[:20]: + indicator_manager.handle_bar(bar) + + # Assert + with pytest.raises(ValueError): + indicator_manager.value("SMA_10", 30) + + +def test_ohlcv_when_no_indicators_are_set(indicator_manager, sample_data): + # Act + for bar in sample_data[:20]: + indicator_manager.handle_bar(bar) # Assert - indicator_manager._calculate_ta.assert_called_once() + assert indicator_manager.output_array.shape == (10,)