From 8bfa54d7009ac1a1e3dd3b3e255bd357b7360829 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Tue, 2 Jan 2024 23:38:04 -0500 Subject: [PATCH] Unit test CAS --- store/cas.go | 37 +++++++++++++++++++++++++++++++ store/cas_test.go | 28 ++++++++++++++++++++++++ store/store.go | 56 ++++++++++++++++++++++++++++++++++------------- 3 files changed, 106 insertions(+), 15 deletions(-) create mode 100644 store/cas.go create mode 100644 store/cas_test.go diff --git a/store/cas.go b/store/cas.go new file mode 100644 index 00000000..6fab20e9 --- /dev/null +++ b/store/cas.go @@ -0,0 +1,37 @@ +package store + +import ( + "errors" + "sync/atomic" +) + +var ( + // ErrCASConflict is returned when a CAS operation fails. + ErrCASConflict = errors.New("cas conflict") +) + +// CheckAndSet is a simple concurrency control mechanism that allows +// only one goroutine to execute a critical section at a time. +type CheckAndSet struct { + state atomic.Int32 +} + +// NewCheckAndSet creates a new CheckAndSet instance. +func NewCheckAndSet() *CheckAndSet { + return &CheckAndSet{} +} + +// Begin attempts to enter the critical section. If another goroutine +// is already in the critical section, Begin returns an error. +func (c *CheckAndSet) Begin() error { + if c.state.CompareAndSwap(0, 1) { + return nil + } else { + return ErrCASConflict + } +} + +// End exits the critical section. +func (c *CheckAndSet) End() { + c.state.Store(0) +} diff --git a/store/cas_test.go b/store/cas_test.go new file mode 100644 index 00000000..4f4a0d05 --- /dev/null +++ b/store/cas_test.go @@ -0,0 +1,28 @@ +package store + +import "testing" + +func Test_NewCAS(t *testing.T) { + cas := NewCheckAndSet() + if exp, got := int32(0), cas.state.Load(); exp != got { + t.Fatalf("expected %d, got %d", exp, got) + } +} + +func Test_CASBegin(t *testing.T) { + cas := NewCheckAndSet() + if err := cas.Begin(); err != nil { + t.Fatalf("expected nil, got %v", err) + } + + // Begin again, should fail + if err := cas.Begin(); err != ErrCASConflict { + t.Fatalf("expected %v, got %v", ErrCASConflict, err) + } + + // End, another begin should succeed + cas.End() + if err := cas.Begin(); err != nil { + t.Fatalf("expected nil, got %v", err) + } +} diff --git a/store/store.go b/store/store.go index af41f64e..5cc3fe3f 100644 --- a/store/store.go +++ b/store/store.go @@ -227,8 +227,6 @@ type Store struct { dbDir string // Path to directory containing SQLite file. db *sql.DB // The underlying SQLite store. - queryTxMu sync.RWMutex - dbAppliedIndexMu sync.RWMutex dbAppliedIndex uint64 appliedIdxUpdateDone chan struct{} @@ -245,6 +243,10 @@ type Store struct { snapshotWClose chan struct{} snapshotWDone chan struct{} + // Snapshotting syncronization + queryTxMu sync.RWMutex + snapshotCAS *CheckAndSet + // Latest log entry index actually reflected by the FSM. Due to Raft code // this value is not updated after a Snapshot-restore. fsmIndex uint64 @@ -348,6 +350,7 @@ func New(ly Layer, c *Config) *Store { logger: logger, notifyingNodes: make(map[string]*Server), ApplyTimeout: applyTimeout, + snapshotCAS: NewCheckAndSet(), } } @@ -1196,26 +1199,44 @@ func (s *Store) Backup(br *proto.BackupRequest, dst io.Writer) (retErr error) { } if br.Format == proto.BackupRequest_BACKUP_REQUEST_FORMAT_BINARY { - f, err := os.CreateTemp(s.dbDir, backupScatchPattern) - if err != nil { + // Snapshot to ensure the main SQLite file has all the latest data. + if err := s.Snapshot(0); err != nil { return err } - if err := f.Close(); err != nil { - return err - } - defer os.Remove(f.Name()) - if err := s.db.Backup(f.Name(), br.Vacuum); err != nil { + // Pause any snapshotting and which will allow us to read the SQLite + // file without it changing underneath us. Any new writes will be + // sent to the WAL. + if err := s.snapshotCAS.Begin(); err != nil { return err } + defer s.snapshotCAS.End() - of, err := os.Open(f.Name()) - if err != nil { - return err + var srcFD *os.File + var err error + if br.Vacuum { + // Vacuum requested, so need an intermediate file. + srcFD, err = os.CreateTemp(s.dbDir, backupScatchPattern) + if err != nil { + return err + } + defer os.Remove(srcFD.Name()) + defer srcFD.Close() + if err := s.db.Backup(srcFD.Name(), br.Vacuum); err != nil { + return err + } + if _, err := srcFD.Seek(0, io.SeekStart); err != nil { + return err + } + } else { + // Fast path -- direct copy. + srcFD, err := os.Open(s.dbPath) + if err != nil { + return err + } + defer srcFD.Close() } - defer of.Close() - - _, err = io.Copy(dst, of) + _, err = io.Copy(dst, srcFD) return err } else if br.Format == proto.BackupRequest_BACKUP_REQUEST_FORMAT_SQL { return s.db.Dump(dst) @@ -1706,6 +1727,11 @@ func (s *Store) Database(leader bool) ([]byte, error) { // http://sqlite.org/howtocorrupt.html states it is safe to copy or serialize the // database as long as no writes to the database are in progress. func (s *Store) fsmSnapshot() (fSnap raft.FSMSnapshot, retErr error) { + if err := s.snapshotCAS.Begin(); err != nil { + return nil, err + } + defer s.snapshotCAS.End() + startT := time.Now() defer func() { if retErr != nil {