Skip to content

Commit

Permalink
Cast size as int64 to remove overflow checks (#605)
Browse files Browse the repository at this point in the history
Three small fixes:
- [`io.ReadFull`](https://pkg.go.dev/io#ReadFull) on prefixes can't
return an `io.EOF` if `prefixBytes > 0`
- Cast `size` to `int64` over `int` to avoid the overflow check
- `prefixBytesRead` must equal 5 on success

---------

Co-authored-by: Akshay Shah <[email protected]>
  • Loading branch information
emcfarlane and akshayjshah authored Oct 12, 2023
1 parent ca44397 commit 96effed
Showing 1 changed file with 14 additions and 32 deletions.
46 changes: 14 additions & 32 deletions envelope.go
Original file line number Diff line number Diff line change
Expand Up @@ -226,21 +226,15 @@ func (r *envelopeReader) Unmarshal(message any) *Error {

func (r *envelopeReader) Read(env *envelope) *Error {
prefixes := [5]byte{}
prefixBytesRead, err := io.ReadFull(r.reader, prefixes[:])

switch {
case (err == nil || errors.Is(err, io.EOF)) &&
prefixBytesRead == 5 &&
isSizeZeroPrefix(prefixes):
// Successfully read prefix and expect no additional data.
env.Flags = prefixes[0]
return nil
case err != nil && errors.Is(err, io.EOF) && prefixBytesRead == 0:
// The stream ended cleanly. That's expected, but we need to propagate them
// to the user so that they know that the stream has ended. We shouldn't
// add any alarming text about protocol errors, though.
return NewError(CodeUnknown, err)
case err != nil:
// io.ReadFull reads the number of bytes requested, or returns an error.
// io.EOF will only be returned if no bytes were read.
if _, err := io.ReadFull(r.reader, prefixes[:]); err != nil {
if errors.Is(err, io.EOF) {
// The stream ended cleanly. That's expected, but we need to propagate an EOF
// to the user so that they know that the stream has ended. We shouldn't
// add any alarming text about protocol errors, though.
return NewError(CodeUnknown, err)
}
// Something else has gone wrong - the stream didn't end cleanly.
if connectErr, ok := asError(err); ok {
return connectErr
Expand All @@ -254,12 +248,9 @@ func (r *envelopeReader) Read(env *envelope) *Error {
"protocol error: incomplete envelope: %w", err,
)
}
size := int(binary.BigEndian.Uint32(prefixes[1:5]))
if size < 0 {
return errorf(CodeInvalidArgument, "message size %d overflowed uint32", size)
}
if r.readMaxBytes > 0 && size > r.readMaxBytes {
_, err := io.CopyN(io.Discard, r.reader, int64(size))
size := int64(binary.BigEndian.Uint32(prefixes[1:5]))
if r.readMaxBytes > 0 && size > int64(r.readMaxBytes) {
_, err := io.CopyN(io.Discard, r.reader, size)
if err != nil && !errors.Is(err, io.EOF) {
return errorf(CodeUnknown, "read enveloped message: %w", err)
}
Expand All @@ -270,7 +261,7 @@ func (r *envelopeReader) Read(env *envelope) *Error {
// length-prefixed messages may arrive in chunks, so we may need to read
// the request body past EOF. We also need to take care that we don't retry
// forever if the message is malformed.
remaining := int64(size)
remaining := size
for remaining > 0 {
bytesRead, err := io.CopyN(env.Data, r.reader, remaining)
if err != nil && !errors.Is(err, io.EOF) {
Expand All @@ -287,7 +278,7 @@ func (r *envelopeReader) Read(env *envelope) *Error {
CodeInvalidArgument,
"protocol error: promised %d bytes in enveloped message, got %d bytes",
size,
int64(size)-remaining,
size-remaining,
)
}
remaining -= bytesRead
Expand All @@ -296,12 +287,3 @@ func (r *envelopeReader) Read(env *envelope) *Error {
env.Flags = prefixes[0]
return nil
}

func isSizeZeroPrefix(prefix [5]byte) bool {
for i := 1; i < 5; i++ {
if prefix[i] != 0 {
return false
}
}
return true
}

0 comments on commit 96effed

Please sign in to comment.