Skip to content

Commit

Permalink
feat: SpectralConv2d, SpectralConvTranspose2d
Browse files Browse the repository at this point in the history
Introduced in
["Efficient Nonlinear Transforms for Lossy Image Compression"][Balle2018efficient]
by Johannes Ballé, PCS 2018.
Reparameterizes the weights to be derived from weights stored in the
frequency domain.
In the original paper, this is referred to as "spectral Adam" or "Sadam"
due to its effect on the Adam optimizer update rule.
The motivation behind representing the weights in the frequency domain
is that optimizer updates/steps may now affect all frequencies to an
equal amount.
This improves the gradient conditioning, thus leading to faster
convergence and increased stability at larger learning rates.

For comparison, see the TensorFlow Compression implementations of
[`SignalConv2D`] and [`RDFTParameter`].
They seem to use `SignalConv2d` in most of their provided architectures:
https://github.com/search?q=repo%3Atensorflow%2Fcompression+Conv2D&type=code

Furthermore, since this is a simple invertible transformation on the
weights, it is trivial to convert any existing pretrained weights into
this form via:

```python
self.weight_freq = nn.Parameter(self._to_transform_domain(self.weight))
self.weight = property(self._get_weight)
```

To override `self.weight` as a property, I'm unregistering the module
using `del self._parameters["weight"]` as shown in
pytorch/pytorch#46886, and also
[using the fact][property-descriptor-so] that `@property`
[returns a descriptor object][property-descriptor-docs]
so that `self.weight` "falls back" to the property.

```python
    def __init__(self, ...):
        self.weight_freq = nn.Parameter(self._to_transform_domain(self.weight))
        del self._parameters["weight"]  # Unregister weight, and fallback to property.

    @Property
    def weight(self) -> Tensor:
        return self._from_transform_domain(self.weight_freq)
```

[Balle2018efficient]: https://arxiv.org/abs/1802.00847
[`SignalConv2D`]: https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/signal_conv.py#L61
[`RDFTParameter`]: https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/parameters.py#L71
[property-descriptor-docs]: https://docs.python.org/3/howto/descriptor.html#properties
[property-descriptor-so]: https://stackoverflow.com/a/17330273/365102
[`eval` mode]: https://stackoverflow.com/a/51433411/365102
  • Loading branch information
YodaEmbedding committed Oct 18, 2023
1 parent a4ae2ee commit f77792e
Showing 1 changed file with 62 additions and 1 deletion.
63 changes: 62 additions & 1 deletion compressai/layers/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

from typing import Any
from typing import Any, Tuple

import torch
import torch.nn as nn
Expand All @@ -43,12 +43,73 @@
"ResidualBlock",
"ResidualBlockUpsample",
"ResidualBlockWithStride",
"SpectralConv2d",
"SpectralConvTranspose2d",
"conv3x3",
"subpel_conv3x3",
"QReLU",
]


class _SpectralConvNdMixin:
def __init__(self, dim: Tuple[int, ...]):
self.dim = dim
self.weight_freq = nn.Parameter(self._to_transform_domain(self.weight))
del self._parameters["weight"] # Unregister weight, and fallback to property.

@property
def weight(self) -> Tensor:
return self._from_transform_domain(self.weight_freq)

def _to_transform_domain(self, x: Tensor) -> Tensor:
return torch.fft.rfftn(x, s=self.kernel_size, dim=self.dim, norm="ortho")

def _from_transform_domain(self, x: Tensor) -> Tensor:
return torch.fft.irfftn(x, s=self.kernel_size, dim=self.dim, norm="ortho")


class SpectralConv2d(nn.Conv2d, _SpectralConvNdMixin):
r"""Spectral 2D convolution.
Introduced in [Balle2018efficient].
Reparameterizes the weights to be derived from weights stored in the
frequency domain.
In the original paper, this is referred to as "spectral Adam" or
"Sadam" due to its effect on the Adam optimizer update rule.
The motivation behind representing the weights in the frequency
domain is that optimizer updates/steps may now affect all
frequencies to an equal amount.
This improves the gradient conditioning, thus leading to faster
convergence and increased stability at larger learning rates.
For comparison, see the TensorFlow Compression implementations of
`SignalConv2D
<https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/signal_conv.py#L61>`_
and
`RDFTParameter
<https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/parameters.py#L71>`_.
[Balle2018efficient]: `"Efficient Nonlinear Transforms for Lossy
Image Compression" <https://arxiv.org/abs/1802.00847>`_,
by Johannes Ballé, PCS 2018.
"""

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
_SpectralConvNdMixin.__init__(self, dim=(-2, -1))


class SpectralConvTranspose2d(nn.ConvTranspose2d, _SpectralConvNdMixin):
r"""Spectral 2D transposed convolution.
Transposed version of :class:`SpectralConv2d`.
"""

def __init__(self, *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
_SpectralConvNdMixin.__init__(self, dim=(-2, -1))


class MaskedConv2d(nn.Conv2d):
r"""Masked 2D convolution implementation, mask future "unseen" pixels.
Useful for building auto-regressive network components.
Expand Down

0 comments on commit f77792e

Please sign in to comment.