-
-
Notifications
You must be signed in to change notification settings - Fork 563
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add a few more expression simplifications #2211
Merged
Merged
Changes from 10 commits
Commits
Show all changes
12 commits
Select commit
Hold shift + click to select a range
5c12173
add a few more expression simplifications
valentinsulzer 610f0dc
renaming
valentinsulzer fbb3954
changelog
valentinsulzer be6d786
fix 1+1D bug
valentinsulzer 3aa7e22
Merge branch 'develop' into more-expression-simplifications
valentinsulzer 3fcf9d8
comment out simplification for matmuls, possibly failing because of p…
valentinsulzer 1e1a2c9
Merge branch 'develop' into more-expression-simplifications
valentinsulzer b2e4f4d
flake8
valentinsulzer 2fa60bd
merge develop
valentinsulzer 95e5626
merge develop
valentinsulzer 4668bda
merge develop, rob comments
valentinsulzer e28b196
style: pre-commit fixes
pre-commit-ci[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -837,7 +837,7 @@ def simplified_addition(left, right): | |
|
||
# Return constant if both sides are constant | ||
if left.is_constant() and right.is_constant(): | ||
return pybamm.simplify_if_constant(pybamm.Addition(left, right)) | ||
return pybamm.simplify_if_constant(Addition(left, right)) | ||
|
||
# Simplify A @ c + B @ c to (A + B) @ c if (A + B) is constant | ||
# This is a common construction that appears from discretisation of spatial | ||
|
@@ -852,10 +852,10 @@ def simplified_addition(left, right): | |
new_left = l_left + r_left | ||
if new_left.is_constant(): | ||
new_sum = new_left @ l_right | ||
new_sum.copy_domains(pybamm.Addition(left, right)) | ||
new_sum.copy_domains(Addition(left, right)) | ||
return new_sum | ||
|
||
if isinstance(right, pybamm.Addition) and left.is_constant(): | ||
if isinstance(right, Addition) and left.is_constant(): | ||
# Simplify a + (b + c) to (a + b) + c if (a + b) is constant | ||
if right.left.is_constant(): | ||
r_left, r_right = right.orphans | ||
|
@@ -864,7 +864,7 @@ def simplified_addition(left, right): | |
elif right.right.is_constant(): | ||
r_left, r_right = right.orphans | ||
return (left + r_right) + r_left | ||
elif isinstance(right, pybamm.Subtraction) and left.is_constant(): | ||
elif isinstance(right, Subtraction) and left.is_constant(): | ||
# Simplify a + (b - c) to (a + b) - c if (a + b) is constant | ||
if right.left.is_constant(): | ||
r_left, r_right = right.orphans | ||
|
@@ -873,7 +873,7 @@ def simplified_addition(left, right): | |
elif right.right.is_constant(): | ||
r_left, r_right = right.orphans | ||
return (left - r_right) + r_left | ||
if isinstance(left, pybamm.Addition) and right.is_constant(): | ||
if isinstance(left, Addition) and right.is_constant(): | ||
# Simplify (a + b) + c to a + (b + c) if (b + c) is constant | ||
if left.right.is_constant(): | ||
l_left, l_right = left.orphans | ||
|
@@ -882,7 +882,7 @@ def simplified_addition(left, right): | |
elif left.left.is_constant(): | ||
l_left, l_right = left.orphans | ||
return (l_left + right) + l_right | ||
elif isinstance(left, pybamm.Subtraction) and right.is_constant(): | ||
elif isinstance(left, Subtraction) and right.is_constant(): | ||
# Simplify (a - b) + c to a + (c - b) if (c - b) is constant | ||
if left.right.is_constant(): | ||
l_left, l_right = left.orphans | ||
|
@@ -892,7 +892,7 @@ def simplified_addition(left, right): | |
l_left, l_right = left.orphans | ||
return (l_left + right) - l_right | ||
|
||
return pybamm.simplify_if_constant(pybamm.Addition(left, right)) | ||
return pybamm.simplify_if_constant(Addition(left, right)) | ||
|
||
|
||
def simplified_subtraction(left, right): | ||
|
@@ -953,7 +953,7 @@ def simplified_subtraction(left, right): | |
if left == right: | ||
return pybamm.zeros_like(left) | ||
|
||
if isinstance(right, pybamm.Addition) and left.is_constant(): | ||
if isinstance(right, Addition) and left.is_constant(): | ||
# Simplify a - (b + c) to (a - b) - c if (a - b) is constant | ||
if right.left.is_constant(): | ||
r_left, r_right = right.orphans | ||
|
@@ -962,7 +962,7 @@ def simplified_subtraction(left, right): | |
elif right.right.is_constant(): | ||
r_left, r_right = right.orphans | ||
return (left - r_right) - r_left | ||
elif isinstance(right, pybamm.Subtraction) and left.is_constant(): | ||
elif isinstance(right, Subtraction) and left.is_constant(): | ||
# Simplify a - (b - c) to (a - b) + c if (a - b) is constant | ||
if right.left.is_constant(): | ||
r_left, r_right = right.orphans | ||
|
@@ -971,7 +971,7 @@ def simplified_subtraction(left, right): | |
elif right.right.is_constant(): | ||
r_left, r_right = right.orphans | ||
return (left + r_right) - r_left | ||
if isinstance(left, pybamm.Addition) and right.is_constant(): | ||
if isinstance(left, Addition) and right.is_constant(): | ||
# Simplify (a + b) - c to a + (b - c) if (b - c) is constant | ||
if left.right.is_constant(): | ||
l_left, l_right = left.orphans | ||
|
@@ -980,7 +980,7 @@ def simplified_subtraction(left, right): | |
elif left.left.is_constant(): | ||
l_left, l_right = left.orphans | ||
return (l_left - right) + l_right | ||
elif isinstance(left, pybamm.Subtraction) and right.is_constant(): | ||
elif isinstance(left, Subtraction) and right.is_constant(): | ||
# Simplify (a - b) - c to a - (c + b) if (c + b) is constant | ||
if left.right.is_constant(): | ||
l_left, l_right = left.orphans | ||
|
@@ -990,7 +990,7 @@ def simplified_subtraction(left, right): | |
l_left, l_right = left.orphans | ||
return (l_left - right) - l_right | ||
|
||
return pybamm.simplify_if_constant(pybamm.Subtraction(left, right)) | ||
return pybamm.simplify_if_constant(Subtraction(left, right)) | ||
|
||
|
||
def simplified_multiplication(left, right): | ||
|
@@ -1011,7 +1011,7 @@ def simplified_multiplication(left, right): | |
|
||
# if one of the children is a zero matrix, we have to be careful about shapes | ||
if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right): | ||
return pybamm.zeros_like(pybamm.Multiplication(left, right)) | ||
return pybamm.zeros_like(Multiplication(left, right)) | ||
|
||
# anything multiplied by a scalar one returns itself | ||
if pybamm.is_scalar_one(left): | ||
|
@@ -1027,7 +1027,7 @@ def simplified_multiplication(left, right): | |
|
||
# Return constant if both sides are constant | ||
if left.is_constant() and right.is_constant(): | ||
return pybamm.simplify_if_constant(pybamm.Multiplication(left, right)) | ||
return pybamm.simplify_if_constant(Multiplication(left, right)) | ||
|
||
# anything multiplied by a matrix one returns itself if | ||
# - the shapes are the same | ||
|
@@ -1121,11 +1121,7 @@ def simplified_multiplication(left, right): | |
# operators | ||
# Also do this for cases like a * (b @ c + d) where (a * b) is constant | ||
elif isinstance(right, (Addition, Subtraction)): | ||
mul_classes = ( | ||
pybamm.Multiplication, | ||
pybamm.MatrixMultiplication, | ||
pybamm.Division, | ||
) | ||
mul_classes = (Multiplication, MatrixMultiplication, Division) | ||
if ( | ||
right.left.is_constant() | ||
or right.right.is_constant() | ||
|
@@ -1152,7 +1148,7 @@ def simplified_multiplication(left, right): | |
# Simplify a * (-b) to (-a) * b if (-a) is constant | ||
return (-left) * right.orphans[0] | ||
|
||
return pybamm.Multiplication(left, right) | ||
return Multiplication(left, right) | ||
|
||
|
||
def simplified_division(left, right): | ||
|
@@ -1169,7 +1165,7 @@ def simplified_division(left, right): | |
|
||
# matrix zero divided by anything returns matrix zero (i.e. itself) | ||
if pybamm.is_matrix_zero(left): | ||
return pybamm.zeros_like(pybamm.Division(left, right)) | ||
return pybamm.zeros_like(Division(left, right)) | ||
|
||
# anything divided by zero raises error | ||
if pybamm.is_scalar_zero(right): | ||
|
@@ -1204,7 +1200,7 @@ def simplified_division(left, right): | |
|
||
# Return constant if both sides are constant | ||
if left.is_constant() and right.is_constant(): | ||
return pybamm.simplify_if_constant(pybamm.Division(left, right)) | ||
return pybamm.simplify_if_constant(Division(left, right)) | ||
|
||
# Simplify (B @ c) / a to (B / a) @ c if (B / a) is constant | ||
# This is a common construction that appears from discretisation of averages | ||
|
@@ -1257,6 +1253,16 @@ def simplified_division(left, right): | |
r_left, r_right = right.orphans | ||
return (left * r_right) / r_left | ||
|
||
# Cancelling out common terms | ||
if ( | ||
isinstance(left, Multiplication) | ||
and isinstance(right, Multiplication) | ||
and left.left == right.left | ||
): | ||
_, l_right = left.orphans | ||
_, r_right = right.orphans | ||
return l_right / r_right | ||
|
||
# Negation simplifications | ||
if isinstance(left, pybamm.Negate) and isinstance(right, pybamm.Negate): | ||
# Double negation cancels out | ||
|
@@ -1269,13 +1275,13 @@ def simplified_division(left, right): | |
# Simplify a / (-b) to (-a) / b if (-a) is constant | ||
return (-left) / right.orphans[0] | ||
|
||
return pybamm.simplify_if_constant(pybamm.Division(left, right)) | ||
return pybamm.simplify_if_constant(Division(left, right)) | ||
|
||
|
||
def simplified_matrix_multiplication(left, right): | ||
left, right = preprocess_binary(left, right) | ||
if pybamm.is_matrix_zero(left) or pybamm.is_matrix_zero(right): | ||
return pybamm.zeros_like(pybamm.MatrixMultiplication(left, right)) | ||
return pybamm.zeros_like(MatrixMultiplication(left, right)) | ||
|
||
if isinstance(right, Multiplication) and left.is_constant(): | ||
# Simplify A @ (b * c) to (A * b) @ c if (A * b) is constant | ||
|
@@ -1309,18 +1315,35 @@ def simplified_matrix_multiplication(left, right): | |
new_mul.copy_domains(right) | ||
return new_mul | ||
|
||
# Simplify A @ (b + c) to (A @ b) + (A @ c) if (A @ b) or (A @ c) is constant | ||
# This is a common construction that appears from discretisation of spatial | ||
# operators | ||
# Don't do this if either b or c is a number as this will lead to matmul errors | ||
elif isinstance(right, Addition): | ||
if (right.left.is_constant() or right.right.is_constant()) and not ( | ||
elif isinstance(right, (Addition, Subtraction)): | ||
# Simplify A @ (b +- c) to (A @ b) +- (A @ c) if (A @ b) or (A @ c) is constant | ||
# This is a common construction that appears from discretisation of spatial | ||
# operators | ||
# Or simplify A @ (B @ b +- C @ c) to (A @ B @ b) +- (A @ C @ c) if (A @ B) | ||
# and (A @ C) are constant | ||
# Don't do this if either b or c is a number as this will lead to matmul errors | ||
if ( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need A to be constant here too? I.e. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah, good point |
||
(right.left.is_constant() or right.right.is_constant()) | ||
# these lines should work but don't, possibly because of poorly | ||
# conditioned model? | ||
# or ( | ||
# isinstance(right.left, MatrixMultiplication) | ||
# and right.left.left.is_constant() | ||
# and isinstance(right.right, MatrixMultiplication) | ||
# and right.right.left.is_constant() | ||
# ) | ||
) and not ( | ||
right.left.size_for_testing == 1 or right.right.size_for_testing == 1 | ||
): | ||
r_left, r_right = right.orphans | ||
return (left @ r_left) + (left @ r_right) | ||
|
||
return pybamm.simplify_if_constant(pybamm.MatrixMultiplication(left, right)) | ||
r_left.domains = right.domains | ||
r_right.domains = right.domains | ||
if isinstance(right, Addition): | ||
return (left @ r_left) + (left @ r_right) | ||
elif isinstance(right, Subtraction): | ||
return (left @ r_left) - (left @ r_right) | ||
|
||
return pybamm.simplify_if_constant(MatrixMultiplication(left, right)) | ||
|
||
|
||
def minimum(left, right): | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Aren't there more simplifications here? I think this catches things like
(a*b)/(a*c)
but wouldn't catch(b*a)/(c*a)
? Is it much overhead to check for these?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah, just adding things as I see them come up in expression trees