Skip to content

Commit

Permalink
make decoder more reliable
Browse files Browse the repository at this point in the history
The original Java code didn't perform any bounds checks.  Thus, the
original Go translation didn't either.  This patch updates the API to
return an error from Decompress(), and adds bound checks.
  • Loading branch information
dgryski committed Oct 13, 2015
1 parent 1716209 commit 6897f36
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 26 deletions.
6 changes: 2 additions & 4 deletions fuzz.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

package quicklz

import "encoding/binary"

func Fuzz(data []byte) int {

if len(data) < 5 {
Expand All @@ -16,12 +14,12 @@ func Fuzz(data []byte) int {

}

ln := binary.LittleEndian.Uint32(data[1:])
ln, _ := sizeDecompressed(data)
if ln > (1 << 21) {
return 0
}

if b := Decompress(data); b == nil {
if _, err := Decompress(data); err != nil {
return 0
}

Expand Down
97 changes: 76 additions & 21 deletions quicklz.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ Licensed under the GPL, like the original.
*/
package quicklz

import "errors"

const (
// Streaming mode not supported
QLZ_STREAMING_BUFFER = 0
Expand Down Expand Up @@ -36,27 +38,30 @@ func headerLen(source []byte) int {
return 3
}

func sizeDecompressed(source []byte) int {
func sizeDecompressed(source []byte) (int, error) {
if headerLen(source) == 9 {
return fastRead(source, 5, 4)
}
return fastRead(source, 2, 1)

}

func sizeCompressed(source []byte) int {
func sizeCompressed(source []byte) (int, error) {
if headerLen(source) == 9 {
return fastRead(source, 1, 4)
}
return fastRead(source, 1, 1)
}

func fastRead(a []byte, i, numbytes int) int {
func fastRead(a []byte, i, numbytes int) (int, error) {
l := 0
if len(a) < i+numbytes {
return 0, ErrCorrupt
}
for j := 0; j < numbytes; j++ {
l |= int(a[i+j]) << (uint(j) * 8)
}
return l
return l, nil
}

func fastWrite(a []byte, i, value, numbytes int) {
Expand Down Expand Up @@ -112,7 +117,7 @@ func Compress(source []byte, level int) []byte {
}

if src <= lastMatchStart {
fetch = fastRead(source, src, 3)
fetch, _ = fastRead(source, src, 3)
}

for src <= lastMatchStart {
Expand Down Expand Up @@ -180,7 +185,7 @@ func Compress(source []byte, level int) []byte {
}
}
lits = 0
fetch = fastRead(source, src, 3)
fetch, _ = fastRead(source, src, 3)
} else {
lits++
hashCounter[hash] = 1
Expand All @@ -191,7 +196,7 @@ func Compress(source []byte, level int) []byte {
fetch = (fetch>>8)&0xffff | int(source[src+2])<<16
}
} else {
fetch = fastRead(source, src, 3)
fetch, _ = fastRead(source, src, 3)

var o, offset2 int
var matchlen, k, m int
Expand Down Expand Up @@ -230,7 +235,7 @@ func Compress(source []byte, level int) []byte {
if matchlen >= 3 && src-o < 131071 {
offset := src - o
for u := 1; u < matchlen; u++ {
fetch = fastRead(source, src+u, 3)
fetch, _ = fastRead(source, src+u, 3)
hash = ((fetch >> 12) ^ fetch) & (HASH_VALUES - 1)
c = hashCounter[hash]
hashCounter[hash]++
Expand Down Expand Up @@ -289,8 +294,16 @@ func Compress(source []byte, level int) []byte {
return d2
}

func Decompress(source []byte) []byte {
size := sizeDecompressed(source)
var (
ErrCorrupt = errors.New("quicklz: corrupt document")
ErrInvalidVersion = errors.New("quicklz: unsupported compression version")
)

func Decompress(source []byte) ([]byte, error) {
size, err := sizeDecompressed(source)
if err != nil || size < 0 {
return nil, ErrCorrupt
}
src := headerLen(source)
var dst int
var cwordVal = 1
Expand All @@ -305,24 +318,35 @@ func Decompress(source []byte) []byte {
level := (source[0] >> 2) & 0x3

if level != 1 && level != 3 {
panic("Go version only supports level 1 and 3")
return nil, ErrInvalidVersion
}

if (source[0] & 1) != 1 {
d2 := make([]byte, size)
copy(d2, source[headerLen(source):])
return d2
l := headerLen(source)
if len(source) < l {
return nil, ErrCorrupt
}
copy(d2, source[l:])
return d2, nil
}

for {
if cwordVal == 1 {
cwordVal = fastRead(source, src, 4)
var err error
cwordVal, err = fastRead(source, src, 4)
if err != nil {
return nil, ErrCorrupt
}
src += 4
if dst <= lastMatchStart {
if level == 1 {
fetch = fastRead(source, src, 3)
fetch, err = fastRead(source, src, 3)
} else {
fetch = fastRead(source, src, 4)
fetch, err = fastRead(source, src, 4)
}
if err != nil {
return nil, ErrCorrupt
}
}
}
Expand All @@ -341,6 +365,9 @@ func Decompress(source []byte) []byte {
matchlen = (fetch & 0xf) + 2
src += 2
} else {
if len(source) <= src+2 {
return nil, ErrCorrupt
}
matchlen = int(source[src+2]) & 0xff
src += 3
}
Expand Down Expand Up @@ -371,6 +398,10 @@ func Decompress(source []byte) []byte {
offset2 = int(dst - offset)
}

if matchlen < 0 || offset2 < 0 || len(destination) <= dst+2 || len(destination) <= offset2+matchlen || len(destination) <= dst+matchlen {
return nil, ErrCorrupt
}

destination[dst+0] = destination[offset2+0]
destination[dst+1] = destination[offset2+1]
destination[dst+2] = destination[offset2+2]
Expand All @@ -381,17 +412,29 @@ func Decompress(source []byte) []byte {
dst += matchlen

if level == 1 {
fetch = fastRead(destination, lastHashed+1, 3) // destination[lastHashed + 1] | (destination[lastHashed + 2] << 8) | (destination[lastHashed + 3] << 16);
fetch, err = fastRead(destination, lastHashed+1, 3) // destination[lastHashed + 1] | (destination[lastHashed + 2] << 8) | (destination[lastHashed + 3] << 16);
if err != nil {
return nil, ErrCorrupt
}
for lastHashed < dst-matchlen {
lastHashed++
hash = ((fetch >> 12) ^ fetch) & (HASH_VALUES - 1)
hashtable[hash] = lastHashed
hashCounter[hash] = 1
if len(destination) <= lastHashed+3 {
return nil, ErrCorrupt
}
fetch = (fetch >> 8 & 0xffff) | (int(destination[lastHashed+3]) << 16)
}
fetch = fastRead(source, src, 3)
fetch, err = fastRead(source, src, 3)
if err != nil {
return nil, ErrCorrupt
}
} else {
fetch = fastRead(source, src, 4)
fetch, err = fastRead(source, src, 4)
if err != nil {
return nil, ErrCorrupt
}
}
lastHashed = dst - 1
} else {
Expand All @@ -404,13 +447,22 @@ func Decompress(source []byte) []byte {
if level == 1 {
for lastHashed < dst-3 {
lastHashed++
fetch2 := fastRead(destination, lastHashed, 3)
fetch2, err := fastRead(destination, lastHashed, 3)
if err != nil {
return nil, ErrCorrupt
}
hash = ((fetch2 >> 12) ^ fetch2) & (HASH_VALUES - 1)
hashtable[hash] = lastHashed
hashCounter[hash] = 1
}
if len(source) <= src+2 {
return nil, ErrCorrupt
}
fetch = fetch>>8&0xffff | int(source[src+2])<<16
} else {
if len(source) <= src+3 {
return nil, ErrCorrupt
}
fetch = fetch>>8&0xffff | int(source[src+2])<<16 | int(source[src+3])<<24
}
} else {
Expand All @@ -420,12 +472,15 @@ func Decompress(source []byte) []byte {
cwordVal = 0x80000000
}

if len(destination) <= dst || len(source) <= src {
return nil, ErrCorrupt
}
destination[dst] = source[src]
dst++
src++
cwordVal = cwordVal >> 1
}
return destination
return destination, nil
}
}
}
Expand Down
6 changes: 5 additions & 1 deletion quicklz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@ func TestCompress(t *testing.T) {

qz := Compress(in[:i], 1)

out := Decompress(qz)
out, err := Decompress(qz)
if err != nil {
t.Errorf("roundtrip error length %d: %v", i, err)
}

if !bytes.Equal(in[:i], out) {
offs := dump(t, "o", out, "i", in[:i])
t.Log("\n" + hex.Dump(qz))
Expand Down

0 comments on commit 6897f36

Please sign in to comment.