Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve weight dampening #11

Merged
merged 1 commit into from
Feb 17, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 32 additions & 17 deletions pkg/qoa/encode.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ func (q *QOA) encodeFrame(sampleData []int16, frameLen uint32, bytes []byte) uin
p += 8

for c := uint32(0); c < channels; c++ {
/* If the weights have grown too large, reset them to 0. This may happen
with certain high-frequency sounds. This is a last resort and will
introduce quite a bit of noise, but should at least prevent pops/clicks */
weightsSum :=
int(q.LMS[c].Weights[0]*q.LMS[c].Weights[0] +
q.LMS[c].Weights[1]*q.LMS[c].Weights[1] +
q.LMS[c].Weights[2]*q.LMS[c].Weights[2] +
q.LMS[c].Weights[3]*q.LMS[c].Weights[3])
if weightsSum > 0x2fffffff {
q.LMS[c].Weights[0] = 0
q.LMS[c].Weights[1] = 0
q.LMS[c].Weights[2] = 0
q.LMS[c].Weights[3] = 0
}
// /* If the weights have grown too large, reset them to 0. This may happen
// with certain high-frequency sounds. This is a last resort and will
// introduce quite a bit of noise, but should at least prevent pops/clicks */
// weightsSum :=
// int(q.LMS[c].Weights[0]*q.LMS[c].Weights[0] +
// q.LMS[c].Weights[1]*q.LMS[c].Weights[1] +
// q.LMS[c].Weights[2]*q.LMS[c].Weights[2] +
// q.LMS[c].Weights[3]*q.LMS[c].Weights[3])
// if weightsSum > 0x2fffffff {
// q.LMS[c].Weights[0] = 0
// q.LMS[c].Weights[1] = 0
// q.LMS[c].Weights[2] = 0
// q.LMS[c].Weights[3] = 0
// }

// Write the current LMS state
history := uint64(0)
Expand All @@ -73,6 +73,7 @@ func (q *QOA) encodeFrame(sampleData []int16, frameLen uint32, bytes []byte) uin
16 scaleFactors, encode all samples for the current slice and
measure the total squared error. */
bestError := -1
bestRank := -1
var bestSlice uint64
var bestLMS qoaLMS
var bestScaleFactor int
Expand All @@ -88,6 +89,7 @@ func (q *QOA) encodeFrame(sampleData []int16, frameLen uint32, bytes []byte) uin
state when encoding. */
lms := q.LMS[c]
slice := uint64(scaleFactor)
currentRank := uint64(0)
currentError := uint64(0)

for si := sliceStart; si < sliceEnd; si += channels {
Expand All @@ -101,17 +103,30 @@ func (q *QOA) encodeFrame(sampleData []int16, frameLen uint32, bytes []byte) uin
dequantized := qoaDequantTable[scaleFactor][quantized]
reconstructed := clampS16(predicted + int(dequantized))

// If the weights have grown too large, we introduce a penalty here. This prevents pops/clicks
// in certain problem cases.
weightsPenalty := (int(q.LMS[c].Weights[0]*q.LMS[c].Weights[0]+
q.LMS[c].Weights[1]*q.LMS[c].Weights[1]+
q.LMS[c].Weights[2]*q.LMS[c].Weights[2]+
q.LMS[c].Weights[3]*q.LMS[c].Weights[3]) >> 18) - 0x8ff
if weightsPenalty < 0 {
weightsPenalty = 0
}

errDelta := int64(sample - int(reconstructed))
currentError += uint64(errDelta * errDelta)
if currentError > uint64(bestError) {
errorSquared := uint64(errDelta * errDelta)
currentRank += errorSquared + uint64(weightsPenalty)*uint64(weightsPenalty)
currentError += errorSquared
if currentError > uint64(bestRank) {
break
}

lms.update(reconstructed, dequantized)
slice = (slice << 3) | uint64(quantized)
}

if currentError < uint64(bestError) {
if currentError < uint64(bestRank) {
bestRank = int(currentRank)
bestError = int(currentError)
bestSlice = slice
bestLMS = lms
Expand Down
Loading