From 4ed320c3fb39b900ce1fffad64add7729f40cf11 Mon Sep 17 00:00:00 2001
From: johnjcasey <95318300+johnjcasey@users.noreply.github.com>
Date: Wed, 18 May 2022 12:15:37 -0400
Subject: [PATCH] [BEAM-10529] update KafkaIO Xlang integration test to publish
 and receive null keys (#17319)

* [BEAM-10529] update test to publish and receive null keys

* [BEAM-10529] add test with a populated key to kafka xlang_kafkaio_it_test.py
---
 .../io/external/xlang_kafkaio_it_test.py      | 29 ++++++++++++++-----
 1 file changed, 22 insertions(+), 7 deletions(-)

diff --git a/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py b/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py
index 7f75e2bf3de1..39a83d6b6949 100644
--- a/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py
+++ b/sdks/python/apache_beam/io/external/xlang_kafkaio_it_test.py
@@ -64,9 +64,11 @@ def process(
 
 
 class CrossLanguageKafkaIO(object):
-  def __init__(self, bootstrap_servers, topic, expansion_service=None):
+  def __init__(
+      self, bootstrap_servers, topic, null_key, expansion_service=None):
     self.bootstrap_servers = bootstrap_servers
     self.topic = topic
+    self.null_key = null_key
     self.expansion_service = expansion_service
     self.sum_counter = Metrics.counter('source', 'elements_sum')
 
@@ -74,9 +76,9 @@ def build_write_pipeline(self, pipeline):
     _ = (
         pipeline
         | 'Generate' >> beam.Create(range(NUM_RECORDS))  # pylint: disable=bad-option-value
-        | 'MakeKV' >> beam.Map(lambda x:
-                               (b'', str(x).encode())).with_output_types(
-                                   typing.Tuple[bytes, bytes])
+        | 'MakeKV' >> beam.Map(
+            lambda x: (None if self.null_key else b'key', str(x).encode())).
+        with_output_types(typing.Tuple[typing.Optional[bytes], bytes])
         | 'WriteToKafka' >> WriteToKafka(
             producer_config={'bootstrap.servers': self.bootstrap_servers},
             topic=self.topic,
@@ -112,13 +114,26 @@ def run_xlang_kafkaio(self, pipeline):
     os.environ.get('LOCAL_KAFKA_JAR'),
     "LOCAL_KAFKA_JAR environment var is not provided.")
 class CrossLanguageKafkaIOTest(unittest.TestCase):
-  def test_kafkaio(self):
-    kafka_topic = 'xlang_kafkaio_test_{}'.format(uuid.uuid4())
+  def test_kafkaio_populated_key(self):
+    kafka_topic = 'xlang_kafkaio_test_populated_key_{}'.format(uuid.uuid4())
     local_kafka_jar = os.environ.get('LOCAL_KAFKA_JAR')
     with self.local_kafka_service(local_kafka_jar) as kafka_port:
       bootstrap_servers = '{}:{}'.format(
           self.get_platform_localhost(), kafka_port)
-      pipeline_creator = CrossLanguageKafkaIO(bootstrap_servers, kafka_topic)
+      pipeline_creator = CrossLanguageKafkaIO(
+          bootstrap_servers, kafka_topic, False)
+
+      self.run_kafka_write(pipeline_creator)
+      self.run_kafka_read(pipeline_creator)
+
+  def test_kafkaio_null_key(self):
+    kafka_topic = 'xlang_kafkaio_test_null_key_{}'.format(uuid.uuid4())
+    local_kafka_jar = os.environ.get('LOCAL_KAFKA_JAR')
+    with self.local_kafka_service(local_kafka_jar) as kafka_port:
+      bootstrap_servers = '{}:{}'.format(
+          self.get_platform_localhost(), kafka_port)
+      pipeline_creator = CrossLanguageKafkaIO(
+          bootstrap_servers, kafka_topic, True)
 
       self.run_kafka_write(pipeline_creator)
       self.run_kafka_read(pipeline_creator)