1
0
Fork 0

/db/request compiles and all existing tests pass

master
Philip O'Toole 1 year ago
parent 2e140af463
commit 7ee6036476

@ -212,3 +212,16 @@ func gzUncompress(b []byte) ([]byte, error) {
}
return ub, nil
}
func MapConsistencyLevel(in QueryRequest_Level) ExecuteQueryRequest_Level {
switch in {
case QueryRequest_QUERY_REQUEST_LEVEL_NONE:
return ExecuteQueryRequest_QUERY_REQUEST_LEVEL_NONE
case QueryRequest_QUERY_REQUEST_LEVEL_WEAK:
return ExecuteQueryRequest_QUERY_REQUEST_LEVEL_WEAK
case QueryRequest_QUERY_REQUEST_LEVEL_STRONG:
return ExecuteQueryRequest_QUERY_REQUEST_LEVEL_STRONG
default:
return ExecuteQueryRequest_QUERY_REQUEST_LEVEL_WEAK
}
}

@ -54,6 +54,10 @@ type Database interface {
// is held on the database.
Query(qr *command.QueryRequest) ([]*command.QueryRows, error)
// Request processes a slice of requests, each of which can be either
// an Execute or Query request.
Request(eqr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, error)
// Load loads a SQLite file into the system
Load(lr *command.LoadRequest) error
}
@ -98,6 +102,9 @@ type Cluster interface {
// Query performs an Query Request on a remote node.
Query(qr *command.QueryRequest, nodeAddr string, creds *cluster.Credentials, timeout time.Duration) ([]*command.QueryRows, error)
// Request performs an ExecuteQuery Request on a remote node.
Request(eqr *command.ExecuteQueryRequest, nodeAddr string, creds *cluster.Credentials, timeout time.Duration) ([]*command.ExecuteQueryResponse, error)
// Backup retrieves a backup from a remote node and writes to the io.Writer.
Backup(br *command.BackupRequest, nodeAddr string, creds *cluster.Credentials, timeout time.Duration, w io.Writer) error
@ -122,10 +129,12 @@ type StatusReporter interface {
Stats() (map[string]interface{}, error)
}
// DBResults stores either an Execute result or a Query result
// DBResults stores either an Execute result, a Query result, or
// an ExecuteQuery result.
type DBResults struct {
ExecuteResult []*command.ExecuteResult
QueryRows []*command.QueryRows
ExecuteResult []*command.ExecuteResult
QueryRows []*command.QueryRows
ExecuteQueryResponse []*command.ExecuteQueryResponse
AssociativeJSON bool // Render in associative form
}
@ -145,6 +154,8 @@ func (d *DBResults) MarshalJSON() ([]byte, error) {
return enc.JSONMarshal(d.ExecuteResult)
} else if d.QueryRows != nil {
return enc.JSONMarshal(d.QueryRows)
} else if d.ExecuteQueryResponse != nil {
return enc.JSONMarshal(d.ExecuteQueryResponse)
}
return json.Marshal(make([]interface{}, 0))
}
@ -1407,25 +1418,7 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
return
}
timeout, isTx, timings, redirect, noRewriteRandom, err := reqParams(r, defaultTimeout)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
lvl, err := level(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
frsh, err := freshness(r)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
assoc, err := isAssociative(r)
timeout, frsh, lvl, isTx, timings, redirect, noRewriteRandom, isAssoc, err := queryReqParams(r, defaultTimeout)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@ -1448,7 +1441,7 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
}
resp := NewResponse()
resp.Results.AssociativeJSON = assoc
resp.Results.AssociativeJSON = isAssoc
qr := &command.QueryRequest{
Request: &command.Request{
@ -1514,7 +1507,94 @@ func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusUnauthorized)
return
}
return
if r.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
timeout, frsh, lvl, isTx, timings, redirect, noRewriteRandom, isAssoc, err := executeQueryReqParams(r, defaultTimeout)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
b, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
r.Body.Close()
stmts, err := ParseRequest(b)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
if err := command.Rewrite(stmts, noRewriteRandom); err != nil {
http.Error(w, fmt.Sprintf("SQL rewrite: %s", err.Error()), http.StatusInternalServerError)
return
}
resp := NewResponse()
resp.Results.AssociativeJSON = isAssoc
eqr := &command.ExecuteQueryRequest{
Request: &command.Request{
Transaction: isTx,
Statements: stmts,
},
Timings: timings,
Level: lvl,
Freshness: frsh.Nanoseconds(),
}
results, resultErr := s.store.Request(eqr)
if resultErr != nil && resultErr == store.ErrNotLeader {
if redirect {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
loc := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, loc, http.StatusMovedPermanently)
return
}
addr, err := s.store.LeaderAddr()
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
if addr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
username, password, ok := r.BasicAuth()
if !ok {
username = ""
}
w.Header().Add(ServedByHTTPHeader, addr)
results, resultErr = s.cluster.Request(eqr, addr, makeCredentials(username, password), timeout)
if resultErr != nil && resultErr.Error() == "unauthorized" {
http.Error(w, "remote request not authorized", http.StatusUnauthorized)
return
}
stats.Add(numRemoteRequests, 1)
}
if resultErr != nil {
resp.Error = resultErr.Error()
} else {
resp.Results.ExecuteQueryResponse = results
}
resp.end = time.Now()
s.writeResponse(w, r, resp)
}
// handleExpvar serves registered expvar information over HTTP.
@ -1920,6 +2000,39 @@ func reqParams(req *http.Request, def time.Duration) (timeout time.Duration, tx,
return timeout, tx, timings, redirect, noRwRandom, nil
}
// queryReqParams is a convenience function to get a bunch of query params
// in one function call.
func queryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, redirect, noRwRandom, isAssoc bool, err error) {
timeout, isTx, timings, redirect, noRwRandom, err = reqParams(req, defaultTimeout)
if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err
}
lvl, err = level(req)
if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err
}
frsh, err = freshness(req)
if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err
}
isAssoc, err = isAssociative(req)
if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err
}
return
}
func executeQueryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.ExecuteQueryRequest_Level, isTx, timings, redirect, noRwRandom, isAssoc bool, err error) {
timeout, frsh, qLvl, isTx, timings, redirect, noRwRandom, isAssoc, err := queryReqParams(req, defaultTimeout)
if err != nil {
return 0, 0, command.ExecuteQueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err
}
return timeout, frsh, command.MapConsistencyLevel(qLvl), isTx, timings, redirect, noRwRandom, isAssoc, nil
}
// noLeader returns whether processing should skip the leader check.
func noLeader(req *http.Request) (bool, error) {
return queryParam(req, "noleader")

@ -1200,6 +1200,7 @@ func Test_timeoutQueryParam(t *testing.T) {
type MockStore struct {
executeFn func(er *command.ExecuteRequest) ([]*command.ExecuteResult, error)
queryFn func(qr *command.QueryRequest) ([]*command.QueryRows, error)
requestFn func(eqr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, error)
backupFn func(br *command.BackupRequest, dst io.Writer) error
loadFn func(lr *command.LoadRequest) error
leaderAddr string
@ -1220,6 +1221,13 @@ func (m *MockStore) Query(qr *command.QueryRequest) ([]*command.QueryRows, error
return nil, nil
}
func (m *MockStore) Request(eqr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, error) {
if m.requestFn != nil {
return m.requestFn(eqr)
}
return nil, nil
}
func (m *MockStore) Join(jr *command.JoinRequest) error {
return nil
}
@ -1266,6 +1274,7 @@ type mockClusterService struct {
apiAddr string
executeFn func(er *command.ExecuteRequest, addr string, t time.Duration) ([]*command.ExecuteResult, error)
queryFn func(qr *command.QueryRequest, addr string, t time.Duration) ([]*command.QueryRows, error)
requestFn func(eqr *command.ExecuteQueryRequest, nodeAddr string, timeout time.Duration) ([]*command.ExecuteQueryResponse, error)
backupFn func(br *command.BackupRequest, addr string, t time.Duration, w io.Writer) error
loadFn func(lr *command.LoadRequest, addr string, t time.Duration) error
removeNodeFn func(rn *command.RemoveNodeRequest, nodeAddr string, t time.Duration) error
@ -1289,6 +1298,13 @@ func (m *mockClusterService) Query(qr *command.QueryRequest, addr string, creds
return nil, nil
}
func (m *mockClusterService) Request(eqr *command.ExecuteQueryRequest, nodeAddr string, creds *cluster.Credentials, timeout time.Duration) ([]*command.ExecuteQueryResponse, error) {
if m.requestFn != nil {
return m.requestFn(eqr, nodeAddr, timeout)
}
return nil, nil
}
func (m *mockClusterService) Backup(br *command.BackupRequest, addr string, creds *cluster.Credentials, t time.Duration, w io.Writer) error {
if m.backupFn != nil {
return m.backupFn(br, addr, t, w)

Loading…
Cancel
Save