Skip to content

Commit

Permalink
Make new lines penalization optional
Browse files Browse the repository at this point in the history
  • Loading branch information
donderom committed May 24, 2024
1 parent be6f88f commit b502210
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/main/scala/com/donderom/llm4s/Params.scala
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ final case class ContextParams(
final case class Penalty(
repeat: Float = 1.10f,
frequency: Float = .0f,
presence: Float = .0f
presence: Float = .0f,
penalizeNewLines: Boolean = true
)

final case class Dynatemp(
Expand Down
7 changes: 7 additions & 0 deletions src/main/scala/com/donderom/llm4s/SlincLlm.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ private class SlincLlm private[llm4s] (private[llm4s] val ctx: Llama.Ctx):
lazy val addBos: Boolean =
if addBosToken != -1 then addBosToken != 0
else llama.llama_vocab_type(model) == Llama.VocabType.SPM
lazy val newLineToken: Int = llama.llama_token_nl(model)

def keepGenerating(token: Int): Boolean =
!llama.llama_token_is_eog(model, token)
Expand Down Expand Up @@ -220,6 +221,12 @@ private class SlincLlm private[llm4s] (private[llm4s] val ctx: Llama.Ctx):
penalty_present = sampling.penalty.presence
)

if !sampling.penalty.penalizeNewLines then
val newLineLogit = logits(newLineToken)
val newLineIndex = tokenData.indexWhere(_.id == newLineToken)
if newLineIndex != -1 then
!data(newLineIndex) = (!data(newLineIndex)).copy(logit = newLineLogit)

val tokenId = sampling match
case Greedy(_, _, logprobs) =>
if logprobs > 0 then
Expand Down

0 comments on commit b502210

Please sign in to comment.