diff --git a/proxyreader_test.go b/proxyreader_test.go index ab7236a2..eda0aa62 100644 --- a/proxyreader_test.go +++ b/proxyreader_test.go @@ -27,32 +27,22 @@ func (r *testReader) Read(p []byte) (n int, err error) { return r.Reader.Read(p) } -type testWriterTo struct { - *testReader - wt io.WriterTo -} - -func (wt testWriterTo) WriteTo(w io.Writer) (n int64, err error) { - wt.called = true - return wt.wt.WriteTo(w) -} - func TestProxyReader(t *testing.T) { p := mpb.New(mpb.WithOutput(io.Discard)) - tReader := &testReader{strings.NewReader(content), false} + reader := &testReader{strings.NewReader(content), false} bar := p.AddBar(int64(len(content))) var buf bytes.Buffer - _, err := io.Copy(&buf, bar.ProxyReader(tReader)) + _, err := io.Copy(&buf, bar.ProxyReader(reader)) if err != nil { t.Errorf("Error copying from reader: %+v\n", err) } p.Wait() - if !tReader.called { + if !reader.called { t.Error("Read not called") } @@ -61,23 +51,60 @@ func TestProxyReader(t *testing.T) { } } +type testReadCloser struct { + io.Reader + called bool +} + +func (r *testReadCloser) Close() error { + r.called = true + return nil +} + +func TestProxyReadCloser(t *testing.T) { + p := mpb.New(mpb.WithOutput(io.Discard)) + + reader := &testReadCloser{strings.NewReader(content), false} + + bar := p.AddBar(int64(len(content))) + + rc := bar.ProxyReader(reader) + _, _ = io.Copy(io.Discard, rc) + _ = rc.Close() + + p.Wait() + + if !reader.called { + t.Error("Close not called") + } +} + +type testWriterTo struct { + io.Reader + called bool +} + +func (wt *testWriterTo) WriteTo(w io.Writer) (n int64, err error) { + wt.called = true + return wt.Reader.(io.WriterTo).WriteTo(w) +} + func TestProxyWriterTo(t *testing.T) { p := mpb.New(mpb.WithOutput(io.Discard)) - var reader io.Reader = strings.NewReader(content) - tWriterTo := testWriterTo{&testReader{reader, false}, reader.(io.WriterTo)} + writerTo := &testWriterTo{strings.NewReader(content), false} bar := p.AddBar(int64(len(content))) var buf bytes.Buffer - _, err := io.Copy(&buf, bar.ProxyReader(tWriterTo)) + _, err := io.Copy(&buf, bar.ProxyReader(writerTo)) if err != nil { t.Errorf("Error copying from reader: %+v\n", err) } p.Wait() - if !tWriterTo.called { + if !writerTo.called { t.Error("WriteTo not called") }