Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better coverage for float32 tests #6780

Merged
merged 4 commits into from
Jun 22, 2023
Merged

Better coverage for float32 tests #6780

merged 4 commits into from
Jun 22, 2023

Conversation

ferrine
Copy link
Member

@ferrine ferrine commented Jun 15, 2023

fix #6779

Bugfixes

  • Fix the issue when float32 mode breaks Dirichet distribution
  • enforce float32 tests to not create float64 variables

📚 Documentation preview 📚: https://pymc--6780.org.readthedocs.build/en/6780/

@ferrine ferrine requested a review from ricardoV94 June 15, 2023 15:18
@ferrine ferrine changed the title Fix 6779 Fix 6779 Dirichlet distribution float32 mode Jun 15, 2023
@ferrine ferrine changed the title Fix 6779 Dirichlet distribution float32 mode Fix #6779 Dirichlet distribution float32 mode Jun 15, 2023
@codecov
Copy link

codecov bot commented Jun 15, 2023

Codecov Report

Merging #6780 (55d6c31) into main (f91dd1c) will increase coverage by 0.00%.
The diff coverage is 100.00%.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #6780   +/-   ##
=======================================
  Coverage   91.89%   91.89%           
=======================================
  Files          95       95           
  Lines       16181    16185    +4     
=======================================
+ Hits        14870    14874    +4     
  Misses       1311     1311           
Impacted Files Coverage Δ
pymc/logprob/transforms.py 94.28% <100.00%> (+0.04%) ⬆️

@ferrine
Copy link
Member Author

ferrine commented Jun 15, 2023

@ricardoV94 some errors seem weird to me, any idea how they could emerge?

with pm.Model() as model:
c = pm.floatX([1, 1, 1])
pm.Dirichlet("a", c)
model.point_logps()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be enough to request the model.logp()? No point in compiling and evaluating once if you're not checking the results

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 15, 2023

@ricardoV94 some errors seem weird to me, any idea how they could emerge?

I think some change with the latest PyTensor release, probably with __repr__ changes given the failure

@ricardoV94 ricardoV94 changed the title Fix #6779 Dirichlet distribution float32 mode Don't upcast to float64 in Simplex transform Jun 15, 2023
@ricardoV94 ricardoV94 changed the title Don't upcast to float64 in Simplex transform Don't force upcast to float64 in Simplex transform Jun 15, 2023
@ricardoV94
Copy link
Member

This is actual a change to the SimplexTransform, we should test that directly, and not via the Dirichlet? Another transform that I think will upcast to float64 is the ZeroSumNormal because it uses the shape of the value as well

@ferrine
Copy link
Member Author

ferrine commented Jun 16, 2023

This is actual a change to the SimplexTransform, we should test that directly, and not via the Dirichlet? Another transform that I think will upcast to float64 is the ZeroSumNormal because it uses the shape of the value as well

I'll change that in the test then

@ferrine ferrine changed the title Don't force upcast to float64 in Simplex transform Better coverage for float32 tests Jun 21, 2023
@ferrine
Copy link
Member Author

ferrine commented Jun 21, 2023

@ricardoV94 I looked into how transforms are checked with float32 mode, they do not. And the float32 tests that exist do not check float32 condition properly, they still allow float64 subgraphs graphs. I bet that float32 tests should be more strict that they are now.

@ricardoV94
Copy link
Member

@ricardoV94 I looked into how transforms are checked with float32 mode, they do not. And the float32 tests that exist do not check float32 condition properly, they still allow float64 subgraphs graphs. I bet that float32 tests should be more strict that they are now.

Feel free to make the float32 job more restrictive. Just keep in mind we shouldn't throw random things into those float32 jobs, but only tests that actually matter.

@ferrine
Copy link
Member Author

ferrine commented Jun 21, 2023

The real cause of the issue was actually here

image

The transforms did not respect the original dtype

@ferrine
Copy link
Member Author

ferrine commented Jun 21, 2023

I added a check that backward(forward(x)) keeps the original tensortype. It seems to catch the intended bug

@ferrine ferrine force-pushed the fix-6779 branch 2 times, most recently from 44711fe to a45a0b6 Compare June 21, 2023 13:55
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good.

The first failing test seems unrelated, can open an issue and rerun or see if it can be seeded.

@ricardoV94
Copy link
Member

ricardoV94 commented Jun 21, 2023

Also failing with the new numpy deprecation, should be a simple fix to replace np.product by np.prod (but should be a distinct commit, or if not, PR)

@@ -44,10 +44,10 @@

# some transforms (stick breaking) require addition of small slack in order to be numerically
# stable. The minimal addable slack for float32 is higher thus we need to be less strict
tol = 1e-7 if pytensor.config.floatX == "float64" else 1e-6
tol = 1e-7 if pytensor.config.floatX == "float64" else 1e-5
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just in case, float32 was not checked in the CI so the previous tolerance was not taken in account

@ferrine ferrine merged commit 14e673f into main Jun 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: Dirichlet is not tolerant to floatX=float32
2 participants