diff --git a/pkg/qoa/encode.go b/pkg/qoa/encode.go index 13f5ec4..1ec1450 100644 --- a/pkg/qoa/encode.go +++ b/pkg/qoa/encode.go @@ -34,20 +34,20 @@ func (q *QOA) encodeFrame(sampleData []int16, frameLen uint32, bytes []byte) uin p += 8 for c := uint32(0); c < channels; c++ { - /* If the weights have grown too large, reset them to 0. This may happen - with certain high-frequency sounds. This is a last resort and will - introduce quite a bit of noise, but should at least prevent pops/clicks */ - weightsSum := - int(q.LMS[c].Weights[0]*q.LMS[c].Weights[0] + - q.LMS[c].Weights[1]*q.LMS[c].Weights[1] + - q.LMS[c].Weights[2]*q.LMS[c].Weights[2] + - q.LMS[c].Weights[3]*q.LMS[c].Weights[3]) - if weightsSum > 0x2fffffff { - q.LMS[c].Weights[0] = 0 - q.LMS[c].Weights[1] = 0 - q.LMS[c].Weights[2] = 0 - q.LMS[c].Weights[3] = 0 - } + // /* If the weights have grown too large, reset them to 0. This may happen + // with certain high-frequency sounds. This is a last resort and will + // introduce quite a bit of noise, but should at least prevent pops/clicks */ + // weightsSum := + // int(q.LMS[c].Weights[0]*q.LMS[c].Weights[0] + + // q.LMS[c].Weights[1]*q.LMS[c].Weights[1] + + // q.LMS[c].Weights[2]*q.LMS[c].Weights[2] + + // q.LMS[c].Weights[3]*q.LMS[c].Weights[3]) + // if weightsSum > 0x2fffffff { + // q.LMS[c].Weights[0] = 0 + // q.LMS[c].Weights[1] = 0 + // q.LMS[c].Weights[2] = 0 + // q.LMS[c].Weights[3] = 0 + // } // Write the current LMS state history := uint64(0) @@ -73,6 +73,7 @@ func (q *QOA) encodeFrame(sampleData []int16, frameLen uint32, bytes []byte) uin 16 scaleFactors, encode all samples for the current slice and measure the total squared error. */ bestError := -1 + bestRank := -1 var bestSlice uint64 var bestLMS qoaLMS var bestScaleFactor int @@ -88,6 +89,7 @@ func (q *QOA) encodeFrame(sampleData []int16, frameLen uint32, bytes []byte) uin state when encoding. */ lms := q.LMS[c] slice := uint64(scaleFactor) + currentRank := uint64(0) currentError := uint64(0) for si := sliceStart; si < sliceEnd; si += channels { @@ -101,9 +103,21 @@ func (q *QOA) encodeFrame(sampleData []int16, frameLen uint32, bytes []byte) uin dequantized := qoaDequantTable[scaleFactor][quantized] reconstructed := clampS16(predicted + int(dequantized)) + // If the weights have grown too large, we introduce a penalty here. This prevents pops/clicks + // in certain problem cases. + weightsPenalty := (int(q.LMS[c].Weights[0]*q.LMS[c].Weights[0]+ + q.LMS[c].Weights[1]*q.LMS[c].Weights[1]+ + q.LMS[c].Weights[2]*q.LMS[c].Weights[2]+ + q.LMS[c].Weights[3]*q.LMS[c].Weights[3]) >> 18) - 0x8ff + if weightsPenalty < 0 { + weightsPenalty = 0 + } + errDelta := int64(sample - int(reconstructed)) - currentError += uint64(errDelta * errDelta) - if currentError > uint64(bestError) { + errorSquared := uint64(errDelta * errDelta) + currentRank += errorSquared + uint64(weightsPenalty)*uint64(weightsPenalty) + currentError += errorSquared + if currentError > uint64(bestRank) { break } @@ -111,7 +125,8 @@ func (q *QOA) encodeFrame(sampleData []int16, frameLen uint32, bytes []byte) uin slice = (slice << 3) | uint64(quantized) } - if currentError < uint64(bestError) { + if currentError < uint64(bestRank) { + bestRank = int(currentRank) bestError = int(currentError) bestSlice = slice bestLMS = lms