1
0
Fork 0

Refactor query parameters into own code

Makes rest of HTTP code cleaner.
master
Philip O'Toole 9 months ago
parent a84429df6a
commit 384ebab5f7

@ -0,0 +1,173 @@
package http
import (
"net/http"
"net/url"
"strconv"
"strings"
"time"
"github.com/rqlite/rqlite/command"
)
// QueryParams represents the query parameters passed in an HTTP request.
type QueryParams map[string]string
// NewQueryParams returns a new QueryParams from the given HTTP request.
func NewQueryParams(r *http.Request) (QueryParams, error) {
qp := make(QueryParams)
values, err := url.ParseQuery(strings.ToLower(r.URL.RawQuery))
if err != nil {
return nil, err
}
for k, v := range values {
qp[k] = v[0]
}
for _, k := range []string{"timeout", "freshness"} {
t, ok := qp[k]
if ok {
_, err := time.ParseDuration(t)
if err != nil {
return nil, err
}
}
}
sz, ok := qp["chunk_kb"]
if ok {
_, err := strconv.Atoi(sz)
if err != nil {
return nil, err
}
}
return qp, nil
}
// Timings returns true if the query parameters indicate timings should be returned.
func (qp QueryParams) Timings() bool {
return qp.HasKey("timings")
}
// Tx returns true if the query parameters indicate the query should be executed in a transaction.
func (qp QueryParams) Tx() bool {
return qp.HasKey("transaction")
}
// Query returns true if the query parameters request queued operation
func (qp QueryParams) Queue() bool {
return qp.HasKey("queue")
}
// Pretty returns true if the query parameters indicate pretty-printing should be returned.
func (qp QueryParams) Pretty() bool {
return qp.HasKey("pretty")
}
// Bypass returns true if the query parameters indicate bypass mode.
func (qp QueryParams) Bypass() bool {
return qp.HasKey("bypass")
}
// Wait returns true if the query parameters indicate the query should wait.
func (qp QueryParams) Wait() bool {
return qp.HasKey("wait")
}
// ChunkKB returns the requested chunk size.
func (qp QueryParams) ChunkKB(defSz int) int {
s, ok := qp["chunk_kb"]
if !ok {
return defSz
}
sz, _ := strconv.Atoi(s)
return sz * 1024
}
// Associative returns true if the query parameters request associative results.
func (qp QueryParams) Associative() bool {
return qp.HasKey("associative")
}
// NoRewrite returns true if the query parameters request no rewriting of queries.
func (qp QueryParams) NoRewriteRandom() bool {
return qp.HasKey("norwrandom")
}
// NonVoters returns true if the query parameters request non-voters to be included in results.
func (qp QueryParams) NonVoters() bool {
return qp.HasKey("nonvoters")
}
// NoLeader returns true if the query parameters request no leader mode
func (qp QueryParams) NoLeader() bool {
return qp.HasKey("noleader")
}
// Redirect returns true if the query parameters request redirect mode
func (qp QueryParams) Redirect() bool {
return qp.HasKey("redirect")
}
// Vacuum returns true if the query parameters request vacuum mode
func (qp QueryParams) Vacuum() bool {
return qp.HasKey("vacuum")
}
// Level returns the requested consistency level.
func (qp QueryParams) Level() command.QueryRequest_Level {
lvl := qp["level"]
switch strings.ToLower(lvl) {
case "none":
return command.QueryRequest_QUERY_REQUEST_LEVEL_NONE
case "weak":
return command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK
case "strong":
return command.QueryRequest_QUERY_REQUEST_LEVEL_STRONG
default:
return command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK
}
}
// BackupFormat returns the requested backup format.
func (qp QueryParams) BackupFormat() command.BackupRequest_Format {
f := qp["fmt"]
switch f {
case "sql":
return command.BackupRequest_BACKUP_REQUEST_FORMAT_SQL
default:
return command.BackupRequest_BACKUP_REQUEST_FORMAT_BINARY
}
}
// Query returns the requested query.
func (qp QueryParams) Query() string {
return qp["q"]
}
// Freshness returns the requested freshness duration.
func (qp QueryParams) Freshness() time.Duration {
f := qp["freshness"]
d, _ := time.ParseDuration(f)
return d
}
// Timeout returns the requested timeout duration.
func (qp QueryParams) Timeout(def time.Duration) time.Duration {
t, ok := qp["timeout"]
if !ok {
return def
}
d, _ := time.ParseDuration(t)
return d
}
// Version returns the requested version.
func (qp QueryParams) Version() string {
return qp["version"]
}
// HasKey returns true if the given key is present in the query parameters.
func (qp QueryParams) HasKey(k string) bool {
_, ok := qp[k]
return ok
}

@ -16,7 +16,6 @@ import (
"net/http/pprof"
"os"
"runtime"
"strconv"
"strings"
"sync"
"time"
@ -445,38 +444,44 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return
}
params, err := NewQueryParams(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
switch {
case r.URL.Path == "/" || r.URL.Path == "":
http.Redirect(w, r, "/status", http.StatusFound)
case strings.HasPrefix(r.URL.Path, "/db/execute"):
stats.Add(numExecutions, 1)
s.handleExecute(w, r)
s.handleExecute(w, r, params)
case strings.HasPrefix(r.URL.Path, "/db/query"):
stats.Add(numQueries, 1)
s.handleQuery(w, r)
s.handleQuery(w, r, params)
case strings.HasPrefix(r.URL.Path, "/db/request"):
stats.Add(numRequests, 1)
s.handleRequest(w, r)
s.handleRequest(w, r, params)
case strings.HasPrefix(r.URL.Path, "/db/backup"):
stats.Add(numBackups, 1)
s.handleBackup(w, r)
s.handleBackup(w, r, params)
case strings.HasPrefix(r.URL.Path, "/db/load"):
stats.Add(numLoad, 1)
s.handleLoad(w, r)
s.handleLoad(w, r, params)
case strings.HasPrefix(r.URL.Path, "/remove"):
s.handleRemove(w, r)
s.handleRemove(w, r, params)
case strings.HasPrefix(r.URL.Path, "/status"):
stats.Add(numStatus, 1)
s.handleStatus(w, r)
s.handleStatus(w, r, params)
case strings.HasPrefix(r.URL.Path, "/nodes"):
s.handleNodes(w, r)
s.handleNodes(w, r, params)
case strings.HasPrefix(r.URL.Path, "/readyz"):
stats.Add(numReadyz, 1)
s.handleReadyz(w, r)
s.handleReadyz(w, r, params)
case r.URL.Path == "/debug/vars":
s.handleExpvar(w, r)
s.handleExpvar(w, r, params)
case strings.HasPrefix(r.URL.Path, "/debug/pprof"):
s.handlePprof(w, r)
s.handlePprof(w, r, params)
default:
w.WriteHeader(http.StatusNotFound)
}
@ -496,7 +501,7 @@ func (s *Service) RegisterStatus(key string, stat StatusReporter) error {
}
// handleRemove handles cluster-remove requests.
func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) {
func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request, qp QueryParams) {
if !s.CheckRequestPerm(r, auth.PermRemove) {
w.WriteHeader(http.StatusUnauthorized)
return
@ -529,12 +534,6 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) {
return
}
timeout, err := timeoutParam(r, defaultTimeout)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
rn := &command.RemoveNodeRequest{
Id: remoteID,
}
@ -542,7 +541,7 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) {
err = s.store.Remove(rn)
if err != nil {
if err == store.ErrNotLeader {
if s.DoRedirect(w, r) {
if s.DoRedirect(w, r, qp) {
return
}
@ -564,7 +563,7 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) {
}
w.Header().Add(ServedByHTTPHeader, addr)
removeErr := s.cluster.RemoveNode(rn, addr, makeCredentials(username, password), timeout)
removeErr := s.cluster.RemoveNode(rn, addr, makeCredentials(username, password), qp.Timeout(defaultTimeout))
if removeErr != nil {
if removeErr.Error() == "unauthorized" {
http.Error(w, "remote remove node not authorized", http.StatusUnauthorized)
@ -582,7 +581,7 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) {
}
// handleBackup returns the consistent database snapshot.
func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request) {
func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request, qp QueryParams) {
if !s.CheckRequestPerm(r, auth.PermBackup) {
w.WriteHeader(http.StatusUnauthorized)
return
@ -593,40 +592,16 @@ func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request) {
return
}
noLeader, err := noLeader(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
format, err := backupFormat(w, r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
vacuum, err := isVacuum(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
timeout, err := timeoutParam(r, defaultTimeout)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
br := &command.BackupRequest{
Format: format,
Leader: !noLeader,
Vacuum: vacuum,
Format: qp.BackupFormat(),
Leader: !qp.NoLeader(),
Vacuum: qp.Vacuum(),
}
err = s.store.Backup(br, w)
err := s.store.Backup(br, w)
if err != nil {
if err == store.ErrNotLeader {
if s.DoRedirect(w, r) {
if s.DoRedirect(w, r, qp) {
return
}
@ -648,7 +623,7 @@ func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request) {
}
w.Header().Add(ServedByHTTPHeader, addr)
backupErr := s.cluster.Backup(br, addr, makeCredentials(username, password), timeout, w)
backupErr := s.cluster.Backup(br, addr, makeCredentials(username, password), qp.Timeout(defaultTimeout), w)
if backupErr != nil {
if backupErr.Error() == "unauthorized" {
http.Error(w, "remote backup not authorized", http.StatusUnauthorized)
@ -671,7 +646,7 @@ func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request) {
}
// handleLoad loads the database from the given SQLite database file or SQLite dump.
func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request, qp QueryParams) {
startTime := time.Now()
if !s.CheckRequestPerm(r, auth.PermLoad) {
@ -686,18 +661,6 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
resp := NewResponse()
timings, err := isTimings(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
chunkSz, err := chunkSizeParam(r, defaultChunkSize)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// Peek at the incoming bytes so we can determine if this is a SQLite database
validSQLite := false
bufReader := bufio.NewReader(r.Body)
@ -725,12 +688,12 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
r.Body.Close()
queries := []string{string(b)}
er := executeRequestFromStrings(queries, timings, false)
er := executeRequestFromStrings(queries, qp.Timings(), false)
results, err := s.store.Execute(er)
if err != nil {
if err == store.ErrNotLeader {
if s.DoRedirect(w, r) {
if s.DoRedirect(w, r, qp) {
return
}
}
@ -740,7 +703,7 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
}
resp.end = time.Now()
} else {
chunker := chunking.NewChunker(bufReader, int64(chunkSz))
chunker := chunking.NewChunker(bufReader, int64(qp.ChunkKB(defaultChunkSize)))
for {
chunk, err := chunker.Next()
if err != nil {
@ -752,11 +715,11 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
} else if err != nil && err == store.ErrNotLeader {
if s.DoRedirect(w, r) {
if s.DoRedirect(w, r, qp) {
return
}
addr, err := s.loadClusterChunk(r, chunk)
addr, err := s.loadClusterChunk(r, qp, chunk)
if err != nil {
if err == ErrRemoteLoadNotAuthorized {
http.Error(w, err.Error(), http.StatusUnauthorized)
@ -765,7 +728,7 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
} else {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
s.loadClusterChunk(r, chunker.Abort())
s.loadClusterChunk(r, qp, chunker.Abort())
return
}
w.Header().Add(ServedByHTTPHeader, addr)
@ -787,10 +750,10 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
}
s.logger.Printf("load request finished in %s", time.Since(startTime).String())
s.writeResponse(w, r, resp)
s.writeResponse(w, r, qp, resp)
}
func (s *Service) loadClusterChunk(r *http.Request, chunk *command.LoadChunkRequest) (string, error) {
func (s *Service) loadClusterChunk(r *http.Request, qp QueryParams, chunk *command.LoadChunkRequest) (string, error) {
addr, err := s.store.LeaderAddr()
if err != nil {
return "", err
@ -803,11 +766,7 @@ func (s *Service) loadClusterChunk(r *http.Request, chunk *command.LoadChunkRequ
if !ok {
username = ""
}
timeout, err := timeoutParam(r, defaultTimeout)
if err != nil {
return "", err
}
err = s.cluster.LoadChunk(chunk, addr, makeCredentials(username, password), timeout)
err = s.cluster.LoadChunk(chunk, addr, makeCredentials(username, password), qp.Timeout(defaultTimeout))
if err != nil {
if err.Error() == "unauthorized" {
return "", ErrRemoteLoadNotAuthorized
@ -819,7 +778,7 @@ func (s *Service) loadClusterChunk(r *http.Request, chunk *command.LoadChunkRequ
}
// handleStatus returns status on the system.
func (s *Service) handleStatus(w http.ResponseWriter, r *http.Request) {
func (s *Service) handleStatus(w http.ResponseWriter, r *http.Request, qp QueryParams) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
if !s.CheckRequestPerm(r, auth.PermStatus) {
@ -925,9 +884,8 @@ func (s *Service) handleStatus(w http.ResponseWriter, r *http.Request) {
}
}()
pretty, _ := isPretty(r)
var b []byte
if pretty {
if qp.Pretty() {
b, err = json.MarshalIndent(status, "", " ")
} else {
b, err = json.Marshal(status)
@ -957,7 +915,7 @@ func (s *Service) handleStatus(w http.ResponseWriter, r *http.Request) {
// handleNodes returns status on the other voting nodes in the system.
// This attempts to contact all the nodes in the cluster, so may take
// some time to return.
func (s *Service) handleNodes(w http.ResponseWriter, r *http.Request) {
func (s *Service) handleNodes(w http.ResponseWriter, r *http.Request, qp QueryParams) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
if !s.CheckRequestPerm(r, auth.PermStatus) {
@ -970,12 +928,6 @@ func (s *Service) handleNodes(w http.ResponseWriter, r *http.Request) {
return
}
timeout, err := timeoutParam(r, defaultTimeout)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
includeNonVoters, err := nonVoters(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -1004,16 +956,10 @@ func (s *Service) handleNodes(w http.ResponseWriter, r *http.Request) {
http.StatusInternalServerError)
return
}
nodes.Test(s.cluster, lAddr, timeout)
nodes.Test(s.cluster, lAddr, qp.Timeout(defaultTimeout))
ver, err := verParam(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
enc := NewNodesRespEncoder(w, ver != "2")
pretty, _ := isPretty(r)
if pretty {
enc := NewNodesRespEncoder(w, qp.Version() != "2")
if qp.Pretty() {
enc.SetIndent("", " ")
}
err = enc.Encode(nodes)
@ -1024,7 +970,7 @@ func (s *Service) handleNodes(w http.ResponseWriter, r *http.Request) {
}
// handleReadyz returns whether the node is ready.
func (s *Service) handleReadyz(w http.ResponseWriter, r *http.Request) {
func (s *Service) handleReadyz(w http.ResponseWriter, r *http.Request, qp QueryParams) {
if !s.CheckRequestPerm(r, auth.PermReady) {
w.WriteHeader(http.StatusUnauthorized)
return
@ -1047,12 +993,6 @@ func (s *Service) handleReadyz(w http.ResponseWriter, r *http.Request) {
return
}
timeout, err := timeoutParam(r, defaultTimeout)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
lAddr, err := s.store.LeaderAddr()
if err != nil {
http.Error(w, fmt.Sprintf("leader address: %s", err.Error()),
@ -1066,7 +1006,7 @@ func (s *Service) handleReadyz(w http.ResponseWriter, r *http.Request) {
return
}
_, err = s.cluster.GetNodeAPIAddr(lAddr, timeout)
_, err = s.cluster.GetNodeAPIAddr(lAddr, qp.Timeout(defaultTimeout))
if err != nil {
w.WriteHeader(http.StatusServiceUnavailable)
w.Write([]byte(fmt.Sprintf("[+]node ok\n[+]leader not contactable: %s", err.Error())))
@ -1083,7 +1023,7 @@ func (s *Service) handleReadyz(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("[+]node ok\n[+]leader ok\n[+]store ok"))
}
func (s *Service) handleExecute(w http.ResponseWriter, r *http.Request) {
func (s *Service) handleExecute(w http.ResponseWriter, r *http.Request, qp QueryParams) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
if !s.CheckRequestPerm(r, auth.PermExecute) {
@ -1096,22 +1036,16 @@ func (s *Service) handleExecute(w http.ResponseWriter, r *http.Request) {
return
}
queue, err := isQueue(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if queue {
if qp.Queue() {
stats.Add(numQueuedExecutions, 1)
s.queuedExecute(w, r)
s.queuedExecute(w, r, qp)
} else {
s.execute(w, r)
s.execute(w, r, qp)
}
}
// queuedExecute handles queued queries that modify the database.
func (s *Service) queuedExecute(w http.ResponseWriter, r *http.Request) {
func (s *Service) queuedExecute(w http.ResponseWriter, r *http.Request, qp QueryParams) {
resp := NewResponse()
// Perform a leader check, unless disabled. This prevents generating queued writes on
@ -1160,12 +1094,6 @@ func (s *Service) queuedExecute(w http.ResponseWriter, r *http.Request) {
return
}
timeout, err := timeoutParam(r, defaultTimeout)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
var fc queue.FlushChannel
if wait {
stats.Add(numQueuedExecutionsWait, 1)
@ -1184,26 +1112,19 @@ func (s *Service) queuedExecute(w http.ResponseWriter, r *http.Request) {
select {
case <-fc:
break
case <-time.NewTimer(timeout).C:
case <-time.NewTimer(qp.Timeout(defaultTimeout)).C:
http.Error(w, "timeout", http.StatusRequestTimeout)
return
}
}
resp.end = time.Now()
s.writeResponse(w, r, resp)
s.writeResponse(w, r, qp, resp)
}
// execute handles queries that modify the database.
func (s *Service) execute(w http.ResponseWriter, r *http.Request) {
func (s *Service) execute(w http.ResponseWriter, r *http.Request, qp QueryParams) {
resp := NewResponse()
timeout, isTx, timings, noRewriteRandom, err := reqParams(r, defaultTimeout)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
b, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
@ -1217,22 +1138,22 @@ func (s *Service) execute(w http.ResponseWriter, r *http.Request) {
return
}
stats.Add(numExecuteStmtsRx, int64(len(stmts)))
if err := command.Rewrite(stmts, !noRewriteRandom); err != nil {
if err := command.Rewrite(stmts, !qp.NoRewriteRandom()); err != nil {
http.Error(w, fmt.Sprintf("SQL rewrite: %s", err.Error()), http.StatusInternalServerError)
return
}
er := &command.ExecuteRequest{
Request: &command.Request{
Transaction: isTx,
Transaction: qp.Tx(),
Statements: stmts,
},
Timings: timings,
Timings: qp.Timings(),
}
results, resultsErr := s.store.Execute(er)
if resultsErr != nil && resultsErr == store.ErrNotLeader {
if s.DoRedirect(w, r) {
if s.DoRedirect(w, r, qp) {
return
}
@ -1254,7 +1175,7 @@ func (s *Service) execute(w http.ResponseWriter, r *http.Request) {
}
w.Header().Add(ServedByHTTPHeader, addr)
results, resultsErr = s.cluster.Execute(er, addr, makeCredentials(username, password), timeout)
results, resultsErr = s.cluster.Execute(er, addr, makeCredentials(username, password), qp.Timeout(defaultTimeout))
if resultsErr != nil {
stats.Add(numRemoteExecutionsFailed, 1)
if resultsErr.Error() == "unauthorized" {
@ -1271,11 +1192,11 @@ func (s *Service) execute(w http.ResponseWriter, r *http.Request) {
resp.Results.ExecuteResult = results
}
resp.end = time.Now()
s.writeResponse(w, r, resp)
s.writeResponse(w, r, qp, resp)
}
// handleQuery handles queries that do not modify the database.
func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request, qp QueryParams) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
if !s.CheckRequestPerm(r, auth.PermQuery) {
@ -1288,14 +1209,8 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
return
}
timeout, frsh, lvl, isTx, timings, noRewriteRandom, isAssoc, err := queryReqParams(r, defaultTimeout)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// Get the query statement(s), and do tx if necessary.
queries, err := requestQueries(r)
queries, err := requestQueries(r, qp)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@ -1304,29 +1219,29 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
// No point rewriting queries if they don't go through the Raft log, since they
// will never be replayed from the log anyway.
if lvl == command.QueryRequest_QUERY_REQUEST_LEVEL_STRONG {
if err := command.Rewrite(queries, noRewriteRandom); err != nil {
if qp.Level() == command.QueryRequest_QUERY_REQUEST_LEVEL_STRONG {
if err := command.Rewrite(queries, qp.NoRewriteRandom()); err != nil {
http.Error(w, fmt.Sprintf("SQL rewrite: %s", err.Error()), http.StatusInternalServerError)
return
}
}
resp := NewResponse()
resp.Results.AssociativeJSON = isAssoc
resp.Results.AssociativeJSON = qp.Associative()
qr := &command.QueryRequest{
Request: &command.Request{
Transaction: isTx,
Transaction: qp.Tx(),
Statements: queries,
},
Timings: timings,
Level: lvl,
Freshness: frsh.Nanoseconds(),
Timings: qp.Timings(),
Level: qp.Level(),
Freshness: qp.Freshness().Nanoseconds(),
}
results, resultsErr := s.store.Query(qr)
if resultsErr != nil && resultsErr == store.ErrNotLeader {
if s.DoRedirect(w, r) {
if s.DoRedirect(w, r, qp) {
return
}
@ -1346,7 +1261,7 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
}
w.Header().Add(ServedByHTTPHeader, addr)
results, resultsErr = s.cluster.Query(qr, addr, makeCredentials(username, password), timeout)
results, resultsErr = s.cluster.Query(qr, addr, makeCredentials(username, password), qp.Timeout(defaultTimeout))
if resultsErr != nil {
stats.Add(numRemoteQueriesFailed, 1)
if resultsErr.Error() == "unauthorized" {
@ -1363,10 +1278,10 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
resp.Results.QueryRows = results
}
resp.end = time.Now()
s.writeResponse(w, r, resp)
s.writeResponse(w, r, qp, resp)
}
func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request) {
func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request, qp QueryParams) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
if !s.CheckRequestPermAll(r, auth.PermQuery, auth.PermExecute) {
@ -1379,12 +1294,6 @@ func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request) {
return
}
timeout, frsh, lvl, isTx, timings, noRewriteRandom, isAssoc, err := executeQueryReqParams(r, defaultTimeout)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
b, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
@ -1399,27 +1308,27 @@ func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request) {
}
stats.Add(numRequestStmtsRx, int64(len(stmts)))
if err := command.Rewrite(stmts, noRewriteRandom); err != nil {
if err := command.Rewrite(stmts, qp.NoRewriteRandom()); err != nil {
http.Error(w, fmt.Sprintf("SQL rewrite: %s", err.Error()), http.StatusInternalServerError)
return
}
resp := NewResponse()
resp.Results.AssociativeJSON = isAssoc
resp.Results.AssociativeJSON = qp.Associative()
eqr := &command.ExecuteQueryRequest{
Request: &command.Request{
Transaction: isTx,
Transaction: qp.Tx(),
Statements: stmts,
},
Timings: timings,
Level: lvl,
Freshness: frsh.Nanoseconds(),
Timings: qp.Timings(),
Level: qp.Level(),
Freshness: qp.Freshness().Nanoseconds(),
}
results, resultErr := s.store.Request(eqr)
if resultErr != nil && resultErr == store.ErrNotLeader {
if s.DoRedirect(w, r) {
if s.DoRedirect(w, r, qp) {
return
}
@ -1439,7 +1348,7 @@ func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request) {
}
w.Header().Add(ServedByHTTPHeader, addr)
results, resultErr = s.cluster.Request(eqr, addr, makeCredentials(username, password), timeout)
results, resultErr = s.cluster.Request(eqr, addr, makeCredentials(username, password), qp.Timeout(defaultTimeout))
if resultErr != nil {
stats.Add(numRemoteRequestsFailed, 1)
if resultErr.Error() == "unauthorized" {
@ -1456,11 +1365,11 @@ func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request) {
resp.Results.ExecuteQueryResponse = results
}
resp.end = time.Now()
s.writeResponse(w, r, resp)
s.writeResponse(w, r, qp, resp)
}
// handleExpvar serves registered expvar information over HTTP.
func (s *Service) handleExpvar(w http.ResponseWriter, r *http.Request) {
func (s *Service) handleExpvar(w http.ResponseWriter, r *http.Request, qp QueryParams) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")
if !s.CheckRequestPerm(r, auth.PermStatus) {
w.WriteHeader(http.StatusUnauthorized)
@ -1484,7 +1393,7 @@ func (s *Service) handleExpvar(w http.ResponseWriter, r *http.Request) {
}
// handlePprof serves pprof information over HTTP.
func (s *Service) handlePprof(w http.ResponseWriter, r *http.Request) {
func (s *Service) handlePprof(w http.ResponseWriter, r *http.Request, qp QueryParams) {
if !s.CheckRequestPerm(r, auth.PermStatus) {
w.WriteHeader(http.StatusUnauthorized)
return
@ -1510,13 +1419,8 @@ func (s *Service) Addr() net.Addr {
// DoRedirect checks if the request is a redirect, and if so, performs the redirect.
// Returns true caller can consider the request handled. Returns false if the request
// was not a redirect and the caller should continue processing the request.
func (s *Service) DoRedirect(w http.ResponseWriter, r *http.Request) bool {
redirect, err := isRedirect(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return true
}
if !redirect {
func (s *Service) DoRedirect(w http.ResponseWriter, r *http.Request, qp QueryParams) bool {
if !qp.Redirect() {
return false
}
@ -1722,17 +1626,14 @@ func (s *Service) tlsStats() map[string]interface{} {
}
// writeResponse writes the given response to the given writer.
func (s *Service) writeResponse(w http.ResponseWriter, r *http.Request, j Responser) {
func (s *Service) writeResponse(w http.ResponseWriter, r *http.Request, qp QueryParams, j Responser) {
var b []byte
var err error
pretty, _ := isPretty(r)
timings, _ := isTimings(r)
if timings {
if qp.Timings() {
j.SetTime()
}
if pretty {
if qp.Pretty() {
b, err = json.MarshalIndent(j, "", " ")
} else {
b, err = json.Marshal(j)
@ -1748,15 +1649,11 @@ func (s *Service) writeResponse(w http.ResponseWriter, r *http.Request, j Respon
}
}
func requestQueries(r *http.Request) ([]*command.Statement, error) {
func requestQueries(r *http.Request, qp QueryParams) ([]*command.Statement, error) {
if r.Method == "GET" {
query, err := stmtParam(r)
if err != nil || query == "" {
return nil, errors.New("bad query GET request")
}
return []*command.Statement{
{
Sql: query,
Sql: qp.Query(),
},
}, nil
}
@ -1782,41 +1679,6 @@ func queryParam(req *http.Request, param string) (bool, error) {
return false, nil
}
// stmtParam returns the value for URL param 'q', if present.
func stmtParam(req *http.Request) (string, error) {
q := req.URL.Query()
stmt := strings.TrimSpace(q.Get("q"))
return stmt, nil
}
// fmtParam returns the value for URL param 'fmt', if present.
func fmtParam(req *http.Request) (string, error) {
q := req.URL.Query()
return strings.TrimSpace(q.Get("fmt")), nil
}
// verParam returns the requested version, if present.
func verParam(req *http.Request) (string, error) {
q := req.URL.Query()
return strings.TrimSpace(q.Get("ver")), nil
}
// isPretty returns whether the HTTP response body should be pretty-printed.
func isPretty(req *http.Request) (bool, error) {
return queryParam(req, "pretty")
}
// isVacuum returns whether the HTTP request is requesting a vacuum.
func isVacuum(req *http.Request) (bool, error) {
return queryParam(req, "vacuum")
}
// isRedirect returns whether the HTTP request is requesting a explicit
// redirect to the leader, if necessary.
func isRedirect(req *http.Request) (bool, error) {
return queryParam(req, "redirect")
}
func keyParam(req *http.Request) string {
q := req.URL.Query()
return strings.TrimSpace(q.Get("key"))
@ -1860,99 +1722,6 @@ func getSubJSON(jsonBlob []byte, keyString string) (json.RawMessage, error) {
return finalObjBytes, nil
}
// timeoutParam returns the value, if any, set for timeout. If not set, it
// returns the value passed in as a default.
func timeoutParam(req *http.Request, def time.Duration) (time.Duration, error) {
q := req.URL.Query()
timeout := strings.TrimSpace(q.Get("timeout"))
if timeout == "" {
return def, nil
}
t, err := time.ParseDuration(timeout)
if err != nil {
return 0, err
}
return t, nil
}
func chunkSizeParam(req *http.Request, defSz int) (int, error) {
q := req.URL.Query()
chunkSize := strings.TrimSpace(q.Get("chunk_kb"))
if chunkSize == "" {
return defSz, nil
}
sz, err := strconv.Atoi(chunkSize)
if err != nil {
return defSz, nil
}
return sz * 1024, nil
}
// isTx returns whether the HTTP request is requesting a transaction.
func isTx(req *http.Request) (bool, error) {
return queryParam(req, "transaction")
}
// isQueue returns whether the HTTP request is requesting a queue.
func isQueue(req *http.Request) (bool, error) {
return queryParam(req, "queue")
}
// reqParams is a convenience function to get a bunch of query params
// in one function call.
func reqParams(req *http.Request, def time.Duration) (timeout time.Duration, tx, timings, noRwRandom bool, err error) {
timeout, err = timeoutParam(req, def)
if err != nil {
return 0, false, false, true, err
}
tx, err = isTx(req)
if err != nil {
return 0, false, false, true, err
}
timings, err = isTimings(req)
if err != nil {
return 0, false, false, true, err
}
noRwRandom, err = noRewriteRandom(req)
if err != nil {
return 0, false, false, true, err
}
return timeout, tx, timings, noRwRandom, nil
}
// queryReqParams is a convenience function to get a bunch of query params
// in one function call.
func queryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, noRwRandom, isAssoc bool, err error) {
timeout, isTx, timings, noRwRandom, err = reqParams(req, defaultTimeout)
if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
}
lvl, err = level(req)
if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
}
frsh, err = freshness(req)
if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
}
isAssoc, err = isAssociative(req)
if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
}
return
}
func executeQueryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, noRwRandom, isAssoc bool, err error) {
timeout, frsh, lvl, isTx, timings, noRwRandom, isAssoc, err = queryReqParams(req, defaultTimeout)
if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
}
return timeout, frsh, lvl, isTx, timings, noRwRandom, isAssoc, nil
}
// noLeader returns whether processing should skip the leader check.
func noLeader(req *http.Request) (bool, error) {
return queryParam(req, "noleader")
@ -2014,21 +1783,6 @@ func freshness(req *http.Request) (time.Duration, error) {
return d, nil
}
// backupFormat returns the request backup format, setting the response header
// accordingly.
func backupFormat(w http.ResponseWriter, r *http.Request) (command.BackupRequest_Format, error) {
fmt, err := fmtParam(r)
if err != nil {
return command.BackupRequest_BACKUP_REQUEST_FORMAT_BINARY, err
}
if fmt == "sql" {
w.Header().Set("Content-Type", "application/sql")
return command.BackupRequest_BACKUP_REQUEST_FORMAT_SQL, nil
}
w.Header().Set("Content-Type", "application/octet-stream")
return command.BackupRequest_BACKUP_REQUEST_FORMAT_BINARY, nil
}
func prettyEnabled(e bool) string {
if e {
return "enabled"

@ -895,14 +895,16 @@ func Test_DoRedirect(t *testing.T) {
}
s := New("127.0.0.1:0", m, c, nil)
req := mustNewHTTPRequest("http://qux:4001")
qp := mustGetQueryParams(req)
if s.DoRedirect(nil, req) {
if s.DoRedirect(nil, req, qp) {
t.Fatalf("incorrectly redirected")
}
req = mustNewHTTPRequest("http://qux:4001/db/query?redirect")
qp = mustGetQueryParams(req)
w := httptest.NewRecorder()
if !s.DoRedirect(w, req) {
if !s.DoRedirect(w, req, qp) {
t.Fatalf("incorrectly not redirected")
}
if exp, got := http.StatusMovedPermanently, w.Code; exp != got {
@ -1244,33 +1246,19 @@ func Test_timeoutVersionPrettyQueryParam(t *testing.T) {
if err != nil {
t.Fatalf("failed to create request: %s", err)
}
timeout, err := timeoutParam(req, def)
qp, err := NewQueryParams(req)
if err != nil {
if tt.err {
// Error is expected, all is OK.
continue
}
t.Fatalf("failed to get timeout as expected: %s", err)
}
if timeout != mustParseDuration(tt.dur) {
t.Fatalf("got wrong timeout, expected %s, got %s", mustParseDuration(tt.dur), timeout)
t.Fatalf("failed to parse query params: %s", err)
}
ver, err := verParam(req)
if err != nil {
t.Fatalf("failed to get version as expected: %s", err)
if got, exp := qp.Timeout(def), mustParseDuration(tt.dur); got != exp {
t.Fatalf("got wrong timeout, expected %s, got %s", exp, got)
}
if ver != tt.ver {
t.Fatalf("got wrong version, expected %s, got %s", tt.ver, ver)
}
pretty, err := isPretty(req)
if err != nil {
t.Fatalf("failed to get pretty as expected on test %d: %s", i, err)
if got, exp := qp.Version(), tt.ver; got != exp {
t.Fatalf("got wrong version, expected %s, got %s", exp, got)
}
if pretty != tt.pretty {
t.Fatalf("got wrong pretty on test %d, expected %t, got %t", i, tt.pretty, pretty)
if got, exp := qp.Pretty(), tt.pretty; got != exp {
t.Fatalf("got wrong pretty on test %d, expected %t, got %t", i, exp, got)
}
}
}
@ -1469,3 +1457,11 @@ func mustGunzip(b []byte) []byte {
return dec
}
func mustGetQueryParams(req *http.Request) QueryParams {
qp, err := NewQueryParams(req)
if err != nil {
panic("failed to get query params")
}
return qp
}

Loading…
Cancel
Save