Skip to content

Commit

Permalink
Merge pull request #11 from braheezy/weight-dampening
Browse files Browse the repository at this point in the history
Improve weight dampening
  • Loading branch information
braheezy authored Feb 17, 2024
2 parents 3992d2f + 60baf2d commit 08a07bc
Showing 1 changed file with 32 additions and 17 deletions.
49 changes: 32 additions & 17 deletions pkg/qoa/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ func (q *QOA) encodeFrame(sampleData []int16, frameLen uint32, bytes []byte) uin
p += 8

for c := uint32(0); c < channels; c++ {
/* If the weights have grown too large, reset them to 0. This may happen
with certain high-frequency sounds. This is a last resort and will
introduce quite a bit of noise, but should at least prevent pops/clicks */
weightsSum :=
int(q.LMS[c].Weights[0]*q.LMS[c].Weights[0] +
q.LMS[c].Weights[1]*q.LMS[c].Weights[1] +
q.LMS[c].Weights[2]*q.LMS[c].Weights[2] +
q.LMS[c].Weights[3]*q.LMS[c].Weights[3])
if weightsSum > 0x2fffffff {
q.LMS[c].Weights[0] = 0
q.LMS[c].Weights[1] = 0
q.LMS[c].Weights[2] = 0
q.LMS[c].Weights[3] = 0
}
// /* If the weights have grown too large, reset them to 0. This may happen
// with certain high-frequency sounds. This is a last resort and will
// introduce quite a bit of noise, but should at least prevent pops/clicks */
// weightsSum :=
// int(q.LMS[c].Weights[0]*q.LMS[c].Weights[0] +
// q.LMS[c].Weights[1]*q.LMS[c].Weights[1] +
// q.LMS[c].Weights[2]*q.LMS[c].Weights[2] +
// q.LMS[c].Weights[3]*q.LMS[c].Weights[3])
// if weightsSum > 0x2fffffff {
// q.LMS[c].Weights[0] = 0
// q.LMS[c].Weights[1] = 0
// q.LMS[c].Weights[2] = 0
// q.LMS[c].Weights[3] = 0
// }

// Write the current LMS state
history := uint64(0)
Expand All @@ -73,6 +73,7 @@ func (q *QOA) encodeFrame(sampleData []int16, frameLen uint32, bytes []byte) uin
16 scaleFactors, encode all samples for the current slice and
measure the total squared error. */
bestError := -1
bestRank := -1
var bestSlice uint64
var bestLMS qoaLMS
var bestScaleFactor int
Expand All @@ -88,6 +89,7 @@ func (q *QOA) encodeFrame(sampleData []int16, frameLen uint32, bytes []byte) uin
state when encoding. */
lms := q.LMS[c]
slice := uint64(scaleFactor)
currentRank := uint64(0)
currentError := uint64(0)

for si := sliceStart; si < sliceEnd; si += channels {
Expand All @@ -101,17 +103,30 @@ func (q *QOA) encodeFrame(sampleData []int16, frameLen uint32, bytes []byte) uin
dequantized := qoaDequantTable[scaleFactor][quantized]
reconstructed := clampS16(predicted + int(dequantized))

// If the weights have grown too large, we introduce a penalty here. This prevents pops/clicks
// in certain problem cases.
weightsPenalty := (int(q.LMS[c].Weights[0]*q.LMS[c].Weights[0]+
q.LMS[c].Weights[1]*q.LMS[c].Weights[1]+
q.LMS[c].Weights[2]*q.LMS[c].Weights[2]+
q.LMS[c].Weights[3]*q.LMS[c].Weights[3]) >> 18) - 0x8ff
if weightsPenalty < 0 {
weightsPenalty = 0
}

errDelta := int64(sample - int(reconstructed))
currentError += uint64(errDelta * errDelta)
if currentError > uint64(bestError) {
errorSquared := uint64(errDelta * errDelta)
currentRank += errorSquared + uint64(weightsPenalty)*uint64(weightsPenalty)
currentError += errorSquared
if currentError > uint64(bestRank) {
break
}

lms.update(reconstructed, dequantized)
slice = (slice << 3) | uint64(quantized)
}

if currentError < uint64(bestError) {
if currentError < uint64(bestRank) {
bestRank = int(currentRank)
bestError = int(currentError)
bestSlice = slice
bestLMS = lms
Expand Down

0 comments on commit 08a07bc

Please sign in to comment.