1
0
Fork 0

Merge pull request #1462 from rqlite/refactor-redirect

Refactor redirect logic in HTTP service
master
Philip O'Toole 10 months ago committed by GitHub
commit dcc23a7968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,7 @@
### Implementation changes and bug fixes
- [PR #1456](https://github.com/rqlite/rqlite/pull/1459): Standardize on chunk size.
- [PR #1456](https://github.com/rqlite/rqlite/pull/1459): Set `TrailingLogs=0` to truncate log during user-initiated Snapshotting.
- [PR #1462](https://github.com/rqlite/rqlite/pull/1462): Refactor redirect logic in HTTP service
## 8.0.1 (December 8th 2023)
This release fixes an edge case issue during restore-from-SQLite. It's possible if a rqlite system crashes shortly after restoring from SQLite it may not have loaded the data correctly.

@ -502,12 +502,6 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) {
return
}
redirect, err := isRedirect(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
b, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
@ -543,16 +537,7 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) {
err = s.store.Remove(rn)
if err != nil {
if err == store.ErrNotLeader {
if redirect {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
redirect := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, redirect, http.StatusMovedPermanently)
if s.DoRedirect(w, r) {
return
}
@ -608,11 +593,6 @@ func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
redirect, err := isRedirect(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
format, err := backupFormat(w, r)
if err != nil {
@ -641,16 +621,7 @@ func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request) {
err = s.store.Backup(br, w)
if err != nil {
if err == store.ErrNotLeader {
if redirect {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
redirect := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, redirect, http.StatusMovedPermanently)
if s.DoRedirect(w, r) {
return
}
@ -722,12 +693,6 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
return
}
redirect, err := isRedirect(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
chunkSz, err := chunkSizeParam(r, defaultChunkSize)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
@ -767,16 +732,9 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
results, err := s.store.Execute(er)
if err != nil {
if err == store.ErrNotLeader {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
if s.DoRedirect(w, r) {
return
}
redirect := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, redirect, http.StatusMovedPermanently)
return
}
resp.Error = err.Error()
} else {
@ -788,7 +746,7 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
for {
chunk, err := chunker.Next()
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
err = s.store.LoadChunk(chunk)
@ -796,16 +754,7 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
} else if err != nil && err == store.ErrNotLeader {
if redirect {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
redirect := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, redirect, http.StatusMovedPermanently)
if s.DoRedirect(w, r) {
return
}
@ -1233,7 +1182,7 @@ func (s *Service) queuedExecute(w http.ResponseWriter, r *http.Request) {
func (s *Service) execute(w http.ResponseWriter, r *http.Request) {
resp := NewResponse()
timeout, isTx, timings, redirect, noRewriteRandom, err := reqParams(r, defaultTimeout)
timeout, isTx, timings, noRewriteRandom, err := reqParams(r, defaultTimeout)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@ -1267,15 +1216,7 @@ func (s *Service) execute(w http.ResponseWriter, r *http.Request) {
results, resultsErr := s.store.Execute(er)
if resultsErr != nil && resultsErr == store.ErrNotLeader {
if redirect {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
loc := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, loc, http.StatusMovedPermanently)
if s.DoRedirect(w, r) {
return
}
@ -1331,7 +1272,7 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
return
}
timeout, frsh, lvl, isTx, timings, redirect, noRewriteRandom, isAssoc, err := queryReqParams(r, defaultTimeout)
timeout, frsh, lvl, isTx, timings, noRewriteRandom, isAssoc, err := queryReqParams(r, defaultTimeout)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@ -1369,15 +1310,7 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
results, resultsErr := s.store.Query(qr)
if resultsErr != nil && resultsErr == store.ErrNotLeader {
if redirect {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
loc := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, loc, http.StatusMovedPermanently)
if s.DoRedirect(w, r) {
return
}
@ -1430,7 +1363,7 @@ func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request) {
return
}
timeout, frsh, lvl, isTx, timings, redirect, noRewriteRandom, isAssoc, err := executeQueryReqParams(r, defaultTimeout)
timeout, frsh, lvl, isTx, timings, noRewriteRandom, isAssoc, err := executeQueryReqParams(r, defaultTimeout)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
@ -1470,15 +1403,7 @@ func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request) {
results, resultErr := s.store.Request(eqr)
if resultErr != nil && resultErr == store.ErrNotLeader {
if redirect {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
loc := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, loc, http.StatusMovedPermanently)
if s.DoRedirect(w, r) {
return
}
@ -1566,13 +1491,41 @@ func (s *Service) Addr() net.Addr {
return s.ln.Addr()
}
// DoRedirect checks if the request is a redirect, and if so, performs the redirect.
// Returns true caller can consider the request handled. Returns false if the request
// was not a redirect and the caller should continue processing the request.
func (s *Service) DoRedirect(w http.ResponseWriter, r *http.Request) bool {
redirect, err := isRedirect(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return true
}
if !redirect {
return false
}
rd, err := s.FormRedirect(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
} else {
http.Redirect(w, r, rd, http.StatusMovedPermanently)
}
return true
}
// FormRedirect returns the value for the "Location" header for a 301 response.
func (s *Service) FormRedirect(r *http.Request, url string) string {
func (s *Service) FormRedirect(r *http.Request) (string, error) {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
return "", ErrLeaderNotFound
}
rq := r.URL.RawQuery
if rq != "" {
rq = fmt.Sprintf("?%s", rq)
}
return fmt.Sprintf("%s%s%s", url, r.URL.Path, rq)
return fmt.Sprintf("%s%s%s", leaderAPIAddr, r.URL.Path, rq), nil
}
// CheckRequestPerm checks if the request is authenticated and authorized
@ -1636,7 +1589,6 @@ func (s *Service) LeaderAPIAddr() string {
}
apiAddr, err := s.cluster.GetNodeAPIAddr(nodeAddr, defaultTimeout)
if err != nil {
return ""
}
@ -1932,61 +1884,57 @@ func isQueue(req *http.Request) (bool, error) {
// reqParams is a convenience function to get a bunch of query params
// in one function call.
func reqParams(req *http.Request, def time.Duration) (timeout time.Duration, tx, timings, redirect, noRwRandom bool, err error) {
func reqParams(req *http.Request, def time.Duration) (timeout time.Duration, tx, timings, noRwRandom bool, err error) {
timeout, err = timeoutParam(req, def)
if err != nil {
return 0, false, false, false, true, err
return 0, false, false, true, err
}
tx, err = isTx(req)
if err != nil {
return 0, false, false, false, true, err
return 0, false, false, true, err
}
timings, err = isTimings(req)
if err != nil {
return 0, false, false, false, true, err
}
redirect, err = isRedirect(req)
if err != nil {
return 0, false, false, false, true, err
return 0, false, false, true, err
}
noRwRandom, err = noRewriteRandom(req)
if err != nil {
return 0, false, false, false, true, err
return 0, false, false, true, err
}
return timeout, tx, timings, redirect, noRwRandom, nil
return timeout, tx, timings, noRwRandom, nil
}
// queryReqParams is a convenience function to get a bunch of query params
// in one function call.
func queryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, redirect, noRwRandom, isAssoc bool, err error) {
timeout, isTx, timings, redirect, noRwRandom, err = reqParams(req, defaultTimeout)
func queryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, noRwRandom, isAssoc bool, err error) {
timeout, isTx, timings, noRwRandom, err = reqParams(req, defaultTimeout)
if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
}
lvl, err = level(req)
if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
}
frsh, err = freshness(req)
if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
}
isAssoc, err = isAssociative(req)
if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
}
return
}
func executeQueryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, redirect, noRwRandom, isAssoc bool, err error) {
timeout, frsh, lvl, isTx, timings, redirect, noRwRandom, isAssoc, err = queryReqParams(req, defaultTimeout)
func executeQueryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, noRwRandom, isAssoc bool, err error) {
timeout, frsh, lvl, isTx, timings, noRwRandom, isAssoc, err = queryReqParams(req, defaultTimeout)
if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
}
return timeout, frsh, lvl, isTx, timings, redirect, noRwRandom, isAssoc, nil
return timeout, frsh, lvl, isTx, timings, noRwRandom, isAssoc, nil
}
// noLeader returns whether processing should skip the leader check.

@ -8,6 +8,7 @@ import (
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
@ -826,36 +827,90 @@ func Test_RegisterStatus(t *testing.T) {
}
func Test_FormRedirect(t *testing.T) {
m := &MockStore{}
c := &mockClusterService{}
m := &MockStore{
leaderAddr: "foo:4002",
}
c := &mockClusterService{
apiAddr: "http://foo:4001",
}
s := New("127.0.0.1:0", m, c, nil)
req := mustNewHTTPRequest("http://qux:4001")
if rd := s.FormRedirect(req, "http://foo:4001"); rd != "http://foo:4001" {
t.Fatal("failed to form redirect for simple URL")
rd, err := s.FormRedirect(req)
if err != nil {
t.Fatalf("failed to form redirect: %s", err.Error())
}
if exp, got := "http://foo:4001", rd; exp != got {
t.Fatalf("incorrect redirect, exp: %s, got: %s", exp, got)
}
}
func Test_FormRedirectParam(t *testing.T) {
m := &MockStore{}
c := &mockClusterService{}
m := &MockStore{
leaderAddr: "foo:4002",
}
c := &mockClusterService{
apiAddr: "http://foo:4001",
}
s := New("127.0.0.1:0", m, c, nil)
req := mustNewHTTPRequest("http://qux:4001/db/query?x=y")
if rd := s.FormRedirect(req, "http://foo:4001"); rd != "http://foo:4001/db/query?x=y" {
t.Fatal("failed to form redirect for URL")
rd, err := s.FormRedirect(req)
if err != nil {
t.Fatalf("failed to form redirect: %s", err.Error())
}
if exp, got := "http://foo:4001/db/query?x=y", rd; rd != got {
t.Fatalf("incorrect redirect, exp: %s, got: %s", exp, got)
}
}
func Test_FormRedirectHTTPS(t *testing.T) {
m := &MockStore{}
c := &mockClusterService{}
m := &MockStore{
leaderAddr: "foo:4002",
}
c := &mockClusterService{
apiAddr: "https://foo:4001",
}
s := New("127.0.0.1:0", m, c, nil)
req := mustNewHTTPRequest("http://qux:4001")
rd, err := s.FormRedirect(req)
if err != nil {
t.Fatalf("failed to form redirect: %s", err.Error())
}
if exp, got := "https://foo:4001", rd; exp != got {
t.Fatalf("incorrect redirect, exp: %s, got: %s", exp, got)
}
}
func Test_DoRedirect(t *testing.T) {
m := &MockStore{
leaderAddr: "foo:4002",
}
c := &mockClusterService{
apiAddr: "https://foo:4001",
}
s := New("127.0.0.1:0", m, c, nil)
req := mustNewHTTPRequest("http://qux:4001")
if rd := s.FormRedirect(req, "https://foo:4001"); rd != "https://foo:4001" {
t.Fatal("failed to form redirect for simple URL")
if s.DoRedirect(nil, req) {
t.Fatalf("incorrectly redirected")
}
req = mustNewHTTPRequest("http://qux:4001/db/query?redirect")
w := httptest.NewRecorder()
if !s.DoRedirect(w, req) {
t.Fatalf("incorrectly not redirected")
}
if exp, got := http.StatusMovedPermanently, w.Code; exp != got {
t.Fatalf("incorrect redirect code, exp: %d, got: %d", exp, got)
}
// check location header
if exp, got := "https://foo:4001/db/query?redirect", w.Header().Get("Location"); exp != got {
t.Fatalf("incorrect redirect location, exp: %s, got: %s", exp, got)
}
}

Loading…
Cancel
Save