diff --git a/CHANGELOG.md b/CHANGELOG.md index 2536331e..1b64a89e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,6 +1,7 @@ ## 7.5.0 (unreleased) ### New features - [PR #1019](https://github.com/rqlite/rqlite/pull/1019): CLI supports restoring from SQLite database files. +- [PR #1024](https://github.com/rqlite/rqlite/pull/1024): Add support for Queued Writes. Fixes [issue #1020](https://github.com/rqlite/rqlite/issues/1020). ## 7.4.0 (May 10th 2022) With this release rqlite supports restoring a node from an actual SQLite file, which is very much faster than restoring using the SQL dump representation of the same SQLite database. diff --git a/DOC/DATA_API.md b/DOC/DATA_API.md index e5233275..477ad011 100644 --- a/DOC/DATA_API.md +++ b/DOC/DATA_API.md @@ -198,7 +198,10 @@ $ curl -v -G 'localhost:4003/db/query?pretty&timings&redirect' --data-urlencode * Connection #0 to host localhost left intact ``` +## Queued Writes API +You can learn about the Queued Writes API [here](https://github.com/rqlite/rqlite/blob/master/DOC/QUEUED_WRITES.md). + ## Bulk API -You can learn about the bulk API [here](https://github.com/rqlite/rqlite/blob/master/DOC/BULK.md). +You can learn about the Bulk API [here](https://github.com/rqlite/rqlite/blob/master/DOC/BULK.md). diff --git a/DOC/PERFORMANCE.md b/DOC/PERFORMANCE.md index e4a0129e..1238ab67 100644 --- a/DOC/PERFORMANCE.md +++ b/DOC/PERFORMANCE.md @@ -23,6 +23,9 @@ The more SQLite statements you can include in a single request to a rqlite node, By using the [bulk API](https://github.com/rqlite/rqlite/blob/master/DOC/BULK.md), transactions, or both, throughput will increase significantly, often by 2 orders of magnitude. This speed-up is due to the way Raft and SQLite work. So for high throughput, execute as many operations as possible within a single transaction. +## Queued Writes +If you can tolerate the risk of some data loss in the event that a node crashes, you could consider using the [Queued Writes API](https://github.com/rqlite/rqlite/blob/master/DOC/QUEUED_WRITES.md). + ## Use more powerful hardware Obviously running rqlite on better disks, better networks, or both, will improve performance. diff --git a/DOC/QUEUED_WRITES.md b/DOC/QUEUED_WRITES.md new file mode 100644 index 00000000..ca7c6f53 --- /dev/null +++ b/DOC/QUEUED_WRITES.md @@ -0,0 +1,22 @@ +# Queued Writes API +> :warning: **This functionality was introduced in version 7.5. It does not exist in earlier releases.** + +rqlite exposes a special API, which will queue up write-requests and execute them in bulk. This allows clients to send multiple distinct requests to a rqlite node, and have rqlite automatically do the batching and bulk insert for the client, without the client doing any extra work. This functionality is best illustrated by an example, showing two requests being queued. +```bash +curl -XPOST 'localhost:4001/db/execute/queue/_default' -H "Content-Type: application/json" -d '[ + ["INSERT INTO foo(name) VALUES(?)", "fiona"], + ["INSERT INTO foo(name) VALUES(?)", "sinead"] +]' +curl -XPOST 'localhost:4001/db/execute/queue/_default' -H "Content-Type: application/json" -d '[ + ["INSERT INTO foo(name) VALUES(?)", "declan"] +]' +``` +rqlite will merge these requests, and execute them as though they had been both contained in a single request. For the same reason that using the [Bulk API](https://github.com/rqlite/rqlite/blob/master/DOC/BULK.md) results in much higher write performance, using the _Queued Writes_ API will also result in much higher write performance. + +The behaviour of the queue rqlite uses to batch the requests is configurable at rqlite launch time. Pass `-h` to `rqlited` to see the queue defaults, and list all configuration options. + +## Caveats +Because the API returns immediately after queuing the requests **but before the data is commited to the SQLite database** there is a risk of data loss in the event the node crashes before queued data is persisted. + +Like most databases there is a trade-off to be made between write-performance and durability. In addition, when the API returns `HTTP 200 OK`, that simply acknowledges that the data has been queued correctly. It does not indicate that the SQL statements will actually be applied successfully to the database. Be sure to check the node's logs if you have any concerns about failed queued writes. + diff --git a/cmd/rqlited/flags.go b/cmd/rqlited/flags.go index 2b964845..5193287f 100644 --- a/cmd/rqlited/flags.go +++ b/cmd/rqlited/flags.go @@ -163,6 +163,16 @@ type Config struct { // a full database re-sync during recovery. RaftNoFreelistSync bool + // WriteQueueCap is the default capacity of Execute queues + WriteQueueCap int + + // WriteQueueBatchSz is the default batch size for Execute queues + WriteQueueBatchSz int + + // WriteQueueTimeout is the default time after which any data will be sent on + // Execute queues, if a batch size has not been reached. + WriteQueueTimeout time.Duration + // CompressionSize sets request query size for compression attempt CompressionSize int @@ -352,6 +362,9 @@ func ParseFlags(name, desc string, build *BuildInfo) (*Config, error) { flag.BoolVar(&config.RaftShutdownOnRemove, "raft-remove-shutdown", false, "Shutdown Raft if node removed") flag.BoolVar(&config.RaftNoFreelistSync, "raft-no-freelist-sync", false, "Do not sync Raft log database freelist to disk") flag.StringVar(&config.RaftLogLevel, "raft-log-level", "INFO", "Minimum log level for Raft module") + flag.IntVar(&config.WriteQueueCap, "write-queue-capacity", 1024, "Write queue capacity") + flag.IntVar(&config.WriteQueueBatchSz, "write-queue-batch-size", 64, "Write queue batch size") + flag.DurationVar(&config.WriteQueueTimeout, "write-queue-timeout", 100*time.Millisecond, "Write queue timeout") flag.IntVar(&config.CompressionSize, "compression-size", 150, "Request query size for compression attempt") flag.IntVar(&config.CompressionBatch, "compression-batch", 5, "Request batch threshold for compression attempt") flag.StringVar(&config.CPUProfile, "cpu-profile", "", "Path to file for CPU profiling information") diff --git a/cmd/rqlited/main.go b/cmd/rqlited/main.go index ed40bb9f..ec062b38 100644 --- a/cmd/rqlited/main.go +++ b/cmd/rqlited/main.go @@ -255,6 +255,9 @@ func startHTTPService(cfg *Config, str *store.Store, cltr *cluster.Client, credS s.TLS1011 = cfg.TLS1011 s.Expvar = cfg.Expvar s.Pprof = cfg.PprofEnabled + s.DefaultQueueCap = cfg.WriteQueueCap + s.DefaultQueueBatchSz = cfg.WriteQueueBatchSz + s.DefaultQueueTimeout = cfg.WriteQueueTimeout s.BuildInfo = map[string]interface{}{ "commit": cmd.Commit, "branch": cmd.Branch, diff --git a/http/service.go b/http/service.go index 826f352b..fdbb18d3 100644 --- a/http/service.go +++ b/http/service.go @@ -25,6 +25,7 @@ import ( "github.com/rqlite/rqlite/auth" "github.com/rqlite/rqlite/command" "github.com/rqlite/rqlite/command/encoding" + "github.com/rqlite/rqlite/queue" "github.com/rqlite/rqlite/store" ) @@ -143,19 +144,22 @@ type Response struct { var stats *expvar.Map const ( - numLeaderNotFound = "leader_not_found" - numExecutions = "executions" - numQueries = "queries" - numRemoteExecutions = "remote_executions" - numRemoteQueries = "remote_queries" - numReadyz = "num_readyz" - numStatus = "num_status" - numBackups = "backups" - numLoad = "loads" - numJoins = "joins" - numNotifies = "notifies" - numAuthOK = "authOK" - numAuthFail = "authFail" + numLeaderNotFound = "leader_not_found" + numExecutions = "executions" + numQueuedExecutions = "queued_executions" + numQueuedExecutionsOK = "queued_executions_ok" + numQueuedExecutionsFailed = "queued_executions_failed" + numQueries = "queries" + numRemoteExecutions = "remote_executions" + numRemoteQueries = "remote_queries" + numReadyz = "num_readyz" + numStatus = "num_status" + numBackups = "backups" + numLoad = "loads" + numJoins = "joins" + numNotifies = "notifies" + numAuthOK = "authOK" + numAuthFail = "authFail" // Default timeout for cluster communications. defaultTimeout = 30 * time.Second @@ -192,6 +196,9 @@ func init() { stats = expvar.NewMap("http") stats.Add(numLeaderNotFound, 0) stats.Add(numExecutions, 0) + stats.Add(numQueuedExecutions, 0) + stats.Add(numQueuedExecutionsOK, 0) + stats.Add(numQueuedExecutionsFailed, 0) stats.Add(numQueries, 0) stats.Add(numRemoteExecutions, 0) stats.Add(numRemoteQueries, 0) @@ -221,11 +228,15 @@ func NewResponse() *Response { // Service provides HTTP service. type Service struct { - addr string // Bind address of the HTTP service. - ln net.Listener // Service listener + closeCh chan struct{} + addr string // Bind address of the HTTP service. + ln net.Listener // Service listener store Store // The Raft-backed database store. + queueDone chan struct{} + stmtQueue *queue.Queue // Queue for queued executes + cluster Cluster // The Cluster service. start time.Time // Start up time. @@ -239,6 +250,10 @@ type Service struct { KeyFile string // Path to SSL private key. TLS1011 bool // Whether older, deprecated TLS should be supported. + DefaultQueueCap int + DefaultQueueBatchSz int + DefaultQueueTimeout time.Duration + credentialStore CredentialStore Expvar bool @@ -253,13 +268,16 @@ type Service struct { // the service performs no authentication and authorization checks. func New(addr string, store Store, cluster Cluster, credentials CredentialStore) *Service { return &Service{ - addr: addr, - store: store, - cluster: cluster, - start: time.Now(), - statuses: make(map[string]StatusReporter), - credentialStore: credentials, - logger: log.New(os.Stderr, "[http] ", log.LstdFlags), + addr: addr, + store: store, + DefaultQueueCap: 1024, + DefaultQueueBatchSz: 128, + DefaultQueueTimeout: 100 * time.Millisecond, + cluster: cluster, + start: time.Now(), + statuses: make(map[string]StatusReporter), + credentialStore: credentials, + logger: log.New(os.Stderr, "[http] ", log.LstdFlags), } } @@ -289,6 +307,14 @@ func (s *Service) Start() error { } s.ln = ln + s.closeCh = make(chan struct{}) + s.queueDone = make(chan struct{}) + + s.stmtQueue = queue.New(s.DefaultQueueCap, s.DefaultQueueBatchSz, s.DefaultQueueTimeout) + go s.runQueue() + s.logger.Printf("execute queue processing started with capacity %d, batch size %d, timeout %s", + s.DefaultQueueCap, s.DefaultQueueBatchSz, s.DefaultQueueTimeout.String()) + go func() { err := server.Serve(s.ln) if err != nil { @@ -302,6 +328,15 @@ func (s *Service) Start() error { // Close closes the service. func (s *Service) Close() { + s.stmtQueue.Close() + + select { + case <-s.queueDone: + default: + close(s.closeCh) + } + <-s.queueDone + s.ln.Close() return } @@ -318,6 +353,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) { switch { case r.URL.Path == "/" || r.URL.Path == "": http.Redirect(w, r, "/status", http.StatusFound) + case strings.HasPrefix(r.URL.Path, "/db/execute/queue/_default"): + stats.Add(numQueuedExecutions, 1) + s.handleQueuedExecute(w, r) case strings.HasPrefix(r.URL.Path, "/db/execute"): stats.Add(numExecutions, 1) s.handleExecute(w, r) @@ -682,10 +720,20 @@ func (s *Service) handleStatus(w http.ResponseWriter, r *http.Request) { oss["hostname"] = hostname } + qs, err := s.stmtQueue.Stats() + if err != nil { + http.Error(w, fmt.Sprintf("queue stats: %s", err.Error()), + http.StatusInternalServerError) + return + } + queueStats := map[string]interface{}{ + "_default": qs, + } httpStatus := map[string]interface{}{ "bind_addr": s.Addr().String(), "auth": prettyEnabled(s.credentialStore != nil), "cluster": clusterStatus, + "queue": queueStats, } nodeStatus := map[string]interface{}{ @@ -892,6 +940,61 @@ func (s *Service) handleReadyz(w http.ResponseWriter, r *http.Request) { w.Write([]byte("[+]node ok\n[+]leader does not exist")) } +// handleQueuedExecute handles queued queries that modify the database. +func (s *Service) handleQueuedExecute(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + + if !s.CheckRequestPerm(r, PermExecute) { + w.WriteHeader(http.StatusUnauthorized) + return + } + + if r.Method != "POST" { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + // Perform a leader check, unless disabled. This prevents generating queued writes on + // a node that does not appear to be connected to a cluster (even a single-node cluster). + noLeader, err := noLeader(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + if !noLeader { + addr, err := s.store.LeaderAddr() + if err != nil || addr == "" { + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) + return + } + } + + resp := NewResponse() + + b, err := ioutil.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + r.Body.Close() + + stmts, err := ParseRequest(b) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + for i := range stmts { + if err := s.stmtQueue.Write(stmts[i]); err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + } + + s.writeResponse(w, r, resp) + return +} + // handleExecute handles queries that modify the database. func (s *Service) handleExecute(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "application/json; charset=utf-8") @@ -1167,6 +1270,48 @@ func (s *Service) LeaderAPIAddr() string { return apiAddr } +func (s *Service) runQueue() { + defer close(s.queueDone) + retryDelay := time.Second + + var err error + for { + select { + case <-s.closeCh: + return + case stmts := <-s.stmtQueue.C: + er := &command.ExecuteRequest{ + Request: &command.Request{ + Statements: stmts, + }, + } + for { + _, err = s.store.Execute(er) + if err != nil { + if err == store.ErrNotLeader { + addr, err := s.store.LeaderAddr() + if err != nil || addr == "" { + s.logger.Println("execute queue can't find leader") + stats.Add(numQueuedExecutionsFailed, 1) + time.Sleep(retryDelay) + continue + } + _, err = s.cluster.Execute(er, addr, defaultTimeout) + if err != nil { + s.logger.Printf("execute queue write failed: %s", err.Error()) + time.Sleep(retryDelay) + continue + } + stats.Add(numRemoteExecutions, 1) + } + } + stats.Add(numQueuedExecutionsOK, 1) + break + } + } + } +} + type checkNodesResponse struct { apiAddr string reachable bool diff --git a/http/service_test.go b/http/service_test.go index 9cf88896..e3b74b63 100644 --- a/http/service_test.go +++ b/http/service_test.go @@ -486,6 +486,7 @@ func Test_401Routes_NoBasicAuth(t *testing.T) { client := &http.Client{} for _, path := range []string{ + "/db/execute/queue/_default", "/db/execute", "/db/query", "/db/backup", @@ -528,6 +529,7 @@ func Test_401Routes_BasicAuthBadPassword(t *testing.T) { client := &http.Client{} for _, path := range []string{ + "/db/execute/queue/_default", "/db/execute", "/db/query", "/db/backup", @@ -575,6 +577,7 @@ func Test_401Routes_BasicAuthBadPerm(t *testing.T) { client := &http.Client{} for _, path := range []string{ + "/db/execute/queue/_default", "/db/execute", "/db/query", "/db/backup", diff --git a/queue/queue.go b/queue/queue.go new file mode 100644 index 00000000..d57a18a8 --- /dev/null +++ b/queue/queue.go @@ -0,0 +1,120 @@ +package queue + +import ( + "time" + + "github.com/rqlite/rqlite/command" +) + +// Queue is a batching queue with a timeout. +type Queue struct { + maxSize int + batchSize int + timeout time.Duration + + batchCh chan *command.Statement + sendCh chan []*command.Statement + C <-chan []*command.Statement + + done chan struct{} + closed chan struct{} + flush chan struct{} + + // Whitebox unit-testing + numTimeouts int +} + +// New returns a instance of a Queue +func New(maxSize, batchSize int, t time.Duration) *Queue { + q := &Queue{ + maxSize: maxSize, + batchSize: batchSize, + timeout: t, + batchCh: make(chan *command.Statement, maxSize), + sendCh: make(chan []*command.Statement, maxSize), + done: make(chan struct{}), + closed: make(chan struct{}), + flush: make(chan struct{}), + } + + q.C = q.sendCh + go q.run() + return q +} + +// Write queues a request. +func (q *Queue) Write(stmt *command.Statement) error { + if stmt == nil { + return nil + } + q.batchCh <- stmt + return nil +} + +// Flush flushes the queue +func (q *Queue) Flush() error { + q.flush <- struct{}{} + return nil +} + +// Close closes the queue. A closed queue should not be used. +func (q *Queue) Close() error { + select { + case <-q.done: + default: + close(q.done) + <-q.closed + } + return nil +} + +// Depth returns the number of queue requests +func (q *Queue) Depth() int { + return len(q.batchCh) +} + +// Stats returns stats on this queue. +func (q *Queue) Stats() (map[string]interface{}, error) { + return map[string]interface{}{ + "max_size": q.maxSize, + "batch_size": q.batchSize, + "timeout": q.timeout, + }, nil +} + +func (q *Queue) run() { + defer close(q.closed) + var stmts []*command.Statement + timer := time.NewTimer(q.timeout) + timer.Stop() + + writeFn := func() { + newStmts := make([]*command.Statement, len(stmts)) + copy(newStmts, stmts) + q.sendCh <- newStmts + + stmts = nil + timer.Stop() + } + + for { + select { + case s := <-q.batchCh: + stmts = append(stmts, s) + if len(stmts) == 1 { + timer.Reset(q.timeout) + } + if len(stmts) >= q.batchSize { + writeFn() + } + case <-timer.C: + q.numTimeouts++ + writeFn() + case <-q.flush: + writeFn() + case <-q.done: + timer.Stop() + return + } + } +} diff --git a/queue/queue_test.go b/queue/queue_test.go new file mode 100644 index 00000000..0a8770b2 --- /dev/null +++ b/queue/queue_test.go @@ -0,0 +1,206 @@ +package queue + +import ( + "testing" + "time" + + "github.com/rqlite/rqlite/command" +) + +var testStmt = &command.Statement{ + Sql: "SELECT * FROM foo", +} + +func Test_NewQueue(t *testing.T) { + q := New(1, 1, 100*time.Millisecond) + if q == nil { + t.Fatalf("failed to create new Queue") + } + defer q.Close() +} + +func Test_NewQueueWriteNil(t *testing.T) { + q := New(1, 1, 60*time.Second) + defer q.Close() + + if err := q.Write(nil); err != nil { + t.Fatalf("failing to write nil: %s", err.Error()) + } +} + +func Test_NewQueueWriteBatchSizeSingle(t *testing.T) { + q := New(1024, 1, 60*time.Second) + defer q.Close() + + if err := q.Write(testStmt); err != nil { + t.Fatalf("failed to write: %s", err.Error()) + } + + select { + case stmts := <-q.C: + if len(stmts) != 1 { + t.Fatalf("received wrong length slice") + } + if stmts[0].Sql != "SELECT * FROM foo" { + t.Fatalf("received wrong SQL") + } + case <-time.NewTimer(5 * time.Second).C: + t.Fatalf("timed out waiting for statement") + } +} + +func Test_NewQueueWriteBatchSizeMulti(t *testing.T) { + q := New(1024, 5, 60*time.Second) + defer q.Close() + + // Write a batch size and wait for it. + for i := 0; i < 5; i++ { + if err := q.Write(testStmt); err != nil { + t.Fatalf("failed to write: %s", err.Error()) + } + } + select { + case stmts := <-q.C: + if len(stmts) != 5 { + t.Fatalf("received wrong length slice") + } + if q.numTimeouts != 0 { + t.Fatalf("queue timeout expired?") + } + case <-time.NewTimer(5 * time.Second).C: + t.Fatalf("timed out waiting for first statements") + } + + // Write one more than a batch size, should still get a batch. + for i := 0; i < 6; i++ { + if err := q.Write(testStmt); err != nil { + t.Fatalf("failed to write: %s", err.Error()) + } + } + select { + case stmts := <-q.C: + if len(stmts) < 5 { + t.Fatalf("received too-short slice") + } + if q.numTimeouts != 0 { + t.Fatalf("queue timeout expired?") + } + case <-time.NewTimer(5 * time.Second).C: + t.Fatalf("timed out waiting for second statements") + } +} + +func Test_NewQueueWriteTimeout(t *testing.T) { + q := New(1024, 10, 1*time.Second) + defer q.Close() + + if err := q.Write(testStmt); err != nil { + t.Fatalf("failed to write: %s", err.Error()) + } + + select { + case stmts := <-q.C: + if len(stmts) != 1 { + t.Fatalf("received wrong length slice") + } + if stmts[0].Sql != "SELECT * FROM foo" { + t.Fatalf("received wrong SQL") + } + if q.numTimeouts != 1 { + t.Fatalf("queue timeout didn't expire") + } + case <-time.NewTimer(5 * time.Second).C: + t.Fatalf("timed out waiting for statement") + } +} + +// Test_NewQueueWriteTimeoutMulti ensures that timer expiring +// twice in a row works fine. +func Test_NewQueueWriteTimeoutMulti(t *testing.T) { + q := New(1024, 10, 1*time.Second) + defer q.Close() + + if err := q.Write(testStmt); err != nil { + t.Fatalf("failed to write: %s", err.Error()) + } + select { + case stmts := <-q.C: + if len(stmts) != 1 { + t.Fatalf("received wrong length slice") + } + if stmts[0].Sql != "SELECT * FROM foo" { + t.Fatalf("received wrong SQL") + } + if q.numTimeouts != 1 { + t.Fatalf("queue timeout didn't expire") + } + case <-time.NewTimer(5 * time.Second).C: + t.Fatalf("timed out waiting for first statement") + } + + if err := q.Write(testStmt); err != nil { + t.Fatalf("failed to write: %s", err.Error()) + } + select { + case stmts := <-q.C: + if len(stmts) != 1 { + t.Fatalf("received wrong length slice") + } + if stmts[0].Sql != "SELECT * FROM foo" { + t.Fatalf("received wrong SQL") + } + if q.numTimeouts != 2 { + t.Fatalf("queue timeout didn't expire") + } + case <-time.NewTimer(5 * time.Second).C: + t.Fatalf("timed out waiting for second statement") + } +} + +// Test_NewQueueWriteTimeoutBatch ensures that timer expiring +// followed by a batch, works fine. +func Test_NewQueueWriteTimeoutBatch(t *testing.T) { + q := New(1024, 2, 1*time.Second) + defer q.Close() + + if err := q.Write(testStmt); err != nil { + t.Fatalf("failed to write: %s", err.Error()) + } + + select { + case stmts := <-q.C: + if len(stmts) != 1 { + t.Fatalf("received wrong length slice") + } + if stmts[0].Sql != "SELECT * FROM foo" { + t.Fatalf("received wrong SQL") + } + if q.numTimeouts != 1 { + t.Fatalf("queue timeout didn't expire") + } + case <-time.NewTimer(5 * time.Second).C: + t.Fatalf("timed out waiting for statement") + } + + if err := q.Write(testStmt); err != nil { + t.Fatalf("failed to write: %s", err.Error()) + } + if err := q.Write(testStmt); err != nil { + t.Fatalf("failed to write: %s", err.Error()) + } + select { + case stmts := <-q.C: + // Should happen before the timeout expires. + if len(stmts) != 2 { + t.Fatalf("received wrong length slice") + } + if stmts[0].Sql != "SELECT * FROM foo" { + t.Fatalf("received wrong SQL") + } + if q.numTimeouts != 1 { + t.Fatalf("queue timeout expired?") + } + case <-time.NewTimer(5 * time.Second).C: + t.Fatalf("timed out waiting for statement") + } +} diff --git a/system_test/full_system_test.py b/system_test/full_system_test.py index 10c7dfb7..b82b4468 100644 --- a/system_test/full_system_test.py +++ b/system_test/full_system_test.py @@ -462,6 +462,18 @@ class Node(object): raise_for_status(r) return r.json() + def execute_queued(self, statement, params=None, queue='_default'): + body = [statement] + if params is not None: + try: + body = body + params + except TypeError: + # Presumably not a list, so append as an object. + body.append(params) + r = requests.post(self._execute_queued_url(queue), data=json.dumps([body])) + raise_for_status(r) + return r.json() + def backup(self, file): with open(file, 'wb') as fd: r = requests.get(self._backup_url()) @@ -514,6 +526,8 @@ class Node(object): if redirect: rd = "?redirect" return 'http://' + self.APIAddr() + '/db/execute' + rd + def _execute_queued_url(self, queue): + return 'http://' + self.APIAddr() + '/db/execute/queue/' + queue def _backup_url(self): return 'http://' + self.APIAddr() + '/db/backup' def _load_url(self): @@ -1159,6 +1173,31 @@ class TestRequestForwarding(unittest.TestCase): j = f.query('SELECT * FROM foo', level="strong") self.assertEqual(j, d_("{'results': [{'values': [[1, 'fiona']], 'types': ['integer', 'text'], 'columns': ['id', 'name']}]}")) + def test_execute_queued_forward(self): + l = self.cluster.wait_for_leader() + j = l.execute('CREATE TABLE foo (id INTEGER NOT NULL PRIMARY KEY, name TEXT)') + self.assertEqual(j, d_("{'results': [{}]}")) + + f = self.cluster.followers()[0] + j = f.execute('INSERT INTO foo(name) VALUES("fiona")') + self.assertEqual(j, d_("{'results': [{'last_insert_id': 1, 'rows_affected': 1}]}")) + fsmIdx = l.wait_for_all_fsm() + + j = f.execute_queued('INSERT INTO foo(name) VALUES("declan")') + self.assertEqual(j, d_("{'results': []}")) + + # Wait for queued write to happen. + timeout = 10 + t = 0 + while True: + j = l.query('SELECT * FROM foo') + if j == d_("{'results': [{'values': [[1, 'fiona'], [2, 'declan']], 'types': ['integer', 'text'], 'columns': ['id', 'name']}]}"): + break + if t > timeout: + raise Exception('timeout', nSnaps) + time.sleep(1) + t+=1 + class TestEndToEndNonVoter(unittest.TestCase): def setUp(self): self.leader = Node(RQLITED_PATH, '0') diff --git a/system_test/helpers.go b/system_test/helpers.go index 27fe815a..d67f081f 100644 --- a/system_test/helpers.go +++ b/system_test/helpers.go @@ -88,7 +88,7 @@ func (n *Node) ExecuteMulti(stmts []string) (string, error) { return n.postExecute(string(j)) } -// ExecuteParameterized executes a single paramterized query against the ndoe +// ExecuteParameterized executes a single paramterized query against the node func (n *Node) ExecuteParameterized(stmt []interface{}) (string, error) { m := make([][]interface{}, 1) m[0] = stmt @@ -100,6 +100,20 @@ func (n *Node) ExecuteParameterized(stmt []interface{}) (string, error) { return n.postExecute(string(j)) } +// ExecuteQueued sends a single statement to the node's Execute queue +func (n *Node) ExecuteQueued(stmt string) (string, error) { + return n.ExecuteQueuedMulti([]string{stmt}) +} + +// ExecuteQueuedMulti sends multiple statements to the node's Execute queue +func (n *Node) ExecuteQueuedMulti(stmts []string) (string, error) { + j, err := json.Marshal(stmts) + if err != nil { + return "", err + } + return n.postExecuteQueued(string(j)) +} + // Query runs a single query against the node. func (n *Node) Query(stmt string) (string, error) { v, _ := url.Parse("http://" + n.APIAddr + "/db/query") @@ -321,6 +335,19 @@ func (n *Node) postExecute(stmt string) (string, error) { return string(body), nil } +func (n *Node) postExecuteQueued(stmt string) (string, error) { + resp, err := http.Post("http://"+n.APIAddr+"/db/execute/queue/_default", "application/json", strings.NewReader(stmt)) + if err != nil { + return "", err + } + defer resp.Body.Close() + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + return "", err + } + return string(body), nil +} + func (n *Node) postQuery(stmt string) (string, error) { resp, err := http.Post("http://"+n.APIAddr+"/db/query", "application/json", strings.NewReader(stmt)) if err != nil { diff --git a/system_test/request_forwarding_test.go b/system_test/request_forwarding_test.go index f569320a..3365c4ca 100644 --- a/system_test/request_forwarding_test.go +++ b/system_test/request_forwarding_test.go @@ -180,6 +180,76 @@ func Test_MultiNodeClusterRequestForwardOK(t *testing.T) { } } +// Test_MultiNodeClusterQueuedRequestForwardOK tests that queued writes are forwarded +// correctly. +func Test_MultiNodeClusterQueuedRequestForwardOK(t *testing.T) { + node1 := mustNewLeaderNode() + defer node1.Deprovision() + + node2 := mustNewNode(false) + defer node2.Deprovision() + if err := node2.Join(node1); err != nil { + t.Fatalf("node failed to join leader: %s", err.Error()) + } + _, err := node2.WaitForLeader() + if err != nil { + t.Fatalf("failed waiting for leader: %s", err.Error()) + } + + // Get the new leader, in case it changed. + c := Cluster{node1, node2} + leader, err := c.Leader() + if err != nil { + t.Fatalf("failed to find cluster leader: %s", err.Error()) + } + + // Create table and confirm its existence. + res, err := leader.Execute(`CREATE TABLE foo (id integer not null primary key, name text)`) + if err != nil { + t.Fatalf("failed to create table: %s", err.Error()) + } + if exp, got := `{"results":[{}]}`, res; exp != got { + t.Fatalf("got incorrect response from follower exp: %s, got: %s", exp, got) + } + rows, err := leader.Query(`SELECT COUNT(*) FROM foo`) + if err != nil { + t.Fatalf("failed to query for count: %s", err.Error()) + } + if exp, got := `{"results":[{"columns":["COUNT(*)"],"types":[""],"values":[[0]]}]}`, rows; exp != got { + t.Fatalf("got incorrect response from follower exp: %s, got: %s", exp, got) + } + + // Write a request to a follower's queue, checking it's eventually sent to the leader. + followers, err := c.Followers() + if err != nil { + t.Fatalf("failed to get followers: %s", err.Error()) + } + if len(followers) != 1 { + t.Fatalf("got incorrect number of followers: %d", len(followers)) + } + res, err = followers[0].ExecuteQueued(`INSERT INTO foo(name) VALUES("fiona")`) + if err != nil { + t.Fatalf("failed to insert record: %s", err.Error()) + } + + ticker := time.NewTicker(10 * time.Millisecond) + timer := time.NewTimer(5 * time.Second) + for { + select { + case <-ticker.C: + r, err := leader.Query(`SELECT COUNT(*) FROM foo`) + if err != nil { + t.Fatalf("failed to query for count: %s", err.Error()) + } + if r == `{"results":[{"columns":["COUNT(*)"],"types":[""],"values":[[1]]}]}` { + return + } + case <-timer.C: + t.Fatalf("timed out waiting for queued writes") + } + } +} + func executeRequestFromString(s string) *command.ExecuteRequest { return executeRequestFromStrings([]string{s}) } diff --git a/system_test/single_node_test.go b/system_test/single_node_test.go index 6dd14391..9b6309de 100644 --- a/system_test/single_node_test.go +++ b/system_test/single_node_test.go @@ -333,6 +333,44 @@ func Test_SingleNodeParameterizedNamed(t *testing.T) { } } +func Test_SingleNodeQueued(t *testing.T) { + node := mustNewLeaderNode() + defer node.Deprovision() + + _, err := node.Execute(`CREATE TABLE foo (id integer not null primary key, name text)`) + if err != nil { + t.Fatalf(`CREATE TABLE failed: %s`, err.Error()) + } + + qWrites := []string{ + `INSERT INTO foo(name) VALUES("fiona")`, + `INSERT INTO foo(name) VALUES("fiona")`, + `INSERT INTO foo(name) VALUES("fiona")`, + } + _, err = node.ExecuteQueuedMulti(qWrites) + if err != nil { + t.Fatalf(`queued write failed: %s`, err.Error()) + } + + ticker := time.NewTicker(10 * time.Millisecond) + timer := time.NewTimer(5 * time.Second) + for { + select { + case <-ticker.C: + r, err := node.Query(`SELECT COUNT(*) FROM foo`) + if err != nil { + t.Fatalf(`query failed: %s`, err.Error()) + } + if r == `{"results":[{"columns":["COUNT(*)"],"types":[""],"values":[[3]]}]}` { + return + } + case <-timer.C: + t.Fatalf("timed out waiting for queued writes") + + } + } +} + // Test_SingleNodeSQLInjection demonstrates that using the non-parameterized API is vulnerable to // SQL injection attacks. func Test_SingleNodeSQLInjection(t *testing.T) {