diff --git a/vsms/search_loop_models.py b/vsms/search_loop_models.py index 4a0a97d..b322fed 100644 --- a/vsms/search_loop_models.py +++ b/vsms/search_loop_models.py @@ -301,7 +301,7 @@ def fit_rank2(*, mod, X, y, batch_size, max_examples, valX=None, valy=None, logg def adjust_vec(vec, Xt, yt, learning_rate, loss_margin, max_examples, minibatch_size): vec = torch.from_numpy(vec).type(torch.float32) - mod = LookupVec(Xt.shape[1], margin=loss_margin, optimizer=torch.optim.Adam, learning_rate=learning_rate, init_vec=vec) + mod = LookupVec(Xt.shape[1], margin=loss_margin, optimizer=torch.optim.SGD, learning_rate=learning_rate, init_vec=vec) fit_rank2(mod=mod, X=Xt.astype('float32'), y=yt.astype('float'), max_examples=max_examples, batch_size=minibatch_size,max_epochs=1) newvec = mod.vec.detach().numpy().reshape(1,-1)