Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TPU performance issues and potential fixes #15884

Closed
Liyang90 opened this issue Dec 1, 2022 · 10 comments · Fixed by #17572
Closed

TPU performance issues and potential fixes #15884

Liyang90 opened this issue Dec 1, 2022 · 10 comments · Fixed by #17572
Assignees
Labels
accelerator: tpu Tensor Processing Unit bug Something isn't working
Milestone

Comments

@Liyang90
Copy link
Contributor

Liyang90 commented Dec 1, 2022

Bug description

Issue 1

Usually mark_step() happens at the beginning of the next iteration when the MpDeviceLoader wrapped dataloader is iterated. However, PyTorch-Lightning may insert multiple callbacks at the end of a batch iteration, such as progress bar refreshing, logging, metrics tracking, running loss updating. Users can also add user-defined end-of-batch callbacks. These callbacks could access lazy tensors’ values and trigger early evaluations (extra compilations and computations). So as an easy fix, we can materialize all lazy tensors after the optimizer step with a xm.mark_step() call, just before all the callbacks access the tensor values.

On top of original code:

diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py
index efa61dd8f..7624a7626 100644
--- a/src/pytorch_lightning/plugins/precision/tpu.py
+++ b/src/pytorch_lightning/plugins/precision/tpu.py
@@ -40,7 +40,7 @@ class TPUPrecisionPlugin(PrecisionPlugin):
        import torch_xla.core.xla_model as xm
        closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
-        closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
+        closure_result = xm.optimizer_step(optimizer, barrier=True, optimizer_args={"closure": closure, **kwargs})
        skipped_backward = closure_result is None
        # in manual optimization, the closure does not return a value
        if model.automatic_optimization and skipped_backward:

Here a barrier=True argument is added to the xm.optimizer_step() call. This would trigger a mark_step() after the optimizer step.

On top of changes proposed in #15878:

diff --git a/src/pytorch_lightning/plugins/precision/tpu.py b/src/pytorch_lightning/plugins/precision/tpu.py
index efa61dd8f..3f1b59059 100644
--- a/src/pytorch_lightning/plugins/precision/tpu.py
+++ b/src/pytorch_lightning/plugins/precision/tpu.py
@@ -29,6 +29,13 @@ class TPUPrecisionPlugin(PrecisionPlugin):
            raise ModuleNotFoundError(str(_XLA_AVAILABLE))
        super().__init__(*args, **kwargs)
+    def _tpu_wrap_closure(self, optimizer, closure: Callable[[], Any]) -> Any:
+        import torch_xla.core.xla_model as xm
+
+        closure_result = closure()
+        xm.reduce_gradients(optimizer)
+        return closure_result
+
    def optimizer_step(  # type: ignore[override]
        self,
        optimizer: Optimizable,
@@ -39,8 +46,10 @@ class TPUPrecisionPlugin(PrecisionPlugin):
    ) -> Any:
        import torch_xla.core.xla_model as xm
+        closure = partial(self._tpu_wrap_closure, optimizer, closure)
        closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
-        closure_result = xm.optimizer_step(optimizer, optimizer_args={"closure": closure, **kwargs})
+        closure_result = optimizer.step(closure=closure, **kwargs)
+        xm.mark_step()
        skipped_backward = closure_result is None
        # in manual optimization, the closure does not return a value
        if model.automatic_optimization and skipped_backward:

Here a xm.mark_step() call is added after the optimizer step.

Issue 2

diff --git a/src/pytorch_lightning/trainer/supporters.py b/src/pytorch_lightning/trainer/supporters.py
index 59856f12e..a9275a486 100644
--- a/src/pytorch_lightning/trainer/supporters.py
+++ b/src/pytorch_lightning/trainer/supporters.py
@@ -72,6 +72,8 @@ class TensorRunningAccum:
    def append(self, x: torch.Tensor) -> None:
        """Add an element to the accumulator."""
+        if x.device.type == "xla":
+            x = x.cpu()
        if self.memory is None:
            # tradeoff memory for speed by keeping the memory on device
            self.memory = torch.zeros(self.window_length, *x.shape, device=x.device, dtype=x.dtype)

This patch moves the running loss tracking to the CPU in case of TPU.

The running loss is tracked in a fixed-length tensor memory (size is 20 by default). In every iteration, the new loss tensor is inserted to an incrementing index in memory: self.memory[self.current_idx] = x. If the running loss is tracked on TPU as a lazy tensor, this in-place update would be a xla::update_slice() op with a different base_indices argument in each iteration, and the inserted loss tensor (x) is a lazy tensor with a huge graph, essentially the graph of the whole forward pass leading to this loss tensor.

During training iterations, the memory is somehow not considered as a live tensor that needs to be synced and materialized by mark_step(), so it is not materialized. Then, when the memory value is finally accessed during teardown(), all the losses ever inserted to it and their graphs would be replayed.

(Update: the bug in PT/XLA has been fixed recently. The memory tensor can now be included in the graph being materialized when mark_step() is called. However the patch is still necessary, because a xla::update_slice() op with a different base_indices argument in each iteration would lead to recompilation even though the rest of the graph for real model training work is identical. The patch is also necessary for users on torch_xla < 1.13.)

Given the simple purpose of the running loss tensor, we can trade off more server-to-host communications for much simpler compilations and computations, by sending loss tensor to CPU and track running loss on CPU. With patch for Issue 1, the loss tensor is already materialized at this moment, so it would not trigger early evaluation, and would simply be a server-to-host transfer.

Issue 3

PyTorch-Lightning moves the logged metrics to CPU from TPU according to this line of code. But they are then moved back to TPU (unintended I assume) at several lines below, because self.device still points to the XLA device. This leads to additional compilation and transfer from server to host when the metrics are accessed. So the patch below keeps the _ResultCollection object on CPU even though the training module is on TPU, to avoid moving the logged metrics back to TPU.

Issue #15743 might be related to this.

diff --git a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py
index d4c74f306..383125a41 100644
--- a/src/pytorch_lightning/trainer/connectors/logger_connector/result.py
+++ b/src/pytorch_lightning/trainer/connectors/logger_connector/result.py
@@ -398,7 +398,11 @@ class _ResultCollection(dict):
    def __init__(self, training: bool, device: Optional[Union[str, torch.device]] = None) -> None:
        super().__init__()
        self.training = training
-        self.device: Optional[Union[str, torch.device]] = device
+        if device:
+            device = torch.device(device)
+            if device.type == "xla":
+                device = torch.device("cpu")
+        self.device: Optional[torch.device] = device
        self.batch: Optional[Any] = None
        self.batch_size: Optional[int] = None
        self.dataloader_idx: Optional[int] = None
@@ -635,10 +639,14 @@ class _ResultCollection(dict):
    def to(self, *args: Any, **kwargs: Any) -> "_ResultCollection":
        """Move all data to the given device."""
-        self.update(apply_to_collection(dict(self), (Tensor, Metric), move_data_to_device, *args, **kwargs))
-
        if "device" in kwargs:
-            self.device = kwargs["device"]
+            device = torch.device(kwargs["device"])
+            if device.type == "xla":
+                kwargs["device"] = "cpu"
+                device = torch.device(kwargs["device"])
+            self.device = device
+
+        self.update(apply_to_collection(dict(self), (Tensor, Metric), move_data_to_device, *args, **kwargs))
        return self
    def cpu(self) -> "_ResultCollection":

How to reproduce the bug

No response

Error messages and logs


# Error messages and logs here please

Environment


#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 1.10):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

@Liyang90 Liyang90 added the needs triage Waiting to be triaged by maintainers label Dec 1, 2022
@awaelchli awaelchli added bug Something isn't working accelerator: tpu Tensor Processing Unit and removed needs triage Waiting to be triaged by maintainers labels Dec 5, 2022
@awaelchli awaelchli added this to the v1.8.x milestone Dec 5, 2022
@awaelchli
Copy link
Contributor

@Liyang90 I agree that we should insert the mark_step at the appropriate place. Your suggestion looks good and we can test this.

Contributions for the XLA strategy are very welcome btw! There are many improvements we should do, Lightning has fallen a bit behind and we should do more testing as well to make sure performance does not degrade.

@JackCaoG
Copy link

JackCaoG commented Dec 6, 2022

Thanks @Liyang90! This is great. Can you cc Steven and me in this kind of issue? I didn't find a good way to subscribe to labels :D

@awaelchli
Copy link
Contributor

Absolutely will do, thanks @JackCaoG! What is Steven's GitHub handle?

We also have a label notifier #10530. If you want, I can add you there to the label accelerator: tpu.

@JackCaoG
Copy link

JackCaoG commented Dec 7, 2022

Absolutely will do, thanks @JackCaoG! What is Steven's GitHub handle?

We also have a label notifier #10530. If you want, I can add you there to the label accelerator: tpu.

Yea label notifier would be nice. Steven's github is @steventk-g

@JackCaoG
Copy link

JackCaoG commented Dec 7, 2022

Thanks @awaelchli , can you also add @Liyang90 to the tpu label notifier?

@carmocca
Copy link
Contributor

@Liyang90 Thanks again for your thorough debugging. All your suggestions make sense. Feel free to open PRs to address each of these issues.

@Liyang90 Liyang90 mentioned this issue Dec 12, 2022
12 tasks
@Borda Borda modified the milestones: v1.8.x, v1.9 Jan 6, 2023
@Borda Borda modified the milestones: v1.9, v1.9.x Jan 16, 2023
@awaelchli
Copy link
Contributor

FYI, issue 2 should be resolved now. We recently removed TensorRunningAccum and the associated code that moved the tensor to cpu.

@Liyang90
Copy link
Contributor Author

We recently removed TensorRunningAccum and the associated code that moved the tensor to cpu.

That's great! How does the loss tracking work now?

@awaelchli
Copy link
Contributor

We no longer compute a running average of the loss, so no tracking of any kind. The loss returned from the user's training_step gets used for backward and nothing else. So I think this part should be well compatible with XLA's graph.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
accelerator: tpu Tensor Processing Unit bug Something isn't working
Projects
None yet
Development

Successfully merging a pull request may close this issue.

5 participants