Skip to content

Commit

Permalink
Add mypy (#152)
Browse files Browse the repository at this point in the history
* fix type hints with mypy

* Add to travis, fix request import in mypy 0.730
  • Loading branch information
andersonberg authored and manycoding committed Oct 4, 2019
1 parent 67a24b8 commit 0538719
Show file tree
Hide file tree
Showing 20 changed files with 95 additions and 92 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ install:
- pip install tox-travis pip -U --no-cache-dir
script:
- tox
- tox -e docs
- tox -e pep8,mypy,docs
after_success:
- tox -e codecov
deploy:
Expand Down
1 change: 1 addition & 0 deletions Pipfile
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pyarrow = "*"
cufflinks = "*"
tables = "*"
nb-black = "*"
mypy = "*"

[requires]
python_version = "3.7"
Expand Down
10 changes: 5 additions & 5 deletions src/arche/arche.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from functools import lru_cache
import logging
from typing import Iterable, Optional, Union
from typing import Iterable, Optional, Union, cast

from arche.data_quality_report import DataQualityReport
from arche.readers.items import Items, CollectionItems, JobItems, RawItems
Expand Down Expand Up @@ -106,15 +106,15 @@ def schema(self, schema_source):
def get_items(
source: Union[str, pd.DataFrame, RawItems],
count: Optional[int],
start: Union[str, int],
start: Optional[str],
filters: Optional[api.Filters],
) -> Items:
if isinstance(source, pd.DataFrame):
return Items.from_df(source)
elif isinstance(source, Iterable) and not isinstance(source, str):
return Items.from_array(source)
return Items.from_array(cast(RawItems, source))
elif helpers.is_job_key(source):
return JobItems(source, count, start or 0, filters)
return JobItems(source, count, int(start or 0), filters)
elif helpers.is_collection_key(source):
return CollectionItems(source, count, start, filters)
else:
Expand All @@ -140,7 +140,7 @@ def run_all_rules(self):
self.run_schema_rules()

def data_quality_report(self, bucket: Optional[str] = None):
if helpers.is_collection_key(self.source):
if helpers.is_collection_key(str(self.source)):
raise ValueError("Collections are not supported")
if not self.schema:
raise ValueError("Schema is empty")
Expand Down
10 changes: 5 additions & 5 deletions src/arche/data_quality_report.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from io import StringIO
import json
from typing import Optional
from typing import Optional, List


from arche.figures import tables
from arche.quality_estimation_algorithm import generate_quality_estimation
from arche.readers.items import CloudItems
from arche.readers.items import JobItems
from arche.readers.schema import Schema
from arche.report import Report
import arche.rules.coverage as coverage_rules
Expand All @@ -23,7 +23,7 @@
class DataQualityReport:
def __init__(
self,
items: CloudItems,
items: JobItems,
schema: Schema,
report: Report,
bucket: Optional[str] = None,
Expand All @@ -36,7 +36,7 @@ def __init__(
"""
self.schema = schema
self.report = report
self.figures = []
self.figures: List = []
self.appendix = self.create_appendix(self.schema.raw)
self.create_figures(items)
self.plot_to_notebook()
Expand All @@ -48,7 +48,7 @@ def __init__(
bucket=bucket,
)

def create_figures(self, items: CloudItems):
def create_figures(self, items: JobItems):
name_url_dups = self.report.results.get(
"Duplicates By **name_field, product_url_field** Tags",
duplicate_rules.find_by_name_url(items.df, self.schema.tags),
Expand Down
11 changes: 2 additions & 9 deletions src/arche/readers/items.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,6 @@ def categorize(df: pd.DataFrame) -> pd.DataFrame:
except TypeError:
continue

def origin_column_name(self, new: str) -> str:
if new in self.df.columns:
return new
for column in self.df.columns:
if column in new:
return column

@classmethod
def from_df(cls, df: pd.DataFrame):
return cls(raw=np.array(df.to_dict("records")), df=df)
Expand All @@ -66,7 +59,7 @@ def __init__(
):
self.key = key
self._count = count
self._limit = None
self._limit: int = 0
self.filters = filters
raw = self.fetch_data()
df = pd.DataFrame(list(raw))
Expand Down Expand Up @@ -104,7 +97,7 @@ def __init__(
filters: Optional[api.Filters] = None,
):
self.start_index = start_index
self.start: int = f"{key}/{start_index}"
self.start: str = f"{key}/{start_index}"
self._job: Job = None
super().__init__(key, count, filters)

Expand Down
23 changes: 13 additions & 10 deletions src/arche/readers/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
import json
import pprint
from typing import Dict, List, Union
from typing import Dict, List, Union, Any, Set, DefaultDict

from arche.tools import s3
import perfect_jsonschema
Expand Down Expand Up @@ -41,25 +41,28 @@ def __repr__(self):
return pprint.pformat(self.raw)

def get_enums(self) -> List[str]:
enums = []
enums: List[str] = []
for k, v in self.raw["properties"].items():
if "enum" in v.keys():
if isinstance(v, Dict) and "enum" in v.keys():
enums.append(k)
return enums

@staticmethod
def get_tags(schema: RawSchema) -> TaggedFields:
tagged_fields = defaultdict(list)
tagged_fields: DefaultDict[str, List[str]] = defaultdict(list)
for key, value in schema["properties"].items():
property_tags = value.get("tag", [])
if property_tags:
tagged_fields = Schema.get_field_tags(property_tags, key, tagged_fields)
return tagged_fields
if isinstance(value, Dict):
property_tags = value.get("tag")
if property_tags:
tagged_fields = Schema.get_field_tags(
property_tags, key, tagged_fields
)
return dict(tagged_fields)

@classmethod
def get_field_tags(
cls, tags: List[str], field: str, tagged_fields: defaultdict
) -> TaggedFields:
cls, tags: Set[Any], field: str, tagged_fields: DefaultDict
) -> DefaultDict[str, List[str]]:
tags = cls.parse_tag(tags)
if not tags:
raise ValueError(
Expand Down
6 changes: 4 additions & 2 deletions src/arche/report.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from functools import partial
from typing import Dict
from typing import Dict, Union

from arche import SH_URL
from arche.rules.result import Level, Outcome, Result
Expand Down Expand Up @@ -44,7 +44,9 @@ def write_summary(cls, result: Result) -> None:
cls.write_rule_outcome(rule_msg.summary, level)

@classmethod
def write_rule_outcome(cls, outcome: str, level: Level = Level.INFO) -> None:
def write_rule_outcome(
cls, outcome: Union[str, Outcome], level: Level = Level.INFO
) -> None:
if isinstance(outcome, Outcome):
outcome = outcome.name
msg = outcome
Expand Down
4 changes: 2 additions & 2 deletions src/arche/rules/duplicates.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import List
from typing import List, Set

from arche.readers.schema import TaggedFields
from arche.rules.result import Result, Outcome
Expand All @@ -18,7 +18,7 @@ def find_by_unique(df: pd.DataFrame, tagged_fields: TaggedFields) -> Result:
result.add_info(Outcome.SKIPPED)
return result

err_keys = set()
err_keys: Set = set()
for field in unique_fields:
result.items_count = df[field].count()
duplicates = df[df.duplicated(field, keep=False)][[field]]
Expand Down
2 changes: 1 addition & 1 deletion src/arche/rules/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def validate(
err_items = len(set(itertools.chain.from_iterable(errors.values())))
if errors:
result.add_error(
f"{err_items} ({err_items/len(raw_items):.0%}) items have {len(errors)} errors",
f"{err_items} ({err_items/len(list(raw_items)):.0%}) items have {len(errors)} errors", # noqa
errors=errors,
)
return result
Expand Down
3 changes: 2 additions & 1 deletion src/arche/rules/others.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import codecs
import re
from typing import Set

from arche.rules.result import Outcome, Result
import numpy as np
Expand Down Expand Up @@ -90,7 +91,7 @@ def garbage_symbols(df: pd.DataFrame) -> Result:
)

errors = {}
row_keys = set()
row_keys: Set = set()
rule_result = Result("Garbage Symbols", items_count=len(df))

for column in tqdm_notebook(
Expand Down
36 changes: 19 additions & 17 deletions src/arche/rules/price.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional, List

from arche.readers.schema import TaggedFields
from arche.rules.result import Result, Outcome
from arche.tools.helpers import is_number, ratio_diff
Expand Down Expand Up @@ -75,12 +77,12 @@ def compare_prices_for_same_urls(
missing and new `product_url_field` tagged fields.
"""
result = Result("Compare Prices For Same Urls")
url_field = tagged_fields.get("product_url_field")
if not url_field:
url_field_list: Optional[List[str]] = tagged_fields.get("product_url_field")
if not url_field_list:
result.add_info(Outcome.SKIPPED)
return result

url_field = url_field[0]
url_field = url_field_list[0]

source_df = source_df.dropna(subset=[url_field])
target_df = target_df.dropna(subset=[url_field])
Expand Down Expand Up @@ -108,11 +110,11 @@ def compare_prices_for_same_urls(
result.add_info(f"{len(same_urls)} same urls in both jobs")

diff_prices_count = 0
price_field = tagged_fields.get("product_price_field")
if not price_field:
price_field_tag = tagged_fields.get("product_price_field")
if not price_field_tag:
result.add_info("product_price_field tag is not set")
else:
price_field = price_field[0]
price_field = price_field_tag[0]
detailed_messages = []
for url in same_urls:
if url.strip() != "nan":
Expand Down Expand Up @@ -153,14 +155,14 @@ def compare_names_for_same_urls(
compare `name_field` field"""

result = Result("Compare Names Per Url")
url_field = tagged_fields.get("product_url_field")
name_field = tagged_fields.get("name_field")
if not url_field or not name_field:
url_field_list: Optional[List[str]] = tagged_fields.get("product_url_field")
name_field_list: Optional[List[str]] = tagged_fields.get("name_field")
if not url_field_list or not name_field_list:
result.add_info(Outcome.SKIPPED)
return result

name_field = name_field[0]
url_field = url_field[0]
name_field: str = name_field_list[0]
url_field: str = url_field_list[0]
diff_names_count = 0

same_urls = source_df[(source_df[url_field].isin(target_df[url_field].values))][
Expand Down Expand Up @@ -200,12 +202,12 @@ def compare_prices_for_same_names(
source_df: pd.DataFrame, target_df: pd.DataFrame, tagged_fields: TaggedFields
):
result = Result("Compare Prices For Same Names")
name_field = tagged_fields.get("name_field")
if not name_field:
name_field_tag = tagged_fields.get("name_field")
if not name_field_tag:
result.add_info(Outcome.SKIPPED)
return result

name_field = name_field[0]
name_field = name_field_tag[0]
source_df = source_df[source_df[name_field].notnull()]
target_df = target_df[target_df[name_field].notnull()]

Expand All @@ -232,12 +234,12 @@ def compare_prices_for_same_names(
result.add_info(f"{len(same_names)} same names in both jobs")

price_tag = "product_price_field"
price_field = tagged_fields.get(price_tag)
if not price_field:
price_field_tag = tagged_fields.get(price_tag)
if not price_field_tag:
result.add_info("product_price_field tag is not set")
return result

price_field = price_field[0]
price_field = price_field_tag[0]
count = 0

detailed_messages = []
Expand Down
4 changes: 2 additions & 2 deletions src/arche/rules/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class Message:
summary: str
detailed: Optional[str] = None
errors: Optional[Dict[str, Set]] = None
_err_keys: Optional[Set[Union[str, int]]] = field(default_factory=set)
_err_keys: Set[Union[str, int]] = field(default_factory=set)

@property
def err_keys(self):
Expand Down Expand Up @@ -246,7 +246,7 @@ def build_stack_bar_data(values_counts: List[pd.Series]) -> List[go.Bar]:
Returns:
A list of Bar objects.
"""
data = []
data: List[go.Bar] = []
for vc in values_counts:
data = data + [
go.Bar(
Expand Down
5 changes: 4 additions & 1 deletion src/arche/tools/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ def get_items_with_pool(
A numpy array of items
"""
active_connections_limit = 10
processes_count = min(max(helpers.cpus_count(), workers), active_connections_limit)
processes_count: int = min(
max(helpers.cpus_count() or 0, workers), active_connections_limit
)

batch_size = math.ceil(count / processes_count)

start_idxs = range(start_index, start_index + count, batch_size)
Expand Down
15 changes: 7 additions & 8 deletions src/arche/tools/bitbucket.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import re
from typing import Dict
import urllib
from urllib.request import Request


NETLOC = os.getenv("BITBUCKET_NETLOC") or "bitbucket.org"
Expand All @@ -11,22 +11,21 @@
PASS = os.getenv("BITBUCKET_PASSWORD")


def prepare_request(url: str) -> urllib.request.Request:
def prepare_request(url: str) -> Request:
if not USER or not PASS:
msg = "Credentials not found: `BITBUCKET_USER` or `BITBUCKET_PASSWORD` not set."
raise ValueError(msg)

api_url = convert_to_api_url(url, NETLOC, API_NETLOC)
return urllib.request.Request(api_url, headers=get_auth_header(USER, PASS))
return Request(api_url, headers=get_auth_header(USER, PASS))


def convert_to_api_url(url: str, netloc: str, api_netloc: str) -> str:
"""Support both regular and raw URLs"""
try:
user, repo, path = re.search(
f"https://{netloc}/(.*?)/(.*?)/(?:raw|src)/(.*)", url
).groups()
except AttributeError:
match = re.search(f"https://{netloc}/(.*?)/(.*?)/(?:raw|src)/(.*)", url)
if match:
user, repo, path = match.groups()
else:
raise ValueError("Not a valid bitbucket URL: {url}")
return f"https://{api_netloc}/2.0/repositories/{user}/{repo}/src/{path}"

Expand Down
2 changes: 1 addition & 1 deletion src/arche/tools/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def is_number(s):
return True


def cpus_count() -> int:
def cpus_count() -> Optional[int]:
try:
return len(os.sched_getaffinity(0))
except AttributeError:
Expand Down
Loading

0 comments on commit 0538719

Please sign in to comment.