Skip to content

Commit

Permalink
SocketStream: add support for unix domain sockets
Browse files Browse the repository at this point in the history
Resolves #100, #208.

Cherry-pick: ebc0bbd

Thomas Gläßle: conflict resolution + removed whitespace
  • Loading branch information
alonho authored and coldfix committed Jul 27, 2017
1 parent 4fbf7e0 commit d0cb0fb
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 30 deletions.
13 changes: 13 additions & 0 deletions rpyc/core/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,19 @@ def connect(cls, host, port, **kwargs):
kwargs["family"] = socket.AF_INET6
return cls(cls._connect(host, port, **kwargs))

@classmethod
def unix_connect(cls, path, timeout = 3):
"""factory method that creates a ``SocketStream `` over a unix domain socket
located in *path*
:param path: the path to the unix domain socket
:param timeout: socket timeout
"""
s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
s.settimeout(timeout)
s.connect(path)
return cls(s)

@classmethod
def ssl_connect(cls, host, port, ssl_kwargs, **kwargs):
"""factory method that creates a ``SocketStream`` over an SSL-wrapped
Expand Down
10 changes: 10 additions & 0 deletions rpyc/utils/classic.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,16 @@ def connect(host, port = DEFAULT_SERVER_PORT, ipv6 = False, keepalive = False):
"""
return factory.connect(host, port, SlaveService, ipv6 = ipv6, keepalive = keepalive)

def unix_connect(path):
"""
Creates a socket connection to the given host and port.
:param path: the path to the unix domain socket
:returns: an RPyC connection exposing ``SlaveService``
"""
return factory.unix_connect(path, SlaveService)

def ssl_connect(host, port = DEFAULT_SERVER_SSL_PORT, keyfile = None,
certfile = None, ca_certs = None, cert_reqs = None, ssl_version = None,
ciphers = None, ipv6 = False):
Expand Down
13 changes: 13 additions & 0 deletions rpyc/utils/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,19 @@ def connect(host, port, service = VoidService, config = {}, ipv6 = False, keepal
s = SocketStream.connect(host, port, ipv6 = ipv6, keepalive = keepalive)
return connect_stream(s, service, config)

def unix_connect(path, service = VoidService, config = {}):
"""
creates a socket-connection to the given host and port
:param path: the path to the unix domain socket
:param service: the local service to expose (defaults to Void)
:param config: configuration dict
:returns: an RPyC connection
"""
s = SocketStream.unix_connect(path)
return connect_stream(s, service, config)

def ssl_connect(host, port, keyfile = None, certfile = None, ca_certs = None,
cert_reqs = None, ssl_version = None, ciphers = None,
service = VoidService, config = {}, ipv6 = False, keepalive = False):
Expand Down
57 changes: 30 additions & 27 deletions rpyc/utils/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ class Server(object):

def __init__(self, service, hostname = "", ipv6 = False, port = 0,
backlog = 10, reuse_addr = True, authenticator = None, registrar = None,
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5):
auto_register = None, protocol_config = {}, logger = None, listener_timeout = 0.5,
socket_path = None):
self.active = False
self._closed = False
self.service = service
Expand All @@ -61,29 +62,35 @@ def __init__(self, service, hostname = "", ipv6 = False, port = 0,
self.protocol_config = protocol_config
self.clients = set()

if ipv6:
if hostname == "localhost" and sys.platform != "win32":
if socket_path is not None:
if hostname != "" or port != 0 or ipv6 != False:
raise ValueError("socket_path is mutually exclusive with: hostname, port, ipv6")
self.listener = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.listener.bind(socket_path)
# set the self.port to the path as it's used for the registry and logging
self.host, self.port = "", socket_path
else:
if ipv6 and hostname == "localhost" and sys.platform != "win32":
# on windows, you should bind to localhost even for ipv6
hostname = "localhost6"
self.listener = socket.socket(socket.AF_INET6, socket.SOCK_STREAM)
else:
self.listener = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
family = socket.AF_INET6 if ipv6 else socket.AF_INET
self.listener = socket.socket(family, socket.SOCK_STREAM)

if reuse_addr and sys.platform != "win32":
# warning: reuseaddr is not what you'd expect on windows!
# it allows you to bind an already bound port, resulting in "unexpected behavior"
# (quoting MSDN)
self.listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
if reuse_addr and sys.platform != "win32":
# warning: reuseaddr is not what you'd expect on windows!
# it allows you to bind an already bound port, resulting in "unexpected behavior"
# (quoting MSDN)
self.listener.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)

self.listener.bind((hostname, port))
self.listener.settimeout(listener_timeout)
self.listener.bind((hostname, port))
self.listener.settimeout(listener_timeout)

# hack for IPv6 (the tuple can be longer than 2)
sockname = self.listener.getsockname()
self.host, self.port = sockname[0], sockname[1]
# hack for IPv6 (the tuple can be longer than 2)
sockname = self.listener.getsockname()
self.host, self.port = sockname[0], sockname[1]

if logger is None:
logger = logging.getLogger("%s/%d" % (self.service.get_service_name(), self.port))
logger = logging.getLogger("%s/%s" % (self.service.get_service_name(), self.port))
self.logger = logger
if "logger" not in self.protocol_config:
self.protocol_config["logger"] = self.logger
Expand Down Expand Up @@ -141,7 +148,7 @@ def accept(self):
return

sock.setblocking(True)
self.logger.info("accepted %s:%s with fd %d", addrinfo[0], addrinfo[1], sock.fileno())
self.logger.info("accepted %s with fd %d", addrinfo, sock.fileno())
self.clients.add(sock)
self._accept_method(sock)

Expand All @@ -156,15 +163,13 @@ def _authenticate_and_serve_client(self, sock):
try:
if self.authenticator:
addrinfo = sock.getpeername()
h = addrinfo[0]
p = addrinfo[1]
try:
sock2, credentials = self.authenticator(sock)
except AuthenticationError:
self.logger.info("[%s]:%s failed to authenticate, rejecting connection", h, p)
self.logger.info("%s failed to authenticate, rejecting connection", addrinfo)
return
else:
self.logger.info("[%s]:%s authenticated successfully", h, p)
self.logger.info("%s authenticated successfully", addrinfo)
else:
credentials = None
sock2 = sock
Expand All @@ -183,12 +188,10 @@ def _authenticate_and_serve_client(self, sock):

def _serve_client(self, sock, credentials):
addrinfo = sock.getpeername()
h = addrinfo[0]
p = addrinfo[1]
if credentials:
self.logger.info("welcome [%s]:%s (%r)", h, p, credentials)
self.logger.info("welcome %s (%r)", addrinfo, credentials)
else:
self.logger.info("welcome [%s]:%s", h, p)
self.logger.info("welcome %s", addrinfo)
try:
config = dict(self.protocol_config, credentials = credentials,
endpoints = (sock.getsockname(), addrinfo), logger = self.logger)
Expand All @@ -197,7 +200,7 @@ def _serve_client(self, sock, credentials):
conn._init_service()
self._handle_connection(conn)
finally:
self.logger.info("goodbye [%s]:%s", h, p)
self.logger.info("goodbye %s", addrinfo)

def _handle_connection(self, conn):
"""This methoed should implement the server's logic."""
Expand Down
31 changes: 28 additions & 3 deletions tests/test_threaded_server.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,56 @@
import rpyc
import time
import tempfile
from rpyc.utils.server import ThreadedServer
from rpyc import SlaveService
import threading
import unittest

class BaseServerTest(object):

def _create_server(self):
raise NotImplementedError

def _create_client(self):
raise NotImplementedError

class Test_ThreadedServer(unittest.TestCase):
def setUp(self):
self.server = ThreadedServer(SlaveService, port=18878, auto_register=False)
self.server = self._create_server()
self.server.logger.quiet = False
t = threading.Thread(target=self.server.start)
t.setDaemon(True)
t.start()
time.sleep(0.5)

def tearDown(self):
self.server.close()

def test_conenction(self):
c = rpyc.classic.connect("localhost", port=18878)
c = self._create_client()
print( c.modules.sys )
print( c.modules["xml.dom.minidom"].parseString("<a/>") )
c.execute("x = 5")
self.assertEqual(c.namespace["x"], 5)
self.assertEqual(c.eval("1+x"), 6)
c.close()

class Test_ThreadedServer(BaseServerTest, unittest.TestCase):

def _create_server(self):
return ThreadedServer(SlaveService, port=18878, auto_register=False)

def _create_client(self):
return rpyc.classic.connect("localhost", port=18878)

class Test_ThreadedServerOverUnixSocket(BaseServerTest, unittest.TestCase):

socket_path = tempfile.mktemp()

def _create_server(self):
return ThreadedServer(SlaveService, socket_path=self.socket_path, auto_register=False)

def _create_client(self):
return rpyc.classic.unix_connect(self.socket_path)

if __name__ == "__main__":
unittest.main()

0 comments on commit d0cb0fb

Please sign in to comment.