Skip to content

Commit

Permalink
protect against malformed frames
Browse files Browse the repository at this point in the history
  • Loading branch information
mhr3 committed May 28, 2023
1 parent 34afa95 commit b90fa65
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
12 changes: 7 additions & 5 deletions gozstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ import (
// DefaultCompressionLevel is the default compression level.
const DefaultCompressionLevel = 3 // Obtained from ZSTD_CLEVEL_DEFAULT.

const maxFrameContentSize = 256 << 20 // 256 MB

// Compress appends compressed src to dst and returns the result.
func Compress(dst, src []byte) []byte {
return compressDictLevel(dst, src, nil, DefaultCompressionLevel)
Expand Down Expand Up @@ -257,14 +259,14 @@ func decompress(dctx, dctxDict *dctxWrapper, dst, src []byte, dd *DDict) ([]byte

// Slow path - resize dst to fit decompressed data.
srcHdr := (*reflect.SliceHeader)(unsafe.Pointer(&src))
decompressBound := int(C.ZSTD_getFrameContentSize_wrapper(unsafe.Pointer(srcHdr.Data), C.size_t(len(src))))
switch uint64(decompressBound) {
case uint64(C.ZSTD_CONTENTSIZE_UNKNOWN):
contentSize := C.ZSTD_getFrameContentSize_wrapper(unsafe.Pointer(srcHdr.Data), C.size_t(len(src)))
switch {
case contentSize == C.ZSTD_CONTENTSIZE_UNKNOWN || contentSize > maxFrameContentSize:
return streamDecompress(dst, src, dd)
case uint64(C.ZSTD_CONTENTSIZE_ERROR):
case contentSize == C.ZSTD_CONTENTSIZE_ERROR:
return dst, fmt.Errorf("cannot decompress invalid src")
}
decompressBound++
decompressBound := int(contentSize) + 1

if n := dstLen + decompressBound - cap(dst); n > 0 {
// This should be optimized since go 1.11 - see https://golang.org/doc/go1.11#performance-compiler.
Expand Down
8 changes: 8 additions & 0 deletions gozstd_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,14 @@ func TestDecompressSmallBlockWithoutSingleSegmentFlag(t *testing.T) {
})
}

func TestDecompressTooLarge(t *testing.T) {
src := []byte{40, 181, 47, 253, 228, 122, 118, 105, 67, 140, 234, 85, 20, 159, 67}
_, err := Decompress(nil, src)
if err == nil {
t.Fatalf("expecting error when decompressing malformed frame")
}
}

func mustUnhex(dataHex string) []byte {
data, err := hex.DecodeString(dataHex)
if err != nil {
Expand Down

0 comments on commit b90fa65

Please sign in to comment.