Skip to content

Commit

Permalink
Add TA-Lib tests (#1417)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsmb7z authored Dec 19, 2023
1 parent f50ae7b commit a3d8e74
Show file tree
Hide file tree
Showing 5 changed files with 244 additions and 15 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ jobs:
if: runner.os == 'Linux'
run: |
make install-talib
pip install ta-lib
poetry run pip install ta-lib
- name: Setup cached pre-commit
id: cached-pre-commit
Expand Down
1 change: 1 addition & 0 deletions nautilus_trader/examples/strategies/talib_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def on_start(self) -> None:

# Register the indicators for updating
self.register_indicator_for_bars(self.bar_type, self.indicator_manager)
self.indicator_manager.change_logger(self.log.get_logger())

# Subscribe to live data
self.subscribe_bars(self.bar_type)
Expand Down
41 changes: 27 additions & 14 deletions nautilus_trader/indicators/ta_lib/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,23 @@
# limitations under the License.
# -------------------------------------------------------------------------------------------------

try:
import talib
from talib import abstract
except ImportError as e:
error_message = (
"Failed to import TA-Lib. This module requires TA-Lib to be installed. "
"Please visit https://github.com/TA-Lib/ta-lib-python for installation instructions. "
"If TA-Lib is already installed, ensure it is correctly added to your Python environment."
)
raise ImportError(error_message) from e


import os
from collections import deque

import numpy as np
import pandas as pd
import talib
from talib import abstract

from nautilus_trader.common.clock import LiveClock
from nautilus_trader.common.logging import Logger
Expand Down Expand Up @@ -54,12 +64,6 @@ class TAFunctionWrapper:
- output_names (list[str]): A list of formatted output names for the technical indicator,
generated based on the `name` and `params`.
Methods:
-------
- __init__(self, name: str, params: dict[str, Union[int, float]] = {}): Initializes the
TAFunctionWrapper instance with a given name and optional parameters for the TA-Lib
function.
Note:
----
- The class utilizes TA-Lib, a popular technical analysis library, to handle the underlying
Expand All @@ -76,10 +80,10 @@ class TAFunctionWrapper:
"""

def __init__(self, name: str, params: dict[str, int | float] = {}):
def __init__(self, name: str, params: dict[str, int | float] | None = None):
self.name = name
self.fn = abstract.Function(name)
self.fn.set_parameters(params)
self.fn.set_parameters(params or {})
self.output_names = self._get_outputs_names(self.name, self.fn)

def __repr__(self):
Expand Down Expand Up @@ -223,17 +227,21 @@ def __init__(
bar_type: BarType,
period: int = 1,
buffer_size: int | None = None,
skip_uniform_price_bar: bool = True,
skip_zero_close_bar: bool = True,
):
super().__init__([])

PyCondition().type(bar_type, BarType, "bar_type")
PyCondition().positive_int(period, "period")
if buffer_size:
PyCondition().positive_int(period, "buffer_size")
if buffer_size is not None:
PyCondition().positive_int(buffer_size, "buffer_size")

# Initialize variables
self._bar_type = bar_type
self._period = period
self._skip_uniform_price_bar = skip_uniform_price_bar
self._skip_zero_close_bar = skip_zero_close_bar
self._output_array: np.recarray | None = None
self._last_ts_event: int = 0
self._data_error_counter: int = 0
Expand Down Expand Up @@ -306,7 +314,7 @@ def set_indicators(self, indicators: tuple[TAFunctionWrapper, ...]) -> None:

# Initialize the output dtypes
self._output_dtypes = [
(col, np.dtype("uint64") if col == "ts_event" else np.dtype("float64"))
(col, np.dtype("uint64") if col in ["ts_event", "ts_init"] else np.dtype("float64"))
for col in self.output_names
]

Expand Down Expand Up @@ -375,6 +383,7 @@ def _calculate_ta(self, append: bool = True) -> None:

combined_output = np.zeros(1, dtype=self._output_dtypes)
combined_output["ts_event"] = self._input_deque[-1]["ts_event"].item()
combined_output["ts_init"] = self._input_deque[-1]["ts_init"].item()
combined_output["open"] = self._input_deque[-1]["open"].item()
combined_output["high"] = self._input_deque[-1]["high"].item()
combined_output["low"] = self._input_deque[-1]["low"].item()
Expand Down Expand Up @@ -672,7 +681,11 @@ def handle_bar(self, bar: Bar) -> None:

self._log.debug(f"Handling bar: {bar!r}")

if bar.is_single_price() and bar.open.as_double() == 0:
if self._skip_uniform_price_bar and bar.is_single_price():
self._log.warning(f"Skipping uniform_price bar: {bar!r}")
return
if self._skip_zero_close_bar and bar.close.raw == 0:
self._log.warning(f"Skipping zero close bar: {bar!r}")
return

bar_data = np.array(
Expand Down
65 changes: 65 additions & 0 deletions tests/unit_tests/indicators/test_talib_func_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import importlib.util
import sys

import pytest


if importlib.util.find_spec("talib") is None:
if sys.platform == "linux":
# Raise the exception (expecting talib to be available on Linux)
error_message = (
"Failed to import TA-Lib. This module requires TA-Lib to be installed. "
"Please visit https://github.com/TA-Lib/ta-lib-python for installation instructions. "
"If TA-Lib is already installed, ensure it is correctly added to your Python environment."
)
raise ImportError(error_message)
pytestmark = pytest.mark.skip(reason="talib is not installed")
else:
import talib
from talib import abstract

from nautilus_trader.indicators.ta_lib.manager import TAFunctionWrapper


def test_init_with_valid_name_and_no_params():
wrapper = TAFunctionWrapper(name="SMA")
assert wrapper.name == "SMA"
assert isinstance(wrapper.fn, talib._ta_lib.Function)
assert isinstance(wrapper.output_names, list)
assert all(isinstance(o, str) for o in wrapper.output_names)


def test_init_with_valid_name_and_params():
wrapper = TAFunctionWrapper(name="EMA", params={"timeperiod": 10})
assert wrapper.name == "EMA"
assert wrapper.fn.parameters["timeperiod"] == 10


def test_repr():
wrapper = TAFunctionWrapper(name="SMA", params={"timeperiod": 5})
assert repr(wrapper) == "TAFunctionWrapper(SMA_5)"


def test_get_outputs_names():
fn = abstract.Function("SMA")
fn.set_parameters({"timeperiod": 5})
output_names = TAFunctionWrapper._get_outputs_names("SMA", fn)
assert output_names == ["SMA_5"]


def test_from_str_valid():
wrapper = TAFunctionWrapper.from_str("SMA_5")
assert wrapper.name == "SMA"
assert wrapper.fn.parameters["timeperiod"] == 5


def test_from_str_invalid():
with pytest.raises(Exception):
TAFunctionWrapper.from_str("INVALID_5")


def test_from_list_of_str():
indicators = ["SMA_5", "EMA_10"]
wrappers = TAFunctionWrapper.from_list_of_str(indicators)
assert len(wrappers) == 2
assert all(isinstance(w, TAFunctionWrapper) for w in wrappers)
150 changes: 150 additions & 0 deletions tests/unit_tests/indicators/test_talib_indicator_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
import importlib.util
import sys
from unittest.mock import Mock

import pytest

from nautilus_trader.model.data import BarType


if importlib.util.find_spec("talib") is None:
if sys.platform == "linux":
# Raise the exception (expecting talib to be available on Linux)
error_message = (
"Failed to import TA-Lib. This module requires TA-Lib to be installed. "
"Please visit https://github.com/TA-Lib/ta-lib-python for installation instructions. "
"If TA-Lib is already installed, ensure it is correctly added to your Python environment."
)
raise ImportError(error_message)
pytestmark = pytest.mark.skip(reason="talib is not installed")
else:
from nautilus_trader.indicators.ta_lib.manager import TAFunctionWrapper
from nautilus_trader.indicators.ta_lib.manager import TALibIndicatorManager


@pytest.fixture(scope="session")
def bar_type() -> BarType:
return BarType.from_str("EUR/USD.IDEALPRO-1-HOUR-MID-EXTERNAL")


@pytest.fixture()
def indicator_manager() -> "TALibIndicatorManager":
return TALibIndicatorManager(
bar_type=BarType.from_str("EUR/USD.IDEALPRO-1-HOUR-MID-EXTERNAL"),
period=10,
)


def test_setup():
# Arrange
bar_type = BarType.from_str("EUR/USD.IDEALPRO-1-HOUR-MID-EXTERNAL")
period = 10

# Act
indicator_manager = TALibIndicatorManager(bar_type=bar_type, period=period)

# Assert
assert indicator_manager.bar_type == bar_type
assert indicator_manager.period == period


def test_invalid_bar_type():
# Arrange, Act, Assert
with pytest.raises(TypeError):
TALibIndicatorManager(bar_type="invalid_bar_type", period=10)


def test_not_positive_period(bar_type):
# Arrange, Act, Assert
with pytest.raises(ValueError):
TALibIndicatorManager(bar_type=bar_type, period=0)


def test_not_positive_buffer_size(bar_type):
# Arrange, Act, Assert
with pytest.raises(ValueError):
TALibIndicatorManager(bar_type=bar_type, period=10, buffer_size=0)


def test_skip_uniform_price_bar_default_true(indicator_manager):
# Assert
indicator_manager._skip_uniform_price_bar = True


def test_skip_zero_close_bar_default_true(indicator_manager):
# Assert
indicator_manager._skip_zero_close_bar = True


def test_set_indicators(indicator_manager):
# Arrange
indicators = (TAFunctionWrapper("SMA", {"timeperiod": 50}), TAFunctionWrapper("MACD"))

# Act
indicator_manager.set_indicators(indicators)

# Assert
assert len(indicator_manager._indicators) == len(indicators)


def test_output_names_generation(indicator_manager):
# Arrange
indicators = (TAFunctionWrapper("SMA", {"timeperiod": 50}),)
# Act
indicator_manager.set_indicators(indicators)
expected_output_names = indicator_manager.input_names() + indicators[0].output_names

# Assert
assert indicator_manager.output_names == tuple(expected_output_names)


def test_increment_count_correctly_increases_counter(indicator_manager):
# Arrange
indicator_manager._stable_period = 5

# Act
for i in range(2):
indicator_manager._increment_count()

# Assert
assert indicator_manager.count == 2


def test_indicator_remains_uninitialized_with_insufficient_input_count(indicator_manager):
# Arrange
indicator_manager._stable_period = 5

# Act
for i in range(4):
indicator_manager._increment_count()

# Assert
assert indicator_manager.has_inputs is True
assert indicator_manager.initialized is False


def test_indicator_initializes_after_receiving_required_input_count(indicator_manager):
# Arrange
indicator_manager._stable_period = 5
indicator_manager._calculate_ta = Mock()

# Act
for i in range(5):
indicator_manager._increment_count()

# Assert
assert indicator_manager.has_inputs is True
assert indicator_manager.initialized is True


def test_calculate_ta_called_on_initialization(indicator_manager):
# Arrange
indicator_manager._stable_period = 5
indicator_manager._calculate_ta = Mock()

# Act
for i in range(5):
indicator_manager._increment_count()

# Assert
indicator_manager._calculate_ta.assert_called_once()

0 comments on commit a3d8e74

Please sign in to comment.