diff --git a/v3/io.go b/v3/io.go index 6ad5abc..a07b800 100644 --- a/v3/io.go +++ b/v3/io.go @@ -1,6 +1,7 @@ package pb import ( + "fmt" "io" ) @@ -17,6 +18,16 @@ func (r *Reader) Read(p []byte) (n int, err error) { return } +// Seek the wrapped reader when it implements io.Seeker +func (r *Reader) Seek(offset int64, whence int) (n int64, err error) { + if seeker, ok := r.Reader.(io.Seeker); ok { + n, err = seeker.Seek(offset, whence) + r.bar.SetCurrent(n) + return n, err + } + return 0, fmt.Errorf("wrapped io.Reader does not implement io.Seeker") +} + // Close the wrapped reader when it implements io.Closer func (r *Reader) Close() (err error) { r.bar.Finish() diff --git a/v3/io_test.go b/v3/io_test.go index bfd8a06..8bb0171 100644 --- a/v3/io_test.go +++ b/v3/io_test.go @@ -1,22 +1,59 @@ package pb import ( + "io" + "math/rand" "testing" + "time" ) +func init() { + rand.Seed(time.Now().UnixNano()) +} + func TestPBProxyReader(t *testing.T) { bar := new(ProgressBar) if bar.GetBool(Bytes) { t.Errorf("By default bytes must be false") } - testReader := new(testReaderWriterCloser) + testReader := new(testReaderWriterSeekerCloser) + testReader.size = 1000000 proxyReader := bar.NewProxyReader(testReader) if !bar.GetBool(Bytes) { t.Errorf("Bytes must be true after call NewProxyReader") } + for i := 0; i < 10; i++ { + // pick a random offset up to half the size of the Reader in either direction. + offset := rand.Int63n(testReader.size) - (testReader.size / 2) + expected := testReader.offset + offset + if expected < 0 { + expected = 0 + } + if expected > testReader.size { + expected = testReader.size + } + position, err := proxyReader.Seek(offset, io.SeekCurrent) + if err != nil { + t.Errorf("Proxy reader failed to seek: %v", err) + } + if position != testReader.offset { + t.Errorf("Proxy reader offset doesn't match reported offset: %d vs %d", testReader.offset, position) + } + if position != expected { + t.Errorf("Proxy reader returned unexpected position: %d instead of %d / %d", position, expected, testReader.size) + } + if proxyReader.bar.Current() != expected { + t.Errorf("Proxy reader bar returned incorrect position: %d vs %d", proxyReader.bar.Current(), expected) + } + } + offset, err := proxyReader.Seek(0, io.SeekStart) + if err != nil || offset != 0 || proxyReader.bar.Current() != 0 { + t.Errorf("Proxy reader failed to reset seek position: %d, %d, %v", offset, proxyReader.bar.Current(), err) + } + for i := 0; i < 10; i++ { buf := make([]byte, 10) n, e := proxyReader.Read(buf) @@ -49,7 +86,7 @@ func TestPBProxyWriter(t *testing.T) { t.Errorf("By default bytes must be false") } - testWriter := new(testReaderWriterCloser) + testWriter := new(testReaderWriterSeekerCloser) proxyReader := bar.NewProxyWriter(testWriter) if !bar.GetBool(Bytes) { @@ -77,24 +114,44 @@ func TestPBProxyWriter(t *testing.T) { proxyReader.Close() } -type testReaderWriterCloser struct { +type testReaderWriterSeekerCloser struct { + size int64 + offset int64 closed bool data []byte } -func (tr *testReaderWriterCloser) Read(p []byte) (n int, err error) { +func (tr *testReaderWriterSeekerCloser) Read(p []byte) (n int, err error) { for i := range p { p[i] = 'f' } + tr.offset += int64(len(p)) return len(p), nil } -func (tr *testReaderWriterCloser) Write(p []byte) (n int, err error) { +func (tr *testReaderWriterSeekerCloser) Seek(offset int64, whence int) (n int64, err error) { + if whence == io.SeekStart { + tr.offset = offset + } else if whence == io.SeekEnd { + tr.offset = tr.size - offset + } else if whence == io.SeekCurrent { + tr.offset += offset + } + + if tr.offset >= tr.size { + tr.offset = tr.size + } else if tr.offset < 0 { + tr.offset = 0 + } + return tr.offset, err +} + +func (tr *testReaderWriterSeekerCloser) Write(p []byte) (n int, err error) { tr.data = append(tr.data, p...) return len(p), nil } -func (tr *testReaderWriterCloser) Close() (err error) { +func (tr *testReaderWriterSeekerCloser) Close() (err error) { tr.closed = true return } diff --git a/v3/pb.go b/v3/pb.go index 17f3750..595654f 100644 --- a/v3/pb.go +++ b/v3/pb.go @@ -399,8 +399,8 @@ func (pb *ProgressBar) SetTemplate(tmpl ProgressBarTemplate) *ProgressBar { return pb.SetTemplateString(string(tmpl)) } -// NewProxyReader creates a wrapper for given reader, but with progress handle -// Takes io.Reader or io.ReadCloser +// NewProxyReadSeeker creates a wrapper for given ReadSeeker, but with progress handle +// Takes io.Reader, io.ReadSeeker or io.ReadCloser // Also, it automatically switches progress bar to handle units as bytes func (pb *ProgressBar) NewProxyReader(r io.Reader) *Reader { pb.Set(Bytes, true)