From bd4e1b0ffe4347c5347cad4bc0b280a166989c52 Mon Sep 17 00:00:00 2001 From: Gus Monod Date: Tue, 26 Sep 2023 13:57:24 -0400 Subject: [PATCH] wip: Fix types --- .pre-commit-config.yaml | 13 ++++--- pyproject.toml | 7 ++++ src/github3/checks.py | 4 +- src/github3/events.py | 18 +++++---- src/github3/gists/gist.py | 7 +++- src/github3/gists/history.py | 2 +- src/github3/git.py | 72 ++++++++++++++++++------------------ src/github3/github.py | 4 +- src/github3/issues/issue.py | 2 +- src/github3/models.py | 25 ++++++------- src/github3/orgs.py | 8 ++-- src/github3/pulls.py | 2 +- src/github3/py.typed | 0 src/github3/repos/branch.py | 57 ++++++++++++++++++---------- src/github3/repos/release.py | 2 +- src/github3/repos/repo.py | 2 +- src/github3/structs.py | 40 ++++++++++---------- src/github3/users.py | 2 +- 18 files changed, 151 insertions(+), 116 deletions(-) create mode 100644 src/github3/py.typed diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1e5d818b6..5fb4406a9 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,11 +27,14 @@ repos: hooks: - id: pyupgrade args: [--py37-plus] - #- repo: https://github.com/pre-commit/mirrors-mypy - # rev: v0.910 - # hooks: - # - id: mypy - # exclude: ^(docs/|tests/) + - repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.5.1 + hooks: + - id: mypy + additional_dependencies: + - types-python-dateutil + - types-requests + exclude: ^(docs/|tests/) - repo: https://github.com/jorisroovers/gitlint rev: v0.19.1 hooks: diff --git a/pyproject.toml b/pyproject.toml index e615261ab..c6cdb1835 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,3 +86,10 @@ exclude = ''' )/ ) ''' + +[tool.mypy] +files = ["."] +exclude = [ + "^docs/", + "^tests/", +] diff --git a/src/github3/checks.py b/src/github3/checks.py index d2dcfb733..f6974beb0 100644 --- a/src/github3/checks.py +++ b/src/github3/checks.py @@ -45,7 +45,7 @@ def _update_attributes(self, pull): def _repr(self): return f"" - def to_pull(self): + def to_pull(self, conditional: bool = False): """Retrieve a full PullRequest object for this CheckPullRequest. :returns: @@ -119,7 +119,7 @@ def _repr(self): self.name, str(self.owner["login"]) ) - def to_app(self): + def to_app(self, conditional: bool = False) -> models.GitHubCore: """Retrieve a full App object for this CheckApp. :returns: diff --git a/src/github3/events.py b/src/github3/events.py index f67968c63..d97d2faaf 100644 --- a/src/github3/events.py +++ b/src/github3/events.py @@ -41,7 +41,7 @@ def _update_attributes(self, user): self.login = user["login"] self._api = self.url = user["url"] - def to_user(self): + def to_user(self, conditional: bool = False) -> models.GitHubCore: """Retrieve a full User object for this EventUser. :returns: @@ -93,7 +93,7 @@ def _update_attributes(self, org): self.login = org["login"] self._api = self.url = org["url"] - def to_org(self): + def to_org(self, conditional: bool = False) -> models.GitHubCore: """Retrieve a full Organization object for this EventOrganization. :returns: @@ -148,7 +148,7 @@ def _update_attributes(self, pull): self.locked = pull["locked"] self._api = self.url = pull["url"] - def to_pull(self): + def to_pull(self, conditional: bool = False) -> models.GitHubCore: """Retrieve a full PullRequest object for this EventPullRequest. :returns: @@ -258,7 +258,9 @@ def _update_attributes(self, comment): self.updated_at = self._strptime(comment["updated_at"]) self.user = users.ShortUser(comment["user"], self) - def to_review_comment(self): + def to_review_comment( + self, conditional: bool = False + ) -> models.GitHubCore: """Retrieve a full ReviewComment object for this EventReviewComment. :returns: @@ -269,7 +271,7 @@ def to_review_comment(self): from . import pulls comment = self._json(self._get(self._api), 200) - return pulls.ReviewComment(comment, self) + return pulls.ReviewComment(comment, self.session) refresh = to_review_comment @@ -285,7 +287,7 @@ def _update_attributes(self, issue): self.locked = issue["locked"] self._api = self.url = issue["url"] - def to_issue(self): + def to_issue(self, conditional: bool = False) -> models.GitHubCore: """Retrieve a full Issue object for this EventIssue.""" from . import issues @@ -352,7 +354,9 @@ def _update_attributes(self, comment): self.updated_at = self._strptime(comment["updated_at"]) self.user = users.ShortUser(comment["user"], self) - def to_issue_comment(self): + def to_issue_comment( + self, conditional: bool = False + ) -> models.GitHubCore: """Retrieve the full IssueComment object for this comment. :returns: diff --git a/src/github3/gists/gist.py b/src/github3/gists/gist.py index 160502ad6..740277f64 100644 --- a/src/github3/gists/gist.py +++ b/src/github3/gists/gist.py @@ -1,4 +1,5 @@ """This module contains the Gist, ShortGist, and GistFork objects.""" +import typing as t from json import dumps from . import comment @@ -31,7 +32,9 @@ class _Gist(models.GitHubCore): """ class_name = "_Gist" - _file_class = gistfile.ShortGistFile + _file_class: t.Type[ + t.Union[gistfile.ShortGistFile, gistfile.GistFile] + ] = gistfile.ShortGistFile def _update_attributes(self, gist): self.comments_count = gist["comments"] @@ -265,7 +268,7 @@ def _update_attributes(self, fork): def _repr(self): return f"" - def to_gist(self): + def to_gist(self, conditional: bool = False) -> models.GitHubCore: """Retrieve the full Gist representation of this fork. :returns: diff --git a/src/github3/gists/history.py b/src/github3/gists/history.py index 02d5c132c..61a81a69c 100644 --- a/src/github3/gists/history.py +++ b/src/github3/gists/history.py @@ -51,7 +51,7 @@ class GistHistory(models.GitHubCore): def _update_attributes(self, history) -> None: self.url = self._api = history["url"] self.version = history["version"] - self.user = users.ShortUser(history["user"], self) + self.user = users.ShortUser(history["user"], self.session) self.change_status = history["change_status"] self.additions = self.change_status.get("additions") self.deletions = self.change_status.get("deletions") diff --git a/src/github3/git.py b/src/github3/git.py index fde1eae20..150ac9182 100644 --- a/src/github3/git.py +++ b/src/github3/git.py @@ -297,42 +297,6 @@ def _repr(self): return f"" -class CommitTree(models.GitHubCore): - """This object represents the abbreviated tree data in a commit. - - The API returns different representations of different objects. When - representing a :class:`~github3.git.ShortCommit` or - :class:`~github3.git.Commit`, the API returns an abbreviated - representation of a git tree. - - This object has the following attributes: - - .. attribute:: sha - - The SHA1 of this tree in the git repository. - """ - - def _update_attributes(self, tree): - self._api = tree["url"] - self.sha = tree["sha"] - - def _repr(self): - return f"" - - def to_tree(self): - """Retrieve a full Tree object for this CommitTree. - - :returns: - The full git data about this tree - :rtype: - :class:`~github3.git.Tree` - """ - json = self._json(self._get(self._api), 200) - return self._instance_or_null(Tree, json) - - refresh = to_tree - - class Tree(models.GitHubCore): """This represents a tree object from a git repository. @@ -382,6 +346,42 @@ def recurse(self): return self._instance_or_null(Tree, json) +class CommitTree(models.GitHubCore): + """This object represents the abbreviated tree data in a commit. + + The API returns different representations of different objects. When + representing a :class:`~github3.git.ShortCommit` or + :class:`~github3.git.Commit`, the API returns an abbreviated + representation of a git tree. + + This object has the following attributes: + + .. attribute:: sha + + The SHA1 of this tree in the git repository. + """ + + def _update_attributes(self, tree): + self._api = tree["url"] + self.sha = tree["sha"] + + def _repr(self): + return f"" + + def to_tree(self, conditional: bool = False) -> models.GitHubCore: + """Retrieve a full Tree object for this CommitTree. + + :returns: + The full git data about this tree + :rtype: + :class:`~github3.git.Tree` + """ + json = self._json(self._get(self._api), 200) + return self._instance_or_null(Tree, json) + + refresh = to_tree + + class Hash(models.GitHubCore): """This is used to represent the elements of a tree. diff --git a/src/github3/github.py b/src/github3/github.py index 5844bc6ca..ac38cfb8d 100644 --- a/src/github3/github.py +++ b/src/github3/github.py @@ -3,7 +3,7 @@ import re import typing as t -import uritemplate +import uritemplate # type: ignore from . import apps from . import auths @@ -544,7 +544,7 @@ def authorize( @requires_auth def blocked_users( self, number: int = -1, etag: t.Optional[str] = None - ) -> t.Generator[users.ShortUser, None, None]: + ) -> t.Iterator[users.ShortUser]: """Iterate over the users blocked by this organization. .. versionadded:: 2.1.0 diff --git a/src/github3/issues/issue.py b/src/github3/issues/issue.py index 41b7cd6fe..ef96d9f39 100644 --- a/src/github3/issues/issue.py +++ b/src/github3/issues/issue.py @@ -1,7 +1,7 @@ """Module containing the Issue logic.""" from json import dumps -from uritemplate import URITemplate +from uritemplate import URITemplate # type: ignore from . import comment from . import event diff --git a/src/github3/models.py b/src/github3/models.py index ad4311cfa..fc4635eea 100644 --- a/src/github3/models.py +++ b/src/github3/models.py @@ -16,9 +16,6 @@ LOG = logging.getLogger(__package__) -T = t.TypeVar("T") - - class GitHubCore: """The base object for all objects that require a session. @@ -28,7 +25,7 @@ class GitHubCore: """ _ratelimit_resource = "core" - _refresh_to: t.Optional["GitHubCore"] = None + _refresh_to: t.Optional[t.Type["GitHubCore"]] = None def __init__(self, json, session: session.GitHubSession): """Initialize our basic object. @@ -240,26 +237,28 @@ def _api(self): value += f"?{self._uri.query}" return value - @staticmethod - def _uri_parse(uri): - return requests.compat.urlparse(uri) - @_api.setter def _api(self, uri): if uri: self._uri = self._uri_parse(uri) self.url = uri + @staticmethod + def _uri_parse(uri): + return requests.compat.urlparse(uri) + def _iter( self, count: int, url: str, - cls: t.Type[T], - params: t.Optional[t.Mapping[str, t.Optional[str]]] = None, + cls: t.Type["GitHubCore"], + params: t.Optional[ + t.MutableMapping[str, t.Union[str, int, None]] + ] = None, etag: t.Optional[str] = None, headers: t.Optional[t.Mapping[str, str]] = None, list_key: t.Optional[str] = None, - ) -> "structs.GitHubIterator[T]": + ) -> "structs.GitHubIterator": """Generic iterator for this project. :param int count: How many items to return. @@ -276,7 +275,7 @@ def _iter( from .structs import GitHubIterator return GitHubIterator( - count, url, cls, self, params, etag, headers, list_key + count, url, cls, self.session, params, etag, headers, list_key ) @property @@ -329,7 +328,7 @@ def refresh(self, conditional: bool = False) -> "GitHubCore": self._json_data = json self._update_attributes(json) else: - return self._refresh_to(json, self) + return self._refresh_to(json, self.session) return self def new_session(self): diff --git a/src/github3/orgs.py b/src/github3/orgs.py index 363fe357d..96fcc8390 100644 --- a/src/github3/orgs.py +++ b/src/github3/orgs.py @@ -2,7 +2,7 @@ import typing as t from json import dumps -from uritemplate import URITemplate +from uritemplate import URITemplate # type: ignore from . import models from . import users @@ -211,7 +211,7 @@ def permissions_for( headers = {"Accept": "application/vnd.github.v3.repository+json"} url = self._build_url("repos", repository, base_url=self._api) json = self._json(self._get(url, headers=headers), 200) - return ShortRepositoryWithPermissions(json, self) + return ShortRepositoryWithPermissions(json, self.session) @requires_auth def repositories(self, number=-1, etag=None): @@ -491,7 +491,7 @@ def add_repository(self, repository, team_id): # FIXME(jlk): add perms @requires_auth def blocked_users( self, number: int = -1, etag: t.Optional[str] = None - ) -> t.Generator[users.ShortUser, None, None]: + ) -> t.Iterator[users.ShortUser]: """Iterate over the users blocked by this organization. .. versionadded:: 2.1.0 @@ -749,7 +749,7 @@ def create_team( getattr(r, "full_name", r) for r in (repo_names or []) ], "maintainers": [ - getattr(m, "login", m) for m in (maintainers or []) + str(getattr(m, "login", m)) for m in (maintainers or []) ], "permission": permission, "privacy": privacy, diff --git a/src/github3/pulls.py b/src/github3/pulls.py index 28b335a80..8c6f91d54 100644 --- a/src/github3/pulls.py +++ b/src/github3/pulls.py @@ -1,7 +1,7 @@ """This module contains all the classes relating to pull requests.""" from json import dumps -from uritemplate import URITemplate +from uritemplate import URITemplate # type: ignore from . import models from . import users diff --git a/src/github3/py.typed b/src/github3/py.typed new file mode 100644 index 000000000..e69de29bb diff --git a/src/github3/repos/branch.py b/src/github3/repos/branch.py index 2dcf7a162..e0668fc78 100644 --- a/src/github3/repos/branch.py +++ b/src/github3/repos/branch.py @@ -8,7 +8,7 @@ if t.TYPE_CHECKING: from .. import apps as tapps from .. import users as tusers - from . import orgs + from .. import orgs class _Branch(models.GitHubCore): @@ -73,7 +73,7 @@ def protection(self) -> "BranchProtection": url = self._build_url("protection", base_url=self._api) resp = self._get(url) json = self._json(resp, 200) - return BranchProtection(json, self) + return BranchProtection(json, self.session) @decorators.requires_auth def protect( @@ -159,7 +159,7 @@ def protect( url = self._build_url("protection", base_url=self._api) resp = self._put(url, json=edit) json = self._json(resp, 200) - return BranchProtection(json, self) + return BranchProtection(json, self.session) @decorators.requires_auth def sync_with_upstream(self) -> t.Mapping[str, str]: @@ -277,8 +277,9 @@ class ShortBranch(_Branch): class_name = "Short Repository Branch" _refresh_to = Branch - @t.overload - def refresh(self, conditional: bool = False) -> Branch: # noqa: D102 + def refresh( # type: ignore[empty-body] + self, conditional: bool = False + ) -> models.GitHubCore: # noqa: D102 ... @@ -652,7 +653,11 @@ def add_teams( resp = self._post(self.teams_url, data=teams) json = self._json(resp, 200) - return [orgs.ShortTeam(team, self) for team in json] if json else [] + return ( + [orgs.ShortTeam(team, self.session) for team in json] + if json + else [] + ) @decorators.requires_auth def add_users( @@ -679,10 +684,14 @@ def add_users( from .. import users as _users json = self._json(self._post(self.users_url, data=users), 200) - return [_users.ShortUser(user, self) for user in json] if json else [] + return ( + [_users.ShortUser(user, self.session) for user in json] + if json + else [] + ) @decorators.requires_auth - def apps(self, number: int = -1) -> t.Generator["tapps.App", None, None]: + def apps(self, number: int = -1) -> t.Iterator["tapps.App"]: """Retrieve current list of apps with access to the protected branch. See @@ -732,7 +741,7 @@ def add_app_restrictions( apps = [getattr(a, "slug", a) for a in apps] json = self._json(self._post(self.apps_url, data=apps), 200) - return [_apps.App(a, self) for a in json] + return [_apps.App(a, self.session) for a in json] @decorators.requires_auth def replace_app_restrictions( @@ -764,7 +773,7 @@ def replace_app_restrictions( apps = [getattr(a, "slug", a) for a in apps] json = self._json(self._put(self.apps_url, data=apps), 200) - return [_apps.App(a, self) for a in json] + return [_apps.App(a, self.session) for a in json] @decorators.requires_auth def remove_app_restrictions( @@ -788,7 +797,7 @@ def remove_app_restrictions( apps = [getattr(a, "slug", a) for a in apps] json = self._json(self._delete(self.apps_url, data=apps), 200) - return [_apps.App(a, self) for a in json] + return [_apps.App(a, self.session) for a in json] @decorators.requires_auth def delete(self) -> bool: @@ -826,7 +835,11 @@ def remove_teams( resp = self._delete(self.teams_url, json=teams) json = self._json(resp, 200) - return [orgs.ShortTeam(team, self) for team in json] if json else [] + return ( + [orgs.ShortTeam(team, self.session) for team in json] + if json + else [] + ) @decorators.requires_auth def remove_users( @@ -849,7 +862,11 @@ def remove_users( json = self._json(resp, 200) from .. import users as _users - return [_users.ShortUser(user, self) for user in json] if json else [] + return ( + [_users.ShortUser(user, self.session) for user in json] + if json + else [] + ) @decorators.requires_auth def replace_teams( @@ -872,7 +889,11 @@ def replace_teams( resp = self._put(self.teams_url, json=teams) json = self._json(resp, 200) - return [orgs.ShortTeam(team, self) for team in json] if json else [] + return ( + [orgs.ShortTeam(team, self.session) for team in json] + if json + else [] + ) @decorators.requires_auth def replace_users( @@ -894,9 +915,7 @@ def replace_users( users_resp = self._put(self.users_url, json=users) return self._boolean(users_resp, 200, 404) - def teams( - self, number: int = -1 - ) -> t.Generator["orgs.ShortTeam", None, None]: + def teams(self, number: int = -1) -> t.Iterator["orgs.ShortTeam"]: """Retrieve an up-to-date listing of teams. :returns: @@ -912,9 +931,7 @@ def teams( orgs.ShortTeam, ) - def users( - self, number: int = -1 - ) -> t.Generator["tusers.ShortUser", None, None]: + def users(self, number: int = -1) -> t.Iterator["tusers.ShortUser"]: """Retrieve an up-to-date listing of users. :returns: diff --git a/src/github3/repos/release.py b/src/github3/repos/release.py index 4901b7111..c731052c7 100644 --- a/src/github3/repos/release.py +++ b/src/github3/repos/release.py @@ -1,7 +1,7 @@ """Release logic for the GitHub API.""" import json -from uritemplate import URITemplate +from uritemplate import URITemplate # type: ignore from .. import models from .. import users diff --git a/src/github3/repos/repo.py b/src/github3/repos/repo.py index 093030740..12fa01cbd 100644 --- a/src/github3/repos/repo.py +++ b/src/github3/repos/repo.py @@ -7,7 +7,7 @@ import base64 import json as jsonlib -import uritemplate as urit +import uritemplate as urit # type: ignore from . import branch from . import comment diff --git a/src/github3/structs.py b/src/github3/structs.py index 66e6ce4dc..362844d79 100644 --- a/src/github3/structs.py +++ b/src/github3/structs.py @@ -10,13 +10,11 @@ if t.TYPE_CHECKING: import requests.models + from typing_extensions import Final from . import session -T = t.TypeVar("T") - - class GitHubIterator(models.GitHubCore, collections.abc.Iterator): """The :class:`GitHubIterator` class powers all of the iter_* methods.""" @@ -24,16 +22,18 @@ def __init__( self, count: int, url: str, - cls: t.Type[T], + cls: t.Type[models.GitHubCore], session: "session.GitHubSession", - params: t.Optional[t.Mapping[str, t.Optional[str]]] = None, + params: t.Optional[ + t.MutableMapping[str, t.Union[str, int, None]] + ] = None, etag: t.Optional[str] = None, headers: t.Optional[t.Mapping[str, str]] = None, list_key: t.Optional[str] = None, ) -> None: models.GitHubCore.__init__(self, {}, session) #: Original number of items requested - self.original: t.Final[int] = count + self.original: Final[int] = count #: Number of items left in the iterator self.count: int = count #: URL the class used to make it's first GET @@ -42,9 +42,11 @@ def __init__( self.last_url: t.Optional[str] = None self._api: str = self.url #: Class for constructing an item to return - self.cls: t.Type[T] = cls + self.cls: t.Type[models.GitHubCore] = cls #: Parameters of the query string - self.params: t.Mapping[str, t.Optional[str]] = params or {} + self.params: t.MutableMapping[str, t.Optional[t.Union[str, int]]] = ( + params or {} + ) self._remove_none(self.params) # We do not set this from the parameter sent. We want this to # represent the ETag header returned by GitHub no matter what. @@ -55,11 +57,11 @@ def __init__( #: Headers generated for the GET request self.headers: t.Dict[str, str] = dict(headers or {}) #: The last response seen - self.last_response: "requests.models.Response" = None + self.last_response: t.Optional["requests.models.Response"] = None #: Last status code received self.last_status: int = 0 #: Key to get the list of items in case a dict is returned - self.list_key: t.Final[t.Optional[str]] = list_key + self.list_key: Final[t.Optional[str]] = list_key if etag: self.headers.update({"If-None-Match": etag}) @@ -69,7 +71,7 @@ def __init__( def _repr(self) -> str: return f"" - def __iter__(self) -> t.Generator[T, None, None]: + def __iter__(self) -> t.Iterator[models.GitHubCore]: self.last_url, params = self.url, self.params headers = self.headers @@ -79,9 +81,7 @@ def __iter__(self) -> t.Generator[T, None, None]: if "per_page" not in params and self.count == -1: params["per_page"] = 100 - cls = self.cls - if issubclass(self.cls, models.GitHubCore): - cls = functools.partial(self.cls, session=self) + cls = functools.partial(self.cls, session=self.session) while (self.count == -1 or self.count > 0) and self.last_url: response = self._get( @@ -90,7 +90,7 @@ def __iter__(self) -> t.Generator[T, None, None]: self.last_response = response self.last_status = response.status_code if params: - params = None # rel_next already has the params + params = {} # rel_next already has the params if not self.etag and response.headers.get("ETag"): self.etag = response.headers.get("ETag") @@ -136,7 +136,7 @@ def __iter__(self) -> t.Generator[T, None, None]: rel_next = response.links.get("next", {}) self.last_url = rel_next.get("url", "") - def __next__(self) -> T: + def __next__(self) -> models.GitHubCore: if not hasattr(self, "__i__"): self.__i__ = self.__iter__() return next(self.__i__) @@ -152,7 +152,7 @@ def refresh(self, conditional: bool = False) -> "GitHubIterator": self.__i__ = self.__iter__() return self - def next(self) -> T: + def next(self) -> models.GitHubCore: return self.__next__() @@ -172,9 +172,11 @@ def __init__( self, count: int, url: str, - cls: t.Type[T], + cls: t.Type[models.GitHubCore], session: "session.GitHubSession", - params: t.Optional[t.Mapping[str, t.Optional[str]]] = None, + params: t.Optional[ + t.MutableMapping[str, t.Union[int, str, None]] + ] = None, etag: t.Optional[str] = None, headers: t.Optional[t.Mapping[str, str]] = None, ): diff --git a/src/github3/users.py b/src/github3/users.py index a2a55644e..dc47fcb6f 100644 --- a/src/github3/users.py +++ b/src/github3/users.py @@ -2,7 +2,7 @@ import typing as t from json import dumps -from uritemplate import URITemplate +from uritemplate import URITemplate # type: ignore from . import models from .decorators import requires_auth