Skip to content

Commit

Permalink
fix(kafka): Add Kraft to Kafka containers (#611)
Browse files Browse the repository at this point in the history
Following a similar strategy as several other testcontainers
implementations, this PR introduces the possibility to run Kafka in
KRAft mode.

```py
with KafkaContainer().with_kraft() as container:
    # Test something with/on KRaft mode
```
  • Loading branch information
jfmlima authored Jun 21, 2024
1 parent 090bd0d commit 762d2a2
Show file tree
Hide file tree
Showing 4 changed files with 202 additions and 9 deletions.
30 changes: 30 additions & 0 deletions core/testcontainers/core/version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
from typing import Callable

from packaging.version import Version


class ComparableVersion:
def __init__(self, version):
self.version = Version(version)

def __lt__(self, other: str):
return self._apply_op(other, lambda x, y: x < y)

def __le__(self, other: str):
return self._apply_op(other, lambda x, y: x <= y)

def __eq__(self, other: str):
return self._apply_op(other, lambda x, y: x == y)

def __ne__(self, other: str):
return self._apply_op(other, lambda x, y: x != y)

def __gt__(self, other: str):
return self._apply_op(other, lambda x, y: x > y)

def __ge__(self, other: str):
return self._apply_op(other, lambda x, y: x >= y)

def _apply_op(self, other: str, op: Callable[[Version, Version], bool]):
other = Version(other)
return op(self.version, other)
78 changes: 78 additions & 0 deletions core/tests/test_version.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import pytest
from packaging.version import InvalidVersion

from testcontainers.core.version import ComparableVersion


@pytest.fixture
def version():
return ComparableVersion("1.0.0")


@pytest.mark.parametrize("other_version, expected", [("0.9.0", False), ("1.0.0", False), ("1.1.0", True)])
def test_lt(version, other_version, expected):
assert (version < other_version) == expected


@pytest.mark.parametrize("other_version, expected", [("0.9.0", False), ("1.0.0", True), ("1.1.0", True)])
def test_le(version, other_version, expected):
assert (version <= other_version) == expected


@pytest.mark.parametrize("other_version, expected", [("0.9.0", False), ("1.0.0", True), ("1.1.0", False)])
def test_eq(version, other_version, expected):
assert (version == other_version) == expected


@pytest.mark.parametrize("other_version, expected", [("0.9.0", True), ("1.0.0", False), ("1.1.0", True)])
def test_ne(version, other_version, expected):
assert (version != other_version) == expected


@pytest.mark.parametrize("other_version, expected", [("0.9.0", True), ("1.0.0", False), ("1.1.0", False)])
def test_gt(version, other_version, expected):
assert (version > other_version) == expected


@pytest.mark.parametrize("other_version, expected", [("0.9.0", True), ("1.0.0", True), ("1.1.0", False)])
def test_ge(version, other_version, expected):
assert (version >= other_version) == expected


@pytest.mark.parametrize(
"invalid_version",
[
"invalid",
"1..0",
],
)
def test_invalid_version_raises_error(invalid_version):
with pytest.raises(InvalidVersion):
ComparableVersion(invalid_version)


@pytest.mark.parametrize(
"invalid_version",
[
"invalid",
"1..0",
],
)
def test_comparison_with_invalid_version_raises_error(version, invalid_version):
with pytest.raises(InvalidVersion):
assert version < invalid_version

with pytest.raises(InvalidVersion):
assert version <= invalid_version

with pytest.raises(InvalidVersion):
assert version == invalid_version

with pytest.raises(InvalidVersion):
assert version != invalid_version

with pytest.raises(InvalidVersion):
assert version > invalid_version

with pytest.raises(InvalidVersion):
assert version >= invalid_version
97 changes: 88 additions & 9 deletions modules/kafka/testcontainers/kafka/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,11 @@
from io import BytesIO
from textwrap import dedent

from typing_extensions import Self

from testcontainers.core.container import DockerContainer
from testcontainers.core.utils import raise_for_deprecated_parameter
from testcontainers.core.version import ComparableVersion
from testcontainers.core.waiting_utils import wait_for_logs
from testcontainers.kafka._redpanda import RedpandaContainer

Expand All @@ -26,18 +29,29 @@ class KafkaContainer(DockerContainer):
>>> with KafkaContainer() as kafka:
... connection = kafka.get_bootstrap_server()
# Using KRaft protocol
>>> with KafkaContainer().with_kraft() as kafka:
... connection = kafka.get_bootstrap_server()
"""

TC_START_SCRIPT = "/tc-start.sh"
MIN_KRAFT_TAG = "7.0.0"

def __init__(self, image: str = "confluentinc/cp-kafka:7.6.0", port: int = 9093, **kwargs) -> None:
raise_for_deprecated_parameter(kwargs, "port_to_expose", "port")
super().__init__(image, **kwargs)
self.port = port
self.kraft_enabled = False
self.wait_for = r".*\[KafkaServer id=\d+\] started.*"
self.boot_command = ""
self.cluster_id = "MkU3OEVBNTcwNTJENDM2Qk"
self.listeners = f"PLAINTEXT://0.0.0.0:{self.port},BROKER://0.0.0.0:9092"
self.security_protocol_map = "BROKER:PLAINTEXT,PLAINTEXT:PLAINTEXT"

self.with_exposed_ports(self.port)
listeners = f"PLAINTEXT://0.0.0.0:{self.port},BROKER://0.0.0.0:9092"
self.with_env("KAFKA_LISTENERS", listeners)
self.with_env("KAFKA_LISTENER_SECURITY_PROTOCOL_MAP", "BROKER:PLAINTEXT,PLAINTEXT:PLAINTEXT")
self.with_env("KAFKA_LISTENERS", self.listeners)
self.with_env("KAFKA_LISTENER_SECURITY_PROTOCOL_MAP", self.security_protocol_map)
self.with_env("KAFKA_INTER_BROKER_LISTENER_NAME", "BROKER")

self.with_env("KAFKA_BROKER_ID", "1")
Expand All @@ -46,6 +60,74 @@ def __init__(self, image: str = "confluentinc/cp-kafka:7.6.0", port: int = 9093,
self.with_env("KAFKA_LOG_FLUSH_INTERVAL_MESSAGES", "10000000")
self.with_env("KAFKA_GROUP_INITIAL_REBALANCE_DELAY_MS", "0")

def with_kraft(self) -> Self:
self._verify_min_kraft_version()
self.kraft_enabled = True
return self

def _verify_min_kraft_version(self):
actual_version = self.image.split(":")[-1]

if ComparableVersion(actual_version) < self.MIN_KRAFT_TAG:
raise ValueError(
f"Provided Confluent Platform's version {actual_version} "
f"is not supported in Kraft mode"
f" (must be {self.MIN_KRAFT_TAG} or above)"
)

def with_cluster_id(self, cluster_id: str) -> Self:
self.cluster_id = cluster_id
return self

def configure(self):
if self.kraft_enabled:
self._configure_kraft()
else:
self._configure_zookeeper()

def _configure_kraft(self) -> None:
self.wait_for = r".*Kafka Server started.*"

self.with_env("CLUSTER_ID", self.cluster_id)
self.with_env("KAFKA_NODE_ID", 1)
self.with_env(
"KAFKA_LISTENER_SECURITY_PROTOCOL_MAP",
f"{self.security_protocol_map},CONTROLLER:PLAINTEXT",
)
self.with_env(
"KAFKA_LISTENERS",
f"{self.listeners},CONTROLLER://0.0.0.0:9094",
)
self.with_env("KAFKA_PROCESS_ROLES", "broker,controller")

network_alias = self._get_network_alias()
controller_quorum_voters = f"1@{network_alias}:9094"
self.with_env("KAFKA_CONTROLLER_QUORUM_VOTERS", controller_quorum_voters)
self.with_env("KAFKA_CONTROLLER_LISTENER_NAMES", "CONTROLLER")

self.boot_command = f"""
sed -i '/KAFKA_ZOOKEEPER_CONNECT/d' /etc/confluent/docker/configure
echo 'kafka-storage format --ignore-formatted -t {self.cluster_id} -c /etc/kafka/kafka.properties' >> /etc/confluent/docker/configure
"""

def _get_network_alias(self):
if self._network:
return next(
iter(self._network_aliases or [self._network.name or self._kwargs.get("network", [])]),
None,
)

return "localhost"

def _configure_zookeeper(self) -> None:
self.boot_command = """
echo 'clientPort=2181' > zookeeper.properties
echo 'dataDir=/var/lib/zookeeper/data' >> zookeeper.properties
echo 'dataLogDir=/var/lib/zookeeper/log' >> zookeeper.properties
zookeeper-server-start zookeeper.properties &
export KAFKA_ZOOKEEPER_CONNECT='localhost:2181'
"""

def get_bootstrap_server(self) -> str:
host = self.get_container_host_ip()
port = self.get_exposed_port(self.port)
Expand All @@ -59,11 +141,7 @@ def tc_start(self) -> None:
dedent(
f"""
#!/bin/bash
echo 'clientPort=2181' > zookeeper.properties
echo 'dataDir=/var/lib/zookeeper/data' >> zookeeper.properties
echo 'dataLogDir=/var/lib/zookeeper/log' >> zookeeper.properties
zookeeper-server-start zookeeper.properties &
export KAFKA_ZOOKEEPER_CONNECT='localhost:2181'
{self.boot_command}
export KAFKA_ADVERTISED_LISTENERS={listeners}
. /etc/confluent/docker/bash-config
/etc/confluent/docker/configure
Expand All @@ -78,10 +156,11 @@ def tc_start(self) -> None:
def start(self, timeout=30) -> "KafkaContainer":
script = KafkaContainer.TC_START_SCRIPT
command = f'sh -c "while [ ! -f {script} ]; do sleep 0.1; done; sh {script}"'
self.configure()
self.with_command(command)
super().start()
self.tc_start()
wait_for_logs(self, r".*\[KafkaServer id=\d+\] started.*", timeout=timeout)
wait_for_logs(self, self.wait_for, timeout=timeout)
return self

def create_file(self, content: bytes, path: str) -> None:
Expand Down
6 changes: 6 additions & 0 deletions modules/kafka/tests/test_kafka.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ def test_kafka_producer_consumer():
produce_and_consume_kafka_message(container)


def test_kafka_with_kraft_producer_consumer():
with KafkaContainer().with_kraft() as container:
assert container.kraft_enabled
produce_and_consume_kafka_message(container)


def test_kafka_producer_consumer_custom_port():
with KafkaContainer(port=9888) as container:
assert container.port == 9888
Expand Down

0 comments on commit 762d2a2

Please sign in to comment.