From e7ab1389c9e302324413248da669f7246416256b Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 16 Dec 2023 11:26:05 -0500 Subject: [PATCH] Periodically log boot progres --- auto/backup/uploader.go | 19 ++--- http/service.go | 1 + progress/progress.go | 134 +++++++++++++++++++++++++++++++ progress/progress_test.go | 165 ++++++++++++++++++++++++++++++++++++++ store/store.go | 13 ++- 5 files changed, 316 insertions(+), 16 deletions(-) create mode 100644 progress/progress.go create mode 100644 progress/progress_test.go diff --git a/auto/backup/uploader.go b/auto/backup/uploader.go index 90b9e005..bf3e733e 100644 --- a/auto/backup/uploader.go +++ b/auto/backup/uploader.go @@ -9,6 +9,8 @@ import ( "log" "os" "time" + + "github.com/rqlite/rqlite/progress" ) // StorageClient is an interface for uploading data to a storage service. @@ -159,7 +161,7 @@ func (u *Uploader) upload(ctx context.Context) error { } defer fd.Close() - cr := &countingReader{reader: fd} + cr := progress.NewCountingReader(fd) startTime := time.Now() err = u.storageClient.Upload(ctx, cr) if err != nil { @@ -167,8 +169,8 @@ func (u *Uploader) upload(ctx context.Context) error { } else { u.lastSum = sum stats.Add(numUploadsOK, 1) - stats.Add(totalUploadBytes, cr.count) - stats.Get(lastUploadBytes).(*expvar.Int).Set(cr.count) + stats.Add(totalUploadBytes, cr.Count()) + stats.Get(lastUploadBytes).(*expvar.Int).Set(cr.Count()) u.lastUploadTime = time.Now() u.lastUploadDuration = time.Since(startTime) } @@ -218,17 +220,6 @@ func compressFromTo(from, to string) error { return nil } -type countingReader struct { - reader io.Reader - count int64 -} - -func (c *countingReader) Read(p []byte) (int, error) { - n, err := c.reader.Read(p) - c.count += int64(n) - return n, err -} - func tempFilename() (string, error) { f, err := os.CreateTemp("", "rqlite-upload") if err != nil { diff --git a/http/service.go b/http/service.go index 0fdf41ae..f760857f 100644 --- a/http/service.go +++ b/http/service.go @@ -765,6 +765,7 @@ func (s *Service) handleBoot(w http.ResponseWriter, r *http.Request, qp QueryPar return } + s.logger.Printf("starting boot process") _, err = s.store.ReadFrom(bufReader) if err != nil { http.Error(w, err.Error(), http.StatusServiceUnavailable) diff --git a/progress/progress.go b/progress/progress.go new file mode 100644 index 00000000..8dd7be1c --- /dev/null +++ b/progress/progress.go @@ -0,0 +1,134 @@ +package progress + +import ( + "context" + "io" + "sync" + "time" +) + +const ( + countingMonitorInterval = 10 * time.Second +) + +// CountingReader is an io.Reader that counts the number of bytes read. +type CountingReader struct { + reader io.Reader + + mu sync.RWMutex + count int64 +} + +// NewCountingReader returns a new CountingReader. +func NewCountingReader(reader io.Reader) *CountingReader { + return &CountingReader{reader: reader} +} + +// Read reads from the underlying reader, and counts the number of bytes read. +func (c *CountingReader) Read(p []byte) (int, error) { + n, err := c.reader.Read(p) + c.mu.Lock() + defer c.mu.Unlock() + c.count += int64(n) + return n, err +} + +// Count returns the number of bytes read. +func (c *CountingReader) Count() int64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.count +} + +// CountingWriter is an io.Writer that counts the number of bytes written. +type CountingWriter struct { + writer io.Writer + + mu sync.RWMutex + count int64 +} + +// NewCountingWriter returns a new CountingWriter. +func NewCountingWriter(writer io.Writer) *CountingWriter { + return &CountingWriter{writer: writer} +} + +// Write writes to the underlying writer, and counts the number of bytes written. +func (c *CountingWriter) Write(p []byte) (int, error) { + n, err := c.writer.Write(p) + c.mu.Lock() + defer c.mu.Unlock() + c.count += int64(n) + return n, err +} + +// Count returns the number of bytes written. +func (c *CountingWriter) Count() int64 { + c.mu.RLock() + defer c.mu.RUnlock() + return c.count +} + +// LoggerFunc is a function that can be used to log the current count. +type LoggerFunc func(n int64) + +// Counter is an interface that can be used to get the current count. +type Counter interface { + Count() int64 +} + +// CountingMonitor is a monitor that periodically logs the current count. +type CountingMonitor struct { + loggerFn LoggerFunc + ctr Counter + + once sync.Once + cancel func() + doneCh chan struct{} +} + +// StartCountingMonitor starts a CountingMonitor. +func StartCountingMonitor(loggerFn LoggerFunc, ctr Counter) *CountingMonitor { + ctx, cancel := context.WithCancel(context.Background()) + + m := &CountingMonitor{ + loggerFn: loggerFn, + ctr: ctr, + cancel: cancel, + doneCh: make(chan struct{}), + } + go m.run(ctx) + return m +} + +func (cm *CountingMonitor) run(ctx context.Context) { + defer close(cm.doneCh) + + ticker := time.NewTicker(countingMonitorInterval) + defer ticker.Stop() + + ranOnce := false + for { + select { + case <-ctx.Done(): + if !ranOnce { + cm.runOnce() + } + return + case <-ticker.C: + cm.runOnce() + ranOnce = true + } + } +} + +func (cm *CountingMonitor) runOnce() { + cm.loggerFn(cm.ctr.Count()) +} + +func (m *CountingMonitor) StopAndWait() { + m.once.Do(func() { + m.cancel() + <-m.doneCh + }) +} diff --git a/progress/progress_test.go b/progress/progress_test.go new file mode 100644 index 00000000..d2058c24 --- /dev/null +++ b/progress/progress_test.go @@ -0,0 +1,165 @@ +package progress + +import ( + "bytes" + "io" + "strings" + "sync" + "testing" +) + +func TestCountingReader_Read(t *testing.T) { + tests := []struct { + name string + input string + readBufferSize int + wantCount int64 + }{ + { + name: "Read all at once", + input: "Hello, world!", + readBufferSize: 13, // Size of "Hello, world!" + wantCount: 13, + }, + { + name: "Read in small chunks", + input: "Hello, world!", + readBufferSize: 2, // Read in chunks of 2 bytes + wantCount: 13, + }, + { + name: "Empty input", + input: "", + readBufferSize: 10, + wantCount: 0, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + reader := strings.NewReader(tt.input) + countingReader := NewCountingReader(reader) + + buf := make([]byte, tt.readBufferSize) + var totalRead int64 + for { + n, err := countingReader.Read(buf) + totalRead += int64(n) + if err == io.EOF { + break + } + if err != nil { + t.Fatalf("Read() error = %v", err) + } + } + + if totalRead != tt.wantCount { + t.Fatalf("Total bytes read = %v, want %v", totalRead, tt.wantCount) + } + + if got := countingReader.Count(); got != tt.wantCount { + t.Fatalf("CountingReader.Count() = %v, want %v", got, tt.wantCount) + } + }) + } +} + +func TestCountingReader_ConcurrentReads(t *testing.T) { + input := "Concurrent reading test string" + reader := strings.NewReader(input) + countingReader := NewCountingReader(reader) + + var wg sync.WaitGroup + readFunc := func() { + defer wg.Done() + buf := make([]byte, 5) // Read in chunks of 5 bytes + for { + if _, err := countingReader.Read(buf); err == io.EOF { + break + } else if err != nil { + t.Errorf("Read() error = %v", err) + } + } + } + + // Simulate concurrent reads + for i := 0; i < 3; i++ { + wg.Add(1) + go readFunc() + } + wg.Wait() + + wantCount := int64(len(input)) + if got := countingReader.Count(); got != wantCount { + t.Fatalf("CountingReader.Count() after concurrent reads = %v, want %v", got, wantCount) + } +} + +func TestCountingWriter_Write(t *testing.T) { + tests := []struct { + name string + input string + wantCount int64 + }{ + { + name: "Write string", + input: "Hello, world!", + wantCount: 13, // Length of "Hello, world!" + }, + { + name: "Write empty string", + input: "", + wantCount: 0, + }, + { + name: "Write long string", + input: "This is a longer test string", + wantCount: 28, // Length of the string + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var buf bytes.Buffer + countingWriter := NewCountingWriter(&buf) + + n, err := countingWriter.Write([]byte(tt.input)) + if err != nil { + t.Fatalf("Write() error = %v", err) + } + if int64(n) != tt.wantCount { + t.Errorf("Written bytes = %v, want %v", n, tt.wantCount) + } + + if got := countingWriter.Count(); got != tt.wantCount { + t.Errorf("CountingWriter.Count() = %v, want %v", got, tt.wantCount) + } + }) + } +} + +func TestCountingWriter_ConcurrentWrites(t *testing.T) { + var buf bytes.Buffer + countingWriter := NewCountingWriter(&buf) + input := "Concurrent write test string" + wantCount := int64(len(input) * 3) // 3 goroutines writing the same string + + var wg sync.WaitGroup + writeFunc := func() { + defer wg.Done() + if _, err := countingWriter.Write([]byte(input)); err != nil { + t.Errorf("Write() error = %v", err) + } + } + + // Perform concurrent writes + for i := 0; i < 3; i++ { + wg.Add(1) + go writeFunc() + } + wg.Wait() + + if got := countingWriter.Count(); got != wantCount { + t.Errorf("CountingWriter.Count() after concurrent writes = %v, want %v", got, wantCount) + } +} diff --git a/store/store.go b/store/store.go index d2560a1f..46754486 100644 --- a/store/store.go +++ b/store/store.go @@ -25,6 +25,7 @@ import ( sql "github.com/rqlite/rqlite/db" wal "github.com/rqlite/rqlite/db/wal" rlog "github.com/rqlite/rqlite/log" + "github.com/rqlite/rqlite/progress" "github.com/rqlite/rqlite/snapshot" ) @@ -1282,11 +1283,19 @@ func (s *Store) ReadFrom(r io.Reader) (int64, error) { } defer f.Close() defer os.Remove(f.Name()) - n, err := io.Copy(f, r) + + cw := progress.NewCountingWriter(f) + cm := progress.StartCountingMonitor(func(n int64) { + s.logger.Printf("installed %d bytes", n) + }, cw) + n, err := func() (int64, error) { + defer cm.StopAndWait() + defer f.Close() + return io.Copy(cw, r) + }() if err != nil { return n, err } - f.Close() // Confirm the data is a valid SQLite database. if !sql.IsValidSQLiteFile(f.Name()) {