Skip to content

Commit

Permalink
add SSL_peek functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
mhils committed Aug 17, 2015
1 parent 308970f commit c8cfb99
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 11 deletions.
5 changes: 5 additions & 0 deletions ChangeLog
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
2015-08-17 Maximilian Hils <[email protected]>

* OpenSSL/SSL.py, : Add support for the ``MSG_PEEK`` flag to
``Connection.recv()`` and ``Connection.recv_into()``.

2015-05-27 Jim Shaver <[email protected]>

* OpenSSL/SSL.py, : Add ``get_protocol_version()`` and
Expand Down
19 changes: 13 additions & 6 deletions OpenSSL/SSL.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import socket
from sys import platform
from functools import wraps, partial
from itertools import count, chain
Expand Down Expand Up @@ -1311,12 +1312,15 @@ def recv(self, bufsiz, flags=None):
method again with the SAME buffer.
:param bufsiz: The maximum number of bytes to read
:param flags: (optional) Included for compatibility with the socket
API, the value is ignored
:param flags: (optional) The only supported flag is ``MSG_PEEK``,
all other flags are ignored.
:return: The string read from the Connection
"""
buf = _ffi.new("char[]", bufsiz)
result = _lib.SSL_read(self._ssl, buf, bufsiz)
if flags is not None and flags & socket.MSG_PEEK:
result = _lib.SSL_peek(self._ssl, buf, bufsiz)
else:
result = _lib.SSL_read(self._ssl, buf, bufsiz)
self._raise_ssl_error(self._ssl, result)
return _ffi.buffer(buf, result)[:]
read = recv
Expand All @@ -1332,8 +1336,8 @@ def recv_into(self, buffer, nbytes=None, flags=None):
buffer. If not present, defaults to the size of the buffer. If
larger than the size of the buffer, is reduced to the size of the
buffer.
:param flags: (optional) Included for compatibility with the socket
API, the value is ignored.
:param flags: (optional) The only supported flag is ``MSG_PEEK``,
all other flags are ignored.
:return: The number of bytes read into the buffer.
"""
if nbytes is None:
Expand All @@ -1345,7 +1349,10 @@ def recv_into(self, buffer, nbytes=None, flags=None):
# better if we could pass memoryviews straight into the SSL_read call,
# but right now we can't. Revisit this if CFFI gets that ability.
buf = _ffi.new("char[]", nbytes)
result = _lib.SSL_read(self._ssl, buf, nbytes)
if flags is not None and flags & socket.MSG_PEEK:
result = _lib.SSL_peek(self._ssl, buf, nbytes)
else:
result = _lib.SSL_read(self._ssl, buf, nbytes)
self._raise_ssl_error(self._ssl, result)

# This strange line is all to avoid a memory copy. The buffer protocol
Expand Down
21 changes: 20 additions & 1 deletion OpenSSL/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from gc import collect, get_referrers
from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN
from sys import platform, getfilesystemencoding
from socket import SHUT_RDWR, error, socket
from socket import MSG_PEEK, SHUT_RDWR, error, socket
from os import makedirs
from os.path import join
from unittest import main
Expand Down Expand Up @@ -2171,6 +2171,16 @@ def test_pending_wrong_args(self):
connection = Connection(Context(TLSv1_METHOD), None)
self.assertRaises(TypeError, connection.pending, None)

def test_peek(self):
"""
:py:obj:`Connection.recv` peeks into the connection if :py:obj:`socket.MSG_PEEK` is passed.
"""
server, client = self._loopback()
server.send(b('xy'))
self.assertEqual(client.recv(2, MSG_PEEK), b('xy'))
self.assertEqual(client.recv(2, MSG_PEEK), b('xy'))
self.assertEqual(client.recv(2), b('xy'))


def test_connect_wrong_args(self):
"""
Expand Down Expand Up @@ -2998,6 +3008,15 @@ def test_bytearray_really_doesnt_overfill(self):
"""
self._doesnt_overfill_test(bytearray)

def test_peek(self):

server, client = self._loopback()
server.send(b('xy'))

for _ in range(2):
output_buffer = bytearray(5)
self.assertEqual(client.recv_into(output_buffer, flags=MSG_PEEK), 2)
self.assertEqual(output_buffer, bytearray(b('xy\x00\x00\x00')))

try:
memoryview
Expand Down
8 changes: 4 additions & 4 deletions doc/api/ssl.rst
Original file line number Diff line number Diff line change
Expand Up @@ -669,20 +669,20 @@ Connection objects have the following methods:
(**not** the underlying transport buffer).


.. py:method:: Connection.recv(bufsize)
.. py:method:: Connection.recv(bufsize[, flags])
Receive data from the Connection. The return value is a string representing the
data received. The maximum amount of data to be received at once, is specified
by *bufsize*.
by *bufsize*. The only supported flag is ``MSG_PEEK``, all other flags are
ignored.


.. py:method:: Connection.recv_into(buffer[, nbytes[, flags]])
Receive data from the Connection and copy it directly into the provided
buffer. The return value is the number of bytes read from the connection.
The maximum amount of data to be received at once is specified by *nbytes*.
*flags* is accepted for compatibility with ``socket.recv_into`` but its
value is ignored.
The only supported flag is ``MSG_PEEK``, all other flags are ignored.

.. py:method:: Connection.bio_write(bytes)
Expand Down

0 comments on commit c8cfb99

Please sign in to comment.