diff --git a/stream_writer.go b/stream_writer.go index 3d2a7992e..877d2df25 100644 --- a/stream_writer.go +++ b/stream_writer.go @@ -18,6 +18,7 @@ package badger import ( "math" + "sync" "github.com/dgraph-io/badger/pb" "github.com/dgraph-io/badger/table" @@ -40,6 +41,7 @@ const headStreamId uint32 = math.MaxUint32 // StreamWriter should not be called on in-use DB instances. It is designed only to bootstrap new // DBs. type StreamWriter struct { + writeLock sync.Mutex db *DB done func() throttle *y.Throttle @@ -68,13 +70,17 @@ func (db *DB) NewStreamWriter() *StreamWriter { // calling Prepare, because it could result in permanent data loss. Not calling Prepare would result // in a corrupt Badger instance. func (sw *StreamWriter) Prepare() error { + sw.writeLock.Lock() + defer sw.writeLock.Unlock() + var err error sw.done, err = sw.db.dropAll() return err } // Write writes KVList to DB. Each KV within the list contains the stream id which StreamWriter -// would use to demux the writes. Write is not thread safe and it should NOT be called concurrently. +// would use to demux the writes. Write is thread safe and can be called concurrently by mulitple +// goroutines. func (sw *StreamWriter) Write(kvs *pb.KVList) error { if len(kvs.GetKv()) == 0 { return nil @@ -112,6 +118,9 @@ func (sw *StreamWriter) Write(kvs *pb.KVList) error { for _, req := range streamReqs { all = append(all, req) } + + sw.writeLock.Lock() + defer sw.writeLock.Unlock() if err := sw.db.vlog.write(all); err != nil { return err } @@ -130,6 +139,9 @@ func (sw *StreamWriter) Write(kvs *pb.KVList) error { // Flush is called once we are done writing all the entries. It syncs DB directories. It also // updates Oracle with maxVersion found in all entries (if DB is not managed). func (sw *StreamWriter) Flush() error { + sw.writeLock.Lock() + defer sw.writeLock.Unlock() + defer sw.done() sw.closer.SignalAndWait()