diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 055193cac66d..27477ee31fa3 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -75,7 +75,7 @@ jobs: if: runner.os == 'Linux' run: | make install-talib - pip install ta-lib + poetry run pip install ta-lib - name: Setup cached pre-commit id: cached-pre-commit diff --git a/nautilus_trader/examples/strategies/talib_strategy.py b/nautilus_trader/examples/strategies/talib_strategy.py index c13b92620b63..0674197d356b 100644 --- a/nautilus_trader/examples/strategies/talib_strategy.py +++ b/nautilus_trader/examples/strategies/talib_strategy.py @@ -119,6 +119,7 @@ def on_start(self) -> None: # Register the indicators for updating self.register_indicator_for_bars(self.bar_type, self.indicator_manager) + self.indicator_manager.change_logger(self.log.get_logger()) # Subscribe to live data self.subscribe_bars(self.bar_type) diff --git a/nautilus_trader/indicators/ta_lib/manager.py b/nautilus_trader/indicators/ta_lib/manager.py index 7b411b1ef724..8ee4bfed3aae 100644 --- a/nautilus_trader/indicators/ta_lib/manager.py +++ b/nautilus_trader/indicators/ta_lib/manager.py @@ -13,13 +13,23 @@ # limitations under the License. # ------------------------------------------------------------------------------------------------- +try: + import talib + from talib import abstract +except ImportError as e: + error_message = ( + "Failed to import TA-Lib. This module requires TA-Lib to be installed. " + "Please visit https://github.com/TA-Lib/ta-lib-python for installation instructions. " + "If TA-Lib is already installed, ensure it is correctly added to your Python environment." + ) + raise ImportError(error_message) from e + + import os from collections import deque import numpy as np import pandas as pd -import talib -from talib import abstract from nautilus_trader.common.clock import LiveClock from nautilus_trader.common.logging import Logger @@ -54,12 +64,6 @@ class TAFunctionWrapper: - output_names (list[str]): A list of formatted output names for the technical indicator, generated based on the `name` and `params`. - Methods: - ------- - - __init__(self, name: str, params: dict[str, Union[int, float]] = {}): Initializes the - TAFunctionWrapper instance with a given name and optional parameters for the TA-Lib - function. - Note: ---- - The class utilizes TA-Lib, a popular technical analysis library, to handle the underlying @@ -76,10 +80,10 @@ class TAFunctionWrapper: """ - def __init__(self, name: str, params: dict[str, int | float] = {}): + def __init__(self, name: str, params: dict[str, int | float] | None = None): self.name = name self.fn = abstract.Function(name) - self.fn.set_parameters(params) + self.fn.set_parameters(params or {}) self.output_names = self._get_outputs_names(self.name, self.fn) def __repr__(self): @@ -223,17 +227,21 @@ def __init__( bar_type: BarType, period: int = 1, buffer_size: int | None = None, + skip_uniform_price_bar: bool = True, + skip_zero_close_bar: bool = True, ): super().__init__([]) PyCondition().type(bar_type, BarType, "bar_type") PyCondition().positive_int(period, "period") - if buffer_size: - PyCondition().positive_int(period, "buffer_size") + if buffer_size is not None: + PyCondition().positive_int(buffer_size, "buffer_size") # Initialize variables self._bar_type = bar_type self._period = period + self._skip_uniform_price_bar = skip_uniform_price_bar + self._skip_zero_close_bar = skip_zero_close_bar self._output_array: np.recarray | None = None self._last_ts_event: int = 0 self._data_error_counter: int = 0 @@ -306,7 +314,7 @@ def set_indicators(self, indicators: tuple[TAFunctionWrapper, ...]) -> None: # Initialize the output dtypes self._output_dtypes = [ - (col, np.dtype("uint64") if col == "ts_event" else np.dtype("float64")) + (col, np.dtype("uint64") if col in ["ts_event", "ts_init"] else np.dtype("float64")) for col in self.output_names ] @@ -375,6 +383,7 @@ def _calculate_ta(self, append: bool = True) -> None: combined_output = np.zeros(1, dtype=self._output_dtypes) combined_output["ts_event"] = self._input_deque[-1]["ts_event"].item() + combined_output["ts_init"] = self._input_deque[-1]["ts_init"].item() combined_output["open"] = self._input_deque[-1]["open"].item() combined_output["high"] = self._input_deque[-1]["high"].item() combined_output["low"] = self._input_deque[-1]["low"].item() @@ -672,7 +681,11 @@ def handle_bar(self, bar: Bar) -> None: self._log.debug(f"Handling bar: {bar!r}") - if bar.is_single_price() and bar.open.as_double() == 0: + if self._skip_uniform_price_bar and bar.is_single_price(): + self._log.warning(f"Skipping uniform_price bar: {bar!r}") + return + if self._skip_zero_close_bar and bar.close.raw == 0: + self._log.warning(f"Skipping zero close bar: {bar!r}") return bar_data = np.array( diff --git a/tests/unit_tests/indicators/test_talib_func_wrapper.py b/tests/unit_tests/indicators/test_talib_func_wrapper.py new file mode 100644 index 000000000000..58721b128817 --- /dev/null +++ b/tests/unit_tests/indicators/test_talib_func_wrapper.py @@ -0,0 +1,65 @@ +import importlib.util +import sys + +import pytest + + +if importlib.util.find_spec("talib") is None: + if sys.platform == "linux": + # Raise the exception (expecting talib to be available on Linux) + error_message = ( + "Failed to import TA-Lib. This module requires TA-Lib to be installed. " + "Please visit https://github.com/TA-Lib/ta-lib-python for installation instructions. " + "If TA-Lib is already installed, ensure it is correctly added to your Python environment." + ) + raise ImportError(error_message) + pytestmark = pytest.mark.skip(reason="talib is not installed") +else: + import talib + from talib import abstract + + from nautilus_trader.indicators.ta_lib.manager import TAFunctionWrapper + + +def test_init_with_valid_name_and_no_params(): + wrapper = TAFunctionWrapper(name="SMA") + assert wrapper.name == "SMA" + assert isinstance(wrapper.fn, talib._ta_lib.Function) + assert isinstance(wrapper.output_names, list) + assert all(isinstance(o, str) for o in wrapper.output_names) + + +def test_init_with_valid_name_and_params(): + wrapper = TAFunctionWrapper(name="EMA", params={"timeperiod": 10}) + assert wrapper.name == "EMA" + assert wrapper.fn.parameters["timeperiod"] == 10 + + +def test_repr(): + wrapper = TAFunctionWrapper(name="SMA", params={"timeperiod": 5}) + assert repr(wrapper) == "TAFunctionWrapper(SMA_5)" + + +def test_get_outputs_names(): + fn = abstract.Function("SMA") + fn.set_parameters({"timeperiod": 5}) + output_names = TAFunctionWrapper._get_outputs_names("SMA", fn) + assert output_names == ["SMA_5"] + + +def test_from_str_valid(): + wrapper = TAFunctionWrapper.from_str("SMA_5") + assert wrapper.name == "SMA" + assert wrapper.fn.parameters["timeperiod"] == 5 + + +def test_from_str_invalid(): + with pytest.raises(Exception): + TAFunctionWrapper.from_str("INVALID_5") + + +def test_from_list_of_str(): + indicators = ["SMA_5", "EMA_10"] + wrappers = TAFunctionWrapper.from_list_of_str(indicators) + assert len(wrappers) == 2 + assert all(isinstance(w, TAFunctionWrapper) for w in wrappers) diff --git a/tests/unit_tests/indicators/test_talib_indicator_manager.py b/tests/unit_tests/indicators/test_talib_indicator_manager.py new file mode 100644 index 000000000000..a3937e0b61f4 --- /dev/null +++ b/tests/unit_tests/indicators/test_talib_indicator_manager.py @@ -0,0 +1,150 @@ +import importlib.util +import sys +from unittest.mock import Mock + +import pytest + +from nautilus_trader.model.data import BarType + + +if importlib.util.find_spec("talib") is None: + if sys.platform == "linux": + # Raise the exception (expecting talib to be available on Linux) + error_message = ( + "Failed to import TA-Lib. This module requires TA-Lib to be installed. " + "Please visit https://github.com/TA-Lib/ta-lib-python for installation instructions. " + "If TA-Lib is already installed, ensure it is correctly added to your Python environment." + ) + raise ImportError(error_message) + pytestmark = pytest.mark.skip(reason="talib is not installed") +else: + from nautilus_trader.indicators.ta_lib.manager import TAFunctionWrapper + from nautilus_trader.indicators.ta_lib.manager import TALibIndicatorManager + + +@pytest.fixture(scope="session") +def bar_type() -> BarType: + return BarType.from_str("EUR/USD.IDEALPRO-1-HOUR-MID-EXTERNAL") + + +@pytest.fixture() +def indicator_manager() -> "TALibIndicatorManager": + return TALibIndicatorManager( + bar_type=BarType.from_str("EUR/USD.IDEALPRO-1-HOUR-MID-EXTERNAL"), + period=10, + ) + + +def test_setup(): + # Arrange + bar_type = BarType.from_str("EUR/USD.IDEALPRO-1-HOUR-MID-EXTERNAL") + period = 10 + + # Act + indicator_manager = TALibIndicatorManager(bar_type=bar_type, period=period) + + # Assert + assert indicator_manager.bar_type == bar_type + assert indicator_manager.period == period + + +def test_invalid_bar_type(): + # Arrange, Act, Assert + with pytest.raises(TypeError): + TALibIndicatorManager(bar_type="invalid_bar_type", period=10) + + +def test_not_positive_period(bar_type): + # Arrange, Act, Assert + with pytest.raises(ValueError): + TALibIndicatorManager(bar_type=bar_type, period=0) + + +def test_not_positive_buffer_size(bar_type): + # Arrange, Act, Assert + with pytest.raises(ValueError): + TALibIndicatorManager(bar_type=bar_type, period=10, buffer_size=0) + + +def test_skip_uniform_price_bar_default_true(indicator_manager): + # Assert + indicator_manager._skip_uniform_price_bar = True + + +def test_skip_zero_close_bar_default_true(indicator_manager): + # Assert + indicator_manager._skip_zero_close_bar = True + + +def test_set_indicators(indicator_manager): + # Arrange + indicators = (TAFunctionWrapper("SMA", {"timeperiod": 50}), TAFunctionWrapper("MACD")) + + # Act + indicator_manager.set_indicators(indicators) + + # Assert + assert len(indicator_manager._indicators) == len(indicators) + + +def test_output_names_generation(indicator_manager): + # Arrange + indicators = (TAFunctionWrapper("SMA", {"timeperiod": 50}),) + # Act + indicator_manager.set_indicators(indicators) + expected_output_names = indicator_manager.input_names() + indicators[0].output_names + + # Assert + assert indicator_manager.output_names == tuple(expected_output_names) + + +def test_increment_count_correctly_increases_counter(indicator_manager): + # Arrange + indicator_manager._stable_period = 5 + + # Act + for i in range(2): + indicator_manager._increment_count() + + # Assert + assert indicator_manager.count == 2 + + +def test_indicator_remains_uninitialized_with_insufficient_input_count(indicator_manager): + # Arrange + indicator_manager._stable_period = 5 + + # Act + for i in range(4): + indicator_manager._increment_count() + + # Assert + assert indicator_manager.has_inputs is True + assert indicator_manager.initialized is False + + +def test_indicator_initializes_after_receiving_required_input_count(indicator_manager): + # Arrange + indicator_manager._stable_period = 5 + indicator_manager._calculate_ta = Mock() + + # Act + for i in range(5): + indicator_manager._increment_count() + + # Assert + assert indicator_manager.has_inputs is True + assert indicator_manager.initialized is True + + +def test_calculate_ta_called_on_initialization(indicator_manager): + # Arrange + indicator_manager._stable_period = 5 + indicator_manager._calculate_ta = Mock() + + # Act + for i in range(5): + indicator_manager._increment_count() + + # Assert + indicator_manager._calculate_ta.assert_called_once()