-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
About AffineCoupling and AdditiveCoupling #63
Comments
@xuedue Hi, thanks for using MemCNN. Whereas the For example for your code: self.invertible_module = memcnn.AffineCoupling(
Fm=MLP(...),
Gm=MLP(...),
adapter=memcnn.AffineAdapterSigmoid
) I hope this solves your issue. |
Fm and Gm need to have input x and output y of the same shape. If I want to implement a reversible MLP with different input and output channels, what should I do? For example, the input of MLP is 100 dimensions, and the output is 2 dimensions. |
I assume here that with dimensions you mean the number of channels in the dimension you split on. A simple way of doing so is to make a reversible MLP with the same shape input and outputs, then at the end extract only the first 2 channels, using slicing: output = self.invertible_module(x)
output_reduced = output[:, :2, :] # depends a bit on the shape of your output A more complex way, if you want to keep it fully invertible, is to create a |
Why? Could you elaborate? Doesn't this work for your use case? Alternatively, you could also take the mean over the first 50 channels and the second 50 channels to reduce your output from (batch_size, 100) -> (batch_size, 2), if you're more comfortable with that (very similar to an average pooling operation). output_100_channels = reversible_mlp_network.forward(input_vector)
print(output_100_channels.shape) # output shape should be (batch_size, 100)
output_2_channels = torch.cat((torch.mean(output_100_channels[:, :50], dim=1, keepdim=True), torch.mean(output_100_channels[:, 50:], dim=1, keepdim=True)), dim=1)
print(output_2_channels .shape) # output shape should be (batch_size, 2) Could you maybe tell me a little bit more about which parts of your network do you want to make reversible? |
Ok, thanks for clarifying your question. First, I would suggest making layers 1-6 invertible. The final layer (7), does it have to be 1024 output features? Why not increase the previous layers (1-6) to 1024 features? Or decrease the out_features to 512? Both of those approaches would be the simplest in your case, so you can just wrap it like you did with layers (1-6). Otherwise, you could duplicate your output once, but I doubt that is what you would want. If you also want to make you first layer (1) invertible, this seems to be the most problematic. If this is really desirable, maybe consider making it have a matching number of An alternative for the first layer would be to try the very experimental |
Thank you very much for your reply What I want to explain is that I need to maintain the reversibility of the entire network, so that I can reversibly get 9216-dimensional input after operating on 1024-dimensional output. Here, the 9216-dimensional input is immutable and must include in the reversible network. My most important goal is to get this 9216-dimensional input through a reversible network. Of course, if I modify the 1024-dimensional output to 9216-dimensional output, the problem can be solved, but changing the input to 1024-dimensional is for better follow-up operations. Can the network be reversible if the input must be a 9216-dimensional vector and the output is a 1024-dimensional vector? If the input and output must be of the same dimension, then there will be relatively large limitations in practical applications. |
For as far as I know this can't be made reversible for the general case. That's because you still require all the elements of the final output to reconstruct the input (all 9216 features). So you can't throw anything away from the 9216-dimensional output vector without throwing away some information about how to reconstruct the input. I think this is because of the mathematical nature of the reversible couplings. The outputs are also partially encoding the inputs after all. I just tried to work it out using the crop and pad strategy I described earlier, but this still won't work because of the above reasons. After applying the padding inside the coupling the output is invertible, but after cropping it would no longer be the case. Of course, there are some non-practical special cases for which it would work, e.g. if the input dimension is a 1024-feature vector 9 times duplicated or a 1024-feature vector with the other 8192-features all being zero.
It depends on what you want. In the literature (and in practice), they tend to mix invertible operations and normal non-invertible operations for things that change output dimensions and shapes like pooling. MemCNN was designed with this mixing strategy in mind, so that you can easily turn invertible modules memory-efficient using the There are also some special techniques to change dimensions, like invertible down-sampling: Source: https://arxiv.org/pdf/1802.07088.pdf You might be able to do something similar by splitting your output of 9216 elements over 9 batches of 1024. |
Thank you very much for your explanation. |
Description
I want to implement the inversion function of MLP, when using AdditiveCoupling, It works
but when I change to AffineCoupling
In addition, I would like to ask, to achieve reversible MLP, do I need to halve the input and output channels in Fm and Gm, as I wrote above.
Look foward to your answer, thanks!
What I Did
The text was updated successfully, but these errors were encountered: