Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Runtime error #14

Open
biaoyanf opened this issue Apr 9, 2024 · 2 comments
Open

Runtime error #14

biaoyanf opened this issue Apr 9, 2024 · 2 comments

Comments

@biaoyanf
Copy link

biaoyanf commented Apr 9, 2024

Hi,

I run the training code:

python train.py --seed 2022 --batch-size 32 \
--num-epoch 3 --devices 0 \
--model-name roberta-large --ckpt-save-path ./ckpt/ \
--data-path ./data/training/ \
--max-samples-per-dataset 500000 --trainin-datasets mnli 

and encounters the error:

RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

How can we solve this?

Thanks in advance!

@pradeepiisc
Copy link

Were you able to fix the error? @biaoyanf

@pradeepiisc
Copy link

``After going through the code and browsing internet, the main reason for this error is disabling of computation gradient somewhere during the training step.
If you notice in training step, the code uses AdamW that is a third party implementation and has a function called def step() with a decorator torch.no_grad() inside the optimization.py class of transformers.

Some links to support my arguments:

https://github.com/Lightning-AI/pytorch-lightning/issues/18222
https://github.com/Lightning-AI/pytorch-lightning/issues/18254

The resolution is to explicilty enable grad via torch.enable_grad() decorator or a function call.
https://github.com/Lightning-AI/pytorch-lightning/pull/18268/files

I referred to this bug fix PR of lighting and added this to the optimizer_loop.py class myself instead of upgrading the torch/lighting version.

@torch.enable_grad() def closure(self, *args: Any, **kwargs: Any) -> ClosureResult: step_output = self._step_fn()

With the above change, the runtimeError is not coming anymore.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants