From ba0be0dd4148ddd2ddb80e2148b70496c77035ec Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Mon, 16 May 2022 22:41:13 -0400 Subject: [PATCH] Batcher emits using channel --- queue/queue.go | 46 ++++++----------- queue/queue_test.go | 118 +++++++++++++++++++------------------------- 2 files changed, 66 insertions(+), 98 deletions(-) diff --git a/queue/queue.go b/queue/queue.go index f7175847..2a6794d5 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -6,35 +6,33 @@ import ( "github.com/rqlite/rqlite/command" ) -type Execer interface { - Execute(er *command.ExecuteRequest) ([]*command.ExecuteResult, error) -} - type Queue struct { maxSize int batchSize int timeout time.Duration - store Execer - c chan *command.Statement + batchCh chan *command.Statement + sendCh chan []*command.Statement + C <-chan []*command.Statement done chan struct{} closed chan struct{} flush chan struct{} } -func New(maxSize, batchSize int, t time.Duration, e Execer) *Queue { +func New(maxSize, batchSize int, t time.Duration) *Queue { q := &Queue{ maxSize: maxSize, batchSize: batchSize, timeout: t, - store: e, - c: make(chan *command.Statement, maxSize), + batchCh: make(chan *command.Statement, maxSize), + sendCh: make(chan []*command.Statement, maxSize), done: make(chan struct{}), closed: make(chan struct{}), flush: make(chan struct{}), } + q.C = q.sendCh go q.run() return q } @@ -45,7 +43,7 @@ func (q *Queue) Write(stmt *command.Statement) error { if stmt == nil { return nil } - q.c <- stmt + q.batchCh <- stmt return nil } @@ -61,7 +59,7 @@ func (q *Queue) Close() error { } func (q *Queue) Depth() int { - return len(q.c) + return len(q.batchCh) } func (q *Queue) run() { @@ -71,19 +69,22 @@ func (q *Queue) run() { timer.Stop() writeFn := func(stmts []*command.Statement) { - q.exec(stmts) + newStmts := make([]*command.Statement, len(stmts)) + copy(newStmts, stmts) + q.sendCh <- newStmts + stmts = nil timer.Stop() } for { select { - case s := <-q.c: + case s := <-q.batchCh: stmts = append(stmts, s) if len(stmts) == 1 { timer.Reset(q.timeout) } - if len(stmts) == q.batchSize { + if len(stmts) >= q.batchSize { writeFn(stmts) } case <-timer.C: @@ -96,20 +97,3 @@ func (q *Queue) run() { } } } - -func (q *Queue) exec(stmts []*command.Statement) error { - if stmts == nil || len(stmts) == 0 { - return nil - } - - er := &command.ExecuteRequest{ - Request: &command.Request{ - Statements: stmts, - }, - } - - // Doesn't handle leader-redirect, transparent forwarding, etc. - // Would need a "wrapped" store which handles it. - _, err := q.store.Execute(er) - return err -} diff --git a/queue/queue_test.go b/queue/queue_test.go index 7814e02a..4f8f23bb 100644 --- a/queue/queue_test.go +++ b/queue/queue_test.go @@ -1,7 +1,6 @@ package queue import ( - "sync" "testing" "time" @@ -13,7 +12,7 @@ var testStmt = &command.Statement{ } func Test_NewQueue(t *testing.T) { - q := New(1, 1, 100*time.Millisecond, nil) + q := New(1, 1, 100*time.Millisecond) if q == nil { t.Fatalf("failed to create new Queue") } @@ -21,8 +20,7 @@ func Test_NewQueue(t *testing.T) { } func Test_NewQueueWriteNil(t *testing.T) { - m := &MockExecer{} - q := New(1, 1, 60*time.Second, m) + q := New(1, 1, 60*time.Second) defer q.Close() if err := q.Write(nil); err != nil { @@ -30,93 +28,79 @@ func Test_NewQueueWriteNil(t *testing.T) { } } -func Test_NewQueueWriteBatchSize(t *testing.T) { - m := &MockExecer{} - q := New(1024, 1, 60*time.Second, m) +func Test_NewQueueWriteBatchSizeSingle(t *testing.T) { + q := New(1024, 1, 60*time.Second) defer q.Close() - var wg sync.WaitGroup - var numExecs int - wg.Add(1) - m.execFn = func(er *command.ExecuteRequest) ([]*command.ExecuteResult, error) { - wg.Done() - numExecs++ - return nil, nil - } - if err := q.Write(testStmt); err != nil { t.Fatalf("failed to write: %s", err.Error()) } - wg.Wait() - if exp, got := 1, numExecs; exp != got { - t.Fatalf("exec not called correct number of times, exp %d got %d", exp, got) + select { + case stmts := <-q.C: + if len(stmts) != 1 { + t.Fatalf("received wrong length slice") + } + if stmts[0].Sql != "SELECT * FROM foo" { + t.Fatalf("received wrong SQL") + } + case <-time.NewTimer(5 * time.Second).C: + t.Fatalf("timed out waiting for statement") } } -func Test_NewQueueWriteFlush(t *testing.T) { - m := &MockExecer{} - q := New(1024, 10, 60*time.Second, m) +func Test_NewQueueWriteBatchSizeMulti(t *testing.T) { + q := New(1024, 5, 60*time.Second) defer q.Close() - var wg sync.WaitGroup - var numExecs int - wg.Add(1) - m.execFn = func(er *command.ExecuteRequest) ([]*command.ExecuteResult, error) { - wg.Done() - numExecs++ - return nil, nil + // Write a batch size and wait for it. + for i := 0; i < 5; i++ { + if err := q.Write(testStmt); err != nil { + t.Fatalf("failed to write: %s", err.Error()) + } } - - if err := q.Write(testStmt); err != nil { - t.Fatalf("failed to write: %s", err.Error()) + select { + case stmts := <-q.C: + if len(stmts) != 5 { + t.Fatalf("received wrong length slice") + } + case <-time.NewTimer(5 * time.Second).C: + t.Fatalf("timed out waiting for first statements") } - time.Sleep(1 * time.Second) - - if err := q.Flush(); err != nil { - t.Fatalf("failed to flush: %s", err.Error()) + // Write one more than a batch size, should still get a batch. + for i := 0; i < 6; i++ { + if err := q.Write(testStmt); err != nil { + t.Fatalf("failed to write: %s", err.Error()) + } } - wg.Wait() - - if exp, got := 1, numExecs; exp != got { - t.Fatalf("exec not called correct number of times, exp %d got %d", exp, got) + select { + case stmts := <-q.C: + if len(stmts) < 5 { + t.Fatalf("received too-short slice") + } + case <-time.NewTimer(5 * time.Second).C: + t.Fatalf("timed out waiting for second statements") } } func Test_NewQueueWriteTimeout(t *testing.T) { - m := &MockExecer{} - q := New(1024, 10, 1*time.Second, m) + q := New(1024, 10, 1*time.Second) defer q.Close() - var wg sync.WaitGroup - var numExecs int - wg.Add(1) - m.execFn = func(er *command.ExecuteRequest) ([]*command.ExecuteResult, error) { - wg.Done() - numExecs++ - return nil, nil - } - if err := q.Write(testStmt); err != nil { t.Fatalf("failed to write: %s", err.Error()) } - time.Sleep(time.Second) - - wg.Wait() - if exp, got := 1, numExecs; exp != got { - t.Fatalf("exec not called correct number of times, exp %d got %d", exp, got) - } -} - -type MockExecer struct { - execFn func(er *command.ExecuteRequest) ([]*command.ExecuteResult, error) -} - -func (m *MockExecer) Execute(er *command.ExecuteRequest) ([]*command.ExecuteResult, error) { - if m.execFn != nil { - return m.execFn(er) + select { + case stmts := <-q.C: + if len(stmts) != 1 { + t.Fatalf("received wrong length slice") + } + if stmts[0].Sql != "SELECT * FROM foo" { + t.Fatalf("received wrong SQL") + } + case <-time.NewTimer(5 * time.Second).C: + t.Fatalf("timed out waiting for statement") } - return nil, nil }