-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[BugFix] Fix state-dict #528
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
maybe one test where the loaded data is not all zeros just to make sure
we also may want some tests in torchrl to test the non regression of saving the state dict of loss modules |
We load onto zeroed data no? |
on it already |
What do you mean? I meant something like getting the params from a linear module, saving, and reloading and checking against the original params |
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_plain_set_nested | 37.9010μs | 20.1606μs | 49.6016 KOps/s | 49.5104 KOps/s | |
test_plain_set_stack_nested | 0.2418ms | 0.1858ms | 5.3827 KOps/s | 5.3443 KOps/s | |
test_plain_set_nested_inplace | 53.7020μs | 23.6788μs | 42.2318 KOps/s | 42.2700 KOps/s | |
test_plain_set_stack_nested_inplace | 0.3111ms | 0.2228ms | 4.4874 KOps/s | 4.4588 KOps/s | |
test_items | 24.7010μs | 3.5974μs | 277.9764 KOps/s | 282.4074 KOps/s | |
test_items_nested | 2.2374ms | 0.3657ms | 2.7344 KOps/s | 2.7461 KOps/s | |
test_items_nested_locked | 0.7145ms | 0.3638ms | 2.7485 KOps/s | 2.7616 KOps/s | |
test_items_nested_leaf | 0.5186ms | 0.2222ms | 4.5014 KOps/s | 4.5427 KOps/s | |
test_items_stack_nested | 2.1838ms | 2.0241ms | 494.0403 Ops/s | 501.8802 Ops/s | |
test_items_stack_nested_leaf | 2.0531ms | 1.8459ms | 541.7432 Ops/s | 543.2395 Ops/s | |
test_items_stack_nested_locked | 1.1201ms | 0.9967ms | 1.0033 KOps/s | 990.5085 Ops/s | |
test_keys | 21.0000μs | 5.0778μs | 196.9345 KOps/s | 189.5071 KOps/s | |
test_keys_nested | 3.0336ms | 0.1856ms | 5.3869 KOps/s | 5.5038 KOps/s | |
test_keys_nested_locked | 4.4334ms | 0.1841ms | 5.4312 KOps/s | 5.5223 KOps/s | |
test_keys_nested_leaf | 0.3275ms | 0.1760ms | 5.6818 KOps/s | 5.3468 KOps/s | |
test_keys_stack_nested | 1.9965ms | 1.8605ms | 537.4762 Ops/s | 544.1983 Ops/s | |
test_keys_stack_nested_leaf | 1.9683ms | 1.8570ms | 538.4979 Ops/s | 543.4500 Ops/s | |
test_keys_stack_nested_locked | 0.9629ms | 0.8344ms | 1.1985 KOps/s | 1.2060 KOps/s | |
test_values | 14.7000μs | 1.5806μs | 632.6717 KOps/s | 628.2515 KOps/s | |
test_values_nested | 0.1347ms | 67.8477μs | 14.7389 KOps/s | 14.5010 KOps/s | |
test_values_nested_locked | 94.1020μs | 67.6789μs | 14.7756 KOps/s | 14.4703 KOps/s | |
test_values_nested_leaf | 0.1450ms | 59.5199μs | 16.8011 KOps/s | 16.7204 KOps/s | |
test_values_stack_nested | 1.7610ms | 1.6355ms | 611.4404 Ops/s | 618.5231 Ops/s | |
test_values_stack_nested_leaf | 1.7505ms | 1.6232ms | 616.0838 Ops/s | 625.7119 Ops/s | |
test_values_stack_nested_locked | 0.9013ms | 0.6580ms | 1.5196 KOps/s | 1.5433 KOps/s | |
test_membership | 18.0000μs | 1.8464μs | 541.5810 KOps/s | 536.9592 KOps/s | |
test_membership_nested | 35.4000μs | 3.6747μs | 272.1303 KOps/s | 266.0965 KOps/s | |
test_membership_nested_leaf | 72.0020μs | 3.6583μs | 273.3519 KOps/s | 267.8386 KOps/s | |
test_membership_stacked_nested | 31.2010μs | 14.4432μs | 69.2366 KOps/s | 69.3179 KOps/s | |
test_membership_stacked_nested_leaf | 70.9010μs | 14.3700μs | 69.5895 KOps/s | 69.0403 KOps/s | |
test_membership_nested_last | 66.4010μs | 7.5185μs | 133.0054 KOps/s | 131.2651 KOps/s | |
test_membership_nested_leaf_last | 34.3000μs | 7.5615μs | 132.2497 KOps/s | 129.4385 KOps/s | |
test_membership_stacked_nested_last | 0.3160ms | 0.2277ms | 4.3922 KOps/s | 4.4023 KOps/s | |
test_membership_stacked_nested_leaf_last | 98.3030μs | 17.0462μs | 58.6641 KOps/s | 59.2988 KOps/s | |
test_nested_getleaf | 65.9020μs | 15.7657μs | 63.4288 KOps/s | 64.1225 KOps/s | |
test_nested_get | 86.1020μs | 14.8332μs | 67.4166 KOps/s | 67.5209 KOps/s | |
test_stacked_getleaf | 1.0313ms | 0.8941ms | 1.1184 KOps/s | 1.1408 KOps/s | |
test_stacked_get | 0.9881ms | 0.8557ms | 1.1687 KOps/s | 1.1806 KOps/s | |
test_nested_getitemleaf | 54.5020μs | 15.5704μs | 64.2244 KOps/s | 63.7315 KOps/s | |
test_nested_getitem | 43.3010μs | 14.7641μs | 67.7317 KOps/s | 67.2432 KOps/s | |
test_stacked_getitemleaf | 1.0492ms | 0.9040ms | 1.1062 KOps/s | 1.1401 KOps/s | |
test_stacked_getitem | 0.9864ms | 0.8563ms | 1.1679 KOps/s | 1.1910 KOps/s | |
test_lock_nested | 74.9286ms | 1.5434ms | 647.9339 Ops/s | 693.8950 Ops/s | |
test_lock_stack_nested | 93.8042ms | 20.3892ms | 49.0455 Ops/s | 52.6179 Ops/s | |
test_unlock_nested | 1.6823ms | 1.4762ms | 677.4066 Ops/s | 649.9003 Ops/s | |
test_unlock_stack_nested | 95.7312ms | 19.5932ms | 51.0382 Ops/s | 51.4140 Ops/s | |
test_flatten_speed | 1.1318ms | 1.0553ms | 947.5680 Ops/s | 981.2656 Ops/s | |
test_unflatten_speed | 1.9978ms | 1.8663ms | 535.8080 Ops/s | 536.8174 Ops/s | |
test_common_ops | 4.8305ms | 1.1244ms | 889.3953 Ops/s | 910.3334 Ops/s | |
test_creation | 38.9010μs | 6.4162μs | 155.8559 KOps/s | 159.1616 KOps/s | |
test_creation_empty | 98.4020μs | 14.0490μs | 71.1795 KOps/s | 72.1194 KOps/s | |
test_creation_nested_1 | 45.1010μs | 25.4060μs | 39.3607 KOps/s | 40.2938 KOps/s | |
test_creation_nested_2 | 52.3010μs | 27.7381μs | 36.0514 KOps/s | 36.4377 KOps/s | |
test_clone | 0.1835ms | 25.2698μs | 39.5729 KOps/s | 39.4362 KOps/s | |
test_getitem[int] | 49.9010μs | 28.0235μs | 35.6843 KOps/s | 35.4795 KOps/s | |
test_getitem[slice_int] | 0.1033ms | 55.1507μs | 18.1321 KOps/s | 18.2307 KOps/s | |
test_getitem[range] | 0.2174ms | 83.3892μs | 11.9920 KOps/s | 12.1503 KOps/s | |
test_getitem[tuple] | 69.6020μs | 45.7011μs | 21.8813 KOps/s | 21.9788 KOps/s | |
test_getitem[list] | 0.3861ms | 78.8473μs | 12.6827 KOps/s | 12.9613 KOps/s | |
test_setitem_dim[int] | 53.6010μs | 33.1261μs | 30.1877 KOps/s | 30.2169 KOps/s | |
test_setitem_dim[slice_int] | 0.1700ms | 59.0216μs | 16.9430 KOps/s | 17.2185 KOps/s | |
test_setitem_dim[range] | 0.1049ms | 80.8013μs | 12.3760 KOps/s | 12.6575 KOps/s | |
test_setitem_dim[tuple] | 70.7020μs | 49.2037μs | 20.3237 KOps/s | 20.6322 KOps/s | |
test_setitem | 0.2269ms | 33.5309μs | 29.8232 KOps/s | 30.4830 KOps/s | |
test_set | 0.2060ms | 32.1533μs | 31.1010 KOps/s | 31.5505 KOps/s | |
test_set_shared | 0.3930ms | 0.1817ms | 5.5028 KOps/s | 5.6098 KOps/s | |
test_update | 0.2099ms | 36.6450μs | 27.2889 KOps/s | 27.9654 KOps/s | |
test_update_nested | 0.2303ms | 53.7877μs | 18.5916 KOps/s | 19.0587 KOps/s | |
test_set_nested | 0.2160ms | 35.4237μs | 28.2296 KOps/s | 28.7609 KOps/s | |
test_set_nested_new | 0.2268ms | 54.3061μs | 18.4141 KOps/s | 18.5638 KOps/s | |
test_select | 0.2783ms | 98.7559μs | 10.1260 KOps/s | 10.2050 KOps/s | |
test_unbind_speed | 0.7072ms | 0.6530ms | 1.5314 KOps/s | 1.5585 KOps/s | |
test_unbind_speed_stack0 | 85.0311ms | 9.5882ms | 104.2952 Ops/s | 114.5765 Ops/s | |
test_unbind_speed_stack1 | 16.7000μs | 1.1708μs | 854.0837 KOps/s | 1.0831 MOps/s | |
test_creation[device0] | 0.5812ms | 0.4541ms | 2.2024 KOps/s | 2.2506 KOps/s | |
test_creation_from_tensor | 2.3356ms | 0.5135ms | 1.9476 KOps/s | 2.0064 KOps/s | |
test_add_one[memmap_tensor0] | 1.9026ms | 33.6271μs | 29.7379 KOps/s | 30.7713 KOps/s | |
test_contiguous[memmap_tensor0] | 27.0010μs | 8.6867μs | 115.1188 KOps/s | 115.5022 KOps/s | |
test_stack[memmap_tensor0] | 0.1097ms | 26.4832μs | 37.7597 KOps/s | 37.5821 KOps/s | |
test_memmaptd_index | 0.4132ms | 0.3174ms | 3.1507 KOps/s | 3.1524 KOps/s | |
test_memmaptd_index_astensor | 2.4745ms | 1.3924ms | 718.1972 Ops/s | 727.2473 Ops/s | |
test_memmaptd_index_op | 2.8996ms | 2.7550ms | 362.9706 Ops/s | 378.7316 Ops/s | |
test_reshape_pytree | 99.6020μs | 38.0298μs | 26.2951 KOps/s | 26.5706 KOps/s | |
test_reshape_td | 86.9020μs | 46.3159μs | 21.5909 KOps/s | 22.3257 KOps/s | |
test_view_pytree | 0.1227ms | 35.4849μs | 28.1810 KOps/s | 28.5663 KOps/s | |
test_view_td | 28.6010μs | 9.0082μs | 111.0101 KOps/s | 112.6618 KOps/s | |
test_unbind_pytree | 78.8020μs | 38.5669μs | 25.9290 KOps/s | 25.8983 KOps/s | |
test_unbind_td | 0.2451ms | 96.9284μs | 10.3169 KOps/s | 10.3506 KOps/s | |
test_split_pytree | 88.0030μs | 45.2766μs | 22.0865 KOps/s | 22.5184 KOps/s | |
test_split_td | 0.9138ms | 0.1164ms | 8.5910 KOps/s | 8.6059 KOps/s | |
test_add_pytree | 0.1093ms | 47.9550μs | 20.8529 KOps/s | 20.8092 KOps/s | |
test_add_td | 0.1690ms | 79.8872μs | 12.5177 KOps/s | 13.1372 KOps/s | |
test_distributed | 26.2010μs | 9.0230μs | 110.8285 KOps/s | 113.1816 KOps/s | |
test_tdmodule | 0.1896ms | 29.1561μs | 34.2982 KOps/s | 34.0776 KOps/s | |
test_tdmodule_dispatch | 0.2849ms | 56.6069μs | 17.6657 KOps/s | 17.8036 KOps/s | |
test_tdseq | 0.5313ms | 33.9600μs | 29.4464 KOps/s | 29.3728 KOps/s | |
test_tdseq_dispatch | 0.5427ms | 68.9302μs | 14.5074 KOps/s | 14.7743 KOps/s | |
test_instantiation_functorch | 1.7849ms | 1.6486ms | 606.5606 Ops/s | 614.7188 Ops/s | |
test_instantiation_td | 2.0235ms | 1.3761ms | 726.6948 Ops/s | 678.2709 Ops/s | |
test_exec_functorch | 0.3240ms | 0.1905ms | 5.2482 KOps/s | 5.3068 KOps/s | |
test_exec_td | 0.3177ms | 0.1851ms | 5.4011 KOps/s | 5.5926 KOps/s | |
test_vmap_mlp_speed[True-True] | 6.9113ms | 1.2387ms | 807.2943 Ops/s | 832.5643 Ops/s | |
test_vmap_mlp_speed[True-False] | 3.7076ms | 0.6298ms | 1.5878 KOps/s | 1.6319 KOps/s | |
test_vmap_mlp_speed[False-True] | 4.0657ms | 1.0568ms | 946.2450 Ops/s | 974.6047 Ops/s | |
test_vmap_mlp_speed[False-False] | 6.4094ms | 0.4763ms | 2.0997 KOps/s | 2.1796 KOps/s | |
test_vmap_transformer_speed[True-True] | 16.1997ms | 13.8033ms | 72.4463 Ops/s | 73.5982 Ops/s | |
test_vmap_transformer_speed[True-False] | 12.1654ms | 8.6892ms | 115.0853 Ops/s | 117.6561 Ops/s | |
test_vmap_transformer_speed[False-True] | 19.2753ms | 13.4779ms | 74.1955 Ops/s | 75.4009 Ops/s | |
test_vmap_transformer_speed[False-False] | 14.7826ms | 8.5578ms | 116.8518 Ops/s | 120.1168 Ops/s |
Ah but we do that, what is the problem with doing it with zero? |
No problems, it just seemed like a special case, but this is a nit |
Description
Fixes calls to state-dict to make them compatible with nn.Module.
Introduces related tests and docstring.
cc @matteobettini