Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fintune part of a whole embeding parameters. #5231

Open
CongHan0808 opened this issue Mar 6, 2024 · 1 comment
Open

Fintune part of a whole embeding parameters. #5231

CongHan0808 opened this issue Mar 6, 2024 · 1 comment

Comments

@CongHan0808
Copy link

I add 100 new tokens to the vocabulary and corresponding embedings. I try to only fintune these embeddings and fix raw tokens with pretrained weights. I follow #4192, then use safe_get_full_grad and safe_set_full_fp32_param to get and modify the grad of the parameter. But all weights of the parameters are updated
There are my code:

model_engine.backward(total_loss)
textembeds_masks = torch.zeros_like(model_engine.in_adaptor.text_embed.weight).to(device=model_engine.local_rank)
textembeds_masks[VOCAB_SIZE_SRC+1,:] = 1
with torch.no_grad():
    for p_name,param in model_engine.named_parameters():
        if "in_adaptor.text_embed.weight" in p_name:
            if param.grad is not None:
           
                hp_grad = safe_get_full_grad(param)
                exp_avg = safe_get_full_optimizer_state(param, "exp_avg")
                exp_avg_sq = safe_get_full_optimizer_state(param, "exp_avg_sq")
                # hp_grad.copy_(hp_grad.data*textembeds_masks)
                
                safe_set_full_fp32_param(param, hp_grad.data*textembeds_masks)
                safe_set_full_optimizer_state(param, exp_avg.data * textembeds_masks, "exp_avg")
                safe_set_full_optimizer_state(param, exp_avg_sq.data * textembeds_masks, "exp_avg_sq")
model_engine.step()
scheduler(step)

After some checkpoints, the raw tokens' weights of in_adaptor.text_embed.weight in different cks are different. How should I change my code to keep the raw tokens' weights the same and only fintune the new tokens' weights.

@AuroraZengfh
Copy link

hi, I'm also searching for solutions to finetuning part of embeddings in deepspeed framework. Did you get a convenient way to solve the problem?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants