Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix state-dict #528

Merged
merged 7 commits into from
Sep 14, 2023
Merged

[BugFix] Fix state-dict #528

merged 7 commits into from
Sep 14, 2023

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Sep 14, 2023

Description

Fixes calls to state-dict to make them compatible with nn.Module.

Introduces related tests and docstring.

cc @matteobettini

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Sep 14, 2023
@vmoens vmoens added the bug Something isn't working label Sep 14, 2023
@vmoens vmoens marked this pull request as ready for review September 14, 2023 14:02
Copy link
Contributor

@matteobettini matteobettini left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

maybe one test where the loaded data is not all zeros just to make sure

@matteobettini
Copy link
Contributor

matteobettini commented Sep 14, 2023

we also may want some tests in torchrl to test the non regression of saving the state dict of loss modules

@vmoens
Copy link
Contributor Author

vmoens commented Sep 14, 2023

LGTM

maybe one test where the loaded data is not all zeros just to make sure

We load onto zeroed data no?

@vmoens
Copy link
Contributor Author

vmoens commented Sep 14, 2023

we also may want some tests in torchrl to test the non regression of saving the state dict of loss modules

on it already

@matteobettini
Copy link
Contributor

LGTM
maybe one test where the loaded data is not all zeros just to make sure

We load onto zeroed data no?

What do you mean?

I meant something like getting the params from a linear module, saving, and reloading and checking against the original params

@github-actions
Copy link

github-actions bot commented Sep 14, 2023

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 109. Improved: $\large\color{#35bf28}2$. Worsened: $\large\color{#d91a1a}4$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_plain_set_nested 37.9010μs 20.1606μs 49.6016 KOps/s 49.5104 KOps/s $\color{#35bf28}+0.18\%$
test_plain_set_stack_nested 0.2418ms 0.1858ms 5.3827 KOps/s 5.3443 KOps/s $\color{#35bf28}+0.72\%$
test_plain_set_nested_inplace 53.7020μs 23.6788μs 42.2318 KOps/s 42.2700 KOps/s $\color{#d91a1a}-0.09\%$
test_plain_set_stack_nested_inplace 0.3111ms 0.2228ms 4.4874 KOps/s 4.4588 KOps/s $\color{#35bf28}+0.64\%$
test_items 24.7010μs 3.5974μs 277.9764 KOps/s 282.4074 KOps/s $\color{#d91a1a}-1.57\%$
test_items_nested 2.2374ms 0.3657ms 2.7344 KOps/s 2.7461 KOps/s $\color{#d91a1a}-0.42\%$
test_items_nested_locked 0.7145ms 0.3638ms 2.7485 KOps/s 2.7616 KOps/s $\color{#d91a1a}-0.48\%$
test_items_nested_leaf 0.5186ms 0.2222ms 4.5014 KOps/s 4.5427 KOps/s $\color{#d91a1a}-0.91\%$
test_items_stack_nested 2.1838ms 2.0241ms 494.0403 Ops/s 501.8802 Ops/s $\color{#d91a1a}-1.56\%$
test_items_stack_nested_leaf 2.0531ms 1.8459ms 541.7432 Ops/s 543.2395 Ops/s $\color{#d91a1a}-0.28\%$
test_items_stack_nested_locked 1.1201ms 0.9967ms 1.0033 KOps/s 990.5085 Ops/s $\color{#35bf28}+1.29\%$
test_keys 21.0000μs 5.0778μs 196.9345 KOps/s 189.5071 KOps/s $\color{#35bf28}+3.92\%$
test_keys_nested 3.0336ms 0.1856ms 5.3869 KOps/s 5.5038 KOps/s $\color{#d91a1a}-2.12\%$
test_keys_nested_locked 4.4334ms 0.1841ms 5.4312 KOps/s 5.5223 KOps/s $\color{#d91a1a}-1.65\%$
test_keys_nested_leaf 0.3275ms 0.1760ms 5.6818 KOps/s 5.3468 KOps/s $\textbf{\color{#35bf28}+6.27\%}$
test_keys_stack_nested 1.9965ms 1.8605ms 537.4762 Ops/s 544.1983 Ops/s $\color{#d91a1a}-1.24\%$
test_keys_stack_nested_leaf 1.9683ms 1.8570ms 538.4979 Ops/s 543.4500 Ops/s $\color{#d91a1a}-0.91\%$
test_keys_stack_nested_locked 0.9629ms 0.8344ms 1.1985 KOps/s 1.2060 KOps/s $\color{#d91a1a}-0.62\%$
test_values 14.7000μs 1.5806μs 632.6717 KOps/s 628.2515 KOps/s $\color{#35bf28}+0.70\%$
test_values_nested 0.1347ms 67.8477μs 14.7389 KOps/s 14.5010 KOps/s $\color{#35bf28}+1.64\%$
test_values_nested_locked 94.1020μs 67.6789μs 14.7756 KOps/s 14.4703 KOps/s $\color{#35bf28}+2.11\%$
test_values_nested_leaf 0.1450ms 59.5199μs 16.8011 KOps/s 16.7204 KOps/s $\color{#35bf28}+0.48\%$
test_values_stack_nested 1.7610ms 1.6355ms 611.4404 Ops/s 618.5231 Ops/s $\color{#d91a1a}-1.15\%$
test_values_stack_nested_leaf 1.7505ms 1.6232ms 616.0838 Ops/s 625.7119 Ops/s $\color{#d91a1a}-1.54\%$
test_values_stack_nested_locked 0.9013ms 0.6580ms 1.5196 KOps/s 1.5433 KOps/s $\color{#d91a1a}-1.53\%$
test_membership 18.0000μs 1.8464μs 541.5810 KOps/s 536.9592 KOps/s $\color{#35bf28}+0.86\%$
test_membership_nested 35.4000μs 3.6747μs 272.1303 KOps/s 266.0965 KOps/s $\color{#35bf28}+2.27\%$
test_membership_nested_leaf 72.0020μs 3.6583μs 273.3519 KOps/s 267.8386 KOps/s $\color{#35bf28}+2.06\%$
test_membership_stacked_nested 31.2010μs 14.4432μs 69.2366 KOps/s 69.3179 KOps/s $\color{#d91a1a}-0.12\%$
test_membership_stacked_nested_leaf 70.9010μs 14.3700μs 69.5895 KOps/s 69.0403 KOps/s $\color{#35bf28}+0.80\%$
test_membership_nested_last 66.4010μs 7.5185μs 133.0054 KOps/s 131.2651 KOps/s $\color{#35bf28}+1.33\%$
test_membership_nested_leaf_last 34.3000μs 7.5615μs 132.2497 KOps/s 129.4385 KOps/s $\color{#35bf28}+2.17\%$
test_membership_stacked_nested_last 0.3160ms 0.2277ms 4.3922 KOps/s 4.4023 KOps/s $\color{#d91a1a}-0.23\%$
test_membership_stacked_nested_leaf_last 98.3030μs 17.0462μs 58.6641 KOps/s 59.2988 KOps/s $\color{#d91a1a}-1.07\%$
test_nested_getleaf 65.9020μs 15.7657μs 63.4288 KOps/s 64.1225 KOps/s $\color{#d91a1a}-1.08\%$
test_nested_get 86.1020μs 14.8332μs 67.4166 KOps/s 67.5209 KOps/s $\color{#d91a1a}-0.15\%$
test_stacked_getleaf 1.0313ms 0.8941ms 1.1184 KOps/s 1.1408 KOps/s $\color{#d91a1a}-1.96\%$
test_stacked_get 0.9881ms 0.8557ms 1.1687 KOps/s 1.1806 KOps/s $\color{#d91a1a}-1.01\%$
test_nested_getitemleaf 54.5020μs 15.5704μs 64.2244 KOps/s 63.7315 KOps/s $\color{#35bf28}+0.77\%$
test_nested_getitem 43.3010μs 14.7641μs 67.7317 KOps/s 67.2432 KOps/s $\color{#35bf28}+0.73\%$
test_stacked_getitemleaf 1.0492ms 0.9040ms 1.1062 KOps/s 1.1401 KOps/s $\color{#d91a1a}-2.97\%$
test_stacked_getitem 0.9864ms 0.8563ms 1.1679 KOps/s 1.1910 KOps/s $\color{#d91a1a}-1.94\%$
test_lock_nested 74.9286ms 1.5434ms 647.9339 Ops/s 693.8950 Ops/s $\textbf{\color{#d91a1a}-6.62\%}$
test_lock_stack_nested 93.8042ms 20.3892ms 49.0455 Ops/s 52.6179 Ops/s $\textbf{\color{#d91a1a}-6.79\%}$
test_unlock_nested 1.6823ms 1.4762ms 677.4066 Ops/s 649.9003 Ops/s $\color{#35bf28}+4.23\%$
test_unlock_stack_nested 95.7312ms 19.5932ms 51.0382 Ops/s 51.4140 Ops/s $\color{#d91a1a}-0.73\%$
test_flatten_speed 1.1318ms 1.0553ms 947.5680 Ops/s 981.2656 Ops/s $\color{#d91a1a}-3.43\%$
test_unflatten_speed 1.9978ms 1.8663ms 535.8080 Ops/s 536.8174 Ops/s $\color{#d91a1a}-0.19\%$
test_common_ops 4.8305ms 1.1244ms 889.3953 Ops/s 910.3334 Ops/s $\color{#d91a1a}-2.30\%$
test_creation 38.9010μs 6.4162μs 155.8559 KOps/s 159.1616 KOps/s $\color{#d91a1a}-2.08\%$
test_creation_empty 98.4020μs 14.0490μs 71.1795 KOps/s 72.1194 KOps/s $\color{#d91a1a}-1.30\%$
test_creation_nested_1 45.1010μs 25.4060μs 39.3607 KOps/s 40.2938 KOps/s $\color{#d91a1a}-2.32\%$
test_creation_nested_2 52.3010μs 27.7381μs 36.0514 KOps/s 36.4377 KOps/s $\color{#d91a1a}-1.06\%$
test_clone 0.1835ms 25.2698μs 39.5729 KOps/s 39.4362 KOps/s $\color{#35bf28}+0.35\%$
test_getitem[int] 49.9010μs 28.0235μs 35.6843 KOps/s 35.4795 KOps/s $\color{#35bf28}+0.58\%$
test_getitem[slice_int] 0.1033ms 55.1507μs 18.1321 KOps/s 18.2307 KOps/s $\color{#d91a1a}-0.54\%$
test_getitem[range] 0.2174ms 83.3892μs 11.9920 KOps/s 12.1503 KOps/s $\color{#d91a1a}-1.30\%$
test_getitem[tuple] 69.6020μs 45.7011μs 21.8813 KOps/s 21.9788 KOps/s $\color{#d91a1a}-0.44\%$
test_getitem[list] 0.3861ms 78.8473μs 12.6827 KOps/s 12.9613 KOps/s $\color{#d91a1a}-2.15\%$
test_setitem_dim[int] 53.6010μs 33.1261μs 30.1877 KOps/s 30.2169 KOps/s $\color{#d91a1a}-0.10\%$
test_setitem_dim[slice_int] 0.1700ms 59.0216μs 16.9430 KOps/s 17.2185 KOps/s $\color{#d91a1a}-1.60\%$
test_setitem_dim[range] 0.1049ms 80.8013μs 12.3760 KOps/s 12.6575 KOps/s $\color{#d91a1a}-2.22\%$
test_setitem_dim[tuple] 70.7020μs 49.2037μs 20.3237 KOps/s 20.6322 KOps/s $\color{#d91a1a}-1.50\%$
test_setitem 0.2269ms 33.5309μs 29.8232 KOps/s 30.4830 KOps/s $\color{#d91a1a}-2.16\%$
test_set 0.2060ms 32.1533μs 31.1010 KOps/s 31.5505 KOps/s $\color{#d91a1a}-1.42\%$
test_set_shared 0.3930ms 0.1817ms 5.5028 KOps/s 5.6098 KOps/s $\color{#d91a1a}-1.91\%$
test_update 0.2099ms 36.6450μs 27.2889 KOps/s 27.9654 KOps/s $\color{#d91a1a}-2.42\%$
test_update_nested 0.2303ms 53.7877μs 18.5916 KOps/s 19.0587 KOps/s $\color{#d91a1a}-2.45\%$
test_set_nested 0.2160ms 35.4237μs 28.2296 KOps/s 28.7609 KOps/s $\color{#d91a1a}-1.85\%$
test_set_nested_new 0.2268ms 54.3061μs 18.4141 KOps/s 18.5638 KOps/s $\color{#d91a1a}-0.81\%$
test_select 0.2783ms 98.7559μs 10.1260 KOps/s 10.2050 KOps/s $\color{#d91a1a}-0.77\%$
test_unbind_speed 0.7072ms 0.6530ms 1.5314 KOps/s 1.5585 KOps/s $\color{#d91a1a}-1.74\%$
test_unbind_speed_stack0 85.0311ms 9.5882ms 104.2952 Ops/s 114.5765 Ops/s $\textbf{\color{#d91a1a}-8.97\%}$
test_unbind_speed_stack1 16.7000μs 1.1708μs 854.0837 KOps/s 1.0831 MOps/s $\textbf{\color{#d91a1a}-21.14\%}$
test_creation[device0] 0.5812ms 0.4541ms 2.2024 KOps/s 2.2506 KOps/s $\color{#d91a1a}-2.14\%$
test_creation_from_tensor 2.3356ms 0.5135ms 1.9476 KOps/s 2.0064 KOps/s $\color{#d91a1a}-2.93\%$
test_add_one[memmap_tensor0] 1.9026ms 33.6271μs 29.7379 KOps/s 30.7713 KOps/s $\color{#d91a1a}-3.36\%$
test_contiguous[memmap_tensor0] 27.0010μs 8.6867μs 115.1188 KOps/s 115.5022 KOps/s $\color{#d91a1a}-0.33\%$
test_stack[memmap_tensor0] 0.1097ms 26.4832μs 37.7597 KOps/s 37.5821 KOps/s $\color{#35bf28}+0.47\%$
test_memmaptd_index 0.4132ms 0.3174ms 3.1507 KOps/s 3.1524 KOps/s $\color{#d91a1a}-0.06\%$
test_memmaptd_index_astensor 2.4745ms 1.3924ms 718.1972 Ops/s 727.2473 Ops/s $\color{#d91a1a}-1.24\%$
test_memmaptd_index_op 2.8996ms 2.7550ms 362.9706 Ops/s 378.7316 Ops/s $\color{#d91a1a}-4.16\%$
test_reshape_pytree 99.6020μs 38.0298μs 26.2951 KOps/s 26.5706 KOps/s $\color{#d91a1a}-1.04\%$
test_reshape_td 86.9020μs 46.3159μs 21.5909 KOps/s 22.3257 KOps/s $\color{#d91a1a}-3.29\%$
test_view_pytree 0.1227ms 35.4849μs 28.1810 KOps/s 28.5663 KOps/s $\color{#d91a1a}-1.35\%$
test_view_td 28.6010μs 9.0082μs 111.0101 KOps/s 112.6618 KOps/s $\color{#d91a1a}-1.47\%$
test_unbind_pytree 78.8020μs 38.5669μs 25.9290 KOps/s 25.8983 KOps/s $\color{#35bf28}+0.12\%$
test_unbind_td 0.2451ms 96.9284μs 10.3169 KOps/s 10.3506 KOps/s $\color{#d91a1a}-0.33\%$
test_split_pytree 88.0030μs 45.2766μs 22.0865 KOps/s 22.5184 KOps/s $\color{#d91a1a}-1.92\%$
test_split_td 0.9138ms 0.1164ms 8.5910 KOps/s 8.6059 KOps/s $\color{#d91a1a}-0.17\%$
test_add_pytree 0.1093ms 47.9550μs 20.8529 KOps/s 20.8092 KOps/s $\color{#35bf28}+0.21\%$
test_add_td 0.1690ms 79.8872μs 12.5177 KOps/s 13.1372 KOps/s $\color{#d91a1a}-4.72\%$
test_distributed 26.2010μs 9.0230μs 110.8285 KOps/s 113.1816 KOps/s $\color{#d91a1a}-2.08\%$
test_tdmodule 0.1896ms 29.1561μs 34.2982 KOps/s 34.0776 KOps/s $\color{#35bf28}+0.65\%$
test_tdmodule_dispatch 0.2849ms 56.6069μs 17.6657 KOps/s 17.8036 KOps/s $\color{#d91a1a}-0.77\%$
test_tdseq 0.5313ms 33.9600μs 29.4464 KOps/s 29.3728 KOps/s $\color{#35bf28}+0.25\%$
test_tdseq_dispatch 0.5427ms 68.9302μs 14.5074 KOps/s 14.7743 KOps/s $\color{#d91a1a}-1.81\%$
test_instantiation_functorch 1.7849ms 1.6486ms 606.5606 Ops/s 614.7188 Ops/s $\color{#d91a1a}-1.33\%$
test_instantiation_td 2.0235ms 1.3761ms 726.6948 Ops/s 678.2709 Ops/s $\textbf{\color{#35bf28}+7.14\%}$
test_exec_functorch 0.3240ms 0.1905ms 5.2482 KOps/s 5.3068 KOps/s $\color{#d91a1a}-1.10\%$
test_exec_td 0.3177ms 0.1851ms 5.4011 KOps/s 5.5926 KOps/s $\color{#d91a1a}-3.42\%$
test_vmap_mlp_speed[True-True] 6.9113ms 1.2387ms 807.2943 Ops/s 832.5643 Ops/s $\color{#d91a1a}-3.04\%$
test_vmap_mlp_speed[True-False] 3.7076ms 0.6298ms 1.5878 KOps/s 1.6319 KOps/s $\color{#d91a1a}-2.70\%$
test_vmap_mlp_speed[False-True] 4.0657ms 1.0568ms 946.2450 Ops/s 974.6047 Ops/s $\color{#d91a1a}-2.91\%$
test_vmap_mlp_speed[False-False] 6.4094ms 0.4763ms 2.0997 KOps/s 2.1796 KOps/s $\color{#d91a1a}-3.66\%$
test_vmap_transformer_speed[True-True] 16.1997ms 13.8033ms 72.4463 Ops/s 73.5982 Ops/s $\color{#d91a1a}-1.57\%$
test_vmap_transformer_speed[True-False] 12.1654ms 8.6892ms 115.0853 Ops/s 117.6561 Ops/s $\color{#d91a1a}-2.19\%$
test_vmap_transformer_speed[False-True] 19.2753ms 13.4779ms 74.1955 Ops/s 75.4009 Ops/s $\color{#d91a1a}-1.60\%$
test_vmap_transformer_speed[False-False] 14.7826ms 8.5578ms 116.8518 Ops/s 120.1168 Ops/s $\color{#d91a1a}-2.72\%$

@vmoens
Copy link
Contributor Author

vmoens commented Sep 14, 2023

I meant something like getting the params from a linear module, saving, and reloading and checking against the original params

Ah but we do that, what is the problem with doing it with zero?

@matteobettini
Copy link
Contributor

I meant something like getting the params from a linear module, saving, and reloading and checking against the original params

Ah but we do that, what is the problem with doing it with zero?

No problems, it just seemed like a special case, but this is a nit

@vmoens vmoens merged commit 0006b91 into main Sep 14, 2023
@vmoens vmoens deleted the fix_state_dict branch September 14, 2023 16:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants