diff --git a/px/px_commandline.py b/px/px_commandline.py index e658129..37a1f59 100644 --- a/px/px_commandline.py +++ b/px/px_commandline.py @@ -255,14 +255,42 @@ def get_python_command(commandline: str) -> Optional[str]: if len(array) == 1: return python - if not array[1].startswith("-"): - return os.path.basename(array[1]) - if len(array) > 2: if array[1] == "-m" and not array[2].startswith("-"): return os.path.basename(array[2]) - return None + if array[1].startswith("-"): + return None + + if os.path.basename(array[1]) == "aws": + # Drop the "python" part of "python aws" + return get_aws_command(array[1:]) + + return os.path.basename(array[1]) + + +def get_aws_command(args: List[str]) -> Optional[str]: + '''Extract "aws command subcommand" from a command line starting with "aws"''' + result = ["aws"] + for arg in args[1:]: + if arg.startswith("--profile="): + continue + if arg.startswith("--region="): + continue + if arg.startswith("-"): + break + if os.path.sep in arg: + break + + result.append(arg) + if len(result) >= 4: + # Got "aws command subcommand" + break + + if len(result) == 4 and result[-1] != "help": + del result[-1] + + return " ".join(result) def get_sudo_command(commandline: str) -> Optional[str]: diff --git a/tests/px_commandline_test.py b/tests/px_commandline_test.py index d0024be..8cec4d7 100644 --- a/tests/px_commandline_test.py +++ b/tests/px_commandline_test.py @@ -36,6 +36,36 @@ def test_get_command_python(): assert px_commandline.get_command("python ") == "python" +def test_get_command_aws(): + assert px_commandline.get_command("Python /usr/local/bin/aws") == "aws" + assert px_commandline.get_command("python aws s3") == "aws s3" + assert px_commandline.get_command("python3 aws s3 help") == "aws s3 help" + assert ( + px_commandline.get_command("/wherever/python3 aws s3 help flaska") + == "aws s3 help" + ) + assert px_commandline.get_command("python aws s3 sync help") == "aws s3 sync help" + assert px_commandline.get_command("python aws s3 sync nothelp") == "aws s3 sync" + assert ( + px_commandline.get_command( + " ".join( + [ + "python3", + "/usr/local/bin/aws", + "--profile=system-admin-prod", + "--region=eu-west-1", + "s3", + "sync", + "--only-show-errors", + "s3://xxxxxx", + "./xxxxxx", + ] + ) + ) + == "aws s3 sync" + ) + + def test_get_command_java(): assert px_commandline.get_command("java") == "java" assert px_commandline.get_command("java -version") == "java"