Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix convergence for dolly+stage3 training #17685

Merged
merged 6 commits into from
Oct 7, 2023
Merged

Conversation

pengwa
Copy link
Contributor

@pengwa pengwa commented Sep 25, 2023

Fix convergence for dolly+stage3 training

In ZeROOffloadSubscriber, we defined some PythonOp, taking input and returning it inplace, for example:
https://github.com/microsoft/onnxruntime/blob/216214b7d302cb504d1e5a647f65b6fe49c22dbb/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py#L223C20-L223C20. While it is possible, when ORT runs such a PythonOp, once it completes, it will release the input OrtValue, triggered the data erasing or overridden. But the PythonOp's returned value OrtValue are still pointing to that address, reading or writting on that may introduce a wrong result or even undefined behaviors.

/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_custom_autograd_function_runner.py:28: UserWarning: .rank-0: onnxruntime.training.utils.hooks._zero_offload_subscriber.ORTZeROOffloadPreForwardFunction->Backward: ONNX Op attribute 'tensor_reuse_map' doesn't indicate 8-th output is reusing any input, but detected inplace_map indicates it is reusing some input index. A clone will be done before returning to ORT, to align with ORT's NO Buffer reuse plan. Please update inplace_map explicitly to avoid such a copy.
  warnings.warn(f".rank-{get_rank()}: {message}")
  0%|▏                                                                                                                                                                                                                                               | 1/1000 [00:04<1:15:08,  4.51s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,023 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 14.1406, 'learning_rate': 0, 'epoch': 0.0}
  0%|▏                                                                                                                                                                                                                                               | 1/1000 [00:04<1:15:08,  4.51s/it]Invalidate trace cache @ step 5: expected module 6, but got module 7
  0%|▍                                                                                                                                                                                                                                                 | 2/1000 [00:04<31:53,  1.92s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,124 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  0%|▋                                                                                                                                                                                                                                                 | 3/1000 [00:04<18:05,  1.09s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,227 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  0%|▋                                                                                                                                                                                                                                                 | 3/1000 [00:04<18:05,  1.09s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,326 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  0%|█▏                                                                                                                                                                                                                                                | 5/1000 [00:04<08:44,  1.90it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,419 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  0%|█▏                                                                                                                                                                                                                                                | 5/1000 [00:04<08:44,  1.90it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,505 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  1%|█▋                                                                                                                                                                                                                                                | 7/1000 [00:05<05:28,  3.02it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,597 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  1%|█▋                                                                                                                                                                                                                                                | 7/1000 [00:05<05:28,  3.02it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,690 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  1%|██▏                                                                                                                                                                                                                                               | 9/1000 [00:05<03:57,  4.17it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,791 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  1%|██▏                                                                                                                                                                                                                                               | 9/1000 [00:05<03:57,  4.17it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,889 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  1%|██▋                                                                                                                                                                                                                                              | 11/1000 [00:05<03:06,  5.32it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,981 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  1%|██▋                                                                                                                                                                                                                                              | 11/1000 [00:05<03:06,  5.32it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,073 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01}
  1%|███▏                                                                                                                                                                                                                                             | 13/1000 [00:05<02:33,  6.42it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,166 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01}
  1%|███▏                                                                                                                                                                                                                                             | 13/1000 [00:05<02:33,  6.42it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,256 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01}
  2%|███▌                                                                                                                                                                                                                                             | 15/1000 [00:05<02:12,  7.43it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,348 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01}
  2%|███▌                                                                                                                                                                                                                                             | 15/1000 [00:05<02:12,  7.43it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,439 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01}
  2%|████                                                                                                                                                                                                                                             | 17/1000 [00:06<01:59,  8.22it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,535 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01}
  2%|████                                                                                                                                                                                                                                             | 17/1000 [00:06<01:59,  8.22it/s]Traceback (most recent call last):
  File "examples/onnxruntime/training/language-modeling/run_clm.py", line 600, in <module>
    main()
  File "examples/onnxruntime/training/language-modeling/run_clm.py", line 548, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 457, in train
    return inner_training_loop(
  File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 781, in _inner_training_loop
    self.deepspeed.step()
  File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/engine.py", line 2084, in step
    self._take_model_step(lr_kwargs)
  File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/engine.py", line 1990, in _take_model_step
    self.optimizer.step()
  File "/bert_ort/pengwa/deepspeed/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 1854, in step
    if self._overflow_check_and_loss_scale_update():
  File "/bert_ort/pengwa/deepspeed/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 1788, in _overflow_check_and_loss_scale_update
    self._update_scale(self.overflow)
  File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 2132, in _update_scale
    self.loss_scaler.update_scale(has_overflow)
  File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/fp16/loss_scaler.py", line 175, in update_scale
    raise Exception(
Exception: Current loss scale already at minimum - cannot decrease scale anymore. Exiting run.
  2%|████                                                                                                                                                                                                                                             | 17/1000 [00:06<06:07,  2.67it/s]
[2023-09-25 08:30:51,075] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 1065120) of binary: /bert_ort/pengwa/py38/bin/python
Traceback (most recent call last):
  File "/bert_ort/pengwa/py38/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/run.py", line 806, in main
    run(args)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
examples/onnxruntime/training/language-modeling/run_clm.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-09-25_08:30:51
  host      : orttrainingdev10.internal.cloudapp.net
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 1065120)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
(/bert_ort/pengwa/py38) [email protected]@orttrainingdev10:/bert_ort/pengwa/optim

The Fix

For those output that are reusing input, but ORT is not aware of, we detected on the fly (the first iteration, by checking the output tensor addresses with input tensor addresses) , then do implicit copy before set it as PythonOp's output tensors.

With this fix: (left: PyTorch, right: ORT)

image

@pengwa pengwa added the training issues related to ONNX Runtime training; typically submitted using template label Sep 25, 2023
@pengwa pengwa requested review from askhade and ajindal1 September 25, 2023 09:06
Copy link
Contributor

@ajindal1 ajindal1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@pengwa
Copy link
Contributor Author

pengwa commented Oct 7, 2023

Thank you @ajindal1 !

@pengwa pengwa merged commit 7201def into main Oct 7, 2023
@pengwa pengwa deleted the pengwa/inplace_pythonop branch October 7, 2023 00:40
yf711 added a commit that referenced this pull request Oct 11, 2023
askhade pushed a commit that referenced this pull request Oct 11, 2023
### Support inplace update for PythonOp/Grad

This PR is based on another PR
#17685 branch, to make it
easier to review.

With PR: PR #17685, By
default all PythonOp inputs/outputs are assumed to not be inplaced, if
during run, we found some inplace update happens (by checking output
data address with all inputs data address), we add clone before set it
as PythonOp/Grad's outputs. In this case, results are correct, but
implicit copies overheads are introduced.

This PR allow users to define output input reuse map, to let ORT know
how to do the reuse map, avoid such unnecessary copies.
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
### Fix convergence for dolly+stage3 training

In
[ZeROOffloadSubscriber](https://github.com/microsoft/onnxruntime/blob/216214b7d302cb504d1e5a647f65b6fe49c22dbb/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py#L359C7-L359C28),
we defined some PythonOp, taking input and returning it inplace, for
example:

https://github.com/microsoft/onnxruntime/blob/216214b7d302cb504d1e5a647f65b6fe49c22dbb/orttraining/orttraining/python/training/utils/hooks/_zero_offload_subscriber.py#L223C20-L223C20.
While it is possible, when ORT runs such a PythonOp, once it completes,
it will release the input OrtValue, triggered the data erasing or
overridden. But the PythonOp's returned value OrtValue are still
pointing to that address, reading or writting on that may introduce a
wrong result or even undefined behaviors.


```
/bert_ort/pengwa/py38/lib/python3.8/site-packages/onnxruntime/training/ortmodule/_custom_autograd_function_runner.py:28: UserWarning: .rank-0: onnxruntime.training.utils.hooks._zero_offload_subscriber.ORTZeROOffloadPreForwardFunction->Backward: ONNX Op attribute 'tensor_reuse_map' doesn't indicate 8-th output is reusing any input, but detected inplace_map indicates it is reusing some input index. A clone will be done before returning to ORT, to align with ORT's NO Buffer reuse plan. Please update inplace_map explicitly to avoid such a copy.
  warnings.warn(f".rank-{get_rank()}: {message}")
  0%|▏                                                                                                                                                                                                                                               | 1/1000 [00:04<1:15:08,  4.51s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,023 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 14.1406, 'learning_rate': 0, 'epoch': 0.0}
  0%|▏                                                                                                                                                                                                                                               | 1/1000 [00:04<1:15:08,  4.51s/it]Invalidate trace cache @ step 5: expected module 6, but got module 7
  0%|▍                                                                                                                                                                                                                                                 | 2/1000 [00:04<31:53,  1.92s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,124 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  0%|▋                                                                                                                                                                                                                                                 | 3/1000 [00:04<18:05,  1.09s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,227 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  0%|▋                                                                                                                                                                                                                                                 | 3/1000 [00:04<18:05,  1.09s/it][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,326 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  0%|█▏                                                                                                                                                                                                                                                | 5/1000 [00:04<08:44,  1.90it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,419 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  0%|█▏                                                                                                                                                                                                                                                | 5/1000 [00:04<08:44,  1.90it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,505 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  1%|█▋                                                                                                                                                                                                                                                | 7/1000 [00:05<05:28,  3.02it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,597 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  1%|█▋                                                                                                                                                                                                                                                | 7/1000 [00:05<05:28,  3.02it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,690 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  1%|██▏                                                                                                                                                                                                                                               | 9/1000 [00:05<03:57,  4.17it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,791 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  1%|██▏                                                                                                                                                                                                                                               | 9/1000 [00:05<03:57,  4.17it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,889 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  1%|██▋                                                                                                                                                                                                                                              | 11/1000 [00:05<03:06,  5.32it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:44,981 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.0}
  1%|██▋                                                                                                                                                                                                                                              | 11/1000 [00:05<03:06,  5.32it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,073 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01}
  1%|███▏                                                                                                                                                                                                                                             | 13/1000 [00:05<02:33,  6.42it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,166 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01}
  1%|███▏                                                                                                                                                                                                                                             | 13/1000 [00:05<02:33,  6.42it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,256 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01}
  2%|███▌                                                                                                                                                                                                                                             | 15/1000 [00:05<02:12,  7.43it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,348 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01}
  2%|███▌                                                                                                                                                                                                                                             | 15/1000 [00:05<02:12,  7.43it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,439 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01}
  2%|████                                                                                                                                                                                                                                             | 17/1000 [00:06<01:59,  8.22it/s][WARNING|trainer_pt_utils.py:849] 2023-09-25 08:30:45,535 >> tried to get lr value before scheduler/optimizer started stepping, returning lr=0
{'loss': 0.0, 'learning_rate': 0, 'epoch': 0.01}
  2%|████                                                                                                                                                                                                                                             | 17/1000 [00:06<01:59,  8.22it/s]Traceback (most recent call last):
  File "examples/onnxruntime/training/language-modeling/run_clm.py", line 600, in <module>
    main()
  File "examples/onnxruntime/training/language-modeling/run_clm.py", line 548, in main
    train_result = trainer.train(resume_from_checkpoint=checkpoint)
  File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 457, in train
    return inner_training_loop(
  File "/bert_ort/pengwa/optimum/optimum/onnxruntime/trainer.py", line 781, in _inner_training_loop
    self.deepspeed.step()
  File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/engine.py", line 2084, in step
    self._take_model_step(lr_kwargs)
  File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/engine.py", line 1990, in _take_model_step
    self.optimizer.step()
  File "/bert_ort/pengwa/deepspeed/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 1854, in step
    if self._overflow_check_and_loss_scale_update():
  File "/bert_ort/pengwa/deepspeed/deepspeed/utils/nvtx.py", line 15, in wrapped_fn
    ret_val = func(*args, **kwargs)
  File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 1788, in _overflow_check_and_loss_scale_update
    self._update_scale(self.overflow)
  File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/zero/stage3.py", line 2132, in _update_scale
    self.loss_scaler.update_scale(has_overflow)
  File "/bert_ort/pengwa/deepspeed/deepspeed/runtime/fp16/loss_scaler.py", line 175, in update_scale
    raise Exception(
Exception: Current loss scale already at minimum - cannot decrease scale anymore. Exiting run.
  2%|████                                                                                                                                                                                                                                             | 17/1000 [00:06<06:07,  2.67it/s]
[2023-09-25 08:30:51,075] torch.distributed.elastic.multiprocessing.api: [ERROR] failed (exitcode: 1) local_rank: 0 (pid: 1065120) of binary: /bert_ort/pengwa/py38/bin/python
Traceback (most recent call last):
  File "/bert_ort/pengwa/py38/bin/torchrun", line 8, in <module>
    sys.exit(main())
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 346, in wrapper
    return f(*args, **kwargs)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/run.py", line 806, in main
    run(args)
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/run.py", line 797, in run
    elastic_launch(
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 134, in __call__
    return launch_agent(self._config, self._entrypoint, list(args))
  File "/bert_ort/pengwa/py38/lib/python3.8/site-packages/torch/distributed/launcher/api.py", line 264, in launch_agent
    raise ChildFailedError(
torch.distributed.elastic.multiprocessing.errors.ChildFailedError:
============================================================
examples/onnxruntime/training/language-modeling/run_clm.py FAILED
------------------------------------------------------------
Failures:
  <NO_OTHER_FAILURES>
------------------------------------------------------------
Root Cause (first observed failure):
[0]:
  time      : 2023-09-25_08:30:51
  host      : orttrainingdev10.internal.cloudapp.net
  rank      : 0 (local_rank: 0)
  exitcode  : 1 (pid: 1065120)
  error_file: <N/A>
  traceback : To enable traceback see: https://pytorch.org/docs/stable/elastic/errors.html
============================================================
(/bert_ort/pengwa/py38) [email protected]@orttrainingdev10:/bert_ort/pengwa/optim
```

## The Fix

For those output that are reusing input, but ORT is not aware of, we
detected on the fly (the first iteration, by checking the output tensor
addresses with input tensor addresses) , then do implicit copy before
set it as PythonOp's output tensors.


With this fix: (left: PyTorch, right: ORT)


![image](https://github.com/microsoft/onnxruntime/assets/10530022/0d72f431-2abd-4e52-af99-19974b85edde)
kleiti pushed a commit to kleiti/onnxruntime that referenced this pull request Mar 22, 2024
### Support inplace update for PythonOp/Grad

This PR is based on another PR
microsoft#17685 branch, to make it
easier to review.

With PR: PR microsoft#17685, By
default all PythonOp inputs/outputs are assumed to not be inplaced, if
during run, we found some inplace update happens (by checking output
data address with all inputs data address), we add clone before set it
as PythonOp/Grad's outputs. In this case, results are correct, but
implicit copies overheads are introduced.

This PR allow users to define output input reuse map, to let ORT know
how to do the reuse map, avoid such unnecessary copies.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
training issues related to ONNX Runtime training; typically submitted using template
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants