-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Refactor] Make CompositeDistribution a tensordict-exclusive class #1112
Conversation
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_plain_set_nested | 33.2630μs | 17.9849μs | 55.6022 KOps/s | 54.2347 KOps/s | |
test_plain_set_stack_nested | 41.7690μs | 18.0024μs | 55.5483 KOps/s | 53.7194 KOps/s | |
test_plain_set_nested_inplace | 58.3300μs | 19.4107μs | 51.5179 KOps/s | 49.2110 KOps/s | |
test_plain_set_stack_nested_inplace | 48.9220μs | 19.3203μs | 51.7589 KOps/s | 49.3015 KOps/s | |
test_items | 27.5620μs | 4.1305μs | 242.1009 KOps/s | 244.3937 KOps/s | |
test_items_nested | 0.4775ms | 0.3989ms | 2.5071 KOps/s | 2.5323 KOps/s | |
test_items_nested_locked | 0.5696ms | 0.3975ms | 2.5156 KOps/s | 2.5386 KOps/s | |
test_items_nested_leaf | 0.1436ms | 71.7581μs | 13.9357 KOps/s | 14.1993 KOps/s | |
test_items_stack_nested | 0.6880ms | 0.3995ms | 2.5033 KOps/s | 2.5312 KOps/s | |
test_items_stack_nested_leaf | 0.1416ms | 74.4218μs | 13.4369 KOps/s | 13.8415 KOps/s | |
test_items_stack_nested_locked | 0.5269ms | 0.4014ms | 2.4916 KOps/s | 2.5276 KOps/s | |
test_keys | 28.1940μs | 3.4985μs | 285.8397 KOps/s | 278.2794 KOps/s | |
test_keys_nested | 0.2683ms | 0.1384ms | 7.2270 KOps/s | 7.2652 KOps/s | |
test_keys_nested_locked | 1.7874ms | 0.1407ms | 7.1048 KOps/s | 7.0364 KOps/s | |
test_keys_nested_leaf | 0.2133ms | 0.1163ms | 8.5989 KOps/s | 8.5116 KOps/s | |
test_keys_stack_nested | 0.2176ms | 0.1342ms | 7.4518 KOps/s | 7.3529 KOps/s | |
test_keys_stack_nested_leaf | 0.2115ms | 0.1132ms | 8.8335 KOps/s | 8.5314 KOps/s | |
test_keys_stack_nested_locked | 0.2258ms | 0.1388ms | 7.2029 KOps/s | 7.0878 KOps/s | |
test_values | 4.6728μs | 1.0567μs | 946.3059 KOps/s | 945.8371 KOps/s | |
test_values_nested | 0.1073ms | 54.7067μs | 18.2793 KOps/s | 18.0852 KOps/s | |
test_values_nested_locked | 0.1184ms | 54.8245μs | 18.2400 KOps/s | 18.0965 KOps/s | |
test_values_nested_leaf | 0.1131ms | 59.2508μs | 16.8774 KOps/s | 16.7136 KOps/s | |
test_values_stack_nested | 0.1022ms | 55.5238μs | 18.0103 KOps/s | 17.8857 KOps/s | |
test_values_stack_nested_leaf | 0.1186ms | 59.2160μs | 16.8873 KOps/s | 16.6109 KOps/s | |
test_values_stack_nested_locked | 0.1047ms | 56.0198μs | 17.8508 KOps/s | 18.1138 KOps/s | |
test_membership | 1.9366μs | 0.7166μs | 1.3954 MOps/s | 1.3170 MOps/s | |
test_membership_nested | 23.7650μs | 2.9012μs | 344.6881 KOps/s | 347.6111 KOps/s | |
test_membership_nested_leaf | 25.9480μs | 2.9279μs | 341.5436 KOps/s | 346.1334 KOps/s | |
test_membership_stacked_nested | 27.9320μs | 2.9185μs | 342.6475 KOps/s | 349.1912 KOps/s | |
test_membership_stacked_nested_leaf | 28.4140μs | 2.9247μs | 341.9126 KOps/s | 348.4564 KOps/s | |
test_membership_nested_last | 35.7170μs | 4.2956μs | 232.7958 KOps/s | 237.6342 KOps/s | |
test_membership_nested_leaf_last | 34.4050μs | 4.3080μs | 232.1256 KOps/s | 239.2872 KOps/s | |
test_membership_stacked_nested_last | 41.9390μs | 13.2450μs | 75.5004 KOps/s | 242.0787 KOps/s | |
test_membership_stacked_nested_leaf_last | 44.9850μs | 13.0690μs | 76.5169 KOps/s | 238.2100 KOps/s | |
test_nested_getleaf | 0.1432ms | 10.9273μs | 91.5140 KOps/s | 91.6590 KOps/s | |
test_nested_get | 41.3780μs | 10.1566μs | 98.4580 KOps/s | 96.3632 KOps/s | |
test_stacked_getleaf | 46.9180μs | 10.6082μs | 94.2669 KOps/s | 93.3225 KOps/s | |
test_stacked_get | 35.4770μs | 10.0706μs | 99.2994 KOps/s | 96.7024 KOps/s | |
test_nested_getitemleaf | 35.4670μs | 11.0375μs | 90.6005 KOps/s | 89.6672 KOps/s | |
test_nested_getitem | 47.5200μs | 10.4266μs | 95.9084 KOps/s | 94.5165 KOps/s | |
test_stacked_getitemleaf | 35.0970μs | 11.0269μs | 90.6876 KOps/s | 86.5849 KOps/s | |
test_stacked_getitem | 35.4160μs | 10.5535μs | 94.7557 KOps/s | 93.8284 KOps/s | |
test_lock_nested | 2.6680ms | 0.4421ms | 2.2619 KOps/s | 2.2906 KOps/s | |
test_lock_stack_nested | 0.6546ms | 0.4033ms | 2.4797 KOps/s | 2.4368 KOps/s | |
test_unlock_nested | 0.7564ms | 0.3589ms | 2.7864 KOps/s | 2.7513 KOps/s | |
test_unlock_stack_nested | 0.3906ms | 0.3197ms | 3.1275 KOps/s | 3.0301 KOps/s | |
test_flatten_speed | 0.1760ms | 95.4035μs | 10.4818 KOps/s | 10.5116 KOps/s | |
test_unflatten_speed | 0.8873ms | 0.4895ms | 2.0431 KOps/s | 1.9908 KOps/s | |
test_common_ops | 4.3079ms | 0.7809ms | 1.2806 KOps/s | 1.2314 KOps/s | |
test_creation | 18.5550μs | 2.0945μs | 477.4384 KOps/s | 484.6501 KOps/s | |
test_creation_empty | 29.9170μs | 10.7858μs | 92.7146 KOps/s | 84.3006 KOps/s | |
test_creation_nested_1 | 35.7380μs | 13.4846μs | 74.1586 KOps/s | 67.7603 KOps/s | |
test_creation_nested_2 | 40.4460μs | 17.9285μs | 55.7770 KOps/s | 52.4013 KOps/s | |
test_clone | 77.2760μs | 13.5560μs | 73.7680 KOps/s | 76.1272 KOps/s | |
test_getitem[int] | 1.4458ms | 12.6964μs | 78.7622 KOps/s | 80.4285 KOps/s | |
test_getitem[slice_int] | 0.1511ms | 24.8815μs | 40.1905 KOps/s | 40.5477 KOps/s | |
test_getitem[range] | 0.1700ms | 47.7529μs | 20.9412 KOps/s | 20.7066 KOps/s | |
test_getitem[tuple] | 0.1348ms | 20.1779μs | 49.5591 KOps/s | 50.0935 KOps/s | |
test_getitem[list] | 0.1718ms | 42.9698μs | 23.2721 KOps/s | 22.6857 KOps/s | |
test_setitem_dim[int] | 47.5190μs | 25.5636μs | 39.1181 KOps/s | 40.3298 KOps/s | |
test_setitem_dim[slice_int] | 90.2300μs | 51.6434μs | 19.3636 KOps/s | 18.9154 KOps/s | |
test_setitem_dim[range] | 0.1217ms | 72.7126μs | 13.7528 KOps/s | 13.6224 KOps/s | |
test_setitem_dim[tuple] | 82.0240μs | 41.1582μs | 24.2965 KOps/s | 24.4277 KOps/s | |
test_setitem | 75.0210μs | 20.0138μs | 49.9656 KOps/s | 47.5719 KOps/s | |
test_set | 73.1170μs | 19.5903μs | 51.0456 KOps/s | 48.8373 KOps/s | |
test_set_shared | 3.6958ms | 0.1716ms | 5.8273 KOps/s | 5.9031 KOps/s | |
test_update | 0.1275ms | 22.0722μs | 45.3059 KOps/s | 41.5109 KOps/s | |
test_update_nested | 0.1141ms | 31.8241μs | 31.4227 KOps/s | 29.4353 KOps/s | |
test_update__nested | 0.6621ms | 31.5195μs | 31.7264 KOps/s | 31.8965 KOps/s | |
test_set_nested | 69.5510μs | 21.5120μs | 46.4858 KOps/s | 44.7543 KOps/s | |
test_set_nested_new | 78.0160μs | 26.3263μs | 37.9849 KOps/s | 36.5497 KOps/s | |
test_select | 0.1210ms | 42.1883μs | 23.7033 KOps/s | 23.0177 KOps/s | |
test_select_nested | 0.1224ms | 59.2645μs | 16.8735 KOps/s | 16.7705 KOps/s | |
test_exclude_nested | 0.1750ms | 77.8555μs | 12.8443 KOps/s | 12.6870 KOps/s | |
test_empty[True] | 0.4982ms | 0.3785ms | 2.6419 KOps/s | 2.5959 KOps/s | |
test_empty[False] | 7.2288μs | 1.2328μs | 811.1655 KOps/s | 836.8637 KOps/s | |
test_unbind_speed | 0.5311ms | 0.2639ms | 3.7896 KOps/s | 3.8938 KOps/s | |
test_unbind_speed_stack0 | 0.3348ms | 0.2538ms | 3.9406 KOps/s | 3.8720 KOps/s | |
test_unbind_speed_stack1 | 97.2553ms | 0.7396ms | 1.3522 KOps/s | 1.4247 KOps/s | |
test_split | 0.1006s | 1.7240ms | 580.0613 Ops/s | 582.3280 Ops/s | |
test_chunk | 97.6427ms | 1.7296ms | 578.1631 Ops/s | 581.9736 Ops/s | |
test_consolidate_njt[False-None] | 8.6050ms | 8.1478ms | 122.7319 Ops/s | 122.9008 Ops/s | |
test_creation[device0] | 0.2218ms | 90.9914μs | 10.9900 KOps/s | 10.9344 KOps/s | |
test_creation_from_tensor | 4.1764ms | 94.7173μs | 10.5577 KOps/s | 10.4016 KOps/s | |
test_add_one[memmap_tensor0] | 0.1311ms | 5.0545μs | 197.8436 KOps/s | 210.0095 KOps/s | |
test_contiguous[memmap_tensor0] | 23.1940μs | 0.5092μs | 1.9638 MOps/s | 1.9358 MOps/s | |
test_stack[memmap_tensor0] | 29.6450μs | 3.5017μs | 285.5784 KOps/s | 295.1797 KOps/s | |
test_memmaptd_index | 1.0604ms | 0.2545ms | 3.9298 KOps/s | 4.2925 KOps/s | |
test_memmaptd_index_astensor | 0.7244ms | 0.3201ms | 3.1236 KOps/s | 3.2266 KOps/s | |
test_memmaptd_index_op | 0.9456ms | 0.5774ms | 1.7319 KOps/s | 1.7048 KOps/s | |
test_serialize_model | 0.1192s | 0.1144s | 8.7377 Ops/s | 7.6011 Ops/s | |
test_serialize_model_pickle | 0.4415s | 0.3894s | 2.5680 Ops/s | 2.5089 Ops/s | |
test_serialize_weights | 0.2050s | 0.1267s | 7.8929 Ops/s | 8.6005 Ops/s | |
test_serialize_weights_returnearly | 0.1717s | 0.1570s | 6.3704 Ops/s | 6.4481 Ops/s | |
test_serialize_weights_pickle | 0.4606s | 0.3886s | 2.5736 Ops/s | 2.4383 Ops/s | |
test_serialize_weights_filesystem | 0.1498s | 0.1430s | 6.9930 Ops/s | 6.3431 Ops/s | |
test_serialize_model_filesystem | 0.1588s | 0.1529s | 6.5413 Ops/s | 6.7616 Ops/s | |
test_reshape_pytree | 73.2280μs | 26.5193μs | 37.7084 KOps/s | 38.0076 KOps/s | |
test_reshape_td | 86.1120μs | 32.3514μs | 30.9105 KOps/s | 30.4863 KOps/s | |
test_view_pytree | 64.8420μs | 26.3876μs | 37.8966 KOps/s | 38.2040 KOps/s | |
test_view_td | 93.9170μs | 37.3281μs | 26.7895 KOps/s | 26.0776 KOps/s | |
test_unbind_pytree | 67.7470μs | 29.7588μs | 33.6035 KOps/s | 30.4332 KOps/s | |
test_unbind_td | 0.3391ms | 38.8686μs | 25.7277 KOps/s | 26.5385 KOps/s | |
test_split_pytree | 86.3120μs | 29.6328μs | 33.7463 KOps/s | 34.1643 KOps/s | |
test_split_td | 0.2105ms | 45.0402μs | 22.2024 KOps/s | 22.1524 KOps/s | |
test_add_pytree | 98.9160μs | 37.0454μs | 26.9939 KOps/s | 28.0430 KOps/s | |
test_add_td | 0.1372ms | 52.1143μs | 19.1886 KOps/s | 17.6811 KOps/s | |
test_compile_add_one_nested[tensordict-compile] | 0.1383ms | 62.8313μs | 15.9156 KOps/s | 16.1546 KOps/s | |
test_compile_add_one_nested[tensordict-eager] | 0.3569ms | 0.1591ms | 6.2846 KOps/s | 6.0755 KOps/s | |
test_compile_add_one_nested[pytree-compile] | 0.1191ms | 45.6700μs | 21.8962 KOps/s | 22.2425 KOps/s | |
test_compile_add_one_nested[pytree-eager] | 0.2253ms | 0.1182ms | 8.4623 KOps/s | 8.6257 KOps/s | |
test_compile_copy_nested[tensordict-compile] | 90.3680μs | 26.3506μs | 37.9498 KOps/s | 37.2323 KOps/s | |
test_compile_copy_nested[tensordict-eager] | 0.1284ms | 53.3241μs | 18.7533 KOps/s | 18.7608 KOps/s | |
test_compile_copy_nested[pytree-compile] | 0.1539ms | 77.5905μs | 12.8882 KOps/s | 12.8815 KOps/s | |
test_compile_copy_nested[pytree-eager] | 0.1274ms | 67.1972μs | 14.8816 KOps/s | 15.0068 KOps/s | |
test_compile_add_one_flat[tensordict-compile] | 0.2226ms | 0.1051ms | 9.5146 KOps/s | 9.3939 KOps/s | |
test_compile_add_one_flat[tensordict-eager] | 0.3917ms | 0.1991ms | 5.0231 KOps/s | 5.0451 KOps/s | |
test_compile_add_one_flat[tensorclass-compile] | 0.1154ms | 44.4510μs | 22.4967 KOps/s | 22.4246 KOps/s | |
test_compile_add_one_flat[tensorclass-eager] | 0.4866ms | 61.6273μs | 16.2266 KOps/s | 16.0143 KOps/s | |
test_compile_add_one_flat[pytree-compile] | 0.1873ms | 0.1024ms | 9.7631 KOps/s | 9.7222 KOps/s | |
test_compile_add_one_flat[pytree-eager] | 0.2824ms | 0.2003ms | 4.9915 KOps/s | 5.0283 KOps/s | |
test_compile_add_self_flat[tensordict-eager] | 0.4079ms | 0.2093ms | 4.7779 KOps/s | 4.7947 KOps/s | |
test_compile_add_self_flat[tensordict-compile] | 0.2268ms | 0.1049ms | 9.5336 KOps/s | 9.3494 KOps/s | |
test_compile_add_self_flat[tensorclass-eager] | 0.1646ms | 53.6788μs | 18.6293 KOps/s | 18.3235 KOps/s | |
test_compile_add_self_flat[tensorclass-compile] | 0.1123ms | 47.5163μs | 21.0454 KOps/s | 21.8724 KOps/s | |
test_compile_add_self_flat[pytree-eager] | 0.2553ms | 0.1575ms | 6.3491 KOps/s | 6.3557 KOps/s | |
test_compile_add_self_flat[pytree-compile] | 0.1989ms | 0.1048ms | 9.5375 KOps/s | 9.7171 KOps/s | |
test_compile_copy_flat[tensordict-compile] | 77.6860μs | 21.1146μs | 47.3606 KOps/s | 47.0821 KOps/s | |
test_compile_copy_flat[tensordict-eager] | 0.1236ms | 58.6223μs | 17.0583 KOps/s | 16.6006 KOps/s | |
test_compile_copy_flat[pytree-compile] | 0.1585ms | 81.6792μs | 12.2430 KOps/s | 12.4286 KOps/s | |
test_compile_copy_flat[pytree-eager] | 0.1462ms | 68.8247μs | 14.5297 KOps/s | 14.5691 KOps/s | |
test_compile_assign_and_add[tensordict-compile] | 0.2890ms | 0.2082ms | 4.8028 KOps/s | 4.8504 KOps/s | |
test_compile_assign_and_add[tensordict-eager] | 2.3569ms | 1.2857ms | 777.7803 Ops/s | 769.0140 Ops/s | |
test_compile_assign_and_add[pytree-compile] | 0.2992ms | 0.2048ms | 4.8817 KOps/s | 4.9285 KOps/s | |
test_compile_assign_and_add[pytree-eager] | 1.3443ms | 0.7787ms | 1.2842 KOps/s | 1.2655 KOps/s | |
test_compile_assign_and_add_stack[compile] | 0.5584ms | 0.4610ms | 2.1693 KOps/s | 2.2237 KOps/s | |
test_compile_assign_and_add_stack[eager] | 2.8630ms | 2.5768ms | 388.0730 Ops/s | 377.9510 Ops/s | |
test_compile_indexing[tensor-tensordict-compile] | 99.1560μs | 35.7724μs | 27.9545 KOps/s | 28.1310 KOps/s | |
test_compile_indexing[tensor-tensordict-eager] | 0.5936ms | 31.4142μs | 31.8327 KOps/s | 29.8791 KOps/s | |
test_compile_indexing[tensor-tensorclass-compile] | 82.6150μs | 29.3236μs | 34.1022 KOps/s | 34.5503 KOps/s | |
test_compile_indexing[tensor-tensorclass-eager] | 82.0540μs | 23.0641μs | 43.3574 KOps/s | 43.0647 KOps/s | |
test_compile_indexing[tensor-pytree-compile] | 96.2010μs | 29.7441μs | 33.6201 KOps/s | 33.3625 KOps/s | |
test_compile_indexing[tensor-pytree-eager] | 69.4400μs | 23.0337μs | 43.4147 KOps/s | 43.8591 KOps/s | |
test_compile_indexing[slice-tensordict-compile] | 0.1303ms | 51.6890μs | 19.3465 KOps/s | 19.5033 KOps/s | |
test_compile_indexing[slice-tensordict-eager] | 0.5301ms | 19.5706μs | 51.0971 KOps/s | 49.1017 KOps/s | |
test_compile_indexing[slice-tensorclass-compile] | 0.1153ms | 44.5426μs | 22.4504 KOps/s | 23.1412 KOps/s | |
test_compile_indexing[slice-tensorclass-eager] | 53.1900μs | 18.6144μs | 53.7218 KOps/s | 53.6036 KOps/s | |
test_compile_indexing[slice-pytree-compile] | 0.1014ms | 45.5499μs | 21.9540 KOps/s | 22.0459 KOps/s | |
test_compile_indexing[slice-pytree-eager] | 56.0160μs | 18.5731μs | 53.8414 KOps/s | 53.5934 KOps/s | |
test_compile_indexing[int-tensordict-compile] | 0.1185ms | 51.5310μs | 19.4058 KOps/s | 19.2251 KOps/s | |
test_compile_indexing[int-tensordict-eager] | 0.9901ms | 19.4104μs | 51.5188 KOps/s | 49.4009 KOps/s | |
test_compile_indexing[int-tensorclass-compile] | 95.5390μs | 45.1816μs | 22.1329 KOps/s | 22.6877 KOps/s | |
test_compile_indexing[int-tensorclass-eager] | 75.8430μs | 18.7214μs | 53.4148 KOps/s | 53.4740 KOps/s | |
test_compile_indexing[int-pytree-compile] | 99.3770μs | 45.3624μs | 22.0447 KOps/s | 22.6795 KOps/s | |
test_compile_indexing[int-pytree-eager] | 61.2450μs | 18.7184μs | 53.4234 KOps/s | 53.7596 KOps/s | |
test_mod_add[eager] | 99.0960μs | 33.9476μs | 29.4572 KOps/s | 28.8444 KOps/s | |
test_mod_add[compile] | 0.1182ms | 46.8090μs | 21.3634 KOps/s | 21.3281 KOps/s | |
test_mod_add[compile-overhead] | 0.1058ms | 47.9714μs | 20.8458 KOps/s | 21.0914 KOps/s | |
test_mod_wrap[eager] | 0.3845ms | 0.2245ms | 4.4548 KOps/s | 4.4231 KOps/s | |
test_mod_wrap[compile] | 0.4487ms | 0.2060ms | 4.8539 KOps/s | 4.8740 KOps/s | |
test_mod_wrap[compile-overhead] | 0.5559ms | 0.2117ms | 4.7236 KOps/s | 4.9439 KOps/s | |
test_mod_wrap_and_backward[eager] | 13.0287ms | 10.6244ms | 94.1232 Ops/s | 91.9848 Ops/s | |
test_mod_wrap_and_backward[compile] | 11.7214ms | 10.4745ms | 95.4698 Ops/s | 87.0173 Ops/s | |
test_mod_wrap_and_backward[compile-overhead] | 12.0177ms | 10.4550ms | 95.6483 Ops/s | 75.5744 Ops/s | |
test_seq_add[eager] | 0.2311ms | 0.1091ms | 9.1685 KOps/s | 8.6007 KOps/s | |
test_seq_add[compile] | 0.1144ms | 61.2347μs | 16.3306 KOps/s | 16.2961 KOps/s | |
test_seq_add[compile-overhead] | 0.1284ms | 59.2254μs | 16.8846 KOps/s | 16.1865 KOps/s | |
test_seq_wrap[eager] | 0.5836ms | 0.4294ms | 2.3287 KOps/s | 2.2176 KOps/s | |
test_seq_wrap[compile] | 0.3332ms | 0.2245ms | 4.4549 KOps/s | 4.2242 KOps/s | |
test_seq_wrap[compile-overhead] | 0.4730ms | 0.2218ms | 4.5081 KOps/s | 4.3491 KOps/s | |
test_func_call_runtime[False-eager] | 0.9358ms | 0.5584ms | 1.7907 KOps/s | 1.8051 KOps/s | |
test_func_call_runtime[False-compile] | 0.8850ms | 0.4211ms | 2.3746 KOps/s | 2.3583 KOps/s | |
test_func_call_runtime[False-compile-overhead] | 0.5333ms | 0.4218ms | 2.3708 KOps/s | 2.3371 KOps/s | |
test_func_call_runtime[True-eager] | 0.9827ms | 0.7662ms | 1.3052 KOps/s | 1.3247 KOps/s | |
test_func_call_runtime[True-compile] | 0.8454ms | 0.4607ms | 2.1705 KOps/s | 2.1584 KOps/s | |
test_func_call_runtime[True-compile-overhead] | 0.8868ms | 0.4636ms | 2.1569 KOps/s | 2.1445 KOps/s | |
test_func_call_cm_runtime[False-eager] | 0.6905ms | 0.5587ms | 1.7899 KOps/s | 1.8231 KOps/s | |
test_func_call_cm_runtime[False-compile] | 0.5706ms | 0.4195ms | 2.3838 KOps/s | 2.3483 KOps/s | |
test_func_call_cm_runtime[False-compile-overhead] | 0.8121ms | 0.4245ms | 2.3556 KOps/s | 2.3554 KOps/s | |
test_func_call_cm_runtime[True-eager] | 1.0853ms | 0.8971ms | 1.1147 KOps/s | 1.1241 KOps/s | |
test_func_call_cm_runtime[True-compile] | 0.5873ms | 0.4832ms | 2.0694 KOps/s | 2.0359 KOps/s | |
test_func_call_cm_runtime[True-compile-overhead] | 0.6830ms | 0.4838ms | 2.0671 KOps/s | 2.0468 KOps/s | |
test_vmap_func_call_cm_runtime[eager] | 2.6881ms | 1.9035ms | 525.3592 Ops/s | 536.4337 Ops/s | |
test_vmap_func_call_cm_runtime[compile] | 0.8727ms | 0.5139ms | 1.9459 KOps/s | 1.9275 KOps/s | |
test_vmap_func_call_cm_runtime[compile-overhead] | 0.8995ms | 0.5162ms | 1.9371 KOps/s | 1.9240 KOps/s | |
test_distributed | 0.2346ms | 0.1230ms | 8.1328 KOps/s | 7.7074 KOps/s | |
test_tdmodule | 79.4000μs | 26.3021μs | 38.0198 KOps/s | 36.2446 KOps/s | |
test_tdmodule_dispatch | 84.4390μs | 48.1526μs | 20.7673 KOps/s | 19.7861 KOps/s | |
test_tdseq | 44.9350μs | 25.4815μs | 39.2441 KOps/s | 37.6630 KOps/s | |
test_tdseq_dispatch | 94.7380μs | 51.0710μs | 19.5806 KOps/s | 19.3236 KOps/s | |
test_instantiation_functorch | 3.6096ms | 1.5540ms | 643.5206 Ops/s | 644.9473 Ops/s | |
test_exec_functorch | 0.3403ms | 0.1794ms | 5.5744 KOps/s | 5.4470 KOps/s | |
test_exec_functional_call | 0.3034ms | 0.1764ms | 5.6677 KOps/s | 5.6458 KOps/s | |
test_exec_td_decorator | 0.4600ms | 0.2304ms | 4.3403 KOps/s | 4.2387 KOps/s | |
test_vmap_mlp_speed_decorator[True-True] | 0.8123ms | 0.6429ms | 1.5554 KOps/s | 1.5090 KOps/s | |
test_vmap_mlp_speed_decorator[True-False] | 0.8991ms | 0.6395ms | 1.5637 KOps/s | 1.5451 KOps/s | |
test_vmap_mlp_speed_decorator[False-True] | 0.6947ms | 0.5227ms | 1.9130 KOps/s | 1.8659 KOps/s | |
test_vmap_mlp_speed_decorator[False-False] | 0.8552ms | 0.5216ms | 1.9171 KOps/s | 1.9363 KOps/s | |
test_to_module_speed[True] | 1.5902ms | 1.2732ms | 785.4463 Ops/s | 765.5498 Ops/s | |
test_to_module_speed[False] | 1.7780ms | 1.2428ms | 804.6076 Ops/s | 792.8590 Ops/s | |
test_tc_init | 90.5800μs | 45.3000μs | 22.0751 KOps/s | 21.8956 KOps/s | |
test_tc_init_nested | 0.1802ms | 89.7667μs | 11.1400 KOps/s | 10.6963 KOps/s | |
test_tc_first_layer_tensor | 44.1130μs | 1.5612μs | 640.5296 KOps/s | 676.9098 KOps/s | |
test_tc_first_layer_nontensor | 27.0510μs | 4.8111μs | 207.8514 KOps/s | 211.0246 KOps/s | |
test_tc_second_layer_tensor | 50.8240μs | 2.8143μs | 355.3222 KOps/s | 360.3567 KOps/s | |
test_tc_second_layer_nontensor | 32.6710μs | 6.1400μs | 162.8661 KOps/s | 166.7244 KOps/s | |
test_unbind | 0.2159s | 12.2421ms | 81.6851 Ops/s | 83.5344 Ops/s | |
test_full_like | 7.9797ms | 7.1593ms | 139.6784 Ops/s | 89.7497 Ops/s | |
test_zeros_like | 3.6185ms | 2.8689ms | 348.5651 Ops/s | 142.7225 Ops/s | |
test_ones_like | 3.6443ms | 3.2295ms | 309.6461 Ops/s | 126.0973 Ops/s | |
test_clone | 5.6892ms | 5.0453ms | 198.2053 Ops/s | 106.3170 Ops/s | |
test_squeeze | 64.0910μs | 11.6427μs | 85.8905 KOps/s | 84.8729 KOps/s | |
test_unsqueeze | 0.3355ms | 88.5764μs | 11.2897 KOps/s | 11.2187 KOps/s | |
test_split | 0.3390ms | 0.1877ms | 5.3290 KOps/s | 5.1156 KOps/s | |
test_permute | 0.3031ms | 0.2170ms | 4.6075 KOps/s | 4.5408 KOps/s | |
test_stack | 32.4591ms | 25.9396ms | 38.5511 Ops/s | 39.9710 Ops/s | |
test_cat | 30.4513ms | 25.8518ms | 38.6821 Ops/s | 41.0037 Ops/s |
ghstack-source-id: 56c1dd2ad856a18613ec1a4c0ca70aedd28a52e3 Pull Request resolved: #1112
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm, thanks!
|
||
Keyword Args: | ||
aggregate_probabilities (bool, optional): if provided, overrides the default ``aggregate_probabilities`` | ||
from the class. | ||
include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict. | ||
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default). | |
Defaults to ``self.include_sum`` which is set through the class constructor (``True`` by default). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok correcting these typos in follow-up PR
category=DeprecationWarning, | ||
) | ||
if include_sum: | ||
slp = 0.0 | ||
d = {} | ||
for name, dist in self.dists.items(): | ||
d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: if I am to use include_sum=False
to go through a custom aggregation downstream, should I also be in charge of setting the key for the log-probs -- instead of the default _add_suffix(name, "_log_prob")
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if you want different names and not the sum, I guess you could do inplace=False
, include_sum=False
and then rename?
Or we could allow you to pass a function that maps the name to a new log-prob name?
Open to suggestions
@@ -320,9 +427,10 @@ def entropy_composite(self, samples_mc=1, include_sum=True) -> TensorDictBase: | |||
x = dist.rsample((samples_mc,)) | |||
e = -dist.log_prob(x).mean(0) | |||
d[_add_suffix(name, "_entropy")] = e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ditto
This method is called by the :meth:`~.log_prob` method when ``self.aggregate_probabilities`` is ``False``. | ||
Keyword Args: | ||
include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict. | ||
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default). | |
Defaults to ``self.include_sum`` which is set through the class constructor (``True`` by default). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ditto
Stack from ghstack (oldest at bottom):
In this PR, I propose a new vision for
CompositeDistributions
where in the future:This implies:
aggregate_probabilities
(cc @albertbou92). The idea is that we actually wantaggregate_probabilities
to be False because we want this class to be an exclusively tensordict-to-tensordict class andaggregate_probabilities=True
by default would mean that we get tensors from log-prob/entropy if not asked otherwise.inplace
andinclude_sum
now control the respective behaviours. Since we want these behaviours to change in the future, we ask users to pass the value explicitly to make sure everything will hold when we make these changes.cc @louisfaury