diff --git a/stream_writer.go b/stream_writer.go index 1cad64ce7..e696a6d3b 100644 --- a/stream_writer.go +++ b/stream_writer.go @@ -104,6 +104,14 @@ func (sw *StreamWriter) PrepareIncremental() error { } sw.done = func() { once.Do(done) } + mts, decr := sw.db.getMemTables() + defer decr() + for _, m := range mts { + if !m.sl.Empty() { + return fmt.Errorf("Unable to do incremental writes because MemTable has data") + } + } + isEmptyDB := true for _, level := range sw.db.Levels() { if level.NumTables > 0 { @@ -117,7 +125,10 @@ func (sw *StreamWriter) PrepareIncremental() error { return nil } if sw.prevLevel == 0 { - return fmt.Errorf("Unable to do incremental writes because L0 has data") + if err := sw.db.Flatten(3); err != nil { + return errors.Wrapf(err, "error during flatten in StreamWriter") + } + sw.prevLevel = len(sw.db.Levels()) - 1 } return nil } diff --git a/stream_writer_test.go b/stream_writer_test.go index da4bde736..43abc6ce8 100644 --- a/stream_writer_test.go +++ b/stream_writer_test.go @@ -602,7 +602,7 @@ func TestStreamWriterWithLargeValue(t *testing.T) { } func TestStreamWriterIncremental(t *testing.T) { - addIncremtal := func(t *testing.T, db *DB, keys [][]byte) { + addIncremental := func(t *testing.T, db *DB, keys [][]byte) { buf := z.NewBuffer(10<<20, "test") defer func() { require.NoError(t, buf.Release()) }() for _, key := range keys { @@ -633,7 +633,7 @@ func TestStreamWriterIncremental(t *testing.T) { require.NoError(t, sw.Write(buf), "sw.Write() failed") require.NoError(t, sw.Flush(), "sw.Flush() failed") - addIncremtal(t, db, [][]byte{[]byte("key-2")}) + addIncremental(t, db, [][]byte{[]byte("key-2")}) txn := db.NewTransaction(false) defer txn.Discard() @@ -646,7 +646,7 @@ func TestStreamWriterIncremental(t *testing.T) { t.Run("incremental on empty DB", func(t *testing.T) { runBadgerTest(t, nil, func(t *testing.T, db *DB) { - addIncremtal(t, db, [][]byte{[]byte("key-1")}) + addIncremental(t, db, [][]byte{[]byte("key-1")}) txn := db.NewTransaction(false) defer txn.Discard() _, err := txn.Get([]byte("key-1")) @@ -656,9 +656,9 @@ func TestStreamWriterIncremental(t *testing.T) { t.Run("multiple incremental", func(t *testing.T) { runBadgerTest(t, nil, func(t *testing.T, db *DB) { - addIncremtal(t, db, [][]byte{[]byte("a1"), []byte("c1")}) - addIncremtal(t, db, [][]byte{[]byte("a2"), []byte("c2")}) - addIncremtal(t, db, [][]byte{[]byte("a3"), []byte("c3")}) + addIncremental(t, db, [][]byte{[]byte("a1"), []byte("c1")}) + addIncremental(t, db, [][]byte{[]byte("a2"), []byte("c2")}) + addIncremental(t, db, [][]byte{[]byte("a3"), []byte("c3")}) txn := db.NewTransaction(false) defer txn.Discard() _, err := txn.Get([]byte("a1")) @@ -675,4 +675,79 @@ func TestStreamWriterIncremental(t *testing.T) { require.NoError(t, err) }) }) + + t.Run("write between incremental writes", func(t *testing.T) { + runBadgerTest(t, nil, func(t *testing.T, db *DB) { + addIncremental(t, db, [][]byte{[]byte("a1"), []byte("c1")}) + require.NoError(t, db.Update(func(txn *Txn) error { + return txn.Set([]byte("a3"), []byte("c3")) + })) + + sw := db.NewStreamWriter() + defer sw.Cancel() + require.EqualError(t, sw.PrepareIncremental(), "Unable to do incremental writes because MemTable has data") + + txn := db.NewTransaction(false) + defer txn.Discard() + _, err := txn.Get([]byte("a1")) + require.NoError(t, err) + _, err = txn.Get([]byte("c1")) + require.NoError(t, err) + _, err = txn.Get([]byte("a3")) + require.NoError(t, err) + }) + }) + + t.Run("incremental writes > #levels", func(t *testing.T) { + runBadgerTest(t, nil, func(t *testing.T, db *DB) { + addIncremental(t, db, [][]byte{[]byte("a1"), []byte("c1")}) + addIncremental(t, db, [][]byte{[]byte("a2"), []byte("c2")}) + addIncremental(t, db, [][]byte{[]byte("a3"), []byte("c3")}) + addIncremental(t, db, [][]byte{[]byte("a4"), []byte("c4")}) + addIncremental(t, db, [][]byte{[]byte("a5"), []byte("c5")}) + addIncremental(t, db, [][]byte{[]byte("a6"), []byte("c6")}) + addIncremental(t, db, [][]byte{[]byte("a7"), []byte("c7")}) + addIncremental(t, db, [][]byte{[]byte("a8"), []byte("c8")}) + addIncremental(t, db, [][]byte{[]byte("a9"), []byte("c9")}) + + txn := db.NewTransaction(false) + defer txn.Discard() + _, err := txn.Get([]byte("a1")) + require.NoError(t, err) + _, err = txn.Get([]byte("c1")) + require.NoError(t, err) + _, err = txn.Get([]byte("a2")) + require.NoError(t, err) + _, err = txn.Get([]byte("c2")) + require.NoError(t, err) + _, err = txn.Get([]byte("a3")) + require.NoError(t, err) + _, err = txn.Get([]byte("c3")) + require.NoError(t, err) + _, err = txn.Get([]byte("a4")) + require.NoError(t, err) + _, err = txn.Get([]byte("c4")) + require.NoError(t, err) + _, err = txn.Get([]byte("a5")) + require.NoError(t, err) + _, err = txn.Get([]byte("c5")) + require.NoError(t, err) + _, err = txn.Get([]byte("a6")) + require.NoError(t, err) + _, err = txn.Get([]byte("c6")) + require.NoError(t, err) + _, err = txn.Get([]byte("a7")) + require.NoError(t, err) + _, err = txn.Get([]byte("c7")) + require.NoError(t, err) + _, err = txn.Get([]byte("a8")) + require.NoError(t, err) + _, err = txn.Get([]byte("c8")) + require.NoError(t, err) + _, err = txn.Get([]byte("a9")) + require.NoError(t, err) + _, err = txn.Get([]byte("c9")) + require.NoError(t, err) + }) + }) }