Skip to content

Commit

Permalink
Check address family to fill wsgi env properly
Browse files Browse the repository at this point in the history
  • Loading branch information
Sergey Skripnick committed Jan 4, 2016
1 parent 35b1a0a commit 6e48b17
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 26 deletions.
35 changes: 22 additions & 13 deletions aiohttp/wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import inspect
import io
import os
import socket
import sys
from urllib.parse import urlsplit

Expand Down Expand Up @@ -90,20 +91,28 @@ def create_wsgi_environ(self, message, payload):
# which this request is received from the client.
# http://www.ietf.org/rfc/rfc3875

remote = self.transport.get_extra_info('peername')
if remote:
environ['REMOTE_ADDR'] = remote[0]
environ['REMOTE_PORT'] = remote[1]
_host, port = self.transport.get_extra_info('sockname')
environ['SERVER_PORT'] = str(port)
host = message.headers.get("HOST", None)
# SERVER_NAME should be set to value of Host header, but this
# header is not required. In this case we shoud set it to local
# address of socket
environ['SERVER_NAME'] = host.split(":")[0] if host else _host
family = self.transport.get_extra_info('socket').family
if family in (socket.AF_INET, socket.AF_INET6):
peername = self.transport.get_extra_info('peername')
environ['REMOTE_ADDR'] = peername[0]
environ['REMOTE_PORT'] = str(peername[1])
http_host = message.headers.get("HOST", None)
if http_host:
hostport = http_host.split(":")
environ['SERVER_NAME'] = hostport[0]
if len(hostport) > 1:
environ['SERVER_PORT'] = str(hostport[1])
else:
environ['SERVER_PORT'] = '80'
else:
# SERVER_NAME should be set to value of Host header, but this
# header is not required. In this case we shoud set it to local
# address of socket
sockname = self.transport.get_extra_info('sockname')
environ['SERVER_NAME'] = sockname[0]
environ['SERVER_PORT'] = str(sockname[1])
else:
# Dealing with unix socket, so request was received from client by
# upstream server and this data may be found in the headers
# We are behind reverse proxy, so get all vars from headers
for header in ('REMOTE_ADDR', 'REMOTE_PORT',
'SERVER_NAME', 'SERVER_PORT'):
environ[header] = message.headers.get(header, '')
Expand Down
38 changes: 25 additions & 13 deletions tests/test_wsgi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import io
import asyncio
import socket
import unittest
import unittest.mock

Expand All @@ -22,8 +23,10 @@ def setUp(self):
self.writer = unittest.mock.Mock()
self.writer.drain.return_value = ()
self.transport = unittest.mock.Mock()
self.transport.get_extra_info.side_effect = [('1.2.3.4', 1234),
('2.3.4.5', 80)]
self.transport.get_extra_info.side_effect = [
unittest.mock.Mock(family=socket.AF_INET),
('1.2.3.4', 1234),
('2.3.4.5', 80)]

self.headers = multidict.MultiDict({"HOST": "python.org"})
self.message = protocol.RawRequestMessage(
Expand Down Expand Up @@ -78,26 +81,20 @@ def test_environ_headers(self):
self.assertEqual(environ['SERVER_PORT'], '80')
get_extra_info_calls = self.transport.get_extra_info.mock_calls
expected_calls = [
unittest.mock.call('socket'),
unittest.mock.call('peername'),
unittest.mock.call('sockname'),
]
self.assertEqual(expected_calls, get_extra_info_calls)

def test_environ_host_header_alternate_port(self):
self.transport.get_extra_info = unittest.mock.Mock(
side_effect=[('1.2.3.4', 1234), ('3.4.5.6', 82)]
)
self.headers.update({'HOST': 'example.com:9999'})
environ = self._make_one()
self.assertEqual(environ['SERVER_PORT'], '82')
self.assertEqual(environ['SERVER_PORT'], '9999')

def test_environ_host_header_alternate_port_ssl(self):
self.transport.get_extra_info = unittest.mock.Mock(
side_effect=[('1.2.3.4', 1234), ('3.4.5.6', 82)]
)
self.headers.update({'HOST': 'example.com:9999'})
environ = self._make_one(is_ssl=True)
self.assertEqual(environ['SERVER_PORT'], '82')
self.assertEqual(environ['SERVER_PORT'], '9999')

def test_wsgi_response(self):
srv = self._make_srv()
Expand Down Expand Up @@ -276,8 +273,23 @@ def test_http_1_0_no_host(self):
self.assertEqual(environ['SERVER_NAME'], '2.3.4.5')
self.assertEqual(environ['SERVER_PORT'], '80')

def test_unix_socket(self):
self.transport.get_extra_info = unittest.mock.Mock(return_value=None)
def test_family_inet6(self):
self.transport.get_extra_info.side_effect = [
unittest.mock.Mock(family=socket.AF_INET6),
("::", 1122, 0, 0),
('2.3.4.5', 80)]
self.message = protocol.RawRequestMessage(
'GET', '/', (1, 0), self.headers, True, 'deflate')
environ = self._make_one()
self.assertEqual(environ['SERVER_NAME'], 'python.org')
self.assertEqual(environ['SERVER_PORT'], '80')
self.assertEqual(environ['REMOTE_ADDR'], '::')
self.assertEqual(environ['REMOTE_PORT'], '1122')

def test_family_unix(self):
fake_socket = unittest.mock.Mock(family=socket.AF_UNIX)
self.transport.get_extra_info = unittest.mock.Mock(
return_value=fake_socket)
headers = multidict.MultiDict({
'SERVER_NAME': '1.2.3.4', 'SERVER_PORT': '5678',
'REMOTE_ADDR': '4.3.2.1', 'REMOTE_PORT': '8765'})
Expand Down

0 comments on commit 6e48b17

Please sign in to comment.