-
Notifications
You must be signed in to change notification settings - Fork 392
Commit
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -501,6 +501,50 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep) | |
candidates->size = last_idx; | ||
} | ||
|
||
void sample_xtc(llama_token_data_array * candidates, float xtc_threshold, float xtc_probability, std::mt19937 & rng, size_t min_keep) | ||
{ | ||
if (xtc_threshold <= 0.0f || xtc_probability <= 0.0f || candidates->size <= 1) { | ||
This comment has been minimized.
Sorry, something went wrong.
This comment has been minimized.
Sorry, something went wrong. |
||
return; | ||
} | ||
|
||
std::uniform_real_distribution<float> dist(0.0f, 1.0f); | ||
float roll = dist(rng); | ||
if(roll>=xtc_probability) //if dice roll fails, skip xtc | ||
{ | ||
return; | ||
} | ||
|
||
llama_sample_softmax(nullptr, candidates); | ||
|
||
//calculate how many tokens cross the xtc threshold | ||
size_t last_idx = candidates->size; | ||
for (size_t i = 0; i < candidates->size; ++i) { | ||
// Go until we reach a value under the threshold | ||
float checkprob = candidates->data[i].p; | ||
if (checkprob < xtc_threshold && i >= min_keep) { | ||
This comment has been minimized.
Sorry, something went wrong.
p-e-w
|
||
last_idx = i; | ||
break; | ||
} | ||
} | ||
|
||
if(last_idx>1) //if there are 2 or more viable candidates | ||
{ | ||
// drop all tokens except those above threshold | ||
This comment has been minimized.
Sorry, something went wrong.
p-e-w
|
||
candidates->size = last_idx; | ||
|
||
// then remove all other tokens EXCEPT the least likely one | ||
for (size_t i = 0; i < candidates->size - 1; ++i) { | ||
candidates->data[i].logit = -999.0f; //infinity gets wonky results downstream, this hack works well enough | ||
} | ||
candidates->sorted = false; | ||
|
||
} //otherwise xtc does not do anything | ||
|
||
// printf("\n\nCandidates: %d, Threshold: %f, LastIdx: %d",candidates->size,xtc_threshold,last_idx); | ||
// printf("\nCandidates: %f %f %f %f\n",candidates->data[0].p,candidates->data[1].p,candidates->data[2].p,candidates->data[3].p); | ||
|
||
} | ||
|
||
void sample_dry(int n_ctx, int penalty_range, float penalty_multiplier, float penalty_base, int allowed_length, const std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>>& restart_sequences, llama_token_data_array * candidates) { | ||
if (penalty_multiplier <= 0.0f || penalty_base <= 0.0f) { | ||
return; | ||
|
@@ -822,7 +866,8 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar | |
} | ||
|
||
int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng, | ||
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor) | ||
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, float xtc_threshold, float xtc_probability, | ||
const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor) | ||
{ | ||
int id = 0; | ||
std::vector<llama_token_data> candidates; | ||
|
@@ -843,6 +888,7 @@ int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, floa | |
sample_grammar(file_format, n_vocab, &candidates_p, grammar); | ||
} | ||
|
||
//dry always first as logits cannot be resorted | ||
sample_dry(n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p); | ||
|
||
//prefilter to top 5k tokens for improved speed | ||
|
@@ -909,6 +955,8 @@ int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, floa | |
break; | ||
} | ||
} | ||
//xtc always last | ||
sample_xtc(&candidates_p, xtc_threshold, xtc_probability, rng, 1); | ||
id = sample_token(&candidates_p, rng); | ||
} | ||
|
||
|
@@ -2088,6 +2136,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs) | |
kcpp_params->dry_base = inputs.dry_base; | ||
kcpp_params->dry_allowed_length = inputs.dry_allowed_length; | ||
kcpp_params->dry_penalty_last_n = inputs.dry_penalty_last_n; | ||
kcpp_params->xtc_threshold = inputs.xtc_threshold; | ||
kcpp_params->xtc_probability = inputs.xtc_probability; | ||
kcpp_params->dynatemp_range = inputs.dynatemp_range; | ||
kcpp_params->dynatemp_exponent = inputs.dynatemp_exponent; | ||
kcpp_params->n_ctx = inputs.max_context_length; | ||
|
@@ -2662,7 +2712,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs) | |
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng, | ||
kcpp_params->mirostat, kcpp_params->mirostat_tau, kcpp_params->mirostat_eta, | ||
kcpp_params->dry_multiplier, kcpp_params->dry_base, | ||
kcpp_params->dry_allowed_length, kcpp_params->dry_penalty_last_n, | ||
kcpp_params->dry_allowed_length, kcpp_params->dry_penalty_last_n, kcpp_params->xtc_threshold, kcpp_params->xtc_probability, | ||
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor); | ||
|
||
if (grammar != nullptr) { | ||
|
You can add
xtc_threshold > 0.5f
to those exclusion criteria, as such a threshold cannot be met by more than one token.You might also consider removing the
xtc_threshold <= 0.0f
criterion. The sampler can already be disabled by setting the probability to 0, and there is no reason why a threshold of 0 shouldn't be valid (and result in the removal of all but the least probable token, which matches the semantics of larger values, and could be interesting for experiments).