Skip to content

Commit

Permalink
Replace ShellUtils arbitrary command execution (#442)
Browse files Browse the repository at this point in the history
* Replace ShellUtils arbitrary command execution
  • Loading branch information
emerkle826 authored Mar 21, 2024
1 parent 4e2a424 commit c219d81
Show file tree
Hide file tree
Showing 7 changed files with 151 additions and 206 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Changelog for Management API, new PRs should update the `main / unreleased` sect
* [FEATURE] [#453](https://github.com/k8ssandra/management-api-for-apache-cassandra/issues/453) Use a longer driver timeout for drain
* [FEATURE] [#455](https://github.com/k8ssandra/management-api-for-apache-cassandra/issues/455) Add DSE 6.8.43 to the build matrix
* [FEATURE] [#458](https://github.com/k8ssandra/management-api-for-apache-cassandra/issues/458) Update MCAC to v0.3.5
* [ENHANCEMENT] [#441](https://github.com/k8ssandra/management-api-for-apache-cassandra/issues/441) Replace ShellUtils arbitrary command execution

## v0.1.73 (2024-02-20)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,100 +6,33 @@
package com.datastax.mgmtapi.util;

import java.io.BufferedReader;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.StringReader;
import java.util.ArrayList;
import java.io.InputStreamReader;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

// From fallout
public class ShellUtils {
private static final Logger logger = LoggerFactory.getLogger(ShellUtils.class);

/**
* This is from http://stackoverflow.com/a/20725050/322152. We're not using org.apache.commons
* .exec.CommandLine because it fails to parse "run 'echo "foo"'" correctly (v1.3 misses off the
* final ')
*/
public static String[] split(CharSequence string) {
List<String> tokens = new ArrayList<>();
boolean escaping = false;
char quoteChar = ' ';
boolean quoting = false;
StringBuilder current = new StringBuilder();
for (int i = 0; i < string.length(); i++) {
char c = string.charAt(i);
if (escaping) {
current.append(c);
escaping = false;
} else if (c == '\\' && !(quoting && quoteChar == '\'')) {
escaping = true;
} else if (quoting && c == quoteChar) {
quoting = false;
} else if (!quoting && (c == '\'' || c == '"')) {
quoting = true;
quoteChar = c;
} else if (!quoting && Character.isWhitespace(c)) {
if (current.length() > 0) {
tokens.add(current.toString());
current = new StringBuilder();
}
} else {
current.append(c);
}
}
if (current.length() > 0) {
tokens.add(current.toString());
}
return tokens.toArray(new String[] {});
}
private static final long PS_WAIT_FOR_TIMEOUT_S = 600;
private static final long PS_MAX_LINES_COLLECTED = 10000;

public static String escape(String param) {
return escape(param, false);
}

public static String escape(String param, boolean forceQuote) {
String escapedQuotesParam = param.replaceAll("'", "'\"'\"'");

return forceQuote || escapedQuotesParam.contains(" ")
? "'" + escapedQuotesParam + "'"
: escapedQuotesParam;
}

public static List<String> wrapCommandWithBash(String command, boolean remoteCommand) {
List<String> fullCmd = new ArrayList<>();
fullCmd.add("/bin/bash");
fullCmd.add("-o");
fullCmd.add("pipefail"); // pipe returns first non-zero exit code
if (remoteCommand) {
// Remote commands should be run in a login shell, since they'll need the environment
// to be set up correctly. Local commands should already be in this situation,
// since fallout should have been run with the correct environment already in place.
fullCmd.add("-l");
}
fullCmd.add("-c"); // execute following command
if (remoteCommand) {
// Remote commands need to be quoted again, to prevent expansion as they're passed to ssh.
String escapedCmd = ShellUtils.escape(command, true);
fullCmd.add(escapedCmd);
} else {
fullCmd.add(command);
public static Process executeShell(
ProcessBuilder processBuilder, Map<String, String> environment) {
if (logger.isTraceEnabled()) {
String cmd = String.join(" ", processBuilder.command());
logger.trace("Executing locally: {}, Env {}", cmd, environment);
}
return fullCmd;
}

public static Process executeShell(String command, Map<String, String> environment) {
List<String> cmds = wrapCommandWithBash(command, false);
logger.trace("Executing locally: {}, Env {}", String.join(" ", cmds), environment);
ProcessBuilder pb = new ProcessBuilder(cmds);
pb.environment().putAll(environment);
processBuilder.environment().putAll(environment);
try {
return pb.start();
return processBuilder.start();
} catch (IOException e) {
throw new RuntimeException(e);
}
Expand All @@ -109,73 +42,47 @@ public interface ThrowingBiFunction<ArgType, Arg2Type, ResType> {
ResType apply(ArgType t, Arg2Type t2) throws IOException;
}

public static <T> T executeShellWithHandlers(
String command,
ThrowingBiFunction<BufferedReader, BufferedReader, T> handler,
ThrowingBiFunction<Integer, BufferedReader, T> errorHandler)
public static <T> T executeWithHandlers(
ProcessBuilder processBuilder,
ThrowingBiFunction<Stream<String>, Stream<String>, T> handler,
ThrowingBiFunction<Integer, Stream<String>, T> errorHandler)
throws IOException {
return executeShellWithHandlers(command, handler, errorHandler, Collections.emptyMap());
return executeWithHandlers(processBuilder, handler, errorHandler, Collections.emptyMap());
}

public static <T> T executeShellWithHandlers(
String command,
ThrowingBiFunction<BufferedReader, BufferedReader, T> handler,
ThrowingBiFunction<Integer, BufferedReader, T> errorHandler,
public static <T> T executeWithHandlers(
ProcessBuilder processBuilder,
ThrowingBiFunction<Stream<String>, Stream<String>, T> handler,
ThrowingBiFunction<Integer, Stream<String>, T> errorHandler,
Map<String, String> environment)
throws IOException {
Process ps = ShellUtils.executeShell(command, environment);

// We need to read _everything_ from stdin + stderr as buffering will interfere with large
// amounts of
// data. If we don't read everything here, the process might just get blocked, because of a
// buffer
// being full.
ByteArrayOutputStream stdinBuffer = new ByteArrayOutputStream();
ByteArrayOutputStream stderrBuffer = new ByteArrayOutputStream();
int ec;
try (InputStream stdin = ps.getInputStream();
InputStream stderr = ps.getErrorStream()) {
byte[] buf = new byte[1024];
while (true) {
int avStdin;
int avStderr;
int rd;
Process ps = ShellUtils.executeShell(processBuilder, environment);

avStdin = Math.min(stdin.available(), buf.length);
if (avStdin > 0) {
rd = stdin.read(buf, 0, avStdin);
if (rd > 0) stdinBuffer.write(buf, 0, rd);
}
avStderr = Math.min(stderr.available(), buf.length);
if (avStderr > 0) {
rd = stderr.read(buf, 0, avStderr);
if (rd > 0) stderrBuffer.write(buf, 0, rd);
}

if (avStdin == 0 && avStderr == 0) {
try {
ec = ps.exitValue();
break;
} catch (IllegalThreadStateException ignore) {
try {
Thread.sleep(50L);
} catch (InterruptedException e) {
// ignore
}
}
}
}
}
return runProcessWithHandlers(ps, handler, errorHandler);
}

private static <T> T runProcessWithHandlers(
Process process,
ThrowingBiFunction<Stream<String>, Stream<String>, T> handler,
ThrowingBiFunction<Integer, Stream<String>, T> errorHandler)
throws IOException {
try (BufferedReader input =
new BufferedReader(new StringReader(stdinBuffer.toString("UTF-8")));
new BufferedReader(new InputStreamReader(process.getInputStream()));
BufferedReader error =
new BufferedReader(new StringReader(stderrBuffer.toString("UTF-8")))) {
if (ec != 0) {
return errorHandler.apply(ps.exitValue(), error);
new BufferedReader(new InputStreamReader(process.getErrorStream()))) {
// we need to read all the output and error for the ps to finish
List<String> inputLines =
input.lines().limit(PS_MAX_LINES_COLLECTED).collect(Collectors.toList());
List<String> errorLines =
error.lines().limit(PS_MAX_LINES_COLLECTED).collect(Collectors.toList());
process.waitFor(PS_WAIT_FOR_TIMEOUT_S, TimeUnit.SECONDS);
if (process.exitValue() != 0) {
return errorHandler.apply(process.exitValue(), errorLines.stream());
} else {
return handler.apply(inputLines.stream(), errorLines.stream());
}

return handler.apply(input, error);
} catch (InterruptedException t) {
throw new RuntimeException(t);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -302,15 +302,15 @@ private void checkDbCmd() {
String dbCmd = "cassandra";
try {
boolean isDse = isDse();
dbCmd = isDse ? "dse" : "cassandra";

String dbHomeEnv = isDse ? "DSE_HOME" : "CASSANDRA_HOME";
if (db_home != null) {
dbHomeDir = new File(db_home);
} else if (System.getenv(dbHomeEnv) != null) {
dbHomeDir = new File(System.getenv(dbHomeEnv));
}

Optional<File> exe = UnixCmds.which(dbCmd);
Optional<File> exe = isDse ? UnixCmds.whichDse() : UnixCmds.whichCassandra();
exe.ifPresent(file -> dbCmdFile = file);

if (dbHomeDir != null && (!dbHomeDir.exists() || !dbHomeDir.isDirectory())) dbHomeDir = null;
Expand All @@ -328,13 +328,13 @@ private void checkDbCmd() {
// Verify Cassandra/DSE cmd works
List<String> errorOutput = new ArrayList<>();
String version =
ShellUtils.executeShellWithHandlers(
dbCmdFile.getAbsolutePath() + " -v",
(input, err) -> input.readLine(),
ShellUtils.executeWithHandlers(
new ProcessBuilder(dbCmdFile.getAbsolutePath(), "-v"),
(input, err) -> input.findFirst().orElse(null),
(exitCode, err) -> {
String s;
errorOutput.add("'" + dbCmdFile.getAbsolutePath() + " -v' exit code: " + exitCode);
while ((s = err.readLine()) != null) errorOutput.add(s);
while ((s = err.findFirst().orElse(null)) != null) errorOutput.add(s);
return null;
});

Expand All @@ -360,7 +360,7 @@ private void checkDbCmd() {
private boolean isDse() {
try {
// first check if dse cmd is already on the path
if (UnixCmds.which("dse").isPresent()) {
if (UnixCmds.whichDse().isPresent()) {
return true;
}
} catch (IOException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,45 +8,72 @@
import com.datastax.mgmtapi.util.ShellUtils;
import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/** Ugly methods for doing Unix commands */
public class UnixCmds {
private static final Logger logger = LoggerFactory.getLogger(UnixCmds.class);

public static Optional<File> which(String exeStr) throws IOException {
return ShellUtils.executeShellWithHandlers(
String.format("/bin/which %s", exeStr),
(input, err) -> {
File exe = new File(input.readLine().toLowerCase());
if (exe.canExecute()) return Optional.of(exe);
private static final String PS_CMD = "/bin/ps";
private static final String KILL_CMD = "/bin/kill";
private static final String WHICH_CMD = "/bin/which";

return Optional.empty();
},
public static Optional<File> whichCassandra() throws IOException {
return which("cassandra");
}

public static Optional<File> whichDse() throws IOException {
return which("dse");
}

private static Optional<File> which(String exeStr) throws IOException {
return ShellUtils.executeWithHandlers(
new ProcessBuilder(WHICH_CMD, exeStr),
(input, err) ->
input.findFirst().map(path -> new File(path.toLowerCase())).filter(File::canExecute),
(exitCode, err) -> Optional.empty());
}

public static Optional<Integer> findDbProcessWithMatchingArg(String filterStr)
throws IOException {
return ShellUtils.executeShellWithHandlers(
"/bin/ps -eo pid,command= | grep Dcassandra.server_process",

ProcessBuilder psListPb = new ProcessBuilder(PS_CMD, "-eo", "pid,command");
return ShellUtils.executeWithHandlers(
psListPb,
(input, err) -> {
Integer pid = null;
String line;
while ((line = input.readLine()) != null) {
if (line.contains(filterStr)) {
if (pid != null)
throw new RuntimeException("Found more than 1 pid for: " + filterStr);

logger.debug("Match found on {}", line);
pid = Integer.valueOf(line.trim().split("\\s")[0]);
}
List<String> match =
input.filter(x -> x.contains(filterStr)).collect(Collectors.toList());

if (match.isEmpty()) {
logger.debug("No process found for filtering criteria: {}", filterStr);
return Optional.empty();
}

if (match.size() > 1) {
throw new RuntimeException("Found more than 1 pid for: " + filterStr);
}

return Optional.ofNullable(pid);
int pid = Integer.parseInt(match.get(0).trim().split("\\s")[0]);
return Optional.of(pid);
},
(exitCode, err) -> Optional.empty());
}

public static boolean terminateProcess(int pid) throws IOException {
return ShellUtils.executeWithHandlers(
new ProcessBuilder(KILL_CMD, String.valueOf(pid)),
(input, err) -> true,
(exitCode, err) -> false);
}

public static boolean killProcess(int pid) throws IOException {
return ShellUtils.executeWithHandlers(
new ProcessBuilder(KILL_CMD, "-9", String.valueOf(pid)),
(input, err) -> true,
(exitCode, err) -> false);
}
}
Loading

0 comments on commit c219d81

Please sign in to comment.