Skip to content

Commit

Permalink
Add implicit FTPS support (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
oleksandr-kuzmenko authored and pohmelie committed Oct 15, 2018
1 parent 95cd305 commit 09958d1
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 57 deletions.
18 changes: 15 additions & 3 deletions aioftp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
logger = logging.getLogger(__name__)


async def open_connection(host, port, loop, create_connection):
async def open_connection(host, port, loop, create_connection, ssl=None):
reader = asyncio.StreamReader(loop=loop)
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
transport, _ = await create_connection(lambda: protocol, host, port)
transport, _ = await create_connection(lambda: protocol,
host, port, ssl=ssl)
writer = asyncio.StreamWriter(transport, protocol, reader, loop)
return reader, writer

Expand Down Expand Up @@ -108,7 +109,8 @@ class BaseClient:
def __init__(self, *, loop=None, create_connection=None,
socket_timeout=None, read_speed_limit=None,
write_speed_limit=None, path_timeout=None,
path_io_factory=pathio.PathIO, encoding="utf-8"):
path_io_factory=pathio.PathIO, encoding="utf-8",
ssl=None):
self.loop = loop or asyncio.get_event_loop()
self.create_connection = create_connection or \
self.loop.create_connection
Expand All @@ -122,6 +124,7 @@ def __init__(self, *, loop=None, create_connection=None,
self.path_io = path_io_factory(timeout=path_timeout, loop=loop)
self.encoding = encoding
self.stream = None
self.ssl = ssl

async def connect(self, host, port=DEFAULT_PORT):
self.server_host = host
Expand All @@ -131,6 +134,7 @@ async def connect(self, host, port=DEFAULT_PORT):
port,
self.loop,
self.create_connection,
self.ssl,
)
self.stream = ThrottleStreamIO(
reader,
Expand Down Expand Up @@ -491,6 +495,14 @@ class Client(BaseClient):
:param encoding: encoding to use for convertion strings to bytes
:type encoding: :py:class:`str`
:param ssl: if given and not false, a SSL/TLS transport is created
(by default a plain TCP transport is created).
If ssl is a ssl.SSLContext object, this context is used to create
the transport; if ssl is True, a default context returned from
ssl.create_default_context() is used.
Please look :py:meth:`asyncio.loop.create_connection` docs.
:type ssl: :py:class:`bool` or :py:class:`ssl.SSLContext`
"""
async def connect(self, host, port=DEFAULT_PORT):
"""
Expand Down
10 changes: 9 additions & 1 deletion aioftp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,7 @@ async def start(self, host=None, port=0, **kwargs):
host,
port,
loop=self.loop,
ssl=self.ssl,
**self._start_server_extra_arguments,
)
for sock in self.server.sockets:
Expand Down Expand Up @@ -776,6 +777,11 @@ class Server(AbstractServer):
:param encoding: encoding to use for convertion strings to bytes
:type encoding: :py:class:`str`
:param ssl: can be set to an :py:class:`ssl.SSLContext` instance
to enable TLS over the accepted connections.
Please look :py:meth:`asyncio.loop.create_server` docs.
:type ssl: :py:class:`ssl.SSLContext`
"""
path_facts = (
("st_size", "Size"),
Expand All @@ -799,7 +805,8 @@ def __init__(self,
read_speed_limit_per_connection=None,
write_speed_limit_per_connection=None,
data_ports=None,
encoding="utf-8"):
encoding="utf-8",
ssl=None):
self.loop = loop or asyncio.get_event_loop()
self.block_size = block_size
self.socket_timeout = socket_timeout
Expand Down Expand Up @@ -832,6 +839,7 @@ def __init__(self,
)
self.throttle_per_user = {}
self.encoding = encoding
self.ssl = ssl

async def dispatcher(self, reader, writer):
host, port, *_ = writer.transport.get_extra_info("peername", ("", ""))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def run_tests(self):
packages=find_packages(),
python_requires=" >= 3.5.3",
install_requires=[],
tests_require=["nose", "coverage"],
tests_require=["nose", "coverage", "trustme"],
cmdclass={"test": NoseTestCommand},
include_package_data=True
)
21 changes: 17 additions & 4 deletions tests/common.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
import asyncio
import functools
import pathlib
import logging
import pathlib
import shutil
import socket
import ssl

import nose
import trustme

import aioftp


ca = trustme.CA()
server_cert = ca.issue_server_cert("127.0.0.1", "::1")

ssl_server = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH)
server_cert.configure_cert(ssl_server)

ssl_client = ssl.create_default_context(ssl.Purpose.SERVER_AUTH)
ca.configure_trust(ssl_client)

PORT = 8888


Expand All @@ -23,16 +34,16 @@ def wrapper():
s_args, s_kwargs = server_args
c_args, c_kwargs = client_args

def run_in_loop(s_args, s_kwargs, c_args, c_kwargs):
def run_in_loop(s_args, s_kwargs, c_args, c_kwargs, s_ssl=None, c_ssl=None):
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(name)s] %(message)s",
datefmt="[%H:%M:%S]:",
)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(None)
server = aioftp.Server(*s_args, loop=loop, **s_kwargs)
client = aioftp.Client(*c_args, loop=loop, **c_kwargs)
server = aioftp.Server(*s_args, loop=loop, ssl=s_ssl, **s_kwargs)
client = aioftp.Client(*c_args, loop=loop, ssl=c_ssl, **c_kwargs)
try:
loop.run_until_complete(f(loop, client, server))
finally:
Expand All @@ -46,8 +57,10 @@ def run_in_loop(s_args, s_kwargs, c_args, c_kwargs):
for factory in (aioftp.PathIO, aioftp.AsyncPathIO):
s_kwargs["path_io_factory"] = factory
run_in_loop(s_args, s_kwargs, c_args, c_kwargs)
run_in_loop(s_args, s_kwargs, c_args, c_kwargs, ssl_server, ssl_client)
else:
run_in_loop(s_args, s_kwargs, c_args, c_kwargs)
run_in_loop(s_args, s_kwargs, c_args, c_kwargs, ssl_server, ssl_client)

return wrapper

Expand Down
4 changes: 2 additions & 2 deletions tests/test-connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ async def test_pasv_connection_ports_not_added(loop, client, server):
@with_connection
async def test_pasv_connection_ports(loop, client, server):

clients = [aioftp.Client(loop=loop) for _ in range(2)]
clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(2)]
expected_data_ports = [30000, 30001]

for i, client in enumerate(clients):
Expand Down Expand Up @@ -128,7 +128,7 @@ async def test_data_ports_remains_empty(loop, client, server):
@with_connection
async def test_pasv_connection_port_reused(loop, client, server):

clients = [aioftp.Client(loop=loop) for _ in range(2)]
clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(2)]

for client in clients:

Expand Down
12 changes: 6 additions & 6 deletions tests/test-maximum-connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
@with_connection
async def test_multiply_connections_no_limits(loop, client, server):

clients = [aioftp.Client(loop=loop) for _ in range(4)]
clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(4)]
for client in clients:

await client.connect("127.0.0.1", PORT)
Expand All @@ -23,7 +23,7 @@ async def test_multiply_connections_no_limits(loop, client, server):
@with_connection
async def test_multiply_connections_limited_error(loop, client, server):

clients = [aioftp.Client(loop=loop) for _ in range(5)]
clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(5)]
for client in clients:

await client.connect("127.0.0.1", PORT)
Expand Down Expand Up @@ -53,7 +53,7 @@ async def test_multiply_user_commands(loop, client, server):
async def test_multiply_connections_with_user_limited_error(loop, client,
server):

clients = [aioftp.Client(loop=loop) for _ in range(5)]
clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(5)]
for client in clients:

await client.connect("127.0.0.1", PORT)
Expand All @@ -69,7 +69,7 @@ async def test_multiply_connections_with_user_limited_error(loop, client,
@with_connection
async def test_multiply_connections_relogin_balanced(loop, client, server):

clients = [aioftp.Client(loop=loop) for _ in range(5)]
clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(5)]
for client in clients[:-1]:

await client.connect("127.0.0.1", PORT)
Expand All @@ -90,7 +90,7 @@ async def test_multiply_connections_relogin_balanced(loop, client, server):
@expect_codes_in_exception("421")
async def test_multiply_connections_server_limit_error(loop, client, server):

clients = [aioftp.Client(loop=loop) for _ in range(5)]
clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(5)]
for client in clients:

await client.connect("127.0.0.1", PORT)
Expand All @@ -107,7 +107,7 @@ async def test_multiply_connections_server_limit_error(loop, client, server):
async def test_multiply_connections_server_relogin_balanced(loop, client,
server):

clients = [aioftp.Client(loop=loop) for _ in range(5)]
clients = [aioftp.Client(loop=loop, ssl=client.ssl) for _ in range(5)]
for client in clients[:-1]:

await client.connect("127.0.0.1", PORT)
Expand Down
Loading

0 comments on commit 09958d1

Please sign in to comment.