Skip to content

Commit

Permalink
Making the stream writer APIs goroutine-safe (#959)
Browse files Browse the repository at this point in the history
(cherry picked from commit f59246c)
  • Loading branch information
Lucas Wang authored and Ibrahim Jarif committed Mar 10, 2020
1 parent 424742d commit 473a2f5
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion stream_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package badger

import (
"math"
"sync"

"github.com/dgraph-io/badger/pb"
"github.com/dgraph-io/badger/table"
Expand All @@ -40,6 +41,7 @@ const headStreamId uint32 = math.MaxUint32
// StreamWriter should not be called on in-use DB instances. It is designed only to bootstrap new
// DBs.
type StreamWriter struct {
writeLock sync.Mutex
db *DB
done func()
throttle *y.Throttle
Expand Down Expand Up @@ -68,13 +70,17 @@ func (db *DB) NewStreamWriter() *StreamWriter {
// calling Prepare, because it could result in permanent data loss. Not calling Prepare would result
// in a corrupt Badger instance.
func (sw *StreamWriter) Prepare() error {
sw.writeLock.Lock()
defer sw.writeLock.Unlock()

var err error
sw.done, err = sw.db.dropAll()
return err
}

// Write writes KVList to DB. Each KV within the list contains the stream id which StreamWriter
// would use to demux the writes. Write is not thread safe and it should NOT be called concurrently.
// would use to demux the writes. Write is thread safe and can be called concurrently by mulitple
// goroutines.
func (sw *StreamWriter) Write(kvs *pb.KVList) error {
if len(kvs.GetKv()) == 0 {
return nil
Expand Down Expand Up @@ -112,6 +118,9 @@ func (sw *StreamWriter) Write(kvs *pb.KVList) error {
for _, req := range streamReqs {
all = append(all, req)
}

sw.writeLock.Lock()
defer sw.writeLock.Unlock()
if err := sw.db.vlog.write(all); err != nil {
return err
}
Expand All @@ -130,6 +139,9 @@ func (sw *StreamWriter) Write(kvs *pb.KVList) error {
// Flush is called once we are done writing all the entries. It syncs DB directories. It also
// updates Oracle with maxVersion found in all entries (if DB is not managed).
func (sw *StreamWriter) Flush() error {
sw.writeLock.Lock()
defer sw.writeLock.Unlock()

defer sw.done()

sw.closer.SignalAndWait()
Expand Down

0 comments on commit 473a2f5

Please sign in to comment.