Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disable csrf checking on all exception views unless explicitly turned on #2517

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 58 additions & 1 deletion pyramid/tests/test_viewderivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1297,6 +1297,64 @@ def inner_view(request):
result = view(None, request)
self.assertTrue(result is response)

def test_csrf_view_skipped_by_default_on_exception_view(self):
from pyramid.request import Request
def view(request):
raise ValueError
def excview(request):
return 'hello'
self.config.add_settings({'pyramid.require_default_csrf': 'yes'})
self.config.set_session_factory(
lambda request: DummySession({'csrf_token': 'foo'}))
self.config.add_view(view, name='foo', require_csrf=False)
self.config.add_view(excview, context=ValueError, renderer='string')
app = self.config.make_wsgi_app()
request = Request.blank('/foo', base_url='http://example.com')
request.method = 'POST'
response = request.get_response(app)
self.assertTrue(b'hello' in response.body)

def test_csrf_view_failed_on_explicit_exception_view(self):
from pyramid.exceptions import BadCSRFToken
from pyramid.request import Request
def view(request):
raise ValueError
def excview(request): pass
self.config.add_settings({'pyramid.require_default_csrf': 'yes'})
self.config.set_session_factory(
lambda request: DummySession({'csrf_token': 'foo'}))
self.config.add_view(view, name='foo', require_csrf=False)
self.config.add_view(excview, context=ValueError, renderer='string',
require_csrf=True)
app = self.config.make_wsgi_app()
request = Request.blank('/foo', base_url='http://example.com')
request.method = 'POST'
try:
request.get_response(app)
except BadCSRFToken:
pass
else: # pragma: no cover
raise AssertionError

def test_csrf_view_passed_on_explicit_exception_view(self):
from pyramid.request import Request
def view(request):
raise ValueError
def excview(request):
return 'hello'
self.config.add_settings({'pyramid.require_default_csrf': 'yes'})
self.config.set_session_factory(
lambda request: DummySession({'csrf_token': 'foo'}))
self.config.add_view(view, name='foo', require_csrf=False)
self.config.add_view(excview, context=ValueError, renderer='string',
require_csrf=True)
app = self.config.make_wsgi_app()
request = Request.blank('/foo', base_url='http://example.com')
request.method = 'POST'
request.headers['X-CSRF-Token'] = 'foo'
response = request.get_response(app)
self.assertTrue(b'hello' in response.body)


class TestDerivationOrder(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -1554,7 +1612,6 @@ def _getViewCallable(self, config, ctx_iface=None, request_iface=None,
from pyramid.interfaces import IRequest
from pyramid.interfaces import IView
from pyramid.interfaces import IViewClassifier
from pyramid.interfaces import IExceptionViewClassifier
classifier = IViewClassifier
if ctx_iface is None:
ctx_iface = Interface
Expand Down
24 changes: 16 additions & 8 deletions pyramid/viewderivers.py
Original file line number Diff line number Diff line change
Expand Up @@ -483,21 +483,29 @@ def csrf_view(view, info):
default_val = _parse_csrf_setting(
info.settings.get('pyramid.require_default_csrf'),
'Config setting "pyramid.require_default_csrf"')
val = _parse_csrf_setting(
explicit_val = _parse_csrf_setting(
info.options.get('require_csrf'),
'View option "require_csrf"')
if (val is True and default_val) or val is None:
val = default_val
if val is True:
val = 'csrf_token'
resolved_val = explicit_val
if (explicit_val is True and default_val) or explicit_val is None:
resolved_val = default_val
if resolved_val is True:
resolved_val = 'csrf_token'
wrapped_view = view
if val:
if resolved_val:
def csrf_view(context, request):
# Assume that anything not defined as 'safe' by RFC2616 needs
# protection
if request.method not in SAFE_REQUEST_METHODS:
if (
request.method not in SAFE_REQUEST_METHODS and
(
# skip exception views unless value is explicitly defined
getattr(request, 'exception', None) is None or
explicit_val is not None
)
):
check_csrf_origin(request, raises=True)
check_csrf_token(request, val, raises=True)
check_csrf_token(request, resolved_val, raises=True)
return view(context, request)
wrapped_view = csrf_view
return wrapped_view
Expand Down