Skip to content

Commit

Permalink
fix: support heredoc/EOF syntax in shell tool (#335)
Browse files Browse the repository at this point in the history
* fix: shell tool doesn't support heredoc syntax, stalls #108

* implement PR comment fixes

* implement PR comment fixes

* format: fixed formatting

---------

Co-authored-by: Jamesb <[email protected]>
Co-authored-by: Erik Bjäreholt <[email protected]>
  • Loading branch information
3 people authored Dec 15, 2024
1 parent 8e8f28e commit 6a1246f
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 30 deletions.
45 changes: 20 additions & 25 deletions gptme/tools/shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ def _run(self, command: str, output=True, tries=0) -> tuple[int | None, str, str
assert self.process.stdin

# run the command
full_command = f"{command}; echo ReturnCode:$? {self.delimiter}\n"
full_command = f"{command}\n"
full_command += f"echo ReturnCode:$? {self.delimiter}\n"
try:
self.process.stdin.write(full_command)
except BrokenPipeError:
Expand All @@ -177,10 +178,9 @@ def _run(self, command: str, output=True, tries=0) -> tuple[int | None, str, str

self.process.stdin.flush()

stdout = []
stderr = []
return_code = None
read_delimiter = False
stdout: list[str] = []
stderr: list[str] = []
return_code: int | None = None

while True:
rlist, _, _ = select.select([self.stdout_fd, self.stderr_fd], [], [])
Expand All @@ -190,13 +190,22 @@ def _run(self, command: str, output=True, tries=0) -> tuple[int | None, str, str
# 2**12 = 4096
# 2**16 = 65536
data = os.read(fd, 2**16).decode("utf-8")
lines = data.splitlines(keepends=True)
re_returncode = re.compile(r"ReturnCode:(\d+)")
for line in re.split(r"(\n)", data):
if match := re_returncode.match(line):
return_code = int(match.group(1))
if self.delimiter in line:
read_delimiter = True
continue
for line in lines:
if "ReturnCode:" in line and self.delimiter in line:
if match := re_returncode.search(line):
return_code = int(match.group(1))
# if command is cd and successful, we need to change the directory
if command.startswith("cd ") and return_code == 0:
ex, pwd, _ = self._run("pwd", output=False)
assert ex == 0
os.chdir(pwd.strip())
return (
return_code,
"".join(stdout).strip(),
"".join(stderr).strip(),
)
if fd == self.stdout_fd:
stdout.append(line)
if output:
Expand All @@ -205,20 +214,6 @@ def _run(self, command: str, output=True, tries=0) -> tuple[int | None, str, str
stderr.append(line)
if output:
print(line, end="", file=sys.stderr)
if read_delimiter:
break

# if command is cd and successful, we need to change the directory
if command.startswith("cd ") and return_code == 0:
ex, pwd, _ = self._run("pwd", output=False)
assert ex == 0
os.chdir(pwd.strip())

return (
return_code,
"".join(stdout).replace(f"ReturnCode:{return_code}", "").strip(),
"".join(stderr).strip(),
)

def close(self):
assert self.process.stdin
Expand Down
63 changes: 58 additions & 5 deletions tests/test_shell.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,38 @@ def test_echo(shell):


def test_echo_multiline(shell):
# tests multiline and trailing + leading whitespace
# Test multiline and trailing + leading whitespace
ret, out, err = shell.run("echo 'Line 1 \n Line 2'")
assert err.strip() == "" # Expecting no stderr
assert (
out.strip() == "Line 1 \n Line 2"
) # Expecting stdout to be "Line 1\nLine 2"
assert err.strip() == ""
assert out.strip() == "Line 1 \n Line 2"
assert ret == 0

# Test basic heredoc (<<)
ret, out, err = shell.run("""
cat << EOF
Hello
World
EOF
""")
assert err.strip() == ""
assert out.strip() == "Hello\nWorld"
assert ret == 0

# Test stripped heredoc (<<-)
ret, out, err = shell.run("""
cat <<- EOF
Hello
World
EOF
""")
assert err.strip() == ""
assert out.strip() == "Hello\nWorld"
assert ret == 0

# Test here-string (<<<)
ret, out, err = shell.run("cat <<< 'Hello World'")
assert err.strip() == ""
assert out.strip() == "Hello World"
assert ret == 0


Expand Down Expand Up @@ -80,6 +106,33 @@ def test_split_commands():
assert len(commands) == 1


def test_heredoc_complex(shell):
# Test nested heredocs
ret, out, err = shell.run("""
cat << OUTER
This is the outer heredoc
$(cat << INNER
This is the inner heredoc
INNER
)
OUTER
""")
assert err.strip() == ""
assert out.strip() == "This is the outer heredoc\nThis is the inner heredoc"
assert ret == 0

# Test heredoc with variable substitution
ret, out, err = shell.run("""
NAME="World"
cat << EOF
Hello, $NAME!
EOF
""")
assert err.strip() == ""
assert out.strip() == "Hello, World!"
assert ret == 0


def test_function():
script = """
function hello() {
Expand Down

0 comments on commit 6a1246f

Please sign in to comment.