Skip to content

Commit

Permalink
Prevent unwanted simplification of derivatives after SymPy upgrade
Browse files Browse the repository at this point in the history
The `simplify` method changed in SymPy version 1.5 due to commit
sympy/sympy@0fdc617
to evaluate unevaluated derivatives by default. This can trivially be
disabled which is what this commit does.
Add a test to catch this in future.
  • Loading branch information
jsharkey13 committed Jun 29, 2021
1 parent 6bdebd2 commit 5a5c5e4
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 1 deletion.
2 changes: 1 addition & 1 deletion checker/maths.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def symbolic_equality(test_expr, target_expr):
# it doesn't seem like much of an issue. Removing 'sympy.posify()' below will
# stop this.
try:
if sympy.simplify(sympy.posify(test_expr - target_expr)[0]) == 0:
if sympy.simplify(sympy.posify(test_expr - target_expr)[0], doit=False) == 0:
print("Symbolic match.")
print("INFO: Adding known pair ({0}, {1})".format(target_expr, test_expr))
KNOWN_PAIRS[(target_expr, test_expr)] = EqualityType.SYMBOLIC
Expand Down
14 changes: 14 additions & 0 deletions checker/tests/test_maths.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,20 @@ def test_derivatives(self):
self.assertTrue(response["equality_type"] == "symbolic", 'For these expressions, expected "equality_type" to be "symbolic", got "{}"!'.format(response["equality_type"]))
print(" PASS ".center(75, "#"))

def test_derivatives_not_removed(self):
print("\n\n\n" + " Test Derivatives Not Evaluated ".center(75, "#"))
test_str = "0"
target_str = "Derivative(y, x)"
response = api.check(test_str, target_str, check_symbols=False)

self.assertTrue("error" not in response, 'Unexpected "error" in response!')
self.assertTrue("equal" in response, 'Key "equal" not in response!')
self.assertTrue(response["equal"] == "false", 'Expected "equal" to be "false", got "{}"!'.format(response["equal"]))
self.assertTrue("equality_type" in response, 'Key "equality_type" not in response!')
self.assertTrue(response["equality_type"] in EQUALITY_TYPES, 'Unexpected "equality_type": "{}"!'.format(response["equality_type"]))
self.assertTrue(response["equality_type"] == "numeric", 'For these expressions, expected "equality_type" to be "numeric", got "{}"!'.format(response["equality_type"]))
print(" PASS ".center(75, "#"))

def test_differential_equations(self):
print("\n\n\n" + " Test Differential Equations ".center(75, "#"))
api.SIMPLIFY_DERIVATIVES = True
Expand Down

0 comments on commit 5a5c5e4

Please sign in to comment.