Skip to content

Commit

Permalink
Add more tests for TALibIndicatorManager (#1422)
Browse files Browse the repository at this point in the history
  • Loading branch information
rsmb7z authored Dec 21, 2023
1 parent cc6efb1 commit 7732e8e
Show file tree
Hide file tree
Showing 3 changed files with 412 additions and 54 deletions.
95 changes: 53 additions & 42 deletions nautilus_trader/indicators/ta_lib/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,9 @@ def __init__(
logger = Logger(clock)
self._log = LoggerAdapter(component_name=repr(self), logger=logger)

# Initialize with empty indicators (acts as OHLCV placeholder in case no indicators are set)
self.set_indicators(())

def change_logger(self, logger: Logger):
PyCondition().type(logger, Logger, "logger")
self._log = LoggerAdapter(component_name=repr(self), logger=logger)
Expand Down Expand Up @@ -297,6 +300,10 @@ def set_indicators(self, indicators: tuple[TAFunctionWrapper, ...]) -> None:
info levels, respectively.
"""
if self.initialized:
self._log.info("Indicator already initialized. Skipping set_indicators.")
return

self._log.debug(f"Setting indicators {indicators}")
PyCondition().list_type(list(indicators), TAFunctionWrapper, "ta_functions")

Expand All @@ -308,9 +315,9 @@ def set_indicators(self, indicators: tuple[TAFunctionWrapper, ...]) -> None:
output_names.extend(indicator.output_names)
lookback = max(lookback, indicator.fn.lookback)

self.output_names = tuple(output_names)
self._stable_period = lookback + self._period
self._input_deque = deque(maxlen=self._stable_period)
self._input_deque = deque(maxlen=lookback + 1)
self.output_names = tuple(output_names)

# Initialize the output dtypes
self._output_dtypes = [
Expand Down Expand Up @@ -343,39 +350,38 @@ def bar_type(self) -> BarType:
def period(self) -> int:
return self._period

def _calculate_ta(self, append: bool = True) -> None:
def _update_ta_outputs(self, append: bool = True) -> None:
"""
Calculate technical analysis indicators and update the output deque.
Update the output deque with calculated technical analysis indicators.
This private method computes the output values for technical analysis indicators
set in the instance. It first initializes a combined output array with the base
values (like 'open', 'high', 'low', 'close', 'volume', and 'ts_event') from the
most recent entry in the input deque. It then iterates through each indicator,
computes its output, and updates the combined output array accordingly.
Depending on the 'append' flag, the method either appends or updates the latest
entry in the output deque with the combined output.
This private method computes and updates the output values for technical
analysis indicators based on the latest data in the input deque. It initializes
a combined output array with base values (e.g., 'open', 'high', 'low', 'close',
'volume', 'ts_event') from the most recent input deque entry. Each indicator's
output is calculated and used to update the combined output array. The updated
data is either appended to or replaces the latest entry in the output deque,
depending on the value of the 'append' argument.
Args:
----
append : bool, optional
A flag to determine whether to append the new output to the output deque (True) or
to replace the most recent output (False). Defaults to True.
append (bool): Determines whether to append the new output to the output
deque (True) or replace the most recent output (False).
Defaults to True.
Returns
Returns:
-------
None
None
The method performs the following steps:
- Initializes a combined output array with the default data types.
- Extracts the base values from the most recent entry in the input deque.
- Initializes a combined output array with base values from the latest input
deque entry.
- Iterates through each indicator, calculates its output, and updates the
combined output array.
- Depending on the 'append' flag, either appends the combined output to the
output deque or replaces its most recent entry.
- Resets the internal output array to ensure that it is rebuilt during the
next access.
- Appends the combined output to the output deque or replaces its most recent
entry based on the 'append' flag.
- Resets the internal output array for reconstruction during the next access.
This method logs actions at the debug level for tracking the calculation and
This method logs actions at the debug level to track the calculation and
updating process.
"""
Expand All @@ -390,31 +396,30 @@ def _calculate_ta(self, append: bool = True) -> None:
combined_output["close"] = self._input_deque[-1]["close"].item()
combined_output["volume"] = self._input_deque[-1]["volume"].item()

input_array = None
input_array = np.concatenate(self._input_deque)
for indicator in self._indicators:
if input_array is None:
input_array = np.concatenate(self._input_deque)
period = indicator.fn.lookback + 1
inputs_dict = {name: input_array[name][-period:] for name in input_array.dtype.names}
self._log.debug(f"Calculating {indicator.name} outputs.")
inputs_dict = {name: input_array[name] for name in input_array.dtype.names}
indicator.fn.set_input_arrays(inputs_dict)
results = indicator.fn.run()

if len(indicator.output_names) == 1:
self._log.debug("Single output.")
combined_output[indicator.output_names[0]] = results[-1]
else:
self._log.debug("Multiple outputs.")
for i, output_name in enumerate(indicator.output_names):
combined_output[output_name] = results[i][-1]

if self.initialized:
if append:
self._log.debug("Appending output.")
self._output_deque.append(combined_output)
else:
self._log.debug("Prepending output.")
self._output_deque[-1] = combined_output
if append:
self._log.debug("Appending output.")
self._output_deque.append(combined_output)
else:
self._log.debug("Prepending output.")
self._output_deque[-1] = combined_output

# Reset output array to force rebuild on next access
self._output_array = None
# Reset output array to force rebuild on next access
self._output_array = None

def _increment_count(self) -> None:
self.count += 1
Expand All @@ -424,7 +429,6 @@ def _increment_count(self) -> None:
if self.count >= self._stable_period:
self._set_initialized(True)
self._log.info(f"Initialized with {self.count} bars")
self._calculate_ta() # Immediately make the first calculation

def value(self, name: str, index=0):
"""
Expand Down Expand Up @@ -478,7 +482,7 @@ def value(self, name: str, index=0):
return self.output_array[name][translated_index]

@property
def output_array(self) -> np.recarray:
def output_array(self) -> np.recarray | None:
"""
Retrieve or generate the output array for the indicator.
Expand Down Expand Up @@ -512,7 +516,7 @@ def output_array(self) -> np.recarray:
self._log.debug("Using cached output array.")
return self._output_array

def generate_output_array(self, truncate: bool) -> np.recarray:
def generate_output_array(self, truncate: bool) -> np.recarray | None:
"""
Generate the output array for the indicator, either truncated or complete.
Expand Down Expand Up @@ -551,6 +555,10 @@ def generate_output_array(self, truncate: bool) -> np.recarray:
```
"""
if not self.initialized:
self._log.info("Indicator not initialized. Returning None.")
return None

if truncate:
self._log.debug("Generating truncated output array.")
output_array = np.concatenate(list(self._output_deque)[-self.period :])
Expand Down Expand Up @@ -705,11 +713,14 @@ def handle_bar(self, bar: Bar) -> None:

if bar.ts_event == self._last_ts_event:
self._input_deque[-1] = bar_data
self._calculate_ta(append=False)
self._update_ta_outputs(append=False)
elif bar.ts_event > self._last_ts_event:
self._input_deque.append(bar_data)
self._increment_count()
self._calculate_ta()
self._update_ta_outputs()
else:
self._data_error_counter += 1
self._log.error(f"Received out of sync bar: {bar!r}")
return

self._last_ts_event = bar.ts_event
5 changes: 5 additions & 0 deletions tests/unit_tests/indicators/test_talib_func_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ def test_from_str_invalid():
TAFunctionWrapper.from_str("INVALID_5")


def test_from_str_invalid_params():
with pytest.raises(ValueError):
TAFunctionWrapper.from_str("SMA_20_10")


def test_from_list_of_str():
indicators = ["SMA_5", "EMA_10"]
wrappers = TAFunctionWrapper.from_list_of_str(indicators)
Expand Down
Loading

0 comments on commit 7732e8e

Please sign in to comment.