-
Notifications
You must be signed in to change notification settings - Fork 77
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] First class dim compatibility #525
Conversation
|
Name | Max | Mean | Ops | Ops on Repo HEAD
|
Change |
---|---|---|---|---|---|
test_plain_set_nested | 40.8030μs | 22.8378μs | 43.7870 KOps/s | 42.1301 KOps/s | |
test_plain_set_stack_nested | 0.2576ms | 0.2135ms | 4.6837 KOps/s | 4.5755 KOps/s | |
test_plain_set_nested_inplace | 0.1127ms | 27.1638μs | 36.8137 KOps/s | 35.4453 KOps/s | |
test_plain_set_stack_nested_inplace | 1.8214ms | 0.2543ms | 3.9322 KOps/s | 3.7469 KOps/s | |
test_items | 58.0050μs | 4.1970μs | 238.2634 KOps/s | 229.5948 KOps/s | |
test_items_nested | 0.5103ms | 0.4152ms | 2.4087 KOps/s | 2.3127 KOps/s | |
test_items_nested_locked | 0.5175ms | 0.4156ms | 2.4060 KOps/s | 2.3016 KOps/s | |
test_items_nested_leaf | 2.0650ms | 0.2752ms | 3.6337 KOps/s | 3.7637 KOps/s | |
test_items_stack_nested | 2.4735ms | 2.3262ms | 429.8859 Ops/s | 404.0665 Ops/s | |
test_items_stack_nested_leaf | 2.3999ms | 2.1087ms | 474.2295 Ops/s | 444.0169 Ops/s | |
test_items_stack_nested_locked | 3.1510ms | 1.2269ms | 815.0581 Ops/s | 819.9756 Ops/s | |
test_keys | 23.0020μs | 6.1172μs | 163.4735 KOps/s | 153.8631 KOps/s | |
test_keys_nested | 2.2918ms | 0.2147ms | 4.6579 KOps/s | 4.4352 KOps/s | |
test_keys_nested_locked | 0.2919ms | 0.2124ms | 4.7091 KOps/s | 4.5592 KOps/s | |
test_keys_nested_leaf | 0.3513ms | 0.2049ms | 4.8794 KOps/s | 4.3601 KOps/s | |
test_keys_stack_nested | 3.1491ms | 2.2531ms | 443.8280 Ops/s | 436.3837 Ops/s | |
test_keys_stack_nested_leaf | 2.2429ms | 2.1161ms | 472.5734 Ops/s | 443.0617 Ops/s | |
test_keys_stack_nested_locked | 1.0903ms | 0.9436ms | 1.0598 KOps/s | 938.1109 Ops/s | |
test_values | 36.3020μs | 1.8935μs | 528.1175 KOps/s | 504.4815 KOps/s | |
test_values_nested | 0.1407ms | 73.5073μs | 13.6041 KOps/s | 13.6382 KOps/s | |
test_values_nested_locked | 0.4006ms | 73.4395μs | 13.6166 KOps/s | 13.7244 KOps/s | |
test_values_nested_leaf | 0.1528ms | 65.7384μs | 15.2118 KOps/s | 15.2686 KOps/s | |
test_values_stack_nested | 2.4459ms | 1.8560ms | 538.8000 Ops/s | 534.7590 Ops/s | |
test_values_stack_nested_leaf | 2.0091ms | 1.8486ms | 540.9581 Ops/s | 537.0398 Ops/s | |
test_values_stack_nested_locked | 0.9060ms | 0.7596ms | 1.3164 KOps/s | 1.2777 KOps/s | |
test_membership | 19.8010μs | 2.1541μs | 464.2206 KOps/s | 451.5150 KOps/s | |
test_membership_nested | 38.4020μs | 4.2442μs | 235.6141 KOps/s | 232.5990 KOps/s | |
test_membership_nested_leaf | 76.3060μs | 4.1557μs | 240.6342 KOps/s | 230.1685 KOps/s | |
test_membership_stacked_nested | 50.6030μs | 16.9154μs | 59.1179 KOps/s | 55.5477 KOps/s | |
test_membership_stacked_nested_leaf | 96.7070μs | 16.8607μs | 59.3096 KOps/s | 55.3671 KOps/s | |
test_membership_nested_last | 24.6010μs | 8.7516μs | 114.2644 KOps/s | 107.4573 KOps/s | |
test_membership_nested_leaf_last | 32.1020μs | 8.7086μs | 114.8292 KOps/s | 107.4175 KOps/s | |
test_membership_stacked_nested_last | 0.3429ms | 0.2594ms | 3.8547 KOps/s | 3.6090 KOps/s | |
test_membership_stacked_nested_leaf_last | 0.1070ms | 19.6227μs | 50.9614 KOps/s | 47.2588 KOps/s | |
test_nested_getleaf | 93.4070μs | 17.7802μs | 56.2423 KOps/s | 51.5113 KOps/s | |
test_nested_get | 50.5030μs | 16.8795μs | 59.2436 KOps/s | 55.7017 KOps/s | |
test_stacked_getleaf | 1.1854ms | 1.0236ms | 976.9444 Ops/s | 911.3819 Ops/s | |
test_stacked_get | 1.1296ms | 0.9775ms | 1.0230 KOps/s | 988.3704 Ops/s | |
test_nested_getitemleaf | 96.4070μs | 17.7327μs | 56.3930 KOps/s | 52.8133 KOps/s | |
test_nested_getitem | 86.4060μs | 16.9241μs | 59.0872 KOps/s | 55.8261 KOps/s | |
test_stacked_getitemleaf | 2.9759ms | 1.0283ms | 972.4494 Ops/s | 912.9131 Ops/s | |
test_stacked_getitem | 1.1051ms | 0.9763ms | 1.0242 KOps/s | 968.1876 Ops/s | |
test_lock_nested | 73.2636ms | 1.8029ms | 554.6577 Ops/s | 594.2184 Ops/s | |
test_lock_stack_nested | 0.1016s | 23.0279ms | 43.4257 Ops/s | 45.8036 Ops/s | |
test_unlock_nested | 71.0487ms | 1.8132ms | 551.5093 Ops/s | 565.3265 Ops/s | |
test_unlock_stack_nested | 0.1032s | 23.7151ms | 42.1672 Ops/s | 45.1116 Ops/s | |
test_flatten_speed | 1.2824ms | 1.1718ms | 853.4100 Ops/s | 849.3135 Ops/s | |
test_unflatten_speed | 2.2633ms | 2.1058ms | 474.8786 Ops/s | 476.0059 Ops/s | |
test_common_ops | 5.8099ms | 1.2812ms | 780.4994 Ops/s | 783.2521 Ops/s | |
test_creation | 28.7020μs | 7.2687μs | 137.5767 KOps/s | 141.4394 KOps/s | |
test_creation_empty | 58.1040μs | 15.9956μs | 62.5173 KOps/s | 64.2385 KOps/s | |
test_creation_nested_1 | 0.1380ms | 29.1215μs | 34.3389 KOps/s | 35.6341 KOps/s | |
test_creation_nested_2 | 0.1228ms | 31.6025μs | 31.6431 KOps/s | 32.6988 KOps/s | |
test_clone | 0.1615ms | 28.3325μs | 35.2951 KOps/s | 35.3838 KOps/s | |
test_getitem[int] | 0.1232ms | 32.3697μs | 30.8931 KOps/s | 31.5161 KOps/s | |
test_getitem[slice_int] | 0.1468ms | 63.8013μs | 15.6737 KOps/s | 15.8951 KOps/s | |
test_getitem[range] | 0.1302ms | 96.5474μs | 10.3576 KOps/s | 10.5466 KOps/s | |
test_getitem[tuple] | 0.1160ms | 53.5406μs | 18.6774 KOps/s | 19.3390 KOps/s | |
test_getitem[list] | 0.3439ms | 92.5050μs | 10.8102 KOps/s | 11.1285 KOps/s | |
test_setitem_dim[int] | 55.8040μs | 38.8641μs | 25.7307 KOps/s | 26.2328 KOps/s | |
test_setitem_dim[slice_int] | 0.1135ms | 69.2660μs | 14.4371 KOps/s | 14.6669 KOps/s | |
test_setitem_dim[range] | 0.2208ms | 94.9757μs | 10.5290 KOps/s | 10.8262 KOps/s | |
test_setitem_dim[tuple] | 82.4050μs | 57.3171μs | 17.4468 KOps/s | 17.9345 KOps/s | |
test_setitem | 0.1728ms | 36.9297μs | 27.0785 KOps/s | 27.3960 KOps/s | |
test_set | 0.1562ms | 35.8199μs | 27.9175 KOps/s | 28.1877 KOps/s | |
test_set_shared | 2.9139ms | 0.2100ms | 4.7626 KOps/s | 4.8148 KOps/s | |
test_update | 0.2191ms | 40.2059μs | 24.8719 KOps/s | 24.9126 KOps/s | |
test_update_nested | 0.2248ms | 60.2317μs | 16.6026 KOps/s | 16.8783 KOps/s | |
test_set_nested | 0.1428ms | 39.3426μs | 25.4177 KOps/s | 25.7714 KOps/s | |
test_set_nested_new | 0.2376ms | 61.3428μs | 16.3018 KOps/s | 16.6495 KOps/s | |
test_select | 0.2717ms | 0.1139ms | 8.7824 KOps/s | 8.9222 KOps/s | |
test_unbind_speed | 0.8707ms | 0.7594ms | 1.3169 KOps/s | 1.3290 KOps/s | |
test_unbind_speed_stack0 | 8.5867ms | 8.3202ms | 120.1899 Ops/s | 94.7012 Ops/s | |
test_unbind_speed_stack1 | 17.5010μs | 1.3187μs | 758.3332 KOps/s | 777.0370 KOps/s | |
test_creation[device0] | 3.4052ms | 0.5331ms | 1.8759 KOps/s | 1.9014 KOps/s | |
test_creation_from_tensor | 0.6736ms | 0.5903ms | 1.6941 KOps/s | 1.6526 KOps/s | |
test_add_one[memmap_tensor0] | 1.9809ms | 38.7326μs | 25.8181 KOps/s | 26.5326 KOps/s | |
test_contiguous[memmap_tensor0] | 61.4040μs | 9.9932μs | 100.0676 KOps/s | 100.4496 KOps/s | |
test_stack[memmap_tensor0] | 85.4050μs | 31.3238μs | 31.9246 KOps/s | 32.0624 KOps/s | |
test_memmaptd_index | 0.4830ms | 0.3580ms | 2.7934 KOps/s | 2.7074 KOps/s | |
test_memmaptd_index_astensor | 1.6289ms | 1.5756ms | 634.6893 Ops/s | 624.1487 Ops/s | |
test_memmaptd_index_op | 3.2206ms | 3.1206ms | 320.4559 Ops/s | 330.7082 Ops/s | |
test_reshape_pytree | 0.1294ms | 44.0568μs | 22.6980 KOps/s | 22.8842 KOps/s | |
test_reshape_td | 81.9060μs | 53.9422μs | 18.5384 KOps/s | 19.3242 KOps/s | |
test_view_pytree | 0.1114ms | 41.6940μs | 23.9843 KOps/s | 24.4412 KOps/s | |
test_view_td | 77.4050μs | 10.3294μs | 96.8112 KOps/s | 95.7883 KOps/s | |
test_unbind_pytree | 93.8060μs | 44.6648μs | 22.3890 KOps/s | 21.7918 KOps/s | |
test_unbind_td | 0.2131ms | 0.1123ms | 8.9046 KOps/s | 9.0284 KOps/s | |
test_split_pytree | 68.2050μs | 46.3313μs | 21.5837 KOps/s | 19.4295 KOps/s | |
test_split_td | 0.8416ms | 0.1264ms | 7.9119 KOps/s | 7.4986 KOps/s | |
test_add_pytree | 0.1256ms | 54.7281μs | 18.2722 KOps/s | 18.1519 KOps/s | |
test_add_td | 0.2196ms | 88.8607μs | 11.2536 KOps/s | 11.5918 KOps/s | |
test_distributed | 33.3020μs | 10.8269μs | 92.3625 KOps/s | 94.0733 KOps/s | |
test_tdmodule | 0.2245ms | 33.4279μs | 29.9151 KOps/s | 29.7587 KOps/s | |
test_tdmodule_dispatch | 0.3167ms | 63.6643μs | 15.7074 KOps/s | 15.3457 KOps/s | |
test_tdseq | 0.6137ms | 36.9631μs | 27.0540 KOps/s | 25.8168 KOps/s | |
test_tdseq_dispatch | 0.2286ms | 76.7709μs | 13.0258 KOps/s | 12.6993 KOps/s | |
test_instantiation_functorch | 2.0226ms | 1.8947ms | 527.7792 Ops/s | 518.8319 Ops/s | |
test_instantiation_td | 2.3753ms | 1.5682ms | 637.6577 Ops/s | 626.4042 Ops/s | |
test_exec_functorch | 0.3001ms | 0.2196ms | 4.5528 KOps/s | 4.5480 KOps/s | |
test_exec_td | 0.2601ms | 0.2057ms | 4.8623 KOps/s | 4.7988 KOps/s | |
test_vmap_mlp_speed[True-True] | 8.9455ms | 1.3783ms | 725.5226 Ops/s | 710.4665 Ops/s | |
test_vmap_mlp_speed[True-False] | 4.4109ms | 0.7194ms | 1.3901 KOps/s | 1.3543 KOps/s | |
test_vmap_mlp_speed[False-True] | 9.4309ms | 1.1802ms | 847.3028 Ops/s | 836.5257 Ops/s | |
test_vmap_mlp_speed[False-False] | 1.1355ms | 0.5364ms | 1.8643 KOps/s | 1.7111 KOps/s | |
test_vmap_transformer_speed[True-True] | 25.3063ms | 16.0569ms | 62.2785 Ops/s | 32.4655 Ops/s | |
test_vmap_transformer_speed[True-False] | 14.0185ms | 10.9114ms | 91.6474 Ops/s | 84.7445 Ops/s | |
test_vmap_transformer_speed[False-True] | 24.9099ms | 16.8639ms | 59.2984 Ops/s | 61.2021 Ops/s | |
test_vmap_transformer_speed[False-False] | 19.7356ms | 10.6362ms | 94.0187 Ops/s | 91.4549 Ops/s |
Cc @zdevito |
Nice! What kind of feedback are you looking for? |
For now nothing, it's just an FYI. |
This is ready for review if anyone wants to give it a shot. One of my goals with this is to do what we did in the example:
The issues I see currently are:
from functorch import dim as ftdim
module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 1))
# create batched params
params = TensorDict.from_module(module)
params = params.expand(3).clone().apply(lambda x: x.requires_grad_())
# FCD indexing
d0 = ftdim.dims(1)
# execute grad
x = torch.randn(3)
def func(params):
params_batched = params[d0]
params_batched.to_module(module)
return module(x)._tensor.sum()
g = torch.func.grad(func)(params)
print(g)
print("params", params['0', 'weight'], "\ngrads", g['0', 'weight']) which is cool. But we must set the parameters within the call to module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 1))
params = TensorDict.from_module(module)
# create a batch of params
params = params.expand(3).clone().apply(lambda x: nn.Parameter(x))
# reduce the first dim
d0 = dim.dims(1)
params = params[0]
# repopulate module
params.to_module(module)
# run module
loss = loss_fn(module(x), target)
# backprop
loss.backward() But because my tensordict is not full of tensors but Having to modify that logic makes FCD less attractive -- I might as well use plain vmap + functional calls if I have to repopulate my module at every call, which I believe will be faster to execute. Do you guys have any thought on that? |
The semantics for adding first-class dimensions to tensordict make sense to me. It also makes sense to be able to install parameters that have first-class dimensions. I am less clear on how we would accomplish it with how parameters in modules currently exist. First-class dims are implemented as their own objects that are not tensor subclasses in order to run fast enough in eager mode without overhead. I think parameters need to be real tensors in a lot of apis. In particular trying to set |
So... we could relax the constraint here:
and say that it has to either be a torch.Tensor, or it can be one of the first-class dim objects. Is that enough? Might be! |
Actually, I read the example more carefully (and fixed up some stuff) and I see:
fails with
Back in the day, @zou3519 talked a lot about the potential of non-lexical functorch transforms, one of the use cases being this kind of backward() call. I think we concluded that in principle it would be possible, but we have to implement it (which we never found time to do.) So drat, I guess you can't actually replace the auto-batching behavior from tensordict with first class dims, what a bummer. |
Thanks @ezyang and @zdevito for looking into this! import torch
import torch.nn as nn
from tensordict import TensorDict
import functorch.dim as dim
module = nn.Sequential(nn.Linear(3, 4), nn.Linear(4, 1))
params = TensorDict.from_module(module)
# create a batch of params
num_models = 5
params = params.expand(num_models).clone().apply(lambda x: nn.Parameter(x))
x = torch.randn(3)
def compute(params_vals):
# reduce the first dim
d0 = dim.dims(1)
for key, val_source in zip(params.keys(include_nested=True, leaves_only=True), params_vals):
params.set(key, val_source)
params_dims = params[d0]
# no vmap
y = torch.func.functional_call(module, params_dims, (x,))
return y._tensor.sum() # or any other loss
grads = torch.func.grad(compute)(list(params.values(True, True)))
# put grads in a tensordict for clarity
grads = TensorDict(
{key: grad for key, grad in zip(params.keys(include_nested=True, leaves_only=True), grads)}, batch_size=params.batch_size
)
print("grads of our 5 models", grads) which gives you
ie, you can get your gradients using fcd and no apparent vmap Note that this requires some modif in functorch and torch nn (see #526 which is based on this PR). I think that this is pretty cool, what is happening is pretty apparent. Not as easy as substituting the params with their fcd counterpart but with some woodwork this could look pretty nice already. For instance, I already made # pass directly the params
def compute(params):
d0 = dim.dims(1)
params_dims = params[d0]
y = torch.func.functional_call(module, params_dims, (x,))
return y._tensor.sum() # or any other loss
# pass directly the params
grads = torch.func.grad(compute)(params)
# grads are already in a tensordict
print("grads of our 5 models", grads) |
Oh this is the thing where you wanted to override how functorch detects batching on inputs passed to grad. @zou3519 I don't remember what your objection to this was? |
I don't think I have an objection? The code in #525 (comment) looks reasonable to me |
That and this too, which I think is more contentious. def func(tensordict_or_tensor):
assert tensordict_or_tensor.shape == (1, 2, 4)
return tensordict_or_tensor
tensor = torch.randn(1, 2, 3, 4)
vmap(func, (2,))(tensor)
tensordict = TensorDict({}, batch_size=[1, 2, 3, 4])
vmap(func, (2,))(tensordict) Currently, with the monkey patch, this code runs. If we just register tensordict within pytree but do not patch, it will fail. Now, I totally get that this has ramifications that go beyond this (I guess we don't want to overcharge vmap with every new class that comes into sight). To me, any solution that makes vmap possible with tensordict (ie, a tensordict passed through vmap sees its shape changed to a consistent shape within the vmap call) is great. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Description
WIP to make TensorDict compatible with FCD.
Example usage:
This can be used with modules to batch operations across sets of parameters seamlessly. We extract the parameters,
to_module
is a draft of what we could use.The idea of this example is to avoid functional calls through tensordict and functorch.dim when working with model ensembles.
cc @ezyang @zou3519 @zdevito
@matteobettini for MARL MLP and ConvNets
@smorad for model ensembles