Skip to content

Commit

Permalink
added xtc sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins committed Aug 21, 2024
1 parent 1a7ecd5 commit 5bf527a
Show file tree
Hide file tree
Showing 5 changed files with 121 additions and 8 deletions.
2 changes: 2 additions & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,8 @@ struct gpt_params {
int32_t dry_allowed_length = 2; // repeated sequences longer than this are penalized
int32_t dry_penalty_last_n = 0; // how many tokens to scan for repetitions (0 = entire context)
std::vector<std::string> dry_sequence_breakers; // DRY sequence breakers
float xtc_threshold = 0;
float xtc_probability = 0;

// DynaTemp!
float dynatemp_range = 0.0f; // enables DynaTemp if greater than 0. dynatemp_min = temperature - dt_range, dynatemp_max = temperature + dt_range
Expand Down
2 changes: 2 additions & 0 deletions expose.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,8 @@ struct generation_inputs
const int dry_allowed_length = 0;
const int dry_penalty_last_n = 0;
const char * dry_sequence_breakers[dry_seq_break_max] = {};
const float xtc_threshold = 0.0f;
const float xtc_probability = 0.0f;
const samplers sampler_order[KCPP_SAMPLER_MAX] = {};
const int sampler_len = 0;
const bool allow_eos_token = false;
Expand Down
54 changes: 52 additions & 2 deletions gpttype_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -501,6 +501,50 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep)
candidates->size = last_idx;
}

void sample_xtc(llama_token_data_array * candidates, float xtc_threshold, float xtc_probability, std::mt19937 & rng, size_t min_keep)
{
if (xtc_threshold <= 0.0f || xtc_probability <= 0.0f || candidates->size <= 1) {

This comment has been minimized.

Copy link
@p-e-w

p-e-w Aug 22, 2024

You can add xtc_threshold > 0.5f to those exclusion criteria, as such a threshold cannot be met by more than one token.

You might also consider removing the xtc_threshold <= 0.0f criterion. The sampler can already be disabled by setting the probability to 0, and there is no reason why a threshold of 0 shouldn't be valid (and result in the removal of all but the least probable token, which matches the semantics of larger values, and could be interesting for experiments).

This comment has been minimized.

Copy link
@LostRuins

LostRuins Aug 22, 2024

Author Owner

Will do

return;
}

std::uniform_real_distribution<float> dist(0.0f, 1.0f);
float roll = dist(rng);
if(roll>=xtc_probability) //if dice roll fails, skip xtc
{
return;
}

llama_sample_softmax(nullptr, candidates);

//calculate how many tokens cross the xtc threshold
size_t last_idx = candidates->size;
for (size_t i = 0; i < candidates->size; ++i) {
// Go until we reach a value under the threshold
float checkprob = candidates->data[i].p;
if (checkprob < xtc_threshold && i >= min_keep) {

This comment has been minimized.

Copy link
@p-e-w

p-e-w Aug 22, 2024

I believe this use of min_keep is misleading or wrong. We are going to remove candidates before the index i, so min_keep should be checked against the number of remaining candidates after that index (those we keep).

This comment has been minimized.

Copy link
@LostRuins

LostRuins Aug 22, 2024

Author Owner

Alright, since XTC is no longer a truncation sampler and won't fire with less than 2 candidates, min_keep shouldn't be needed anymore.

last_idx = i;
break;
}
}

if(last_idx>1) //if there are 2 or more viable candidates
{
// drop all tokens except those above threshold

This comment has been minimized.

Copy link
@p-e-w

p-e-w Aug 22, 2024

No! This isn't how XTC works. Please see the chart in the original PR.

If we drop all tokens below the threshold, and then also drop all tokens above the threshold except the last one, we are left with exactly one token. This isn't what we want; the tail should be preserved.

Cutting the tail is the job of truncation samplers like Min-P. Make sure that XTC comes last in the sampling stack, so that all truncation has already happened (as explained in the original PR).

This comment has been minimized.

Copy link
@LostRuins

LostRuins Aug 22, 2024

Author Owner

Can you take a look at cca3c4c and see if it's alright now

This comment has been minimized.

Copy link
@p-e-w

p-e-w Aug 24, 2024

I left a small comment, other than that the implementation looks good now!

candidates->size = last_idx;

// then remove all other tokens EXCEPT the least likely one
for (size_t i = 0; i < candidates->size - 1; ++i) {
candidates->data[i].logit = -999.0f; //infinity gets wonky results downstream, this hack works well enough
}
candidates->sorted = false;

} //otherwise xtc does not do anything

// printf("\n\nCandidates: %d, Threshold: %f, LastIdx: %d",candidates->size,xtc_threshold,last_idx);
// printf("\nCandidates: %f %f %f %f\n",candidates->data[0].p,candidates->data[1].p,candidates->data[2].p,candidates->data[3].p);

}

void sample_dry(int n_ctx, int penalty_range, float penalty_multiplier, float penalty_base, int allowed_length, const std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>>& restart_sequences, llama_token_data_array * candidates) {
if (penalty_multiplier <= 0.0f || penalty_base <= 0.0f) {
return;
Expand Down Expand Up @@ -822,7 +866,8 @@ void sample_grammar(FileFormat file_format, int32_t n_vocab, llama_token_data_ar
}

int SampleLogits(const float * logits, int n_ctx, int n_vocab, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, float top_k, float top_a, float top_p, float min_p, float typical_p, float tfs, float temp, std::mt19937 & rng,
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor)
int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, float dry_base, int dry_allowed_length, int dry_penalty_last_n, float xtc_threshold, float xtc_probability,
const std::vector<samplers> & sampler_order, llama_grammar * grammar, float dynatemp_range, float dynatemp_exponent, float smoothing_factor)
{
int id = 0;
std::vector<llama_token_data> candidates;
Expand All @@ -843,6 +888,7 @@ int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, floa
sample_grammar(file_format, n_vocab, &candidates_p, grammar);
}

//dry always first as logits cannot be resorted
sample_dry(n_ctx, dry_penalty_last_n, dry_multiplier, dry_base, dry_allowed_length, dry_sequence_breakers, &candidates_p);

//prefilter to top 5k tokens for improved speed
Expand Down Expand Up @@ -909,6 +955,8 @@ int mirostat, float mirostat_tau, float mirostat_eta, float dry_multiplier, floa
break;
}
}
//xtc always last
sample_xtc(&candidates_p, xtc_threshold, xtc_probability, rng, 1);
id = sample_token(&candidates_p, rng);
}

Expand Down Expand Up @@ -2088,6 +2136,8 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
kcpp_params->dry_base = inputs.dry_base;
kcpp_params->dry_allowed_length = inputs.dry_allowed_length;
kcpp_params->dry_penalty_last_n = inputs.dry_penalty_last_n;
kcpp_params->xtc_threshold = inputs.xtc_threshold;
kcpp_params->xtc_probability = inputs.xtc_probability;
kcpp_params->dynatemp_range = inputs.dynatemp_range;
kcpp_params->dynatemp_exponent = inputs.dynatemp_exponent;
kcpp_params->n_ctx = inputs.max_context_length;
Expand Down Expand Up @@ -2662,7 +2712,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs)
top_k, top_a, top_p, min_p, typical_p, tfs_z, temp, rng,
kcpp_params->mirostat, kcpp_params->mirostat_tau, kcpp_params->mirostat_eta,
kcpp_params->dry_multiplier, kcpp_params->dry_base,
kcpp_params->dry_allowed_length, kcpp_params->dry_penalty_last_n,
kcpp_params->dry_allowed_length, kcpp_params->dry_penalty_last_n, kcpp_params->xtc_threshold, kcpp_params->xtc_probability,
sampler_order, grammar, dynatemp_range, dynatemp_exponent, smoothing_factor);

if (grammar != nullptr) {
Expand Down
63 changes: 58 additions & 5 deletions klite.embd
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Current version indicated by LITEVER below.
-->

<script>
const LITEVER = 167;
const LITEVER = 169;
const urlParams = new URLSearchParams(window.location.search);
const localflag = true;
const STORAGE_PREFIX = (localflag?"e_":"")+"kaihordewebui_";
Expand Down Expand Up @@ -209,7 +209,7 @@ Current version indicated by LITEVER below.

.settingsbody
{
height: calc(82vh - 100px);
height: calc(86vh - 94px);
overflow-y: auto;
overflow-x: hidden;
}
Expand Down Expand Up @@ -4353,6 +4353,8 @@ Current version indicated by LITEVER below.
dry_base: 1.75,
dry_allowed_length: 2,
dry_sequence_breakers: ["\n", ":", "\"", "*"],
xtc_threshold: 0.2,
xtc_probability: 0.0,
sampler_order: [6, 0, 1, 3, 4, 2, 5],
};

Expand Down Expand Up @@ -5391,6 +5393,10 @@ Current version indicated by LITEVER below.
{
return (custom_kobold_endpoint!="" && koboldcpp_version && koboldcpp_version!="" && compare_version_str(koboldcpp_version, "1.70") >= 0);
}
function is_using_kcpp_with_xtc()
{
return (custom_kobold_endpoint!="" && koboldcpp_version && koboldcpp_version!="" && compare_version_str(koboldcpp_version, "1.74") >= 0);
}


//0 is none, 1 is pseudostreaming, 2 is true poll-streaming, 3 is sse-streaming
Expand Down Expand Up @@ -9557,6 +9563,8 @@ Current version indicated by LITEVER below.
document.getElementById("miro_tau").value = localsettings.miro_tau;
document.getElementById("miro_eta").value = localsettings.miro_eta;
document.getElementById("dry_multiplier").value = localsettings.dry_multiplier;
document.getElementById("xtc_threshold").value = localsettings.xtc_threshold;
document.getElementById("xtc_probability").value = localsettings.xtc_probability;
document.getElementById("dry_base").value = localsettings.dry_base;
document.getElementById("dry_allowed_length").value = localsettings.dry_allowed_length;
document.getElementById("token_count_multiplier").value = localsettings.token_count_multiplier;
Expand All @@ -9582,6 +9590,17 @@ Current version indicated by LITEVER below.
document.getElementById("drysupporteddiv").classList.add("hidden");
document.getElementById("dryunsupporteddiv").classList.remove("hidden");
}

if(is_using_kcpp_with_xtc())
{
document.getElementById("xtcsupporteddiv").classList.remove("hidden");
document.getElementById("xtcunsupporteddiv").classList.add("hidden");
}
else
{
document.getElementById("xtcsupporteddiv").classList.add("hidden");
document.getElementById("xtcunsupporteddiv").classList.remove("hidden");
}
pendingsequencebreakers = localsettings.dry_sequence_breakers;

document.getElementById("setgrammar").disabled = !is_using_kcpp_with_grammar();
Expand Down Expand Up @@ -9721,6 +9740,7 @@ Current version indicated by LITEVER below.
document.getElementById("tfs_s").value = found.tfs;
document.getElementById("miro_type").value = 0;
document.getElementById("dry_multiplier").value = 0;
document.getElementById("xtc_probability").value = 0;
document.getElementById("rep_pen").value = document.getElementById("rep_pen_slide").value = found.rep_pen;
document.getElementById("rep_pen_range").value = found.rep_pen_range;
document.getElementById("rep_pen_slope").value = found.rep_pen_slope;
Expand Down Expand Up @@ -9808,6 +9828,7 @@ Current version indicated by LITEVER below.
document.getElementById("tfs_s").value != found.tfs ||
document.getElementById("miro_type").value != 0 ||
document.getElementById("dry_multiplier").value != 0 ||
document.getElementById("xtc_probability").value != 0 ||
document.getElementById("rep_pen").value != found.rep_pen ||
document.getElementById("rep_pen_range").value != found.rep_pen_range ||
document.getElementById("rep_pen_slope").value != found.rep_pen_slope ||
Expand Down Expand Up @@ -9950,6 +9971,8 @@ Current version indicated by LITEVER below.
localsettings.dry_base = document.getElementById("dry_base").value;
localsettings.dry_allowed_length = document.getElementById("dry_allowed_length").value;
localsettings.dry_sequence_breakers = pendingsequencebreakers;
localsettings.xtc_threshold = document.getElementById("xtc_threshold").value;
localsettings.xtc_probability = document.getElementById("xtc_probability").value;
localsettings.token_count_multiplier = document.getElementById("token_count_multiplier").value;

localsettings.speech_synth = document.getElementById("ttsselect").value;
Expand Down Expand Up @@ -10050,6 +10073,8 @@ Current version indicated by LITEVER below.
localsettings.dry_multiplier = cleannum(localsettings.dry_multiplier, 0.0, 100.0);
localsettings.dry_base = cleannum(localsettings.dry_base, 0.0, 8.0);
localsettings.dry_allowed_length = cleannum(Math.floor(localsettings.dry_allowed_length), 0, 100);
localsettings.xtc_probability = cleannum(localsettings.xtc_probability, 0.0, 1.0);
localsettings.xtc_threshold = cleannum(localsettings.xtc_threshold, 0.0, 1.0);
localsettings.sampler_seed = cleannum(localsettings.sampler_seed, -1, 999999);
localsettings.token_count_multiplier = cleannum(localsettings.token_count_multiplier, 70, 130);
toggle_invert_colors();
Expand Down Expand Up @@ -12098,6 +12123,13 @@ Current version indicated by LITEVER below.
submit_payload.params.dry_penalty_last_n = localsettings.rep_pen_range;
submit_payload.params.dry_sequence_breakers = JSON.parse(JSON.stringify(localsettings.dry_sequence_breakers));
}

if(custom_kobold_endpoint != "" && is_using_kcpp_with_xtc() && localsettings.xtc_probability > 0)
{
submit_payload.params.xtc_threshold = localsettings.xtc_threshold;
submit_payload.params.xtc_probability = localsettings.xtc_probability;
}

//presence pen and logit bias for OAI and newer kcpp
if((custom_kobold_endpoint != "" && is_using_kcpp_with_mirostat()) || custom_oai_endpoint!="")
{
Expand Down Expand Up @@ -17531,7 +17563,8 @@ Current version indicated by LITEVER below.
</select>
<span class="color_green" style="font-weight: bold;">Please input Gemini or PaLM API Key.</span><br><br>
<input class="form-control" type="password" id="custom_palm_key" placeholder="PaLM/Gemini API Key (Required)" value="" onfocus="focus_api_keys()" onblur="blur_api_keys()"><br>
<input class="form-control" type="text" id="gemini_system_instruction" placeholder="(Enter System Instruction)" value=""><br>
<textarea class="form-control" rows="3" style="resize: vertical; line-height:1.1; padding:4px; display:inline; width: 100%" type="text" id="gemini_system_instruction" placeholder="(Enter System Instruction)"
value=""></textarea><br>
</div>
<div id="coherecustom" class="aidgpopuplistheader anotelabel hidden">
Uses Cohere's models through their own API.<br><br>
Expand Down Expand Up @@ -17638,7 +17671,7 @@ Current version indicated by LITEVER below.

<div class="popupcontainer flex hidden" id="settingscontainer">
<div class="popupbg flex"></div>
<div class="nspopup flexsize" style="margin-top: 6vh; background-color:#102840">
<div class="nspopup flexsize" style="margin-top: 4vh; background-color:#102840">
<div class="popuptitlebar">
<div class="popuptitletext" title="Settings Menu">Settings</div>
</div>
Expand Down Expand Up @@ -18074,7 +18107,28 @@ Current version indicated by LITEVER below.
<div id="dryunsupporteddiv" class="color_red" style="font-weight:bold;padding:3px;font-size:12px">DRY Not Supported</div>
</div>

<div class="settinglabel settingcell">
<div title="XTC" class="justifyleft settingsmall" style="width:100%">XTC <span class="helpicon">?<span class="helptext">
Enables Exclude Top Choices (XTC) Sampling. May not be available depending on backend, not supported on Horde.</span></span></div>
<div class="justifyleft settingsmall" style="width:100%">
<div id="xtcsupporteddiv" style="display:flex;">
<div class="settinglabel settingcell">
<div title="XTC Threshold" class="justifyleft settingsmall" style="width:100%">Threshold</div>
<div class="justifyleft settingsmall" style="width:100%">
<input title="XTC Threshold" class="settinglabel miniinput" type="text" inputmode="decimal" placeholder="0.0" value="0.0" id="xtc_threshold"></div>
</div>
<div class="settinglabel settingcell">
<div title="XTC Probability" class="justifyleft settingsmall" style="width:100%">Probability</div>
<div class="justifyleft settingsmall" style="width:100%">
<input title="XTC Probability" class="settinglabel miniinput" type="text" inputmode="decimal" placeholder="0.0" value="0.0" id="xtc_probability"></div>
</div>
</div>
<div id="xtcunsupporteddiv" class="color_red" style="font-weight:bold;padding:3px;font-size:12px">XTC Not Supported</div>
</div>
</div>
</div>

<div style="display:flex;width:100%;">
<div class="settinglabel settingcell">
<div title="Grammar" class="justifyleft settingsmall" style="width:100%">Grammar <span class="helpicon">?<span class="helptext">
Grammar Sampling (KCPP) - Allows you to constrain output to fit specific structures. Resets grammar state every generation unless Retain is checked.</span></span></div>
Expand All @@ -18086,7 +18140,6 @@ Current version indicated by LITEVER below.
</div>
</div>
</div>

</div>
</div>
</div>
Expand Down
8 changes: 7 additions & 1 deletion koboldcpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
modelbusy = threading.Lock()
requestsinqueue = 0
defaultport = 5001
KcppVersion = "1.73.1"
KcppVersion = "1.74"
showdebug = True
guimode = False
showsamplerwarning = True
Expand Down Expand Up @@ -154,6 +154,8 @@ class generation_inputs(ctypes.Structure):
("dry_allowed_length", ctypes.c_int),
("dry_penalty_last_n", ctypes.c_int),
("dry_sequence_breakers", ctypes.c_char_p * dry_seq_break_max),
("xtc_threshold", ctypes.c_float),
("xtc_probability", ctypes.c_float),
("sampler_order", ctypes.c_int * sampler_order_max),
("sampler_len", ctypes.c_int),
("allow_eos_token", ctypes.c_bool),
Expand Down Expand Up @@ -896,6 +898,8 @@ def generate(genparams, is_quiet=False, stream_flag=False):
dry_allowed_length = genparams.get('dry_allowed_length', 2)
dry_penalty_last_n = genparams.get('dry_penalty_last_n', 320)
dry_sequence_breakers = genparams.get('dry_sequence_breakers', [])
xtc_threshold = genparams.get('xtc_threshold', 0.2)
xtc_probability = genparams.get('xtc_probability', 0)
sampler_order = genparams.get('sampler_order', [6, 0, 1, 3, 4, 2, 5])
seed = tryparseint(genparams.get('sampler_seed', -1))
stop_sequence = genparams.get('stop_sequence', [])
Expand Down Expand Up @@ -964,6 +968,8 @@ def generate(genparams, is_quiet=False, stream_flag=False):
inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0
inputs.dry_multiplier = dry_multiplier
inputs.dry_base = dry_base
inputs.xtc_threshold = xtc_threshold
inputs.xtc_probability = xtc_probability
inputs.dry_allowed_length = dry_allowed_length
inputs.dry_penalty_last_n = dry_penalty_last_n
# Handle dry_sequence_breakers being passed as a json-encoded array of
Expand Down

0 comments on commit 5bf527a

Please sign in to comment.