-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
TPU performance issues and potential fixes #15884
Comments
@Liyang90 I agree that we should insert the Contributions for the XLA strategy are very welcome btw! There are many improvements we should do, Lightning has fallen a bit behind and we should do more testing as well to make sure performance does not degrade. |
Thanks @Liyang90! This is great. Can you cc Steven and me in this kind of issue? I didn't find a good way to subscribe to labels :D |
Yea label notifier would be nice. Steven's github is @steventk-g |
Thanks @awaelchli , can you also add @Liyang90 to the tpu label notifier? |
@Liyang90 Thanks again for your thorough debugging. All your suggestions make sense. Feel free to open PRs to address each of these issues. |
FYI, issue 2 should be resolved now. We recently removed |
That's great! How does the loss tracking work now? |
We no longer compute a running average of the loss, so no tracking of any kind. The loss returned from the user's training_step gets used for backward and nothing else. So I think this part should be well compatible with XLA's graph. |
re: issue 3 I don't know why we move the metrics to CPU. It was added in https://github.com/Lightning-AI/lightning/pull/5743/files#diff-51f8c4fefbb7ae230000c1b9a474c8f87086a82cb4b49d9b57a3e77e7cb2ebdfR151 We can try removing https://github.com/Lightning-AI/lightning/blob/97a61868fb5ff605a857b266a9b2a44b2330da71/src/lightning/pytorch/trainer/connectors/logger_connector/result.py#L370-L372 to fix it |
Bug description
Issue 1
Usually
mark_step()
happens at the beginning of the next iteration when theMpDeviceLoader
wrapped dataloader is iterated. However, PyTorch-Lightning may insert multiple callbacks at the end of a batch iteration, such as progress bar refreshing, logging, metrics tracking, running loss updating. Users can also add user-defined end-of-batch callbacks. These callbacks could access lazy tensors’ values and trigger early evaluations (extra compilations and computations). So as an easy fix, we can materialize all lazy tensors after the optimizer step with axm.mark_step()
call, just before all the callbacks access the tensor values.On top of original code:
Here a
barrier=True
argument is added to thexm.optimizer_step()
call. This would trigger amark_step()
after the optimizer step.On top of changes proposed in #15878:
Here a
xm.mark_step()
call is added after the optimizer step.Issue 2
This patch moves the running loss tracking to the CPU in case of TPU.
The running loss is tracked in a fixed-length tensor
memory
(size is 20 by default). In every iteration, the new loss tensor is inserted to an incrementing index inmemory
:self.memory[self.current_idx] = x
. If the running loss is tracked on TPU as a lazy tensor, this in-place update would be axla::update_slice()
op with a differentbase_indices
argument in each iteration, and the inserted loss tensor (x
) is a lazy tensor with a huge graph, essentially the graph of the whole forward pass leading to this loss tensor.During training iterations, thememory
is somehow not considered as a live tensor that needs to be synced and materialized bymark_step()
, so it is not materialized. Then, when thememory
value is finally accessed duringteardown()
, all the losses ever inserted to it and their graphs would be replayed.(Update: the bug in PT/XLA has been fixed recently. The
memory
tensor can now be included in the graph being materialized whenmark_step()
is called. However the patch is still necessary, because axla::update_slice()
op with a differentbase_indices
argument in each iteration would lead to recompilation even though the rest of the graph for real model training work is identical. The patch is also necessary for users on torch_xla < 1.13.)Given the simple purpose of the running loss tensor, we can trade off more server-to-host communications for much simpler compilations and computations, by sending loss tensor to CPU and track running loss on CPU. With patch for Issue 1, the loss tensor is already materialized at this moment, so it would not trigger early evaluation, and would simply be a server-to-host transfer.
Issue 3
PyTorch-Lightning moves the logged metrics to CPU from TPU according to this line of code. But they are then moved back to TPU (unintended I assume) at several lines below, because
self.device
still points to the XLA device. This leads to additional compilation and transfer from server to host when the metrics are accessed. So the patch below keeps the_ResultCollection
object on CPU even though the training module is on TPU, to avoid moving the logged metrics back to TPU.Issue #15743 might be related to this.
How to reproduce the bug
No response
Error messages and logs
Environment
More info
No response
The text was updated successfully, but these errors were encountered: