Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slower training with trivial inverse #77

Open
michaelsdr opened this issue Aug 12, 2023 · 2 comments
Open

Slower training with trivial inverse #77

michaelsdr opened this issue Aug 12, 2023 · 2 comments
Assignees
Labels
question Further information is requested

Comments

@michaelsdr
Copy link

michaelsdr commented Aug 12, 2023

  • MemCNN version: 1.5.2
  • PyTorch version: 2.0.1
  • Python version: 3.8.3
  • Operating System: mac

Description

I notice that my training is 2 times slower when using the InvertibleModuleWrapper on a nn.module fn for which the inverse does not necessitate computations.

What I Did

I define

class Fwd(nn.Module):
    def __init__(self, function):
        super().__init__()
        self.function = function

    def forward(self, x):
        return x + self.function(x) 

    def inverse(self, x):
        return torch.zeros_like(x)

then I define neural nets of depth d for a list functions of modules by using:

        self.nets = [
                InvertibleModuleWrapper(
                    Fwd(functions[i]),
                    num_bwd_passes=1,
                    disable=False,
                )
                for i in range(depth)
            ]

The obtained model is 2x slower to train. However, I would expect a similar running time. Any idea on how to solve this ? Many thanks

@silvandeleemput
Copy link
Owner

Hi, just a small comment on your example code is that the inverse of the Fwd Module looks wrong, since 0 isn't the inverse of x + f(x). I assume you just used this for testing purposes. In that case I would suggest you use something like the following which has a proper inverse method:

class MultiplicationInverse(torch.nn.Module):
    def __init__(self, factor=2):
        super(MultiplicationInverse, self).__init__()
        self.factor = torch.nn.Parameter(torch.ones(1) * factor)

    def forward(self, x):
        return x * self.factor

    def inverse(self, y):
        return y / self.factor

If you want to use arbitrary non-invertible PyTorch functions like the Fwd method and turn them invertible you should use one of the provided couplings (Additive or Sigmoid), before wrapping them inside the InvertibleModuleWrapper:

import memcnn
import torch

invertible_coupling = memcnn.AdditiveCoupling(
    Fm=Fwd(functions[i]),
    Gm=Fwd(functions[i])
)

Let me if you still experience such a drop in performance after you change the example. A bit of a performance drop is expected, since extra operations are required to recompute the inverses, which should be comparable to using checkpoints.

@silvandeleemput silvandeleemput self-assigned this Sep 14, 2023
@silvandeleemput silvandeleemput added the question Further information is requested label Sep 14, 2023
@michaelsdr
Copy link
Author

Hi, many thanks for your help, I will let you know if there are any troubles

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants