Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: SpectralConv2d, SpectralConvTranspose2d #258

Closed

Conversation

YodaEmbedding
Copy link
Contributor

@YodaEmbedding YodaEmbedding commented Oct 18, 2023

Introduced in "Efficient Nonlinear Transforms for Lossy Image Compression" by Johannes Ballé, PCS 2018. Reparameterizes the weights to be derived from weights stored in the frequency domain. In the original paper, this is referred to as "spectral Adam" or "Sadam" due to its effect on the Adam optimizer update rule. The motivation behind representing the weights in the frequency domain is that optimizer updates/steps may now affect all frequencies to an equal amount. This improves the gradient conditioning, thus leading to faster convergence and increased stability at larger learning rates.

spectral_adam

spectral_adam_rd_curves

For comparison, see the TensorFlow Compression implementations of SignalConv2D and RDFTParameter. They seem to use SignalConv2d in most of their provided architectures:
https://github.com/search?q=repo%3Atensorflow%2Fcompression+Conv2D&type=code

Furthermore, since this is a simple invertible transformation on the weights, it is trivial to convert any existing pretrained weights into this form via:

weight_transformed = self._to_transform_domain(weight)

To override self.weight as a property, I'm unregistering the module using del self._parameters["weight"] as shown in pytorch/pytorch#46886, and also using the fact that @property returns a descriptor object so that self.weight "falls back" to the property.

    def __init__(self, ...):
        self.weight_transformed = nn.Parameter(self._to_transform_domain(self.weight))
        del self._parameters["weight"]  # Unregister weight, and fallback to property.

    @property
    def weight(self) -> Tensor:
        return self._from_transform_domain(self.weight_transformed)

Checklist:

  • Verify: See if training from scratch actually improves performance or training times by appropriately replacing conv/deconv in existing models.
  • Verify: Check that SpectralConvTransposed2d is defined correctly... Since nn.ConvTransposed2d inherits from the same _ConvNd as nn.Conv2d with minor adjustments, I assumed that the weights are treated in the same way.
  • Ergonomics: If we choose to offer reparametrized versions of already created models, they should override load_state_dict to be able to load/write checkpoints in a way that is compatible with the original (non-reparametrized) model definitions.
  • Performance (medium): See if PyTorch 2.0 Dynamo's model compilation works when we manually use 2x torch.float32 instead of torch.complex64. Or maybe it's just the FFT which ruins things. It would be nice if there were a way to skip compilation for certain incompatible operations, particularly since we're merely taking a FFT of leaf tensors, not intermediate computations...
  • Performance (minor): During inference time, the convolution weights can be precomputed, instead of having to run a FFT on them repeatedly. Perhaps this can be handled inside CompressionModel.update(). Or even easier: when the SpectralConv2d module is in eval mode, just cache the computed weight, and invalidate/clear the cache when the module switches back to train mode.
  • Naming: There's not too many collisions with existing usages of SpectralConv2d on Google, so this should be an OK name. Tensorflow Compression calls their conv layers SignalConv2d. Perhaps the most "accurate" name would be SpectralReparametrizedConv2d, but that's a bit wordy.

@YodaEmbedding YodaEmbedding marked this pull request as draft October 18, 2023 01:19
Introduced in
["Efficient Nonlinear Transforms for Lossy Image Compression"][Balle2018efficient]
by Johannes Ballé, PCS 2018.
Reparameterizes the weights to be derived from weights stored in the
frequency domain.
In the original paper, this is referred to as "spectral Adam" or "Sadam"
due to its effect on the Adam optimizer update rule.
The motivation behind representing the weights in the frequency domain
is that optimizer updates/steps may now affect all frequencies to an
equal amount.
This improves the gradient conditioning, thus leading to faster
convergence and increased stability at larger learning rates.

For comparison, see the TensorFlow Compression implementations of
[`SignalConv2D`] and [`RDFTParameter`].
They seem to use `SignalConv2d` in most of their provided architectures:
https://github.com/search?q=repo%3Atensorflow%2Fcompression+Conv2D&type=code

Furthermore, since this is a simple invertible transformation on the
weights, it is trivial to convert any existing pretrained weights into
this form via:

```python
weight_transformed = self._to_transform_domain(weight)
```

To override `self.weight` as a property, I'm unregistering the module
using `del self._parameters["weight"]` as shown in
pytorch/pytorch#46886, and also
[using the fact][property-descriptor-so] that `@property`
[returns a descriptor object][property-descriptor-docs]
so that `self.weight` "falls back" to the property.

```python
    def __init__(self, ...):
        self.weight_transformed = nn.Parameter(self._to_transform_domain(self.weight))
        del self._parameters["weight"]  # Unregister weight, and fallback to property.

    @Property
    def weight(self) -> Tensor:
        return self._from_transform_domain(self.weight_transformed)
```

[Balle2018efficient]: https://arxiv.org/abs/1802.00847
[`SignalConv2D`]: https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/signal_conv.py#L61
[`RDFTParameter`]: https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/parameters.py#L71
[property-descriptor-docs]: https://docs.python.org/3/howto/descriptor.html#properties
[property-descriptor-so]: https://stackoverflow.com/a/17330273/365102
[`eval` mode]: https://stackoverflow.com/a/51433411/365102
@YodaEmbedding
Copy link
Contributor Author

YodaEmbedding commented Oct 18, 2023

SpectralConv2d vs Conv2d mini-experiments

Below are a couple of example runs (with different randomly initialized kernels) to compare SpectralConv2d vs Conv2d. (Ignore the inaccurate titles; I accidentally set the title for everything to "smoothing kernel".)


sinusoidal channel averaging:

conv_kwargs = dict(in_channels=3, out_channels=2, kernel_size=11, padding=5)

def init_conv_target(conv):
    k = 0.05 * torch.linspace(0, 2 * torch.pi, conv.kernel_size[0], device=device).sin()
    conv.weight.data[:] = k

fit_spectralconv2d

fit_spectralconv2d

fit_spectralconv2d

fit_spectralconv2d


Simple depthwise cosine "smoothing" kernel:

conv_kwargs = dict(in_channels=8, out_channels=8, kernel_size=11, padding=5)

def init_conv_target(conv):
    k = torch.linspace(0, torch.pi, conv.kernel_size[0], device=device).sin()
    k = k * k[:, None]
    k = k / k.sum()
    idx = torch.arange(conv.in_channels, device=device)
    conv.weight.data[:] = 0
    conv.weight.data[idx, idx, :, :] = k

fit_spectralconv2d


Simple depthwise vertical edge-detector kernel:

conv_kwargs = dict(in_channels=4, out_channels=4, kernel_size=5, padding=5, groups=4)

def init_conv_target(conv):
    K = conv.kernel_size[0]
    k = 2 * torch.linspace(-1, 1, K, device=device) / K**2
    conv.weight.data[:] = k

fit_spectralconv2d


Random initialization:

conv_kwargs = dict(in_channels=4, out_channels=4, kernel_size=5, padding=5)

def init_conv_target(conv):
    pass  # Maintain random initialization.

fit_spectralconv2d

Surprisingly, SpectralConv2d is not worse...!


Figures generated via:

import torch
import torch.nn as nn
import matplotlib.pyplot as plt

from compressai.layers import SpectralConv2d

device = "cuda"


def train(conv, conv_target, max_steps=1000, lr=1e-4, batch_size=16):
    losses = []
    for step in range(max_steps):
        conv.zero_grad()
        x = torch.rand((batch_size, conv.in_channels, 256, 256)).to(device)
        y = conv_target(x).detach()
        y_hat = conv(x)
        loss = ((y - y_hat) ** 2).mean()
        loss.backward()
        for param in conv.parameters():
            param.data -= param.grad * lr
        print(f"step {step:04d}  loss {loss.item():.4f}")
        losses.append(loss.item())
    return losses


models = {
    "spectral": SpectralConv2d(**conv_kwargs),
    "regular": nn.Conv2d(**conv_kwargs),
}

# Initialize all models to exact same random initialization.
conv_rand = nn.Conv2d(**conv_kwargs).to(device)
init_weight = conv_rand.weight.data
init_bias = conv_rand.bias.data

for key, conv in models.items():
    conv.to(device)
    if isinstance(conv, SpectralConv2d):
        conv.weight_transformed.data = conv._to_transform_domain(init_weight).clone()
    else:
        conv.weight.data = init_weight.clone()
    conv.bias.data = init_bias.clone()

# Initialize "ideal"/target model kernels using the functions defined above.
conv_target = nn.Conv2d(**conv_kwargs).to(device)
init_conv_target(conv_target)

results = {key: train(conv, conv_target) for key, conv in models.items()}

fig, ax = plt.subplots()
for key, y in results.items():
    ax.plot(y, label=key)
ax.legend()
ax.set(xlabel="step", ylabel="loss", title="Fitting to a target kernel")
fig.savefig("fit_spectralconv2d.png")

Random musings:

I wonder why no one's applied this to other image problems (classification/superresolution/etc)? Those problems should also be concerned with "frequency regularized" kernels (so to speak).

I wonder if this could be turned into a short "paper" with more interesting experiments... Or if not, perhaps at least a "note" (1, 2) on arxiv.

Also, another related thought that I had once upon a time:

Generate the kernel from a different basis, e.g. the first $K=4$ 2D DCT basis elements. Assuming a single input and output channel, y = conv(x, weight=k(w)), where $w = (w_1, \ldots, w_K)$ are trainable weights and $k(w)$ generates the kernel via the weighted summation $k(w) = \sum_i w_i e_i$. ...Then, even large 7x7 or 9x9 kernels can be expressed by only a few parameters. Extend appropriately to multiple input/output channels. I guess this looks very related to "dynamic convolution" / "CondConv", but with 10x fewer parameters instead of 4x more parameters.

I think I like Balle's insight of using a big DFT and calling it a day, though. I guess one more thing we could do as an extension to Balle's "Spectral" reparameterization is something like:

    def _from_transform_domain(self, w: Tensor) -> Tensor:
        # Attenuate out high-frequencies.
        # Not sure if correctly written, but this is intended to keep the lower frequencies, and push the others to 0.
        # d/dw is then [hopefully larger for the lower frequencies, but I should check if I wrote this correctly...].
        yy = torch.linspace(1, 0.1, w.shape[-2], device=w.device)[None, :]
        xx = torch.linspace(1, 0.1, w.shape[-1], device=w.device)
        mask = yy * xx
        w = w * mask

        # Balle's original spectral reparametrization.
        return torch.fft.irfftn(w, s=self.kernel_size, dim=self.dim, norm="ortho")

Or if high frequencies are also important, do this "regularizing" reparameterization for only some portion of the weights. That way, the network is forced into a balance of both, inhabits a lower dimensional space well-conditioned/"regularized" space (c.f. dynamic conv's affine constraint), and is roughly just as expressible as an unparameterized network.

TODO: Create a test suite of target kernels like the above, or pulled from trained networks, and see how quickly different _from_transform_domain functions converge... Is there something that fits our networks better than Balle's unmodified spectral conv?

@YodaEmbedding
Copy link
Contributor Author

YodaEmbedding commented Oct 25, 2023

SpectralConv2d vs Conv2d experiments

First attempt

++model.name="bmshj2018-factorized" ++criterion.lmbda=0.0067 ++scheduler.net.threshold=5e-4

Training forward()-loss curves comparison
Blue: Conv2d. Green: SpectralConv2d.
Validation forward()-loss curves comparison
Blue: Conv2d. Green: SpectralConv2d. No clamp(0, 1) applied to x_hat.
Inference/test forward()-loss curves comparison
Blue: Conv2d. Green: SpectralConv2d. No clamp(0, 1) applied to x_hat.

However, even though the inference/test loss is evidently much smaller, the position of the point on the RD plot is actually less "optimal" w.r.t. the default configuration:

Conv2d SpectralConv2d

One possibility to mitigate this might be to initially train with the spectral transform enabled for the first 30 epochs, and then disable it for the remainder of the training.

Perhaps the "inaccurate" reconstruction loss ($D$) estimation is only a problem when "noise" quantization is used during training. Not sure. Maybe modifying the covariance structure (see Balle's paper, pg 2) also makes it easier for $g_s(y + \mathcal{N}(0, 1))$ to shift the distribution and generate relatively more "optimistic" $\hat{x}$ than rounding would...? Maybe STE might help...

Contour lines of equal loss for $\lambda=0.0067$
Click to see code for generating loss contour plot.
import json

import matplotlib.pyplot as plt
import numpy as np

RESULTS_DIR = "/home/mulhaq/code/research/compressai/master/results"

CODECS = [f"{RESULTS_DIR}/image/kodak/compressai-bmshj2018-factorized_mse_cuda.json"]


def mse_to_psnr(mse, max_value=1.0):
    return -10 * np.log10(mse / max_value**2)


def psnr_to_mse(psnr, max_value=1.0):
    return 10 ** (-psnr / 10) * max_value**2


def main():
    xlim = (0, 2.25)
    ylim = (26, 41)

    x = np.linspace(*xlim, 50)
    y = np.linspace(*ylim, 50)

    R = x
    D = psnr_to_mse(y)

    lmbdas = [0.0067]
    cmaps = ["Greys"]
    # lmbdas = [0.0018, 0.0035, 0.0067, 0.0130, 0.0250, 0.0483, 0.0932, 0.1800]
    # cmaps = ["Reds", "Oranges", "YlOrBr", "Greens", "Blues", "Purples", "Greys", "RdPu"]

    levels = np.logspace(-2, 0.5, 200)

    fig, ax = plt.subplots(figsize=(8, 6))

    for lmbda, cmap in zip(lmbdas, cmaps):
        loss = R + lmbda * 255**2 * D[:, None]
        im = ax.contour(x, y, loss, levels=levels, cmap=cmap)
        cbar = fig.colorbar(im, ax=ax, fraction=0.08, pad=0.01)
        cbar.set_ticks([round(tick, 2) for tick in cbar.ax.get_yticks()])

    # RD curves.
    for codec in CODECS:
        with open(codec, "r") as f:
            data = json.load(f)

        ax.plot(
            data["results"]["bpp"],
            data["results"]["psnr-rgb"],
            ".-",
            label=data["name"],
        )

    # Custom points.
    ax_kwargs = dict(zorder=100, s=8)

    series = [
        dict(
            x=[0.308173],
            y=[29.9231],
            label="Conv2d (compress/decompress)",
            color="C1",
            marker="*",
        ),
        dict(
            x=[0.322696],
            y=[30.0040],
            label="SpectralConv2d (compress/decompress)",
            color="C6",
            marker="*",
        ),
        dict(
            x=[0.307627],
            y=[29.1858],
            label="Conv2d (forward) (no clamp)",
            color="C1",
        ),
        dict(
            x=[0.322088],
            y=[29.3400],
            label="SpectralConv2d (forward) (no clamp)",
            color="C6",
        ),
    ]

    for series_i in series:
        ax.scatter(**series_i, **ax_kwargs)

    # Finalize.
    ax.set(
        xlabel="Bit-rate [bpp]",
        ylabel="PSNR [dB]",
        title="Loss surface",
        xlim=xlim,
        ylim=ylim,
    )

    ax.legend(loc="lower right", fontsize="small")

    fig.savefig("loss_surface.png", dpi=300)


if __name__ == "__main__":
    main()

Still, though... a point with lower $L$ actually has worse $(R,D)$ relative to the optimal achievable $(R, D)$ for a given model architecture?!

Code wars episodes I-III.

EDIT I (The eval menace): The plotted PSNR was measured using compress/decompress, whereas the loss was measured using forward's mse_loss, which is somehow much worse. What may be happening is that the loss shown in the first figure was probably being measured on the noise-quantized y_hat, rather than round-quantized y_hat. But doesn't eval mode disable noise quantization?

EDIT II (The .clone() wars): Eval mode is certainly set, according to https://github.com/catalyst-team/catalyst/blob/v22.04/catalyst/core/runner.py#L312 which calls model.train(mode=False) (which recursively updates all submodules to self.training=False). It is correctly using "dequantize" for valid/infer rather than "noise", so still I wonder what could be causing the worse mse_loss...

EDIT III (Revenge of the x_hat): For forward and compress/decompress, both y and y_hat are exactly the same, but x_hat isn't. Thus, the problem only occurs during g_s(y_hat).

EDIT IV (A new hope): Found it! There's no .clamp_(0, 1) in the forward, which makes sense during training.

def forward(...):
        x_hat = self.g_s(y_hat)

def decompress(...):
        x_hat = self.g_s(y_hat).clamp_(0, 1)

By clamp_ing the x_hat when not in training mode, both methods now give the exact same results. It's quite surprising that simple clamping causes such a big gain in PSNR. (Between 0.66 dB to 0.74 dB.) It's also weird that the SpectralConv2d-trained model only gets a smaller 0.66 dB jump but the Conv2d somehow enjoys 0.74 dB. But maybe that's just luck.

EDIT V (The scheduler strikes back): According to the training losses, it looks like the SpectralConv2d-trained model dropped LR earlier, whereas the Conv2d-trained model happily went on for an additional 100 epochs before the first LR drop! Here were the settings I used:

scheduler:
  net:
    type: "ReduceLROnPlateau"
    mode: "min"
    factor: 0.1
    patience: 10
    threshold: 5e-4  # default is 1e-4

The CompressAI Trainer config is using the defaults. It's just my personal config that has the 5e-4. Upon further reflection, it looks like that reduction in loss was slowing down anyways, so maybe scheduler.net.threshold=5e-4 is probably not that bad an idea. I'll give 1e-4 a try anyways.

EDIT VI (Return of the multi-stage training): Perhaps best is to initially train with spectral-mode, then auto-schedule switching to regular-mode some time before the first LR drop occurs.

That may (or may not?) require changing the training code, however, i.e., a new Runner? I wonder if PyTorch has schedulers that can work with regular modules, and not just optimizers. Otherwise, I guess I could create a pseudo-optimizer that doesn't actually optimize anything, but merely switches conv.enable_transform=False for all the convs when a pseudo-LR drop occurs... That's a convoluted way of avoiding adding a new runner. Maybe a callback is less convoluted. But then I have to write my own pseudo-LR drop logic...?!


Second attempt

++model.name="bmshj2018-factorized" ++criterion.lmbda=0.0067 (with default scheduler.net.threshold=1e-4)

Inference/test forward()-loss curves comparison
Green: Conv2d. Blue: SpectralConv2d. No clamp(0, 1) applied to x_hat.
Conv2d SpectralConv2d

Much better.

Multi-stage training might still be a good idea, though, since at around loss=0.655 (i.e. 40 epochs for SpectralConv2d) it looks like Conv2d starts converging more quickly than SpectralConv2d.

@fracape
Copy link
Collaborator

fracape commented Feb 28, 2024

merged together with #270

@fracape fracape closed this Feb 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants