Skip to content

Commit

Permalink
zstd: Handle zero sized frame content size stricter (#521)
Browse files Browse the repository at this point in the history
`0` was used for "unknown". Instead use max uint64 so fcs == 0 is properly respected.

Conforms to reference decoder.
  • Loading branch information
klauspost authored Mar 9, 2022
1 parent 531d692 commit 96d0db7
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 15 deletions.
12 changes: 6 additions & 6 deletions zstd/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -348,10 +348,10 @@ func (d *Decoder) DecodeAll(input, dst []byte) ([]byte, error) {
frame.history.setDict(&dict)
}

if frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) {
if frame.FrameContentSize != fcsUnknown && frame.FrameContentSize > d.o.maxDecodedSize-uint64(len(dst)) {
return dst, ErrDecoderSizeExceeded
}
if frame.FrameContentSize > 0 && frame.FrameContentSize < 1<<30 {
if frame.FrameContentSize < 1<<30 {
// Never preallocate more than 1 GB up front.
if cap(dst)-len(dst) < int(frame.FrameContentSize) {
dst2 := make([]byte, len(dst), len(dst)+int(frame.FrameContentSize))
Expand Down Expand Up @@ -514,7 +514,7 @@ func (d *Decoder) nextBlockSync() (ok bool) {

// Check frame size (before CRC)
d.syncStream.decodedFrame += uint64(len(d.current.b))
if d.frame.FrameContentSize > 0 && d.syncStream.decodedFrame > d.frame.FrameContentSize {
if d.syncStream.decodedFrame > d.frame.FrameContentSize {
if debugDecoder {
printf("DecodedFrame (%d) > FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize)
}
Expand All @@ -523,7 +523,7 @@ func (d *Decoder) nextBlockSync() (ok bool) {
}

// Check FCS
if d.current.d.Last && d.frame.FrameContentSize > 0 && d.syncStream.decodedFrame != d.frame.FrameContentSize {
if d.current.d.Last && d.frame.FrameContentSize != fcsUnknown && d.syncStream.decodedFrame != d.frame.FrameContentSize {
if debugDecoder {
printf("DecodedFrame (%d) != FrameContentSize (%d)\n", d.syncStream.decodedFrame, d.frame.FrameContentSize)
}
Expand Down Expand Up @@ -811,11 +811,11 @@ func (d *Decoder) startStreamDecoder(ctx context.Context, r io.Reader, output ch
}
if !hasErr {
decodedFrame += uint64(len(do.b))
if fcs > 0 && decodedFrame > fcs {
if decodedFrame > fcs {
println("fcs exceeded", block.Last, fcs, decodedFrame)
do.err = ErrFrameSizeExceeded
hasErr = true
} else if block.Last && fcs > 0 && decodedFrame != fcs {
} else if block.Last && fcs != fcsUnknown && decodedFrame != fcs {
do.err = ErrFrameSizeMismatch
hasErr = true
} else {
Expand Down
4 changes: 4 additions & 0 deletions zstd/decoder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1142,6 +1142,10 @@ func testDecoderFileBad(t *testing.T, fn string, newDec func() (*Decoder, error)
return
}
got, err := ioutil.ReadAll(dec)
if err == ErrCRCMismatch && !strings.Contains(tt.Name, "badsum") {
t.Error(err)
return
}
if err == nil {
want := errMap[tt.Name]
if want == "" {
Expand Down
13 changes: 4 additions & 9 deletions zstd/framedec.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,7 @@ func (d *frameDec) reset(br byteBuffer) error {
default:
fcsSize = 1 << v
}
d.FrameContentSize = 0
d.FrameContentSize = fcsUnknown
if fcsSize > 0 {
b, err := br.readSmall(fcsSize)
if err != nil {
Expand Down Expand Up @@ -343,26 +343,21 @@ func (d *frameDec) runDecoder(dst []byte, dec *blockDec) ([]byte, error) {
err = ErrDecoderSizeExceeded
break
}
if d.SingleSegment && uint64(len(d.history.b)) > d.o.maxDecodedSize {
println("runDecoder: single segment and", uint64(len(d.history.b)), ">", d.o.maxDecodedSize)
err = ErrFrameSizeExceeded
break
}
if d.FrameContentSize > 0 && uint64(len(d.history.b)-crcStart) > d.FrameContentSize {
if uint64(len(d.history.b)-crcStart) > d.FrameContentSize {
println("runDecoder: FrameContentSize exceeded", uint64(len(d.history.b)-crcStart), ">", d.FrameContentSize)
err = ErrFrameSizeExceeded
break
}
if dec.Last {
break
}
if debugDecoder && d.FrameContentSize > 0 {
if debugDecoder {
println("runDecoder: FrameContentSize", uint64(len(d.history.b)-crcStart), "<=", d.FrameContentSize)
}
}
dst = d.history.b
if err == nil {
if d.FrameContentSize > 0 && uint64(len(d.history.b)-crcStart) != d.FrameContentSize {
if d.FrameContentSize != fcsUnknown && uint64(len(d.history.b)-crcStart) != d.FrameContentSize {
err = ErrFrameSizeMismatch
} else if d.HasCheckSum {
var n int
Expand Down
Binary file modified zstd/testdata/bad.zip
Binary file not shown.
3 changes: 3 additions & 0 deletions zstd/zstd.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ const zstdMinMatch = 3
// Reset the buffer offset when reaching this.
const bufferReset = math.MaxInt32 - MaxWindowSize

// fcsUnknown is used for unknown frame content size.
const fcsUnknown = math.MaxUint64

var (
// ErrReservedBlockType is returned when a reserved block type is found.
// Typically this indicates wrong or corrupted input.
Expand Down

0 comments on commit 96d0db7

Please sign in to comment.