Skip to content

Commit

Permalink
feat(chart-data-api): download multiple csvs as zip (#18618)
Browse files Browse the repository at this point in the history
* feat(chart-data-api): download multiple csvs as zip

* break out util

* check for empty request
  • Loading branch information
villebro authored Feb 8, 2022
1 parent 9c08bc0 commit 125be78
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 7 deletions.
28 changes: 21 additions & 7 deletions superset/charts/data/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any, Dict, Optional, TYPE_CHECKING

import simplejson
from flask import g, make_response, request
from flask import current_app, g, make_response, request, Response
from flask_appbuilder.api import expose, protect
from flask_babel import gettext as _
from marshmallow import ValidationError
Expand All @@ -44,13 +44,11 @@
from superset.exceptions import QueryObjectValidationError
from superset.extensions import event_logger
from superset.utils.async_query_manager import AsyncQueryTokenException
from superset.utils.core import json_int_dttm_ser
from superset.utils.core import create_zip, json_int_dttm_ser
from superset.views.base import CsvResponse, generate_download_headers
from superset.views.base_api import statsd_metrics

if TYPE_CHECKING:
from flask import Response

from superset.common.query_context import QueryContext

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -350,9 +348,25 @@ def _send_chart_response(
if not security_manager.can_access("can_csv", "Superset"):
return self.response_403()

# return the first result
data = result["queries"][0]["data"]
return CsvResponse(data, headers=generate_download_headers("csv"))
if not result["queries"]:
return self.response_400(_("Empty query result"))

if len(result["queries"]) == 1:
# return single query results csv format
data = result["queries"][0]["data"]
return CsvResponse(data, headers=generate_download_headers("csv"))

# return multi-query csv results bundled as a zip file
encoding = current_app.config["CSV_EXPORT"].get("encoding", "utf-8")
files = {
f"query_{idx + 1}.csv": result["data"].encode(encoding)
for idx, result in enumerate(result["queries"])
}
return Response(
create_zip(files),
headers=generate_download_headers("zip"),
mimetype="application/zip",
)

if result_format == ChartDataResultFormat.JSON:
response_data = simplejson.dumps(
Expand Down
12 changes: 12 additions & 0 deletions superset/utils/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from email.mime.text import MIMEText
from email.utils import formatdate
from enum import Enum, IntEnum
from io import BytesIO
from timeit import default_timer
from types import TracebackType
from typing import (
Expand All @@ -61,6 +62,7 @@
Union,
)
from urllib.parse import unquote_plus
from zipfile import ZipFile

import bleach
import markdown as md
Expand Down Expand Up @@ -1788,3 +1790,13 @@ def apply_max_row_limit(limit: int, max_limit: Optional[int] = None,) -> int:
if limit != 0:
return min(max_limit, limit)
return max_limit


def create_zip(files: Dict[str, Any]) -> BytesIO:
buf = BytesIO()
with ZipFile(buf, "w") as bundle:
for filename, contents in files.items():
with bundle.open(filename, "w") as fp:
fp.write(contents)
buf.seek(0)
return buf
30 changes: 30 additions & 0 deletions tests/integration_tests/charts/data/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,11 @@
import unittest
import copy
from datetime import datetime
from io import BytesIO
from typing import Optional
from unittest import mock
from zipfile import ZipFile

from flask import Response
from tests.integration_tests.conftest import with_feature_flags
from superset.models.sql_lab import Query
Expand Down Expand Up @@ -235,6 +238,16 @@ def test_with_query_result_type__200(self):
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 200

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_empty_request_with_csv_result_format(self):
"""
Chart data API: Test empty chart data with CSV result format
"""
self.query_context_payload["result_format"] = "csv"
self.query_context_payload["queries"] = []
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 400

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_csv_result_format(self):
"""
Expand All @@ -243,6 +256,22 @@ def test_with_csv_result_format(self):
self.query_context_payload["result_format"] = "csv"
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 200
assert rv.mimetype == "text/csv"

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_multi_query_csv_result_format(self):
"""
Chart data API: Test chart data with multi-query CSV result format
"""
self.query_context_payload["result_format"] = "csv"
self.query_context_payload["queries"].append(
self.query_context_payload["queries"][0]
)
rv = self.post_assert_metric(CHART_DATA_URI, self.query_context_payload, "data")
assert rv.status_code == 200
assert rv.mimetype == "application/zip"
zipfile = ZipFile(BytesIO(rv.data), "r")
assert zipfile.namelist() == ["query_1.csv", "query_2.csv"]

@pytest.mark.usefixtures("load_birth_names_dashboard_with_slices")
def test_with_csv_result_format_when_actor_not_permitted_for_csv__403(self):
Expand Down Expand Up @@ -766,6 +795,7 @@ def test_chart_data_get(self):
}
)
rv = self.get_assert_metric(f"api/v1/chart/{chart.id}/data/", "get_data")
assert rv.mimetype == "application/json"
data = json.loads(rv.data.decode("utf-8"))
assert data["result"][0]["status"] == "success"
assert data["result"][0]["rowcount"] == 2
Expand Down

0 comments on commit 125be78

Please sign in to comment.