-
Notifications
You must be signed in to change notification settings - Fork 315
/
SocketInterface.py
222 lines (193 loc) · 7.52 KB
/
SocketInterface.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
from __future__ import absolute_import, print_function
import json
import socket
import struct
import threading
import traceback
import dill
import six
from six.moves import input
from six.moves.queue import Queue
if six.PY2:
class ConnectionAbortedError(Exception):
pass
# TODO - Implement a cleaner shutdown for server socket
# see: https://stackoverflow.com/a/1148237
class serversocket:
"""
A server socket to receive and process string messages
from client sockets to a central queue
"""
def __init__(self, name=None, verbose=False):
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.sock.bind(('localhost', 0))
self.sock.listen(10) # queue a max of n connect requests
self.verbose = verbose
self.name = name
self.queue = Queue()
if self.verbose:
print("Server bound to: " + str(self.sock.getsockname()))
def start_accepting(self):
""" Start the listener thread """
thread = threading.Thread(target=self._accept, args=())
thread.daemon = True # stops from blocking shutdown
if self.name is not None:
thread.name = thread.name + "-" + self.name
thread.start()
def _accept(self):
""" Listen for connections and pass handling to a new thread """
while True:
try:
(client, address) = self.sock.accept()
thread = threading.Thread(target=self._handle_conn,
args=(client, address))
thread.daemon = True
thread.start()
except ConnectionAbortedError:
# Workaround for #278
print("A connection establish request was performed "
"on a closed socket")
return
def _handle_conn(self, client, address):
"""
Receive messages and pass to queue. Messages are prefixed with
a 4-byte integer to specify the message length and 1-byte character
to indicate the type of serialization applied to the message.
Supported serialization formats:
'n' : no serialization
'u' : Unicode string in UTF-8
'd' : dill pickle
'j' : json
"""
if self.verbose:
print("Thread: %s connected to: %s" %
(threading.current_thread(), address))
try:
while True:
msg = self.receive_msg(client, 5)
msglen, serialization = struct.unpack('>Lc', msg)
if self.verbose:
print("Received message, length %d, serialization %r"
% (msglen, serialization))
msg = self.receive_msg(client, msglen)
if serialization != b'n':
try:
if serialization == b'd': # dill serialization
msg = dill.loads(msg)
elif serialization == b'j': # json serialization
msg = json.loads(msg.decode('utf-8'))
elif serialization == b'u': # utf-8 serialization
msg = msg.decode('utf-8')
else:
print("Unrecognized serialization type: %r"
% serialization)
continue
except (UnicodeDecodeError, ValueError) as e:
print("Error de-serializing message: %s \n %s" % (
msg, traceback.format_exc(e)))
continue
self.queue.put(msg)
except RuntimeError:
if self.verbose:
print("Client socket: " + str(address) + " closed")
def receive_msg(self, client, msglen):
msg = b''
while len(msg) < msglen:
chunk = client.recv(msglen - len(msg))
if not chunk:
raise RuntimeError("socket connection broken")
msg = msg + chunk
return msg
def close(self):
self.sock.close()
class clientsocket:
"""A client socket for sending messages"""
def __init__(self, serialization='json', verbose=False):
""" `serialization` specifies the type of serialization to use for
non-string messages. Supported formats:
* 'json' uses the json module. Cross-language support. (default)
* 'dill' uses the dill pickle module. Python only.
"""
self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
if serialization != 'json' and serialization != 'dill':
raise ValueError(
"Unsupported serialization type: %s" % serialization)
self.serialization = serialization
self.verbose = verbose
def connect(self, host, port):
if self.verbose:
print("Connecting to: %s:%i" % (host, port))
self.sock.connect((host, port))
def send(self, msg):
"""
Sends an arbitrary python object to the connected socket. Serializes
using dill if not string, and prepends msg len (4-bytes) and
serialization type (1-byte).
"""
import six
if isinstance(msg, six.binary_type):
serialization = b'n'
elif isinstance(msg, six.text_type):
serialization = b'u'
msg = msg.encode('utf-8')
elif self.serialization == 'dill':
msg = dill.dumps(msg, dill.HIGHEST_PROTOCOL)
serialization = b'd'
elif self.serialization == 'json':
msg = json.dumps(msg).encode('utf-8')
serialization = b'j'
else:
raise ValueError("Unsupported serialization type set: %s"
% serialization)
if self.verbose:
print("Sending message with serialization %s" % serialization)
# prepend with message length
msg = struct.pack('>Lc', len(msg), serialization) + msg
totalsent = 0
while totalsent < len(msg):
sent = self.sock.send(msg[totalsent:])
if sent == 0:
raise RuntimeError("socket connection broken")
totalsent = totalsent + sent
def close(self):
self.sock.close()
def main():
import sys
# Just for testing
if sys.argv[1] == 's':
sock = serversocket(verbose=True)
sock.start_accepting()
input("Press enter to exit...")
sock.close()
elif sys.argv[1] == 'c':
host = input("Enter the host name:\n")
port = input("Enter the port:\n")
serialization = input(
"Enter the serialization type (default: 'json'):\n")
if serialization == '':
serialization = 'json'
sock = clientsocket(serialization=serialization)
sock.connect(host, int(port))
msg = None
# some predefined messages
tuple_msg = ('hello', 'world')
list_msg = ['hello', 'world']
dict_msg = {'hello': 'world'}
def function_msg(x):
return x
# read user input
while msg != "quit":
msg = input("Enter a message to send:\n")
if msg == 'tuple':
sock.send(tuple_msg)
elif msg == 'list':
sock.send(list_msg)
elif msg == 'dict':
sock.send(dict_msg)
elif msg == 'function':
sock.send(function_msg)
else:
sock.send(msg)
sock.close()
if __name__ == '__main__':
main()