Skip to content

Commit

Permalink
implement CidFromReader
Browse files Browse the repository at this point in the history
And reuse a CidFromBytes test for it, which includes both CIDv0 and
CIDv1 cases as inputs.

Fixes #126.
  • Loading branch information
mvdan committed Jul 2, 2021
1 parent 8e9280d commit e3b3357
Show file tree
Hide file tree
Showing 2 changed files with 165 additions and 0 deletions.
138 changes: 138 additions & 0 deletions cid.go
Original file line number Diff line number Diff line change
Expand Up @@ -680,3 +680,141 @@ func CidFromBytes(data []byte) (int, Cid, error) {

return l, Cid{string(data[0:l])}, nil
}

func toBufByteReader(r io.Reader, dst []byte) *bufByteReader {
// If the reader already implements ByteReader, use it directly.
// Otherwise, use a fallback that does 1-byte Reads.
if br, ok := r.(io.ByteReader); ok {
return &bufByteReader{direct: br, dst: dst, alreadyRead: len(dst)}
}
return &bufByteReader{fallback: r, dst: dst, alreadyRead: len(dst)}
}

type bufByteReader struct {
direct io.ByteReader
fallback io.Reader

alreadyRead int
dst []byte
}

func (r *bufByteReader) ReadByte() (byte, error) {
// We still have some of the initial bytes to use.
if r.alreadyRead > 0 {
b := r.dst[len(r.dst)-r.alreadyRead]
r.alreadyRead--
return b, nil
}

// The underlying reader has ReadByte; use it.
if br := r.direct; br != nil {
b, err := br.ReadByte()
if err != nil {
return 0, err
}
r.dst = append(r.dst, b)
return b, nil
}

// Fall back to a one-byte Read.
var p [1]byte
if _, err := io.ReadFull(r.fallback, p[:]); err != nil {
return 0, err
}
r.dst = append(r.dst, p[0])
return p[0], nil
}

// CidFromReader reads a precise number of bytes for a CID from a given reader.
// It returns the number of bytes read, the CID, and any error encountered.
//
// It's recommended to supply a reader that buffers and implements io.ByteReader,
// as CidFromReader has to do many single-byte reads to decode varints.
// If the argument only implements io.Reader, single-byte Read calls are used instead.
func CidFromReader(r io.Reader) (int, Cid, error) {
// 64 bytes is enough for any CIDv0,
// and it's enough for most CIDv1s in practice.
// If the digest is too long, we'll allocate more.
buf := make([]byte, 0, 64)

// We read two bytes, to tell if this is a CIDv0 or a CIDv1.
v0head := buf[:2]
if _, err := io.ReadFull(r, v0head[:]); err != nil {
return 0, Undef, err
}

// If we have a CIDv0, read the rest of the bytes and cast the buffer.
if v0head[0] == mh.SHA2_256 && v0head[1] == 32 {
if _, err := io.ReadFull(r, buf[2:34]); err != nil {
return 0, Undef, err
}

h, err := mh.Cast(buf[:34])
if err != nil {
return 0, Undef, err
}

return 34, Cid{string(h)}, nil
}

// The varint package wants a io.ByteReader, so we must wrap our io.Reader.
// Note that we already read two bytes, so bufByteReader uses those first.
// After those two bytes, bufByteReader appends the read bytes to br.dst.
br := toBufByteReader(r, buf[:2])
vers, err := varint.ReadUvarint(br)
if err != nil {
return 0, Undef, err
}

if vers != 1 {
return 0, Undef, fmt.Errorf("expected 1 as the cid version number, got: %d", vers)
}

// CID block encoding multicodec.
_, err = varint.ReadUvarint(br)
if err != nil {
return 0, Undef, err
}

// TODO: we could replace all this with multihash.MHFromReader
mhStart := len(br.dst)

// Multihash hash function code.
_, err = varint.ReadUvarint(br)
if err != nil {
return 0, Undef, err
}

// Multihash digest length.
mhl, err := varint.ReadUvarint(br)
if err != nil {
return 0, Undef, err
}
_ = mhl

// Update buf's length.
// We're not reading single bytes beyond this point.
buf = br.dst
br = nil

// Multihash digest; might be too long, so allocate.
// Refuse to make large allocations to prevent OOMs due to bugs.
// TODO: reuse buf if it has enough space
const maxDigestAlloc = 32 << 20 // 32MiB
if mhl > maxDigestAlloc {
return 0, Undef, fmt.Errorf("refusing to allocate %d bytes for a digest", mhl)
}
digest := make([]byte, int(mhl))
if _, err := io.ReadFull(r, digest); err != nil {
return 0, Undef, err
}
buf = append(buf, digest...)

// This simply ensures the multihash is valid.
_, _, err = mh.MHFromBytes(buf[mhStart:])
if err != nil {
return 0, Undef, err
}

return len(buf), Cid{string(buf)}, nil
}
27 changes: 27 additions & 0 deletions cid_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"bytes"
"encoding/json"
"fmt"
"io"
"math/rand"
"reflect"
"strings"
"testing"
"testing/iotest"

mbase "github.com/multiformats/go-multibase"
mh "github.com/multiformats/go-multihash"
Expand Down Expand Up @@ -692,6 +694,31 @@ func TestReadCidsFromBuffer(t *testing.T) {
if cur != len(buf) {
t.Fatal("had trailing bytes")
}

// The same, but now with CidFromReader.
// In multiple forms, to catch more io interface bugs.
for _, r := range []io.Reader{
// implements io.ByteReader
bytes.NewReader(buf),

// tiny reads, no io.ByteReader
iotest.OneByteReader(bytes.NewReader(buf)),
} {
cur = 0
for _, expc := range cids {
n, c, err := CidFromReader(r)
if err != nil {
t.Fatal(err)
}
if c != expc {
t.Fatal("cids mismatched")
}
cur += n
}
if cur != len(buf) {
t.Fatal("had trailing bytes")
}
}
}

func TestBadCidFromBytes(t *testing.T) {
Expand Down

0 comments on commit e3b3357

Please sign in to comment.