diff --git a/merge.go b/merge.go index aef1410..6e6b8dd 100644 --- a/merge.go +++ b/merge.go @@ -120,7 +120,12 @@ func (r *mergedRowGroupRows) ReadRows(rows []Row) (n int, err error) { r.err = err return n, err } + c := r.cursors[0] heap.Pop(r) + if err := c.close(); err != nil { + r.err = err + return n, err + } } else { heap.Fix(r, 0) } diff --git a/merge_test.go b/merge_test.go index 4d877d5..5be47a6 100644 --- a/merge_test.go +++ b/merge_test.go @@ -17,6 +17,75 @@ const ( rowsPerGroup = benchmarkNumRows ) +type wrappedRowGroup struct { + parquet.RowGroup + rowsCallback func(parquet.Rows) parquet.Rows +} + +func (r wrappedRowGroup) Rows() parquet.Rows { + return r.rowsCallback(r.RowGroup.Rows()) +} + +type wrappedRows struct { + parquet.Rows + closed bool +} + +func (r *wrappedRows) Close() error { + r.closed = true + return r.Rows.Close() +} + +func TestMergeRowGroupsCursorsAreClosed(t *testing.T) { + + type model struct { + A int + } + + schema := parquet.SchemaOf(model{}) + options := []parquet.RowGroupOption{ + parquet.SortingColumns( + parquet.Ascending(schema.Columns()[0]...), + ), + } + + prng := rand.New(rand.NewSource(0)) + rowGroups := make([]parquet.RowGroup, numRowGroups) + rows := make([]*wrappedRows, 0, numRowGroups) + + for i := range rowGroups { + rowGroups[i] = wrappedRowGroup{ + RowGroup: sortedRowGroup(options, randomRowsOf(prng, rowsPerGroup, model{})...), + rowsCallback: func(r parquet.Rows) parquet.Rows { + wrapped := &wrappedRows{Rows: r} + rows = append(rows, wrapped) + return wrapped + }, + } + } + + m, err := parquet.MergeRowGroups(rowGroups, options...) + if err != nil { + t.Fatal(err) + } + func() { + mergedRows := m.Rows() + defer mergedRows.Close() + + // Add 1 more slot to the buffer to force an io.EOF on the first read. + rbuf := make([]parquet.Row, (numRowGroups*rowsPerGroup)+1) + if _, err := mergedRows.ReadRows(rbuf); !errors.Is(err, io.EOF) { + t.Fatal(err) + } + }() + + for i, wrapped := range rows { + if !wrapped.closed { + t.Fatalf("RowGroup %d not closed", i) + } + } +} + func BenchmarkMergeRowGroups(b *testing.B) { for _, test := range readerTests { b.Run(test.scenario, func(b *testing.B) {