Skip to content

Commit

Permalink
implement CidFromReader
Browse files Browse the repository at this point in the history
And reuse two CidFromBytes tests for it, which includes both CIDv0 and
CIDv1 cases as inputs, as well as some inputs that should error.

Fixes #126.
  • Loading branch information
mvdan committed Jul 14, 2021
1 parent 8e9280d commit 8f4ec9e
Show file tree
Hide file tree
Showing 2 changed files with 226 additions and 41 deletions.
136 changes: 136 additions & 0 deletions cid.go
Original file line number Diff line number Diff line change
Expand Up @@ -680,3 +680,139 @@ func CidFromBytes(data []byte) (int, Cid, error) {

return l, Cid{string(data[0:l])}, nil
}

func toBufByteReader(r io.Reader, dst []byte) *bufByteReader {
// If the reader already implements ByteReader, use it directly.
// Otherwise, use a fallback that does 1-byte Reads.
if br, ok := r.(io.ByteReader); ok {
return &bufByteReader{direct: br, dst: dst}
}
return &bufByteReader{fallback: r, dst: dst}
}

type bufByteReader struct {
direct io.ByteReader
fallback io.Reader

dst []byte
}

func (r *bufByteReader) ReadByte() (byte, error) {
// The underlying reader has ReadByte; use it.
if br := r.direct; br != nil {
b, err := br.ReadByte()
if err != nil {
return 0, err
}
r.dst = append(r.dst, b)
return b, nil
}

// Fall back to a one-byte Read.
// TODO: consider reading straight into dst,
// once we have benchmarks and if they prove that to be faster.
var p [1]byte
if _, err := io.ReadFull(r.fallback, p[:]); err != nil {
return 0, err
}
r.dst = append(r.dst, p[0])
return p[0], nil
}

// CidFromReader reads a precise number of bytes for a CID from a given reader.
// It returns the number of bytes read, the CID, and any error encountered.
// The number of bytes read is accurate even if a non-nil error is returned.
//
// It's recommended to supply a reader that buffers and implements io.ByteReader,
// as CidFromReader has to do many single-byte reads to decode varints.
// If the argument only implements io.Reader, single-byte Read calls are used instead.
func CidFromReader(r io.Reader) (int, Cid, error) {
// 64 bytes is enough for any CIDv0,
// and it's enough for most CIDv1s in practice.
// If the digest is too long, we'll allocate more.
br := toBufByteReader(r, make([]byte, 0, 64))

// We read the first varint, to tell if this is a CIDv0 or a CIDv1.
// The varint package wants a io.ByteReader, so we must wrap our io.Reader.
vers, err := varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
}

// If we have a CIDv0, read the rest of the bytes and cast the buffer.
if vers == mh.SHA2_256 {
if n, err := io.ReadFull(r, br.dst[1:34]); err != nil {
return len(br.dst) + n, Undef, err
}

br.dst = br.dst[:34]
h, err := mh.Cast(br.dst)
if err != nil {
return len(br.dst), Undef, err
}

return len(br.dst), Cid{string(h)}, nil
}

if vers != 1 {
return len(br.dst), Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers)
}

// CID block encoding multicodec.
_, err = varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
}

// We could replace most of the code below with go-multihash's ReadMultihash.
// Note that it would save code, but prevent reusing buffers.
// Plus, we already have a ByteReader now.
mhStart := len(br.dst)

// Multihash hash function code.
_, err = varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
}

// Multihash digest length.
mhl, err := varint.ReadUvarint(br)
if err != nil {
return len(br.dst), Undef, err
}

// Refuse to make large allocations to prevent OOMs due to bugs.
const maxDigestAlloc = 32 << 20 // 32MiB
if mhl > maxDigestAlloc {
return len(br.dst), Undef, fmt.Errorf("refusing to allocate %d bytes for a digest", mhl)
}

// Fine to convert mhl to int, given maxDigestAlloc.
prefixLength := len(br.dst)
cidLength := prefixLength + int(mhl)
if cidLength > cap(br.dst) {
// If the multihash digest doesn't fit in our initial 64 bytes,
// efficiently extend the slice via append+make.
br.dst = append(br.dst, make([]byte, cidLength-cap(br.dst))...)
} else {
// The multihash digest fits inside our buffer,
// so just extend its capacity.
br.dst = br.dst[:cidLength]
}

if n, err := io.ReadFull(r, br.dst[prefixLength:cidLength]); err != nil {
// We can't use len(br.dst) here,
// as we've only read n bytes past prefixLength.
return prefixLength + n, Undef, err
}

// This simply ensures the multihash is valid.
// TODO: consider removing this bit, as it's probably redundant;
// for now, it helps ensure consistency with CidFromBytes.
_, _, err = mh.MHFromBytes(br.dst[mhStart:])
if err != nil {
return len(br.dst), Undef, err
}

return len(br.dst), Cid{string(br.dst)}, nil
}
131 changes: 90 additions & 41 deletions cid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"bytes"
"encoding/json"
"fmt"
"io"
"math/rand"
"reflect"
"strings"
"testing"
"testing/iotest"

mbase "github.com/multiformats/go-multibase"
mh "github.com/multiformats/go-multihash"
Expand Down Expand Up @@ -692,51 +694,98 @@ func TestReadCidsFromBuffer(t *testing.T) {
if cur != len(buf) {
t.Fatal("had trailing bytes")
}
}

func TestBadCidFromBytes(t *testing.T) {
l, c, err := CidFromBytes([]byte{mh.SHA2_256, 32, 0x00})
if err == nil {
t.Fatal("expected not-enough-bytes for V0 CidFromBytes")
}
if l != 0 {
t.Fatal("expected length=0 from bad CidFromBytes")
}
if c != Undef {
t.Fatal("expected Undef CID from bad CidFromBytes")
}
// The same, but now with CidFromReader.
// In multiple forms, to catch more io interface bugs.
for _, r := range []io.Reader{
// implements io.ByteReader
bytes.NewReader(buf),

c, err = Decode("bafkreie5qrjvaw64n4tjm6hbnm7fnqvcssfed4whsjqxzslbd3jwhsk3mm")
if err != nil {
t.Fatal(err)
}
byts := make([]byte, c.ByteLen())
copy(byts, c.Bytes())
byts[1] = 0x80 // bad codec varint
byts[2] = 0x00
l, c, err = CidFromBytes(byts)
if err == nil {
t.Fatal("expected not-enough-bytes for V1 CidFromBytes")
}
if l != 0 {
t.Fatal("expected length=0 from bad CidFromBytes")
}
if c != Undef {
t.Fatal("expected Undef CID from bad CidFromBytes")
// tiny reads, no io.ByteReader
iotest.OneByteReader(bytes.NewReader(buf)),
} {
cur = 0
for _, expc := range cids {
n, c, err := CidFromReader(r)
if err != nil {
t.Fatal(err)
}
if c != expc {
t.Fatal("cids mismatched")
}
cur += n
}
if cur != len(buf) {
t.Fatal("had trailing bytes")
}
}
}

copy(byts, c.Bytes())
byts[2] = 0x80 // bad multihash varint
byts[3] = 0x00
l, c, err = CidFromBytes(byts)
if err == nil {
t.Fatal("expected not-enough-bytes for V1 CidFromBytes")
}
if l != 0 {
t.Fatal("expected length=0 from bad CidFromBytes")
}
if c != Undef {
t.Fatal("expected Undef CID from bad CidFromBytes")
func TestBadCidInput(t *testing.T) {
for _, name := range []string{
"FromBytes",
"FromReader",
} {
t.Run(name, func(t *testing.T) {
usingReader := name == "FromReader"

fromBytes := CidFromBytes
if usingReader {
fromBytes = func(data []byte) (int, Cid, error) {
return CidFromReader(bytes.NewReader(data))
}
}

l, c, err := fromBytes([]byte{mh.SHA2_256, 32, 0x00})
if err == nil {
t.Fatal("expected not-enough-bytes for V0 CID")
}
if !usingReader && l != 0 {
t.Fatal("expected length==0 from bad CID")
} else if usingReader && l == 0 {
t.Fatal("expected length!=0 from bad CID")
}
if c != Undef {
t.Fatal("expected Undef CID from bad CID")
}

c, err = Decode("bafkreie5qrjvaw64n4tjm6hbnm7fnqvcssfed4whsjqxzslbd3jwhsk3mm")
if err != nil {
t.Fatal(err)
}
byts := make([]byte, c.ByteLen())
copy(byts, c.Bytes())
byts[1] = 0x80 // bad codec varint
byts[2] = 0x00
l, c, err = fromBytes(byts)
if err == nil {
t.Fatal("expected not-enough-bytes for V1 CID")
}
if !usingReader && l != 0 {
t.Fatal("expected length==0 from bad CID")
} else if usingReader && l == 0 {
t.Fatal("expected length!=0 from bad CID")
}
if c != Undef {
t.Fatal("expected Undef CID from bad CID")
}

copy(byts, c.Bytes())
byts[2] = 0x80 // bad multihash varint
byts[3] = 0x00
l, c, err = fromBytes(byts)
if err == nil {
t.Fatal("expected not-enough-bytes for V1 CID")
}
if !usingReader && l != 0 {
t.Fatal("expected length==0 from bad CID")
} else if usingReader && l == 0 {
t.Fatal("expected length!=0 from bad CID")
}
if c != Undef {
t.Fatal("expected Undef CID from bad CidFromBytes")
}
})
}
}

Expand Down

0 comments on commit 8f4ec9e

Please sign in to comment.