Skip to content

Commit

Permalink
Cadlabs/python tool sanitization (#4754)
Browse files Browse the repository at this point in the history
Co-authored-by: BenSchZA <[email protected]>
  • Loading branch information
dev2049 and BenSchZA authored May 18, 2023
1 parent 0dc304c commit e28bdf4
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 4 deletions.
16 changes: 13 additions & 3 deletions langchain/tools/python/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,17 @@ def _get_default_python_repl() -> PythonREPL:
return PythonREPL(_globals=globals(), _locals=None)


_MD_PY_BLOCK = "```python"


def sanitize_input(query: str) -> str:
query = query.strip()
if query[: len(_MD_PY_BLOCK)] == _MD_PY_BLOCK:
query = query[len(_MD_PY_BLOCK) :].strip()
query = query.strip("`").strip()
return query


class PythonREPLTool(BaseTool):
"""A tool for running python code in a REPL."""

Expand All @@ -39,7 +50,7 @@ def _run(
) -> Any:
"""Use the tool."""
if self.sanitize_input:
query = query.strip().strip("```")
query = sanitize_input(query)
return self.python_repl.run(query)

async def _arun(
Expand Down Expand Up @@ -84,8 +95,7 @@ def _run(
"""Use the tool."""
try:
if self.sanitize_input:
# Remove the triple backticks from the query.
query = query.strip().strip("```")
query = sanitize_input(query)
tree = ast.parse(query)
module = ast.Module(tree.body[:-1], type_ignores=[])
exec(ast.unparse(module), self.globals, self.locals) # type: ignore
Expand Down
33 changes: 32 additions & 1 deletion tests/unit_tests/tools/python/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@

import pytest

from langchain.tools.python.tool import PythonAstREPLTool, PythonREPLTool
from langchain.tools.python.tool import (
PythonAstREPLTool,
PythonREPLTool,
sanitize_input,
)


def test_python_repl_tool_single_input() -> None:
Expand All @@ -21,3 +25,30 @@ def test_python_ast_repl_tool_single_input() -> None:
tool = PythonAstREPLTool()
assert tool.is_single_input
assert tool.run("1 + 1") == 2


def test_sanitize_input() -> None:
query = """
```
p = 5
```
"""
expected = "p = 5"
actual = sanitize_input(query)
assert expected == actual

query = """
```python
p = 5
```
"""
expected = "p = 5"
actual = sanitize_input(query)
assert expected == actual

query = """
p = 5
"""
expected = "p = 5"
actual = sanitize_input(query)
assert expected == actual

0 comments on commit e28bdf4

Please sign in to comment.