diff --git a/snapshot/snapshot_test.go b/snapshot/snapshot_test.go new file mode 100644 index 00000000..bccf0bc5 --- /dev/null +++ b/snapshot/snapshot_test.go @@ -0,0 +1,72 @@ +package snapshot + +import ( + "bytes" + "io" + "testing" +) + +func Test_SnapshotNew(t *testing.T) { + // Create a new snapshot + s := NewSnapshot(nil) + if s == nil { + t.Errorf("expected snapshot to be created") + } +} + +// Test_SnapshotPersist_NilData tests that Persist does not error when +// given a nil data buffer. +func Test_SnapshotPersist_NilData(t *testing.T) { + compactedBuf := bytes.NewBuffer(nil) + s := NewSnapshot(io.NopCloser(compactedBuf)) + if s == nil { + t.Errorf("expected snapshot to be created") + } + + mrs := &mockRaftSink{} + err := s.Persist(mrs) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if len(mrs.buf.Bytes()) != 0 { + t.Errorf("expected %d, got %d", 0, len(mrs.buf.Bytes())) + } +} + +func Test_SnapshotPersist_SimpleData(t *testing.T) { + compactedBuf := bytes.NewBuffer([]byte("hello world")) + s := NewSnapshot(io.NopCloser(compactedBuf)) + if s == nil { + t.Errorf("expected snapshot to be created") + } + + mrs := &mockRaftSink{} + err := s.Persist(mrs) + if err != nil { + t.Errorf("expected no error, got %v", err) + } + if mrs.buf.String() != "hello world" { + t.Errorf("expected %s, got %s", "hello world", mrs.buf.String()) + } +} + +type mockRaftSink struct { + buf bytes.Buffer +} + +func (mrs *mockRaftSink) Write(p []byte) (n int, err error) { + return mrs.buf.Write(p) +} + +func (mrs *mockRaftSink) Close() error { + return nil +} + +// implement cancel +func (mrs *mockRaftSink) Cancel() error { + return nil +} + +func (mrs *mockRaftSink) ID() string { + return "" +} diff --git a/store/store.go b/store/store.go index 4fc3d17a..cd7ac488 100644 --- a/store/store.go +++ b/store/store.go @@ -1780,7 +1780,7 @@ func (s *Store) Snapshot() (raft.FSMSnapshot, error) { } else { compactedBuf := bytes.NewBuffer(nil) var err error - if pathExists(s.db.WALPath()) { + if pathExistsWithData(s.db.WALPath()) { walFD, err := os.Open(s.db.WALPath()) if err != nil { return nil, err @@ -2390,6 +2390,17 @@ func pathExists(p string) bool { return true } +// pathExistsWithData returns true if the given path exists and has data. +func pathExistsWithData(p string) bool { + if !pathExists(p) { + return false + } + if size, err := fileSize(p); err != nil || size == 0 { + return false + } + return true +} + func dirExists(path string) bool { stat, err := os.Stat(path) return err == nil && stat.IsDir()