Skip to content

Commit

Permalink
use batch.num.messages config in KafkaGroupReadableResource::Next (te…
Browse files Browse the repository at this point in the history
…nsorflow#1460)

* use max.poll.records config in KafkaGroupReadableResource

* switch to batch.num.messages

add test case for batch.num.messages

* fix failing kafka test
  • Loading branch information
kvignesh1420 authored Jun 19, 2021
1 parent 444ff41 commit d0383cd
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 5 deletions.
22 changes: 18 additions & 4 deletions tensorflow_io/core/kernels/kafka_kernels.cc
Original file line number Diff line number Diff line change
Expand Up @@ -904,6 +904,20 @@ class KafkaGroupReadableResource : public ResourceBase {
}
}

// set max.poll.records configuration
std::string batch_num_messages;
if ((result = conf->get("batch.num.messages", batch_num_messages)) !=
RdKafka::Conf::CONF_OK) {
batch_num_messages = "1024";
if ((result = conf->set("batch.num.messages", batch_num_messages,
errstr)) != RdKafka::Conf::CONF_OK) {
return errors::Internal("failed to set batch.num.messages [",
batch_num_messages, "]:", errstr);
}
}
sscanf(batch_num_messages.c_str(), "%d", &batch_num_messages_);
LOG(INFO) << "max num of messages per batch: " << batch_num_messages_;

// Always set enable.partition.eof=true
if ((result = conf->set("enable.partition.eof", "true", errstr)) !=
RdKafka::Conf::CONF_OK) {
Expand Down Expand Up @@ -947,16 +961,15 @@ class KafkaGroupReadableResource : public ResourceBase {

// Initialize necessary variables
int64 num_messages = 0;
int64 max_num_messages = 1024;
max_stream_timeout_polls_ = stream_timeout / message_poll_timeout;

// Allocate memory for message_value and key_value vectors
std::vector<string> message_value, key_value;
message_value.reserve(max_num_messages);
key_value.reserve(max_num_messages);
message_value.reserve(batch_num_messages_);
key_value.reserve(batch_num_messages_);

std::unique_ptr<RdKafka::Message> message;
while (consumer_.get() != nullptr && num_messages < max_num_messages) {
while (consumer_.get() != nullptr && num_messages < batch_num_messages_) {
if (!kafka_event_cb_.run()) {
return errors::Internal(
"failed to consume messages due to broker issue");
Expand Down Expand Up @@ -1022,6 +1035,7 @@ class KafkaGroupReadableResource : public ResourceBase {
KafkaRebalanceCb kafka_rebalance_cb_ = KafkaRebalanceCb();
int max_stream_timeout_polls_ = -1;
int stream_timeout_polls_ = -1;
int batch_num_messages_ = 1024;
};

class KafkaGroupReadableInitOp
Expand Down
34 changes: 33 additions & 1 deletion tests/test_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,38 @@ def write_messages_background():
)


def test_kafka_mini_dataset_size():
"""Test the functionality of batch.num.messages property of
KafkaBatchIODataset/KafkaGroupIODataset.
"""
import tensorflow_io.kafka as kafka_io

# Write new messages to the topic
for i in range(200, 10000):
message = "D{}".format(i)
kafka_io.write_kafka(message=message, topic="key-partition-test")

BATCH_NUM_MESSAGES = 5000
dataset = tfio.experimental.streaming.KafkaBatchIODataset(
topics=["key-partition-test"],
group_id="cgminibatchsize",
servers=None,
stream_timeout=5000,
configuration=[
"session.timeout.ms=7000",
"max.poll.interval.ms=8000",
"auto.offset.reset=earliest",
"batch.num.messages={}".format(BATCH_NUM_MESSAGES),
],
)
for mini_d in dataset:
count = 0
for _ in mini_d:
count += 1
assert count == BATCH_NUM_MESSAGES
break


def test_kafka_batch_io_dataset():
"""Test the functionality of the KafkaBatchIODataset by training a model
directly on the incoming kafka message batch(of type tf.data.Dataset), in an
Expand All @@ -460,7 +492,7 @@ def test_kafka_batch_io_dataset():

dataset = tfio.experimental.streaming.KafkaBatchIODataset(
topics=["mini-batch-test"],
group_id="cgminibatch",
group_id="cgminibatchtrain",
servers=None,
stream_timeout=5000,
configuration=[
Expand Down

0 comments on commit d0383cd

Please sign in to comment.