Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Fix IPv6-only bugs on SMTP settings
Browse files Browse the repository at this point in the history
While there, do it in such a fashion that we both document and prepare
the groundwork for similar issues relating to direct usage of
reactor.connectTCP, which lead to IPv6 incompatibilities.

Closes #7720

Signed-off-by: Nico Schottelius <[email protected]>
  • Loading branch information
evilham committed Aug 28, 2023
1 parent 224c2bb commit e0c054c
Show file tree
Hide file tree
Showing 6 changed files with 142 additions and 30 deletions.
1 change: 1 addition & 0 deletions changelog.d/16155.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix IPv6-related bugs on SMTP settings, adding groundwork to fix similar issues. Contributed by @evilham and @telmich (ungleich.ch).
30 changes: 13 additions & 17 deletions synapse/handlers/send_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@

import twisted
from twisted.internet.defer import Deferred
from twisted.internet.interfaces import IOpenSSLContextFactory
from twisted.internet.endpoints import HostnameEndpoint
from twisted.internet.interfaces import IOpenSSLContextFactory, IProtocolFactory
from twisted.internet.ssl import optionsForClientTLS
from twisted.mail.smtp import ESMTPSender, ESMTPSenderFactory
from twisted.protocols.tls import TLSMemoryBIOFactory

from synapse.logging.context import make_deferred_yieldable
from synapse.types import ISynapseReactor
Expand All @@ -37,6 +39,9 @@

_is_old_twisted = parse_version(twisted.__version__) < parse_version("21")

# We assign the name ESMTPTLSClientFactory, to be able to redefine it in tests
ESMTPTLSClientFactory = TLSMemoryBIOFactory


class _NoTLSESMTPSender(ESMTPSender):
"""Extend ESMTPSender to disable TLS
Expand Down Expand Up @@ -97,6 +102,7 @@ def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory:
**kwargs,
)

factory: IProtocolFactory
if _is_old_twisted:
# before twisted 21.2, we have to override the ESMTPSender protocol to disable
# TLS
Expand All @@ -109,23 +115,13 @@ def build_sender_factory(**kwargs: Any) -> ESMTPSenderFactory:
# set to enable TLS.
factory = build_sender_factory(hostname=smtphost if enable_tls else None)

endpoint = HostnameEndpoint(
reactor, smtphost, smtpport, timeout=30, bindAddress=None
)
if force_tls:
reactor.connectSSL(
smtphost,
smtpport,
factory,
optionsForClientTLS(smtphost),
timeout=30,
bindAddress=None,
)
else:
reactor.connectTCP(
smtphost,
smtpport,
factory,
timeout=30,
bindAddress=None,
)
factory = ESMTPTLSClientFactory(optionsForClientTLS(smtphost), True, factory)

await make_deferred_yieldable(endpoint.connect(factory))

await make_deferred_yieldable(d)

Expand Down
70 changes: 60 additions & 10 deletions tests/handlers/test_send_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,42 @@
# limitations under the License.


from typing import Callable, List, Tuple
from typing import Callable, List, Tuple, Type, Union
from unittest.mock import patch

from zope.interface import implementer

from twisted.internet import defer
from twisted.internet.address import IPv4Address
from twisted.internet._sslverify import ClientTLSOptions
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.defer import ensureDeferred
from twisted.internet.interfaces import IProtocolFactory
from twisted.internet.ssl import ContextFactory
from twisted.mail import interfaces, smtp

import synapse.handlers.send_email

from tests.server import FakeTransport
from tests.unittest import HomeserverTestCase, override_config


def TestingESMTPTLSClientFactory(
contextFactory: ContextFactory,
_connectWrapped: bool,
wrappedProtocol: IProtocolFactory,
) -> IProtocolFactory:
"""We use this to pass through in testing without using TLS, but
saving the context information to check that it would have happened.
Note that this is what the MemoryReactor does on connectSSL.
It only saves the contextFactory, but starts the connection with the
underlying Factory.
See: L{twisted.internet.testing.MemoryReactor.connectSSL}"""

wrappedProtocol._testingContextFactory = contextFactory # type: ignore[attr-defined]
return wrappedProtocol


@implementer(interfaces.IMessageDelivery)
class _DummyMessageDelivery:
def __init__(self) -> None:
Expand Down Expand Up @@ -75,7 +98,15 @@ def connectionLost(self) -> None:
pass


class SendEmailHandlerTestCase(HomeserverTestCase):
class SendEmailHandlerTestCaseIPv4(HomeserverTestCase):
ip_class: Union[Type[IPv4Address], Type[IPv6Address]] = IPv4Address

def setUp(self) -> None:
HomeserverTestCase.lookups["localhost"] = HomeserverTestCase.lookups.get(
"localhost", "127.0.0.1"
)
super().setUp()

def test_send_email(self) -> None:
"""Happy-path test that we can send email to a non-TLS server."""
h = self.hs.get_send_email_handler()
Expand All @@ -89,7 +120,7 @@ def test_send_email(self) -> None:
(host, port, client_factory, _timeout, _bindAddress) = self.reactor.tcpClients[
0
]
self.assertEqual(host, "localhost")
self.assertEqual(host, self.lookups["localhost"])
self.assertEqual(port, 25)

# wire it up to an SMTP server
Expand All @@ -105,7 +136,7 @@ def test_send_email(self) -> None:
FakeTransport(
client_protocol,
self.reactor,
peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
peer_address=self.ip_class("TCP", self.lookups["localhost"], 1234),
)
)

Expand All @@ -118,6 +149,11 @@ def test_send_email(self) -> None:
self.assertEqual(str(user), "[email protected]")
self.assertIn(b"Subject: test subject", msg)

@patch.object(
synapse.handlers.send_email,
"ESMTPTLSClientFactory",
TestingESMTPTLSClientFactory,
)
@override_config(
{
"email": {
Expand All @@ -135,17 +171,23 @@ def test_send_email_force_tls(self) -> None:
)
)
# there should be an attempt to connect to localhost:465
self.assertEqual(len(self.reactor.sslClients), 1)
self.assertEqual(len(self.reactor.tcpClients), 1)
(
host,
port,
client_factory,
contextFactory,
_timeout,
_bindAddress,
) = self.reactor.sslClients[0]
self.assertEqual(host, "localhost")
) = self.reactor.tcpClients[0]
self.assertEqual(host, self.lookups["localhost"])
self.assertEqual(port, 465)
# We need to make sure that TLS is happenning
self.assertIsInstance(
client_factory._wrappedFactory._testingContextFactory,
ClientTLSOptions,
)
# And since we use endpoints, they go through reactor.connectTCP
# which works differently to connectSSL on the testing reactor

# wire it up to an SMTP server
message_delivery = _DummyMessageDelivery()
Expand All @@ -160,7 +202,7 @@ def test_send_email_force_tls(self) -> None:
FakeTransport(
client_protocol,
self.reactor,
peer_address=IPv4Address("TCP", "127.0.0.1", 1234),
peer_address=self.ip_class("TCP", self.lookups["localhost"], 1234),
)
)

Expand All @@ -172,3 +214,11 @@ def test_send_email_force_tls(self) -> None:
user, msg = message_delivery.messages.pop()
self.assertEqual(str(user), "[email protected]")
self.assertIn(b"Subject: test subject", msg)


class SendEmailHandlerTestCaseIPv6(SendEmailHandlerTestCaseIPv4):
ip_class = IPv6Address

def setUp(self) -> None:
HomeserverTestCase.lookups["localhost"] = "::1"
super().setUp()
7 changes: 5 additions & 2 deletions tests/rest/media/test_url_preview.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import json
import os
import re
from typing import Any, Dict, Optional, Sequence, Tuple, Type
from typing import Dict, List, Optional, Sequence, Tuple, Type, Union
from urllib.parse import quote, urlencode

from twisted.internet._resolver import HostResolution
Expand Down Expand Up @@ -48,6 +48,9 @@ class URLPreviewTests(unittest.HomeserverTestCase):
skip = "url preview feature requires lxml"

hijack_auth = True
lookups: Dict[ # type: ignore[misc, assignment]
str, List[Tuple[Union[Type[IPv4Address], Type[IPv6Address]], str]]
]
user_id = "@test:user"
end_content = (
b"<html><head>"
Expand Down Expand Up @@ -120,7 +123,7 @@ def prepare(self, reactor: MemoryReactor, clock: Clock, hs: HomeServer) -> None:
media_repo_resource = hs.get_media_repository_resource()
self.preview_url = media_repo_resource.children[b"preview_url"]

self.lookups: Dict[str, Any] = {}
self.lookups = {}

class Resolver:
def resolveHostName(
Expand Down
53 changes: 52 additions & 1 deletion tests/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import hashlib
import ipaddress
import json
import logging
import os
Expand Down Expand Up @@ -45,7 +46,7 @@
from typing_extensions import ParamSpec
from zope.interface import implementer

from twisted.internet import address, threads, udp
from twisted.internet import address, tcp, threads, udp
from twisted.internet._resolver import SimpleResolverComplexifier
from twisted.internet.defer import Deferred, fail, maybeDeferred, succeed
from twisted.internet.error import DNSLookupError
Expand Down Expand Up @@ -567,6 +568,8 @@ def connectTCP(
conn = super().connectTCP(
host, port, factory, timeout=timeout, bindAddress=None
)
if self.lookups and host in self.lookups:
validate_connector(conn, self.lookups[host])

callback = self._tcp_callbacks.get((host, port))
if callback:
Expand Down Expand Up @@ -599,6 +602,54 @@ def advance(self, amount: float) -> None:
super().advance(0)


def validate_connector(connector: tcp.Connector, expected_ip: str) -> None:
"""Try to validate the obtained connector as it would happen when
synapse is running and the conection will be established.
This method will raise a useful exception when necessary, else it will
just do nothing.
This is in order to help catch quirks related to reactor.connectTCP,
since when called directly, the connector's destination will be of type
IPv4Address, with the hostname as the literal host that was given (which
could be an IPv6-only host or an IPv6 literal).
But when called from reactor.connectTCP *through* e.g. an Endpoint, the
connector's destination will contain the specific IP address with the
correct network stack class.
Note that testing code paths that use connectTCP directly should not be
affected by this check, unless they specifically add a test with a
matching HomeserverTestCase.lookups[HOSTNAME] = "IPv6Literal".
For an example of implementing such tests, see test/handlers/send_email.py.
"""
destination = connector.getDestination()

def check_ip(
cls: Union[Type[ipaddress.IPv4Address], Type[ipaddress.IPv6Address]]
) -> None:
"""With this class we produce a more informative error if needed"""
try:
cls(expected_ip)
except Exception:
raise ValueError(
"Invalid IP type and resolution, got %s, expected: %s %s"
% (destination, expected_ip, cls)
)

# We use address.IPv{4,6}Address to check what the reactor thinks it is
# is sending but check for validity with with ipaddress.IPv{4,6}Address
# because they fail with IPs on the wrong network stack.
if isinstance(destination, address.IPv4Address):
check_ip(ipaddress.IPv4Address)
elif isinstance(destination, address.IPv6Address):
check_ip(ipaddress.IPv6Address)
else:
raise ValueError(
"Unknown address type %s for %s" % (type(destination), destination)
)


class ThreadPool:
"""
Threadless thread pool.
Expand Down
11 changes: 11 additions & 0 deletions tests/unittest.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,11 +314,14 @@ class HomeserverTestCase(TestCase):
user_id (str): The user ID to assume if auth is hijacked.
hijack_auth: Whether to hijack auth to return the user specified
in user_id.
lookups: Dictionary of Hostname to IP address (either v4 or v6) that
will be fed to the reactor before the homeserver is started.
"""

hijack_auth: ClassVar[bool] = True
needs_threadpool: ClassVar[bool] = False
servlets: ClassVar[List[RegisterServletsFunc]] = []
lookups: ClassVar[Dict[str, str]] = {}

def __init__(self, methodName: str):
super().__init__(methodName)
Expand All @@ -334,6 +337,12 @@ def setUp(self) -> None:
calling the prepare function.
"""
self.reactor, self.clock = get_clock()

# The homeserver will start connecting places as soon as it starts,
# so we must update some fake DNS entries early.
if HomeserverTestCase.lookups:
self.reactor.lookups.update(HomeserverTestCase.lookups)

self._hs_args = {"clock": self.clock, "reactor": self.reactor}
self.hs = self.make_homeserver(self.reactor, self.clock)

Expand Down Expand Up @@ -410,6 +419,8 @@ async def get_requester(*args: Any, **kwargs: Any) -> Requester:
def tearDown(self) -> None:
# Reset to not use frozen dicts.
events.USE_FROZEN_DICTS = False
# Clear any possible forced lookups
HomeserverTestCase.lookups.clear()

def wait_on_thread(self, deferred: Deferred, timeout: int = 10) -> None:
"""
Expand Down

0 comments on commit e0c054c

Please sign in to comment.