-
Notifications
You must be signed in to change notification settings - Fork 17
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Stable versions of torchrl/tensordict still getting internal dynamo error #14
Comments
Thanks for reporting! |
This occurred when using the official code in the repo. I can give nightly a try later and the other suggestion |
Sounds good! And sorry for the trouble, i wish things were a bit easier to land. |
I get this in WSL2 when using --compile --cudagraphs. Using either flag independently works. python leanrl/ppo_continuous_action_torchcompile.py --seed 1 --total-timesteps 50000 --cudagraphs --compile |
I've written a few test cases that narrow it down a bit. It seems to occur for tensordict inputs to cudagraph+compiled functions, and on the final warmup round. Here are three tests, using Python 3.12 on WSL2 using the requirements.txt from this repo: def test_cudagraph_compile_tensordict_pre_warmup():
'''Passes
'''
import torch
from tensordict.nn import CudaGraphModule
from torch import nn
from tensordict import TensorDict
n_warmup = 10
device = torch.device("cuda:0")
def eval_fun(td: TensorDict):
td["output"] = torch.zeros_like(td["input"])
return td
td = TensorDict(device=device)
td["input"] = torch.zeros(1, device=device)
eval_fun = torch.compile(eval_fun, fullgraph=True)
eval_fun_cgm = CudaGraphModule(eval_fun, warmup=n_warmup)
for i in range(n_warmup - 1):
td_out2 = eval_fun_cgm(td.clone())
def test_cudagraph_compile_tensordict_post_warmup():
'''Fails
'''
import torch
from tensordict.nn import CudaGraphModule
from torch import nn
from tensordict import TensorDict
n_warmup = 2
device = torch.device("cuda:0")
def eval_fun(td: TensorDict):
td["output"] = torch.zeros_like(td["input"])
return td
td = TensorDict(device=device)
td["input"] = torch.zeros(1, device=device)
eval_fun = torch.compile(eval_fun, fullgraph=True)
eval_fun_cgm = CudaGraphModule(eval_fun, warmup=n_warmup)
#stepping one step longer than the pre warmup test
for i in range(n_warmup):
td_out2 = eval_fun_cgm(td.clone())
def test_cudagraph_compile_tensordict_post_warmup_with_clone():
'''Passes
'''
import torch
from tensordict.nn import CudaGraphModule
from torch import nn
from tensordict import TensorDict
n_warmup = 2
device = torch.device("cuda:0")
def eval_fun(td: TensorDict):
td = td.clone() #workaround, clone the input
td["output"] = torch.zeros_like(td["input"])
return td
td = TensorDict(device=device)
td["input"] = torch.zeros(1, device=device)
eval_fun = torch.compile(eval_fun, fullgraph=True)
eval_fun_cgm = CudaGraphModule(eval_fun, warmup=n_warmup)
#stepping one step longer than the pre warmup test
for i in range(n_warmup):
td_out2 = eval_fun_cgm(td.clone()) Another workaround that works is wrapping gae in a TensorDictModule, like update, with next_obs and next_done being propagated through the container tensordict. |
I am currently getting the same issue in #10.
I have torch 2.5.1, torchrl 0.6.0, tensordict 0.6.0 at the moment. I am running a slightly modified version of the original code. I can run with cudagraphs or compile, but not both. Although with cudagraphs things are working great!
Trace:
currently trying to add leanrl tricks to maniskill, update times have massively improved!
The text was updated successfully, but these errors were encountered: