Skip to content

Commit

Permalink
experimental[major]: CVE-2024-46946 fix (#26783)
Browse files Browse the repository at this point in the history
Description: Resolve CVE-2024-46946 by switching out sympify with
parse_expr with a very specific allowed set of operations.

https://nvd.nist.gov/vuln/detail/cve-2024-46946

Sympify uses eval which makes it vulnerable to code execution.
parse_expr is limited to specific expressions.

Bandit results

![image](https://github.com/user-attachments/assets/170a6376-7028-4e70-a7ef-9acfb49c1d8a)

---------

Co-authored-by: aqiu7 <[email protected]>
Co-authored-by: Eugene Yurtsev <[email protected]>
Co-authored-by: Eugene Yurtsev <[email protected]>
  • Loading branch information
4 people authored Sep 24, 2024
1 parent f9ef688 commit 0414be4
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,44 @@ class LLMSymbolicMathChain(Chain):
extra="forbid",
)

allow_dangerous_requests: bool # Assign no default.
"""Must be set by the user to allow dangerous requests or not.
We recommend a default of False to allow only pre-defined symbolic operations.
When set to True, the chain will allow any kind of input. This is
STRONGLY DISCOURAGED unless you fully trust the input (and believe that
the LLM itself cannot behave in a malicious way).
You should absolutely NOT be deploying this in a production environment
with allow_dangerous_requests=True. As this would allow a malicious actor
to execute arbitrary code on your system.
Use default=True at your own risk.
When set to False, the chain will only allow pre-defined symbolic operations.
If the some symbolic expressions are failing to evaluate, you can open a PR
to add them to extend the list of allowed operations.
"""

def __init__(self, **kwargs: Any) -> None:
if "allow_dangerous_requests" not in kwargs:
raise ValueError(
"LLMSymbolicMathChain requires allow_dangerous_requests to be set. "
"We recommend that you set `allow_dangerous_requests=False` to allow "
"only pre-defined symbolic operations. "
"If the some symbolic expressions are failing to evaluate, you can "
"open a PR to add them to extend the list of allowed operations. "
"Alternatively, you can set `allow_dangerous_requests=True` to allow "
"any kind of input but this is STRONGLY DISCOURAGED unless you "
"fully trust the input (and believe that the LLM itself cannot behave "
"in a malicious way)."
"You should absolutely NOT be deploying this in a production "
"environment with allow_dangerous_requests=True. As "
"this would allow a malicious actor to execute arbitrary code on "
"your system."
)
super().__init__(**kwargs)

@property
def input_keys(self) -> List[str]:
"""Expect input key.
Expand All @@ -65,8 +103,59 @@ def _evaluate_expression(self, expression: str) -> str:
raise ImportError(
"Unable to import sympy, please install it with `pip install sympy`."
) from e

try:
output = str(sympy.sympify(expression, evaluate=True))
if self.allow_dangerous_requests:
output = str(sympy.sympify(expression, evaluate=True))
else:
allowed_symbols = {
# Basic arithmetic and trigonometry
"sin": sympy.sin,
"cos": sympy.cos,
"tan": sympy.tan,
"cot": sympy.cot,
"sec": sympy.sec,
"csc": sympy.csc,
"asin": sympy.asin,
"acos": sympy.acos,
"atan": sympy.atan,
# Hyperbolic functions
"sinh": sympy.sinh,
"cosh": sympy.cosh,
"tanh": sympy.tanh,
"asinh": sympy.asinh,
"acosh": sympy.acosh,
"atanh": sympy.atanh,
# Exponentials and logarithms
"exp": sympy.exp,
"log": sympy.log,
"ln": sympy.log, # natural log sympy defaults to natural log
"log10": lambda x: sympy.log(x, 10), # log base 10 (use sympy.log)
# Powers and roots
"sqrt": sympy.sqrt,
"cbrt": lambda x: sympy.Pow(x, sympy.Rational(1, 3)),
# Combinatorics and other math functions
"factorial": sympy.factorial,
"binomial": sympy.binomial,
"gcd": sympy.gcd,
"lcm": sympy.lcm,
"abs": sympy.Abs,
"sign": sympy.sign,
"mod": sympy.Mod,
# Constants
"pi": sympy.pi,
"e": sympy.E,
"I": sympy.I,
"oo": sympy.oo,
"NaN": sympy.nan,
}

# Use parse_expr with strict settings
output = str(
sympy.parse_expr(
expression, local_dict=allowed_symbols, evaluate=True
)
)
except Exception as e:
raise ValueError(
f'LLMSymbolicMathChain._evaluate("{expression}") raised error: {e}.'
Expand Down
25 changes: 24 additions & 1 deletion libs/experimental/tests/unit_tests/test_llm_symbolic_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,20 @@ def fake_llm_symbolic_math_chain() -> LLMSymbolicMathChain:
question="What are the solutions to this equation x**2 - x?"
): "```text\nsolveset(x**2 - x, x)\n```",
_PROMPT_TEMPLATE.format(question="foo"): "foo",
_PROMPT_TEMPLATE.format(question="__import__('os')"): "__import__('os')",
}
fake_llm = FakeLLM(queries=queries)
return LLMSymbolicMathChain.from_llm(fake_llm, input_key="q", output_key="a")
return LLMSymbolicMathChain.from_llm(
fake_llm, input_key="q", output_key="a", allow_dangerous_requests=False
)


def test_require_allow_dangerous_requests_to_be_set() -> None:
"""Test that allow_dangerous_requests must be set."""
fake_llm = FakeLLM(queries={})

with pytest.raises(ValueError):
LLMSymbolicMathChain.from_llm(fake_llm, input_key="q", output_key="a")


def test_simple_question(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
Expand Down Expand Up @@ -80,3 +91,15 @@ def test_error(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
"""Test question that raises error."""
with pytest.raises(ValueError):
fake_llm_symbolic_math_chain.run("foo")


def test_security_vulnerability(
fake_llm_symbolic_math_chain: LLMSymbolicMathChain,
) -> None:
"""Test for potential security vulnerability with malicious input."""
# Example of a code injection attempt
malicious_input = "__import__('os')"

# Run the chain with the malicious input and ensure it raises an error
with pytest.raises(ValueError):
fake_llm_symbolic_math_chain.run(malicious_input)

0 comments on commit 0414be4

Please sign in to comment.