Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Latest accelerate 0.29.0 and 0.29.1 broke the gemma 7b lora tuning example #2629

Closed
3 of 4 tasks
brianchunkang opened this issue Apr 6, 2024 · 10 comments · Fixed by #2634
Closed
3 of 4 tasks

Latest accelerate 0.29.0 and 0.29.1 broke the gemma 7b lora tuning example #2629

brianchunkang opened this issue Apr 6, 2024 · 10 comments · Fixed by #2634

Comments

@brianchunkang
Copy link

System Info

- `Accelerate` version: 0.29.1
- Platform: Linux-5.19.0-1030-gcp-x86_64-with-glibc2.2.5
- `accelerate` bash location: /usr/local/bin/accelerate
- Python version: 3.8.18
- Numpy version: 1.24.4
- PyTorch version (GPU?): 2.3.0 (False)
- PyTorch XPU available: False
- PyTorch NPU available: False
- PyTorch MLU available: False
- System RAM: 188.67 GB

The example is here: https://huggingface.co/google/gemma-7b/blob/main/examples/example_fsdp.py

And it works with the folllowing libraries:
pip install transformers==4.38.2 -U
pip install datasets==2.18.0
pip install trl==0.8.1 peft==0.10.0
pip install accelerate==0.28.0

But 0.29.0 and 0.29.1 cause and error in Transformers here and I think it may be because the dataloader type changed? https://github.com/huggingface/transformers/blob/092f1fdaa4224fdd88c616dc9678e6fcb37bfffd/src/transformers/integrations/tpu.py#L24

Traceback is here:
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/usr/local/lib/python3.8/site-packages/trl/trainer/sft_trainer.py", line 360, in train
    output = super().train(*args, **kwargs)
  File "/usr/local/lib/python3.8/site-packages/transformers/trainer.py", line 1624, in train
    return inner_training_loop(
  File "/usr/local/lib/python3.8/site-packages/transformers/trainer.py", line 1655, in _inner_training_loop
    train_dataloader = tpu_spmd_dataloader(train_dataloader)
  File "/usr/local/lib/python3.8/site-packages/transformers/integrations/tpu.py", line 24, in tpu_spmd_dataloader
    assert isinstance(
AssertionError: The dataloader must be a `torch_xla.distributed.parallel_loader.MpDeviceLoader`.

Information

  • The official example scripts
  • My own modified scripts

Tasks

  • One of the scripts in the examples/ folder of Accelerate or an officially supported no_trainer script in the examples folder of the transformers repo (such as run_no_trainer_glue.py)
  • My own task or dataset (give details below)

Reproduction

  1. Code snippet is here: https://huggingface.co/google/gemma-7b/blob/main/examples/example_fsdp.py
  2. pip install transformers==4.38.2 -U
    pip install datasets==2.18.0
    pip install trl==0.8.1 peft==0.10.0
  3. Output is this:
    Traceback (most recent call last):
    File "", line 1, in
    File "/usr/local/lib/python3.8/site-packages/trl/trainer/sft_trainer.py", line 360, in train
    output = super().train(*args, **kwargs)
    File "/usr/local/lib/python3.8/site-packages/transformers/trainer.py", line 1624, in train
    return inner_training_loop(
    File "/usr/local/lib/python3.8/site-packages/transformers/trainer.py", line 1655, in _inner_training_loop
    train_dataloader = tpu_spmd_dataloader(train_dataloader)
    File "/usr/local/lib/python3.8/site-packages/transformers/integrations/tpu.py", line 24, in tpu_spmd_dataloader
    assert isinstance(
    AssertionError: The dataloader must be a torch_xla.distributed.parallel_loader.MpDeviceLoader.

Expected behavior

Dataloader type is recognized and doesn't throw an error in the assert

@brianchunkang
Copy link
Author

@alanwaketan

@zorrofox
Copy link

zorrofox commented Apr 7, 2024

I also have the same issue...

@muellerzr
Copy link
Collaborator

Hi all, what could be helpful is can you show me the output of print(trainer.accelerator.state)?

@zorrofox
Copy link

zorrofox commented Apr 7, 2024

print(trainer.accelerator.state) output bellow:

Distributed environment: NO
Num processes: 1
Process index: 0
Local process index: 0
Device: xla:0

But I have set the ENV in my OS:

export PJRT_DEVICE=TPU
export XLA_USE_SPMD=1

Hi all, what could be helpful is can you show me the output of print(trainer.accelerator.state)?

@muellerzr
Copy link
Collaborator

Thanks, likely from #2576, I’ll try and fix it on Monday (unless one of you gets there first!) and then we can release another patch :)

@muellerzr
Copy link
Collaborator

Hi all, can you try running with pip install git+https://github.com/huggingface/accelerate@patchfixes? Reporting back AcceleratorState() would be handy too (still working on setting up local GPU-XLA here so flying blind-ish)

@zorrofox
Copy link

zorrofox commented Apr 9, 2024

@muellerzr thanks a lot for your patch fixes. And it's work on transformers==4.38.2 datasets==2.18.0 trl==0.8.1 peft==0.10.0. But broken in transformers==4.39.3. The error stack bellow:

Traceback (most recent call last):
  File "/home/altostrat_com/docker/example_fsdp.py", line 62, in <module>
    trainer.train()
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/trl/trainer/sft_trainer.py", line 360, in train
    output = super().train(*args, **kwargs)
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/transformers/trainer.py", line 1780, in train
    return inner_training_loop(
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/transformers/trainer.py", line 2118, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/transformers/trainer.py", line 3045, in training_step
    self.accelerator.backward(loss)
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/accelerate/accelerator.py", line 2013, in backward
    loss.backward(**kwargs)
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/torch/_tensor.py", line 534, in backward
    torch.autograd.backward(
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/torch/autograd/__init__.py", line 267, in backward
    _engine_run_backward(
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/torch/autograd/graph.py", line 767, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/torch/autograd/function.py", line 300, in apply
    return user_fn(self, *args)
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/torch_xla/utils/checkpoint.py", line 161, in backward
    outputs = ctx.run_function(*detached_inputs)
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/torch_xla/distributed/fsdp/utils.py", line 21, in _xla_checkpointed_forward_no_kwargs
    return m._xla_checkpointed_forward_original(*args, **kwargs)
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 643, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/altostrat_com/.local/lib/python3.10/site-packages/transformers/models/gemma/modeling_gemma.py", line 283, in forward
    attn_weights = attn_weights + causal_mask
RuntimeError: torch_xla/csrc/helpers.cpp:587 : Check failed: dim1 == dim2 || dim1 == 1 || dim2 == 1 || dim1 == xla::Shape::kUnboundedSize || dim2 == xla::Shape::kUnboundedSize 
*** Begin stack trace ***
        tsl::CurrentStackTrace()
        torch_xla::XlaHelpers::GetPromotedShape(xla::Shape const&, xla::Shape const&)
        torch_xla::XlaHelpers::PromoteShapes(xla::XlaOp, xla::XlaOp)
        torch_xla::XlaHelpers::Promote(xla::XlaOp, xla::XlaOp)
        torch_xla::BuildAdd(xla::XlaOp, xla::XlaOp, xla::XlaOp)

        torch_xla::InferOutputShape(absl::lts_20230802::Span<xla::Shape const>, std::function<xla::XlaOp (absl::lts_20230802::Span<xla::XlaOp const>)> const&)


        torch_xla::XlaNode::GetOpShape(std::function<xla::Shape ()> const&) const
        torch_xla::XlaNode::XlaNode(torch::lazy::OpKind, c10::ArrayRef<torch::lazy::Value>, std::function<xla::Shape ()> const&, unsigned long, torch::lazy::hash_t)
        torch_xla::Generic::Generic(torch::lazy::OpKind, c10::ArrayRef<torch::lazy::Value>, std::function<xla::Shape ()> const&, std::function<absl::lts_20230802::InlinedVector<xla::XlaOp, 1ul, std::allocator<xla::XlaOp> > (torch_xla::XlaNode const&, torch_xla::LoweringContext*)>, unsigned long, torch::lazy::hash_t)
        std::shared_ptr<torch::lazy::Node> torch::lazy::MakeNode<torch_xla::Generic, torch::lazy::OpKind, c10::ArrayRef<torch::lazy::Value>&, std::function<xla::Shape ()> const&, std::function<absl::lts_20230802::InlinedVector<xla::XlaOp, 1ul, std::allocator<xla::XlaOp> > (torch_xla::XlaNode const&, torch_xla::LoweringContext*)>, unsigned long&, torch::lazy::hash_t&>(torch::lazy::OpKind&&, c10::ArrayRef<torch::lazy::Value>&, std::function<xla::Shape ()> const&, std::function<absl::lts_20230802::InlinedVector<xla::XlaOp, 1ul, std::allocator<xla::XlaOp> > (torch_xla::XlaNode const&, torch_xla::LoweringContext*)>&&, unsigned long&, torch::lazy::hash_t&)
        torch_xla::Add(torch::lazy::Value const&, torch::lazy::Value const&, torch::lazy::Value const&)
        torch_xla::tensor_methods::add(c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, c10::intrusive_ptr<torch_xla::XLATensor, c10::detail::intrusive_target_default_null_type<torch_xla::XLATensor> > const&, c10::Scalar const&, std::optional<c10::ScalarType>)
        torch_xla::XLANativeFunctions::add(at::Tensor const&, at::Tensor const&, c10::Scalar const&)

        c10::Dispatcher::callBoxed(c10::OperatorHandle const&, std::vector<c10::IValue, std::allocator<c10::IValue> >*) const

        at::_ops::add_Tensor::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::Scalar const&)


        at::_ops::add_Tensor::call(at::Tensor const&, at::Tensor const&, c10::Scalar const&)






        PyNumber_Add
        _PyEval_EvalFrameDefault

        PyObject_Call
        _PyEval_EvalFrameDefault

        PyObject_Call
        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyObject_FastCallDictTstate
        _PyObject_Call_Prepend

        PyObject_Call
        _PyEval_EvalFrameDefault

        PyObject_Call
        _PyEval_EvalFrameDefault

        _PyEval_EvalFrameDefault
        _PyFunction_Vectorcall
        _PyEval_EvalFrameDefault

        torch::autograd::PyNode::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&)

        torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&)
        torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&)
        torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool)
        torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool)



*** End stack trace ***

@muellerzr
Copy link
Collaborator

@zorrofox can you open an issue on transformers for that? Since we've fixed it on the Accelerate side in that case :)

@alanwaketan
Copy link

Thanks everyone! @zorrofox Could you cc me in the transformers' bug? In case we need to fix something in the torch-xla side.

@zorrofox
Copy link

@alanwaketan Done. Thanks for your help!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants