Skip to content

Commit

Permalink
llamamodel: fix embedding crash for >512 tokens after #2310 (#2383)
Browse files Browse the repository at this point in the history
Signed-off-by: Jared Van Bortel <[email protected]>
  • Loading branch information
cebtenzzre authored May 29, 2024
1 parent f047f38 commit e94177e
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion gpt4all-backend/llamamodel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ bool LLamaModel::loadModel(const std::string &modelPath, int n_ctx, int ngl)
bool isEmbedding = is_embedding_arch(llama_model_arch(d_ptr->model));
const int n_ctx_train = llama_n_ctx_train(d_ptr->model);
if (isEmbedding) {
d_ptr->ctx_params.n_batch = n_ctx;
d_ptr->ctx_params.n_batch = n_ctx;
d_ptr->ctx_params.n_ubatch = n_ctx;
} else {
if (n_ctx > n_ctx_train) {
std::cerr << "warning: model was trained on only " << n_ctx_train << " context tokens ("
Expand Down

0 comments on commit e94177e

Please sign in to comment.