diff --git a/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py b/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py index 7f75e2bf3de1..39a83d6b6949 100644 --- a/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py +++ b/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py @@ -64,9 +64,11 @@ def process( class CrossLanguageKafkaIO(object): - def __init__(self, bootstrap_servers, topic, expansion_service=None): + def __init__( + self, bootstrap_servers, topic, null_key, expansion_service=None): self.bootstrap_servers = bootstrap_servers self.topic = topic + self.null_key = null_key self.expansion_service = expansion_service self.sum_counter = Metrics.counter('source', 'elements_sum') @@ -74,9 +76,9 @@ def build_write_pipeline(self, pipeline): _ = ( pipeline | 'Generate' >> beam.Create(range(NUM_RECORDS)) # pylint: disable=bad-option-value - | 'MakeKV' >> beam.Map(lambda x: - (b'', str(x).encode())).with_output_types( - typing.Tuple[bytes, bytes]) + | 'MakeKV' >> beam.Map( + lambda x: (None if self.null_key else b'key', str(x).encode())). + with_output_types(typing.Tuple[typing.Optional[bytes], bytes]) | 'WriteToKafka' >> WriteToKafka( producer_config={'bootstrap.servers': self.bootstrap_servers}, topic=self.topic, @@ -112,13 +114,26 @@ def run_xlang_kafkaio(self, pipeline): os.environ.get('LOCAL_KAFKA_JAR'), "LOCAL_KAFKA_JAR environment var is not provided.") class CrossLanguageKafkaIOTest(unittest.TestCase): - def test_kafkaio(self): - kafka_topic = 'xlang_kafkaio_test_{}'.format(uuid.uuid4()) + def test_kafkaio_populated_key(self): + kafka_topic = 'xlang_kafkaio_test_populated_key_{}'.format(uuid.uuid4()) local_kafka_jar = os.environ.get('LOCAL_KAFKA_JAR') with self.local_kafka_service(local_kafka_jar) as kafka_port: bootstrap_servers = '{}:{}'.format( self.get_platform_localhost(), kafka_port) - pipeline_creator = CrossLanguageKafkaIO(bootstrap_servers, kafka_topic) + pipeline_creator = CrossLanguageKafkaIO( + bootstrap_servers, kafka_topic, False) + + self.run_kafka_write(pipeline_creator) + self.run_kafka_read(pipeline_creator) + + def test_kafkaio_null_key(self): + kafka_topic = 'xlang_kafkaio_test_null_key_{}'.format(uuid.uuid4()) + local_kafka_jar = os.environ.get('LOCAL_KAFKA_JAR') + with self.local_kafka_service(local_kafka_jar) as kafka_port: + bootstrap_servers = '{}:{}'.format( + self.get_platform_localhost(), kafka_port) + pipeline_creator = CrossLanguageKafkaIO( + bootstrap_servers, kafka_topic, True) self.run_kafka_write(pipeline_creator) self.run_kafka_read(pipeline_creator)