Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experimental[major]: CVE-2024-46946 fix #26783

Merged
merged 21 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,44 @@ class LLMSymbolicMathChain(Chain):
extra="forbid",
)

allow_dangerous_requests: bool # Assign no default.
"""Must be set by the user to allow dangerous requests or not.

We recommend a default of False to allow only pre-defined symbolic operations.

When set to True, the chain will allow any kind of input. This is
STRONGLY DISCOURAGED unless you fully trust the input (and believe that
the LLM itself cannot behave in a malicious way).
You should absolutely NOT be deploying this in a production environment
with allow_dangerous_requests=True. As this would allow a malicious actor
to execute arbitrary code on your system.
Use default=True at your own risk.


When set to False, the chain will only allow pre-defined symbolic operations.
If the some symbolic expressions are failing to evaluate, you can open a PR
to add them to extend the list of allowed operations.
"""

def __init__(self, **kwargs: Any) -> None:
if "allow_dangerous_requests" not in kwargs:
raise ValueError(
"LLMSymbolicMathChain requires allow_dangerous_requests to be set. "
"We recommend that you set `allow_dangerous_requests=False` to allow "
"only pre-defined symbolic operations. "
"If the some symbolic expressions are failing to evaluate, you can "
"open a PR to add them to extend the list of allowed operations. "
"Alternatively, you can set `allow_dangerous_requests=True` to allow "
"any kind of input but this is STRONGLY DISCOURAGED unless you "
"fully trust the input (and believe that the LLM itself cannot behave "
"in a malicious way)."
"You should absolutely NOT be deploying this in a production "
"environment with allow_dangerous_requests=True. As "
"this would allow a malicious actor to execute arbitrary code on "
"your system."
)
super().__init__(**kwargs)

@property
def input_keys(self) -> List[str]:
"""Expect input key.
Expand All @@ -65,8 +103,59 @@ def _evaluate_expression(self, expression: str) -> str:
raise ImportError(
"Unable to import sympy, please install it with `pip install sympy`."
) from e

try:
output = str(sympy.sympify(expression, evaluate=True))
if self.allow_dangerous_requests:
output = str(sympy.sympify(expression, evaluate=True))
else:
allowed_symbols = {
# Basic arithmetic and trigonometry
"sin": sympy.sin,
"cos": sympy.cos,
"tan": sympy.tan,
"cot": sympy.cot,
"sec": sympy.sec,
"csc": sympy.csc,
"asin": sympy.asin,
"acos": sympy.acos,
"atan": sympy.atan,
# Hyperbolic functions
"sinh": sympy.sinh,
"cosh": sympy.cosh,
"tanh": sympy.tanh,
"asinh": sympy.asinh,
"acosh": sympy.acosh,
"atanh": sympy.atanh,
# Exponentials and logarithms
"exp": sympy.exp,
"log": sympy.log,
"ln": sympy.log, # natural log sympy defaults to natural log
"log10": lambda x: sympy.log(x, 10), # log base 10 (use sympy.log)
# Powers and roots
"sqrt": sympy.sqrt,
"cbrt": lambda x: sympy.Pow(x, sympy.Rational(1, 3)),
# Combinatorics and other math functions
"factorial": sympy.factorial,
"binomial": sympy.binomial,
"gcd": sympy.gcd,
"lcm": sympy.lcm,
"abs": sympy.Abs,
"sign": sympy.sign,
"mod": sympy.Mod,
# Constants
"pi": sympy.pi,
"e": sympy.E,
"I": sympy.I,
"oo": sympy.oo,
"NaN": sympy.nan,
}

# Use parse_expr with strict settings
output = str(
sympy.parse_expr(
expression, local_dict=allowed_symbols, evaluate=True
)
)
except Exception as e:
raise ValueError(
f'LLMSymbolicMathChain._evaluate("{expression}") raised error: {e}.'
Expand Down
25 changes: 24 additions & 1 deletion libs/experimental/tests/unit_tests/test_llm_symbolic_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,20 @@ def fake_llm_symbolic_math_chain() -> LLMSymbolicMathChain:
question="What are the solutions to this equation x**2 - x?"
): "```text\nsolveset(x**2 - x, x)\n```",
_PROMPT_TEMPLATE.format(question="foo"): "foo",
_PROMPT_TEMPLATE.format(question="__import__('os')"): "__import__('os')",
}
fake_llm = FakeLLM(queries=queries)
return LLMSymbolicMathChain.from_llm(fake_llm, input_key="q", output_key="a")
return LLMSymbolicMathChain.from_llm(
fake_llm, input_key="q", output_key="a", allow_dangerous_requests=False
)


def test_require_allow_dangerous_requests_to_be_set() -> None:
"""Test that allow_dangerous_requests must be set."""
fake_llm = FakeLLM(queries={})

with pytest.raises(ValueError):
LLMSymbolicMathChain.from_llm(fake_llm, input_key="q", output_key="a")


def test_simple_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
Expand Down Expand Up @@ -80,3 +91,15 @@ def test_error(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
"""Test question that raises error."""
with pytest.raises(ValueError):
fake_llm_symbolic_math_chain.run("foo")


def test_security_vulnerability(
fake_llm_symbolic_math_chain: LLMSymbolicMathChain,
) -> None:
"""Test for potential security vulnerability with malicious input."""
# Example of a code injection attempt
malicious_input = "__import__('os')"

# Run the chain with the malicious input and ensure it raises an error
with pytest.raises(ValueError):
fake_llm_symbolic_math_chain.run(malicious_input)
Loading