Skip to content

Commit

Permalink
llama : add pooling switch
Browse files Browse the repository at this point in the history
  • Loading branch information
ggerganov committed Mar 4, 2024
1 parent 9bbeb0f commit e66da35
Showing 1 changed file with 25 additions and 18 deletions.
43 changes: 25 additions & 18 deletions llama.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8113,7 +8113,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) {

for (int i = 0; i < n_tokens; ++i) {
const llama_seq_id seq_id = batch.seq_id[i][0];
const llama_pos pos = batch.pos[i];
const llama_pos pos = batch.pos[i];
if (pos == 0) {
data[seq_id] = i;
}
Expand Down Expand Up @@ -8379,10 +8379,17 @@ static int llama_decode_internal(
if (batch.logits[i] == 0) {
continue;
}
if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) {
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
} else {
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
switch (hparams.pooling_type) {
case LLAMA_POOLING_TYPE_CLS:
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*batch.seq_id[i][0])*sizeof(float), n_embd*sizeof(float));
break;
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_NONE:
ggml_backend_tensor_get_async(backend_embd, embd, embeddings_out.data() + (n_embd*i), (n_embd*i)*sizeof(float), n_embd*sizeof(float));
break;
default:
GGML_ASSERT(false && "unknown pooling type");
break;
}
}
}
Expand Down Expand Up @@ -8680,19 +8687,19 @@ static uint8_t llama_token_to_byte(const llama_vocab& vocab, llama_token id) {
GGML_ASSERT(llama_is_byte_token(vocab, id));
const auto& token_data = vocab.id_to_token.at(id);
switch (llama_vocab_get_type(vocab)) {
case LLAMA_VOCAB_TYPE_SPM: {
auto buf = token_data.text.substr(3, 2);
return strtol(buf.c_str(), NULL, 16);
}
case LLAMA_VOCAB_TYPE_BPE: {
GGML_ASSERT(false);
return unicode_to_bytes_bpe(token_data.text);
}
case LLAMA_VOCAB_TYPE_WPM: {
GGML_ASSERT(false);
}
default:
GGML_ASSERT(false);
case LLAMA_VOCAB_TYPE_SPM: {
auto buf = token_data.text.substr(3, 2);
return strtol(buf.c_str(), NULL, 16);
}
case LLAMA_VOCAB_TYPE_BPE: {
GGML_ASSERT(false);
return unicode_to_bytes_bpe(token_data.text);
}
case LLAMA_VOCAB_TYPE_WPM: {
GGML_ASSERT(false);
}
default:
GGML_ASSERT(false);
}
}

Expand Down

0 comments on commit e66da35

Please sign in to comment.