diff --git a/target_bigquery/core.py b/target_bigquery/core.py index 052dade..ed26f65 100644 --- a/target_bigquery/core.py +++ b/target_bigquery/core.py @@ -129,7 +129,9 @@ def get_escaped_name(self, suffix: str = "") -> str: """Returns the table name as as escaped SQL string.""" return f"`{self.project}`.`{self.dataset}`.`{self.name}{suffix}`" - def get_resolved_schema(self, apply_transforms: bool = False) -> List[bigquery.SchemaField]: + def get_resolved_schema( + self, apply_transforms: bool = False + ) -> List[bigquery.SchemaField]: """Returns the schema for this table after factoring in the ingestion strategy.""" if self.ingestion_strategy is IngestionStrategy.FIXED: return DEFAULT_SCHEMA @@ -184,14 +186,17 @@ def create_table( try: self._dataset = client.get_dataset(self.as_dataset_ref()) except NotFound: - self._dataset = client.create_dataset(self.as_dataset(**kwargs["dataset"])) + self._dataset = client.create_dataset( + self.as_dataset(**kwargs["dataset"]) + ) if not hasattr(self, "_table"): try: self._table = client.get_table(self.as_ref()) except NotFound: self._table = client.create_table( self.as_table( - apply_transforms and self.ingestion_strategy != IngestionStrategy.FIXED, + apply_transforms + and self.ingestion_strategy != IngestionStrategy.FIXED, **kwargs["table"], ) ) @@ -208,7 +213,9 @@ def default_table_options(self) -> Dict[str, Any]: "Generated by target-bigquery.\nStream Schema\n{schema}\nBigQuery Ingestion" " Strategy: {strategy}".format( schema=( - (schema_dump[:16000] + "...") if len(schema_dump) > 16000 else schema_dump + (schema_dump[:16000] + "...") + if len(schema_dump) > 16000 + else schema_dump ), strategy=self.ingestion_strategy, ) @@ -224,7 +231,9 @@ def default_dataset_options() -> Dict[str, Any]: return {"location": "US"} def __hash__(self) -> int: - return hash((self.name, self.dataset, self.project, json.dumps(self.jsonschema))) + return hash( + (self.name, self.dataset, self.project, json.dumps(self.jsonschema)) + ) @dataclass @@ -268,7 +277,9 @@ def run(self) -> None: def serialize_exception(self, exc: Exception) -> str: """Serialize an exception to a string.""" - msg = "".join(traceback.format_exception(type(exc), exc, exc.__traceback__, chain=False)) + msg = "".join( + traceback.format_exception(type(exc), exc, exc.__traceback__, chain=False) + ) msg += f"\nWorker ID: {self.ext_id}\n" return msg @@ -442,7 +453,9 @@ def generate_view(self) -> bool: def _validate_and_parse(self, record: dict) -> dict: return record - def preprocess_record(self, record: Dict[str, Any], context: Dict[str, Any]) -> Dict[str, Any]: + def preprocess_record( + self, record: Dict[str, Any], context: Dict[str, Any] + ) -> Dict[str, Any]: """Preprocess a record before writing it to the sink.""" metadata = { k: record.pop(k, None) @@ -531,7 +544,9 @@ def clean_up(self) -> None: ctas_tmp = f"CREATE OR REPLACE TEMP TABLE `{tmp}` AS {dedupe_query}" merge_clause = ( f"MERGE `{self.merge_target}` AS target USING `{tmp or self.table}` AS source ON " - + " AND ".join(f"target.`{f}` = source.`{f}`" for f in self.key_properties) + + " AND ".join( + f"target.`{f}` = source.`{f}`" for f in self.key_properties + ) ) update_clause = "UPDATE SET " + ", ".join( f"target.`{f.name}` = source.`{f.name}`" for f in target.schema @@ -615,7 +630,9 @@ def augmented_syspath(new_paths: Optional[Iterable[str]] = None): def bigquery_client_factory(creds: BigQueryCredentials) -> bigquery.Client: """Get a BigQuery client.""" if creds.path: - return bigquery.Client.from_service_account_json(creds.path, project=creds.project) + return bigquery.Client.from_service_account_json( + creds.path, project=creds.project + ) elif creds.json: return bigquery.Client.from_service_account_info( json.loads(creds.json), project=creds.project @@ -627,7 +644,9 @@ def bigquery_client_factory(creds: BigQueryCredentials) -> bigquery.Client: def gcs_client_factory(creds: BigQueryCredentials) -> storage.Client: """Get a GCS client.""" if creds.path: - return storage.Client.from_service_account_json(creds.path, project=creds.project) + return storage.Client.from_service_account_json( + creds.path, project=creds.project + ) elif creds.json: return storage.Client.from_service_account_info( json.loads(creds.json), project=creds.project @@ -641,7 +660,9 @@ def storage_client_factory( ) -> bigquery_storage_v1.BigQueryWriteClient: """Get a BigQuery Storage Write client.""" if creds.path: - return bigquery_storage_v1.BigQueryWriteClient.from_service_account_file(str(creds.path)) + return bigquery_storage_v1.BigQueryWriteClient.from_service_account_file( + str(creds.path) + ) elif creds.json: return bigquery_storage_v1.BigQueryWriteClient.from_service_account_info( json.loads(creds.json) @@ -726,9 +747,13 @@ def generate_view_statement(self, table_name: BigQueryTable) -> str: projection = "" for field_ in self.translated_schema[:]: if field_.mode == "REPEATED": - projection += indent(self._wrap_json_array(field_, path="$", depth=1), " " * 4) + projection += indent( + self._wrap_json_array(field_, path="$", depth=1), " " * 4 + ) else: - projection += indent(self._bigquery_field_to_projection(field_).as_sql(), " " * 4) + projection += indent( + self._bigquery_field_to_projection(field_).as_sql(), " " * 4 + ) return ( f"CREATE OR REPLACE VIEW {table_name.get_escaped_name('_view')} AS \nSELECT" @@ -761,9 +786,13 @@ def _jsonschema_property_to_bigquery_column( if "items" not in schema_property: return SchemaField(name, "JSON", "REPEATED") items_schema: dict = schema_property["items"] - items_type = bigquery_type(items_schema["type"], items_schema.get("format", None)) + items_type = bigquery_type( + items_schema["type"], items_schema.get("format", None) + ) if items_type == "record": - return self._translate_record_to_bigquery_schema(name, items_schema, "REPEATED") + return self._translate_record_to_bigquery_schema( + name, items_schema, "REPEATED" + ) return SchemaField(name, items_type, "REPEATED") elif "object" in property_type: return self._translate_record_to_bigquery_schema(name, schema_property) @@ -785,7 +814,10 @@ def _jsonschema_property_to_bigquery_column( property_format = schema_property.get("format", None) if "array" in property_type: - if "items" not in schema_property or "type" not in schema_property["items"]: + if ( + "items" not in schema_property + or "type" not in schema_property["items"] + ): return SchemaField(name, "JSON", "REPEATED") items_schema: dict = schema_property["items"] if "patternProperties" in items_schema: @@ -805,7 +837,9 @@ def _jsonschema_property_to_bigquery_column( or "patternProperties" in schema_property ): return SchemaField(name, "JSON", "NULLABLE") - return self._translate_record_to_bigquery_schema(name, schema_property) + return self._translate_record_to_bigquery_schema( + name, schema_property + ) else: if "patternProperties" in schema_property: return SchemaField(name, "JSON", "NULLABLE") @@ -826,7 +860,10 @@ def _translate_record_to_bigquery_schema( if len(properties) == 0: return SchemaField(name, "JSON", mode) - fields = [self._jsonschema_property_to_bigquery_column(col, t) for col, t in properties] + fields = [ + self._jsonschema_property_to_bigquery_column(col, t) + for col, t in properties + ] return SchemaField(name, "RECORD", mode, fields=fields) def _bigquery_field_to_projection( @@ -1016,7 +1053,9 @@ def __del__(self) -> None: # pylint: disable=no-else-return,too-many-branches,too-many-return-statements -def bigquery_type(property_type: List[str], property_format: Optional[str] = None) -> str: +def bigquery_type( + property_type: List[str], property_format: Optional[str] = None +) -> str: """Convert a JSON Schema type to a BigQuery type.""" if property_format == "date-time": return "timestamp" diff --git a/target_bigquery/gcs_stage.py b/target_bigquery/gcs_stage.py index 25e77be..0d87e5d 100644 --- a/target_bigquery/gcs_stage.py +++ b/target_bigquery/gcs_stage.py @@ -147,7 +147,9 @@ def __init__( self.increment_jobs_enqueued = target.increment_jobs_enqueued @staticmethod - def worker_cls_factory(worker_executor_cls: Type[Process], config: Dict[str, Any]) -> Type[ + def worker_cls_factory( + worker_executor_cls: Type[Process], config: Dict[str, Any] + ) -> Type[ Union[ GcsStagingThreadWorker, GcsStagingProcessWorker, @@ -223,7 +225,9 @@ def create_bucket_if_not_exists(self) -> storage.Bucket: storage_class: Optional[str] = self.config.get("storage_class") if storage_class: kwargs["storage_class"] = storage_class - location: str = self.config.get("location", self.default_bucket_options()["location"]) + location: str = self.config.get( + "location", self.default_bucket_options()["location"] + ) if not hasattr(self, "_gcs_bucket"): self._gcs_bucket = self.client.get_bucket(self.as_bucket()) @@ -235,7 +239,9 @@ def create_bucket_if_not_exists(self) -> storage.Bucket: f"specified location: {location}" ) else: - self._gcs_bucket = self.client.create_bucket(self.as_bucket(), location=location) + self._gcs_bucket = self.client.create_bucket( + self.as_bucket(), location=location + ) else: # Wait for eventual consistency time.sleep(5) diff --git a/target_bigquery/proto_gen.py b/target_bigquery/proto_gen.py index aa1eb21..da23220 100644 --- a/target_bigquery/proto_gen.py +++ b/target_bigquery/proto_gen.py @@ -32,7 +32,9 @@ } -def generate_field_v2(base: SchemaField, i: int = 1, pool: Optional[Any] = None) -> Dict[str, Any]: +def generate_field_v2( + base: SchemaField, i: int = 1, pool: Optional[Any] = None +) -> Dict[str, Any]: """Generate proto2 field properties from a SchemaField.""" name: str = base.name typ: str = cast(str, base.field_type).upper() @@ -72,7 +74,9 @@ def proto_schema_factory_v2( for f in bigquery_schema: fhash.update(hash(f).to_bytes(8, "big", signed=True)) fname = f"AnonymousProto_{fhash.hexdigest()}.proto" - clsname = f"net.proto2.python.public.target_bigquery.AnonymousProto_{fhash.hexdigest()}" + clsname = ( + f"net.proto2.python.public.target_bigquery.AnonymousProto_{fhash.hexdigest()}" + ) factory = message_factory.MessageFactory(pool=pool) try: proto_descriptor = factory.pool.FindMessageTypeByName(clsname) @@ -127,6 +131,8 @@ def proto_schema_factory(bigquery_schema: Iterable[SchemaField]) -> Type[proto.M (proto.Message,), { name: f - for f, name in (generate_field(field, i + 1) for i, field in enumerate(bigquery_schema)) + for f, name in ( + generate_field(field, i + 1) for i, field in enumerate(bigquery_schema) + ) }, ) diff --git a/target_bigquery/storage_write.py b/target_bigquery/storage_write.py index 8a8e411..8f66cd7 100644 --- a/target_bigquery/storage_write.py +++ b/target_bigquery/storage_write.py @@ -63,7 +63,9 @@ def get_application_stream(client: BigQueryWriteClient, job: "Job") -> StreamCom """Get an application created stream for the parent. This stream must be finalized and committed.""" write_stream = types.WriteStream() write_stream.type_ = types.WriteStream.Type.PENDING # type: ignore - write_stream = client.create_write_stream(parent=job.parent, write_stream=write_stream) + write_stream = client.create_write_stream( + parent=job.parent, write_stream=write_stream + ) job.template.write_stream = write_stream.name append_rows_stream = writer.AppendRowsStream(client, job.template) rv = (write_stream.name, append_rows_stream) @@ -269,7 +271,9 @@ class BigQueryStorageWriteSink(BaseBigQuerySink): WORKER_CREATION_MIN_INTERVAL = 1.0 @staticmethod - def worker_cls_factory(worker_executor_cls: Type[Process], config: Dict[str, Any]) -> Type[ + def worker_cls_factory( + worker_executor_cls: Type[Process], config: Dict[str, Any] + ) -> Type[ Union[ StorageWriteThreadStreamWorker, StorageWriteProcessStreamWorker, @@ -361,7 +365,9 @@ def commit_streams(self) -> None: ) self.logger.info(f"Batch commit time: {write.commit_time}") self.logger.info(f"Batch commit errors: {write.stream_errors}") - self.logger.info(f"Writes to streams: '{self.open_streams}' have been committed.") + self.logger.info( + f"Writes to streams: '{self.open_streams}' have been committed." + ) self.open_streams = set() def clean_up(self) -> None: diff --git a/target_bigquery/streaming_insert.py b/target_bigquery/streaming_insert.py index 4072646..1c24b6f 100644 --- a/target_bigquery/streaming_insert.py +++ b/target_bigquery/streaming_insert.py @@ -97,7 +97,9 @@ class BigQueryStreamingInsertSink(BaseBigQuerySink): WORKER_CREATION_MIN_INTERVAL = 1.0 @staticmethod - def worker_cls_factory(worker_executor_cls: Type[Process], config: Dict[str, Any]) -> Type[ + def worker_cls_factory( + worker_executor_cls: Type[Process], config: Dict[str, Any] + ) -> Type[ Union[ StreamingInsertThreadWorker, StreamingInsertProcessWorker, @@ -119,10 +121,14 @@ def process_record(self, record: Dict[str, Any], context: Dict[str, Any]) -> Non self.records_to_drain.append(record) def process_batch(self, context: Dict[str, Any]) -> None: - self.global_queue.put(Job(table=self.table.as_ref(), records=self.records_to_drain.copy())) + self.global_queue.put( + Job(table=self.table.as_ref(), records=self.records_to_drain.copy()) + ) self.increment_jobs_enqueued() self.records_to_drain = [] -class BigQueryStreamingInsertDenormalizedSink(Denormalized, BigQueryStreamingInsertSink): +class BigQueryStreamingInsertDenormalizedSink( + Denormalized, BigQueryStreamingInsertSink +): pass diff --git a/target_bigquery/target.py b/target_bigquery/target.py index 25cad28..9a6fbd1 100644 --- a/target_bigquery/target.py +++ b/target_bigquery/target.py @@ -389,7 +389,9 @@ def increment_jobs_enqueued(self) -> None: # We woulod approach this by adding a new ParType enum and interpreting the # the Process, Pipe, and Queue classes as protocols which can be duck-typed. - def get_parallelization_components(self, default=ParType.THREAD) -> Tuple[ + def get_parallelization_components( + self, default=ParType.THREAD + ) -> Tuple[ Type["Process"], Callable[[bool], Tuple["Connection", "Connection"]], Callable[[], "Queue"], @@ -422,7 +424,9 @@ def add_worker_predicate(self) -> bool: """Predicate determining when it is valid to add a worker to the pool.""" return ( self._jobs_enqueued - > getattr(self.get_sink_class(), "WORKER_CAPACITY_FACTOR", WORKER_CAPACITY_FACTOR) + > getattr( + self.get_sink_class(), "WORKER_CAPACITY_FACTOR", WORKER_CAPACITY_FACTOR + ) * (len(self.workers) + 1) and len(self.workers) < self.config.get("options", {}).get( @@ -467,7 +471,9 @@ def resize_worker_pool(self) -> None: # SDK overrides to inject our worker management logic and sink selection. - def get_sink_class(self, stream_name: Optional[str] = None) -> Type[BaseBigQuerySink]: + def get_sink_class( + self, stream_name: Optional[str] = None + ) -> Type[BaseBigQuerySink]: """Returns the sink class to use for a given stream based on user config.""" _ = stream_name method, denormalized = ( diff --git a/target_bigquery/tests/test_sync.py b/target_bigquery/tests/test_sync.py index 39859c1..91f9934 100644 --- a/target_bigquery/tests/test_sync.py +++ b/target_bigquery/tests/test_sync.py @@ -54,7 +54,9 @@ ["batch_job", "streaming_insert", "storage_write_api", "gcs_stage"], ids=["batch_job", "streaming_insert", "storage_write_api", "gcs_stage"], ) -@pytest.mark.parametrize("batch_mode", [False, True], ids=["no_batch_mode", "batch_mode"]) +@pytest.mark.parametrize( + "batch_mode", [False, True], ids=["no_batch_mode", "batch_mode"] +) def test_basic_sync(method, batch_mode): OPTS = { "method": method, @@ -106,7 +108,9 @@ def test_basic_sync(method, batch_mode): # target.get_sink_class().WORKER_CAPACITY_FACTOR = 1 # target.get_sink_class().WORKER_CREATION_MIN_INTERVAL = 1 - client = bigquery_client_factory(BigQueryCredentials(json=target.config["credentials_json"])) + client = bigquery_client_factory( + BigQueryCredentials(json=target.config["credentials_json"]) + ) stdout, stderr = target_sync_test(target, singer_input) del stdout, stderr time.sleep(5) # wait for the eventual consistency seen in LoadJob sinks @@ -151,7 +155,9 @@ def test_basic_denorm_sync(method): singer_input = io.StringIO() singer_input.write( - BASIC_SINGER_STREAM.replace("{stream_name}", table_name).replace("{load_id}", load_id) + BASIC_SINGER_STREAM.replace("{stream_name}", table_name).replace( + "{load_id}", load_id + ) ) singer_input.seek(0) @@ -171,7 +177,9 @@ def test_basic_denorm_sync(method): # target.get_sink_class().WORKER_CAPACITY_FACTOR = 2 # target.get_sink_class().WORKER_CREATION_MIN_INTERVAL = 1 - client = bigquery_client_factory(BigQueryCredentials(json=target.config["credentials_json"])) + client = bigquery_client_factory( + BigQueryCredentials(json=target.config["credentials_json"]) + ) stdout, stderr = target_sync_test(target, singer_input) del stdout, stderr time.sleep(10) # wait for the eventual consistency seen in LoadJobs sinks