Skip to content

Commit

Permalink
add ssl_peek functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
mhils committed Jul 25, 2015
1 parent e2a9ad3 commit 24093e7
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 4 deletions.
10 changes: 7 additions & 3 deletions OpenSSL/SSL.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import socket
from sys import platform
from functools import wraps, partial
from itertools import count, chain
Expand Down Expand Up @@ -1311,12 +1312,15 @@ def recv(self, bufsiz, flags=None):
method again with the SAME buffer.
:param bufsiz: The maximum number of bytes to read
:param flags: (optional) Included for compatibility with the socket
API, the value is ignored
:param flags: (optional) The only supported flag is
socket.MSG_PEEK. All other flags are ignored.
:return: The string read from the Connection
"""
buf = _ffi.new("char[]", bufsiz)
result = _lib.SSL_read(self._ssl, buf, bufsiz)
if flags and flags & socket.MSG_PEEK:
result = _lib.SSL_peek(self._ssl, buf, bufsiz)
else:
result = _lib.SSL_read(self._ssl, buf, bufsiz)
self._raise_ssl_error(self._ssl, result)
return _ffi.buffer(buf, result)[:]
read = recv
Expand Down
12 changes: 11 additions & 1 deletion OpenSSL/test/test_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from gc import collect, get_referrers
from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN
from sys import platform, getfilesystemencoding
from socket import SHUT_RDWR, error, socket
from socket import MSG_PEEK, SHUT_RDWR, error, socket
from os import makedirs
from os.path import join
from unittest import main
Expand Down Expand Up @@ -2171,6 +2171,16 @@ def test_pending_wrong_args(self):
connection = Connection(Context(TLSv1_METHOD), None)
self.assertRaises(TypeError, connection.pending, None)

def test_peek(self):
"""
:py:obj:`Connection.recv` peeks into the connect if :py:obj:`socket.MSG_PEEK` is passed.
"""
server, client = self._loopback()
server.send(b('xy'))
self.assertEquals(client.recv(2, MSG_PEEK), b('xy'))
self.assertEquals(client.recv(2, MSG_PEEK), b('xy'))
self.assertEquals(client.recv(2), b('xy'))


def test_connect_wrong_args(self):
"""
Expand Down

0 comments on commit 24093e7

Please sign in to comment.