diff --git a/queue/queue.go b/queue/queue.go index 8f7f24a4..29053cf6 100644 --- a/queue/queue.go +++ b/queue/queue.go @@ -1,6 +1,7 @@ package queue import ( + "errors" "expvar" "sync" "time" @@ -58,16 +59,16 @@ type queuedStatements struct { func mergeQueued(qs []*queuedStatements) *Request { var o *Request + if len(qs) > 0 { + o = &Request{ + SequenceNumber: qs[0].SequenceNumber, + flushChans: make([]FlushChannel, 0), + } + } + for i := range qs { - if o == nil { - o = &Request{ - SequenceNumber: qs[i].SequenceNumber, - flushChans: make([]FlushChannel, 0), - } - } else { - if o.SequenceNumber < qs[i].SequenceNumber { - o.SequenceNumber = qs[i].SequenceNumber - } + if o.SequenceNumber < qs[i].SequenceNumber { + o.SequenceNumber = qs[i].SequenceNumber } o.Statements = append(o.Statements, qs[i].Statements...) if qs[i].flushChan != nil { @@ -126,6 +127,12 @@ func New(maxSize, batchSize int, t time.Duration) *Queue { // c is an optional channel. If non-nil, it will be closed when the Request // containing these statements is closed. func (q *Queue) Write(stmts []*command.Statement, c FlushChannel) (int64, error) { + select { + case <-q.done: + return 0, errors.New("queue is closed") + default: + } + q.seqMu.Lock() defer q.seqMu.Unlock() q.seqNum++ diff --git a/queue/queue_test.go b/queue/queue_test.go index 980efe80..f2b63bd5 100644 --- a/queue/queue_test.go +++ b/queue/queue_test.go @@ -120,6 +120,17 @@ func Test_NewQueue(t *testing.T) { defer q.Close() } +func Test_NewQueueClosedWrite(t *testing.T) { + q := New(1, 1, 100*time.Millisecond) + if q == nil { + t.Fatalf("failed to create new Queue") + } + q.Close() + if _, err := q.Write(testStmtsFoo, nil); err == nil { + t.Fatalf("failed to detect closed queue") + } +} + func Test_NewQueueWriteNil(t *testing.T) { q := New(1, 1, 60*time.Second) defer q.Close()