Skip to content

Commit

Permalink
mypy: disallow_untyped_defs
Browse files Browse the repository at this point in the history
  • Loading branch information
tamird committed May 9, 2024
1 parent 68019d7 commit 5734e5e
Show file tree
Hide file tree
Showing 11 changed files with 103 additions and 84 deletions.
12 changes: 8 additions & 4 deletions email_validator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import TYPE_CHECKING

# Export the main method, helper methods, and the public data types.
from .exceptions_types import ValidatedEmail, EmailNotValidError, \
EmailSyntaxError, EmailUndeliverableError
Expand All @@ -9,12 +11,14 @@
"EmailSyntaxError", "EmailUndeliverableError",
"caching_resolver", "__version__"]


def caching_resolver(*args, **kwargs):
# Lazy load `deliverability` as it is slow to import (due to dns.resolver)
if TYPE_CHECKING:
from .deliverability import caching_resolver
else:
def caching_resolver(*args, **kwargs):
# Lazy load `deliverability` as it is slow to import (due to dns.resolver)
from .deliverability import caching_resolver

return caching_resolver(*args, **kwargs)
return caching_resolver(*args, **kwargs)


# These global attributes are a part of the library's API and can be
Expand Down
6 changes: 3 additions & 3 deletions email_validator/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,14 @@
import json
import os
import sys
from typing import Any, Dict
from typing import Any, Dict, Optional

from .validate_email import validate_email
from .validate_email import validate_email, _Resolver
from .deliverability import caching_resolver
from .exceptions_types import EmailNotValidError


def main(dns_resolver=None):
def main(dns_resolver: Optional[_Resolver] = None) -> None:
# The dns_resolver argument is for tests.

# Set options from environment variables.
Expand Down
4 changes: 2 additions & 2 deletions email_validator/deliverability.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,9 @@ def validate_email_deliverability(domain: str, domain_i18n: str, timeout: Option
# https://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml
# https://www.iana.org/assignments/iana-ipv6-special-registry/iana-ipv6-special-registry.xhtml
# (Issue #134.)
def is_global_addr(ipaddr):
def is_global_addr(address: Any) -> bool:
try:
ipaddr = ipaddress.ip_address(ipaddr)
ipaddr = ipaddress.ip_address(address)
except ValueError:
return False
return ipaddr.is_global
Expand Down
20 changes: 10 additions & 10 deletions email_validator/exceptions_types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import warnings
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union


class EmailNotValidError(ValueError):
Expand Down Expand Up @@ -63,32 +63,32 @@ class ValidatedEmail:
mx_fallback_type: str

"""The display name in the original input text, unquoted and unescaped, or None."""
display_name: str
display_name: Optional[str]

"""Tests use this constructor."""
def __init__(self, **kwargs):
def __init__(self, **kwargs: Any) -> None:
for k, v in kwargs.items():
setattr(self, k, v)

def __repr__(self):
def __repr__(self) -> str:
return f"<ValidatedEmail {self.normalized}>"

"""For backwards compatibility, support old field names."""
def __getattr__(self, key):
def __getattr__(self, key: str) -> str:
if key == "original_email":
return self.original
if key == "email":
return self.normalized
raise AttributeError(key)

@property
def email(self):
def email(self) -> str:
warnings.warn("ValidatedEmail.email is deprecated and will be removed, use ValidatedEmail.normalized instead", DeprecationWarning)
return self.normalized

"""For backwards compatibility, some fields are also exposed through a dict-like interface. Note
that some of the names changed when they became attributes."""
def __getitem__(self, key):
def __getitem__(self, key: str) -> Union[Optional[str], bool, List[Tuple[int, str]]]:
warnings.warn("dict-like access to the return value of validate_email is deprecated and may not be supported in the future.", DeprecationWarning, stacklevel=2)
if key == "email":
return self.normalized
Expand All @@ -109,7 +109,7 @@ def __getitem__(self, key):
raise KeyError()

"""Tests use this."""
def __eq__(self, other):
def __eq__(self, other: object) -> bool:
if not isinstance(other, ValidatedEmail):
return False
return (
Expand All @@ -127,7 +127,7 @@ def __eq__(self, other):
)

"""This helps producing the README."""
def as_constructor(self):
def as_constructor(self) -> str:
return "ValidatedEmail(" \
+ ",".join(f"\n {key}={repr(getattr(self, key))}"
for key in ('normalized', 'local_part', 'domain',
Expand All @@ -139,7 +139,7 @@ def as_constructor(self):
+ ")"

"""Convenience method for accessing ValidatedEmail as a dict"""
def as_dict(self):
def as_dict(self) -> Dict[str, Any]:
d = self.__dict__
if d.get('domain_address'):
d['domain_address'] = repr(d['domain_address'])
Expand Down
34 changes: 22 additions & 12 deletions email_validator/syntax.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .exceptions_types import EmailSyntaxError
from .exceptions_types import EmailSyntaxError, ValidatedEmail
from .rfc_constants import EMAIL_MAX_LENGTH, LOCAL_PART_MAX_LENGTH, DOMAIN_MAX_LENGTH, \
DOT_ATOM_TEXT, DOT_ATOM_TEXT_INTL, ATEXT_RE, ATEXT_INTL_DOT_RE, ATEXT_HOSTNAME_INTL, QTEXT_INTL, \
DNS_LABEL_LENGTH_LIMIT, DOT_ATOM_TEXT_HOSTNAME, DOMAIN_NAME_REGEX, DOMAIN_LITERAL_CHARS
Expand All @@ -7,10 +7,10 @@
import unicodedata
import idna # implements IDNA 2008; Python's codec is only IDNA 2003
import ipaddress
from typing import Optional, TypedDict, Union
from typing import Optional, Tuple, TypedDict, Union


def split_email(email):
def split_email(email: str) -> Tuple[Optional[str], str, str, bool]:
# Return the display name, unescaped local part, and domain part
# of the address, and whether the local part was quoted. If no
# display name was present and angle brackets do not surround
Expand Down Expand Up @@ -46,7 +46,7 @@ def split_email(email):
# We assume the input string is already stripped of leading and
# trailing CFWS.

def split_string_at_unquoted_special(text, specials):
def split_string_at_unquoted_special(text: str, specials: Tuple[str, ...]) -> Tuple[str, str]:
# Split the string at the first character in specials (an @-sign
# or left angle bracket) that does not occur within quotes.
inside_quote = False
Expand Down Expand Up @@ -77,7 +77,7 @@ def split_string_at_unquoted_special(text, specials):

return left_part, right_part

def unquote_quoted_string(text):
def unquote_quoted_string(text: str) -> Tuple[str, bool]:
# Remove surrounding quotes and unescape escaped backslashes
# and quotes. Escapes are parsed liberally. I think only
# backslashes and quotes can be escaped but we'll allow anything
Expand Down Expand Up @@ -155,15 +155,15 @@ def unquote_quoted_string(text):
return display_name, local_part, domain_part, is_quoted_local_part


def get_length_reason(addr, utf8=False, limit=EMAIL_MAX_LENGTH):
def get_length_reason(addr: str, utf8: bool = False, limit: int = EMAIL_MAX_LENGTH) -> str:
"""Helper function to return an error message related to invalid length."""
diff = len(addr) - limit
prefix = "at least " if utf8 else ""
suffix = "s" if diff > 1 else ""
return f"({prefix}{diff} character{suffix} too many)"


def safe_character_display(c):
def safe_character_display(c: str) -> str:
# Return safely displayable characters in quotes.
if c == '\\':
return f"\"{c}\"" # can't use repr because it escapes it
Expand Down Expand Up @@ -351,7 +351,7 @@ def validate_email_local_part(local: str, allow_smtputf8: bool = True, allow_emp
raise EmailSyntaxError("The email address contains invalid characters before the @-sign.")


def check_unsafe_chars(s, allow_space=False):
def check_unsafe_chars(s: str, allow_space: bool = False) -> None:
# Check for unsafe characters or characters that would make the string
# invalid or non-sensible Unicode.
bad_chars = set()
Expand Down Expand Up @@ -403,7 +403,7 @@ def check_unsafe_chars(s, allow_space=False):
+ ", ".join(safe_character_display(c) for c in sorted(bad_chars)) + ".")


def check_dot_atom(label, start_descr, end_descr, is_hostname):
def check_dot_atom(label: str, start_descr: str, end_descr: str, is_hostname: bool) -> None:
# RFC 5322 3.2.3
if label.endswith("."):
raise EmailSyntaxError(end_descr.format("period"))
Expand All @@ -422,7 +422,12 @@ def check_dot_atom(label, start_descr, end_descr, is_hostname):
raise EmailSyntaxError("An email address cannot have a period and a hyphen next to each other.")


def validate_email_domain_name(domain, test_environment=False, globally_deliverable=True):
class DomainNameValidationResult(TypedDict):
ascii_domain: str
domain: str


def validate_email_domain_name(domain: str, test_environment: bool = False, globally_deliverable: bool = True) -> DomainNameValidationResult:
"""Validates the syntax of the domain part of an email address."""

# Check for invalid characters before normalization.
Expand Down Expand Up @@ -586,7 +591,7 @@ def validate_email_domain_name(domain, test_environment=False, globally_delivera
}


def validate_email_length(addrinfo):
def validate_email_length(addrinfo: ValidatedEmail) -> None:
# If the email address has an ASCII representation, then we assume it may be
# transmitted in ASCII (we can't assume SMTPUTF8 will be used on all hops to
# the destination) and the length limit applies to ASCII characters (which is
Expand Down Expand Up @@ -627,7 +632,12 @@ def validate_email_length(addrinfo):
raise EmailSyntaxError(f"The email address is too long {reason}.")


def validate_email_domain_literal(domain_literal):
class DomainLiteralValidationResult(TypedDict):
domain_address: Union[ipaddress.IPv4Address, ipaddress.IPv6Address]
domain: str


def validate_email_domain_literal(domain_literal: str) -> DomainLiteralValidationResult:
# This is obscure domain-literal syntax. Parse it and return
# a compressed/normalized address.
# RFC 5321 4.1.3 and RFC 5322 3.4.1.
Expand Down
14 changes: 7 additions & 7 deletions email_validator/validate_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,20 +110,20 @@ def validate_email(

elif domain_part.startswith("[") and domain_part.endswith("]"):
# Parse the address in the domain literal and get back a normalized domain.
domain_part_info = validate_email_domain_literal(domain_part[1:-1])
domain_literal_info = validate_email_domain_literal(domain_part[1:-1])
if not allow_domain_literal:
raise EmailSyntaxError("A bracketed IP address after the @-sign is not allowed here.")
ret.domain = domain_part_info["domain"]
ret.ascii_domain = domain_part_info["domain"] # Domain literals are always ASCII.
ret.domain_address = domain_part_info["domain_address"]
ret.domain = domain_literal_info["domain"]
ret.ascii_domain = domain_literal_info["domain"] # Domain literals are always ASCII.
ret.domain_address = domain_literal_info["domain_address"]
is_domain_literal = True # Prevent deliverability checks.

else:
# Check the syntax of the domain and get back a normalized
# internationalized and ASCII form.
domain_part_info = validate_email_domain_name(domain_part, test_environment=test_environment, globally_deliverable=globally_deliverable)
ret.domain = domain_part_info["domain"]
ret.ascii_domain = domain_part_info["ascii_domain"]
domain_name_info = validate_email_domain_name(domain_part, test_environment=test_environment, globally_deliverable=globally_deliverable)
ret.domain = domain_name_info["domain"]
ret.ascii_domain = domain_name_info["ascii_domain"]

# Construct the complete normalized form.
ret.normalized = ret.local_part + "@" + ret.domain
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ check_untyped_defs = true
disallow_incomplete_defs = true
# disallow_untyped_calls = true
disallow_untyped_decorators = true
# disallow_untyped_defs = true
disallow_untyped_defs = true

warn_redundant_casts = true
warn_unused_ignores = true
Expand Down
39 changes: 21 additions & 18 deletions tests/mocked_dns_response.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from typing import Any, Dict, Iterator, Optional

import dns.rdataset
import dns.resolver
import json
import os.path
Expand All @@ -23,7 +26,7 @@ class MockedDnsResponseData:
INSTANCE = None

@staticmethod
def create_resolver():
def create_resolver() -> dns.resolver.Resolver:
if MockedDnsResponseData.INSTANCE is None:
# Create a singleton instance of this class and load the saved DNS responses.
# Except when BUILD_MOCKED_DNS_RESPONSE_DATA is true, don't load the data.
Expand All @@ -37,20 +40,19 @@ def create_resolver():
dns_resolver = dns.resolver.Resolver(configure=BUILD_MOCKED_DNS_RESPONSE_DATA)
return caching_resolver(cache=MockedDnsResponseData.INSTANCE, dns_resolver=dns_resolver)

def __init__(self):
self.data = {}

def load(self):
# Loads the saved DNS response data from the JSON file and
# re-structures it into dnspython classes.
class Ans: # mocks the dns.resolver.Answer class
def __init__(self) -> None:
self.data: Dict[dns.resolver.CacheKey, Optional[MockedDnsResponseData.Ans]] = {}

def __init__(self, rrset):
self.rrset = rrset
# Loads the saved DNS response data from the JSON file and
# re-structures it into dnspython classes.
class Ans: # mocks the dns.resolver.Answer class
def __init__(self, rrset: dns.rdataset.Rdataset) -> None:
self.rrset = rrset

def __iter__(self):
return iter(self.rrset)
def __iter__(self) -> Iterator[Any]:
return iter(self.rrset)

def load(self) -> None:
with open(self.DATA_PATH) as f:
data = json.load(f)
for item in data:
Expand All @@ -62,11 +64,11 @@ def __iter__(self):
for rr in item["answer"]
]
if item["answer"]:
self.data[key] = Ans(dns.rdataset.from_rdata_list(0, rdatas=rdatas))
self.data[key] = MockedDnsResponseData.Ans(dns.rdataset.from_rdata_list(0, rdatas=rdatas))
else:
self.data[key] = None

def save(self):
def save(self) -> None:
# Re-structure as a list with basic data types.
data = [
{
Expand All @@ -81,11 +83,12 @@ def save(self):
])
}
for key, value in self.data.items()
if value is not None
]
with open(self.DATA_PATH, "w") as f:
json.dump(data, f, indent=True)

def get(self, key):
def get(self, key: dns.resolver.CacheKey) -> Optional[Ans]:
# Special-case a domain to create a timeout.
if key[0].to_text() == "timeout.com.":
raise dns.exception.Timeout()
Expand All @@ -108,16 +111,16 @@ def get(self, key):

raise ValueError(f"Saved DNS data did not contain query: {key}")

def put(self, key, value):
def put(self, key: dns.resolver.CacheKey, value: Ans) -> None:
# Build the DNS data by saving the live query response.
if not BUILD_MOCKED_DNS_RESPONSE_DATA:
raise ValueError("Should not get here.")
self.data[key] = value


@pytest.fixture(scope="session", autouse=True)
def MockedDnsResponseDataCleanup(request):
def cleanup_func():
def MockedDnsResponseDataCleanup(request: pytest.FixtureRequest) -> None:
def cleanup_func() -> None:
if BUILD_MOCKED_DNS_RESPONSE_DATA and MockedDnsResponseData.INSTANCE is not None:
MockedDnsResponseData.INSTANCE.save()
request.addfinalizer(cleanup_func)
Loading

0 comments on commit 5734e5e

Please sign in to comment.