Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: SpectralConv2d, SpectralConvTranspose2d
Introduced in ["Efficient Nonlinear Transforms for Lossy Image Compression"][Balle2018efficient] by Johannes Ballé, PCS 2018. Reparameterizes the weights to be derived from weights stored in the frequency domain. In the original paper, this is referred to as "spectral Adam" or "Sadam" due to its effect on the Adam optimizer update rule. The motivation behind representing the weights in the frequency domain is that optimizer updates/steps may now affect all frequencies to an equal amount. This improves the gradient conditioning, thus leading to faster convergence and increased stability at larger learning rates. For comparison, see the TensorFlow Compression implementations of [`SignalConv2D`] and [`RDFTParameter`]. They seem to use `SignalConv2d` in most of their provided architectures: https://github.com/search?q=repo%3Atensorflow%2Fcompression+Conv2D&type=code Furthermore, since this is a simple invertible transformation on the weights, it is trivial to convert any existing pretrained weights into this form via: ```python self.weight_freq = nn.Parameter(self._to_transform_domain(self.weight)) self.weight = property(self._get_weight) ``` To override `self.weight` as a property, I'm unregistering the module using `del self._parameters["weight"]` as shown in pytorch/pytorch#46886, and also [using the fact][property-descriptor-so] that `@property` [returns a descriptor object][property-descriptor-docs] so that `self.weight` "falls back" to the property. ```python def __init__(self, ...): self.weight_freq = nn.Parameter(self._to_transform_domain(self.weight)) del self._parameters["weight"] # Unregister weight, and fallback to property. @Property def weight(self) -> Tensor: return self._from_transform_domain(self.weight_freq) ``` [Balle2018efficient]: https://arxiv.org/abs/1802.00847 [`SignalConv2D`]: https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/signal_conv.py#L61 [`RDFTParameter`]: https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/parameters.py#L71 [property-descriptor-docs]: https://docs.python.org/3/howto/descriptor.html#properties [property-descriptor-so]: https://stackoverflow.com/a/17330273/365102 [`eval` mode]: https://stackoverflow.com/a/51433411/365102
- Loading branch information