Skip to content

Commit

Permalink
back to sgd for lin vec.
Browse files Browse the repository at this point in the history
  • Loading branch information
orm011 committed Jun 17, 2021
1 parent 9c1a3f6 commit fd7f913
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion vsms/search_loop_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def fit_rank2(*, mod, X, y, batch_size, max_examples, valX=None, valy=None, logg

def adjust_vec(vec, Xt, yt, learning_rate, loss_margin, max_examples, minibatch_size):
vec = torch.from_numpy(vec).type(torch.float32)
mod = LookupVec(Xt.shape[1], margin=loss_margin, optimizer=torch.optim.Adam, learning_rate=learning_rate, init_vec=vec)
mod = LookupVec(Xt.shape[1], margin=loss_margin, optimizer=torch.optim.SGD, learning_rate=learning_rate, init_vec=vec)
fit_rank2(mod=mod, X=Xt.astype('float32'), y=yt.astype('float'),
max_examples=max_examples, batch_size=minibatch_size,max_epochs=1)
newvec = mod.vec.detach().numpy().reshape(1,-1)
Expand Down

0 comments on commit fd7f913

Please sign in to comment.