Skip to content

Commit

Permalink
style: ruff format
Browse files Browse the repository at this point in the history
  • Loading branch information
z3z1ma committed Jul 1, 2024
1 parent f52531c commit 81046a4
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 38 deletions.
77 changes: 58 additions & 19 deletions target_bigquery/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ def get_escaped_name(self, suffix: str = "") -> str:
"""Returns the table name as as escaped SQL string."""
return f"`{self.project}`.`{self.dataset}`.`{self.name}{suffix}`"

def get_resolved_schema(self, apply_transforms: bool = False) -> List[bigquery.SchemaField]:
def get_resolved_schema(
self, apply_transforms: bool = False
) -> List[bigquery.SchemaField]:
"""Returns the schema for this table after factoring in the ingestion strategy."""
if self.ingestion_strategy is IngestionStrategy.FIXED:
return DEFAULT_SCHEMA
Expand Down Expand Up @@ -184,14 +186,17 @@ def create_table(
try:
self._dataset = client.get_dataset(self.as_dataset_ref())
except NotFound:
self._dataset = client.create_dataset(self.as_dataset(**kwargs["dataset"]))
self._dataset = client.create_dataset(
self.as_dataset(**kwargs["dataset"])
)
if not hasattr(self, "_table"):
try:
self._table = client.get_table(self.as_ref())
except NotFound:
self._table = client.create_table(
self.as_table(
apply_transforms and self.ingestion_strategy != IngestionStrategy.FIXED,
apply_transforms
and self.ingestion_strategy != IngestionStrategy.FIXED,
**kwargs["table"],
)
)
Expand All @@ -208,7 +213,9 @@ def default_table_options(self) -> Dict[str, Any]:
"Generated by target-bigquery.\nStream Schema\n{schema}\nBigQuery Ingestion"
" Strategy: {strategy}".format(
schema=(
(schema_dump[:16000] + "...") if len(schema_dump) > 16000 else schema_dump
(schema_dump[:16000] + "...")
if len(schema_dump) > 16000
else schema_dump
),
strategy=self.ingestion_strategy,
)
Expand All @@ -224,7 +231,9 @@ def default_dataset_options() -> Dict[str, Any]:
return {"location": "US"}

def __hash__(self) -> int:
return hash((self.name, self.dataset, self.project, json.dumps(self.jsonschema)))
return hash(
(self.name, self.dataset, self.project, json.dumps(self.jsonschema))
)


@dataclass
Expand Down Expand Up @@ -268,7 +277,9 @@ def run(self) -> None:

def serialize_exception(self, exc: Exception) -> str:
"""Serialize an exception to a string."""
msg = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__, chain=False))
msg = "".join(
traceback.format_exception(type(exc), exc, exc.__traceback__, chain=False)
)
msg += f"\nWorker ID: {self.ext_id}\n"
return msg

Expand Down Expand Up @@ -442,7 +453,9 @@ def generate_view(self) -> bool:
def _validate_and_parse(self, record: dict) -> dict:
return record

def preprocess_record(self, record: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]:
def preprocess_record(
self, record: Dict[str, Any], context: Dict[str, Any]
) -> Dict[str, Any]:
"""Preprocess a record before writing it to the sink."""
metadata = {
k: record.pop(k, None)
Expand Down Expand Up @@ -531,7 +544,9 @@ def clean_up(self) -> None:
ctas_tmp = f"CREATE OR REPLACE TEMP TABLE `{tmp}` AS {dedupe_query}"
merge_clause = (
f"MERGE `{self.merge_target}` AS target USING `{tmp or self.table}` AS source ON "
+ " AND ".join(f"target.`{f}` = source.`{f}`" for f in self.key_properties)
+ " AND ".join(
f"target.`{f}` = source.`{f}`" for f in self.key_properties
)
)
update_clause = "UPDATE SET " + ", ".join(
f"target.`{f.name}` = source.`{f.name}`" for f in target.schema
Expand Down Expand Up @@ -615,7 +630,9 @@ def augmented_syspath(new_paths: Optional[Iterable[str]] = None):
def bigquery_client_factory(creds: BigQueryCredentials) -> bigquery.Client:
"""Get a BigQuery client."""
if creds.path:
return bigquery.Client.from_service_account_json(creds.path, project=creds.project)
return bigquery.Client.from_service_account_json(
creds.path, project=creds.project
)
elif creds.json:
return bigquery.Client.from_service_account_info(
json.loads(creds.json), project=creds.project
Expand All @@ -627,7 +644,9 @@ def bigquery_client_factory(creds: BigQueryCredentials) -> bigquery.Client:
def gcs_client_factory(creds: BigQueryCredentials) -> storage.Client:
"""Get a GCS client."""
if creds.path:
return storage.Client.from_service_account_json(creds.path, project=creds.project)
return storage.Client.from_service_account_json(
creds.path, project=creds.project
)
elif creds.json:
return storage.Client.from_service_account_info(
json.loads(creds.json), project=creds.project
Expand All @@ -641,7 +660,9 @@ def storage_client_factory(
) -> bigquery_storage_v1.BigQueryWriteClient:
"""Get a BigQuery Storage Write client."""
if creds.path:
return bigquery_storage_v1.BigQueryWriteClient.from_service_account_file(str(creds.path))
return bigquery_storage_v1.BigQueryWriteClient.from_service_account_file(
str(creds.path)
)
elif creds.json:
return bigquery_storage_v1.BigQueryWriteClient.from_service_account_info(
json.loads(creds.json)
Expand Down Expand Up @@ -726,9 +747,13 @@ def generate_view_statement(self, table_name: BigQueryTable) -> str:
projection = ""
for field_ in self.translated_schema[:]:
if field_.mode == "REPEATED":
projection += indent(self._wrap_json_array(field_, path="$", depth=1), " " * 4)
projection += indent(
self._wrap_json_array(field_, path="$", depth=1), " " * 4
)
else:
projection += indent(self._bigquery_field_to_projection(field_).as_sql(), " " * 4)
projection += indent(
self._bigquery_field_to_projection(field_).as_sql(), " " * 4
)

return (
f"CREATE OR REPLACE VIEW {table_name.get_escaped_name('_view')} AS \nSELECT"
Expand Down Expand Up @@ -761,9 +786,13 @@ def _jsonschema_property_to_bigquery_column(
if "items" not in schema_property:
return SchemaField(name, "JSON", "REPEATED")
items_schema: dict = schema_property["items"]
items_type = bigquery_type(items_schema["type"], items_schema.get("format", None))
items_type = bigquery_type(
items_schema["type"], items_schema.get("format", None)
)
if items_type == "record":
return self._translate_record_to_bigquery_schema(name, items_schema, "REPEATED")
return self._translate_record_to_bigquery_schema(
name, items_schema, "REPEATED"
)
return SchemaField(name, items_type, "REPEATED")
elif "object" in property_type:
return self._translate_record_to_bigquery_schema(name, schema_property)
Expand All @@ -785,7 +814,10 @@ def _jsonschema_property_to_bigquery_column(
property_format = schema_property.get("format", None)

if "array" in property_type:
if "items" not in schema_property or "type" not in schema_property["items"]:
if (
"items" not in schema_property
or "type" not in schema_property["items"]
):
return SchemaField(name, "JSON", "REPEATED")
items_schema: dict = schema_property["items"]
if "patternProperties" in items_schema:
Expand All @@ -805,7 +837,9 @@ def _jsonschema_property_to_bigquery_column(
or "patternProperties" in schema_property
):
return SchemaField(name, "JSON", "NULLABLE")
return self._translate_record_to_bigquery_schema(name, schema_property)
return self._translate_record_to_bigquery_schema(
name, schema_property
)
else:
if "patternProperties" in schema_property:
return SchemaField(name, "JSON", "NULLABLE")
Expand All @@ -826,7 +860,10 @@ def _translate_record_to_bigquery_schema(
if len(properties) == 0:
return SchemaField(name, "JSON", mode)

fields = [self._jsonschema_property_to_bigquery_column(col, t) for col, t in properties]
fields = [
self._jsonschema_property_to_bigquery_column(col, t)
for col, t in properties
]
return SchemaField(name, "RECORD", mode, fields=fields)

def _bigquery_field_to_projection(
Expand Down Expand Up @@ -1016,7 +1053,9 @@ def __del__(self) -> None:


# pylint: disable=no-else-return,too-many-branches,too-many-return-statements
def bigquery_type(property_type: List[str], property_format: Optional[str] = None) -> str:
def bigquery_type(
property_type: List[str], property_format: Optional[str] = None
) -> str:
"""Convert a JSON Schema type to a BigQuery type."""
if property_format == "date-time":
return "timestamp"
Expand Down
12 changes: 9 additions & 3 deletions target_bigquery/gcs_stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,9 @@ def __init__(
self.increment_jobs_enqueued = target.increment_jobs_enqueued

@staticmethod
def worker_cls_factory(worker_executor_cls: Type[Process], config: Dict[str, Any]) -> Type[
def worker_cls_factory(
worker_executor_cls: Type[Process], config: Dict[str, Any]
) -> Type[
Union[
GcsStagingThreadWorker,
GcsStagingProcessWorker,
Expand Down Expand Up @@ -223,7 +225,9 @@ def create_bucket_if_not_exists(self) -> storage.Bucket:
storage_class: Optional[str] = self.config.get("storage_class")
if storage_class:
kwargs["storage_class"] = storage_class
location: str = self.config.get("location", self.default_bucket_options()["location"])
location: str = self.config.get(
"location", self.default_bucket_options()["location"]
)

if not hasattr(self, "_gcs_bucket"):
self._gcs_bucket = self.client.get_bucket(self.as_bucket())
Expand All @@ -235,7 +239,9 @@ def create_bucket_if_not_exists(self) -> storage.Bucket:
f"specified location: {location}"
)
else:
self._gcs_bucket = self.client.create_bucket(self.as_bucket(), location=location)
self._gcs_bucket = self.client.create_bucket(
self.as_bucket(), location=location
)
else:
# Wait for eventual consistency
time.sleep(5)
Expand Down
12 changes: 9 additions & 3 deletions target_bigquery/proto_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,9 @@
}


def generate_field_v2(base: SchemaField, i: int = 1, pool: Optional[Any] = None) -> Dict[str, Any]:
def generate_field_v2(
base: SchemaField, i: int = 1, pool: Optional[Any] = None
) -> Dict[str, Any]:
"""Generate proto2 field properties from a SchemaField."""
name: str = base.name
typ: str = cast(str, base.field_type).upper()
Expand Down Expand Up @@ -72,7 +74,9 @@ def proto_schema_factory_v2(
for f in bigquery_schema:
fhash.update(hash(f).to_bytes(8, "big", signed=True))
fname = f"AnonymousProto_{fhash.hexdigest()}.proto"
clsname = f"net.proto2.python.public.target_bigquery.AnonymousProto_{fhash.hexdigest()}"
clsname = (
f"net.proto2.python.public.target_bigquery.AnonymousProto_{fhash.hexdigest()}"
)
factory = message_factory.MessageFactory(pool=pool)
try:
proto_descriptor = factory.pool.FindMessageTypeByName(clsname)
Expand Down Expand Up @@ -127,6 +131,8 @@ def proto_schema_factory(bigquery_schema: Iterable[SchemaField]) -> Type[proto.M
(proto.Message,),
{
name: f
for f, name in (generate_field(field, i + 1) for i, field in enumerate(bigquery_schema))
for f, name in (
generate_field(field, i + 1) for i, field in enumerate(bigquery_schema)
)
},
)
12 changes: 9 additions & 3 deletions target_bigquery/storage_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,9 @@ def get_application_stream(client: BigQueryWriteClient, job: "Job") -> StreamCom
"""Get an application created stream for the parent. This stream must be finalized and committed."""
write_stream = types.WriteStream()
write_stream.type_ = types.WriteStream.Type.PENDING # type: ignore
write_stream = client.create_write_stream(parent=job.parent, write_stream=write_stream)
write_stream = client.create_write_stream(
parent=job.parent, write_stream=write_stream
)
job.template.write_stream = write_stream.name
append_rows_stream = writer.AppendRowsStream(client, job.template)
rv = (write_stream.name, append_rows_stream)
Expand Down Expand Up @@ -269,7 +271,9 @@ class BigQueryStorageWriteSink(BaseBigQuerySink):
WORKER_CREATION_MIN_INTERVAL = 1.0

@staticmethod
def worker_cls_factory(worker_executor_cls: Type[Process], config: Dict[str, Any]) -> Type[
def worker_cls_factory(
worker_executor_cls: Type[Process], config: Dict[str, Any]
) -> Type[
Union[
StorageWriteThreadStreamWorker,
StorageWriteProcessStreamWorker,
Expand Down Expand Up @@ -361,7 +365,9 @@ def commit_streams(self) -> None:
)
self.logger.info(f"Batch commit time: {write.commit_time}")
self.logger.info(f"Batch commit errors: {write.stream_errors}")
self.logger.info(f"Writes to streams: '{self.open_streams}' have been committed.")
self.logger.info(
f"Writes to streams: '{self.open_streams}' have been committed."
)
self.open_streams = set()

def clean_up(self) -> None:
Expand Down
12 changes: 9 additions & 3 deletions target_bigquery/streaming_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,9 @@ class BigQueryStreamingInsertSink(BaseBigQuerySink):
WORKER_CREATION_MIN_INTERVAL = 1.0

@staticmethod
def worker_cls_factory(worker_executor_cls: Type[Process], config: Dict[str, Any]) -> Type[
def worker_cls_factory(
worker_executor_cls: Type[Process], config: Dict[str, Any]
) -> Type[
Union[
StreamingInsertThreadWorker,
StreamingInsertProcessWorker,
Expand All @@ -119,10 +121,14 @@ def process_record(self, record: Dict[str, Any], context: Dict[str, Any]) -> Non
self.records_to_drain.append(record)

def process_batch(self, context: Dict[str, Any]) -> None:
self.global_queue.put(Job(table=self.table.as_ref(), records=self.records_to_drain.copy()))
self.global_queue.put(
Job(table=self.table.as_ref(), records=self.records_to_drain.copy())
)
self.increment_jobs_enqueued()
self.records_to_drain = []


class BigQueryStreamingInsertDenormalizedSink(Denormalized, BigQueryStreamingInsertSink):
class BigQueryStreamingInsertDenormalizedSink(
Denormalized, BigQueryStreamingInsertSink
):
pass
12 changes: 9 additions & 3 deletions target_bigquery/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,9 @@ def increment_jobs_enqueued(self) -> None:
# We woulod approach this by adding a new ParType enum and interpreting the
# the Process, Pipe, and Queue classes as protocols which can be duck-typed.

def get_parallelization_components(self, default=ParType.THREAD) -> Tuple[
def get_parallelization_components(
self, default=ParType.THREAD
) -> Tuple[
Type["Process"],
Callable[[bool], Tuple["Connection", "Connection"]],
Callable[[], "Queue"],
Expand Down Expand Up @@ -422,7 +424,9 @@ def add_worker_predicate(self) -> bool:
"""Predicate determining when it is valid to add a worker to the pool."""
return (
self._jobs_enqueued
> getattr(self.get_sink_class(), "WORKER_CAPACITY_FACTOR", WORKER_CAPACITY_FACTOR)
> getattr(
self.get_sink_class(), "WORKER_CAPACITY_FACTOR", WORKER_CAPACITY_FACTOR
)
* (len(self.workers) + 1)
and len(self.workers)
< self.config.get("options", {}).get(
Expand Down Expand Up @@ -467,7 +471,9 @@ def resize_worker_pool(self) -> None:

# SDK overrides to inject our worker management logic and sink selection.

def get_sink_class(self, stream_name: Optional[str] = None) -> Type[BaseBigQuerySink]:
def get_sink_class(
self, stream_name: Optional[str] = None
) -> Type[BaseBigQuerySink]:
"""Returns the sink class to use for a given stream based on user config."""
_ = stream_name
method, denormalized = (
Expand Down
16 changes: 12 additions & 4 deletions target_bigquery/tests/test_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@
["batch_job", "streaming_insert", "storage_write_api", "gcs_stage"],
ids=["batch_job", "streaming_insert", "storage_write_api", "gcs_stage"],
)
@pytest.mark.parametrize("batch_mode", [False, True], ids=["no_batch_mode", "batch_mode"])
@pytest.mark.parametrize(
"batch_mode", [False, True], ids=["no_batch_mode", "batch_mode"]
)
def test_basic_sync(method, batch_mode):
OPTS = {
"method": method,
Expand Down Expand Up @@ -106,7 +108,9 @@ def test_basic_sync(method, batch_mode):
# target.get_sink_class().WORKER_CAPACITY_FACTOR = 1
# target.get_sink_class().WORKER_CREATION_MIN_INTERVAL = 1

client = bigquery_client_factory(BigQueryCredentials(json=target.config["credentials_json"]))
client = bigquery_client_factory(
BigQueryCredentials(json=target.config["credentials_json"])
)
stdout, stderr = target_sync_test(target, singer_input)
del stdout, stderr
time.sleep(5) # wait for the eventual consistency seen in LoadJob sinks
Expand Down Expand Up @@ -151,7 +155,9 @@ def test_basic_denorm_sync(method):

singer_input = io.StringIO()
singer_input.write(
BASIC_SINGER_STREAM.replace("{stream_name}", table_name).replace("{load_id}", load_id)
BASIC_SINGER_STREAM.replace("{stream_name}", table_name).replace(
"{load_id}", load_id
)
)
singer_input.seek(0)

Expand All @@ -171,7 +177,9 @@ def test_basic_denorm_sync(method):
# target.get_sink_class().WORKER_CAPACITY_FACTOR = 2
# target.get_sink_class().WORKER_CREATION_MIN_INTERVAL = 1

client = bigquery_client_factory(BigQueryCredentials(json=target.config["credentials_json"]))
client = bigquery_client_factory(
BigQueryCredentials(json=target.config["credentials_json"])
)
stdout, stderr = target_sync_test(target, singer_input)
del stdout, stderr
time.sleep(10) # wait for the eventual consistency seen in LoadJobs sinks
Expand Down

0 comments on commit 81046a4

Please sign in to comment.