Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to load the pretrained safesensor and continue to train? #13

Open
JunyuanDeng opened this issue Jun 19, 2024 · 4 comments
Open

How to load the pretrained safesensor and continue to train? #13

JunyuanDeng opened this issue Jun 19, 2024 · 4 comments

Comments

@JunyuanDeng
Copy link

Hello, Thanks for your sharing code!

I am now try to train the stage 2 with the provided vista.safetensors

So I change the command to below:

torchrun \
    --nnodes=1 \
    --nproc_per_node=8 \
    train.py \
    --base configs/training/vista_phase2_stage2.yaml \
    --finetune ${PATH_TO_STAGE1_CKPT}/vista.safetensors \
    --num_nodes 1 \
    --n_devices 8

But there are lots of missing keys like:
image

And the loss, in my expectation, should be low, which is not true in my observation:
image

I download the sampled video "samples_mp4_epoch00_batch0000_step000001.mp4":

samples_mp4_epoch00_batch0000_step000001.mp4

What should I do to use the provided weight to start the phase 2 stage 2 traning?

@Little-Podi
Copy link
Collaborator

Sorry for the trouble. I haven't verify this resuming feature yet. It seems that there are some random weights after initialization. Make sure the new weights are initialized as zeros. In addition, if there are some "unexpected" weights when loading the checkpoint, make sure all of them are remapped to "missing" weights. It can be realized by renaming the keys in the state dictionary and loading the dictionary to the model again.

@zhoujiawei3
Copy link

@JunyuanDeng
Hi, have you resolved this issue? Could you please share how you did it? Thank you!

@zhoujiawei3
Copy link

zhoujiawei3 commented Nov 10, 2024

@Little-Podi Hi,I want to make sure your words mean that we need to change the code to set the missing keys initialized as zeros in this case? As when I set these missing keys's value to zero, the samples_mp4_epoch00_batch0000_step000001.mp4 is still in that strange form

@jywu511
Copy link

jywu511 commented Dec 25, 2024

@Little-Podi Hi, thanks a lot for sharing the great work! I met the same question, could you share the checkpoint after stage1 for continue training? Thanks a lot!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants