Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial read_gbq implementation (WIP) #1

Closed
wants to merge 16 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
repos:
- repo: https://github.com/psf/black
rev: 20.8b1
hooks:
- id: black
language_version: python3
exclude: versioneer.py
- repo: https://gitlab.com/pycqa/flake8
rev: 3.8.3
hooks:
- id: flake8
language_version: python3
- repo: https://github.com/pycqa/isort
rev: 5.8.0
hooks:
- id: isort
language_version: python3
1 change: 1 addition & 0 deletions dask_bigquery/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .core import read_gbq
236 changes: 236 additions & 0 deletions dask_bigquery/core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
from __future__ import annotations

import logging
import warnings
from collections.abc import Iterable
from contextlib import contextmanager
from functools import partial

import dask
import dask.dataframe as dd
import pandas as pd
import pyarrow
from google.cloud import bigquery, bigquery_storage


@contextmanager
def bigquery_client(project_id="dask-bigquery", with_storage_api=False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe link to googleapis/google-cloud-python#9457 and/or googleapis/gapic-generator-python#575 for context (no pun intended)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also re project_id, you probably ought make it default to None and infer from the global context, don't remember off the top of my head where you grab that from

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe link to googleapis/google-cloud-python#9457 and/or googleapis/gapic-generator-python#575 for context (no pun intended)

Do you mean add them as a comment?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, just bc eventually there will probably be one upstream that could be used but right now there's not

# Ignore google auth credentials warning
Copy link
Contributor

@bnaul bnaul Aug 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

probably delete this filterwarnings (although the warning is super annoying 🙃)

warnings.filterwarnings(
"ignore", "Your application has authenticated using end user credentials"
)

bq_storage_client = None
bq_client = bigquery.Client(project_id)
try:
if with_storage_api:
bq_storage_client = bigquery_storage.BigQueryReadClient(
credentials=bq_client._credentials
)
yield bq_client, bq_storage_client
else:
yield bq_client
finally:
bq_client.close()


def _stream_to_dfs(bqs_client, stream_name, schema, timeout):
"""Given a Storage API client and a stream name, yield all dataframes."""
return [
pyarrow.ipc.read_record_batch(
pyarrow.py_buffer(message.arrow_record_batch.serialized_record_batch),
schema,
).to_pandas()
for message in bqs_client.read_rows(name=stream_name, offset=0, timeout=timeout)
]


@dask.delayed
def _read_rows_arrow(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we use delayed (see comment below) then this name will be the name that shows up in the task stream, progress bars, etc.. We may want to make it more clearly GBQ related, like bigquery_read

*,
make_create_read_session_request: callable,
partition_field: str = None,
project_id: str,
stream_name: str = None,
timeout: int,
) -> pd.DataFrame:
"""Read a single batch of rows via BQ Storage API, in Arrow binary format.
Args:
project_id: BigQuery project
create_read_session_request: kwargs to pass to `bqs_client.create_read_session`
as `request`
partition_field: BigQuery field for partitions, to be used as Dask index col for
divisions
NOTE: Please set if specifying `row_restriction` filters in TableReadOptions.
stream_name: BigQuery Storage API Stream "name".
NOTE: Please set if reading from Storage API without any `row_restriction`.
https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream
NOTE: `partition_field` and `stream_name` kwargs are mutually exclusive.
Adapted from
https://github.com/googleapis/python-bigquery-storage/blob/a0fc0af5b4447ce8b50c365d4d081b9443b8490e/google/cloud/bigquery_storage_v1/reader.py.
"""
with bigquery_client(project_id, with_storage_api=True) as (bq_client, bqs_client):
session = bqs_client.create_read_session(make_create_read_session_request())
schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(session.arrow_schema.serialized_schema)
)

if (partition_field is not None) and (stream_name is not None):
raise ValueError(
"The kwargs `partition_field` and `stream_name` are mutually exclusive."
)

elif partition_field is not None:
shards = [
df
for stream in session.streams
for df in _stream_to_dfs(
bqs_client, stream.name, schema, timeout=timeout
)
]
# NOTE: if no rows satisfying the row_restriction, then `shards` will be empty list
if len(shards) == 0:
shards = [schema.empty_table().to_pandas()]
shards = [shard.set_index(partition_field, drop=True) for shard in shards]

elif stream_name is not None:
shards = _stream_to_dfs(bqs_client, stream_name, schema, timeout=timeout)
# NOTE: BQ Storage API can return empty streams
if len(shards) == 0:
shards = [schema.empty_table().to_pandas()]

else:
raise NotImplementedError(
"Please specify either `partition_field` or `stream_name`."
)

return pd.concat(shards)


def read_gbq(
project_id: str,
dataset_id: str,
table_id: str,
partition_field: str = None,
partitions: Iterable[str] = None,
row_filter="",
fields: list[str] = (),
read_timeout: int = 3600,
):
"""Read table as dask dataframe using BigQuery Storage API via Arrow format.
If `partition_field` and `partitions` are specified, then the resulting dask dataframe
will be partitioned along the same boundaries. Otherwise, partitions will be approximately
balanced according to BigQuery stream allocation logic.
If `partition_field` is specified but not included in `fields` (either implicitly by requesting
all fields, or explicitly by inclusion in the list `fields`), then it will still be included
in the query in order to have it available for dask dataframe indexing.
Args:
project_id: BigQuery project
dataset_id: BigQuery dataset within project
table_id: BigQuery table within dataset
partition_field: to specify filters of form "WHERE {partition_field} = ..."
partitions: all values to select of `partition_field`
fields: names of the fields (columns) to select (default None to "SELECT *")
read_timeout: # of seconds an individual read request has before timing out
Returns:
dask dataframe
See https://github.com/dask/dask/issues/3121 for additional context.
"""
if (partition_field is None) and (partitions is not None):
raise ValueError("Specified `partitions` without `partition_field`.")

# If `partition_field` is not part of the `fields` filter, fetch it anyway to be able
# to set it as dask dataframe index. We want this to be able to have consistent:
# BQ partitioning + dask divisions + pandas index values
if (partition_field is not None) and fields and (partition_field not in fields):
fields = (partition_field, *fields)

# These read tasks seems to cause deadlocks (or at least long stuck workers out of touch with
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this annotate is maybe a bad idea, would be nice to have @jrbourbeau or someone weigh in; note that we observed this behavior with now-fairly old dask and bigquery_storage/pyarrow versions so I have no idea if it's still relevant

# the scheduler), particularly when mixed with other tasks that execute C code. Anecdotally
# annotating the tasks with a higher priority seems to help (but not fully solve) the issue at
# the expense of higher cluster memory usage.
with bigquery_client(project_id, with_storage_api=True) as (
bq_client,
bqs_client,
), dask.annotate(priority=1):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We would definitely prefer to not have this annotation if possible. Data generation tasks should be *de-*prioritized if anything

table_ref = bq_client.get_table(".".join((dataset_id, table_id)))
if table_ref.table_type == "VIEW":
# Materialize the view since the operations below don't work on views.
logging.warning(
"Materializing view in order to read into dask. This may be expensive."
)
query = f"SELECT * FROM `{full_id(table_ref)}`"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is just a shortcut we use

def full_id(table):
    return f"{table.project}.{table.dataset_id}.{table.table_id}"

your call whether to inline it or add the helper

also same comment re: this view behavior, not sure whether it'd be safer here to just raise instead of materializing

table_ref, _, _ = execute_query(query)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bnaul what about the execute_query function, is this also a shortcut?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah yes, it's just a wrapper for bigquery.Client.query() that saves the result to a temporary table. but it needs a temporary dataset to store things in which not everyone would have configured, so again maybe it's better to just raise for VIEWs instead

Copy link
Contributor Author

@ncclementi ncclementi Aug 6, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh I see, if we remove this part I'm assuming it'll raise an error on itself when we have a "VIEW" case, or is there a need to do a custom raise like:

if table_ref.table_type == "VIEW":
   raise TypeError('Table type VIEW not supported')


# The protobuf types can't be pickled (may be able to tweak w/ copyreg), so instead use a
# generator func.
def make_create_read_session_request(row_filter=""):
return bigquery_storage.types.CreateReadSessionRequest(
max_stream_count=100, # 0 -> use as many streams as BQ Storage will provide
parent=f"projects/{project_id}",
read_session=bigquery_storage.types.ReadSession(
data_format=bigquery_storage.types.DataFormat.ARROW,
read_options=bigquery_storage.types.ReadSession.TableReadOptions(
row_restriction=row_filter,
selected_fields=fields,
),
table=table_ref.to_bqstorage(),
),
)

# Create a read session in order to detect the schema.
# Read sessions are light weight and will be auto-deleted after 24 hours.
session = bqs_client.create_read_session(
make_create_read_session_request(row_filter=row_filter)
)
schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(session.arrow_schema.serialized_schema)
)
meta = schema.empty_table().to_pandas()
delayed_kwargs = dict(prefix=f"{dataset_id}.{table_id}-")

if partition_field is not None:
if row_filter:
raise ValueError("Cannot pass both `partition_field` and `row_filter`")
delayed_kwargs["meta"] = meta.set_index(partition_field, drop=True)

if partitions is None:
logging.info(
"Specified `partition_field` without `partitions`; reading full table."
)
partitions = pd.read_gbq(
f"SELECT DISTINCT {partition_field} FROM {dataset_id}.{table_id}",
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will do a complete scan of the table. Maybe consider using something like return [p for p in bq_client.list_partitions(f'{dataset_id}.{table_id}') if p != '__NULL__']

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shaayohn thanks for the comment, I got a bit confused by the return do you mean replacing this:

partitions = pd.read_gbq(
                    f"SELECT DISTINCT {partition_field} FROM {dataset_id}.{table_id}",
                    project_id=project_id,
                )[partition_field].tolist()

For

partitions = [p for p in bq_client.list_partitions(f'{dataset_id}.{table_id}') if p != '__NULL__']

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Apologies, that's exactly what I meant! Silly copypasta :)

project_id=project_id,
)[partition_field].tolist()
# TODO generalize to ranges (as opposed to discrete values)

partitions = sorted(partitions)
delayed_kwargs["divisions"] = (*partitions, partitions[-1])
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@bnaul I noticed in the example I run, that this line causes to have the last partition to contain only 1 element, but that element could have fit into the previous to last partition. What is the reason you separate the last partition?

Copy link
Contributor

@bnaul bnaul Aug 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure why it's not working correctly for you but the idea is that you need n+1 divisions for n partitions. seems to work OK here

import dask.dataframe as dd
from dask import delayed

@delayed
def make_df(d):
    return pd.DataFrame({"date": d, "x": np.random.random(10)}).set_index("date")

dates = pd.date_range("2020-01-01", "2020-01-08")
ddf = dd.from_delayed([make_df(d) for d in dates], divisions=[*dates, dates[-1]])

ddf
Out[61]:
Dask DataFrame Structure:
                     x
npartitions=8
2020-01-01     float64
2020-01-02         ...
...                ...
2020-01-08         ...
2020-01-08         ...
Dask Name: from-delayed, 16 tasks


ddf.map_partitions(len).compute()
Out[62]:
0    10
1    10
2    10
3    10
4    10
5    10
6    10
7    10
dtype: int64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if it's related to how the data is originally partitioned. For example when I read one of the tables of the covid public data set that I copied on "my_project" I see this

from dask_bigquery import read_gbq

ddf= read_gbq(
                project_id="my_project",
                dataset_id="covid19_public_forecasts",
                table_id="county_14d",)

ddf.map_partitions(len).compute()

Notice the last two partitions...

0     3164
1     3164
2     3164
3     3164
4     3164
5     3164
6     3164
7     3164
8     3164
9     3164
10    3164
11    3164
12    3164
13    3164
14    3164
15    3164
16    3164
17    3164
18    3164
19    3164
20    3164
21    3164
22    3164
23    3164
24    3164
25    3164
26    3164
27    3164
28    3164
29    3164
30    3164
31    3164
32    3164
33    3164
34    3164
35    3164
36    3164
37    3164
38    3164
39    3164
40    3164
41    3163
42       1
dtype: int64

row_filters = [
f'{partition_field} = "{partition_value}"'
for partition_value in partitions
]
delayed_dfs = [
_read_rows_arrow(
make_create_read_session_request=partial(
make_create_read_session_request, row_filter=row_filter
),
partition_field=partition_field,
project_id=project_id,
timeout=read_timeout,
)
for row_filter in row_filters
]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great for now, but at some point we may want to use raw task graphs. They're a bit cleaner in a few ways. Delayed is more designed for user code. If we have the time we prefer to use raw graphs in dev code.

For example, in some cases I wouldn't be surprised if each Delayed task produces a single TaskGroup, rather than having all of the tasks in a single TaskGroup. Sure, this will compute just fine, but other features (like the task group visualization, or coiled telemetry) may be sad.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jrbourbeau and I gave it a try to use HighLevelGraphs and we realized that this will require modifying the structure of the function _read_rows_arrow since as is now, the inputs don't match the required format asked in DataFrameIOLayer
https://github.com/dask/dask/blob/95fb60a31a87c6b94b01ed75ab6533fa04d51f19/dask/layers.py#L1159-L1166

We might want to move this to a separate PR.

else:
delayed_kwargs["meta"] = meta
delayed_dfs = [
_read_rows_arrow(
make_create_read_session_request=make_create_read_session_request,
project_id=project_id,
stream_name=stream.name,
timeout=read_timeout,
)
for stream in session.streams
]

return dd.from_delayed(dfs=delayed_dfs, **delayed_kwargs)
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
dask
distributed
google-cloud-bigquery
google-cloud-bigquery-storage
pandas
pandas-gbq
pyarrow
4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[flake8]
exclude = __init__.py
max-line-length = 120
ignore = F811
21 changes: 21 additions & 0 deletions setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
#!/usr/bin/env python

from setuptools import setup

with open("README.md", "r", encoding="utf-8") as f:
long_description = f.read()

setup(
name="dask-bigquery",
version="0.0.1",
description="Dask + BigQuery intergration",
license="BSD",
packages=["dask_bigquery"],
long_description=long_description,
long_description_content_type="text/markdown",
python_requires=">=3.7",
install_requires=open("requirements.txt").read().strip().split("\n"),
extras_require={"test": ["pytest"]},
include_package_data=True,
zip_safe=False,
)