Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stable versions of torchrl/tensordict still getting internal dynamo error #14

Open
StoneT2000 opened this issue Oct 31, 2024 · 5 comments

Comments

@StoneT2000
Copy link

StoneT2000 commented Oct 31, 2024

I am currently getting the same issue in #10.

I have torch 2.5.1, torchrl 0.6.0, tensordict 0.6.0 at the moment. I am running a slightly modified version of the original code. I can run with cudagraphs or compile, but not both. Although with cudagraphs things are working great!

Trace:

python leanrl/ppo_continuous_action_torchcompile.py --num-envs 1 --num-steps 64 --total-timesteps 256 --compile --cudagraphs
/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/tyro/_fields.py:181: UserWarning: The field target_kl is annotated with type <class 'float'>, but the default value None has type <class 'NoneType'>. We'll try to handle this gracefully, but it may cause unexpected behavior.
  warnings.warn(
wandb: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.
wandb: Currently logged in as: stonet2000. Use `wandb login --relogin` to force relogin
wandb: Tracking run with wandb version 0.18.5
wandb: Run data is saved locally in /home/stao/work/external/leanrl/wandb/run-20241030_192316-jgwpim8y
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run ppo_continuous_action_torchcompile-HalfCheetah-v4__ppo_continuous_action_torchcompile__1__True__True
wandb: ⭐️ View project at https://wandb.ai/stonet2000/ppo_continuous_action
wandb: 🚀 View run at https://wandb.ai/stonet2000/ppo_continuous_action/runs/jgwpim8y
/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/tensordict/nn/cudagraphs.py:194: UserWarning: Tensordict is registered in PyTree. This is incompatible with CudaGraphModule. Removing TDs from PyTree. To silence this warning, call tensordict.nn.functional_module._exclude_td_from_pytree().set() or set the environment variable `EXCLUDE_TD_FROM_PYTREE=1`. This operation is irreversible.
  warnings.warn(
  0%|                                                                                                                                 | 0/4 [00:00<?, ?it/s]/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/torch/_inductor/compile_fx.py:167: UserWarning: TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled. Consider setting `torch.set_float32_matmul_precision('high')` for better performance.
  warnings.warn(
W1030 19:23:26.159230 2648226 site-packages/torch/_logging/_internal.py:1081] [11/0] Profiler function <class 'torch.autograd.profiler.record_function'> will be ignored
 25%|██████████████████████████████▎                                                                                          | 1/4 [00:10<00:31, 10.50s/it]/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/torch/cuda/graphs.py:84: UserWarning: The CUDA Graph is empty. This usually means that the graph was attempted to be captured on wrong device or stream. (Triggered internally at ../aten/src/ATen/cuda/CUDAGraph.cpp:208.)
  super().capture_end()
 25%|██████████████████████████████▎                                                                                          | 1/4 [00:10<00:31, 10.65s/it]
Traceback (most recent call last):
  File "/home/stao/work/external/leanrl/leanrl/ppo_continuous_action_torchcompile.py", line 358, in <module>
    container = gae(next_obs, next_done, container)
  File "/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/tensordict/nn/cudagraphs.py", line 439, in __call__
    return self._call_func(*args, **kwargs)
  File "/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/tensordict/nn/cudagraphs.py", line 345, in _call
    out = self.module(*self._args, **self._kwargs)
  File "/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 465, in _fn
    return fn(*args, **kwargs)
  File "/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 1269, in __call__
    return self._torchdynamo_orig_callable(
  File "/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 526, in __call__
    return _compile(
  File "/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 952, in _compile
    raise InternalTorchDynamoError(
  File "/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 924, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 666, in compile_inner
    return _compile_inner(code, one_graph, hooks, transform)
  File "/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/torch/_utils_internal.py", line 87, in wrapper_function
    return function(*args, **kwargs)
  File "/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 699, in _compile_inner
    out_code = transform_code_object(code, transform)
  File "/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object
    transformations(instructions, code_options)
  File "/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 208, in _fn
    cuda_rng_state = torch.cuda.get_rng_state()
  File "/home/stao/miniforge3/envs/leanrl/lib/python3.10/site-packages/torch/cuda/random.py", line 42, in get_rng_state
    return default_generator.get_state()
torch._dynamo.exc.InternalTorchDynamoError: RuntimeError: Cannot call CUDAGeneratorImpl::current_seed during CUDA graph capture. If you need this call to be captured, please file an issue. Current cudaStreamCaptureStatus: cudaStreamCaptureStatusActive


You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True

currently trying to add leanrl tricks to maniskill, update times have massively improved!
image

@vmoens
Copy link
Contributor

vmoens commented Nov 6, 2024

Thanks for reporting!
Is there a code I can look at?
Perhaps let's try with nightlies (I know there were a few fixes in cuda graph that were not included in 2.5.1)
Also you could run with compile(..., mode="reduce-overhead") and check if a warning is raised - could be that there's an op that isn't accounted for. CudaGraphModule doesn't do that check for you, whereas "reduce-overhead" does, but it's generally slower to execute.

@StoneT2000
Copy link
Author

StoneT2000 commented Nov 6, 2024

This occurred when using the official code in the repo. I can give nightly a try later and the other suggestion

@vmoens
Copy link
Contributor

vmoens commented Nov 6, 2024

Sounds good! And sorry for the trouble, i wish things were a bit easier to land.
If you have a repro I'd love to give it a shot

@dennismalmgren
Copy link

I get this in WSL2 when using --compile --cudagraphs. Using either flag independently works.

python leanrl/ppo_continuous_action_torchcompile.py --seed 1 --total-timesteps 50000 --cudagraphs --compile
Tried with stable torch but also nightly versions of torch, tensordict etc.

@dennismalmgren
Copy link

dennismalmgren commented Nov 28, 2024

I've written a few test cases that narrow it down a bit. It seems to occur for tensordict inputs to cudagraph+compiled functions, and on the final warmup round. Here are three tests, using Python 3.12 on WSL2 using the requirements.txt from this repo:

def test_cudagraph_compile_tensordict_pre_warmup():
    '''Passes
    '''
    import torch
    from tensordict.nn import CudaGraphModule
    from torch import nn
    from tensordict import TensorDict

    n_warmup = 10
    device = torch.device("cuda:0")

    def eval_fun(td: TensorDict):
        td["output"] = torch.zeros_like(td["input"])
        return td

    td = TensorDict(device=device)
    td["input"] = torch.zeros(1, device=device)

    eval_fun = torch.compile(eval_fun, fullgraph=True)
    eval_fun_cgm = CudaGraphModule(eval_fun, warmup=n_warmup)

    for i in range(n_warmup - 1):
        td_out2 = eval_fun_cgm(td.clone())


def test_cudagraph_compile_tensordict_post_warmup():
    '''Fails
    '''

    import torch
    from tensordict.nn import CudaGraphModule
    from torch import nn
    from tensordict import TensorDict

    n_warmup = 2
    device = torch.device("cuda:0")

    def eval_fun(td: TensorDict):
        td["output"] = torch.zeros_like(td["input"])
        return td

    td = TensorDict(device=device)
    td["input"] = torch.zeros(1, device=device)

    eval_fun = torch.compile(eval_fun, fullgraph=True)
    eval_fun_cgm = CudaGraphModule(eval_fun, warmup=n_warmup)

    #stepping one step longer than the pre warmup test
    for i in range(n_warmup):
        td_out2 = eval_fun_cgm(td.clone())

def test_cudagraph_compile_tensordict_post_warmup_with_clone():
    '''Passes
    '''

    import torch
    from tensordict.nn import CudaGraphModule
    from torch import nn
    from tensordict import TensorDict

    n_warmup = 2
    device = torch.device("cuda:0")

    def eval_fun(td: TensorDict):
        td = td.clone() #workaround, clone the input
        td["output"] = torch.zeros_like(td["input"])
        return td

    td = TensorDict(device=device)
    td["input"] = torch.zeros(1, device=device)

    eval_fun = torch.compile(eval_fun, fullgraph=True)
    eval_fun_cgm = CudaGraphModule(eval_fun, warmup=n_warmup)

    #stepping one step longer than the pre warmup test
    for i in range(n_warmup):
        td_out2 = eval_fun_cgm(td.clone())

Another workaround that works is wrapping gae in a TensorDictModule, like update, with next_obs and next_done being propagated through the container tensordict.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants