diff --git a/sanic/static.py b/sanic/static.py index bb37e56121..256da6da50 100644 --- a/sanic/static.py +++ b/sanic/static.py @@ -1,3 +1,4 @@ +from functools import partial, wraps from mimetypes import guess_type from os import path from re import sub @@ -15,6 +16,91 @@ from sanic.response import HTTPResponse, file, file_stream +async def _static_request_handler( + file_or_directory, + use_modified_since, + use_content_range, + stream_large_files, + request, + content_type=None, + file_uri=None +): + # Using this to determine if the URL is trying to break out of the path + # served. os.path.realpath seems to be very slow + if file_uri and "../" in file_uri: + raise InvalidUsage("Invalid URL") + # Merge served directory and requested file if provided + # Strip all / that in the beginning of the URL to help prevent python + # from herping a derp and treating the uri as an absolute path + root_path = file_path = file_or_directory + if file_uri: + file_path = path.join( + file_or_directory, sub("^[/]*", "", file_uri) + ) + + # URL decode the path sent by the browser otherwise we won't be able to + # match filenames which got encoded (filenames with spaces etc) + file_path = path.abspath(unquote(file_path)) + if not file_path.startswith(path.abspath(unquote(root_path))): + raise FileNotFound( + "File not found", path=file_or_directory, relative_url=file_uri + ) + try: + headers = {} + # Check if the client has been sent this file before + # and it has not been modified since + stats = None + if use_modified_since: + stats = await stat_async(file_path) + modified_since = strftime( + "%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime) + ) + if request.headers.get("If-Modified-Since") == modified_since: + return HTTPResponse(status=304) + headers["Last-Modified"] = modified_since + _range = None + if use_content_range: + _range = None + if not stats: + stats = await stat_async(file_path) + headers["Accept-Ranges"] = "bytes" + headers["Content-Length"] = str(stats.st_size) + if request.method != "HEAD": + try: + _range = ContentRangeHandler(request, stats) + except HeaderNotFound: + pass + else: + del headers["Content-Length"] + for key, value in _range.headers.items(): + headers[key] = value + headers["Content-Type"] = ( + content_type or guess_type(file_path)[0] or "text/plain" + ) + if request.method == "HEAD": + return HTTPResponse(headers=headers) + else: + if stream_large_files: + if type(stream_large_files) == int: + threshold = stream_large_files + else: + threshold = 1024 * 1024 + + if not stats: + stats = await stat_async(file_path) + if stats.st_size >= threshold: + return await file_stream( + file_path, headers=headers, _range=_range + ) + return await file(file_path, headers=headers, _range=_range) + except ContentRangeError: + raise + except Exception: + raise FileNotFound( + "File not found", path=file_or_directory, relative_url=file_uri + ) + + def register( app, uri, @@ -56,86 +142,21 @@ def register( if not path.isfile(file_or_directory): uri += "" - async def _handler(request, file_uri=None): - # Using this to determine if the URL is trying to break out of the path - # served. os.path.realpath seems to be very slow - if file_uri and "../" in file_uri: - raise InvalidUsage("Invalid URL") - # Merge served directory and requested file if provided - # Strip all / that in the beginning of the URL to help prevent python - # from herping a derp and treating the uri as an absolute path - root_path = file_path = file_or_directory - if file_uri: - file_path = path.join( - file_or_directory, sub("^[/]*", "", file_uri) - ) - - # URL decode the path sent by the browser otherwise we won't be able to - # match filenames which got encoded (filenames with spaces etc) - file_path = path.abspath(unquote(file_path)) - if not file_path.startswith(path.abspath(unquote(root_path))): - raise FileNotFound( - "File not found", path=file_or_directory, relative_url=file_uri - ) - try: - headers = {} - # Check if the client has been sent this file before - # and it has not been modified since - stats = None - if use_modified_since: - stats = await stat_async(file_path) - modified_since = strftime( - "%a, %d %b %Y %H:%M:%S GMT", gmtime(stats.st_mtime) - ) - if request.headers.get("If-Modified-Since") == modified_since: - return HTTPResponse(status=304) - headers["Last-Modified"] = modified_since - _range = None - if use_content_range: - _range = None - if not stats: - stats = await stat_async(file_path) - headers["Accept-Ranges"] = "bytes" - headers["Content-Length"] = str(stats.st_size) - if request.method != "HEAD": - try: - _range = ContentRangeHandler(request, stats) - except HeaderNotFound: - pass - else: - del headers["Content-Length"] - for key, value in _range.headers.items(): - headers[key] = value - headers["Content-Type"] = ( - content_type or guess_type(file_path)[0] or "text/plain" - ) - if request.method == "HEAD": - return HTTPResponse(headers=headers) - else: - if stream_large_files: - if type(stream_large_files) == int: - threshold = stream_large_files - else: - threshold = 1024 * 1024 - - if not stats: - stats = await stat_async(file_path) - if stats.st_size >= threshold: - return await file_stream( - file_path, headers=headers, _range=_range - ) - return await file(file_path, headers=headers, _range=_range) - except ContentRangeError: - raise - except Exception: - raise FileNotFound( - "File not found", path=file_or_directory, relative_url=file_uri - ) - # special prefix for static files if not name.startswith("_static_"): name = f"_static_{name}" + _handler = wraps(_static_request_handler)( + partial( + _static_request_handler, + file_or_directory, + use_modified_since, + use_content_range, + stream_large_files, + content_type=content_type + ) + ) + app.route( uri, methods=["GET", "HEAD"], diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py index 3cd60e5630..ef1db52f40 100644 --- a/tests/test_multiprocessing.py +++ b/tests/test_multiprocessing.py @@ -87,3 +87,14 @@ def test_pickle_app_with_bp(app, protocol): request, response = up_p_app.test_client.get("/") assert up_p_app.is_request_stream is False assert response.text == "Hello" + +@pytest.mark.parametrize("protocol", [3, 4]) +def test_pickle_app_with_static(app, protocol): + app.route("/")(handler) + app.static('/static', "/tmp/static") + p_app = pickle.dumps(app, protocol=protocol) + del app + up_p_app = pickle.loads(p_app) + assert up_p_app + request, response = up_p_app.test_client.get("/static/missing.txt") + assert response.status == 404