From 5e9750f7a6bbaa4bf0155b36c9573703bbebc3f7 Mon Sep 17 00:00:00 2001 From: David Norton Date: Mon, 20 Oct 2014 13:12:59 -0400 Subject: [PATCH] Revert to older snapshots if the latest raft snapshot is corrupted Fix #1040 --- _vendor/raft/server.go | 98 +++++++++++++++-------------------- _vendor/raft/snapshot.go | 51 ++++++++++++++++++ _vendor/raft/snapshot_test.go | 32 ++++++++++++ coordinator/raft_server.go | 6 ++- 4 files changed, 131 insertions(+), 56 deletions(-) diff --git a/_vendor/raft/server.go b/_vendor/raft/server.go index 8a9d05c15e6..620eed37e0f 100644 --- a/_vendor/raft/server.go +++ b/_vendor/raft/server.go @@ -4,10 +4,10 @@ import ( "encoding/json" "errors" "fmt" - "hash/crc32" "io/ioutil" "os" "path" + "path/filepath" "sort" "sync" "time" @@ -1253,6 +1253,31 @@ func (s *server) saveSnapshot() error { return nil } +// Returns a list of available snapshot names sorted newest to oldest +func (s *server) SnapshotList() ([]string, error) { + // Get FileInfo for everything in the snapshot dir + ssdir := path.Join(s.path, "snapshot") + finfos, err := ioutil.ReadDir(ssdir) + + if err != nil { + return nil, err + } + + // Build a list of snapshot file names + var ssnames []string + for _, finfo := range finfos { + fname := finfo.Name() + if finfo.Mode().IsRegular() && filepath.Ext(fname) == ".ss" { + ssnames = append(ssnames, fname) + } + } + + // Sort snapshot names from newest to oldest + sort.Sort(sort.Reverse(sort.StringSlice(ssnames))) + + return ssnames, nil +} + // Retrieves the log path for the server. func (s *server) SnapshotPath(lastIndex uint64, lastTerm uint64) string { return path.Join(s.path, "snapshot", fmt.Sprintf("%v_%v.ss", lastTerm, lastIndex)) @@ -1314,80 +1339,43 @@ func (s *server) processSnapshotRecoveryRequest(req *SnapshotRecoveryRequest) *S // Load a snapshot at restart func (s *server) LoadSnapshot() error { - // Open snapshot/ directory. - dir, err := os.OpenFile(path.Join(s.path, "snapshot"), os.O_RDONLY, 0) - if err != nil { - s.debugln("cannot.open.snapshot: ", err) - return err - } - - // Retrieve a list of all snapshots. - filenames, err := dir.Readdirnames(-1) - if err != nil { - dir.Close() - panic(err) - } - dir.Close() - - if len(filenames) == 0 { - s.debugln("no.snapshot.to.load") - return nil - } - - // Grab the latest snapshot. - sort.Strings(filenames) - snapshotPath := path.Join(s.path, "snapshot", filenames[len(filenames)-1]) - - // Read snapshot data. - file, err := os.OpenFile(snapshotPath, os.O_RDONLY, 0) + sslist, err := s.SnapshotList() if err != nil { return err } - defer file.Close() - // Check checksum. - var checksum uint32 - n, err := fmt.Fscanf(file, "%08x\n", &checksum) - if err != nil { - return err - } else if n != 1 { - return errors.New("checksum.err: bad.snapshot.file") + // Load most recent snapshot (falling back to older snapshots if needed) + var ss *Snapshot + for _, ssname := range sslist { + ssFullPath := path.Join(s.path, "snapshot", ssname) + ss, err = loadSnapshot(ssFullPath) + if err == nil { + break + } + s.debugln(err) } - // Load remaining snapshot contents. - b, err := ioutil.ReadAll(file) if err != nil { - return err + return err // couldn't load any of the snapshots } - // Generate checksum. - byteChecksum := crc32.ChecksumIEEE(b) - if uint32(checksum) != byteChecksum { - s.debugln(checksum, " ", byteChecksum) - return errors.New("bad snapshot file") - } - - // Decode snapshot. - if err = json.Unmarshal(b, &s.snapshot); err != nil { - s.debugln("unmarshal.snapshot.error: ", err) - return err - } + s.snapshot = ss // Recover snapshot into state machine. - if err = s.stateMachine.Recovery(s.snapshot.State); err != nil { + if err = s.stateMachine.Recovery(ss.State); err != nil { s.debugln("recovery.snapshot.error: ", err) return err } // Recover cluster configuration. - for _, peer := range s.snapshot.Peers { + for _, peer := range ss.Peers { s.AddPeer(peer.Name, peer.ConnectionString) } // Update log state. - s.log.startTerm = s.snapshot.LastTerm - s.log.startIndex = s.snapshot.LastIndex - s.log.updateCommitIndex(s.snapshot.LastIndex) + s.log.startTerm = ss.LastTerm + s.log.startIndex = ss.LastIndex + s.log.updateCommitIndex(ss.LastIndex) return err } diff --git a/_vendor/raft/snapshot.go b/_vendor/raft/snapshot.go index 41a4dac946d..07d2812069c 100644 --- a/_vendor/raft/snapshot.go +++ b/_vendor/raft/snapshot.go @@ -2,6 +2,7 @@ package raft import ( "encoding/json" + "errors" "fmt" "hash/crc32" "io" @@ -9,6 +10,7 @@ import ( "os" "code.google.com/p/gogoprotobuf/proto" + "github.com/dgnorton/goback" "github.com/influxdb/influxdb/_vendor/raft/protobuf" ) @@ -51,13 +53,60 @@ type SnapshotResponse struct { Success bool `json:"success"` } +// loadSnapshot reads a snapshot from file +func loadSnapshot(filename string) (*Snapshot, error) { + file, err := os.OpenFile(filename, os.O_RDONLY, 0) + if err != nil { + return nil, err + } + defer file.Close() + + // Check checksum. + var checksum uint32 + n, err := fmt.Fscanf(file, "%08x\n", &checksum) + if err != nil { + return nil, err + } else if n != 1 { + return nil, errors.New("checksum.err: bad.snapshot.file") + } + + // Load remaining snapshot contents. + b, err := ioutil.ReadAll(file) + if err != nil { + return nil, err + } + + // Generate checksum. + byteChecksum := crc32.ChecksumIEEE(b) + if uint32(checksum) != byteChecksum { + return nil, errors.New("bad snapshot file") + } + + // Decode snapshot. + var ss Snapshot + if err = json.Unmarshal(b, &ss); err != nil { + return nil, err + } + + return &ss, nil +} + // save writes the snapshot to file. func (ss *Snapshot) save() error { + tx := goback.Begin() + defer tx.Rollback() + // Open the file for writing. file, err := os.OpenFile(ss.Path, os.O_CREATE|os.O_WRONLY, 0600) if err != nil { return err } + tx.Exec(func() error { + if err := os.Remove(ss.Path); err != nil { + panic(err) + } + return nil + }) defer file.Close() // Serialize to JSON. @@ -82,6 +131,8 @@ func (ss *Snapshot) save() error { return err } + tx.Commit() + return nil } diff --git a/_vendor/raft/snapshot_test.go b/_vendor/raft/snapshot_test.go index 5d6eecb43d6..ac840e9559d 100644 --- a/_vendor/raft/snapshot_test.go +++ b/_vendor/raft/snapshot_test.go @@ -1,6 +1,10 @@ package raft import ( + "fmt" + "io/ioutil" + "os" + "path" "testing" "time" @@ -8,6 +12,34 @@ import ( "github.com/stretchr/testify/mock" ) +func TestSnapshotSave(t *testing.T) { + tstPath, err := ioutil.TempDir("", "snapshot_save") + if err != nil { + t.Errorf("Failed to create a temporary directory: %s", err) + } + defer os.RemoveAll(tstPath) + + index := uint64(1) + term := uint64(2) + ssfile := path.Join(tstPath, fmt.Sprintf("%v_%v.ss", term, index)) + + ss := &Snapshot{ + LastIndex: index, + LastTerm: term, + Path: ssfile, + } + + err = ss.save() + + if err != nil { + t.Error(err) + } + + if _, err := os.Stat(ssfile); os.IsNotExist(err) { + t.Errorf("Failed to create snapshot: %s", ssfile) + } +} + // Ensure that a snapshot occurs when there are existing logs. func TestSnapshot(t *testing.T) { runServerWithMockStateMachine(Leader, func(s Server, m *mock.Mock) { diff --git a/coordinator/raft_server.go b/coordinator/raft_server.go index c2d14a1e9ec..bb121a35241 100644 --- a/coordinator/raft_server.go +++ b/coordinator/raft_server.go @@ -391,7 +391,11 @@ func (s *RaftServer) startRaft() error { } s.raftServer.SetElectionTimeout(s.config.RaftTimeout.Duration) - s.raftServer.LoadSnapshot() // ignore errors + + err = s.raftServer.LoadSnapshot() + if err != nil { + panic(err) + } s.raftServer.AddEventListener(raft.StateChangeEventType, s.raftEventHandler)