Skip to content

Commit

Permalink
Refine mypy config and fix issues
Browse files Browse the repository at this point in the history
  • Loading branch information
cjdsellers committed Apr 1, 2024
1 parent 4f939b1 commit 5d12ae7
Show file tree
Hide file tree
Showing 18 changed files with 118 additions and 47 deletions.
4 changes: 1 addition & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,8 @@ repos:
hooks:
- id: mypy
args: [
"--ignore-missing-imports",
"--config", "pyproject.toml",
"--allow-incomplete-defs",
"--no-strict-optional", # Fixing in progress
"--warn-no-return",
]
additional_dependencies: [
msgspec,
Expand Down
10 changes: 5 additions & 5 deletions nautilus_trader/adapters/betfair/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ def __init__(
username: str,
password: str,
app_key: str,
):
) -> None:
# Config
self.username = username
self.password = password
Expand Down Expand Up @@ -124,15 +124,15 @@ def update_headers(self, login_resp: LoginResponse) -> None:
},
)

def reset_headers(self):
def reset_headers(self) -> None:
self._headers = {
"Accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
"User-Agent": nautilus_trader.USER_AGENT,
"X-Application": self.app_key,
}

async def connect(self):
async def connect(self) -> None:
if self.session_token is not None:
self._log.warning("Session token exists (already connected), skipping")
return
Expand All @@ -145,12 +145,12 @@ async def connect(self):
self._log.info("Login success", color=LogColor.GREEN)
self.update_headers(login_resp=resp)

async def disconnect(self):
async def disconnect(self) -> None:
self._log.info("Disconnecting...")
self.reset_headers()
self._log.info("Disconnected", color=LogColor.GREEN)

async def keep_alive(self):
async def keep_alive(self) -> None:
"""
Renew authentication.
"""
Expand Down
1 change: 1 addition & 0 deletions nautilus_trader/backtest/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def main(
with fsspec.open(fsspec_url, "rb") as f:
data = f.read().decode()
else:
assert raw is not None # Type checking
data = raw.encode()

configs = msgspec.json.decode(
Expand Down
2 changes: 1 addition & 1 deletion nautilus_trader/core/nautilus_pyo3.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -1126,7 +1126,7 @@ class AccountState:
self,
account_id: AccountId,
account_type: AccountType,
base_currency: Currency,
base_currency: Currency | None,
balances: list[AccountBalance],
margins: list[MarginBalance],
is_reported: bool,
Expand Down
11 changes: 10 additions & 1 deletion nautilus_trader/examples/strategies/ema_cross_stop_entry.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from nautilus_trader.model.events import OrderFilled
from nautilus_trader.model.identifiers import InstrumentId
from nautilus_trader.model.instruments import Instrument
from nautilus_trader.model.objects import Price
from nautilus_trader.model.orders import MarketIfTouchedOrder
from nautilus_trader.model.orders import TrailingStopMarketOrder
from nautilus_trader.trading.strategy import Strategy
Expand Down Expand Up @@ -145,7 +146,7 @@ def __init__(self, config: EMACrossStopEntryConfig) -> None:
self.atr = AverageTrueRange(config.atr_period)

self.instrument: Instrument | None = None # Initialized in `on_start()`
self.tick_size = None # Initialized in `on_start()`
self.tick_size: Price | None = None # Initialized in `on_start()`

# Users order management variables
self.entry = None
Expand Down Expand Up @@ -266,6 +267,10 @@ def entry_buy(self, last_bar: Bar) -> None:
self.log.error("No instrument loaded")
return

if not self.tick_size:
self.log.error("No tick size loaded")
return

order: MarketIfTouchedOrder = self.order_factory.market_if_touched(
instrument_id=self.instrument_id,
order_side=OrderSide.BUY,
Expand Down Expand Up @@ -301,6 +306,10 @@ def entry_sell(self, last_bar: Bar) -> None:
self.log.error("No instrument loaded")
return

if not self.tick_size:
self.log.error("No tick size loaded")
return

order: MarketIfTouchedOrder = self.order_factory.market_if_touched(
instrument_id=self.instrument_id,
order_side=OrderSide.SELL,
Expand Down
4 changes: 4 additions & 0 deletions nautilus_trader/indicators/ta_lib/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,9 @@ def _update_ta_outputs(self, append: bool = True) -> None:
"""
self._log.debug("Calculating outputs.")

if self._input_deque is None:
return

combined_output = np.zeros(1, dtype=self._output_dtypes)
combined_output["ts_event"] = self._input_deque[-1]["ts_event"].item()
combined_output["ts_init"] = self._input_deque[-1]["ts_init"].item()
Expand All @@ -402,6 +405,7 @@ def _update_ta_outputs(self, append: bool = True) -> None:
combined_output["volume"] = self._input_deque[-1]["volume"].item()

input_array = np.concatenate(self._input_deque)
assert self._indicators # Type checking
for indicator in self._indicators:
self._log.debug(f"Calculating {indicator.name} outputs.")
inputs_dict = {name: input_array[name] for name in input_array.dtype.names}
Expand Down
60 changes: 45 additions & 15 deletions nautilus_trader/test_kit/mocks/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,28 +47,40 @@ def __init__(self, bar_type: BarType) -> None:
self.calls: list[str] = []

def on_start(self) -> None:
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)
self.register_indicator_for_bars(self.bar_type, self.ema1)
self.register_indicator_for_bars(self.bar_type, self.ema2)

def on_instrument(self, instrument) -> None:
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)
self.store.append(instrument)

def on_ticker(self, ticker):
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)
self.store.append(ticker)

def on_quote_tick(self, tick):
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)
self.store.append(tick)

def on_trade_tick(self, tick) -> None:
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)
self.store.append(tick)

def on_bar(self, bar) -> None:
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)
self.store.append(bar)

if bar.bar_type != self.bar_type:
Expand All @@ -94,36 +106,54 @@ def on_bar(self, bar) -> None:
self.position_id = sell_order.client_order_id

def on_data(self, data) -> None:
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)
self.store.append(data)

def on_strategy_data(self, data) -> None:
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)
self.store.append(data)

def on_event(self, event) -> None:
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)
self.store.append(event)

def on_stop(self) -> None:
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)

def on_resume(self) -> None:
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)

def on_reset(self) -> None:
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)

def on_save(self) -> dict[str, bytes]:
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)
return {"UserState": b"1"}

def on_load(self, state: dict[str, bytes]) -> None:
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)
self.store.append(state)

def on_dispose(self) -> None:
self.calls.append(inspect.currentframe().f_code.co_name)
current_frame = inspect.currentframe()
assert current_frame # Type checking
self.calls.append(current_frame.f_code.co_name)


class KaboomStrategy(Strategy):
Expand Down
2 changes: 1 addition & 1 deletion nautilus_trader/test_kit/rust/events_pyo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,7 +483,7 @@ def order_filled(
order_side=order.side,
order_type=order.order_type,
last_qty=last_qty,
last_px=last_px or order.price,
last_px=last_px or order.price or Price.from_str("1.00000"),
currency=instrument.quote_currency,
commission=commission,
liquidity_side=liquidity_side,
Expand Down
2 changes: 1 addition & 1 deletion nautilus_trader/test_kit/stubs/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,12 @@ def strategies_config() -> list[ImportableStrategyConfig]:

@staticmethod
def backtest_engine_config(
catalog: ParquetDataCatalog,
log_level="INFO",
bypass_logging: bool = True,
bypass_risk: bool = False,
allow_cash_position: bool = True,
persist: bool = False,
catalog: ParquetDataCatalog | None = None,
strategies: list[ImportableStrategyConfig] | None = None,
) -> BacktestEngineConfig:
if persist:
Expand Down
11 changes: 11 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -259,10 +259,21 @@ disallow_incomplete_defs = true
explicit_package_bases = true
ignore_missing_imports = true
namespace_packages = true
no_strict_optional = false
warn_no_return = true
warn_unused_configs = true
warn_unused_ignores = true

[[tool.mypy.overrides]]
no_strict_optional = true
module = [
"examples/*",
"nautilus_trader/adapters/betfair/*",
"nautilus_trader/adapters/binance/*",
"nautilus_trader/adapters/interactive_brokers/*",
"nautilus_trader/indicators/ta_lib/*",
]

[tool.pytest.ini_options]
testpaths = ["tests"]
addopts = "-ra --new-first --failed-first --doctest-modules --doctest-glob=\"*.pyx\""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -845,13 +845,14 @@ async def test_generate_order_status_report_client_id(
instrument_provider.add(instrument)

# Act
report: OrderStatusReport = await exec_client.generate_order_status_report(
report: OrderStatusReport | None = await exec_client.generate_order_status_report(
instrument_id=instrument.id,
venue_order_id=VenueOrderId("1"),
client_order_id=None,
)

# Assert
assert report
assert report.order_status == OrderStatus.ACCEPTED
assert report.price == Price(5.0, BETFAIR_PRICE_PRECISION)
assert report.quantity == Quantity(10.0, BETFAIR_QUANTITY_PRECISION)
Expand All @@ -874,13 +875,14 @@ async def test_generate_order_status_report_venue_order_id(
venue_order_id = VenueOrderId("323427122115")

# Act
report: OrderStatusReport = await exec_client.generate_order_status_report(
report: OrderStatusReport | None = await exec_client.generate_order_status_report(
instrument_id=instrument.id,
venue_order_id=venue_order_id,
client_order_id=client_order_id,
)

# Assert
assert report
assert report.order_status == OrderStatus.ACCEPTED
assert report.price == Price(5.0, BETFAIR_PRICE_PRECISION)
assert report.quantity == Quantity(10.0, BETFAIR_QUANTITY_PRECISION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def test_order_book_integrity(self, filename, book_count) -> None:
result = [book.count for book in books.values()]
assert result == book_count

def test_betfair_trade_sizes(self): # noqa: C901
def test_betfair_trade_sizes(self) -> None: # noqa: C901
mcms = BetfairDataProvider.read_mcm("1.206064380.bz2")
trade_ticks: dict[InstrumentId, list[TradeTick]] = defaultdict(list)
betfair_tv: dict[int, dict[float, float]] = {}
Expand Down Expand Up @@ -338,7 +338,7 @@ def test_betfair_trade_sizes(self): # noqa: C901


class TestBetfairParsing:
def setup(self):
def setup(self) -> None:
# Fixture Setup
self.loop = asyncio.new_event_loop()
self.clock = LiveClock()
Expand Down
2 changes: 1 addition & 1 deletion tests/integration_tests/adapters/betfair/test_kit.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def betfair_backtest_run_config(
),
]
if add_strategy
else None
else []
),
)
run_config = BacktestRunConfig( # typing: ignore
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,12 @@


class TestBybitInstruments:
def setup(self):
def setup(self) -> None:
# linear
linear_data: BybitInstrumentsLinearResponse = msgspec.json.Decoder(
BybitInstrumentsLinearResponse,
).decode(
pkgutil.get_data(
pkgutil.get_data( # type: ignore [arg-type]
"tests.integration_tests.adapters.bybit.resources.http_responses.linear",
"instruments.json",
),
Expand All @@ -49,7 +49,7 @@ def setup(self):
spot_data: BybitInstrumentsSpotResponse = msgspec.json.Decoder(
BybitInstrumentsSpotResponse,
).decode(
pkgutil.get_data(
pkgutil.get_data( # type: ignore [arg-type]
"tests.integration_tests.adapters.bybit.resources.http_responses.spot",
"instruments.json",
),
Expand All @@ -59,7 +59,7 @@ def setup(self):
option_data: BybitInstrumentsOptionResponse = msgspec.json.Decoder(
BybitInstrumentsOptionResponse,
).decode(
pkgutil.get_data(
pkgutil.get_data( # type: ignore [arg-type]
"tests.integration_tests.adapters.bybit.resources.http_responses.option",
"instruments.json",
),
Expand Down
Loading

0 comments on commit 5d12ae7

Please sign in to comment.