diff --git a/commitlog/commitlog.go b/commitlog/commitlog.go index 585e421b..45eb550b 100644 --- a/commitlog/commitlog.go +++ b/commitlog/commitlog.go @@ -12,6 +12,10 @@ import ( "github.com/pkg/errors" ) +var ( + ErrSegmentNotFound = errors.New("segment not found") +) + type CommitLog struct { Options name string @@ -138,6 +142,29 @@ func (l *CommitLog) DeleteAll() error { return os.RemoveAll(l.Path) } +func (l *CommitLog) TruncateTo(offset int64) error { + l.mu.Lock() + defer l.mu.Unlock() + var segments []*Segment + for _, segment := range l.segments { + if segment.BaseOffset < offset { + if err := segment.Delete(); err != nil { + return err + } + } else { + segments = append(segments, segment) + } + } + l.segments = segments + return nil +} + +func (l *CommitLog) Segments() []*Segment { + l.mu.Lock() + defer l.mu.Unlock() + return l.segments +} + func (l *CommitLog) checkSplit() bool { return l.activeSegment().IsFull() } diff --git a/commitlog/commitlog_test.go b/commitlog/commitlog_test.go index 5082050b..95794f6e 100644 --- a/commitlog/commitlog_test.go +++ b/commitlog/commitlog_test.go @@ -106,6 +106,52 @@ func TestNewCommitLogExisting(t *testing.T) { } } +func TestTruncateTo(t *testing.T) { + var err error + l0 := setup(t) + defer cleanup(t) + + for _, msgSet := range msgSets { + _, err = l0.Append(msgSet) + assert.NoError(t, err) + } + assert.Equal(t, int64(2), l0.NewestOffset()) + assert.Equal(t, 2, len(l0.Segments())) + + err = l0.TruncateTo(int64(1)) + assert.NoError(t, err) + assert.Equal(t, 1, len(l0.Segments())) + + maxBytes := msgSets[0].Size() + _, err = l0.NewReader(ReaderOptions{ + Offset: 0, + MaxBytes: maxBytes, + }) + assert.Error(t, err) + + r, err := l0.NewReader(ReaderOptions{ + Offset: 1, + MaxBytes: maxBytes, + }) + assert.NoError(t, err) + + for i, _ := range msgSets[1:] { + p := make([]byte, maxBytes) + _, err = r.Read(p) + assert.NoError(t, err) + + ms := MessageSet(p) + assert.Equal(t, int64(i+1), ms.Offset()) + + payload := ms.Payload() + var offset int + for _, msg := range msgs { + assert.Equal(t, []byte(msg), payload[offset:offset+len(msg)]) + offset += len(msg) + } + } +} + func check(t *testing.T, got, want []byte) { if !bytes.Equal(got, want) { t.Errorf("got = %s, want %s", string(got), string(want)) diff --git a/commitlog/index.go b/commitlog/index.go index f0f2db7b..78a70cec 100644 --- a/commitlog/index.go +++ b/commitlog/index.go @@ -100,7 +100,7 @@ func (idx *index) ReadEntry(e *Entry, offset int64) error { rel := &relEntry{} err := binary.Read(b, binary.BigEndian, rel) if err != nil { - return errors.Wrap(err, "binar read failed") + return errors.Wrap(err, "binary read failed") } idx.mu.RLock() rel.fill(e, idx.baseOffset) @@ -139,3 +139,7 @@ func (idx *index) Sync() error { func (idx *index) Close() (err error) { return idx.file.Close() } + +func (idx *index) Name() string { + return idx.file.Name() +} diff --git a/commitlog/reader.go b/commitlog/reader.go index fd40df8a..e871781c 100644 --- a/commitlog/reader.go +++ b/commitlog/reader.go @@ -3,26 +3,26 @@ package commitlog import ( "io" "sync" - - "github.com/pkg/errors" ) type Reader struct { ReaderOptions - segment *Segment - segments []*Segment - idx int - mu sync.Mutex - position int64 + commitlog *CommitLog + idx int + mu sync.Mutex + position int64 } func (r *Reader) Read(p []byte) (n int, err error) { r.mu.Lock() defer r.mu.Unlock() + segments := r.commitlog.Segments() + segment := segments[r.idx] + var readSize int for { - readSize, err = r.segment.ReadAt(p[n:], r.position) + readSize, err = segment.ReadAt(p[n:], r.position) n += readSize r.position += int64(readSize) if readSize != 0 && err == nil { @@ -31,12 +31,12 @@ func (r *Reader) Read(p []byte) (n int, err error) { if n == len(p) || err != io.EOF { break } - if len(r.segments) <= r.idx+1 { + if len(segments) <= r.idx+1 { err = io.EOF break } r.idx++ - r.segment = r.segments[r.idx] + segment = segments[r.idx] r.position = 0 } @@ -50,18 +50,16 @@ type ReaderOptions struct { } func (l *CommitLog) NewReader(options ReaderOptions) (r *Reader, err error) { - segment, idx := findSegment(l.segments, options.Offset) - entry, _ := segment.findEntry(options.Offset) - position := entry.Position - + segment, idx := findSegment(l.Segments(), options.Offset) if segment == nil { - return nil, errors.Wrap(err, "segment not found") + return nil, ErrSegmentNotFound } + entry, _ := segment.findEntry(options.Offset) + position := entry.Position return &Reader{ ReaderOptions: options, - segment: segment, - segments: l.segments, + commitlog: l, idx: idx, position: position, }, nil diff --git a/commitlog/segment.go b/commitlog/segment.go index ebb86da0..7d9c953f 100644 --- a/commitlog/segment.go +++ b/commitlog/segment.go @@ -165,3 +165,18 @@ func (s *Segment) findEntry(offset int64) (e *Entry, err error) { } return e, nil } + +func (s *Segment) Delete() error { + if err := s.Close(); err != nil { + return err + } + s.Lock() + defer s.Unlock() + if err := os.Remove(s.log.Name()); err != nil { + return err + } + if err := os.Remove(s.Index.Name()); err != nil { + return err + } + return nil +}