Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-43837: [Go][IPC] Consolidate StreamWriter and FileWriter, ensuring that EOS indicator is written in file #43890

Merged
merged 1 commit into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions go/arrow/ipc/file_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,17 @@
package ipc_test

import (
"bytes"
"fmt"
"os"
"testing"

"github.com/apache/arrow/go/v18/arrow/array"
"github.com/apache/arrow/go/v18/arrow/internal/arrdata"
"github.com/apache/arrow/go/v18/arrow/internal/flatbuf"
"github.com/apache/arrow/go/v18/arrow/ipc"
"github.com/apache/arrow/go/v18/arrow/memory"
"github.com/stretchr/testify/require"
)

func TestFile(t *testing.T) {
Expand Down Expand Up @@ -75,3 +79,39 @@ func TestFileCompressed(t *testing.T) {
}
}
}

func TestFileEmbedsStream(t *testing.T) {
mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer mem.AssertSize(t, 0)

recs := arrdata.Records["primitives"]
schema := recs[0].Schema()

var buf bytes.Buffer
w, err := ipc.NewFileWriter(&buf, ipc.WithSchema(schema), ipc.WithAllocator(mem))
require.NoError(t, err)
defer w.Close()

for _, rec := range recs {
require.NoError(t, w.Write(rec))
}

require.NoError(t, w.Close())

// we should be able to read a valid ipc stream within the ipc file

// create an ipc stream reader, skipping the file magic+padding bytes
rdr, err := ipc.NewReader(bytes.NewReader(buf.Bytes()[8:]), ipc.WithSchema(schema), ipc.WithAllocator(mem))
require.NoError(t, err)
defer rdr.Release()

// the stream reader should know to stop before the footer if the EOS indicator is properly written
var i int
for rdr.Next() {
rec := rdr.Record()
require.Truef(t, array.RecordEqual(rec, recs[i]), "records[%d] differ", i)
i++
}

require.NoError(t, rdr.Err())
}
82 changes: 19 additions & 63 deletions go/arrow/ipc/file_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,23 +37,17 @@ type PayloadWriter interface {
Close() error
}

type pwriter struct {
w io.WriteSeeker
pos int64
type fileWriter struct {
streamWriter

schema *arrow.Schema
dicts []fileBlock
recs []fileBlock
}

func (w *pwriter) Start() error {
func (w *fileWriter) Start() error {
var err error

err = w.updatePos()
if err != nil {
return fmt.Errorf("arrow/ipc: could not update position while in start: %w", err)
}

// only necessary to align to 8-byte boundary at the start of the file
_, err = w.Write(Magic)
if err != nil {
Expand All @@ -65,10 +59,10 @@ func (w *pwriter) Start() error {
return fmt.Errorf("arrow/ipc: could not align start block: %w", err)
}

return err
return w.streamWriter.Start()
}

func (w *pwriter) WritePayload(p Payload) error {
func (w *fileWriter) WritePayload(p Payload) error {
blk := fileBlock{Offset: w.pos, Meta: 0, Body: p.size}
n, err := writeIPCPayload(w, p)
if err != nil {
Expand All @@ -77,11 +71,6 @@ func (w *pwriter) WritePayload(p Payload) error {

blk.Meta = int32(n)

err = w.updatePos()
if err != nil {
return fmt.Errorf("arrow/ipc: could not update position while in write-payload: %w", err)
}

switch flatbuf.MessageHeader(p.msg) {
case flatbuf.MessageHeaderDictionaryBatch:
w.dicts = append(w.dicts, blk)
Expand All @@ -92,27 +81,18 @@ func (w *pwriter) WritePayload(p Payload) error {
return nil
}

func (w *pwriter) Close() error {
func (w *fileWriter) Close() error {
var err error

// write file footer
err = w.updatePos()
if err != nil {
return fmt.Errorf("arrow/ipc: could not update position while in close: %w", err)
if err = w.streamWriter.Close(); err != nil {
return err
}

pos := w.pos
err = writeFileFooter(w.schema, w.dicts, w.recs, w)
if err != nil {
if err = writeFileFooter(w.schema, w.dicts, w.recs, w); err != nil {
return fmt.Errorf("arrow/ipc: could not write file footer: %w", err)
}

// write file footer length
err = w.updatePos() // not strictly needed as we passed w to writeFileFooter...
if err != nil {
return fmt.Errorf("arrow/ipc: could not compute file footer length: %w", err)
}

size := w.pos - pos
if size <= 0 {
return fmt.Errorf("arrow/ipc: invalid file footer size (size=%d)", size)
Expand All @@ -133,13 +113,7 @@ func (w *pwriter) Close() error {
return nil
}

func (w *pwriter) updatePos() error {
var err error
w.pos, err = w.w.Seek(0, io.SeekCurrent)
return err
}

func (w *pwriter) align(align int32) error {
func (w *fileWriter) align(align int32) error {
remainder := paddedLength(w.pos, align) - w.pos
if remainder == 0 {
return nil
Expand All @@ -149,12 +123,6 @@ func (w *pwriter) align(align int32) error {
return err
}

func (w *pwriter) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
w.pos += int64(n)
return n, err
}

func writeIPCPayload(w io.Writer, p Payload) (int, error) {
n, err := writeMessage(p.meta, kArrowIPCAlignment, w)
if err != nil {
Expand Down Expand Up @@ -259,18 +227,12 @@ func (ps payloads) Release() {

// FileWriter is an Arrow file writer.
type FileWriter struct {
w io.WriteSeeker
w io.Writer

mem memory.Allocator

header struct {
started bool
offset int64
}

footer struct {
written bool
}
headerStarted bool
footerWritten bool

pw PayloadWriter

Expand All @@ -289,15 +251,15 @@ type FileWriter struct {
}

// NewFileWriter opens an Arrow file using the provided writer w.
func NewFileWriter(w io.WriteSeeker, opts ...Option) (*FileWriter, error) {
func NewFileWriter(w io.Writer, opts ...Option) (*FileWriter, error) {
var (
cfg = newConfig(opts...)
err error
)

f := FileWriter{
w: w,
pw: &pwriter{w: w, schema: cfg.schema, pos: -1},
pw: &fileWriter{streamWriter: streamWriter{w: w}, schema: cfg.schema},
mem: cfg.alloc,
schema: cfg.schema,
codec: cfg.codec,
Expand All @@ -306,12 +268,6 @@ func NewFileWriter(w io.WriteSeeker, opts ...Option) (*FileWriter, error) {
compressors: make([]compressor, cfg.compressNP),
}

pos, err := f.w.Seek(0, io.SeekCurrent)
if err != nil {
return nil, fmt.Errorf("arrow/ipc: could not seek current position: %w", err)
}
f.header.offset = pos

return &f, err
}

Expand All @@ -321,15 +277,15 @@ func (f *FileWriter) Close() error {
return fmt.Errorf("arrow/ipc: could not write empty file: %w", err)
}

if f.footer.written {
if f.footerWritten {
return nil
}

err = f.pw.Close()
if err != nil {
return fmt.Errorf("arrow/ipc: could not close payload writer: %w", err)
}
f.footer.written = true
f.footerWritten = true

return nil
}
Expand Down Expand Up @@ -367,14 +323,14 @@ func (f *FileWriter) Write(rec arrow.Record) error {
}

func (f *FileWriter) checkStarted() error {
if !f.header.started {
if !f.headerStarted {
return f.start()
}
return nil
}

func (f *FileWriter) start() error {
f.header.started = true
f.headerStarted = true
err := f.pw.Start()
if err != nil {
return err
Expand Down
12 changes: 6 additions & 6 deletions go/arrow/ipc/writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,26 +37,26 @@ import (
"github.com/apache/arrow/go/v18/internal/utils"
)

type swriter struct {
type streamWriter struct {
w io.Writer
pos int64
}

func (w *swriter) Start() error { return nil }
func (w *swriter) Close() error {
func (w *streamWriter) Start() error { return nil }
func (w *streamWriter) Close() error {
_, err := w.Write(kEOS[:])
return err
}

func (w *swriter) WritePayload(p Payload) error {
func (w *streamWriter) WritePayload(p Payload) error {
_, err := writeIPCPayload(w, p)
if err != nil {
return err
}
return nil
}

func (w *swriter) Write(p []byte) (int, error) {
func (w *streamWriter) Write(p []byte) (int, error) {
n, err := w.w.Write(p)
w.pos += int64(n)
return n, err
Expand Down Expand Up @@ -118,7 +118,7 @@ func NewWriter(w io.Writer, opts ...Option) *Writer {
return &Writer{
w: w,
mem: cfg.alloc,
pw: &swriter{w: w},
pw: &streamWriter{w: w},
schema: cfg.schema,
codec: cfg.codec,
emitDictDeltas: cfg.emitDictDeltas,
Expand Down
Loading