You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I notice that my training is 2 times slower when using the InvertibleModuleWrapper on a nn.module fn for which the inverse does not necessitate computations.
What I Did
I define
class Fwd(nn.Module):
def __init__(self, function):
super().__init__()
self.function = function
def forward(self, x):
return x + self.function(x)
def inverse(self, x):
return torch.zeros_like(x)
then I define neural nets of depth d for a list functions of modules by using:
self.nets = [
InvertibleModuleWrapper(
Fwd(functions[i]),
num_bwd_passes=1,
disable=False,
)
for i in range(depth)
]
The obtained model is 2x slower to train. However, I would expect a similar running time. Any idea on how to solve this ? Many thanks
The text was updated successfully, but these errors were encountered:
Hi, just a small comment on your example code is that the inverse of the Fwd Module looks wrong, since 0 isn't the inverse of x + f(x). I assume you just used this for testing purposes. In that case I would suggest you use something like the following which has a proper inverse method:
If you want to use arbitrary non-invertible PyTorch functions like the Fwd method and turn them invertible you should use one of the provided couplings (Additive or Sigmoid), before wrapping them inside the InvertibleModuleWrapper:
Let me if you still experience such a drop in performance after you change the example. A bit of a performance drop is expected, since extra operations are required to recompute the inverses, which should be comparable to using checkpoints.
Description
I notice that my training is 2 times slower when using the InvertibleModuleWrapper on a nn.module fn for which the inverse does not necessitate computations.
What I Did
I define
then I define neural nets of depth d for a list functions of modules by using:
The obtained model is 2x slower to train. However, I would expect a similar running time. Any idea on how to solve this ? Many thanks
The text was updated successfully, but these errors were encountered: