Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experimental[major]: CVE-2024-46946 fix #26783

Merged
merged 21 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,52 @@ def _evaluate_expression(self, expression: str) -> str:
"Unable to import sympy, please install it with `pip install sympy`."
) from e
try:
output = str(sympy.sympify(expression, evaluate=True))
allowed_symbols = {
# Basic arithmetic and trigonometry
"sin": sympy.sin,
"cos": sympy.cos,
"tan": sympy.tan,
"cot": sympy.cot,
"sec": sympy.sec,
"csc": sympy.csc,
"asin": sympy.asin,
"acos": sympy.acos,
"atan": sympy.atan,
# Hyperbolic functions
"sinh": sympy.sinh,
"cosh": sympy.cosh,
"tanh": sympy.tanh,
"asinh": sympy.asinh,
"acosh": sympy.acosh,
"atanh": sympy.atanh,
# Exponentials and logarithms
"exp": sympy.exp,
"log": sympy.log,
"ln": sympy.log, # natural log (alias)
"log10": sympy.log, # log base 10 (use sympy.log)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

both log10 and ln are aliased to sympy.log

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

changed so that log would be a user defined base (default 10), ln will be regular sympy log which defaults to ln. And log 10 to base 10

# Powers and roots
"sqrt": sympy.sqrt,
"cbrt": lambda x: sympy.Pow(x, sympy.Rational(1, 3)),
# Combinatorics and other math functions
"factorial": sympy.factorial,
"binomial": sympy.binomial,
"gcd": sympy.gcd,
"lcm": sympy.lcm,
"abs": sympy.Abs,
"sign": sympy.sign,
"mod": sympy.Mod,
# Constants
"pi": sympy.pi,
"e": sympy.E,
"I": sympy.I,
"oo": sympy.oo,
"NaN": sympy.nan,
}

# Use parse_expr with strict settings
output = str(
sympy.parse_expr(expression, local_dict=allowed_symbols, evaluate=True)
)
except Exception as e:
raise ValueError(
f'LLMSymbolicMathChain._evaluate("{expression}") raised error: {e}.'
Expand Down
15 changes: 15 additions & 0 deletions libs/experimental/tests/unit_tests/test_llm_symbolic_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def fake_llm_symbolic_math_chain() -> LLMSymbolicMathChain:
question="What are the solutions to this equation x**2 - x?"
): "```text\nsolveset(x**2 - x, x)\n```",
_PROMPT_TEMPLATE.format(question="foo"): "foo",
_PROMPT_TEMPLATE.format(
question="__import__('os').system('rm -rf /')"
): "__import__('os').system('rm -rf /')",
}
fake_llm = FakeLLM(queries=queries)
return LLMSymbolicMathChain.from_llm(fake_llm, input_key="q", output_key="a")
Expand Down Expand Up @@ -80,3 +83,15 @@ def test_error(fake_llm_symbolic_math_chain: LLMSymbolicMathChain) -> None:
"""Test question that raises error."""
with pytest.raises(ValueError):
fake_llm_symbolic_math_chain.run("foo")


def test_security_vulnerability(
fake_llm_symbolic_math_chain: LLMSymbolicMathChain,
) -> None:
"""Test for potential security vulnerability with malicious input."""
# Example of a code injection attempt
malicious_input = "__import__('os').system('rm -rf /')"

# Run the chain with the malicious input and ensure it raises an error
with pytest.raises(ValueError):
fake_llm_symbolic_math_chain.run(malicious_input)