-
Notifications
You must be signed in to change notification settings - Fork 658
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
5 changed files
with
244 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import importlib.util | ||
import sys | ||
|
||
import pytest | ||
|
||
|
||
if importlib.util.find_spec("talib") is None: | ||
if sys.platform == "linux": | ||
# Raise the exception (expecting talib to be available on Linux) | ||
error_message = ( | ||
"Failed to import TA-Lib. This module requires TA-Lib to be installed. " | ||
"Please visit https://github.com/TA-Lib/ta-lib-python for installation instructions. " | ||
"If TA-Lib is already installed, ensure it is correctly added to your Python environment." | ||
) | ||
raise ImportError(error_message) | ||
pytestmark = pytest.mark.skip(reason="talib is not installed") | ||
else: | ||
import talib | ||
from talib import abstract | ||
|
||
from nautilus_trader.indicators.ta_lib.manager import TAFunctionWrapper | ||
|
||
|
||
def test_init_with_valid_name_and_no_params(): | ||
wrapper = TAFunctionWrapper(name="SMA") | ||
assert wrapper.name == "SMA" | ||
assert isinstance(wrapper.fn, talib._ta_lib.Function) | ||
assert isinstance(wrapper.output_names, list) | ||
assert all(isinstance(o, str) for o in wrapper.output_names) | ||
|
||
|
||
def test_init_with_valid_name_and_params(): | ||
wrapper = TAFunctionWrapper(name="EMA", params={"timeperiod": 10}) | ||
assert wrapper.name == "EMA" | ||
assert wrapper.fn.parameters["timeperiod"] == 10 | ||
|
||
|
||
def test_repr(): | ||
wrapper = TAFunctionWrapper(name="SMA", params={"timeperiod": 5}) | ||
assert repr(wrapper) == "TAFunctionWrapper(SMA_5)" | ||
|
||
|
||
def test_get_outputs_names(): | ||
fn = abstract.Function("SMA") | ||
fn.set_parameters({"timeperiod": 5}) | ||
output_names = TAFunctionWrapper._get_outputs_names("SMA", fn) | ||
assert output_names == ["SMA_5"] | ||
|
||
|
||
def test_from_str_valid(): | ||
wrapper = TAFunctionWrapper.from_str("SMA_5") | ||
assert wrapper.name == "SMA" | ||
assert wrapper.fn.parameters["timeperiod"] == 5 | ||
|
||
|
||
def test_from_str_invalid(): | ||
with pytest.raises(Exception): | ||
TAFunctionWrapper.from_str("INVALID_5") | ||
|
||
|
||
def test_from_list_of_str(): | ||
indicators = ["SMA_5", "EMA_10"] | ||
wrappers = TAFunctionWrapper.from_list_of_str(indicators) | ||
assert len(wrappers) == 2 | ||
assert all(isinstance(w, TAFunctionWrapper) for w in wrappers) |
150 changes: 150 additions & 0 deletions
150
tests/unit_tests/indicators/test_talib_indicator_manager.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,150 @@ | ||
import importlib.util | ||
import sys | ||
from unittest.mock import Mock | ||
|
||
import pytest | ||
|
||
from nautilus_trader.model.data import BarType | ||
|
||
|
||
if importlib.util.find_spec("talib") is None: | ||
if sys.platform == "linux": | ||
# Raise the exception (expecting talib to be available on Linux) | ||
error_message = ( | ||
"Failed to import TA-Lib. This module requires TA-Lib to be installed. " | ||
"Please visit https://github.com/TA-Lib/ta-lib-python for installation instructions. " | ||
"If TA-Lib is already installed, ensure it is correctly added to your Python environment." | ||
) | ||
raise ImportError(error_message) | ||
pytestmark = pytest.mark.skip(reason="talib is not installed") | ||
else: | ||
from nautilus_trader.indicators.ta_lib.manager import TAFunctionWrapper | ||
from nautilus_trader.indicators.ta_lib.manager import TALibIndicatorManager | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def bar_type() -> BarType: | ||
return BarType.from_str("EUR/USD.IDEALPRO-1-HOUR-MID-EXTERNAL") | ||
|
||
|
||
@pytest.fixture() | ||
def indicator_manager() -> "TALibIndicatorManager": | ||
return TALibIndicatorManager( | ||
bar_type=BarType.from_str("EUR/USD.IDEALPRO-1-HOUR-MID-EXTERNAL"), | ||
period=10, | ||
) | ||
|
||
|
||
def test_setup(): | ||
# Arrange | ||
bar_type = BarType.from_str("EUR/USD.IDEALPRO-1-HOUR-MID-EXTERNAL") | ||
period = 10 | ||
|
||
# Act | ||
indicator_manager = TALibIndicatorManager(bar_type=bar_type, period=period) | ||
|
||
# Assert | ||
assert indicator_manager.bar_type == bar_type | ||
assert indicator_manager.period == period | ||
|
||
|
||
def test_invalid_bar_type(): | ||
# Arrange, Act, Assert | ||
with pytest.raises(TypeError): | ||
TALibIndicatorManager(bar_type="invalid_bar_type", period=10) | ||
|
||
|
||
def test_not_positive_period(bar_type): | ||
# Arrange, Act, Assert | ||
with pytest.raises(ValueError): | ||
TALibIndicatorManager(bar_type=bar_type, period=0) | ||
|
||
|
||
def test_not_positive_buffer_size(bar_type): | ||
# Arrange, Act, Assert | ||
with pytest.raises(ValueError): | ||
TALibIndicatorManager(bar_type=bar_type, period=10, buffer_size=0) | ||
|
||
|
||
def test_skip_uniform_price_bar_default_true(indicator_manager): | ||
# Assert | ||
indicator_manager._skip_uniform_price_bar = True | ||
|
||
|
||
def test_skip_zero_close_bar_default_true(indicator_manager): | ||
# Assert | ||
indicator_manager._skip_zero_close_bar = True | ||
|
||
|
||
def test_set_indicators(indicator_manager): | ||
# Arrange | ||
indicators = (TAFunctionWrapper("SMA", {"timeperiod": 50}), TAFunctionWrapper("MACD")) | ||
|
||
# Act | ||
indicator_manager.set_indicators(indicators) | ||
|
||
# Assert | ||
assert len(indicator_manager._indicators) == len(indicators) | ||
|
||
|
||
def test_output_names_generation(indicator_manager): | ||
# Arrange | ||
indicators = (TAFunctionWrapper("SMA", {"timeperiod": 50}),) | ||
# Act | ||
indicator_manager.set_indicators(indicators) | ||
expected_output_names = indicator_manager.input_names() + indicators[0].output_names | ||
|
||
# Assert | ||
assert indicator_manager.output_names == tuple(expected_output_names) | ||
|
||
|
||
def test_increment_count_correctly_increases_counter(indicator_manager): | ||
# Arrange | ||
indicator_manager._stable_period = 5 | ||
|
||
# Act | ||
for i in range(2): | ||
indicator_manager._increment_count() | ||
|
||
# Assert | ||
assert indicator_manager.count == 2 | ||
|
||
|
||
def test_indicator_remains_uninitialized_with_insufficient_input_count(indicator_manager): | ||
# Arrange | ||
indicator_manager._stable_period = 5 | ||
|
||
# Act | ||
for i in range(4): | ||
indicator_manager._increment_count() | ||
|
||
# Assert | ||
assert indicator_manager.has_inputs is True | ||
assert indicator_manager.initialized is False | ||
|
||
|
||
def test_indicator_initializes_after_receiving_required_input_count(indicator_manager): | ||
# Arrange | ||
indicator_manager._stable_period = 5 | ||
indicator_manager._calculate_ta = Mock() | ||
|
||
# Act | ||
for i in range(5): | ||
indicator_manager._increment_count() | ||
|
||
# Assert | ||
assert indicator_manager.has_inputs is True | ||
assert indicator_manager.initialized is True | ||
|
||
|
||
def test_calculate_ta_called_on_initialization(indicator_manager): | ||
# Arrange | ||
indicator_manager._stable_period = 5 | ||
indicator_manager._calculate_ta = Mock() | ||
|
||
# Act | ||
for i in range(5): | ||
indicator_manager._increment_count() | ||
|
||
# Assert | ||
indicator_manager._calculate_ta.assert_called_once() |