Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Force ICNN to adopt default initialization of its own layers #551

Merged
merged 1 commit into from
Jun 25, 2024

Conversation

Algue-Rythme
Copy link
Collaborator

The ICNN used to rely on initialisation with normal matrices. Now, it fallbacks to the behavior of the layers, i.e lecun_normal, which scales the standard deviation of the weights with 1/sqrt(fan_in). This std is much smaller for wide networks.

Copy link

codecov bot commented Jun 21, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 91.39%. Comparing base (787d4a9) to head (27c582a).
Report is 39 commits behind head on main.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main     #551      +/-   ##
==========================================
+ Coverage   91.38%   91.39%   +0.01%     
==========================================
  Files          69       69              
  Lines        7242     7242              
  Branches     1019     1018       -1     
==========================================
+ Hits         6618     6619       +1     
  Misses        472      472              
+ Partials      152      151       -1     
Files with missing lines Coverage Δ
src/ott/neural/networks/icnn.py 94.54% <100.00%> (ø)

... and 1 file with indirect coverage changes

@michalk8 michalk8 self-requested a review June 25, 2024 09:26
Copy link
Collaborator

@michalk8 michalk8 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @Algue-Rythme , LGTM !

@michalk8 michalk8 merged commit d3b6c40 into main Jun 25, 2024
13 checks passed
@michalk8 michalk8 deleted the fixdefaultinit branch June 25, 2024 09:26
michalk8 added a commit that referenced this pull request Oct 16, 2024
michalk8 added a commit that referenced this pull request Oct 16, 2024
* Start batched vmap

* Initial `batched_vmap` impl

* Nicer formatting

* Fix getting shape

* Remove private API usage

* Fix new args

* Add a TODO

* Canonicalize axes

* Add `batched_vmap` to docs

* Removed batched transport functions

* Remove `_norm_{x,y}` from `CostFn`

* Implement `apply_lse_kernel`

* Implememt `apply_kernel`

* Implement `apply_cost`

* Remove old functions

* Make function private

* Refactor `apply_cost` to have consistent shapes

* Use `_apply_cost_to_vec` in `PointCloud`

* Remoeve TODO

* Formatting

* Simplify `_apply_sqeucl_cost`

* Fix `RecusionError`

* Remove docstring of a private method

* Fix `apply_lse_kernel`

* Squeeze only 1 axis of the cost

* Add TODO

* Rename function, make a property

* Remove unused helper function

* Compute mean summary online

* Compute mean online

* Compute max cost matrix

* Update error message

* Remove TODO

* Flatten out axes

* Fix missing cross terms in the costs

* Fix geom tests

* Fix dtype

* Start implementing transport functions

* Implement online transport functions

* Fix solver tests

* Fix Bures test

* Don't use `pairwise` in tests

* Update notebook that uses `norm`

* Fix bug in `UnbalancedBures`

* Rename `pairwise -> __call__`

* Remove old shape code

* Always instantiate the cost for online

* Remove old TODO

* Extract `_apply_cost_to_vec_fast`

* Update max cost in LRCGeom

* Fix test, use more `multi_dot`

* Remove `batch_size` from `LRCGeometry`

* Add better warning error

* Reorder properties

* Add docs to `batched_vmap`

* Start adding tests

* Reorder functions in test

* Fix axes, add a test

* Update test fn

* Move out assert

* Dont canon out_axes

* Check max traces

* Test memory of batched vmap

* Install `typing_extensions`

* Remove `.` from description

* Add more `out_axes` tests

* Add `in_axes` test

* Fix negative axes

* Increase memory limit in the test

* Add in_axes pytree test

* Remove old warnings filters

* Update fixtures

* Update SqEucl cost.

* Update docstrings

* Remove unused imports from the docs

* Revert test pre-commits

* Fix ICNN init notebook

Was broken by #551

* Improve error message
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants