From b50221057990f0624862b5355343feba31913b1a Mon Sep 17 00:00:00 2001 From: donderom <274926+donderom@users.noreply.github.com> Date: Fri, 24 May 2024 13:14:29 +0200 Subject: [PATCH] Make new lines penalization optional --- src/main/scala/com/donderom/llm4s/Params.scala | 3 ++- src/main/scala/com/donderom/llm4s/SlincLlm.scala | 7 +++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/main/scala/com/donderom/llm4s/Params.scala b/src/main/scala/com/donderom/llm4s/Params.scala index b60f080..eaf7e5b 100644 --- a/src/main/scala/com/donderom/llm4s/Params.scala +++ b/src/main/scala/com/donderom/llm4s/Params.scala @@ -60,7 +60,8 @@ final case class ContextParams( final case class Penalty( repeat: Float = 1.10f, frequency: Float = .0f, - presence: Float = .0f + presence: Float = .0f, + penalizeNewLines: Boolean = true ) final case class Dynatemp( diff --git a/src/main/scala/com/donderom/llm4s/SlincLlm.scala b/src/main/scala/com/donderom/llm4s/SlincLlm.scala index 3ae2502..c886a14 100644 --- a/src/main/scala/com/donderom/llm4s/SlincLlm.scala +++ b/src/main/scala/com/donderom/llm4s/SlincLlm.scala @@ -116,6 +116,7 @@ private class SlincLlm private[llm4s] (private[llm4s] val ctx: Llama.Ctx): lazy val addBos: Boolean = if addBosToken != -1 then addBosToken != 0 else llama.llama_vocab_type(model) == Llama.VocabType.SPM + lazy val newLineToken: Int = llama.llama_token_nl(model) def keepGenerating(token: Int): Boolean = !llama.llama_token_is_eog(model, token) @@ -220,6 +221,12 @@ private class SlincLlm private[llm4s] (private[llm4s] val ctx: Llama.Ctx): penalty_present = sampling.penalty.presence ) + if !sampling.penalty.penalizeNewLines then + val newLineLogit = logits(newLineToken) + val newLineIndex = tokenData.indexWhere(_.id == newLineToken) + if newLineIndex != -1 then + !data(newLineIndex) = (!data(newLineIndex)).copy(logit = newLineLogit) + val tokenId = sampling match case Greedy(_, _, logprobs) => if logprobs > 0 then