Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Add additional type hints to HTTP client. (#8812)
Browse files Browse the repository at this point in the history
This also removes some duplicated code between the simple
HTTP client and matrix federation client.
clokep authored Nov 25, 2020

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
1 parent 4fd222a commit 968939b
Showing 5 changed files with 142 additions and 149 deletions.
2 changes: 1 addition & 1 deletion changelog.d/8806.misc
Original file line number Diff line number Diff line change
@@ -1 +1 @@
Add type hints to matrix federation client and agent.
Add type hints to HTTP abstractions.
1 change: 1 addition & 0 deletions changelog.d/8812.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add type hints to HTTP abstractions.
3 changes: 2 additions & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -45,6 +45,7 @@ files =
synapse/handlers/saml_handler.py,
synapse/handlers/sync.py,
synapse/handlers/ui_auth,
synapse/http/client.py,
synapse/http/federation/matrix_federation_agent.py,
synapse/http/federation/well_known_resolver.py,
synapse/http/matrixfederationclient.py,
@@ -109,7 +110,7 @@ ignore_missing_imports = True
[mypy-opentracing]
ignore_missing_imports = True

[mypy-OpenSSL]
[mypy-OpenSSL.*]
ignore_missing_imports = True

[mypy-netaddr]
211 changes: 128 additions & 83 deletions synapse/http/client.py
Original file line number Diff line number Diff line change
@@ -14,9 +14,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import urllib
import urllib.parse
from io import BytesIO
from typing import (
TYPE_CHECKING,
Any,
BinaryIO,
Dict,
@@ -31,14 +32,16 @@

import treq
from canonicaljson import encode_canonical_json
from netaddr import IPAddress
from netaddr import IPAddress, IPSet
from prometheus_client import Counter
from zope.interface import implementer, provider

from OpenSSL import SSL
from OpenSSL.SSL import VERIFY_NONE
from twisted.internet import defer, error as twisted_error, protocol, ssl
from twisted.internet.interfaces import (
IAddress,
IHostResolution,
IReactorPluggableNameResolver,
IResolutionReceiver,
)
@@ -53,7 +56,7 @@
)
from twisted.web.http import PotentialDataLoss
from twisted.web.http_headers import Headers
from twisted.web.iweb import IResponse
from twisted.web.iweb import IAgent, IBodyProducer, IResponse

from synapse.api.errors import Codes, HttpResponseException, SynapseError
from synapse.http import QuieterFileBodyProducer, RequestTimedOutError, redact_uri
@@ -63,6 +66,9 @@
from synapse.util import json_decoder
from synapse.util.async_helpers import timeout_deferred

if TYPE_CHECKING:
from synapse.app.homeserver import HomeServer

logger = logging.getLogger(__name__)

outgoing_requests_counter = Counter("synapse_http_client_requests", "", ["method"])
@@ -84,12 +90,19 @@
QueryParams = Union[Mapping[str, QueryParamValue], Mapping[bytes, QueryParamValue]]


def check_against_blacklist(ip_address, ip_whitelist, ip_blacklist):
def check_against_blacklist(
ip_address: IPAddress, ip_whitelist: Optional[IPSet], ip_blacklist: IPSet
) -> bool:
"""
Compares an IP address to allowed and disallowed IP sets.
Args:
ip_address (netaddr.IPAddress)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
ip_address: The IP address to check
ip_whitelist: Allowed IP addresses.
ip_blacklist: Disallowed IP addresses.
Returns:
True if the IP address is in the blacklist and not in the whitelist.
"""
if ip_address in ip_blacklist:
if ip_whitelist is None or ip_address not in ip_whitelist:
@@ -118,23 +131,30 @@ class IPBlacklistingResolver:
addresses, preventing DNS rebinding attacks on URL preview.
"""

def __init__(self, reactor, ip_whitelist, ip_blacklist):
def __init__(
self,
reactor: IReactorPluggableNameResolver,
ip_whitelist: Optional[IPSet],
ip_blacklist: IPSet,
):
"""
Args:
reactor (twisted.internet.reactor)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
reactor: The twisted reactor.
ip_whitelist: IP addresses to allow.
ip_blacklist: IP addresses to disallow.
"""
self._reactor = reactor
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist

def resolveHostName(self, recv, hostname, portNumber=0):
def resolveHostName(
self, recv: IResolutionReceiver, hostname: str, portNumber: int = 0
) -> IResolutionReceiver:

r = recv()
addresses = []
addresses = [] # type: List[IAddress]

def _callback():
def _callback() -> None:
r.resolutionBegan(None)

has_bad_ip = False
@@ -161,15 +181,15 @@ def _callback():
@provider(IResolutionReceiver)
class EndpointReceiver:
@staticmethod
def resolutionBegan(resolutionInProgress):
def resolutionBegan(resolutionInProgress: IHostResolution) -> None:
pass

@staticmethod
def addressResolved(address):
def addressResolved(address: IAddress) -> None:
addresses.append(address)

@staticmethod
def resolutionComplete():
def resolutionComplete() -> None:
_callback()

self._reactor.nameResolver.resolveHostName(
@@ -185,19 +205,29 @@ class BlacklistingAgentWrapper(Agent):
directly (without an IP address lookup).
"""

def __init__(self, agent, reactor, ip_whitelist=None, ip_blacklist=None):
def __init__(
self,
agent: IAgent,
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
):
"""
Args:
agent (twisted.web.client.Agent): The Agent to wrap.
reactor (twisted.internet.reactor)
ip_whitelist (netaddr.IPSet)
ip_blacklist (netaddr.IPSet)
agent: The Agent to wrap.
ip_whitelist: IP addresses to allow.
ip_blacklist: IP addresses to disallow.
"""
self._agent = agent
self._ip_whitelist = ip_whitelist
self._ip_blacklist = ip_blacklist

def request(self, method, uri, headers=None, bodyProducer=None):
def request(
self,
method: bytes,
uri: bytes,
headers: Optional[Headers] = None,
bodyProducer: Optional[IBodyProducer] = None,
) -> defer.Deferred:
h = urllib.parse.urlparse(uri.decode("ascii"))

try:
@@ -226,23 +256,23 @@ class SimpleHttpClient:

def __init__(
self,
hs,
treq_args={},
ip_whitelist=None,
ip_blacklist=None,
http_proxy=None,
https_proxy=None,
hs: "HomeServer",
treq_args: Dict[str, Any] = {},
ip_whitelist: Optional[IPSet] = None,
ip_blacklist: Optional[IPSet] = None,
http_proxy: Optional[bytes] = None,
https_proxy: Optional[bytes] = None,
):
"""
Args:
hs (synapse.server.HomeServer)
treq_args (dict): Extra keyword arguments to be given to treq.request.
ip_blacklist (netaddr.IPSet): The IP addresses that are blacklisted that
hs
treq_args: Extra keyword arguments to be given to treq.request.
ip_blacklist: The IP addresses that are blacklisted that
we may not request.
ip_whitelist (netaddr.IPSet): The whitelisted IP addresses, that we can
ip_whitelist: The whitelisted IP addresses, that we can
request if it were otherwise caught in a blacklist.
http_proxy (bytes): proxy server to use for http connections. host[:port]
https_proxy (bytes): proxy server to use for https connections. host[:port]
http_proxy: proxy server to use for http connections. host[:port]
https_proxy: proxy server to use for https connections. host[:port]
"""
self.hs = hs

@@ -306,7 +336,6 @@ def __getattr__(_self, attr):
# by the DNS resolution.
self.agent = BlacklistingAgentWrapper(
self.agent,
self.reactor,
ip_whitelist=self._ip_whitelist,
ip_blacklist=self._ip_blacklist,
)
@@ -397,7 +426,7 @@ async def request(
async def post_urlencoded_get_json(
self,
uri: str,
args: Mapping[str, Union[str, List[str]]] = {},
args: Optional[Mapping[str, Union[str, List[str]]]] = None,
headers: Optional[RawHeaders] = None,
) -> Any:
"""
@@ -422,17 +451,15 @@ async def post_urlencoded_get_json(
# TODO: Do we ever want to log message contents?
logger.debug("post_urlencoded_get_json args: %s", args)

query_bytes = urllib.parse.urlencode(encode_urlencode_args(args), True).encode(
"utf8"
)
query_bytes = encode_query_args(args)

actual_headers = {
b"Content-Type": [b"application/x-www-form-urlencoded"],
b"User-Agent": [self.user_agent],
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore

response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=query_bytes
@@ -479,7 +506,7 @@ async def post_json_get_json(
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore

response = await self.request(
"POST", uri, headers=Headers(actual_headers), data=json_str
@@ -495,7 +522,10 @@ async def post_json_get_json(
)

async def get_json(
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None,
self,
uri: str,
args: Optional[QueryParams] = None,
headers: Optional[RawHeaders] = None,
) -> Any:
"""Gets some json from the given URI.
@@ -516,7 +546,7 @@ async def get_json(
"""
actual_headers = {b"Accept": [b"application/json"]}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore

body = await self.get_raw(uri, args, headers=headers)
return json_decoder.decode(body.decode("utf-8"))
@@ -525,7 +555,7 @@ async def put_json(
self,
uri: str,
json_body: Any,
args: QueryParams = {},
args: Optional[QueryParams] = None,
headers: RawHeaders = None,
) -> Any:
"""Puts some json to the given URI.
@@ -546,9 +576,9 @@ async def put_json(
ValueError: if the response was not JSON
"""
if len(args):
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
if args:
query_str = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_str)

json_str = encode_canonical_json(json_body)

@@ -558,7 +588,7 @@ async def put_json(
b"Accept": [b"application/json"],
}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore

response = await self.request(
"PUT", uri, headers=Headers(actual_headers), data=json_str
@@ -574,7 +604,10 @@ async def put_json(
)

async def get_raw(
self, uri: str, args: QueryParams = {}, headers: Optional[RawHeaders] = None
self,
uri: str,
args: Optional[QueryParams] = None,
headers: Optional[RawHeaders] = None,
) -> bytes:
"""Gets raw text from the given URI.
@@ -592,13 +625,13 @@ async def get_raw(
HttpResponseException on a non-2xx HTTP response.
"""
if len(args):
query_bytes = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_bytes)
if args:
query_str = urllib.parse.urlencode(args, True)
uri = "%s?%s" % (uri, query_str)

actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore

response = await self.request("GET", uri, headers=Headers(actual_headers))

@@ -641,20 +674,21 @@ async def get_file(

actual_headers = {b"User-Agent": [self.user_agent]}
if headers:
actual_headers.update(headers)
actual_headers.update(headers) # type: ignore

response = await self.request("GET", url, headers=Headers(actual_headers))

resp_headers = dict(response.headers.getAllRawHeaders())

if (
b"Content-Length" in resp_headers
and max_size
and int(resp_headers[b"Content-Length"][0]) > max_size
):
logger.warning("Requested URL is too large > %r bytes" % (self.max_size,))
logger.warning("Requested URL is too large > %r bytes" % (max_size,))
raise SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
"Requested file is too large > %r bytes" % (max_size,),
Codes.TOO_LARGE,
)

@@ -668,7 +702,7 @@ async def get_file(

try:
length = await make_deferred_yieldable(
_readBodyToFile(response, output_stream, max_size)
readBodyToFile(response, output_stream, max_size)
)
except SynapseError:
# This can happen e.g. because the body is too large.
@@ -696,18 +730,16 @@ def _timeout_to_request_timed_out_error(f: Failure):
return f


# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.


class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(self, stream, deferred, max_size):
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size

def dataReceived(self, data):
def dataReceived(self, data: bytes) -> None:
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
@@ -721,7 +753,7 @@ def dataReceived(self, data):
self.deferred = defer.Deferred()
self.transport.loseConnection()

def connectionLost(self, reason):
def connectionLost(self, reason: Failure) -> None:
if reason.check(ResponseDone):
self.deferred.callback(self.length)
elif reason.check(PotentialDataLoss):
@@ -732,35 +764,48 @@ def connectionLost(self, reason):
self.deferred.errback(reason)


# XXX: FIXME: This is horribly copy-pasted from matrixfederationclient.
# The two should be factored out.
def readBodyToFile(
response: IResponse, stream: BinaryIO, max_size: Optional[int]
) -> defer.Deferred:
"""
Read a HTTP response body to a file-object. Optionally enforcing a maximum file size.
Args:
response: The HTTP response to read from.
stream: The file-object to write to.
max_size: The maximum file size to allow.
Returns:
A Deferred which resolves to the length of the read body.
"""

def _readBodyToFile(response, stream, max_size):
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d


def encode_urlencode_args(args):
return {k: encode_urlencode_arg(v) for k, v in args.items()}
def encode_query_args(args: Optional[Mapping[str, Union[str, List[str]]]]) -> bytes:
"""
Encodes a map of query arguments to bytes which can be appended to a URL.
Args:
args: The query arguments, a mapping of string to string or list of strings.
Returns:
The query arguments encoded as bytes.
"""
if args is None:
return b""

def encode_urlencode_arg(arg):
if isinstance(arg, str):
return arg.encode("utf-8")
elif isinstance(arg, list):
return [encode_urlencode_arg(i) for i in arg]
else:
return arg
encoded_args = {}
for k, vs in args.items():
if isinstance(vs, str):
vs = [vs]
encoded_args[k] = [v.encode("utf8") for v in vs]

query_str = urllib.parse.urlencode(encoded_args, True)

def _print_ex(e):
if hasattr(e, "reasons") and e.reasons:
for ex in e.reasons:
_print_ex(ex)
else:
logger.exception(e)
return query_str.encode("utf8")


class InsecureInterceptableContextFactory(ssl.ContextFactory):
74 changes: 10 additions & 64 deletions synapse/http/matrixfederationclient.py
Original file line number Diff line number Diff line change
@@ -19,7 +19,7 @@
import sys
import urllib.parse
from io import BytesIO
from typing import BinaryIO, Callable, Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import attr
import treq
@@ -28,26 +28,27 @@
from signedjson.sign import sign_json
from zope.interface import implementer

from twisted.internet import defer, protocol
from twisted.internet import defer
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import IReactorPluggableNameResolver, IReactorTime
from twisted.internet.task import _EPSILON, Cooperator
from twisted.python.failure import Failure
from twisted.web._newclient import ResponseDone
from twisted.web.http_headers import Headers
from twisted.web.iweb import IBodyProducer, IResponse

import synapse.metrics
import synapse.util.retryutils
from synapse.api.errors import (
Codes,
FederationDeniedError,
HttpResponseException,
RequestSendFailed,
SynapseError,
)
from synapse.http import QuieterFileBodyProducer
from synapse.http.client import BlacklistingAgentWrapper, IPBlacklistingResolver
from synapse.http.client import (
BlacklistingAgentWrapper,
IPBlacklistingResolver,
encode_query_args,
readBodyToFile,
)
from synapse.http.federation.matrix_federation_agent import MatrixFederationAgent
from synapse.logging.context import make_deferred_yieldable
from synapse.logging.opentracing import (
@@ -250,9 +251,7 @@ def __getattr__(_self, attr):
# Use a BlacklistingAgentWrapper to prevent circumventing the IP
# blacklist via IP literals in server names
self.agent = BlacklistingAgentWrapper(
self.agent,
self.reactor,
ip_blacklist=hs.config.federation_ip_range_blacklist,
self.agent, ip_blacklist=hs.config.federation_ip_range_blacklist,
)

self.clock = hs.get_clock()
@@ -986,7 +985,7 @@ async def get_file(
headers = dict(response.headers.getAllRawHeaders())

try:
d = _readBodyToFile(response, output_stream, max_size)
d = readBodyToFile(response, output_stream, max_size)
d.addTimeout(self.default_timeout, self.reactor)
length = await make_deferred_yieldable(d)
except Exception as e:
@@ -1010,44 +1009,6 @@ async def get_file(
return (length, headers)


class _ReadBodyToFileProtocol(protocol.Protocol):
def __init__(
self, stream: BinaryIO, deferred: defer.Deferred, max_size: Optional[int]
):
self.stream = stream
self.deferred = deferred
self.length = 0
self.max_size = max_size

def dataReceived(self, data: bytes) -> None:
self.stream.write(data)
self.length += len(data)
if self.max_size is not None and self.length >= self.max_size:
self.deferred.errback(
SynapseError(
502,
"Requested file is too large > %r bytes" % (self.max_size,),
Codes.TOO_LARGE,
)
)
self.deferred = defer.Deferred()
self.transport.loseConnection()

def connectionLost(self, reason: Failure) -> None:
if reason.check(ResponseDone):
self.deferred.callback(self.length)
else:
self.deferred.errback(reason)


def _readBodyToFile(
response: IResponse, stream: BinaryIO, max_size: Optional[int]
) -> defer.Deferred:
d = defer.Deferred()
response.deliverBody(_ReadBodyToFileProtocol(stream, d, max_size))
return d


def _flatten_response_never_received(e):
if hasattr(e, "reasons"):
reasons = ", ".join(
@@ -1088,18 +1049,3 @@ def check_content_type_is_json(headers: Headers) -> None:
),
can_retry=False,
)


def encode_query_args(args: Optional[QueryArgs]) -> bytes:
if args is None:
return b""

encoded_args = {}
for k, vs in args.items():
if isinstance(vs, str):
vs = [vs]
encoded_args[k] = [v.encode("utf8") for v in vs]

query_str = urllib.parse.urlencode(encoded_args, True)

return query_str.encode("utf8")

0 comments on commit 968939b

Please sign in to comment.