diff --git a/ChangeLog b/ChangeLog index b5b392272..2e59a1dac 100644 --- a/ChangeLog +++ b/ChangeLog @@ -1,3 +1,8 @@ +2015-08-17 Maximilian Hils + + * OpenSSL/SSL.py, : Add support for the ``MSG_PEEK`` flag to + ``Connection.recv()`` and ``Connection.recv_into()``. + 2015-05-27 Jim Shaver * OpenSSL/SSL.py, : Add ``get_protocol_version()`` and diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py index 8c87c349b..9b2701368 100644 --- a/OpenSSL/SSL.py +++ b/OpenSSL/SSL.py @@ -1,3 +1,4 @@ +import socket from sys import platform from functools import wraps, partial from itertools import count, chain @@ -1311,12 +1312,15 @@ def recv(self, bufsiz, flags=None): method again with the SAME buffer. :param bufsiz: The maximum number of bytes to read - :param flags: (optional) Included for compatibility with the socket - API, the value is ignored + :param flags: (optional) The only supported flag is ``MSG_PEEK``, + all other flags are ignored. :return: The string read from the Connection """ buf = _ffi.new("char[]", bufsiz) - result = _lib.SSL_read(self._ssl, buf, bufsiz) + if flags is not None and flags & socket.MSG_PEEK: + result = _lib.SSL_peek(self._ssl, buf, bufsiz) + else: + result = _lib.SSL_read(self._ssl, buf, bufsiz) self._raise_ssl_error(self._ssl, result) return _ffi.buffer(buf, result)[:] read = recv @@ -1332,8 +1336,8 @@ def recv_into(self, buffer, nbytes=None, flags=None): buffer. If not present, defaults to the size of the buffer. If larger than the size of the buffer, is reduced to the size of the buffer. - :param flags: (optional) Included for compatibility with the socket - API, the value is ignored. + :param flags: (optional) The only supported flag is ``MSG_PEEK``, + all other flags are ignored. :return: The number of bytes read into the buffer. """ if nbytes is None: @@ -1345,7 +1349,10 @@ def recv_into(self, buffer, nbytes=None, flags=None): # better if we could pass memoryviews straight into the SSL_read call, # but right now we can't. Revisit this if CFFI gets that ability. buf = _ffi.new("char[]", nbytes) - result = _lib.SSL_read(self._ssl, buf, nbytes) + if flags is not None and flags & socket.MSG_PEEK: + result = _lib.SSL_peek(self._ssl, buf, nbytes) + else: + result = _lib.SSL_read(self._ssl, buf, nbytes) self._raise_ssl_error(self._ssl, result) # This strange line is all to avoid a memory copy. The buffer protocol diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py index e586537f9..7eba493e2 100644 --- a/OpenSSL/test/test_ssl.py +++ b/OpenSSL/test/test_ssl.py @@ -8,7 +8,7 @@ from gc import collect, get_referrers from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN from sys import platform, getfilesystemencoding -from socket import SHUT_RDWR, error, socket +from socket import MSG_PEEK, SHUT_RDWR, error, socket from os import makedirs from os.path import join from unittest import main @@ -2171,6 +2171,16 @@ def test_pending_wrong_args(self): connection = Connection(Context(TLSv1_METHOD), None) self.assertRaises(TypeError, connection.pending, None) + def test_peek(self): + """ + :py:obj:`Connection.recv` peeks into the connection if :py:obj:`socket.MSG_PEEK` is passed. + """ + server, client = self._loopback() + server.send(b('xy')) + self.assertEqual(client.recv(2, MSG_PEEK), b('xy')) + self.assertEqual(client.recv(2, MSG_PEEK), b('xy')) + self.assertEqual(client.recv(2), b('xy')) + def test_connect_wrong_args(self): """ @@ -2998,6 +3008,15 @@ def test_bytearray_really_doesnt_overfill(self): """ self._doesnt_overfill_test(bytearray) + def test_peek(self): + + server, client = self._loopback() + server.send(b('xy')) + + for _ in range(2): + output_buffer = bytearray(5) + self.assertEqual(client.recv_into(output_buffer, flags=MSG_PEEK), 2) + self.assertEqual(output_buffer, bytearray(b('xy\x00\x00\x00'))) try: memoryview diff --git a/doc/api/ssl.rst b/doc/api/ssl.rst index 89ae6a1c2..054867802 100644 --- a/doc/api/ssl.rst +++ b/doc/api/ssl.rst @@ -669,11 +669,12 @@ Connection objects have the following methods: (**not** the underlying transport buffer). -.. py:method:: Connection.recv(bufsize) +.. py:method:: Connection.recv(bufsize[, flags]) Receive data from the Connection. The return value is a string representing the data received. The maximum amount of data to be received at once, is specified - by *bufsize*. + by *bufsize*. The only supported flag is ``MSG_PEEK``, all other flags are + ignored. .. py:method:: Connection.recv_into(buffer[, nbytes[, flags]]) @@ -681,8 +682,7 @@ Connection objects have the following methods: Receive data from the Connection and copy it directly into the provided buffer. The return value is the number of bytes read from the connection. The maximum amount of data to be received at once is specified by *nbytes*. - *flags* is accepted for compatibility with ``socket.recv_into`` but its - value is ignored. + The only supported flag is ``MSG_PEEK``, all other flags are ignored. .. py:method:: Connection.bio_write(bytes)