Skip to content

Commit

Permalink
remove use of reflection
Browse files Browse the repository at this point in the history
Signed-off-by: kwall <[email protected]>
  • Loading branch information
k-wall committed Nov 13, 2024
1 parent 5ae0385 commit 8ad8720
Show file tree
Hide file tree
Showing 3 changed files with 245 additions and 23 deletions.
263 changes: 245 additions & 18 deletions kafka-server/src/main/java/com/ozangunalp/kafka/server/ZkUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,25 @@
import org.apache.kafka.common.metadata.UserScramCredentialRecord;
import org.apache.kafka.common.security.JaasUtils;
import org.apache.kafka.common.security.scram.internals.ScramCredentialUtils;
import org.apache.kafka.common.security.scram.internals.ScramFormatter;
import org.apache.kafka.common.utils.Time;
import org.apache.kafka.metadata.bootstrap.BootstrapMetadata;
import org.apache.kafka.metadata.storage.Formatter;
import org.apache.kafka.metadata.storage.FormatterException;
import org.apache.kafka.server.common.ApiMessageAndVersion;
import org.apache.kafka.server.common.MetadataVersion;
import org.apache.kafka.server.config.ZkConfigs;
import org.apache.zookeeper.client.ZKClientConfig;
import scala.Option;

import java.lang.reflect.InvocationTargetException;
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Base64;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalInt;
import java.util.TreeMap;
import java.util.stream.Collectors;

public class ZkUtils {

Expand All @@ -27,9 +35,7 @@ private ZkUtils() {

public static void createScramUsersInZookeeper(KafkaConfig config, List<String> scramCredentials) {
if (!scramCredentials.isEmpty()) {

var scramCredentialRecords = buildUserScramCredentialRecords(scramCredentials);

ZKClientConfig zkClientConfig = KafkaServer.zkClientConfigFromKafkaConfig(config, false);
try (var zkClient = createZkClient("Kafka native", Time.SYSTEM, config, zkClientConfig)) {
var adminZkClient = new AdminZkClient(zkClient, Option.empty());
Expand Down Expand Up @@ -72,25 +78,246 @@ private static KafkaZkClient createZkClient(String name, Time time, KafkaConfig
"kafka.server", "SessionExpireListener", false, false);
}

private static List<UserScramCredentialRecord> buildUserScramCredentialRecords(List<String> scramCredentials) {
private static List<UserScramCredentialRecord> buildUserScramCredentialRecords(List<String> scramCredentials) {
try {
// Kafka's API don't expose a mechanism to generate UserScramCredentialRecord directly.
// Best we can do it to use the KRaft's storage formatter and extract the records it would generate.
var storageFormatter = new Formatter();
storageFormatter.setReleaseVersion(MetadataVersion.LATEST_PRODUCTION);
storageFormatter.setScramArguments(scramCredentials);
var calcMetadataMethod = storageFormatter.getClass().getDeclaredMethod("calculateBootstrapMetadata");
calcMetadataMethod.setAccessible(true);
var metadata = (BootstrapMetadata) calcMetadataMethod.invoke(storageFormatter);
return metadata.records().stream()
return ScramParser.parse(scramCredentials)
.stream()
.map(ApiMessageAndVersion::message)
.filter(UserScramCredentialRecord.class::isInstance)
.map(UserScramCredentialRecord.class::cast)
.toList();
} catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
throw new RuntimeException("Failed to generate UserScramCredentialRecords", e);
} catch (Exception e) {
throw new RuntimeException("Failed to build UserScramCredentialRecord", e);
}
}


/**
* Copied from org.apache.kafka.metadata.storage.ScramParser (3.9) as the #parse method is
* inaccessiable.
*/
public static class ScramParser {

static List<ApiMessageAndVersion> parse(List<String> arguments) throws Exception {
List<ApiMessageAndVersion> records = new ArrayList<>();
for (String argument : arguments) {
Map.Entry<org.apache.kafka.common.security.scram.internals.ScramMechanism, String> entry = parsePerMechanismArgument(argument);
PerMechanismData data = new PerMechanismData(entry.getKey(), entry.getValue());
records.add(new ApiMessageAndVersion(data.toRecord(), (short) 0));
}
return records;
}

static Map.Entry<org.apache.kafka.common.security.scram.internals.ScramMechanism, String> parsePerMechanismArgument(String input) {
input = input.trim();
int equalsIndex = input.indexOf('=');
if (equalsIndex < 0) {
throw new FormatterException("Failed to find equals sign in SCRAM " +
"argument '" + input + "'");
}
String mechanismString = input.substring(0, equalsIndex);
String configString = input.substring(equalsIndex + 1);
org.apache.kafka.common.security.scram.internals.ScramMechanism mechanism = org.apache.kafka.common.security.scram.internals.ScramMechanism.forMechanismName(mechanismString);
if (mechanism == null) {
throw new FormatterException("The add-scram mechanism " + mechanismString +
" is not supported.");
}
if (!configString.startsWith("[")) {
throw new FormatterException("Expected configuration string to start with [");
}
if (!configString.endsWith("]")) {
throw new FormatterException("Expected configuration string to end with ]");
}
return new AbstractMap.SimpleImmutableEntry<>(mechanism,
configString.substring(1, configString.length() - 1));
}
static final class PerMechanismData {

private final org.apache.kafka.common.security.scram.internals.ScramMechanism mechanism;
private final String configuredName;
private final Optional<byte[]> configuredSalt;
private final OptionalInt configuredIterations;
private final Optional<String> configuredPasswordString;
private final Optional<byte[]> configuredSaltedPassword;

PerMechanismData(
org.apache.kafka.common.security.scram.internals.ScramMechanism mechanism,
String configuredName,
Optional<byte[]> configuredSalt,
OptionalInt configuredIterations,
Optional<String> configuredPasswordString,
Optional<byte[]> configuredSaltedPassword
) {
this.mechanism = mechanism;
this.configuredName = configuredName;
this.configuredSalt = configuredSalt;
this.configuredIterations = configuredIterations;
this.configuredPasswordString = configuredPasswordString;
this.configuredSaltedPassword = configuredSaltedPassword;
}

PerMechanismData(
org.apache.kafka.common.security.scram.internals.ScramMechanism mechanism,
String configString
) {
this.mechanism = mechanism;
String[] configComponents = configString.split(",");
Map<String, String> components = new TreeMap<>();
for (String configComponent : configComponents) {
Map.Entry<String, String> entry = splitTrimmedConfigStringComponent(configComponent);
components.put(entry.getKey(), entry.getValue());
}
this.configuredName = components.remove("name");
if (this.configuredName == null) {
throw new FormatterException("You must supply 'name' to add-scram");
}

String saltString = components.remove("salt");
if (saltString == null) {
this.configuredSalt = Optional.empty();
} else {
try {
this.configuredSalt = Optional.of(Base64.getDecoder().decode(saltString));
} catch (IllegalArgumentException e) {
throw new FormatterException("Failed to decode given salt: " + saltString, e);
}
}
String iterationsString = components.remove("iterations");
if (iterationsString == null) {
this.configuredIterations = OptionalInt.empty();
} else {
try {
this.configuredIterations = OptionalInt.of(Integer.parseInt(iterationsString));
} catch (NumberFormatException e) {
throw new FormatterException("Failed to parse iterations count: " + iterationsString, e);
}
}
String passwordString = components.remove("password");
String saltedPasswordString = components.remove("saltedpassword");
if (passwordString == null) {
if (saltedPasswordString == null) {
throw new FormatterException("You must supply one of 'password' or 'saltedpassword' " +
"to add-scram");
} else if (!configuredSalt.isPresent()) {
throw new FormatterException("You must supply 'salt' with 'saltedpassword' to add-scram");
}
try {
this.configuredPasswordString = Optional.empty();
this.configuredSaltedPassword = Optional.of(Base64.getDecoder().decode(saltedPasswordString));
} catch (IllegalArgumentException e) {
throw new FormatterException("Failed to decode given saltedPassword: " +
saltedPasswordString, e);
}
} else {
this.configuredPasswordString = Optional.of(passwordString);
this.configuredSaltedPassword = Optional.empty();
}
if (!components.isEmpty()) {
throw new FormatterException("Unknown SCRAM configurations: " +
components.keySet().stream().collect(Collectors.joining(", ")));
}
}

byte[] salt() throws Exception {
if (configuredSalt.isPresent()) {
return configuredSalt.get();
}
return new ScramFormatter(mechanism).secureRandomBytes();
}

int iterations() {
if (configuredIterations.isPresent()) {
return configuredIterations.getAsInt();
}
return 4096;
}

byte[] saltedPassword(byte[] salt, int iterations) throws Exception {
if (configuredSaltedPassword.isPresent()) {
return configuredSaltedPassword.get();
}
return new ScramFormatter(mechanism).saltedPassword(
configuredPasswordString.get(),
salt,
iterations);
}

UserScramCredentialRecord toRecord() throws Exception {
ScramFormatter formatter = new ScramFormatter(mechanism);
byte[] salt = salt();
int iterations = iterations();
if (iterations < mechanism.minIterations()) {
throw new FormatterException("The 'iterations' value must be >= " +
mechanism.minIterations() + " for add-scram using " + mechanism);
}
if (iterations > mechanism.maxIterations()) {
throw new FormatterException("The 'iterations' value must be <= " +
mechanism.maxIterations() + " for add-scram using " + mechanism);
}
byte[] saltedPassword = saltedPassword(salt, iterations);
return new UserScramCredentialRecord().
setName(configuredName).
setMechanism(mechanism.type()).
setSalt(salt).
setStoredKey(formatter.storedKey(formatter.clientKey(saltedPassword))).
setServerKey(formatter.serverKey(saltedPassword)).
setIterations(iterations);
}

@Override
public boolean equals(Object o) {
if (o == null || (!(o.getClass().equals(PerMechanismData.class)))) return false;
PerMechanismData other = (PerMechanismData) o;
return mechanism.equals(other.mechanism) &&
configuredName.equals(other.configuredName) &&
Arrays.equals(configuredSalt.orElseGet(() -> null),
other.configuredSalt.orElseGet(() -> null)) &&
configuredIterations.equals(other.configuredIterations) &&
configuredPasswordString.equals(other.configuredPasswordString) &&
Arrays.equals(configuredSaltedPassword.orElseGet(() -> null),
other.configuredSaltedPassword.orElseGet(() -> null));
}

@Override
public int hashCode() {
return Objects.hash(mechanism,
configuredName,
configuredSalt,
configuredIterations,
configuredPasswordString,
configuredSaltedPassword);
}

@Override
public String toString() {
return "PerMechanismData" +
"(mechanism=" + mechanism +
", configuredName=" + configuredName +
", configuredSalt=" + configuredSalt.map(v -> Arrays.toString(v)) +
", configuredIterations=" + configuredIterations +
", configuredPasswordString=" + configuredPasswordString +
", configuredSaltedPassword=" + configuredSaltedPassword.map(v -> Arrays.toString(v)) +
")";
}
}

static Map.Entry<String, String> splitTrimmedConfigStringComponent(String input) {
int i;
for (i = 0; i < input.length(); i++) {
if (input.charAt(i) == '=') {
break;
}
}
if (i == input.length()) {
throw new FormatterException("No equals sign found in SCRAM component: " + input);
}
String value = input.substring(i + 1);
if (value.length() >= 2) {
if (value.startsWith("\"") && value.endsWith("\"")) {
value = value.substring(1, value.length() - 1);
}
}
return new AbstractMap.SimpleImmutableEntry<>(input.substring(0, i), value);
}
}
}
4 changes: 0 additions & 4 deletions quarkus-kafka-server-extension/deployment/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,6 @@
<artifactId>quarkus-junit5-internal</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.kafka</groupId>
<artifactId>kafka-metadata</artifactId>
</dependency>
</dependencies>
<build>
<plugins>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ void build(BuildProducer<ReflectiveClassBuildItem> reflectiveClass,
producer.produce(new RuntimeInitializedClassBuildItem("kafka.server.DelayedDeleteRecordsMetrics$"));

reflectiveClass.produce(ReflectiveClassBuildItem.builder(org.apache.kafka.common.metrics.JmxReporter.class).build());
reflectiveClass.produce(ReflectiveClassBuildItem.builder(org.apache.kafka.metadata.storage.Formatter.class).build());
}

@BuildStep
Expand Down

0 comments on commit 8ad8720

Please sign in to comment.