Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making the stream writer APIs goroutine-safe #959

Merged
merged 4 commits into from
Aug 2, 2019
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion stream_writer.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package badger

import (
"math"
"sync"

"github.com/dgraph-io/badger/pb"
"github.com/dgraph-io/badger/table"
Expand All @@ -40,6 +41,7 @@ const headStreamId uint32 = math.MaxUint32
// StreamWriter should not be called on in-use DB instances. It is designed only to bootstrap new
// DBs.
type StreamWriter struct {
writeLock sync.Mutex
db *DB
done func()
throttle *y.Throttle
Expand Down Expand Up @@ -68,13 +70,17 @@ func (db *DB) NewStreamWriter() *StreamWriter {
// calling Prepare, because it could result in permanent data loss. Not calling Prepare would result
// in a corrupt Badger instance.
func (sw *StreamWriter) Prepare() error {
sw.writeLock.Lock()
defer sw.writeLock.Unlock()

var err error
sw.done, err = sw.db.dropAll()
return err
}

// Write writes KVList to DB. Each KV within the list contains the stream id which StreamWriter
// would use to demux the writes. Write is not thread safe and it should NOT be called concurrently.
// would use to demux the writes. Write is thread safe and can be called concurrently by mulitple
// goroutines.
func (sw *StreamWriter) Write(kvs *pb.KVList) error {
if len(kvs.GetKv()) == 0 {
return nil
Expand Down Expand Up @@ -112,6 +118,9 @@ func (sw *StreamWriter) Write(kvs *pb.KVList) error {
for _, req := range streamReqs {
all = append(all, req)
}

sw.writeLock.Lock()
defer sw.writeLock.Unlock()
if err := sw.db.vlog.write(all); err != nil {
return err
}
Expand All @@ -130,6 +139,9 @@ func (sw *StreamWriter) Write(kvs *pb.KVList) error {
// Flush is called once we are done writing all the entries. It syncs DB directories. It also
// updates Oracle with maxVersion found in all entries (if DB is not managed).
func (sw *StreamWriter) Flush() error {
sw.writeLock.Lock()
defer sw.writeLock.Unlock()

defer sw.done()

sw.closer.SignalAndWait()
Expand Down