Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Smooth MaxPool2D rule #181

Open
rachtibat opened this issue Apr 5, 2023 · 2 comments
Open

Smooth MaxPool2D rule #181

rachtibat opened this issue Apr 5, 2023 · 2 comments
Labels
enhancement New feature or request

Comments

@rachtibat
Copy link
Contributor

Hey,

we'd like to add a new rule that smooths the MaxPool2D operation by replacing it by an AveragePool2D backward pass:

class SmoothMaxPool2dRule(BasicHook):

    def __init__(self, epsilon=1e-6, zero_params=None):
        stabilizer_fn = Stabilizer.ensure(epsilon)
        super().__init__(
            gradient_mapper=(lambda out_grad, outputs: out_grad / stabilizer_fn(outputs[0])),
            reducer=(lambda inputs, gradients: inputs[0] * gradients[0]),
        )

    def backward(self, module, grad_input, grad_output):
        '''Backward hook to compute LRP based on the class attributes.'''
        original_input = self.stored_tensors['input'][0].clone()
        inputs, outputs = [], []
        kernel_size = module.kernel_size
        stride = module.stride
        padding = module.padding
        
        input = original_input.requires_grad_()
        with torch.autograd.enable_grad():
            output = F.avg_pool2d(input, kernel_size, stride, padding, ceil_mode=False, count_include_pad=True, divisor_override=None)
        inputs.append(input)
        outputs.append(output)
        
        grad_outputs = self.gradient_mapper(grad_output[0], outputs)
        gradients = torch.autograd.grad(
            outputs,
            inputs,
            grad_outputs=grad_outputs,
            create_graph=grad_output[0].requires_grad
        )
        relevance = self.reducer(inputs, gradients)
        return tuple(relevance if original.shape == relevance.shape else None for original in grad_input)

You can test the code with

import torch.nn as nn
from zennit.rules import *
from zennit.core import BasicHook
import torch.nn.functional as F

if __name__ == "__main__":
    
    input = torch.linspace(0, 35, 36).view(1, 1, 6, 6).requires_grad_()

    layer = nn.MaxPool2d(2, 2, 0)
    norm_rule = Norm()
    h = norm_rule.register(layer)

    output = layer(input)
    grad, = torch.autograd.grad(output, input, torch.ones_like(output))
    h.remove()

    print(input)
    print(output)
    print(grad)

    print("###")

    rule = SmoothMaxPool2dRule()
    h = rule.register(layer)

    output = layer(input)
    grad, = torch.autograd.grad(output, input, torch.ones_like(output))
    h.remove()

    print(input)
    print(output)
    print(grad)

Do you think that's fine? I can create a pull request if you want.

Best,
Reduan

@chr5tphr
Copy link
Owner

chr5tphr commented Apr 5, 2023

Hey Reduan,

thanks for the issue as always!

I think having a way to use the AvgPool2d gradient for MaxPool2d layers is a must-have.
I have some proof-of-concept code which I implemented back in the day to directly and explicitly compute the avg-pool gradient with MaxPool parameters using transposed convolutions.

While going over your code and seeing the BasicHook.backward structure copied, I had the idea that we could also add a layer of abstraction above ParamMod: a ModuleMod or FuncMod, which is a general modifier of the forward function.
This way, one could add very flexible custom rules based on BasicHook, not only limited to the parameters of the module, which would be especially useful for parameter-less modules like MaxPool.

I have a different approach of attributing MaxPool in the pipeline, which could benefit from this approach. Do you maybe know of another use-case for arbitrary function override? Or maybe @sebastian-lapuschkin ?

If it is only for MaxPool, implementing an explicit rule based on Hook may be better, where we could instead use my existing proof-of-concept code. Although, and I guess that's why you based this off BasicHook rather than Hook, stabilizer would not automatically be part of the rule, which I think may not be necessary for pooling anyway.

As for the name, maybe its better to call it something like AvgPoolRule, since for AvgPool this would also be correct, although one could just use the EpsilonRule there.

@rachtibat
Copy link
Contributor Author

Hey,

thank you for your prompt and thoughtful response as always.
I like the idea to add a FuncMod.

I ask Sebastian, and he told me that another use-case would be to change the 1x1 CNN downsample layer with stride=2 in ResNets that also creates such a checkerboard pattern.
See:
image

The question is, if we should implement it with a FuncMod.

A spontaneous idea that would change the backward pass function instead:

  1. Compute Relevance normally. As a result, only every second or forth pixel would get relevance, the others are zero.
  2. Take only the relevance pixels and write them in a smaller image with 1/4 size.
  3. Average Upsample the image to the original size

With a FuncMod we could do:

  1. Take the pixels that would be selected by the downsample layer and repeat them 4 times to the original input size by overwriting the ignored pixels.
  2. Do a 2x2 downsample with 4 times bigger kernel but 1/4th kernel values

Best

@chr5tphr chr5tphr added the enhancement New feature or request label Aug 11, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

2 participants