1
0
Fork 0

Queue unit tests pass with sequence numbers

master
Philip O'Toole 2 years ago
parent c268fa851c
commit 487a459441

@ -1,25 +1,79 @@
package queue
import (
"sync"
"time"
"github.com/rqlite/rqlite/command"
)
// FlushChannel is the type passed to the Queue, if caller wants
// to know when a specific set of statements has been processed.
type FlushChannel chan bool
// Request represents a batch of statements to be processed.
type Request struct {
SequenceNumber int64
Statements []*command.Statement
flushChans []FlushChannel
}
// Close closes a request, closing any associated flush channels.
func (r *Request) Close() {
for _, c := range r.flushChans {
close(c)
}
}
type queuedStatements struct {
SequenceNumber int64
Statements []*command.Statement
flushChan FlushChannel
}
func mergeQueued(qs []*queuedStatements) *Request {
if qs == nil {
return nil
}
var o *Request
for i := range qs {
if o == nil {
o = &Request{
SequenceNumber: qs[i].SequenceNumber,
Statements: make([]*command.Statement, 0),
flushChans: make([]FlushChannel, 0),
}
} else {
if o.SequenceNumber < qs[i].SequenceNumber {
o.SequenceNumber = qs[i].SequenceNumber
}
}
o.Statements = append(o.Statements, qs[i].Statements...)
if qs[i].flushChan != nil {
o.flushChans = append(o.flushChans, qs[i].flushChan)
}
}
return o
}
// Queue is a batching queue with a timeout.
type Queue struct {
maxSize int
batchSize int
timeout time.Duration
batchCh chan *command.Statement
sendCh chan []*command.Statement
C <-chan []*command.Statement
batchCh chan *queuedStatements
sendCh chan *Request
C <-chan *Request
done chan struct{}
closed chan struct{}
flush chan struct{}
seqMu sync.Mutex
seqNum int64
// Whitebox unit-testing
numTimeouts int
}
@ -30,11 +84,12 @@ func New(maxSize, batchSize int, t time.Duration) *Queue {
maxSize: maxSize,
batchSize: batchSize,
timeout: t,
batchCh: make(chan *command.Statement, maxSize),
sendCh: make(chan []*command.Statement, maxSize),
batchCh: make(chan *queuedStatements, maxSize),
sendCh: make(chan *Request, 1),
done: make(chan struct{}),
closed: make(chan struct{}),
flush: make(chan struct{}),
seqNum: time.Now().UnixNano(),
}
q.C = q.sendCh
@ -42,13 +97,24 @@ func New(maxSize, batchSize int, t time.Duration) *Queue {
return q
}
// Write queues a request.
func (q *Queue) Write(stmt *command.Statement) error {
if stmt == nil {
return nil
// Write queues a request, and returns a monotonically incrementing
// sequence number associated with the slice of statements. If one
// slice has a larger sequence number than a number, the former slice
// will always be commited to Raft before the latter slice.
//
// c is an optional channel. If non-nil, it will be closed when the Request
// containing these statements is closed.
func (q *Queue) Write(stmts []*command.Statement, c FlushChannel) (int64, error) {
q.seqMu.Lock()
defer q.seqMu.Unlock()
q.seqNum++
q.batchCh <- &queuedStatements{
SequenceNumber: q.seqNum,
Statements: stmts,
flushChan: c,
}
q.batchCh <- stmt
return nil
return q.seqNum, nil
}
// Flush flushes the queue
@ -68,7 +134,7 @@ func (q *Queue) Close() error {
return nil
}
// Depth returns the number of queue requests
// Depth returns the number of queued requests
func (q *Queue) Depth() int {
return len(q.batchCh)
}
@ -84,27 +150,25 @@ func (q *Queue) Stats() (map[string]interface{}, error) {
func (q *Queue) run() {
defer close(q.closed)
var stmts []*command.Statement
queuedStmts := make([]*queuedStatements, 0)
timer := time.NewTimer(q.timeout)
timer.Stop()
writeFn := func() {
newStmts := make([]*command.Statement, len(stmts))
copy(newStmts, stmts)
q.sendCh <- newStmts
stmts = nil
q.sendCh <- mergeQueued(queuedStmts)
queuedStmts = nil
timer.Stop()
}
for {
select {
case s := <-q.batchCh:
stmts = append(stmts, s)
if len(stmts) == 1 {
queuedStmts = append(queuedStmts, s)
if len(queuedStmts) == 1 {
timer.Reset(q.timeout)
}
if len(stmts) >= q.batchSize {
if len(queuedStmts) >= q.batchSize {
writeFn()
}
case <-timer.C:

@ -1,14 +1,101 @@
package queue
import (
"reflect"
"testing"
"time"
"github.com/rqlite/rqlite/command"
)
var testStmt = &command.Statement{
Sql: "SELECT * FROM foo",
var (
testStmtFoo = &command.Statement{Sql: "SELECT * FROM foo"}
testStmtBar = &command.Statement{Sql: "SELECT * FROM bar"}
testStmtQux = &command.Statement{Sql: "SELECT * FROM qux"}
testStmtsFoo = []*command.Statement{testStmtFoo}
testStmtsBar = []*command.Statement{testStmtBar}
testStmtsFooBar = []*command.Statement{testStmtFoo, testStmtBar}
testStmtsFooBarFoo = []*command.Statement{testStmtFoo, testStmtBar, testStmtFoo}
flushChan1 = make(FlushChannel)
flushChan2 = make(FlushChannel)
)
func Test_MergeQueuedStatements(t *testing.T) {
if mergeQueued(nil) != nil {
t.Fatalf("merging of nil failed")
}
tests := []struct {
qs []*queuedStatements
exp *Request
}{
{
qs: []*queuedStatements{
{1, testStmtsFoo, nil},
},
exp: &Request{1, testStmtsFoo, nil},
},
{
qs: []*queuedStatements{
{1, testStmtsFoo, nil},
{2, testStmtsBar, nil},
},
exp: &Request{2, testStmtsFooBar, nil},
},
{
qs: []*queuedStatements{
{1, testStmtsFooBar, nil},
{2, testStmtsFoo, nil},
},
exp: &Request{2, testStmtsFooBarFoo, nil},
},
{
qs: []*queuedStatements{
{1, testStmtsFooBar, nil},
{2, testStmtsFoo, nil},
},
exp: &Request{2, testStmtsFooBarFoo, nil},
},
{
qs: []*queuedStatements{
{1, testStmtsFooBar, flushChan1},
{2, testStmtsFoo, flushChan2},
},
exp: &Request{2, testStmtsFooBarFoo, []FlushChannel{flushChan1, flushChan2}},
},
{
qs: []*queuedStatements{
{1, testStmtsFooBar, nil},
{2, testStmtsFoo, flushChan2},
},
exp: &Request{2, testStmtsFooBarFoo, []FlushChannel{flushChan2}},
},
{
qs: []*queuedStatements{
{2, testStmtsFooBar, nil},
{1, testStmtsFoo, flushChan2},
},
exp: &Request{2, testStmtsFooBarFoo, []FlushChannel{flushChan2}},
},
}
for i, tt := range tests {
r := mergeQueued(tt.qs)
if got, exp := r.SequenceNumber, tt.exp.SequenceNumber; got != exp {
t.Fatalf("incorrect sequence number for test %d, exp %d, got %d", i, exp, got)
}
if !reflect.DeepEqual(r.Statements, tt.exp.Statements) {
t.Fatalf("statements don't match for test %d", i)
}
if len(r.flushChans) != len(tt.exp.flushChans) {
t.Fatalf("incorrect number of flush channels for test %d", i)
for i := range r.flushChans {
if r.flushChans[i] != tt.exp.flushChans[i] {
t.Fatalf("wrong channel for test %d", i)
}
}
}
}
}
func Test_NewQueue(t *testing.T) {
@ -23,7 +110,7 @@ func Test_NewQueueWriteNil(t *testing.T) {
q := New(1, 1, 60*time.Second)
defer q.Close()
if err := q.Write(nil); err != nil {
if _, err := q.Write(nil, nil); err != nil {
t.Fatalf("failing to write nil: %s", err.Error())
}
}
@ -32,16 +119,16 @@ func Test_NewQueueWriteBatchSizeSingle(t *testing.T) {
q := New(1024, 1, 60*time.Second)
defer q.Close()
if err := q.Write(testStmt); err != nil {
if _, err := q.Write(testStmtsFoo, nil); err != nil {
t.Fatalf("failed to write: %s", err.Error())
}
select {
case stmts := <-q.C:
if len(stmts) != 1 {
t.Fatalf("received wrong length slice")
case req := <-q.C:
if exp, got := 1, len(req.Statements); exp != got {
t.Fatalf("received wrong length slice, exp %d, got %d", exp, got)
}
if stmts[0].Sql != "SELECT * FROM foo" {
if req.Statements[0].Sql != "SELECT * FROM foo" {
t.Fatalf("received wrong SQL")
}
case <-time.NewTimer(5 * time.Second).C:
@ -55,13 +142,13 @@ func Test_NewQueueWriteBatchSizeMulti(t *testing.T) {
// Write a batch size and wait for it.
for i := 0; i < 5; i++ {
if err := q.Write(testStmt); err != nil {
if _, err := q.Write(testStmtsFoo, nil); err != nil {
t.Fatalf("failed to write: %s", err.Error())
}
}
select {
case stmts := <-q.C:
if len(stmts) != 5 {
case req := <-q.C:
if len(req.Statements) != 5 {
t.Fatalf("received wrong length slice")
}
if q.numTimeouts != 0 {
@ -73,13 +160,13 @@ func Test_NewQueueWriteBatchSizeMulti(t *testing.T) {
// Write one more than a batch size, should still get a batch.
for i := 0; i < 6; i++ {
if err := q.Write(testStmt); err != nil {
if _, err := q.Write(testStmtsBar, nil); err != nil {
t.Fatalf("failed to write: %s", err.Error())
}
}
select {
case stmts := <-q.C:
if len(stmts) < 5 {
case req := <-q.C:
if len(req.Statements) < 5 {
t.Fatalf("received too-short slice")
}
if q.numTimeouts != 0 {
@ -94,16 +181,16 @@ func Test_NewQueueWriteTimeout(t *testing.T) {
q := New(1024, 10, 1*time.Second)
defer q.Close()
if err := q.Write(testStmt); err != nil {
if _, err := q.Write(testStmtsFoo, nil); err != nil {
t.Fatalf("failed to write: %s", err.Error())
}
select {
case stmts := <-q.C:
if len(stmts) != 1 {
case req := <-q.C:
if len(req.Statements) != 1 {
t.Fatalf("received wrong length slice")
}
if stmts[0].Sql != "SELECT * FROM foo" {
if req.Statements[0].Sql != "SELECT * FROM foo" {
t.Fatalf("received wrong SQL")
}
if q.numTimeouts != 1 {
@ -120,15 +207,15 @@ func Test_NewQueueWriteTimeoutMulti(t *testing.T) {
q := New(1024, 10, 1*time.Second)
defer q.Close()
if err := q.Write(testStmt); err != nil {
if _, err := q.Write(testStmtsFoo, nil); err != nil {
t.Fatalf("failed to write: %s", err.Error())
}
select {
case stmts := <-q.C:
if len(stmts) != 1 {
case req := <-q.C:
if len(req.Statements) != 1 {
t.Fatalf("received wrong length slice")
}
if stmts[0].Sql != "SELECT * FROM foo" {
if req.Statements[0].Sql != "SELECT * FROM foo" {
t.Fatalf("received wrong SQL")
}
if q.numTimeouts != 1 {
@ -138,15 +225,15 @@ func Test_NewQueueWriteTimeoutMulti(t *testing.T) {
t.Fatalf("timed out waiting for first statement")
}
if err := q.Write(testStmt); err != nil {
if _, err := q.Write(testStmtsFoo, nil); err != nil {
t.Fatalf("failed to write: %s", err.Error())
}
select {
case stmts := <-q.C:
if len(stmts) != 1 {
case req := <-q.C:
if len(req.Statements) != 1 {
t.Fatalf("received wrong length slice")
}
if stmts[0].Sql != "SELECT * FROM foo" {
if req.Statements[0].Sql != "SELECT * FROM foo" {
t.Fatalf("received wrong SQL")
}
if q.numTimeouts != 2 {
@ -163,16 +250,16 @@ func Test_NewQueueWriteTimeoutBatch(t *testing.T) {
q := New(1024, 2, 1*time.Second)
defer q.Close()
if err := q.Write(testStmt); err != nil {
if _, err := q.Write(testStmtsFoo, nil); err != nil {
t.Fatalf("failed to write: %s", err.Error())
}
select {
case stmts := <-q.C:
if len(stmts) != 1 {
case req := <-q.C:
if len(req.Statements) != 1 {
t.Fatalf("received wrong length slice")
}
if stmts[0].Sql != "SELECT * FROM foo" {
if req.Statements[0].Sql != "SELECT * FROM foo" {
t.Fatalf("received wrong SQL")
}
if q.numTimeouts != 1 {
@ -182,19 +269,19 @@ func Test_NewQueueWriteTimeoutBatch(t *testing.T) {
t.Fatalf("timed out waiting for statement")
}
if err := q.Write(testStmt); err != nil {
if _, err := q.Write(testStmtsFoo, nil); err != nil {
t.Fatalf("failed to write: %s", err.Error())
}
if err := q.Write(testStmt); err != nil {
if _, err := q.Write(testStmtsFoo, nil); err != nil {
t.Fatalf("failed to write: %s", err.Error())
}
select {
case stmts := <-q.C:
case req := <-q.C:
// Should happen before the timeout expires.
if len(stmts) != 2 {
if len(req.Statements) != 2 {
t.Fatalf("received wrong length slice")
}
if stmts[0].Sql != "SELECT * FROM foo" {
if req.Statements[0].Sql != "SELECT * FROM foo" {
t.Fatalf("received wrong SQL")
}
if q.numTimeouts != 1 {

Loading…
Cancel
Save