From a9f50fae94b49a4e83acb131c04a34731a4a6604 Mon Sep 17 00:00:00 2001 From: Ahmed Abualsaud <65791736+ahmedabu98@users.noreply.github.com> Date: Tue, 17 Dec 2024 16:30:21 -0500 Subject: [PATCH] Python ExternalTransformProvider improvements (#33359) --- .../python/apache_beam/transforms/external.py | 14 ++-- .../transforms/external_transform_provider.py | 65 +++++++++++++------ 2 files changed, 51 insertions(+), 28 deletions(-) diff --git a/sdks/python/apache_beam/transforms/external.py b/sdks/python/apache_beam/transforms/external.py index e44f7482dc61..fb37a8fd974d 100644 --- a/sdks/python/apache_beam/transforms/external.py +++ b/sdks/python/apache_beam/transforms/external.py @@ -962,14 +962,14 @@ def __init__( self, path_to_jar, extra_args=None, classpath=None, append_args=None): if extra_args and append_args: raise ValueError('Only one of extra_args or append_args may be provided') - self._path_to_jar = path_to_jar + self.path_to_jar = path_to_jar self._extra_args = extra_args self._classpath = classpath or [] self._service_count = 0 self._append_args = append_args or [] def is_existing_service(self): - return subprocess_server.is_service_endpoint(self._path_to_jar) + return subprocess_server.is_service_endpoint(self.path_to_jar) @staticmethod def _expand_jars(jar): @@ -997,7 +997,7 @@ def _expand_jars(jar): def _default_args(self): """Default arguments to be used by `JavaJarExpansionService`.""" - to_stage = ','.join([self._path_to_jar] + sum(( + to_stage = ','.join([self.path_to_jar] + sum(( JavaJarExpansionService._expand_jars(jar) for jar in self._classpath or []), [])) args = ['{{PORT}}', f'--filesToStage={to_stage}'] @@ -1009,8 +1009,8 @@ def _default_args(self): def __enter__(self): if self._service_count == 0: - self._path_to_jar = subprocess_server.JavaJarServer.local_jar( - self._path_to_jar) + self.path_to_jar = subprocess_server.JavaJarServer.local_jar( + self.path_to_jar) if self._extra_args is None: self._extra_args = self._default_args() + self._append_args # Consider memoizing these servers (with some timeout). @@ -1018,7 +1018,7 @@ def __enter__(self): 'Starting a JAR-based expansion service from JAR %s ' + ( 'and with classpath: %s' % self._classpath if self._classpath else ''), - self._path_to_jar) + self.path_to_jar) classpath_urls = [ subprocess_server.JavaJarServer.local_jar(path) for jar in self._classpath @@ -1026,7 +1026,7 @@ def __enter__(self): ] self._service_provider = subprocess_server.JavaJarServer( ExpansionAndArtifactRetrievalStub, - self._path_to_jar, + self.path_to_jar, self._extra_args, classpath=classpath_urls) self._service = self._service_provider.__enter__() diff --git a/sdks/python/apache_beam/transforms/external_transform_provider.py b/sdks/python/apache_beam/transforms/external_transform_provider.py index 117c7f7c9b93..b22cd4b24cb6 100644 --- a/sdks/python/apache_beam/transforms/external_transform_provider.py +++ b/sdks/python/apache_beam/transforms/external_transform_provider.py @@ -26,6 +26,7 @@ from apache_beam.transforms import PTransform from apache_beam.transforms.external import BeamJarExpansionService +from apache_beam.transforms.external import JavaJarExpansionService from apache_beam.transforms.external import SchemaAwareExternalTransform from apache_beam.transforms.external import SchemaTransformsConfig from apache_beam.typehints.schemas import named_tuple_to_schema @@ -133,37 +134,57 @@ class ExternalTransformProvider: (see the `urn_pattern` parameter). These classes are generated when :class:`ExternalTransformProvider` is - initialized. We need to give it one or more expansion service addresses that - are already up and running: - >>> provider = ExternalTransformProvider(["localhost:12345", - ... "localhost:12121"]) - We can also give it the gradle target of a standard Beam expansion service: - >>> provider = ExternalTransform(BeamJarExpansionService( - ... "sdks:java:io:google-cloud-platform:expansion-service:shadowJar")) - Let's take a look at the output of :func:`get_available()` to know the - available transforms in the expansion service(s) we provided: + initialized. You can give it an expansion service address that is already + up and running: + + >>> provider = ExternalTransformProvider("localhost:12345") + + Or you can give it the path to an expansion service Jar file: + + >>> provider = ExternalTransformProvider(JavaJarExpansionService( + "path/to/expansion-service.jar")) + + Or you can give it the gradle target of a standard Beam expansion service: + + >>> provider = ExternalTransformProvider(BeamJarExpansionService( + "sdks:java:io:google-cloud-platform:expansion-service:shadowJar")) + + Note that you can provide a list of these services: + + >>> provider = ExternalTransformProvider([ + "localhost:12345", + JavaJarExpansionService("path/to/expansion-service.jar"), + BeamJarExpansionService( + "sdks:java:io:google-cloud-platform:expansion-service:shadowJar")]) + + The output of :func:`get_available()` provides a list of available transforms + in the provided expansion service(s): + >>> provider.get_available() [('JdbcWrite', 'beam:schematransform:org.apache.beam:jdbc_write:v1'), ('BigtableRead', 'beam:schematransform:org.apache.beam:bigtable_read:v1'), ...] - Then retrieve a transform by :func:`get()`, :func:`get_urn()`, or by directly - accessing it as an attribute of :class:`ExternalTransformProvider`. - All of the following commands do the same thing: + You can retrieve a transform with :func:`get()`, :func:`get_urn()`, or by + directly accessing it as an attribute. The following lines all do the same + thing: + >>> provider.get('BigqueryStorageRead') >>> provider.get_urn( - ... 'beam:schematransform:org.apache.beam:bigquery_storage_read:v1') + 'beam:schematransform:org.apache.beam:bigquery_storage_read:v1') >>> provider.BigqueryStorageRead - You can inspect the transform's documentation to know more about it. This - returns some documentation only IF the underlying SchemaTransform - implementation provides any. + You can inspect the transform's documentation for more details. The following + returns the documentation provided by the underlying SchemaTransform. If no + such documentation is provided, this will be empty. + >>> import inspect >>> inspect.getdoc(provider.BigqueryStorageRead) Similarly, you can inspect the transform's signature to know more about its parameters, including their names, types, and any documentation that the underlying SchemaTransform may provide: + >>> inspect.signature(provider.BigqueryStorageRead) (query: 'typing.Union[str, NoneType]: The SQL query to be executed to...', row_restriction: 'typing.Union[str, NoneType]: Read only rows that match...', @@ -178,8 +199,6 @@ class ExternalTransformProvider: query=query, row_restriction=restriction) | 'Some processing' >> beam.Map(...)) - - Experimental; no backwards compatibility guarantees. """ def __init__(self, expansion_services, urn_pattern=STANDARD_URN_PATTERN): f"""Initialize an ExternalTransformProvider @@ -188,6 +207,7 @@ def __init__(self, expansion_services, urn_pattern=STANDARD_URN_PATTERN): A list of expansion services to discover transforms from. Supported forms: * a string representing the expansion service address + * a :attr:`JavaJarExpansionService` pointing to the path of a Java Jar * a :attr:`BeamJarExpansionService` pointing to a gradle target :param urn_pattern: The regular expression used to match valid transforms. In addition to @@ -213,11 +233,14 @@ def _create_wrappers(self): target = service if isinstance(service, BeamJarExpansionService): target = service.gradle_target + if isinstance(service, JavaJarExpansionService): + target = service.path_to_jar try: schematransform_configs = SchemaAwareExternalTransform.discover(service) except Exception as e: logging.exception( - "Encountered an error while discovering expansion service %s:\n%s", + "Encountered an error while discovering " + "expansion service at '%s':\n%s", target, e) continue @@ -249,7 +272,7 @@ def _create_wrappers(self): if skipped_urns: logging.info( - "Skipped URN(s) in %s that don't follow the pattern \"%s\": %s", + "Skipped URN(s) in '%s' that don't follow the pattern \"%s\": %s", target, self._urn_pattern, skipped_urns) @@ -262,7 +285,7 @@ def get_available(self) -> List[Tuple[str, str]]: return list(self._name_to_urn.items()) def get_all(self) -> Dict[str, ExternalTransform]: - """Get all ExternalTransform""" + """Get all ExternalTransforms""" return self._transforms def get(self, name) -> ExternalTransform: