Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enforce Datetime Type for Expires on Set-Cookie #1484

Merged
merged 4 commits into from
Feb 6, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 11 additions & 10 deletions sanic/cookies.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import re
import string

from datetime import datetime


DEFAULT_MAX_AGE = 0

Expand Down Expand Up @@ -108,6 +110,11 @@ def __setitem__(self, key, value):
if key.lower() == "max-age":
if not str(value).isdigit():
value = DEFAULT_MAX_AGE
elif key.lower() == "expires":
if not isinstance(value, datetime):
raise TypeError(
"Cookie 'expires' property must be a datetime"
)
return super().__setitem__(key, value)

def encode(self, encoding):
Expand All @@ -131,16 +138,10 @@ def encode(self, encoding):
except TypeError:
output.append("%s=%s" % (self._keys[key], value))
elif key == "expires":
try:
output.append(
"%s=%s"
% (
self._keys[key],
value.strftime("%a, %d-%b-%Y %T GMT"),
)
)
except AttributeError:
output.append("%s=%s" % (self._keys[key], value))
output.append(
"%s=%s"
% (self._keys[key], value.strftime("%a, %d-%b-%Y %T GMT"))
)
elif key in self._flags and self[key]:
output.append(self._keys[key])
else:
Expand Down
19 changes: 10 additions & 9 deletions tests/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
def uvloop_installed():
try:
import uvloop

return True
except ImportError:
return False
Expand All @@ -27,28 +28,28 @@ async def handler(request):
assert response.text == "pass"


@pytest.mark.skipif(sys.version_info < (3, 7),
reason="requires python3.7 or higher")
@pytest.mark.skipif(
sys.version_info < (3, 7), reason="requires python3.7 or higher"
)
def test_create_asyncio_server(app):
if not uvloop_installed():
loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server(
return_asyncio_server=True)
asyncio_srv_coro = app.create_server(return_asyncio_server=True)
assert isawaitable(asyncio_srv_coro)
srv = loop.run_until_complete(asyncio_srv_coro)
assert srv.is_serving() is True


@pytest.mark.skipif(sys.version_info < (3, 7),
reason="requires python3.7 or higher")
@pytest.mark.skipif(
sys.version_info < (3, 7), reason="requires python3.7 or higher"
)
def test_asyncio_server_start_serving(app):
if not uvloop_installed():
loop = asyncio.get_event_loop()
asyncio_srv_coro = app.create_server(
return_asyncio_server=True,
asyncio_server_kwargs=dict(
start_serving=False
))
asyncio_server_kwargs=dict(start_serving=False),
)
srv = loop.run_until_complete(asyncio_srv_coro)
assert srv.is_serving() is False

Expand Down
24 changes: 13 additions & 11 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_config_custom_defaults():
custom_defaults = {
"REQUEST_MAX_SIZE": 1,
"KEEP_ALIVE": False,
"ACCESS_LOG": False
"ACCESS_LOG": False,
}
conf = Config(defaults=custom_defaults)
for key, value in DEFAULT_CONFIG.items():
Expand All @@ -182,13 +182,13 @@ def test_config_custom_defaults_with_env():
custom_defaults = {
"REQUEST_MAX_SIZE123": 1,
"KEEP_ALIVE123": False,
"ACCESS_LOG123": False
"ACCESS_LOG123": False,
}

environ_defaults = {
"SANIC_REQUEST_MAX_SIZE123": "2",
"SANIC_KEEP_ALIVE123": "True",
"SANIC_ACCESS_LOG123": "False"
"SANIC_ACCESS_LOG123": "False",
}

for key, value in environ_defaults.items():
Expand All @@ -201,8 +201,8 @@ def test_config_custom_defaults_with_env():
try:
value = int(value)
except ValueError:
if value in ['True', 'False']:
value = value == 'True'
if value in ["True", "False"]:
value = value == "True"

assert getattr(conf, key) == value

Expand All @@ -213,7 +213,7 @@ def test_config_custom_defaults_with_env():
def test_config_access_log_passing_in_run(app):
assert app.config.ACCESS_LOG == True

@app.listener('after_server_start')
@app.listener("after_server_start")
async def _request(sanic, loop):
app.stop()

Expand All @@ -227,16 +227,18 @@ async def _request(sanic, loop):
async def test_config_access_log_passing_in_create_server(app):
assert app.config.ACCESS_LOG == True

@app.listener('after_server_start')
@app.listener("after_server_start")
async def _request(sanic, loop):
app.stop()

await app.create_server(port=1341, access_log=False,
return_asyncio_server=True)
await app.create_server(
port=1341, access_log=False, return_asyncio_server=True
)
assert app.config.ACCESS_LOG == False

await app.create_server(port=1342, access_log=True,
return_asyncio_server=True)
await app.create_server(
port=1342, access_log=True, return_asyncio_server=True
)
assert app.config.ACCESS_LOG == True


Expand Down
13 changes: 9 additions & 4 deletions tests/test_cookies.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,10 +160,7 @@ def handler(request):
assert response.cookies["test"]["max-age"] == str(DEFAULT_MAX_AGE)


@pytest.mark.parametrize(
"expires",
[datetime.now() + timedelta(seconds=60), "Fri, 21-Dec-2018 15:30:00 GMT"],
)
@pytest.mark.parametrize("expires", [datetime.now() + timedelta(seconds=60)])
def test_cookie_expires(app, expires):
cookies = {"test": "wait"}

Expand All @@ -183,3 +180,11 @@ def handler(request):
expires = expires.strftime("%a, %d-%b-%Y %T GMT")

assert response.cookies["test"]["expires"] == expires


@pytest.mark.parametrize("expires", ["Fri, 21-Dec-2018 15:30:00 GMT"])
def test_cookie_expires_illegal_instance_type(expires):
c = Cookie("test_cookie", "value")
with pytest.raises(expected_exception=TypeError) as e:
c["expires"] = expires
assert e.message == "Cookie 'expires' property must be a datetime"
12 changes: 4 additions & 8 deletions tests/test_keep_alive_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,8 @@
from sanic.testing import SanicTestClient, HOST, PORT


CONFIG_FOR_TESTS = {
"KEEP_ALIVE_TIMEOUT": 2,
"KEEP_ALIVE": True
}
CONFIG_FOR_TESTS = {"KEEP_ALIVE_TIMEOUT": 2, "KEEP_ALIVE": True}


class ReuseableTCPConnector(TCPConnector):
def __init__(self, *args, **kwargs):
Expand Down Expand Up @@ -51,9 +49,7 @@ def _sanic_endpoint_test(
uri="/",
gather_request=True,
debug=False,
server_kwargs={
"return_asyncio_server": True,
},
server_kwargs={"return_asyncio_server": True},
*request_args,
**request_kwargs
):
Expand Down Expand Up @@ -147,7 +143,7 @@ async def _collect_response(loop):
# loop, so the changes above are required too.
async def _local_request(self, method, uri, cookies=None, *args, **kwargs):
request_keepalive = kwargs.pop(
"request_keepalive", CONFIG_FOR_TESTS['KEEP_ALIVE_TIMEOUT']
"request_keepalive", CONFIG_FOR_TESTS["KEEP_ALIVE_TIMEOUT"]
)
if uri.startswith(("http:", "https:", "ftp:", "ftps://" "//")):
url = uri
Expand Down
12 changes: 4 additions & 8 deletions tests/test_logo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@


def test_logo_base(app, caplog):
server = app.create_server(
debug=True, return_asyncio_server=True)
server = app.create_server(debug=True, return_asyncio_server=True)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
Expand All @@ -32,8 +31,7 @@ def test_logo_base(app, caplog):
def test_logo_false(app, caplog):
app.config.LOGO = False

server = app.create_server(
debug=True, return_asyncio_server=True)
server = app.create_server(debug=True, return_asyncio_server=True)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
Expand All @@ -52,8 +50,7 @@ def test_logo_false(app, caplog):
def test_logo_true(app, caplog):
app.config.LOGO = True

server = app.create_server(
debug=True, return_asyncio_server=True)
server = app.create_server(debug=True, return_asyncio_server=True)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
Expand All @@ -72,8 +69,7 @@ def test_logo_true(app, caplog):
def test_logo_custom(app, caplog):
app.config.LOGO = "My Custom Logo"

server = app.create_server(
debug=True, return_asyncio_server=True)
server = app.create_server(debug=True, return_asyncio_server=True)
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop._stopping = False
Expand Down
4 changes: 1 addition & 3 deletions tests/test_request_timeout.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,7 @@ async def _local_request(self, method, uri, cookies=None, *args, **kwargs):
host=HOST, port=self.port, uri=uri
)
conn = DelayableTCPConnector(
pre_request_delay=self._request_delay,
ssl=False,
loop=self._loop,
pre_request_delay=self._request_delay, ssl=False, loop=self._loop
)
async with aiohttp.ClientSession(
cookies=cookies, connector=conn, loop=self._loop
Expand Down
3 changes: 1 addition & 2 deletions tests/test_server_events.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ class MySanicDb:
async def init_db(app, loop):
app.db = MySanicDb()

await app.create_server(
debug=True, return_asyncio_server=True)
await app.create_server(debug=True, return_asyncio_server=True)

assert hasattr(app, "db")
assert isinstance(app.db, MySanicDb)
31 changes: 17 additions & 14 deletions tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,28 @@ def gunicorn_worker():
worker.kill()


@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def gunicorn_worker_with_access_logs():
command = (
'gunicorn '
'--bind 127.0.0.1:1338 '
'--worker-class sanic.worker.GunicornWorker '
'examples.simple_server:app'
"gunicorn "
"--bind 127.0.0.1:1338 "
"--worker-class sanic.worker.GunicornWorker "
"examples.simple_server:app"
)
worker = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE)
time.sleep(2)
return worker


@pytest.fixture(scope='module')
@pytest.fixture(scope="module")
def gunicorn_worker_with_env_var():
command = (
'env SANIC_ACCESS_LOG="False" '
'gunicorn '
'--bind 127.0.0.1:1339 '
'--worker-class sanic.worker.GunicornWorker '
'--log-level info '
'examples.simple_server:app'
"gunicorn "
"--bind 127.0.0.1:1339 "
"--worker-class sanic.worker.GunicornWorker "
"--log-level info "
"examples.simple_server:app"
)
worker = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE)
time.sleep(2)
Expand All @@ -62,7 +62,7 @@ def test_gunicorn_worker_no_logs(gunicorn_worker_with_env_var):
"""
if SANIC_ACCESS_LOG was set to False do not show access logs
"""
with urllib.request.urlopen('http://localhost:1339/') as _:
with urllib.request.urlopen("http://localhost:1339/") as _:
gunicorn_worker_with_env_var.kill()
assert not gunicorn_worker_with_env_var.stdout.read()

Expand All @@ -71,9 +71,12 @@ def test_gunicorn_worker_with_logs(gunicorn_worker_with_access_logs):
"""
default - show access logs
"""
with urllib.request.urlopen('http://localhost:1338/') as _:
with urllib.request.urlopen("http://localhost:1338/") as _:
gunicorn_worker_with_access_logs.kill()
assert b"(sanic.access)[INFO][127.0.0.1" in gunicorn_worker_with_access_logs.stdout.read()
assert (
b"(sanic.access)[INFO][127.0.0.1"
in gunicorn_worker_with_access_logs.stdout.read()
)


class GunicornTestWorker(GunicornWorker):
Expand Down