Skip to content

Commit

Permalink
perf: allow @torch.compile by avoiding in-place operation via clone()
Browse files Browse the repository at this point in the history
`@torch.compile` can speed up model training by 5% - 200%. Simply use:

```python
model = torch.compile(model)
```

This commit resolves the error that comes up when compiling _mask:

```none
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.
```
  • Loading branch information
YodaEmbedding authored and fracape committed Feb 2, 2024
1 parent 2489952 commit a494099
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions compressai/latent_codecs/checkerboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def _forward_twopass_step(
# Keep only elements needed for current step.
# It's not necessary to mask the rest out just yet, but it doesn't hurt.
params_i = self._keep_only(params_i, step)
y_i = self._keep_only(y.clone(), step)
y_i = self._keep_only(y, step)

# Determine y_hat for current step, and mask out the other pixels.
_, means_i = self.latent_codec["y"]._chunk(params_i)
Expand Down Expand Up @@ -387,12 +387,17 @@ def _copy(self, dest: Tensor, src: Tensor, step: str) -> None:
dest[..., 0::2, 1::2] = src[..., 0::2, 1::2]
dest[..., 1::2, 0::2] = src[..., 1::2, 0::2]

def _keep_only(self, y: Tensor, step: str) -> Tensor:
def _keep_only(self, y: Tensor, step: str, inplace: bool = False) -> Tensor:
"""Keep only pixels in the current step, and zero out the rest."""
parity = self.non_anchor_parity if step == "anchor" else self.anchor_parity
return self._mask(y, parity)
return self._mask(
y,
parity=self.non_anchor_parity if step == "anchor" else self.anchor_parity,
inplace=inplace,
)

def _mask(self, y: Tensor, parity: str) -> Tensor:
def _mask(self, y: Tensor, parity: str, inplace: bool = False) -> Tensor:
if not inplace:
y = y.clone()
if parity == "even":
y[..., 0::2, 0::2] = 0
y[..., 1::2, 1::2] = 0
Expand Down

0 comments on commit a494099

Please sign in to comment.