Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix incompatibility with Twisted < 21. (#10713)
Browse files Browse the repository at this point in the history
Turns out that the functionality added in #10546 to skip TLS was incompatible
with older Twisted versions, so we need to be a bit more inventive.

Also, add a test to (hopefully) not break this in future. Sadly, testing TLS is
really hard.
  • Loading branch information
richvdh authored Aug 27, 2021
1 parent f03cafb commit 8f98260
Show file tree
Hide file tree
Showing 5 changed files with 173 additions and 21 deletions.
1 change: 1 addition & 0 deletions changelog.d/10713.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix a regression introduced in Synapse 1.41 which broke email transmission on Systems using older versions of the Twisted library.
1 change: 1 addition & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ files =
tests/test_utils,
tests/handlers/test_password_providers.py,
tests/handlers/test_room_summary.py,
tests/handlers/test_send_email.py,
tests/rest/client/v1/test_login.py,
tests/rest/client/v2_alpha/test_auth.py,
tests/util/test_itertools.py,
Expand Down
65 changes: 47 additions & 18 deletions synapse/handlers/send_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,12 @@
from io import BytesIO
from typing import TYPE_CHECKING, Optional

from pkg_resources import parse_version

import twisted
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IReactorTCP
from twisted.mail.smtp import ESMTPSenderFactory
from twisted.internet.interfaces import IOpenSSLContextFactory, IReactorTCP
from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory

from synapse.logging.context import make_deferred_yieldable

Expand All @@ -30,6 +33,19 @@

logger = logging.getLogger(__name__)

_is_old_twisted = parse_version(twisted.__version__) < parse_version("21")


class _NoTLSESMTPSender(ESMTPSender):
"""Extend ESMTPSender to disable TLS
Unfortunately, before Twisted 21.2, ESMTPSender doesn't give an easy way to disable
TLS, so we override its internal method which it uses to generate a context factory.
"""

def _getContextFactory(self) -> Optional[IOpenSSLContextFactory]:
return None


async def _sendmail(
reactor: IReactorTCP,
Expand All @@ -42,7 +58,7 @@ async def _sendmail(
password: Optional[bytes] = None,
require_auth: bool = False,
require_tls: bool = False,
tls_hostname: Optional[str] = None,
enable_tls: bool = True,
) -> None:
"""A simple wrapper around ESMTPSenderFactory, to allow substitution in tests
Expand All @@ -57,24 +73,37 @@ async def _sendmail(
password: password to give when authenticating
require_auth: if auth is not offered, fail the request
require_tls: if TLS is not offered, fail the reqest
tls_hostname: TLS hostname to check for. None to disable TLS.
enable_tls: True to enable TLS. If this is False and require_tls is True,
the request will fail.
"""
msg = BytesIO(msg_bytes)

d: "Deferred[object]" = Deferred()

factory = ESMTPSenderFactory(
username,
password,
from_addr,
to_addr,
msg,
d,
heloFallback=True,
requireAuthentication=require_auth,
requireTransportSecurity=require_tls,
hostname=tls_hostname,
)
def build_sender_factory(**kwargs) -> ESMTPSenderFactory:
return ESMTPSenderFactory(
username,
password,
from_addr,
to_addr,
msg,
d,
heloFallback=True,
requireAuthentication=require_auth,
requireTransportSecurity=require_tls,
**kwargs,
)

if _is_old_twisted:
# before twisted 21.2, we have to override the ESMTPSender protocol to disable
# TLS
factory = build_sender_factory()

if not enable_tls:
factory.protocol = _NoTLSESMTPSender
else:
# for twisted 21.2 and later, there is a 'hostname' parameter which we should
# set to enable TLS.
factory = build_sender_factory(hostname=smtphost if enable_tls else None)

# the IReactorTCP interface claims host has to be a bytes, which seems to be wrong
reactor.connectTCP(smtphost, smtpport, factory, timeout=30, bindAddress=None) # type: ignore[arg-type]
Expand Down Expand Up @@ -154,5 +183,5 @@ async def send_email(
password=self._smtp_pass,
require_auth=self._smtp_user is not None,
require_tls=self._require_transport_security,
tls_hostname=self._smtp_host if self._enable_tls else None,
enable_tls=self._enable_tls,
)
112 changes: 112 additions & 0 deletions tests/handlers/test_send_email.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
# Copyright 2021 The Matrix.org Foundation C.I.C.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from typing import List, Tuple

from zope.interface import implementer

from twisted.internet import defer
from twisted.internet.address import IPv4Address
from twisted.internet.defer import ensureDeferred
from twisted.mail import interfaces, smtp

from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase


@implementer(interfaces.IMessageDelivery)
class _DummyMessageDelivery:
def __init__(self):
# (recipient, message) tuples
self.messages: List[Tuple[smtp.Address, bytes]] = []

def receivedHeader(self, helo, origin, recipients):
return None

def validateFrom(self, helo, origin):
return origin

def record_message(self, recipient: smtp.Address, message: bytes):
self.messages.append((recipient, message))

def validateTo(self, user: smtp.User):
return lambda: _DummyMessage(self, user)


@implementer(interfaces.IMessageSMTP)
class _DummyMessage:
"""IMessageSMTP implementation which saves the message delivered to it
to the _DummyMessageDelivery object.
"""

def __init__(self, delivery: _DummyMessageDelivery, user: smtp.User):
self._delivery = delivery
self._user = user
self._buffer: List[bytes] = []

def lineReceived(self, line):
self._buffer.append(line)

def eomReceived(self):
message = b"\n".join(self._buffer) + b"\n"
self._delivery.record_message(self._user.dest, message)
return defer.succeed(b"saved")

def connectionLost(self):
pass


class SendEmailHandlerTestCase(HomeserverTestCase):
def test_send_email(self):
"""Happy-path test that we can send email to a non-TLS server."""
h = self.hs.get_send_email_handler()
d = ensureDeferred(
h.send_email(
"[email protected]", "test subject", "Tests", "HTML content", "Text content"
)
)
# there should be an attempt to connect to localhost:25
self.assertEqual(len(self.reactor.tcpClients), 1)
(host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
0
]
self.assertEqual(host, "localhost")
self.assertEqual(port, 25)

# wire it up to an SMTP server
message_delivery = _DummyMessageDelivery()
server_protocol = smtp.ESMTP()
server_protocol.delivery = message_delivery
# make sure that the server uses the test reactor to set timeouts
server_protocol.callLater = self.reactor.callLater # type: ignore[assignment]

client_protocol = client_factory.buildProtocol(None)
client_protocol.makeConnection(FakeTransport(server_protocol, self.reactor))
server_protocol.makeConnection(
FakeTransport(
client_protocol,
self.reactor,
peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
)
)

# the message should now get delivered
self.get_success(d, by=0.1)

# check it arrived
self.assertEqual(len(message_delivery.messages), 1)
user, msg = message_delivery.messages.pop()
self.assertEqual(str(user), "[email protected]")
self.assertIn(b"Subject: test subject", msg)
15 changes: 12 additions & 3 deletions tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@

from twisted.internet import address, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, succeed
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
from twisted.internet.interfaces import (
IAddress,
IHostnameResolver,
IProtocol,
IPullProducer,
Expand Down Expand Up @@ -511,6 +512,9 @@ class FakeTransport:
will get called back for connectionLost() notifications etc.
"""

_peer_address: Optional[IAddress] = attr.ib(default=None)
"""The value to be returend by getPeer"""

disconnecting = False
disconnected = False
connected = True
Expand All @@ -519,7 +523,7 @@ class FakeTransport:
autoflush = attr.ib(default=True)

def getPeer(self):
return None
return self._peer_address

def getHost(self):
return None
Expand Down Expand Up @@ -572,7 +576,12 @@ def registerProducer(self, producer, streaming):
self.producerStreaming = streaming

def _produce():
d = self.producer.resumeProducing()
if not self.producer:
# we've been unregistered
return
# some implementations of IProducer (for example, FileSender)
# don't return a deferred.
d = maybeDeferred(self.producer.resumeProducing)
d.addCallback(lambda x: self._reactor.callLater(0.1, _produce))

if not streaming:
Expand Down

0 comments on commit 8f98260

Please sign in to comment.