-
Notifications
You must be signed in to change notification settings - Fork 35
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Issue about generated images #6
Comments
Unfortunately, I could not even train a model yet, because my compute ressources are allocated otherwise at the moment. Did you use your own codebase? If so, did you publish it somewhere? |
@wzmsltw I'm also building a custom model inspired by this paper on the CelebA dataset, and I found something similar happens. I think in my case it's still early in the training, and I get accuracies of around ~40% which make sense since during the training the masking follows a cosine scheduling (as the paper says) and the AUC of the cosine function is around 0.363, so 40% means it works slightly better than just guessing the unmasked tokens. What I find strange is what happens during the sampling. And as the sampling process continues, it would seem like the model is starting to generate something that looks like a face :D BUT at some point, the sampling makes the face start to fade away... And we end up with an almost empty image in most cases :/ looks like as I start to decrease the temperature during sampling, the samples start to collapse to an empty background or idk... Maybe something similar is happening to you... I don't know if this will be fixed with more training, or it's a problem in the sampling procedure... wdyt? |
Unfortunately no. I have the feeling that the issue is with the sampling method, but it might also be related with how the tokens are masked during training :/ |
May I ask if you used my implementation or your own? If you used mine then it might also be something implementation specific. However if we both get the same white images with different implementations, then yes there might be something wrong with the overall sampling strategy. |
I used my own implementation |
I cannot still reproduce the results in the paper after 300 epochs of training on ImageNet. However, I fine that the temperature annealing is a key for the performance of MaskGIT and diversity of generated samples. When I have not used the temperature annealing, I got 32.87 FID. However, when I have used the temperature annealing, which linearly decrease the temperature of logits from 3.0 into 1.0, I got 20.26. When I train MaskGIT on FFHQ, very simple (but high quality) images are generated. However, due to the simplicity of generated images, the recall of the trained model is very low and FID is over 100. I think that many tricks are required to train MaskGit, but the details are not described in the paper. Especially, temperature annealing is very very important trick to decrease FID, but the authors did not describe the details.. How can I believe the scores in the original paper... |
The samples I shared used temperature annealing as well, but I still don't get very good results. |
@pabloppp @LeeDoYup Maybe if you are interested we can make a group on discord and report new findings. I would also be interested in your Transformer implementation. I guess mine is so simplistic. So if you are up to it you can add me on discord: dome#8231 Also the authors are referencing BEiT which uses a slightly different way of training. Even though the authors clearly described their way of training, maybe using the approach from BEiT could result in improvements. Have you tried anything like this? |
I have not tried anything like BEiT, in fact, my architecture is pretty different from the one proposed in the paper. What I tried to follow as close as possible were the losses, training schedules & sampling schedules. |
I do not have discord account, so i will try to create soon. |
It shouldn't be relevant, but my model is conditional, so I add an identity embedding to the input. My goal was to be able to control to some degree the generation, so I can ask the model to generate a specific face instead of just random + help the model since conditional generation is usually way easier for generative models. The model seems to be able to use that information up to some point, like generating male/female faces depending on the reference image but does a lot of random generation as well. |
Hello! Small update: I just tried adding typical filtering to the sampling code, and the results are still far from perfect, but I managed to pass from a very high % of just plain colored images to a considerable % of face-ish results :D Here's an example without typical filtering: And here a couple of examples with typical filtering (with a mass of 0.2): For the filtering I just adapted the code from the official repo: Seems like, although the 'typical filter' is made to try to follow some rules about how language works, by allowing the model to pick from a large number of options when the expected information is high while reducing the pool of options when the expected information is low, it seems to also benefit image generation. I think it might even be related to the non-sequential nature of the sampling, so at the beginning when sampling the first pixels, the expected information is pretty high, so the model can pick a wider variety of options, while as the image starts taking shape the options are reduced since we already have a general sketch of the image... Or idk, it might be something completely different XD Anyway, hopefully this is useful for someone, and maybe we could even reach @cimeister to ask if they thought of this for image generation 🤔 |
@pabloppp Oh, did you use the typical sampling in the process of multinomial sampling to predict the code of each position? I think it would be help increasing diversity, since the typical sampling is known to resolve de-generation problem in NLG. |
Yes, basically before calling multinomial sampling I do what the TypicalLogitsWarper function does to set the logits of the filtered tokens to -inf so the multinomial only samples from the filtered pool. I also keep the temperature decay and the sampling schedule for the number of tokens sampled each step untouched. |
Wow very cool! Thanks for sharing, @pabloppp. We hadn't tried typical sampling yet for image generation but it seems like a promising direction! |
The paper describes as follow.
Here, I am confusing whether they use temperature annealing (TA) to randomly select the masking position, or use TA in the multinomial sampling in each position. |
@LeeDoYup I'm pretty sure temperature is applied before softmax logits, thus affecting the multinomial sampling 🤔 but things that you mention are correlated: you change the temperature, so the probability of sampling some tokens varies, then you apply the multinomial and keep only a number of tokens based on their score following the cosine schedule. |
@pabloppp When I logically think about the mask selection based on the algorithm, I also agree that the TA is used before softmax logits. However, when I only read the sentence above, the sentence means that they made a randomness on "mask selection" not on token sampling. So I am very confused. When I use the random masking strategy, the performance on ImageNet is much improved. For examle, when |
@LeeDoYup can you show some pseudo-code example for both of the cases you describe above? |
@dome272 I use the random masking strategy as follow:
That is, I randomly select the positions of unmasked tokens. |
In some way, it makes sense to choose them randomly since during training you're masking them randomly, so the model is used to having to reconstruct random missing tokens, not necessarily start with the highest scores and end with the ones with lowest score. But that said, I the paper they very explicitly say that they sample the highest scored tokens.
I'm pretty sure when they say Anyway, if you've found that sampling randomly instead of taking the highest scores helps, it's worth a try 🙇 BTW what are you doing exactly to get recall & precision from a random sampling? 🤔 What do you compare your output to in order to get metrics? |
@pabloppp I totally agree with you. However, the most problematic fact is that the hyper-parameters of TA is not described in the paper and hard to reproduce the results. When I evaluate the recall & precision on ImageNet, I generated 50K samples and use the protocol in this repository. |
@pabloppp logits /= temperature
filtered_logits = transformers.top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
probs = torch.nn.functional.softmax(filtered_logits, dim=-1)
sample = torch.multinomial(probs, 1) Thats the normal topk_topp sampling. At which point do you call the TypicalLogitsWarper and with which input? |
I don't think you're doing it right. You're supposed to first sample for every masked token, then pick the topK with highest scores, since otherwise, you don't really know the score of the token you sampled. I do not use directly the TypicalLogitsWarper class, but I use the same implementation. This is the core of my sampling implementation. logits, _ = self(x, c, mask)
probs = logits.div(temp)
probs_flat = probs.permute(0, 2, 3, 1).reshape(-1, probs.size(1))
if typical_filtering:
probs_flat_norm = torch.nn.functional.log_softmax(probs_flat, dim=-1)
probs_flat_norm_p = torch.exp(probs_flat_norm)
entropy = -(probs_flat_norm * probs_flat_norm_p).nansum(-1, keepdim=True)
probs_flat_shifted = torch.abs((-probs_flat_norm) - entropy)
probs_flat_sorted, probs_flat_indices = torch.sort(probs_flat_shifted, descending=False)
probs_flat_cumsum = probs_flat.gather(-1, probs_flat_indices).softmax(dim=-1).cumsum(dim=-1)
last_ind = (probs_flat_cumsum < typical_mass).sum(dim=-1)
sorted_indices_to_remove = probs_flat_sorted > probs_flat_sorted.gather(1, last_ind.view(-1, 1))
if typical_min_tokens > 1:
sorted_indices_to_remove[..., :typical_min_tokens] = 0
indices_to_remove = sorted_indices_to_remove.scatter(1, probs_flat_indices, sorted_indices_to_remove)
probs_flat = probs_flat.masked_fill(indices_to_remove, -float("Inf"))
probs_flat = probs_flat.softmax(dim=-1)
sample_indices = torch.multinomial(probs_flat, num_samples=1)
sample_scores = torch.gather(probs_flat, 1, sample_indices) |
hey guys. today I reached out the authors if they would help us in our problem and the first author replied to me and said that they are planning to release the code next week (or so). |
Guess the official repo is out: https://github.com/google-research/maskgit (although it seems to be in JAX) |
Yea please report if anyone finds the cliffhanger.... |
Small finding: So, step 0 just basically samples randomly a token, then the next step is also random but less random, etc... :/ I guess we should try this, but it seems like a lot of randomness XD |
Yes, when I see https://github.com/google-research/maskgit/blob/cf615d448642942ddebaa7af1d1ed06a05720a91/maskgit/libml/parallel_decode.py#L49-L56, i conclude that the paper was wrong. Mixing randomness is the key of the algorithm..... |
If someone of you translated the sampling code to pytorch, could you post it here? |
Thanks for your share. I have also tried to train a transformer based model on COCO dataset for image generation, but got worse results. Could you share some nature images generated by your model trained on the dataset of nature images? I wonder if the training iteration is the key point to train a well model or I miss some details in my implementation. |
I finish to train 200M params of model on ImageNet during 300 epochs, but when I use the released technique, I got FID=21. When I use temperature scale in predicting tokens (not mask), I got FID=11~12, which is not reproduced result. By the way, I think they do not use temperature annealing to randomly select the position of un-masking, since the below code fixes the temperature parameters as a scalar (=4.5). Is it right...? |
It's 4.5 * (1 - ratio) and ratio goes up from 0 to 1 with the sampling step, so the temperature gets annealed to 0 |
@pabloppp Oh, thank you I will try it ! |
@pabloppp so maybe training the model longer can get a better results like @LeeDoYup shared? Although the quantitative result is not as good as the paper stated, but I think the quality of generated images with FID=21 is much better than images we got. Could you share any generated images here @LeeDoYup ? |
@GuoxingY The images are generated images of ImageNet. I will share in this thread soon. |
How much of the image do you mask for the reconstructed image? |
The formulation might have been a bit misleading in that context. The reconstruction is just encoding and decoding the image and has nothing to do with the transformer. I just put it in for my own better understanding. |
Hi
I have also tried to re-produce the MaskGIT recently. After training 150 epoch on ImageNet, our model can only achieve 8.4% accuracy on token classification. During sampling, we find our model will generate monochrome image (nearly white). Do you meet similar problem?
The text was updated successfully, but these errors were encountered: