You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I add 100 new tokens to the vocabulary and corresponding embedings. I try to only fintune these embeddings and fix raw tokens with pretrained weights. I follow #4192, then use safe_get_full_grad and safe_set_full_fp32_param to get and modify the grad of the parameter. But all weights of the parameters are updated
There are my code:
model_engine.backward(total_loss)
textembeds_masks = torch.zeros_like(model_engine.in_adaptor.text_embed.weight).to(device=model_engine.local_rank)
textembeds_masks[VOCAB_SIZE_SRC+1,:] = 1
with torch.no_grad():
for p_name,param in model_engine.named_parameters():
if "in_adaptor.text_embed.weight" in p_name:
if param.grad is not None:
hp_grad = safe_get_full_grad(param)
exp_avg = safe_get_full_optimizer_state(param, "exp_avg")
exp_avg_sq = safe_get_full_optimizer_state(param, "exp_avg_sq")
# hp_grad.copy_(hp_grad.data*textembeds_masks)
safe_set_full_fp32_param(param, hp_grad.data*textembeds_masks)
safe_set_full_optimizer_state(param, exp_avg.data * textembeds_masks, "exp_avg")
safe_set_full_optimizer_state(param, exp_avg_sq.data * textembeds_masks, "exp_avg_sq")
model_engine.step()
scheduler(step)
After some checkpoints, the raw tokens' weights of in_adaptor.text_embed.weight in different cks are different. How should I change my code to keep the raw tokens' weights the same and only fintune the new tokens' weights.
The text was updated successfully, but these errors were encountered:
I add 100 new tokens to the vocabulary and corresponding embedings. I try to only fintune these embeddings and fix raw tokens with pretrained weights. I follow #4192, then use
safe_get_full_grad
andsafe_set_full_fp32_param
to get and modify the grad of the parameter. But all weights of the parameters are updatedThere are my code:
After some checkpoints, the raw tokens' weights of in_adaptor.text_embed.weight in different cks are different. How should I change my code to keep the raw tokens' weights the same and only fintune the new tokens' weights.
The text was updated successfully, but these errors were encountered: