Skip to content

Commit

Permalink
add basic typing
Browse files Browse the repository at this point in the history
  • Loading branch information
honzajavorek committed Dec 28, 2024
1 parent 7ffbdec commit c7d2339
Show file tree
Hide file tree
Showing 4 changed files with 48 additions and 36 deletions.
58 changes: 34 additions & 24 deletions fiobank.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import warnings
from datetime import date, datetime
from decimal import Decimal
from typing import Any, Callable, Generator

import requests
from tenacity import (
Expand All @@ -15,15 +16,15 @@
__all__ = ("FioBank", "ThrottlingError")


def coerce_amount(value):
def coerce_amount(value: int | float) -> Decimal:
if isinstance(value, int):
return Decimal(value)
if isinstance(value, float):
return Decimal(str(value))
raise ValueError(value)


def coerce_date(value):
def coerce_date(value: datetime | date | str):
if isinstance(value, datetime):
return value.date()
elif isinstance(value, date):
Expand All @@ -32,7 +33,7 @@ def coerce_date(value):
return datetime.strptime(value[:10], "%Y-%m-%d").date()


def sanitize_value(value, convert=None):
def sanitize_value(value: Any, convert: Callable | None = None) -> Any:
if isinstance(value, str):
value = value.strip() or None
if convert and value is not None:
Expand All @@ -43,7 +44,7 @@ def sanitize_value(value, convert=None):
class ThrottlingError(Exception):
"""Throttling error raised when the API is being used too fast."""

def __str__(self):
def __str__(self) -> str:
return "Token can be used only once per 30s."


Expand All @@ -60,7 +61,7 @@ class FioBank(object):

_amount_re = re.compile(r"\-?\d+(\.\d+)? [A-Z]{3}")

def __init__(self, token, decimal=False):
def __init__(self, token: str, decimal=False):
self.token = token

if decimal:
Expand Down Expand Up @@ -114,7 +115,7 @@ def __init__(self, token, decimal=False):
stop=stop_after_attempt(3),
wait=wait_random_exponential(max=2 * 60),
)
def _request(self, action, **params):
def _request(self, action: str, **params) -> dict | None:
url_template = self.base_url + self.actions[action]
url = url_template.format(token=self.token, **params)

Expand All @@ -127,7 +128,7 @@ def _request(self, action, **params):
return response.json(parse_float=self.float_type)
return None

def _parse_info(self, data):
def _parse_info(self, data: dict) -> dict:
# parse data from API
info = {}
for key, value in data["accountStatement"]["info"].items():
Expand All @@ -143,7 +144,7 @@ def _parse_info(self, data):
# return data
return info

def _parse_transactions(self, data):
def _parse_transactions(self, data: dict) -> Generator[dict, None, None]:
schema = self.transaction_schema
try:
entries = data["accountStatement"]["transactionList"]["transaction"] # NOQA
Expand Down Expand Up @@ -180,7 +181,7 @@ def _parse_transactions(self, data):
# generate transaction data
yield trans

def _add_account_number_full(self, obj):
def _add_account_number_full(self, obj: dict) -> None:
account_number = obj.get("account_number")
bank_code = obj.get("bank_code")

Expand All @@ -191,22 +192,29 @@ def _add_account_number_full(self, obj):

obj["account_number_full"] = account_number_full

def info(self):
def info(self) -> dict:
today = date.today()
data = self._request("periods", from_date=today, to_date=today)
return self._parse_info(data)

def period(self, from_date, to_date):
data = self._request(
if data := self._request("periods", from_date=today, to_date=today):
return self._parse_info(data)
raise ValueError("No data available")

def period(
self, from_date: date | datetime | str, to_date: date | datetime | str
) -> Generator[dict, None, None]:
if data := self._request(
"periods", from_date=coerce_date(from_date), to_date=coerce_date(to_date)
)
return self._parse_transactions(data)

def statement(self, year, number):
data = self._request("by-id", year=year, number=number)
return self._parse_transactions(data)

def last(self, from_id=None, from_date=None):
):
return self._parse_transactions(data)
raise ValueError("No data available")

def statement(self, year: int, number: int) -> Generator[dict, None, None]:
if data := self._request("by-id", year=year, number=number):
return self._parse_transactions(data)
raise ValueError("No data available")

def last(
self, from_id: int | None = None, from_date: date | datetime | str | None = None
) -> Generator[dict, None, None]:
if from_id and from_date:
raise ValueError("Only one constraint is allowed.")

Expand All @@ -215,4 +223,6 @@ def last(self, from_id=None, from_date=None):
elif from_date:
self._request("set-last-date", from_date=coerce_date(from_date))

return self._parse_transactions(self._request("last"))
if data := self._request("last"):
return self._parse_transactions(data)
raise ValueError("No data available")
8 changes: 4 additions & 4 deletions tests/test_coerce_date.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,17 @@
"2016-08-03T21:03:42",
],
)
def test_coerce_date(test_input):
def test_coerce_date(test_input: date | datetime | str):
assert coerce_date(test_input) == date(2016, 8, 3)


@pytest.mark.parametrize("test_input", [42, True])
def test_coerce_date_invalid_type(test_input):
def test_coerce_date_invalid_type(test_input: int | bool):
with pytest.raises(TypeError):
coerce_date(test_input)
coerce_date(test_input) # type: ignore


@pytest.mark.parametrize("test_input", ["21:03:42", "[email protected]"])
def test_coerce_date_invalid_value(test_input):
def test_coerce_date_invalid_value(test_input: str):
with pytest.raises(ValueError):
coerce_date(test_input)
10 changes: 5 additions & 5 deletions tests/test_fiobank.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@


@pytest.fixture()
def token() -> uuid.UUID:
return uuid.uuid4()
def token() -> str:
return str(uuid.uuid4())


@pytest.fixture()
Expand All @@ -32,7 +32,7 @@ def transactions_json() -> dict:


@pytest.fixture()
def client_float(token: uuid.UUID, transactions_text: str):
def client_float(token: str, transactions_text: str):
with responses.RequestsMock(assert_all_requests_are_fired=False) as resps:
url = re.compile(
re.escape(FioBank.base_url)
Expand All @@ -50,7 +50,7 @@ def client_float(token: uuid.UUID, transactions_text: str):


@pytest.fixture()
def client_decimal(token: uuid.UUID, transactions_text: str):
def client_decimal(token: str, transactions_text: str):
with responses.RequestsMock(assert_all_requests_are_fired=False) as resps:
url = re.compile(
re.escape(FioBank.base_url)
Expand Down Expand Up @@ -447,7 +447,7 @@ def test_transactions_parse_no_account_number_full(transactions_json):
assert sdk_transaction["account_number_full"] is None


def test_409_conflict(token: uuid.UUID, transactions_text: str):
def test_409_conflict(token: str, transactions_text: str):
with responses.RequestsMock(registry=OrderedRegistry) as resps:
url = re.compile(
re.escape(FioBank.base_url)
Expand Down
8 changes: 5 additions & 3 deletions tests/test_sanitize_value.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Any, Callable

import pytest

from fiobank import sanitize_value
Expand All @@ -14,7 +16,7 @@
(30.8, 30.8),
],
)
def test_sanitize_value_no_effect(test_input, expected):
def test_sanitize_value_no_effect(test_input: Any, expected: Any):
assert sanitize_value(test_input) == expected


Expand All @@ -26,7 +28,7 @@ def test_sanitize_value_no_effect(test_input, expected):
("\nfio ", "fio"),
],
)
def test_sanitize_value_strip(test_input, expected):
def test_sanitize_value_strip(test_input: str, expected: str | None):
assert sanitize_value(test_input) == expected


Expand All @@ -44,5 +46,5 @@ def test_sanitize_value_strip(test_input, expected):
(False, bool, False),
],
)
def test_sanitize_value_convert(test_input, convert, expected):
def test_sanitize_value_convert(test_input: Any, convert: Callable, expected: Any):
assert sanitize_value(test_input, convert) == expected

0 comments on commit c7d2339

Please sign in to comment.