-
Notifications
You must be signed in to change notification settings - Fork 232
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: SpectralConv2d, SpectralConvTranspose2d #258
feat: SpectralConv2d, SpectralConvTranspose2d #258
Conversation
Introduced in ["Efficient Nonlinear Transforms for Lossy Image Compression"][Balle2018efficient] by Johannes Ballé, PCS 2018. Reparameterizes the weights to be derived from weights stored in the frequency domain. In the original paper, this is referred to as "spectral Adam" or "Sadam" due to its effect on the Adam optimizer update rule. The motivation behind representing the weights in the frequency domain is that optimizer updates/steps may now affect all frequencies to an equal amount. This improves the gradient conditioning, thus leading to faster convergence and increased stability at larger learning rates. For comparison, see the TensorFlow Compression implementations of [`SignalConv2D`] and [`RDFTParameter`]. They seem to use `SignalConv2d` in most of their provided architectures: https://github.com/search?q=repo%3Atensorflow%2Fcompression+Conv2D&type=code Furthermore, since this is a simple invertible transformation on the weights, it is trivial to convert any existing pretrained weights into this form via: ```python weight_transformed = self._to_transform_domain(weight) ``` To override `self.weight` as a property, I'm unregistering the module using `del self._parameters["weight"]` as shown in pytorch/pytorch#46886, and also [using the fact][property-descriptor-so] that `@property` [returns a descriptor object][property-descriptor-docs] so that `self.weight` "falls back" to the property. ```python def __init__(self, ...): self.weight_transformed = nn.Parameter(self._to_transform_domain(self.weight)) del self._parameters["weight"] # Unregister weight, and fallback to property. @Property def weight(self) -> Tensor: return self._from_transform_domain(self.weight_transformed) ``` [Balle2018efficient]: https://arxiv.org/abs/1802.00847 [`SignalConv2D`]: https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/signal_conv.py#L61 [`RDFTParameter`]: https://github.com/tensorflow/compression/blob/v2.14.0/tensorflow_compression/python/layers/parameters.py#L71 [property-descriptor-docs]: https://docs.python.org/3/howto/descriptor.html#properties [property-descriptor-so]: https://stackoverflow.com/a/17330273/365102 [`eval` mode]: https://stackoverflow.com/a/51433411/365102
f77792e
to
243bec9
Compare
SpectralConv2d vs Conv2d mini-experimentsBelow are a couple of example runs (with different randomly initialized kernels) to compare SpectralConv2d vs Conv2d. (Ignore the inaccurate titles; I accidentally set the title for everything to "smoothing kernel".) sinusoidal channel averaging: conv_kwargs = dict(in_channels=3, out_channels=2, kernel_size=11, padding=5)
def init_conv_target(conv):
k = 0.05 * torch.linspace(0, 2 * torch.pi, conv.kernel_size[0], device=device).sin()
conv.weight.data[:] = k Simple depthwise cosine "smoothing" kernel: conv_kwargs = dict(in_channels=8, out_channels=8, kernel_size=11, padding=5)
def init_conv_target(conv):
k = torch.linspace(0, torch.pi, conv.kernel_size[0], device=device).sin()
k = k * k[:, None]
k = k / k.sum()
idx = torch.arange(conv.in_channels, device=device)
conv.weight.data[:] = 0
conv.weight.data[idx, idx, :, :] = k Simple depthwise vertical edge-detector kernel: conv_kwargs = dict(in_channels=4, out_channels=4, kernel_size=5, padding=5, groups=4)
def init_conv_target(conv):
K = conv.kernel_size[0]
k = 2 * torch.linspace(-1, 1, K, device=device) / K**2
conv.weight.data[:] = k Random initialization: conv_kwargs = dict(in_channels=4, out_channels=4, kernel_size=5, padding=5)
def init_conv_target(conv):
pass # Maintain random initialization. Surprisingly, SpectralConv2d is not worse...! Figures generated via: import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from compressai.layers import SpectralConv2d
device = "cuda"
def train(conv, conv_target, max_steps=1000, lr=1e-4, batch_size=16):
losses = []
for step in range(max_steps):
conv.zero_grad()
x = torch.rand((batch_size, conv.in_channels, 256, 256)).to(device)
y = conv_target(x).detach()
y_hat = conv(x)
loss = ((y - y_hat) ** 2).mean()
loss.backward()
for param in conv.parameters():
param.data -= param.grad * lr
print(f"step {step:04d} loss {loss.item():.4f}")
losses.append(loss.item())
return losses
models = {
"spectral": SpectralConv2d(**conv_kwargs),
"regular": nn.Conv2d(**conv_kwargs),
}
# Initialize all models to exact same random initialization.
conv_rand = nn.Conv2d(**conv_kwargs).to(device)
init_weight = conv_rand.weight.data
init_bias = conv_rand.bias.data
for key, conv in models.items():
conv.to(device)
if isinstance(conv, SpectralConv2d):
conv.weight_transformed.data = conv._to_transform_domain(init_weight).clone()
else:
conv.weight.data = init_weight.clone()
conv.bias.data = init_bias.clone()
# Initialize "ideal"/target model kernels using the functions defined above.
conv_target = nn.Conv2d(**conv_kwargs).to(device)
init_conv_target(conv_target)
results = {key: train(conv, conv_target) for key, conv in models.items()}
fig, ax = plt.subplots()
for key, y in results.items():
ax.plot(y, label=key)
ax.legend()
ax.set(xlabel="step", ylabel="loss", title="Fitting to a target kernel")
fig.savefig("fit_spectralconv2d.png") Random musings: I wonder why no one's applied this to other image problems (classification/superresolution/etc)? Those problems should also be concerned with "frequency regularized" kernels (so to speak). I wonder if this could be turned into a short "paper" with more interesting experiments... Or if not, perhaps at least a "note" (1, 2) on arxiv. Also, another related thought that I had once upon a time: Generate the kernel from a different basis, e.g. the first I think I like Balle's insight of using a big DFT and calling it a day, though. I guess one more thing we could do as an extension to Balle's "Spectral" reparameterization is something like: def _from_transform_domain(self, w: Tensor) -> Tensor:
# Attenuate out high-frequencies.
# Not sure if correctly written, but this is intended to keep the lower frequencies, and push the others to 0.
# d/dw is then [hopefully larger for the lower frequencies, but I should check if I wrote this correctly...].
yy = torch.linspace(1, 0.1, w.shape[-2], device=w.device)[None, :]
xx = torch.linspace(1, 0.1, w.shape[-1], device=w.device)
mask = yy * xx
w = w * mask
# Balle's original spectral reparametrization.
return torch.fft.irfftn(w, s=self.kernel_size, dim=self.dim, norm="ortho") Or if high frequencies are also important, do this "regularizing" reparameterization for only some portion of the weights. That way, the network is forced into a balance of both, inhabits a TODO: Create a test suite of target kernels like the above, or pulled from trained networks, and see how quickly different |
SpectralConv2d vs Conv2d experimentsFirst attempt
However, even though the inference/test loss is evidently much smaller, the position of the point on the RD plot is actually less "optimal" w.r.t. the default configuration:
One possibility to mitigate this might be to initially train with the spectral transform enabled for the first 30 epochs, and then disable it for the remainder of the training. Perhaps the "inaccurate" reconstruction loss (
Click to see code for generating loss contour plot.import json
import matplotlib.pyplot as plt
import numpy as np
RESULTS_DIR = "/home/mulhaq/code/research/compressai/master/results"
CODECS = [f"{RESULTS_DIR}/image/kodak/compressai-bmshj2018-factorized_mse_cuda.json"]
def mse_to_psnr(mse, max_value=1.0):
return -10 * np.log10(mse / max_value**2)
def psnr_to_mse(psnr, max_value=1.0):
return 10 ** (-psnr / 10) * max_value**2
def main():
xlim = (0, 2.25)
ylim = (26, 41)
x = np.linspace(*xlim, 50)
y = np.linspace(*ylim, 50)
R = x
D = psnr_to_mse(y)
lmbdas = [0.0067]
cmaps = ["Greys"]
# lmbdas = [0.0018, 0.0035, 0.0067, 0.0130, 0.0250, 0.0483, 0.0932, 0.1800]
# cmaps = ["Reds", "Oranges", "YlOrBr", "Greens", "Blues", "Purples", "Greys", "RdPu"]
levels = np.logspace(-2, 0.5, 200)
fig, ax = plt.subplots(figsize=(8, 6))
for lmbda, cmap in zip(lmbdas, cmaps):
loss = R + lmbda * 255**2 * D[:, None]
im = ax.contour(x, y, loss, levels=levels, cmap=cmap)
cbar = fig.colorbar(im, ax=ax, fraction=0.08, pad=0.01)
cbar.set_ticks([round(tick, 2) for tick in cbar.ax.get_yticks()])
# RD curves.
for codec in CODECS:
with open(codec, "r") as f:
data = json.load(f)
ax.plot(
data["results"]["bpp"],
data["results"]["psnr-rgb"],
".-",
label=data["name"],
)
# Custom points.
ax_kwargs = dict(zorder=100, s=8)
series = [
dict(
x=[0.308173],
y=[29.9231],
label="Conv2d (compress/decompress)",
color="C1",
marker="*",
),
dict(
x=[0.322696],
y=[30.0040],
label="SpectralConv2d (compress/decompress)",
color="C6",
marker="*",
),
dict(
x=[0.307627],
y=[29.1858],
label="Conv2d (forward) (no clamp)",
color="C1",
),
dict(
x=[0.322088],
y=[29.3400],
label="SpectralConv2d (forward) (no clamp)",
color="C6",
),
]
for series_i in series:
ax.scatter(**series_i, **ax_kwargs)
# Finalize.
ax.set(
xlabel="Bit-rate [bpp]",
ylabel="PSNR [dB]",
title="Loss surface",
xlim=xlim,
ylim=ylim,
)
ax.legend(loc="lower right", fontsize="small")
fig.savefig("loss_surface.png", dpi=300)
if __name__ == "__main__":
main() Still, though... a point with lower Code wars episodes I-III.EDIT I (The EDIT II (The EDIT III (Revenge of the EDIT IV (A def forward(...):
x_hat = self.g_s(y_hat)
def decompress(...):
x_hat = self.g_s(y_hat).clamp_(0, 1) By EDIT V (The scheduler strikes back): According to the training losses, it looks like the scheduler:
net:
type: "ReduceLROnPlateau"
mode: "min"
factor: 0.1
patience: 10
threshold: 5e-4 # default is 1e-4 The CompressAI Trainer config is using the defaults. It's just my personal config that has the EDIT VI (Return of the multi-stage training): Perhaps best is to initially train with spectral-mode, then auto-schedule switching to regular-mode some time before the first LR drop occurs. That may (or may not?) require changing the training code, however, i.e., a new Second attempt
Much better. Multi-stage training might still be a good idea, though, since at around |
merged together with #270 |
Introduced in "Efficient Nonlinear Transforms for Lossy Image Compression" by Johannes Ballé, PCS 2018. Reparameterizes the weights to be derived from weights stored in the frequency domain. In the original paper, this is referred to as "spectral Adam" or "Sadam" due to its effect on the Adam optimizer update rule. The motivation behind representing the weights in the frequency domain is that optimizer updates/steps may now affect all frequencies to an equal amount. This improves the gradient conditioning, thus leading to faster convergence and increased stability at larger learning rates.
For comparison, see the TensorFlow Compression implementations of
SignalConv2D
andRDFTParameter
. They seem to useSignalConv2d
in most of their provided architectures:https://github.com/search?q=repo%3Atensorflow%2Fcompression+Conv2D&type=code
Furthermore, since this is a simple invertible transformation on the weights, it is trivial to convert any existing pretrained weights into this form via:
To override
self.weight
as a property, I'm unregistering the module usingdel self._parameters["weight"]
as shown in pytorch/pytorch#46886, and also using the fact that@property
returns a descriptor object so thatself.weight
"falls back" to the property.Checklist:
SpectralConvTransposed2d
is defined correctly... Sincenn.ConvTransposed2d
inherits from the same_ConvNd
asnn.Conv2d
with minor adjustments, I assumed that the weights are treated in the same way.load_state_dict
to be able to load/write checkpoints in a way that is compatible with the original (non-reparametrized) model definitions.torch.float32
instead oftorch.complex64
. Or maybe it's just the FFT which ruins things. It would be nice if there were a way to skip compilation for certain incompatible operations, particularly since we're merely taking a FFT of leaf tensors, not intermediate computations...CompressionModel.update()
. Or even easier: when theSpectralConv2d
module is ineval
mode, just cache the computed weight, and invalidate/clear the cache when the module switches back totrain
mode.SpectralConv2d
on Google, so this should be an OK name. Tensorflow Compression calls their conv layersSignalConv2d
. Perhaps the most "accurate" name would beSpectralReparametrizedConv2d
, but that's a bit wordy.