diff --git a/queue/queue.go b/queue/queue.go index d57a18a8..07dda485 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -1,25 +1,79 @@ package queue import ( + "sync" "time" "github.com/rqlite/rqlite/command" ) +// FlushChannel is the type passed to the Queue, if caller wants +// to know when a specific set of statements has been processed. +type FlushChannel chan bool + +// Request represents a batch of statements to be processed. +type Request struct { + SequenceNumber int64 + Statements []*command.Statement + flushChans []FlushChannel +} + +// Close closes a request, closing any associated flush channels. +func (r *Request) Close() { + for _, c := range r.flushChans { + close(c) + } +} + +type queuedStatements struct { + SequenceNumber int64 + Statements []*command.Statement + flushChan FlushChannel +} + +func mergeQueued(qs []*queuedStatements) *Request { + if qs == nil { + return nil + } + var o *Request + for i := range qs { + if o == nil { + o = &Request{ + SequenceNumber: qs[i].SequenceNumber, + Statements: make([]*command.Statement, 0), + flushChans: make([]FlushChannel, 0), + } + } else { + if o.SequenceNumber < qs[i].SequenceNumber { + o.SequenceNumber = qs[i].SequenceNumber + } + } + o.Statements = append(o.Statements, qs[i].Statements...) + if qs[i].flushChan != nil { + o.flushChans = append(o.flushChans, qs[i].flushChan) + } + } + return o +} + // Queue is a batching queue with a timeout. type Queue struct { maxSize int batchSize int timeout time.Duration - batchCh chan *command.Statement - sendCh chan []*command.Statement - C <-chan []*command.Statement + batchCh chan *queuedStatements + + sendCh chan *Request + C <-chan *Request done chan struct{} closed chan struct{} flush chan struct{} + seqMu sync.Mutex + seqNum int64 + // Whitebox unit-testing numTimeouts int } @@ -30,11 +84,12 @@ func New(maxSize, batchSize int, t time.Duration) *Queue { maxSize: maxSize, batchSize: batchSize, timeout: t, - batchCh: make(chan *command.Statement, maxSize), - sendCh: make(chan []*command.Statement, maxSize), + batchCh: make(chan *queuedStatements, maxSize), + sendCh: make(chan *Request, 1), done: make(chan struct{}), closed: make(chan struct{}), flush: make(chan struct{}), + seqNum: time.Now().UnixNano(), } q.C = q.sendCh @@ -42,13 +97,24 @@ func New(maxSize, batchSize int, t time.Duration) *Queue { return q } -// Write queues a request. -func (q *Queue) Write(stmt *command.Statement) error { - if stmt == nil { - return nil +// Write queues a request, and returns a monotonically incrementing +// sequence number associated with the slice of statements. If one +// slice has a larger sequence number than a number, the former slice +// will always be commited to Raft before the latter slice. +// +// c is an optional channel. If non-nil, it will be closed when the Request +// containing these statements is closed. +func (q *Queue) Write(stmts []*command.Statement, c FlushChannel) (int64, error) { + q.seqMu.Lock() + defer q.seqMu.Unlock() + q.seqNum++ + + q.batchCh <- &queuedStatements{ + SequenceNumber: q.seqNum, + Statements: stmts, + flushChan: c, } - q.batchCh <- stmt - return nil + return q.seqNum, nil } // Flush flushes the queue @@ -68,7 +134,7 @@ func (q *Queue) Close() error { return nil } -// Depth returns the number of queue requests +// Depth returns the number of queued requests func (q *Queue) Depth() int { return len(q.batchCh) } @@ -84,27 +150,25 @@ func (q *Queue) Stats() (map[string]interface{}, error) { func (q *Queue) run() { defer close(q.closed) - var stmts []*command.Statement + + queuedStmts := make([]*queuedStatements, 0) timer := time.NewTimer(q.timeout) timer.Stop() writeFn := func() { - newStmts := make([]*command.Statement, len(stmts)) - copy(newStmts, stmts) - q.sendCh <- newStmts - - stmts = nil + q.sendCh <- mergeQueued(queuedStmts) + queuedStmts = nil timer.Stop() } for { select { case s := <-q.batchCh: - stmts = append(stmts, s) - if len(stmts) == 1 { + queuedStmts = append(queuedStmts, s) + if len(queuedStmts) == 1 { timer.Reset(q.timeout) } - if len(stmts) >= q.batchSize { + if len(queuedStmts) >= q.batchSize { writeFn() } case <-timer.C: diff --git a/queue/queue_test.go b/queue/queue_test.go index 0a8770b2..a1afe086 100644 --- a/queue/queue_test.go +++ b/queue/queue_test.go @@ -1,14 +1,101 @@ package queue import ( + "reflect" "testing" "time" "github.com/rqlite/rqlite/command" ) -var testStmt = &command.Statement{ - Sql: "SELECT * FROM foo", +var ( + testStmtFoo = &command.Statement{Sql: "SELECT * FROM foo"} + testStmtBar = &command.Statement{Sql: "SELECT * FROM bar"} + testStmtQux = &command.Statement{Sql: "SELECT * FROM qux"} + testStmtsFoo = []*command.Statement{testStmtFoo} + testStmtsBar = []*command.Statement{testStmtBar} + testStmtsFooBar = []*command.Statement{testStmtFoo, testStmtBar} + testStmtsFooBarFoo = []*command.Statement{testStmtFoo, testStmtBar, testStmtFoo} + flushChan1 = make(FlushChannel) + flushChan2 = make(FlushChannel) +) + +func Test_MergeQueuedStatements(t *testing.T) { + if mergeQueued(nil) != nil { + t.Fatalf("merging of nil failed") + } + + tests := []struct { + qs []*queuedStatements + exp *Request + }{ + { + qs: []*queuedStatements{ + {1, testStmtsFoo, nil}, + }, + exp: &Request{1, testStmtsFoo, nil}, + }, + { + qs: []*queuedStatements{ + {1, testStmtsFoo, nil}, + {2, testStmtsBar, nil}, + }, + exp: &Request{2, testStmtsFooBar, nil}, + }, + { + qs: []*queuedStatements{ + {1, testStmtsFooBar, nil}, + {2, testStmtsFoo, nil}, + }, + exp: &Request{2, testStmtsFooBarFoo, nil}, + }, + { + qs: []*queuedStatements{ + {1, testStmtsFooBar, nil}, + {2, testStmtsFoo, nil}, + }, + exp: &Request{2, testStmtsFooBarFoo, nil}, + }, + { + qs: []*queuedStatements{ + {1, testStmtsFooBar, flushChan1}, + {2, testStmtsFoo, flushChan2}, + }, + exp: &Request{2, testStmtsFooBarFoo, []FlushChannel{flushChan1, flushChan2}}, + }, + { + qs: []*queuedStatements{ + {1, testStmtsFooBar, nil}, + {2, testStmtsFoo, flushChan2}, + }, + exp: &Request{2, testStmtsFooBarFoo, []FlushChannel{flushChan2}}, + }, + { + qs: []*queuedStatements{ + {2, testStmtsFooBar, nil}, + {1, testStmtsFoo, flushChan2}, + }, + exp: &Request{2, testStmtsFooBarFoo, []FlushChannel{flushChan2}}, + }, + } + + for i, tt := range tests { + r := mergeQueued(tt.qs) + if got, exp := r.SequenceNumber, tt.exp.SequenceNumber; got != exp { + t.Fatalf("incorrect sequence number for test %d, exp %d, got %d", i, exp, got) + } + if !reflect.DeepEqual(r.Statements, tt.exp.Statements) { + t.Fatalf("statements don't match for test %d", i) + } + if len(r.flushChans) != len(tt.exp.flushChans) { + t.Fatalf("incorrect number of flush channels for test %d", i) + for i := range r.flushChans { + if r.flushChans[i] != tt.exp.flushChans[i] { + t.Fatalf("wrong channel for test %d", i) + } + } + } + } } func Test_NewQueue(t *testing.T) { @@ -23,7 +110,7 @@ func Test_NewQueueWriteNil(t *testing.T) { q := New(1, 1, 60*time.Second) defer q.Close() - if err := q.Write(nil); err != nil { + if _, err := q.Write(nil, nil); err != nil { t.Fatalf("failing to write nil: %s", err.Error()) } } @@ -32,16 +119,16 @@ func Test_NewQueueWriteBatchSizeSingle(t *testing.T) { q := New(1024, 1, 60*time.Second) defer q.Close() - if err := q.Write(testStmt); err != nil { + if _, err := q.Write(testStmtsFoo, nil); err != nil { t.Fatalf("failed to write: %s", err.Error()) } select { - case stmts := <-q.C: - if len(stmts) != 1 { - t.Fatalf("received wrong length slice") + case req := <-q.C: + if exp, got := 1, len(req.Statements); exp != got { + t.Fatalf("received wrong length slice, exp %d, got %d", exp, got) } - if stmts[0].Sql != "SELECT * FROM foo" { + if req.Statements[0].Sql != "SELECT * FROM foo" { t.Fatalf("received wrong SQL") } case <-time.NewTimer(5 * time.Second).C: @@ -55,13 +142,13 @@ func Test_NewQueueWriteBatchSizeMulti(t *testing.T) { // Write a batch size and wait for it. for i := 0; i < 5; i++ { - if err := q.Write(testStmt); err != nil { + if _, err := q.Write(testStmtsFoo, nil); err != nil { t.Fatalf("failed to write: %s", err.Error()) } } select { - case stmts := <-q.C: - if len(stmts) != 5 { + case req := <-q.C: + if len(req.Statements) != 5 { t.Fatalf("received wrong length slice") } if q.numTimeouts != 0 { @@ -73,13 +160,13 @@ func Test_NewQueueWriteBatchSizeMulti(t *testing.T) { // Write one more than a batch size, should still get a batch. for i := 0; i < 6; i++ { - if err := q.Write(testStmt); err != nil { + if _, err := q.Write(testStmtsBar, nil); err != nil { t.Fatalf("failed to write: %s", err.Error()) } } select { - case stmts := <-q.C: - if len(stmts) < 5 { + case req := <-q.C: + if len(req.Statements) < 5 { t.Fatalf("received too-short slice") } if q.numTimeouts != 0 { @@ -94,16 +181,16 @@ func Test_NewQueueWriteTimeout(t *testing.T) { q := New(1024, 10, 1*time.Second) defer q.Close() - if err := q.Write(testStmt); err != nil { + if _, err := q.Write(testStmtsFoo, nil); err != nil { t.Fatalf("failed to write: %s", err.Error()) } select { - case stmts := <-q.C: - if len(stmts) != 1 { + case req := <-q.C: + if len(req.Statements) != 1 { t.Fatalf("received wrong length slice") } - if stmts[0].Sql != "SELECT * FROM foo" { + if req.Statements[0].Sql != "SELECT * FROM foo" { t.Fatalf("received wrong SQL") } if q.numTimeouts != 1 { @@ -120,15 +207,15 @@ func Test_NewQueueWriteTimeoutMulti(t *testing.T) { q := New(1024, 10, 1*time.Second) defer q.Close() - if err := q.Write(testStmt); err != nil { + if _, err := q.Write(testStmtsFoo, nil); err != nil { t.Fatalf("failed to write: %s", err.Error()) } select { - case stmts := <-q.C: - if len(stmts) != 1 { + case req := <-q.C: + if len(req.Statements) != 1 { t.Fatalf("received wrong length slice") } - if stmts[0].Sql != "SELECT * FROM foo" { + if req.Statements[0].Sql != "SELECT * FROM foo" { t.Fatalf("received wrong SQL") } if q.numTimeouts != 1 { @@ -138,15 +225,15 @@ func Test_NewQueueWriteTimeoutMulti(t *testing.T) { t.Fatalf("timed out waiting for first statement") } - if err := q.Write(testStmt); err != nil { + if _, err := q.Write(testStmtsFoo, nil); err != nil { t.Fatalf("failed to write: %s", err.Error()) } select { - case stmts := <-q.C: - if len(stmts) != 1 { + case req := <-q.C: + if len(req.Statements) != 1 { t.Fatalf("received wrong length slice") } - if stmts[0].Sql != "SELECT * FROM foo" { + if req.Statements[0].Sql != "SELECT * FROM foo" { t.Fatalf("received wrong SQL") } if q.numTimeouts != 2 { @@ -163,16 +250,16 @@ func Test_NewQueueWriteTimeoutBatch(t *testing.T) { q := New(1024, 2, 1*time.Second) defer q.Close() - if err := q.Write(testStmt); err != nil { + if _, err := q.Write(testStmtsFoo, nil); err != nil { t.Fatalf("failed to write: %s", err.Error()) } select { - case stmts := <-q.C: - if len(stmts) != 1 { + case req := <-q.C: + if len(req.Statements) != 1 { t.Fatalf("received wrong length slice") } - if stmts[0].Sql != "SELECT * FROM foo" { + if req.Statements[0].Sql != "SELECT * FROM foo" { t.Fatalf("received wrong SQL") } if q.numTimeouts != 1 { @@ -182,19 +269,19 @@ func Test_NewQueueWriteTimeoutBatch(t *testing.T) { t.Fatalf("timed out waiting for statement") } - if err := q.Write(testStmt); err != nil { + if _, err := q.Write(testStmtsFoo, nil); err != nil { t.Fatalf("failed to write: %s", err.Error()) } - if err := q.Write(testStmt); err != nil { + if _, err := q.Write(testStmtsFoo, nil); err != nil { t.Fatalf("failed to write: %s", err.Error()) } select { - case stmts := <-q.C: + case req := <-q.C: // Should happen before the timeout expires. - if len(stmts) != 2 { + if len(req.Statements) != 2 { t.Fatalf("received wrong length slice") } - if stmts[0].Sql != "SELECT * FROM foo" { + if req.Statements[0].Sql != "SELECT * FROM foo" { t.Fatalf("received wrong SQL") } if q.numTimeouts != 1 {