Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Make CompositeDistribution a tensordict-exclusive class #1112

Merged
merged 2 commits into from
Nov 27, 2024

Conversation

vmoens
Copy link
Contributor

@vmoens vmoens commented Nov 26, 2024

Stack from ghstack (oldest at bottom):

In this PR, I propose a new vision for CompositeDistributions where in the future:

  • log_prob will return a tensordict with log-probs if not asked otherwise
  • the aggregate log_prob will not be computed. Users will have the opportunity to do that manually, for instance through
    lp = dist.log_prob(sample, inplace=False)
    lp = lp.sum(reduce=False)
  • the tensordict returned will not be the one used as input but another one.

This implies:

  • reverting the announced deprecation (FutureWarning) for aggregate_probabilities (cc @albertbou92). The idea is that we actually want aggregate_probabilities to be False because we want this class to be an exclusively tensordict-to-tensordict class and aggregate_probabilities=True by default would mean that we get tensors from log-prob/entropy if not asked otherwise.
  • New params, inplace and include_sum now control the respective behaviours. Since we want these behaviours to change in the future, we ask users to pass the value explicitly to make sure everything will hold when we make these changes.

cc @louisfaury

[ghstack-poisoned]
@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 26, 2024
Copy link

github-actions bot commented Nov 26, 2024

$\color{#D29922}\textsf{\Large⚠\kern{0.2cm}\normalsize Warning}$ Result of CPU Benchmark Tests

Total Benchmarks: 217. Improved: $\large\color{#35bf28}23$. Worsened: $\large\color{#d91a1a}8$.

Expand to view detailed results
Name Max Mean Ops Ops on Repo HEAD Change
test_plain_set_nested 33.2630μs 17.9849μs 55.6022 KOps/s 54.2347 KOps/s $\color{#35bf28}+2.52\%$
test_plain_set_stack_nested 41.7690μs 18.0024μs 55.5483 KOps/s 53.7194 KOps/s $\color{#35bf28}+3.40\%$
test_plain_set_nested_inplace 58.3300μs 19.4107μs 51.5179 KOps/s 49.2110 KOps/s $\color{#35bf28}+4.69\%$
test_plain_set_stack_nested_inplace 48.9220μs 19.3203μs 51.7589 KOps/s 49.3015 KOps/s $\color{#35bf28}+4.98\%$
test_items 27.5620μs 4.1305μs 242.1009 KOps/s 244.3937 KOps/s $\color{#d91a1a}-0.94\%$
test_items_nested 0.4775ms 0.3989ms 2.5071 KOps/s 2.5323 KOps/s $\color{#d91a1a}-1.00\%$
test_items_nested_locked 0.5696ms 0.3975ms 2.5156 KOps/s 2.5386 KOps/s $\color{#d91a1a}-0.91\%$
test_items_nested_leaf 0.1436ms 71.7581μs 13.9357 KOps/s 14.1993 KOps/s $\color{#d91a1a}-1.86\%$
test_items_stack_nested 0.6880ms 0.3995ms 2.5033 KOps/s 2.5312 KOps/s $\color{#d91a1a}-1.10\%$
test_items_stack_nested_leaf 0.1416ms 74.4218μs 13.4369 KOps/s 13.8415 KOps/s $\color{#d91a1a}-2.92\%$
test_items_stack_nested_locked 0.5269ms 0.4014ms 2.4916 KOps/s 2.5276 KOps/s $\color{#d91a1a}-1.43\%$
test_keys 28.1940μs 3.4985μs 285.8397 KOps/s 278.2794 KOps/s $\color{#35bf28}+2.72\%$
test_keys_nested 0.2683ms 0.1384ms 7.2270 KOps/s 7.2652 KOps/s $\color{#d91a1a}-0.53\%$
test_keys_nested_locked 1.7874ms 0.1407ms 7.1048 KOps/s 7.0364 KOps/s $\color{#35bf28}+0.97\%$
test_keys_nested_leaf 0.2133ms 0.1163ms 8.5989 KOps/s 8.5116 KOps/s $\color{#35bf28}+1.03\%$
test_keys_stack_nested 0.2176ms 0.1342ms 7.4518 KOps/s 7.3529 KOps/s $\color{#35bf28}+1.35\%$
test_keys_stack_nested_leaf 0.2115ms 0.1132ms 8.8335 KOps/s 8.5314 KOps/s $\color{#35bf28}+3.54\%$
test_keys_stack_nested_locked 0.2258ms 0.1388ms 7.2029 KOps/s 7.0878 KOps/s $\color{#35bf28}+1.62\%$
test_values 4.6728μs 1.0567μs 946.3059 KOps/s 945.8371 KOps/s $\color{#35bf28}+0.05\%$
test_values_nested 0.1073ms 54.7067μs 18.2793 KOps/s 18.0852 KOps/s $\color{#35bf28}+1.07\%$
test_values_nested_locked 0.1184ms 54.8245μs 18.2400 KOps/s 18.0965 KOps/s $\color{#35bf28}+0.79\%$
test_values_nested_leaf 0.1131ms 59.2508μs 16.8774 KOps/s 16.7136 KOps/s $\color{#35bf28}+0.98\%$
test_values_stack_nested 0.1022ms 55.5238μs 18.0103 KOps/s 17.8857 KOps/s $\color{#35bf28}+0.70\%$
test_values_stack_nested_leaf 0.1186ms 59.2160μs 16.8873 KOps/s 16.6109 KOps/s $\color{#35bf28}+1.66\%$
test_values_stack_nested_locked 0.1047ms 56.0198μs 17.8508 KOps/s 18.1138 KOps/s $\color{#d91a1a}-1.45\%$
test_membership 1.9366μs 0.7166μs 1.3954 MOps/s 1.3170 MOps/s $\textbf{\color{#35bf28}+5.95\%}$
test_membership_nested 23.7650μs 2.9012μs 344.6881 KOps/s 347.6111 KOps/s $\color{#d91a1a}-0.84\%$
test_membership_nested_leaf 25.9480μs 2.9279μs 341.5436 KOps/s 346.1334 KOps/s $\color{#d91a1a}-1.33\%$
test_membership_stacked_nested 27.9320μs 2.9185μs 342.6475 KOps/s 349.1912 KOps/s $\color{#d91a1a}-1.87\%$
test_membership_stacked_nested_leaf 28.4140μs 2.9247μs 341.9126 KOps/s 348.4564 KOps/s $\color{#d91a1a}-1.88\%$
test_membership_nested_last 35.7170μs 4.2956μs 232.7958 KOps/s 237.6342 KOps/s $\color{#d91a1a}-2.04\%$
test_membership_nested_leaf_last 34.4050μs 4.3080μs 232.1256 KOps/s 239.2872 KOps/s $\color{#d91a1a}-2.99\%$
test_membership_stacked_nested_last 41.9390μs 13.2450μs 75.5004 KOps/s 242.0787 KOps/s $\textbf{\color{#d91a1a}-68.81\%}$
test_membership_stacked_nested_leaf_last 44.9850μs 13.0690μs 76.5169 KOps/s 238.2100 KOps/s $\textbf{\color{#d91a1a}-67.88\%}$
test_nested_getleaf 0.1432ms 10.9273μs 91.5140 KOps/s 91.6590 KOps/s $\color{#d91a1a}-0.16\%$
test_nested_get 41.3780μs 10.1566μs 98.4580 KOps/s 96.3632 KOps/s $\color{#35bf28}+2.17\%$
test_stacked_getleaf 46.9180μs 10.6082μs 94.2669 KOps/s 93.3225 KOps/s $\color{#35bf28}+1.01\%$
test_stacked_get 35.4770μs 10.0706μs 99.2994 KOps/s 96.7024 KOps/s $\color{#35bf28}+2.69\%$
test_nested_getitemleaf 35.4670μs 11.0375μs 90.6005 KOps/s 89.6672 KOps/s $\color{#35bf28}+1.04\%$
test_nested_getitem 47.5200μs 10.4266μs 95.9084 KOps/s 94.5165 KOps/s $\color{#35bf28}+1.47\%$
test_stacked_getitemleaf 35.0970μs 11.0269μs 90.6876 KOps/s 86.5849 KOps/s $\color{#35bf28}+4.74\%$
test_stacked_getitem 35.4160μs 10.5535μs 94.7557 KOps/s 93.8284 KOps/s $\color{#35bf28}+0.99\%$
test_lock_nested 2.6680ms 0.4421ms 2.2619 KOps/s 2.2906 KOps/s $\color{#d91a1a}-1.25\%$
test_lock_stack_nested 0.6546ms 0.4033ms 2.4797 KOps/s 2.4368 KOps/s $\color{#35bf28}+1.76\%$
test_unlock_nested 0.7564ms 0.3589ms 2.7864 KOps/s 2.7513 KOps/s $\color{#35bf28}+1.28\%$
test_unlock_stack_nested 0.3906ms 0.3197ms 3.1275 KOps/s 3.0301 KOps/s $\color{#35bf28}+3.22\%$
test_flatten_speed 0.1760ms 95.4035μs 10.4818 KOps/s 10.5116 KOps/s $\color{#d91a1a}-0.28\%$
test_unflatten_speed 0.8873ms 0.4895ms 2.0431 KOps/s 1.9908 KOps/s $\color{#35bf28}+2.62\%$
test_common_ops 4.3079ms 0.7809ms 1.2806 KOps/s 1.2314 KOps/s $\color{#35bf28}+3.99\%$
test_creation 18.5550μs 2.0945μs 477.4384 KOps/s 484.6501 KOps/s $\color{#d91a1a}-1.49\%$
test_creation_empty 29.9170μs 10.7858μs 92.7146 KOps/s 84.3006 KOps/s $\textbf{\color{#35bf28}+9.98\%}$
test_creation_nested_1 35.7380μs 13.4846μs 74.1586 KOps/s 67.7603 KOps/s $\textbf{\color{#35bf28}+9.44\%}$
test_creation_nested_2 40.4460μs 17.9285μs 55.7770 KOps/s 52.4013 KOps/s $\textbf{\color{#35bf28}+6.44\%}$
test_clone 77.2760μs 13.5560μs 73.7680 KOps/s 76.1272 KOps/s $\color{#d91a1a}-3.10\%$
test_getitem[int] 1.4458ms 12.6964μs 78.7622 KOps/s 80.4285 KOps/s $\color{#d91a1a}-2.07\%$
test_getitem[slice_int] 0.1511ms 24.8815μs 40.1905 KOps/s 40.5477 KOps/s $\color{#d91a1a}-0.88\%$
test_getitem[range] 0.1700ms 47.7529μs 20.9412 KOps/s 20.7066 KOps/s $\color{#35bf28}+1.13\%$
test_getitem[tuple] 0.1348ms 20.1779μs 49.5591 KOps/s 50.0935 KOps/s $\color{#d91a1a}-1.07\%$
test_getitem[list] 0.1718ms 42.9698μs 23.2721 KOps/s 22.6857 KOps/s $\color{#35bf28}+2.59\%$
test_setitem_dim[int] 47.5190μs 25.5636μs 39.1181 KOps/s 40.3298 KOps/s $\color{#d91a1a}-3.00\%$
test_setitem_dim[slice_int] 90.2300μs 51.6434μs 19.3636 KOps/s 18.9154 KOps/s $\color{#35bf28}+2.37\%$
test_setitem_dim[range] 0.1217ms 72.7126μs 13.7528 KOps/s 13.6224 KOps/s $\color{#35bf28}+0.96\%$
test_setitem_dim[tuple] 82.0240μs 41.1582μs 24.2965 KOps/s 24.4277 KOps/s $\color{#d91a1a}-0.54\%$
test_setitem 75.0210μs 20.0138μs 49.9656 KOps/s 47.5719 KOps/s $\textbf{\color{#35bf28}+5.03\%}$
test_set 73.1170μs 19.5903μs 51.0456 KOps/s 48.8373 KOps/s $\color{#35bf28}+4.52\%$
test_set_shared 3.6958ms 0.1716ms 5.8273 KOps/s 5.9031 KOps/s $\color{#d91a1a}-1.28\%$
test_update 0.1275ms 22.0722μs 45.3059 KOps/s 41.5109 KOps/s $\textbf{\color{#35bf28}+9.14\%}$
test_update_nested 0.1141ms 31.8241μs 31.4227 KOps/s 29.4353 KOps/s $\textbf{\color{#35bf28}+6.75\%}$
test_update__nested 0.6621ms 31.5195μs 31.7264 KOps/s 31.8965 KOps/s $\color{#d91a1a}-0.53\%$
test_set_nested 69.5510μs 21.5120μs 46.4858 KOps/s 44.7543 KOps/s $\color{#35bf28}+3.87\%$
test_set_nested_new 78.0160μs 26.3263μs 37.9849 KOps/s 36.5497 KOps/s $\color{#35bf28}+3.93\%$
test_select 0.1210ms 42.1883μs 23.7033 KOps/s 23.0177 KOps/s $\color{#35bf28}+2.98\%$
test_select_nested 0.1224ms 59.2645μs 16.8735 KOps/s 16.7705 KOps/s $\color{#35bf28}+0.61\%$
test_exclude_nested 0.1750ms 77.8555μs 12.8443 KOps/s 12.6870 KOps/s $\color{#35bf28}+1.24\%$
test_empty[True] 0.4982ms 0.3785ms 2.6419 KOps/s 2.5959 KOps/s $\color{#35bf28}+1.77\%$
test_empty[False] 7.2288μs 1.2328μs 811.1655 KOps/s 836.8637 KOps/s $\color{#d91a1a}-3.07\%$
test_unbind_speed 0.5311ms 0.2639ms 3.7896 KOps/s 3.8938 KOps/s $\color{#d91a1a}-2.68\%$
test_unbind_speed_stack0 0.3348ms 0.2538ms 3.9406 KOps/s 3.8720 KOps/s $\color{#35bf28}+1.77\%$
test_unbind_speed_stack1 97.2553ms 0.7396ms 1.3522 KOps/s 1.4247 KOps/s $\textbf{\color{#d91a1a}-5.09\%}$
test_split 0.1006s 1.7240ms 580.0613 Ops/s 582.3280 Ops/s $\color{#d91a1a}-0.39\%$
test_chunk 97.6427ms 1.7296ms 578.1631 Ops/s 581.9736 Ops/s $\color{#d91a1a}-0.65\%$
test_consolidate_njt[False-None] 8.6050ms 8.1478ms 122.7319 Ops/s 122.9008 Ops/s $\color{#d91a1a}-0.14\%$
test_creation[device0] 0.2218ms 90.9914μs 10.9900 KOps/s 10.9344 KOps/s $\color{#35bf28}+0.51\%$
test_creation_from_tensor 4.1764ms 94.7173μs 10.5577 KOps/s 10.4016 KOps/s $\color{#35bf28}+1.50\%$
test_add_one[memmap_tensor0] 0.1311ms 5.0545μs 197.8436 KOps/s 210.0095 KOps/s $\textbf{\color{#d91a1a}-5.79\%}$
test_contiguous[memmap_tensor0] 23.1940μs 0.5092μs 1.9638 MOps/s 1.9358 MOps/s $\color{#35bf28}+1.45\%$
test_stack[memmap_tensor0] 29.6450μs 3.5017μs 285.5784 KOps/s 295.1797 KOps/s $\color{#d91a1a}-3.25\%$
test_memmaptd_index 1.0604ms 0.2545ms 3.9298 KOps/s 4.2925 KOps/s $\textbf{\color{#d91a1a}-8.45\%}$
test_memmaptd_index_astensor 0.7244ms 0.3201ms 3.1236 KOps/s 3.2266 KOps/s $\color{#d91a1a}-3.19\%$
test_memmaptd_index_op 0.9456ms 0.5774ms 1.7319 KOps/s 1.7048 KOps/s $\color{#35bf28}+1.59\%$
test_serialize_model 0.1192s 0.1144s 8.7377 Ops/s 7.6011 Ops/s $\textbf{\color{#35bf28}+14.95\%}$
test_serialize_model_pickle 0.4415s 0.3894s 2.5680 Ops/s 2.5089 Ops/s $\color{#35bf28}+2.35\%$
test_serialize_weights 0.2050s 0.1267s 7.8929 Ops/s 8.6005 Ops/s $\textbf{\color{#d91a1a}-8.23\%}$
test_serialize_weights_returnearly 0.1717s 0.1570s 6.3704 Ops/s 6.4481 Ops/s $\color{#d91a1a}-1.20\%$
test_serialize_weights_pickle 0.4606s 0.3886s 2.5736 Ops/s 2.4383 Ops/s $\textbf{\color{#35bf28}+5.55\%}$
test_serialize_weights_filesystem 0.1498s 0.1430s 6.9930 Ops/s 6.3431 Ops/s $\textbf{\color{#35bf28}+10.25\%}$
test_serialize_model_filesystem 0.1588s 0.1529s 6.5413 Ops/s 6.7616 Ops/s $\color{#d91a1a}-3.26\%$
test_reshape_pytree 73.2280μs 26.5193μs 37.7084 KOps/s 38.0076 KOps/s $\color{#d91a1a}-0.79\%$
test_reshape_td 86.1120μs 32.3514μs 30.9105 KOps/s 30.4863 KOps/s $\color{#35bf28}+1.39\%$
test_view_pytree 64.8420μs 26.3876μs 37.8966 KOps/s 38.2040 KOps/s $\color{#d91a1a}-0.80\%$
test_view_td 93.9170μs 37.3281μs 26.7895 KOps/s 26.0776 KOps/s $\color{#35bf28}+2.73\%$
test_unbind_pytree 67.7470μs 29.7588μs 33.6035 KOps/s 30.4332 KOps/s $\textbf{\color{#35bf28}+10.42\%}$
test_unbind_td 0.3391ms 38.8686μs 25.7277 KOps/s 26.5385 KOps/s $\color{#d91a1a}-3.06\%$
test_split_pytree 86.3120μs 29.6328μs 33.7463 KOps/s 34.1643 KOps/s $\color{#d91a1a}-1.22\%$
test_split_td 0.2105ms 45.0402μs 22.2024 KOps/s 22.1524 KOps/s $\color{#35bf28}+0.23\%$
test_add_pytree 98.9160μs 37.0454μs 26.9939 KOps/s 28.0430 KOps/s $\color{#d91a1a}-3.74\%$
test_add_td 0.1372ms 52.1143μs 19.1886 KOps/s 17.6811 KOps/s $\textbf{\color{#35bf28}+8.53\%}$
test_compile_add_one_nested[tensordict-compile] 0.1383ms 62.8313μs 15.9156 KOps/s 16.1546 KOps/s $\color{#d91a1a}-1.48\%$
test_compile_add_one_nested[tensordict-eager] 0.3569ms 0.1591ms 6.2846 KOps/s 6.0755 KOps/s $\color{#35bf28}+3.44\%$
test_compile_add_one_nested[pytree-compile] 0.1191ms 45.6700μs 21.8962 KOps/s 22.2425 KOps/s $\color{#d91a1a}-1.56\%$
test_compile_add_one_nested[pytree-eager] 0.2253ms 0.1182ms 8.4623 KOps/s 8.6257 KOps/s $\color{#d91a1a}-1.89\%$
test_compile_copy_nested[tensordict-compile] 90.3680μs 26.3506μs 37.9498 KOps/s 37.2323 KOps/s $\color{#35bf28}+1.93\%$
test_compile_copy_nested[tensordict-eager] 0.1284ms 53.3241μs 18.7533 KOps/s 18.7608 KOps/s $\color{#d91a1a}-0.04\%$
test_compile_copy_nested[pytree-compile] 0.1539ms 77.5905μs 12.8882 KOps/s 12.8815 KOps/s $\color{#35bf28}+0.05\%$
test_compile_copy_nested[pytree-eager] 0.1274ms 67.1972μs 14.8816 KOps/s 15.0068 KOps/s $\color{#d91a1a}-0.83\%$
test_compile_add_one_flat[tensordict-compile] 0.2226ms 0.1051ms 9.5146 KOps/s 9.3939 KOps/s $\color{#35bf28}+1.29\%$
test_compile_add_one_flat[tensordict-eager] 0.3917ms 0.1991ms 5.0231 KOps/s 5.0451 KOps/s $\color{#d91a1a}-0.44\%$
test_compile_add_one_flat[tensorclass-compile] 0.1154ms 44.4510μs 22.4967 KOps/s 22.4246 KOps/s $\color{#35bf28}+0.32\%$
test_compile_add_one_flat[tensorclass-eager] 0.4866ms 61.6273μs 16.2266 KOps/s 16.0143 KOps/s $\color{#35bf28}+1.33\%$
test_compile_add_one_flat[pytree-compile] 0.1873ms 0.1024ms 9.7631 KOps/s 9.7222 KOps/s $\color{#35bf28}+0.42\%$
test_compile_add_one_flat[pytree-eager] 0.2824ms 0.2003ms 4.9915 KOps/s 5.0283 KOps/s $\color{#d91a1a}-0.73\%$
test_compile_add_self_flat[tensordict-eager] 0.4079ms 0.2093ms 4.7779 KOps/s 4.7947 KOps/s $\color{#d91a1a}-0.35\%$
test_compile_add_self_flat[tensordict-compile] 0.2268ms 0.1049ms 9.5336 KOps/s 9.3494 KOps/s $\color{#35bf28}+1.97\%$
test_compile_add_self_flat[tensorclass-eager] 0.1646ms 53.6788μs 18.6293 KOps/s 18.3235 KOps/s $\color{#35bf28}+1.67\%$
test_compile_add_self_flat[tensorclass-compile] 0.1123ms 47.5163μs 21.0454 KOps/s 21.8724 KOps/s $\color{#d91a1a}-3.78\%$
test_compile_add_self_flat[pytree-eager] 0.2553ms 0.1575ms 6.3491 KOps/s 6.3557 KOps/s $\color{#d91a1a}-0.10\%$
test_compile_add_self_flat[pytree-compile] 0.1989ms 0.1048ms 9.5375 KOps/s 9.7171 KOps/s $\color{#d91a1a}-1.85\%$
test_compile_copy_flat[tensordict-compile] 77.6860μs 21.1146μs 47.3606 KOps/s 47.0821 KOps/s $\color{#35bf28}+0.59\%$
test_compile_copy_flat[tensordict-eager] 0.1236ms 58.6223μs 17.0583 KOps/s 16.6006 KOps/s $\color{#35bf28}+2.76\%$
test_compile_copy_flat[pytree-compile] 0.1585ms 81.6792μs 12.2430 KOps/s 12.4286 KOps/s $\color{#d91a1a}-1.49\%$
test_compile_copy_flat[pytree-eager] 0.1462ms 68.8247μs 14.5297 KOps/s 14.5691 KOps/s $\color{#d91a1a}-0.27\%$
test_compile_assign_and_add[tensordict-compile] 0.2890ms 0.2082ms 4.8028 KOps/s 4.8504 KOps/s $\color{#d91a1a}-0.98\%$
test_compile_assign_and_add[tensordict-eager] 2.3569ms 1.2857ms 777.7803 Ops/s 769.0140 Ops/s $\color{#35bf28}+1.14\%$
test_compile_assign_and_add[pytree-compile] 0.2992ms 0.2048ms 4.8817 KOps/s 4.9285 KOps/s $\color{#d91a1a}-0.95\%$
test_compile_assign_and_add[pytree-eager] 1.3443ms 0.7787ms 1.2842 KOps/s 1.2655 KOps/s $\color{#35bf28}+1.48\%$
test_compile_assign_and_add_stack[compile] 0.5584ms 0.4610ms 2.1693 KOps/s 2.2237 KOps/s $\color{#d91a1a}-2.45\%$
test_compile_assign_and_add_stack[eager] 2.8630ms 2.5768ms 388.0730 Ops/s 377.9510 Ops/s $\color{#35bf28}+2.68\%$
test_compile_indexing[tensor-tensordict-compile] 99.1560μs 35.7724μs 27.9545 KOps/s 28.1310 KOps/s $\color{#d91a1a}-0.63\%$
test_compile_indexing[tensor-tensordict-eager] 0.5936ms 31.4142μs 31.8327 KOps/s 29.8791 KOps/s $\textbf{\color{#35bf28}+6.54\%}$
test_compile_indexing[tensor-tensorclass-compile] 82.6150μs 29.3236μs 34.1022 KOps/s 34.5503 KOps/s $\color{#d91a1a}-1.30\%$
test_compile_indexing[tensor-tensorclass-eager] 82.0540μs 23.0641μs 43.3574 KOps/s 43.0647 KOps/s $\color{#35bf28}+0.68\%$
test_compile_indexing[tensor-pytree-compile] 96.2010μs 29.7441μs 33.6201 KOps/s 33.3625 KOps/s $\color{#35bf28}+0.77\%$
test_compile_indexing[tensor-pytree-eager] 69.4400μs 23.0337μs 43.4147 KOps/s 43.8591 KOps/s $\color{#d91a1a}-1.01\%$
test_compile_indexing[slice-tensordict-compile] 0.1303ms 51.6890μs 19.3465 KOps/s 19.5033 KOps/s $\color{#d91a1a}-0.80\%$
test_compile_indexing[slice-tensordict-eager] 0.5301ms 19.5706μs 51.0971 KOps/s 49.1017 KOps/s $\color{#35bf28}+4.06\%$
test_compile_indexing[slice-tensorclass-compile] 0.1153ms 44.5426μs 22.4504 KOps/s 23.1412 KOps/s $\color{#d91a1a}-2.98\%$
test_compile_indexing[slice-tensorclass-eager] 53.1900μs 18.6144μs 53.7218 KOps/s 53.6036 KOps/s $\color{#35bf28}+0.22\%$
test_compile_indexing[slice-pytree-compile] 0.1014ms 45.5499μs 21.9540 KOps/s 22.0459 KOps/s $\color{#d91a1a}-0.42\%$
test_compile_indexing[slice-pytree-eager] 56.0160μs 18.5731μs 53.8414 KOps/s 53.5934 KOps/s $\color{#35bf28}+0.46\%$
test_compile_indexing[int-tensordict-compile] 0.1185ms 51.5310μs 19.4058 KOps/s 19.2251 KOps/s $\color{#35bf28}+0.94\%$
test_compile_indexing[int-tensordict-eager] 0.9901ms 19.4104μs 51.5188 KOps/s 49.4009 KOps/s $\color{#35bf28}+4.29\%$
test_compile_indexing[int-tensorclass-compile] 95.5390μs 45.1816μs 22.1329 KOps/s 22.6877 KOps/s $\color{#d91a1a}-2.45\%$
test_compile_indexing[int-tensorclass-eager] 75.8430μs 18.7214μs 53.4148 KOps/s 53.4740 KOps/s $\color{#d91a1a}-0.11\%$
test_compile_indexing[int-pytree-compile] 99.3770μs 45.3624μs 22.0447 KOps/s 22.6795 KOps/s $\color{#d91a1a}-2.80\%$
test_compile_indexing[int-pytree-eager] 61.2450μs 18.7184μs 53.4234 KOps/s 53.7596 KOps/s $\color{#d91a1a}-0.63\%$
test_mod_add[eager] 99.0960μs 33.9476μs 29.4572 KOps/s 28.8444 KOps/s $\color{#35bf28}+2.12\%$
test_mod_add[compile] 0.1182ms 46.8090μs 21.3634 KOps/s 21.3281 KOps/s $\color{#35bf28}+0.17\%$
test_mod_add[compile-overhead] 0.1058ms 47.9714μs 20.8458 KOps/s 21.0914 KOps/s $\color{#d91a1a}-1.16\%$
test_mod_wrap[eager] 0.3845ms 0.2245ms 4.4548 KOps/s 4.4231 KOps/s $\color{#35bf28}+0.72\%$
test_mod_wrap[compile] 0.4487ms 0.2060ms 4.8539 KOps/s 4.8740 KOps/s $\color{#d91a1a}-0.41\%$
test_mod_wrap[compile-overhead] 0.5559ms 0.2117ms 4.7236 KOps/s 4.9439 KOps/s $\color{#d91a1a}-4.46\%$
test_mod_wrap_and_backward[eager] 13.0287ms 10.6244ms 94.1232 Ops/s 91.9848 Ops/s $\color{#35bf28}+2.32\%$
test_mod_wrap_and_backward[compile] 11.7214ms 10.4745ms 95.4698 Ops/s 87.0173 Ops/s $\textbf{\color{#35bf28}+9.71\%}$
test_mod_wrap_and_backward[compile-overhead] 12.0177ms 10.4550ms 95.6483 Ops/s 75.5744 Ops/s $\textbf{\color{#35bf28}+26.56\%}$
test_seq_add[eager] 0.2311ms 0.1091ms 9.1685 KOps/s 8.6007 KOps/s $\textbf{\color{#35bf28}+6.60\%}$
test_seq_add[compile] 0.1144ms 61.2347μs 16.3306 KOps/s 16.2961 KOps/s $\color{#35bf28}+0.21\%$
test_seq_add[compile-overhead] 0.1284ms 59.2254μs 16.8846 KOps/s 16.1865 KOps/s $\color{#35bf28}+4.31\%$
test_seq_wrap[eager] 0.5836ms 0.4294ms 2.3287 KOps/s 2.2176 KOps/s $\textbf{\color{#35bf28}+5.01\%}$
test_seq_wrap[compile] 0.3332ms 0.2245ms 4.4549 KOps/s 4.2242 KOps/s $\textbf{\color{#35bf28}+5.46\%}$
test_seq_wrap[compile-overhead] 0.4730ms 0.2218ms 4.5081 KOps/s 4.3491 KOps/s $\color{#35bf28}+3.66\%$
test_func_call_runtime[False-eager] 0.9358ms 0.5584ms 1.7907 KOps/s 1.8051 KOps/s $\color{#d91a1a}-0.80\%$
test_func_call_runtime[False-compile] 0.8850ms 0.4211ms 2.3746 KOps/s 2.3583 KOps/s $\color{#35bf28}+0.69\%$
test_func_call_runtime[False-compile-overhead] 0.5333ms 0.4218ms 2.3708 KOps/s 2.3371 KOps/s $\color{#35bf28}+1.44\%$
test_func_call_runtime[True-eager] 0.9827ms 0.7662ms 1.3052 KOps/s 1.3247 KOps/s $\color{#d91a1a}-1.47\%$
test_func_call_runtime[True-compile] 0.8454ms 0.4607ms 2.1705 KOps/s 2.1584 KOps/s $\color{#35bf28}+0.56\%$
test_func_call_runtime[True-compile-overhead] 0.8868ms 0.4636ms 2.1569 KOps/s 2.1445 KOps/s $\color{#35bf28}+0.58\%$
test_func_call_cm_runtime[False-eager] 0.6905ms 0.5587ms 1.7899 KOps/s 1.8231 KOps/s $\color{#d91a1a}-1.82\%$
test_func_call_cm_runtime[False-compile] 0.5706ms 0.4195ms 2.3838 KOps/s 2.3483 KOps/s $\color{#35bf28}+1.51\%$
test_func_call_cm_runtime[False-compile-overhead] 0.8121ms 0.4245ms 2.3556 KOps/s 2.3554 KOps/s $+0.01\%$
test_func_call_cm_runtime[True-eager] 1.0853ms 0.8971ms 1.1147 KOps/s 1.1241 KOps/s $\color{#d91a1a}-0.84\%$
test_func_call_cm_runtime[True-compile] 0.5873ms 0.4832ms 2.0694 KOps/s 2.0359 KOps/s $\color{#35bf28}+1.65\%$
test_func_call_cm_runtime[True-compile-overhead] 0.6830ms 0.4838ms 2.0671 KOps/s 2.0468 KOps/s $\color{#35bf28}+0.99\%$
test_vmap_func_call_cm_runtime[eager] 2.6881ms 1.9035ms 525.3592 Ops/s 536.4337 Ops/s $\color{#d91a1a}-2.06\%$
test_vmap_func_call_cm_runtime[compile] 0.8727ms 0.5139ms 1.9459 KOps/s 1.9275 KOps/s $\color{#35bf28}+0.95\%$
test_vmap_func_call_cm_runtime[compile-overhead] 0.8995ms 0.5162ms 1.9371 KOps/s 1.9240 KOps/s $\color{#35bf28}+0.68\%$
test_distributed 0.2346ms 0.1230ms 8.1328 KOps/s 7.7074 KOps/s $\textbf{\color{#35bf28}+5.52\%}$
test_tdmodule 79.4000μs 26.3021μs 38.0198 KOps/s 36.2446 KOps/s $\color{#35bf28}+4.90\%$
test_tdmodule_dispatch 84.4390μs 48.1526μs 20.7673 KOps/s 19.7861 KOps/s $\color{#35bf28}+4.96\%$
test_tdseq 44.9350μs 25.4815μs 39.2441 KOps/s 37.6630 KOps/s $\color{#35bf28}+4.20\%$
test_tdseq_dispatch 94.7380μs 51.0710μs 19.5806 KOps/s 19.3236 KOps/s $\color{#35bf28}+1.33\%$
test_instantiation_functorch 3.6096ms 1.5540ms 643.5206 Ops/s 644.9473 Ops/s $\color{#d91a1a}-0.22\%$
test_exec_functorch 0.3403ms 0.1794ms 5.5744 KOps/s 5.4470 KOps/s $\color{#35bf28}+2.34\%$
test_exec_functional_call 0.3034ms 0.1764ms 5.6677 KOps/s 5.6458 KOps/s $\color{#35bf28}+0.39\%$
test_exec_td_decorator 0.4600ms 0.2304ms 4.3403 KOps/s 4.2387 KOps/s $\color{#35bf28}+2.40\%$
test_vmap_mlp_speed_decorator[True-True] 0.8123ms 0.6429ms 1.5554 KOps/s 1.5090 KOps/s $\color{#35bf28}+3.07\%$
test_vmap_mlp_speed_decorator[True-False] 0.8991ms 0.6395ms 1.5637 KOps/s 1.5451 KOps/s $\color{#35bf28}+1.20\%$
test_vmap_mlp_speed_decorator[False-True] 0.6947ms 0.5227ms 1.9130 KOps/s 1.8659 KOps/s $\color{#35bf28}+2.52\%$
test_vmap_mlp_speed_decorator[False-False] 0.8552ms 0.5216ms 1.9171 KOps/s 1.9363 KOps/s $\color{#d91a1a}-1.00\%$
test_to_module_speed[True] 1.5902ms 1.2732ms 785.4463 Ops/s 765.5498 Ops/s $\color{#35bf28}+2.60\%$
test_to_module_speed[False] 1.7780ms 1.2428ms 804.6076 Ops/s 792.8590 Ops/s $\color{#35bf28}+1.48\%$
test_tc_init 90.5800μs 45.3000μs 22.0751 KOps/s 21.8956 KOps/s $\color{#35bf28}+0.82\%$
test_tc_init_nested 0.1802ms 89.7667μs 11.1400 KOps/s 10.6963 KOps/s $\color{#35bf28}+4.15\%$
test_tc_first_layer_tensor 44.1130μs 1.5612μs 640.5296 KOps/s 676.9098 KOps/s $\textbf{\color{#d91a1a}-5.37\%}$
test_tc_first_layer_nontensor 27.0510μs 4.8111μs 207.8514 KOps/s 211.0246 KOps/s $\color{#d91a1a}-1.50\%$
test_tc_second_layer_tensor 50.8240μs 2.8143μs 355.3222 KOps/s 360.3567 KOps/s $\color{#d91a1a}-1.40\%$
test_tc_second_layer_nontensor 32.6710μs 6.1400μs 162.8661 KOps/s 166.7244 KOps/s $\color{#d91a1a}-2.31\%$
test_unbind 0.2159s 12.2421ms 81.6851 Ops/s 83.5344 Ops/s $\color{#d91a1a}-2.21\%$
test_full_like 7.9797ms 7.1593ms 139.6784 Ops/s 89.7497 Ops/s $\textbf{\color{#35bf28}+55.63\%}$
test_zeros_like 3.6185ms 2.8689ms 348.5651 Ops/s 142.7225 Ops/s $\textbf{\color{#35bf28}+144.23\%}$
test_ones_like 3.6443ms 3.2295ms 309.6461 Ops/s 126.0973 Ops/s $\textbf{\color{#35bf28}+145.56\%}$
test_clone 5.6892ms 5.0453ms 198.2053 Ops/s 106.3170 Ops/s $\textbf{\color{#35bf28}+86.43\%}$
test_squeeze 64.0910μs 11.6427μs 85.8905 KOps/s 84.8729 KOps/s $\color{#35bf28}+1.20\%$
test_unsqueeze 0.3355ms 88.5764μs 11.2897 KOps/s 11.2187 KOps/s $\color{#35bf28}+0.63\%$
test_split 0.3390ms 0.1877ms 5.3290 KOps/s 5.1156 KOps/s $\color{#35bf28}+4.17\%$
test_permute 0.3031ms 0.2170ms 4.6075 KOps/s 4.5408 KOps/s $\color{#35bf28}+1.47\%$
test_stack 32.4591ms 25.9396ms 38.5511 Ops/s 39.9710 Ops/s $\color{#d91a1a}-3.55\%$
test_cat 30.4513ms 25.8518ms 38.6821 Ops/s 41.0037 Ops/s $\textbf{\color{#d91a1a}-5.66\%}$

[ghstack-poisoned]
@vmoens vmoens added the Refactor Refactoring code - not a new feature label Nov 26, 2024
@vmoens vmoens merged commit ee052a8 into gh/vmoens/35/base Nov 27, 2024
47 of 50 checks passed
vmoens added a commit that referenced this pull request Nov 27, 2024
ghstack-source-id: 56c1dd2ad856a18613ec1a4c0ca70aedd28a52e3
Pull Request resolved: #1112
@vmoens vmoens deleted the gh/vmoens/35/head branch November 27, 2024 14:06
Copy link

@louisfaury louisfaury left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm, thanks!


Keyword Args:
aggregate_probabilities (bool, optional): if provided, overrides the default ``aggregate_probabilities``
from the class.
include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict.
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default).
Copy link

@louisfaury louisfaury Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default).
Defaults to ``self.include_sum`` which is set through the class constructor (``True`` by default).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok correcting these typos in follow-up PR

category=DeprecationWarning,
)
if include_sum:
slp = 0.0
d = {}
for name, dist in self.dists.items():
d[_add_suffix(name, "_log_prob")] = lp = dist.log_prob(sample.get(name))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question: if I am to use include_sum=False to go through a custom aggregation downstream, should I also be in charge of setting the key for the log-probs -- instead of the default _add_suffix(name, "_log_prob") ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if you want different names and not the sum, I guess you could do inplace=False, include_sum=False and then rename?
Or we could allow you to pass a function that maps the name to a new log-prob name?
Open to suggestions

@@ -320,9 +427,10 @@ def entropy_composite(self, samples_mc=1, include_sum=True) -> TensorDictBase:
x = dist.rsample((samples_mc,))
e = -dist.log_prob(x).mean(0)
d[_add_suffix(name, "_entropy")] = e

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ditto

This method is called by the :meth:`~.log_prob` method when ``self.aggregate_probabilities`` is ``False``.
Keyword Args:
include_sum (bool, optional): Whether to include the summed log-probability in the output TensorDict.
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default).
Copy link

@louisfaury louisfaury Nov 27, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Defaults to ``self.inplace`` which is set through the class constructor (``True`` by default).
Defaults to ``self.include_sum`` which is set through the class constructor (``True`` by default).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. Refactor Refactoring code - not a new feature
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants