Skip to content

Commit

Permalink
disable csrf checking on all exception views unless explicitly turned on
Browse files Browse the repository at this point in the history
  • Loading branch information
mmerickel committed Apr 18, 2016
1 parent 8840437 commit 6f524a9
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 9 deletions.
59 changes: 58 additions & 1 deletion pyramid/tests/test_viewderivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,64 @@ def inner_view(request):
result = view(None, request)
self.assertTrue(result is response)

def test_csrf_view_skipped_by_default_on_exception_view(self):
from pyramid.request import Request
def view(request):
raise ValueError
def excview(request):
return 'hello'
self.config.add_settings({'pyramid.require_default_csrf': 'yes'})
self.config.set_session_factory(
lambda request: DummySession({'csrf_token': 'foo'}))
self.config.add_view(view, name='foo', require_csrf=False)
self.config.add_view(excview, context=ValueError, renderer='string')
app = self.config.make_wsgi_app()
request = Request.blank('/foo', base_url='http://example.com')
request.method = 'POST'
response = request.get_response(app)
self.assertTrue(b'hello' in response.body)

def test_csrf_view_failed_on_explicit_exception_view(self):
from pyramid.exceptions import BadCSRFToken
from pyramid.request import Request
def view(request):
raise ValueError
def excview(request): pass
self.config.add_settings({'pyramid.require_default_csrf': 'yes'})
self.config.set_session_factory(
lambda request: DummySession({'csrf_token': 'foo'}))
self.config.add_view(view, name='foo', require_csrf=False)
self.config.add_view(excview, context=ValueError, renderer='string',
require_csrf=True)
app = self.config.make_wsgi_app()
request = Request.blank('/foo', base_url='http://example.com')
request.method = 'POST'
try:
request.get_response(app)
except BadCSRFToken:
pass
else: # pragma: no cover
raise AssertionError

def test_csrf_view_passed_on_explicit_exception_view(self):
from pyramid.request import Request
def view(request):
raise ValueError
def excview(request):
return 'hello'
self.config.add_settings({'pyramid.require_default_csrf': 'yes'})
self.config.set_session_factory(
lambda request: DummySession({'csrf_token': 'foo'}))
self.config.add_view(view, name='foo', require_csrf=False)
self.config.add_view(excview, context=ValueError, renderer='string',
require_csrf=True)
app = self.config.make_wsgi_app()
request = Request.blank('/foo', base_url='http://example.com')
request.method = 'POST'
request.headers['X-CSRF-Token'] = 'foo'
response = request.get_response(app)
self.assertTrue(b'hello' in response.body)


class TestDerivationOrder(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -1554,7 +1612,6 @@ def _getViewCallable(self, config, ctx_iface=None, request_iface=None,
from pyramid.interfaces import IRequest
from pyramid.interfaces import IView
from pyramid.interfaces import IViewClassifier
from pyramid.interfaces import IExceptionViewClassifier
classifier = IViewClassifier
if ctx_iface is None:
ctx_iface = Interface
Expand Down
24 changes: 16 additions & 8 deletions pyramid/viewderivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,21 +483,29 @@ def csrf_view(view, info):
default_val = _parse_csrf_setting(
info.settings.get('pyramid.require_default_csrf'),
'Config setting "pyramid.require_default_csrf"')
val = _parse_csrf_setting(
explicit_val = _parse_csrf_setting(
info.options.get('require_csrf'),
'View option "require_csrf"')
if (val is True and default_val) or val is None:
val = default_val
if val is True:
val = 'csrf_token'
resolved_val = explicit_val
if (explicit_val is True and default_val) or explicit_val is None:
resolved_val = default_val
if resolved_val is True:
resolved_val = 'csrf_token'
wrapped_view = view
if val:
if resolved_val:
def csrf_view(context, request):
# Assume that anything not defined as 'safe' by RFC2616 needs
# protection
if request.method not in SAFE_REQUEST_METHODS:
if (
request.method not in SAFE_REQUEST_METHODS and
(
# skip exception views unless value is explicitly defined
getattr(request, 'exception', None) is None or
explicit_val is not None
)
):
check_csrf_origin(request, raises=True)
check_csrf_token(request, val, raises=True)
check_csrf_token(request, resolved_val, raises=True)
return view(context, request)
wrapped_view = csrf_view
return wrapped_view
Expand Down

0 comments on commit 6f524a9

Please sign in to comment.