diff --git a/go/arrow/ipc/file_test.go b/go/arrow/ipc/file_test.go index dea63579cfea6..b9a4547a5126a 100644 --- a/go/arrow/ipc/file_test.go +++ b/go/arrow/ipc/file_test.go @@ -17,13 +17,17 @@ package ipc_test import ( + "bytes" "fmt" "os" "testing" + "github.com/apache/arrow/go/v18/arrow/array" "github.com/apache/arrow/go/v18/arrow/internal/arrdata" "github.com/apache/arrow/go/v18/arrow/internal/flatbuf" + "github.com/apache/arrow/go/v18/arrow/ipc" "github.com/apache/arrow/go/v18/arrow/memory" + "github.com/stretchr/testify/require" ) func TestFile(t *testing.T) { @@ -75,3 +79,39 @@ func TestFileCompressed(t *testing.T) { } } } + +func TestFileEmbedsStream(t *testing.T) { + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + recs := arrdata.Records["primitives"] + schema := recs[0].Schema() + + var buf bytes.Buffer + w, err := ipc.NewFileWriter(&buf, ipc.WithSchema(schema), ipc.WithAllocator(mem)) + require.NoError(t, err) + defer w.Close() + + for _, rec := range recs { + require.NoError(t, w.Write(rec)) + } + + require.NoError(t, w.Close()) + + // we should be able to read a valid ipc stream within the ipc file + + // create an ipc stream reader, skipping the file magic+padding bytes + rdr, err := ipc.NewReader(bytes.NewReader(buf.Bytes()[8:]), ipc.WithSchema(schema), ipc.WithAllocator(mem)) + require.NoError(t, err) + defer rdr.Release() + + // the stream reader should know to stop before the footer if the EOS indicator is properly written + var i int + for rdr.Next() { + rec := rdr.Record() + require.Truef(t, array.RecordEqual(rec, recs[i]), "records[%d] differ", i) + i++ + } + + require.NoError(t, rdr.Err()) +} diff --git a/go/arrow/ipc/file_writer.go b/go/arrow/ipc/file_writer.go index 8582c81baf2fe..9a3d7d3dbeb02 100644 --- a/go/arrow/ipc/file_writer.go +++ b/go/arrow/ipc/file_writer.go @@ -37,23 +37,17 @@ type PayloadWriter interface { Close() error } -type pwriter struct { - w io.WriteSeeker - pos int64 +type fileWriter struct { + streamWriter schema *arrow.Schema dicts []fileBlock recs []fileBlock } -func (w *pwriter) Start() error { +func (w *fileWriter) Start() error { var err error - err = w.updatePos() - if err != nil { - return fmt.Errorf("arrow/ipc: could not update position while in start: %w", err) - } - // only necessary to align to 8-byte boundary at the start of the file _, err = w.Write(Magic) if err != nil { @@ -65,10 +59,10 @@ func (w *pwriter) Start() error { return fmt.Errorf("arrow/ipc: could not align start block: %w", err) } - return err + return w.streamWriter.Start() } -func (w *pwriter) WritePayload(p Payload) error { +func (w *fileWriter) WritePayload(p Payload) error { blk := fileBlock{Offset: w.pos, Meta: 0, Body: p.size} n, err := writeIPCPayload(w, p) if err != nil { @@ -77,11 +71,6 @@ func (w *pwriter) WritePayload(p Payload) error { blk.Meta = int32(n) - err = w.updatePos() - if err != nil { - return fmt.Errorf("arrow/ipc: could not update position while in write-payload: %w", err) - } - switch flatbuf.MessageHeader(p.msg) { case flatbuf.MessageHeaderDictionaryBatch: w.dicts = append(w.dicts, blk) @@ -92,27 +81,18 @@ func (w *pwriter) WritePayload(p Payload) error { return nil } -func (w *pwriter) Close() error { +func (w *fileWriter) Close() error { var err error - // write file footer - err = w.updatePos() - if err != nil { - return fmt.Errorf("arrow/ipc: could not update position while in close: %w", err) + if err = w.streamWriter.Close(); err != nil { + return err } pos := w.pos - err = writeFileFooter(w.schema, w.dicts, w.recs, w) - if err != nil { + if err = writeFileFooter(w.schema, w.dicts, w.recs, w); err != nil { return fmt.Errorf("arrow/ipc: could not write file footer: %w", err) } - // write file footer length - err = w.updatePos() // not strictly needed as we passed w to writeFileFooter... - if err != nil { - return fmt.Errorf("arrow/ipc: could not compute file footer length: %w", err) - } - size := w.pos - pos if size <= 0 { return fmt.Errorf("arrow/ipc: invalid file footer size (size=%d)", size) @@ -133,13 +113,7 @@ func (w *pwriter) Close() error { return nil } -func (w *pwriter) updatePos() error { - var err error - w.pos, err = w.w.Seek(0, io.SeekCurrent) - return err -} - -func (w *pwriter) align(align int32) error { +func (w *fileWriter) align(align int32) error { remainder := paddedLength(w.pos, align) - w.pos if remainder == 0 { return nil @@ -149,12 +123,6 @@ func (w *pwriter) align(align int32) error { return err } -func (w *pwriter) Write(p []byte) (int, error) { - n, err := w.w.Write(p) - w.pos += int64(n) - return n, err -} - func writeIPCPayload(w io.Writer, p Payload) (int, error) { n, err := writeMessage(p.meta, kArrowIPCAlignment, w) if err != nil { @@ -259,18 +227,12 @@ func (ps payloads) Release() { // FileWriter is an Arrow file writer. type FileWriter struct { - w io.WriteSeeker + w io.Writer mem memory.Allocator - header struct { - started bool - offset int64 - } - - footer struct { - written bool - } + headerStarted bool + footerWritten bool pw PayloadWriter @@ -289,7 +251,7 @@ type FileWriter struct { } // NewFileWriter opens an Arrow file using the provided writer w. -func NewFileWriter(w io.WriteSeeker, opts ...Option) (*FileWriter, error) { +func NewFileWriter(w io.Writer, opts ...Option) (*FileWriter, error) { var ( cfg = newConfig(opts...) err error @@ -297,7 +259,7 @@ func NewFileWriter(w io.WriteSeeker, opts ...Option) (*FileWriter, error) { f := FileWriter{ w: w, - pw: &pwriter{w: w, schema: cfg.schema, pos: -1}, + pw: &fileWriter{streamWriter: streamWriter{w: w}, schema: cfg.schema}, mem: cfg.alloc, schema: cfg.schema, codec: cfg.codec, @@ -306,12 +268,6 @@ func NewFileWriter(w io.WriteSeeker, opts ...Option) (*FileWriter, error) { compressors: make([]compressor, cfg.compressNP), } - pos, err := f.w.Seek(0, io.SeekCurrent) - if err != nil { - return nil, fmt.Errorf("arrow/ipc: could not seek current position: %w", err) - } - f.header.offset = pos - return &f, err } @@ -321,7 +277,7 @@ func (f *FileWriter) Close() error { return fmt.Errorf("arrow/ipc: could not write empty file: %w", err) } - if f.footer.written { + if f.footerWritten { return nil } @@ -329,7 +285,7 @@ func (f *FileWriter) Close() error { if err != nil { return fmt.Errorf("arrow/ipc: could not close payload writer: %w", err) } - f.footer.written = true + f.footerWritten = true return nil } @@ -367,14 +323,14 @@ func (f *FileWriter) Write(rec arrow.Record) error { } func (f *FileWriter) checkStarted() error { - if !f.header.started { + if !f.headerStarted { return f.start() } return nil } func (f *FileWriter) start() error { - f.header.started = true + f.headerStarted = true err := f.pw.Start() if err != nil { return err diff --git a/go/arrow/ipc/writer.go b/go/arrow/ipc/writer.go index 02c67635bb2fd..5a280fbf84a1f 100644 --- a/go/arrow/ipc/writer.go +++ b/go/arrow/ipc/writer.go @@ -37,18 +37,18 @@ import ( "github.com/apache/arrow/go/v18/internal/utils" ) -type swriter struct { +type streamWriter struct { w io.Writer pos int64 } -func (w *swriter) Start() error { return nil } -func (w *swriter) Close() error { +func (w *streamWriter) Start() error { return nil } +func (w *streamWriter) Close() error { _, err := w.Write(kEOS[:]) return err } -func (w *swriter) WritePayload(p Payload) error { +func (w *streamWriter) WritePayload(p Payload) error { _, err := writeIPCPayload(w, p) if err != nil { return err @@ -56,7 +56,7 @@ func (w *swriter) WritePayload(p Payload) error { return nil } -func (w *swriter) Write(p []byte) (int, error) { +func (w *streamWriter) Write(p []byte) (int, error) { n, err := w.w.Write(p) w.pos += int64(n) return n, err @@ -118,7 +118,7 @@ func NewWriter(w io.Writer, opts ...Option) *Writer { return &Writer{ w: w, mem: cfg.alloc, - pw: &swriter{w: w}, + pw: &streamWriter{w: w}, schema: cfg.schema, codec: cfg.codec, emitDictDeltas: cfg.emitDictDeltas,