diff --git a/docs/conf.py b/docs/conf.py index 1be4d0cbe5..7e56b3502a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -256,6 +256,7 @@ (_any_role, 'HttpProtocolHandler'), (_any_role, 'multiprocessing.Manager'), (_any_role, 'work_klass'), + (_any_role, 'proxy.core.base.tcp_upstream.TcpUpstreamConnectionHandler'), (_py_class_role, 'CacheStore'), (_py_class_role, 'HttpParser'), (_py_class_role, 'HttpProtocolHandlerPlugin'), diff --git a/proxy/core/base/__init__.py b/proxy/core/base/__init__.py index 721d83e2e1..8a307776d0 100644 --- a/proxy/core/base/__init__.py +++ b/proxy/core/base/__init__.py @@ -14,8 +14,10 @@ """ from .tcp_server import BaseTcpServerHandler from .tcp_tunnel import BaseTcpTunnelHandler +from .tcp_upstream import TcpUpstreamConnectionHandler __all__ = [ 'BaseTcpServerHandler', 'BaseTcpTunnelHandler', + 'TcpUpstreamConnectionHandler', ] diff --git a/proxy/core/base/tcp_server.py b/proxy/core/base/tcp_server.py index fcab812341..4db61463cd 100644 --- a/proxy/core/base/tcp_server.py +++ b/proxy/core/base/tcp_server.py @@ -45,7 +45,6 @@ class BaseTcpServerHandler(Work): a. handle_data(data: memoryview) implementation b. Optionally, also implement other Work method e.g. initialize, is_inactive, shutdown - """ def __init__(self, *args: Any, **kwargs: Any) -> None: diff --git a/proxy/core/base/tcp_upstream.py b/proxy/core/base/tcp_upstream.py new file mode 100644 index 0000000000..3f94edc2fe --- /dev/null +++ b/proxy/core/base/tcp_upstream.py @@ -0,0 +1,102 @@ +# -*- coding: utf-8 -*- +""" + proxy.py + ~~~~~~~~ + ⚡⚡⚡ Fast, Lightweight, Pluggable, TLS interception capable proxy server focused on + Network monitoring, controls & Application development, testing, debugging. + + :copyright: (c) 2013-present by Abhinav Singh and contributors. + :license: BSD, see LICENSE for more details. +""" +from abc import ABC, abstractmethod + +import ssl +import socket +import logging + +from typing import Tuple, List, Optional, Any + +from ...common.types import Readables, Writables +from ...core.connection import TcpServerConnection + +logger = logging.getLogger(__name__) + + +class TcpUpstreamConnectionHandler(ABC): + """:class:`~proxy.core.base.TcpUpstreamConnectionHandler` can + be used to insert an upstream server connection lifecycle within + asynchronous proxy.py lifecycle. + + Call `initialize_upstream` to initialize the upstream connection object. + Then, directly use ``self.upstream`` object within your class. + + .. spelling:: + + tcp + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + # This is currently a hack, see comments below for rationale, + # will be fixed later. + super().__init__(*args, **kwargs) # type: ignore + self.upstream: Optional[TcpServerConnection] = None + # TODO: Currently, :class:`~proxy.core.base.TcpUpstreamConnectionHandler` + # is used within :class:`~proxy.plugin.ReverseProxyPlugin` and + # :class:`~proxy.plugin.ProxyPoolPlugin`. + # + # For both of which we expect a 4-tuple as arguments + # containing (uuid, flags, client, event_queue). + # We really don't need the rest of the args here. + # May be uuid? May be event_queue in the future. + # But certainly we don't not client here. + # A separate tunnel class must be created which handles + # client connection too. + # + # Both :class:`~proxy.plugin.ReverseProxyPlugin` and + # :class:`~proxy.plugin.ProxyPoolPlugin` are currently + # calling client queue within `handle_upstream_data` callback. + # + # This can be abstracted out too. + self.server_recvbuf_size = args[1].server_recvbuf_size + self.total_size = 0 + + @abstractmethod + def handle_upstream_data(self, raw: memoryview) -> None: + pass + + def initialize_upstream(self, addr: str, port: int) -> None: + self.upstream = TcpServerConnection(addr, port) + + def get_descriptors(self) -> Tuple[List[socket.socket], List[socket.socket]]: + if not self.upstream: + return [], [] + return [self.upstream.connection], [self.upstream.connection] if self.upstream.has_buffer() else [] + + def read_from_descriptors(self, r: Readables) -> bool: + if self.upstream and self.upstream.connection in r: + try: + raw = self.upstream.recv(self.server_recvbuf_size) + if raw is not None: + self.total_size += len(raw) + self.handle_upstream_data(raw) + else: + return True # Teardown because upstream proxy closed the connection + except ssl.SSLWantReadError: + logger.info('Upstream SSLWantReadError, will retry') + return False + except ConnectionResetError: + logger.debug('Connection reset by upstream') + return True + return False + + def write_to_descriptors(self, w: Writables) -> bool: + if self.upstream and self.upstream.connection in w and self.upstream.has_buffer(): + try: + self.upstream.flush() + except ssl.SSLWantWriteError: + logger.info('Upstream SSLWantWriteError, will retry') + return False + except BrokenPipeError: + logger.debug('BrokenPipeError when flushing to upstream') + return True + return False diff --git a/proxy/http/server/pac_plugin.py b/proxy/http/server/pac_plugin.py index 581f185a4b..68aad02cc8 100644 --- a/proxy/http/server/pac_plugin.py +++ b/proxy/http/server/pac_plugin.py @@ -20,7 +20,6 @@ from .plugin import HttpWebServerBasePlugin from .protocols import httpProtocolTypes -from ..websocket import WebsocketFrame from ..parser import HttpParser from ...common.utils import bytes_, text_, build_http_response @@ -64,15 +63,6 @@ def handle_request(self, request: HttpParser) -> None: if self.flags.pac_file and self.pac_file_response: self.client.queue(self.pac_file_response) - def on_websocket_open(self) -> None: - pass # pragma: no cover - - def on_websocket_message(self, frame: WebsocketFrame) -> None: - pass # pragma: no cover - - def on_client_connection_close(self) -> None: - pass # pragma: no cover - def cache_pac_file_response(self) -> None: if self.flags.pac_file: try: diff --git a/proxy/http/server/plugin.py b/proxy/http/server/plugin.py index 55e66f39c2..11c2e3ec76 100644 --- a/proxy/http/server/plugin.py +++ b/proxy/http/server/plugin.py @@ -95,15 +95,19 @@ def on_client_connection_close(self) -> None: """Client has closed the connection, do any clean up task now.""" pass - @abstractmethod + # No longer abstract since v2.4.0 + # + # @abstractmethod def on_websocket_open(self) -> None: """Called when websocket handshake has finished.""" - raise NotImplementedError() # pragma: no cover + pass # pragma: no cover - @abstractmethod + # No longer abstract since v2.4.0 + # + # @abstractmethod def on_websocket_message(self, frame: WebsocketFrame) -> None: """Handle websocket frame.""" - raise NotImplementedError() # pragma: no cover + return None # pragma: no cover # Deprecated since v2.4.0 # diff --git a/proxy/http/server/web.py b/proxy/http/server/web.py index 9c44f836c3..513dbf5502 100644 --- a/proxy/http/server/web.py +++ b/proxy/http/server/web.py @@ -327,7 +327,5 @@ def on_client_connection_close(self) -> None: if not log_handled: self.access_log(context) - # TODO: Allow plugins to customize access_log, similar - # to how proxy server plugins are able to do it. def access_log(self, context: Dict[str, Any]) -> None: logger.info(DEFAULT_WEB_ACCESS_LOG_FORMAT.format_map(context)) diff --git a/proxy/http/url.py b/proxy/http/url.py index 5c7eb7c7b9..2d50743a71 100644 --- a/proxy/http/url.py +++ b/proxy/http/url.py @@ -16,6 +16,7 @@ from typing import Optional, Tuple from ..common.constants import COLON, SLASH +from ..common.utils import text_ class Url: @@ -36,6 +37,18 @@ def __init__( self.port: Optional[int] = port self.remainder: Optional[bytes] = remainder + def __str__(self) -> str: + url = '' + if self.scheme: + url += '{0}://'.format(text_(self.scheme)) + if self.hostname: + url += text_(self.hostname) + if self.port: + url += ':{0}'.format(self.port) + if self.remainder: + url += text_(self.remainder) + return url + @classmethod def from_bytes(cls, raw: bytes) -> 'Url': """A URL within proxy.py core can have several styles, @@ -57,7 +70,9 @@ def from_bytes(cls, raw: bytes) -> 'Url': return cls(remainder=raw) if sraw.startswith('https://') or sraw.startswith('http://'): is_https = sraw.startswith('https://') - rest = raw[len(b'https://'):] if is_https else raw[len(b'http://'):] + rest = raw[len(b'https://'):] \ + if is_https \ + else raw[len(b'http://'):] parts = rest.split(SLASH) host, port = Url.parse_host_and_port(parts[0]) return cls( diff --git a/proxy/plugin/proxy_pool.py b/proxy/plugin/proxy_pool.py index 3751fd37ad..02278671af 100644 --- a/proxy/plugin/proxy_pool.py +++ b/proxy/plugin/proxy_pool.py @@ -9,20 +9,18 @@ :license: BSD, see LICENSE for more details. """ import random -import socket import logging -from typing import Dict, List, Optional, Any, Tuple +from typing import Dict, List, Optional, Any from ..common.flag import flags -from ..common.types import Readables, Writables from ..http import Url, httpMethods from ..http.parser import HttpParser from ..http.exception import HttpProtocolException from ..http.proxy import HttpProxyBasePlugin -from ..core.connection.server import TcpServerConnection +from ..core.base import TcpUpstreamConnectionHandler logger = logging.getLogger(__name__) @@ -54,7 +52,7 @@ ) -class ProxyPoolPlugin(HttpProxyBasePlugin): +class ProxyPoolPlugin(TcpUpstreamConnectionHandler, HttpProxyBasePlugin): """Proxy pool plugin simply acts as a proxy adapter for proxy.py itself. Imagine this plugin as setting up proxy settings for proxy.py instance itself. @@ -62,42 +60,13 @@ class ProxyPoolPlugin(HttpProxyBasePlugin): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self.upstream: Optional[TcpServerConnection] = None # Cached attributes to be used during access log override self.request_host_port_path_method: List[Any] = [ None, None, None, None, ] - self.total_size = 0 - - def get_descriptors(self) -> Tuple[List[socket.socket], List[socket.socket]]: - if not self.upstream: - return [], [] - return [self.upstream.connection], [self.upstream.connection] if self.upstream.has_buffer() else [] - - def read_from_descriptors(self, r: Readables) -> bool: - # Read from upstream proxy and queue for client - if self.upstream and self.upstream.connection in r: - try: - raw = self.upstream.recv(self.flags.server_recvbuf_size) - if raw is not None: - self.total_size += len(raw) - self.client.queue(raw) - else: - return True # Teardown because upstream proxy closed the connection - except ConnectionResetError: - logger.debug('Connection reset by upstream proxy') - return True - return False # Do not teardown connection - - def write_to_descriptors(self, w: Writables) -> bool: - # Flush queued data to upstream proxy now - if self.upstream and self.upstream.connection in w and self.upstream.has_buffer(): - try: - self.upstream.flush() - except BrokenPipeError: - logger.debug('BrokenPipeError when flushing to upstream proxy') - return True - return False + + def handle_upstream_data(self, raw: memoryview) -> None: + self.client.queue(raw) def before_upstream_connection( self, request: HttpParser, @@ -109,12 +78,14 @@ def before_upstream_connection( # must be bootstrapped within it's own re-usable and gc'd pool, to avoid establishing # a fresh upstream proxy connection for each client request. # + # See :class:`~proxy.core.connection.pool.ConnectionPool` which is a work + # in progress for SSL cache handling. + # # Implement your own logic here e.g. round-robin, least connection etc. endpoint = random.choice(self.flags.proxy_pool)[0].split(':') logger.debug('Using endpoint: {0}:{1}'.format(*endpoint)) - self.upstream = TcpServerConnection( - endpoint[0], int(endpoint[1]), - ) + self.initialize_upstream(endpoint[0], int(endpoint[1])) + assert self.upstream try: self.upstream.connect() except ConnectionRefusedError: diff --git a/proxy/plugin/reverse_proxy.py b/proxy/plugin/reverse_proxy.py index 23afd986a1..8eb3b91aaf 100644 --- a/proxy/plugin/reverse_proxy.py +++ b/proxy/plugin/reverse_proxy.py @@ -8,30 +8,25 @@ :copyright: (c) 2013-present by Abhinav Singh and contributors. :license: BSD, see LICENSE for more details. """ -import ssl import random -import socket import logging -from typing import List, Optional, Tuple, Any -from urllib import parse as urlparse +from typing import List, Tuple, Any, Dict, Optional from ..common.utils import text_ from ..common.constants import DEFAULT_HTTPS_PORT, DEFAULT_HTTP_PORT -from ..common.types import Readables, Writables -from ..core.connection import TcpServerConnection + +from ..http import Url from ..http.exception import HttpProtocolException from ..http.parser import HttpParser -from ..http.websocket import WebsocketFrame from ..http.server import HttpWebServerBasePlugin, httpProtocolTypes +from ..core.base import TcpUpstreamConnectionHandler + logger = logging.getLogger(__name__) -# TODO: ReverseProxyPlugin and ProxyPoolPlugin are implementing -# a similar behavior. Abstract that particular logic out into its -# own class. -class ReverseProxyPlugin(HttpWebServerBasePlugin): +class ReverseProxyPlugin(TcpUpstreamConnectionHandler, HttpWebServerBasePlugin): """Extend in-built Web Server to add Reverse Proxy capabilities. This example plugin is equivalent to following Nginx configuration:: @@ -74,7 +69,11 @@ class ReverseProxyPlugin(HttpWebServerBasePlugin): def __init__(self, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) - self.upstream: Optional[TcpServerConnection] = None + # Chosen upstream proxy_pass url + self.choice: Optional[Url] = None + + def handle_upstream_data(self, raw: memoryview) -> None: + self.client.queue(raw) def routes(self) -> List[Tuple[int, str]]: return [ @@ -82,76 +81,45 @@ def routes(self) -> List[Tuple[int, str]]: (httpProtocolTypes.HTTPS, ReverseProxyPlugin.REVERSE_PROXY_LOCATION), ] - def get_descriptors(self) -> Tuple[List[socket.socket], List[socket.socket]]: - if not self.upstream: - return [], [] - return [self.upstream.connection], [self.upstream.connection] if self.upstream.has_buffer() else [] - - def read_from_descriptors(self, r: Readables) -> bool: - if self.upstream and self.upstream.connection in r: - try: - raw = self.upstream.recv(self.flags.server_recvbuf_size) - if raw is not None: - self.client.queue(raw) - else: - return True # Teardown because upstream server closed the connection - except ssl.SSLWantReadError: - logger.info('Upstream server SSLWantReadError, will retry') - return False - except ConnectionResetError: - logger.debug('Connection reset by upstream server') - return True - return super().read_from_descriptors(r) - - def write_to_descriptors(self, w: Writables) -> bool: - if self.upstream and self.upstream.connection in w and self.upstream.has_buffer(): - try: - self.upstream.flush() - except ssl.SSLWantWriteError: - logger.info('Upstream server SSLWantWriteError, will retry') - return False - except BrokenPipeError: - logger.debug( - 'BrokenPipeError when flushing to upstream server', - ) - return True - return super().write_to_descriptors(w) - def handle_request(self, request: HttpParser) -> None: - url = urlparse.urlsplit( + self.choice = Url.from_bytes( random.choice(ReverseProxyPlugin.REVERSE_PROXY_PASS), ) - assert url.hostname - port = url.port or ( - DEFAULT_HTTP_PORT if url.scheme == - b'http' else DEFAULT_HTTPS_PORT - ) - self.upstream = TcpServerConnection(text_(url.hostname), port) + assert self.choice.hostname + port = self.choice.port or \ + DEFAULT_HTTP_PORT \ + if self.choice.scheme == b'http' \ + else DEFAULT_HTTPS_PORT + + self.initialize_upstream(text_(self.choice.hostname), port) + assert self.upstream try: self.upstream.connect() - if url.scheme == b'https': + if self.choice.scheme == b'https': self.upstream.wrap( text_( - url.hostname, + self.choice.hostname, ), ca_file=str(self.flags.ca_file), ) self.upstream.queue(memoryview(request.build())) except ConnectionRefusedError: logger.info( 'Connection refused by upstream server {0}:{1}'.format( - text_(url.hostname), port, + text_(self.choice.hostname), port, ), ) raise HttpProtocolException() - def on_websocket_open(self) -> None: - pass - - def on_websocket_message(self, frame: WebsocketFrame) -> None: - pass - def on_client_connection_close(self) -> None: if self.upstream and not self.upstream.closed: logger.debug('Closing upstream server connection') self.upstream.close() self.upstream = None + + def on_access_log(self, context: Dict[str, Any]) -> Optional[Dict[str, Any]]: + log_format = '{client_addr} - {request_method} {request_path} -> {upstream_proxy_pass} - {connection_time_ms}ms' + context.update({ + 'upstream_proxy_pass': str(self.choice) if self.choice else None, + }) + logger.info(log_format.format_map(context)) + return None diff --git a/tests/http/test_url.py b/tests/http/test_url.py index 11a3c54a45..de8ec0e71a 100644 --- a/tests/http/test_url.py +++ b/tests/http/test_url.py @@ -15,6 +15,22 @@ class TestUrl(unittest.TestCase): + def test_url_str(self) -> None: + url = Url.from_bytes(b'localhost') + self.assertEqual(str(url), 'localhost') + url = Url.from_bytes(b'/') + self.assertEqual(str(url), '/') + url = Url.from_bytes(b'http://httpbin.org/get') + self.assertEqual(str(url), 'http://httpbin.org/get') + url = Url.from_bytes(b'httpbin.org:443') + self.assertEqual(str(url), 'httpbin.org:443') + url = Url.from_bytes('å∫ç.com'.encode('utf-8')) + self.assertEqual(str(url), 'å∫ç.com') + url = Url.from_bytes(b'https://example.com/path/dir/?a=b&c=d#p=q') + self.assertEqual(str(url), 'https://example.com/path/dir/?a=b&c=d#p=q') + url = Url.from_bytes(b'http://localhost:12345/v1/users/') + self.assertEqual(str(url), 'http://localhost:12345/v1/users/') + def test_just_domain_name_url(self) -> None: url = Url.from_bytes(b'localhost') self.assertEqual(url.scheme, None)