1
0
Fork 0

Merge pull request #1462 from rqlite/refactor-redirect

Refactor redirect logic in HTTP service
master
Philip O'Toole 10 months ago committed by GitHub
commit dcc23a7968
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,7 @@
### Implementation changes and bug fixes ### Implementation changes and bug fixes
- [PR #1456](https://github.com/rqlite/rqlite/pull/1459): Standardize on chunk size. - [PR #1456](https://github.com/rqlite/rqlite/pull/1459): Standardize on chunk size.
- [PR #1456](https://github.com/rqlite/rqlite/pull/1459): Set `TrailingLogs=0` to truncate log during user-initiated Snapshotting. - [PR #1456](https://github.com/rqlite/rqlite/pull/1459): Set `TrailingLogs=0` to truncate log during user-initiated Snapshotting.
- [PR #1462](https://github.com/rqlite/rqlite/pull/1462): Refactor redirect logic in HTTP service
## 8.0.1 (December 8th 2023) ## 8.0.1 (December 8th 2023)
This release fixes an edge case issue during restore-from-SQLite. It's possible if a rqlite system crashes shortly after restoring from SQLite it may not have loaded the data correctly. This release fixes an edge case issue during restore-from-SQLite. It's possible if a rqlite system crashes shortly after restoring from SQLite it may not have loaded the data correctly.

@ -502,12 +502,6 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) {
return return
} }
redirect, err := isRedirect(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
b, err := io.ReadAll(r.Body) b, err := io.ReadAll(r.Body)
if err != nil { if err != nil {
w.WriteHeader(http.StatusBadRequest) w.WriteHeader(http.StatusBadRequest)
@ -543,16 +537,7 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) {
err = s.store.Remove(rn) err = s.store.Remove(rn)
if err != nil { if err != nil {
if err == store.ErrNotLeader { if err == store.ErrNotLeader {
if redirect { if s.DoRedirect(w, r) {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
redirect := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, redirect, http.StatusMovedPermanently)
return return
} }
@ -608,11 +593,6 @@ func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
redirect, err := isRedirect(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
format, err := backupFormat(w, r) format, err := backupFormat(w, r)
if err != nil { if err != nil {
@ -641,16 +621,7 @@ func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request) {
err = s.store.Backup(br, w) err = s.store.Backup(br, w)
if err != nil { if err != nil {
if err == store.ErrNotLeader { if err == store.ErrNotLeader {
if redirect { if s.DoRedirect(w, r) {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
redirect := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, redirect, http.StatusMovedPermanently)
return return
} }
@ -722,12 +693,6 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
return return
} }
redirect, err := isRedirect(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
chunkSz, err := chunkSizeParam(r, defaultChunkSize) chunkSz, err := chunkSizeParam(r, defaultChunkSize)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
@ -767,16 +732,9 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
results, err := s.store.Execute(er) results, err := s.store.Execute(er)
if err != nil { if err != nil {
if err == store.ErrNotLeader { if err == store.ErrNotLeader {
leaderAPIAddr := s.LeaderAPIAddr() if s.DoRedirect(w, r) {
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return return
} }
redirect := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, redirect, http.StatusMovedPermanently)
return
} }
resp.Error = err.Error() resp.Error = err.Error()
} else { } else {
@ -788,7 +746,7 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
for { for {
chunk, err := chunker.Next() chunk, err := chunker.Next()
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} }
err = s.store.LoadChunk(chunk) err = s.store.LoadChunk(chunk)
@ -796,16 +754,7 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
} else if err != nil && err == store.ErrNotLeader { } else if err != nil && err == store.ErrNotLeader {
if redirect { if s.DoRedirect(w, r) {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
redirect := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, redirect, http.StatusMovedPermanently)
return return
} }
@ -1233,7 +1182,7 @@ func (s *Service) queuedExecute(w http.ResponseWriter, r *http.Request) {
func (s *Service) execute(w http.ResponseWriter, r *http.Request) { func (s *Service) execute(w http.ResponseWriter, r *http.Request) {
resp := NewResponse() resp := NewResponse()
timeout, isTx, timings, redirect, noRewriteRandom, err := reqParams(r, defaultTimeout) timeout, isTx, timings, noRewriteRandom, err := reqParams(r, defaultTimeout)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
@ -1267,15 +1216,7 @@ func (s *Service) execute(w http.ResponseWriter, r *http.Request) {
results, resultsErr := s.store.Execute(er) results, resultsErr := s.store.Execute(er)
if resultsErr != nil && resultsErr == store.ErrNotLeader { if resultsErr != nil && resultsErr == store.ErrNotLeader {
if redirect { if s.DoRedirect(w, r) {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
loc := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, loc, http.StatusMovedPermanently)
return return
} }
@ -1331,7 +1272,7 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
return return
} }
timeout, frsh, lvl, isTx, timings, redirect, noRewriteRandom, isAssoc, err := queryReqParams(r, defaultTimeout) timeout, frsh, lvl, isTx, timings, noRewriteRandom, isAssoc, err := queryReqParams(r, defaultTimeout)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
@ -1369,15 +1310,7 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
results, resultsErr := s.store.Query(qr) results, resultsErr := s.store.Query(qr)
if resultsErr != nil && resultsErr == store.ErrNotLeader { if resultsErr != nil && resultsErr == store.ErrNotLeader {
if redirect { if s.DoRedirect(w, r) {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
loc := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, loc, http.StatusMovedPermanently)
return return
} }
@ -1430,7 +1363,7 @@ func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request) {
return return
} }
timeout, frsh, lvl, isTx, timings, redirect, noRewriteRandom, isAssoc, err := executeQueryReqParams(r, defaultTimeout) timeout, frsh, lvl, isTx, timings, noRewriteRandom, isAssoc, err := executeQueryReqParams(r, defaultTimeout)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
@ -1470,15 +1403,7 @@ func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request) {
results, resultErr := s.store.Request(eqr) results, resultErr := s.store.Request(eqr)
if resultErr != nil && resultErr == store.ErrNotLeader { if resultErr != nil && resultErr == store.ErrNotLeader {
if redirect { if s.DoRedirect(w, r) {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
loc := s.FormRedirect(r, leaderAPIAddr)
http.Redirect(w, r, loc, http.StatusMovedPermanently)
return return
} }
@ -1566,13 +1491,41 @@ func (s *Service) Addr() net.Addr {
return s.ln.Addr() return s.ln.Addr()
} }
// DoRedirect checks if the request is a redirect, and if so, performs the redirect.
// Returns true caller can consider the request handled. Returns false if the request
// was not a redirect and the caller should continue processing the request.
func (s *Service) DoRedirect(w http.ResponseWriter, r *http.Request) bool {
redirect, err := isRedirect(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return true
}
if !redirect {
return false
}
rd, err := s.FormRedirect(r)
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
} else {
http.Redirect(w, r, rd, http.StatusMovedPermanently)
}
return true
}
// FormRedirect returns the value for the "Location" header for a 301 response. // FormRedirect returns the value for the "Location" header for a 301 response.
func (s *Service) FormRedirect(r *http.Request, url string) string { func (s *Service) FormRedirect(r *http.Request) (string, error) {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
stats.Add(numLeaderNotFound, 1)
return "", ErrLeaderNotFound
}
rq := r.URL.RawQuery rq := r.URL.RawQuery
if rq != "" { if rq != "" {
rq = fmt.Sprintf("?%s", rq) rq = fmt.Sprintf("?%s", rq)
} }
return fmt.Sprintf("%s%s%s", url, r.URL.Path, rq) return fmt.Sprintf("%s%s%s", leaderAPIAddr, r.URL.Path, rq), nil
} }
// CheckRequestPerm checks if the request is authenticated and authorized // CheckRequestPerm checks if the request is authenticated and authorized
@ -1636,7 +1589,6 @@ func (s *Service) LeaderAPIAddr() string {
} }
apiAddr, err := s.cluster.GetNodeAPIAddr(nodeAddr, defaultTimeout) apiAddr, err := s.cluster.GetNodeAPIAddr(nodeAddr, defaultTimeout)
if err != nil { if err != nil {
return "" return ""
} }
@ -1932,61 +1884,57 @@ func isQueue(req *http.Request) (bool, error) {
// reqParams is a convenience function to get a bunch of query params // reqParams is a convenience function to get a bunch of query params
// in one function call. // in one function call.
func reqParams(req *http.Request, def time.Duration) (timeout time.Duration, tx, timings, redirect, noRwRandom bool, err error) { func reqParams(req *http.Request, def time.Duration) (timeout time.Duration, tx, timings, noRwRandom bool, err error) {
timeout, err = timeoutParam(req, def) timeout, err = timeoutParam(req, def)
if err != nil { if err != nil {
return 0, false, false, false, true, err return 0, false, false, true, err
} }
tx, err = isTx(req) tx, err = isTx(req)
if err != nil { if err != nil {
return 0, false, false, false, true, err return 0, false, false, true, err
} }
timings, err = isTimings(req) timings, err = isTimings(req)
if err != nil { if err != nil {
return 0, false, false, false, true, err return 0, false, false, true, err
}
redirect, err = isRedirect(req)
if err != nil {
return 0, false, false, false, true, err
} }
noRwRandom, err = noRewriteRandom(req) noRwRandom, err = noRewriteRandom(req)
if err != nil { if err != nil {
return 0, false, false, false, true, err return 0, false, false, true, err
} }
return timeout, tx, timings, redirect, noRwRandom, nil return timeout, tx, timings, noRwRandom, nil
} }
// queryReqParams is a convenience function to get a bunch of query params // queryReqParams is a convenience function to get a bunch of query params
// in one function call. // in one function call.
func queryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, redirect, noRwRandom, isAssoc bool, err error) { func queryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, noRwRandom, isAssoc bool, err error) {
timeout, isTx, timings, redirect, noRwRandom, err = reqParams(req, defaultTimeout) timeout, isTx, timings, noRwRandom, err = reqParams(req, defaultTimeout)
if err != nil { if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
} }
lvl, err = level(req) lvl, err = level(req)
if err != nil { if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
} }
frsh, err = freshness(req) frsh, err = freshness(req)
if err != nil { if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
} }
isAssoc, err = isAssociative(req) isAssoc, err = isAssociative(req)
if err != nil { if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
} }
return return
} }
func executeQueryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, redirect, noRwRandom, isAssoc bool, err error) { func executeQueryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, noRwRandom, isAssoc bool, err error) {
timeout, frsh, lvl, isTx, timings, redirect, noRwRandom, isAssoc, err = queryReqParams(req, defaultTimeout) timeout, frsh, lvl, isTx, timings, noRwRandom, isAssoc, err = queryReqParams(req, defaultTimeout)
if err != nil { if err != nil {
return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err
} }
return timeout, frsh, lvl, isTx, timings, redirect, noRwRandom, isAssoc, nil return timeout, frsh, lvl, isTx, timings, noRwRandom, isAssoc, nil
} }
// noLeader returns whether processing should skip the leader check. // noLeader returns whether processing should skip the leader check.

@ -8,6 +8,7 @@ import (
"io" "io"
"io/ioutil" "io/ioutil"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"os" "os"
"strings" "strings"
@ -826,36 +827,90 @@ func Test_RegisterStatus(t *testing.T) {
} }
func Test_FormRedirect(t *testing.T) { func Test_FormRedirect(t *testing.T) {
m := &MockStore{} m := &MockStore{
c := &mockClusterService{} leaderAddr: "foo:4002",
}
c := &mockClusterService{
apiAddr: "http://foo:4001",
}
s := New("127.0.0.1:0", m, c, nil) s := New("127.0.0.1:0", m, c, nil)
req := mustNewHTTPRequest("http://qux:4001") req := mustNewHTTPRequest("http://qux:4001")
if rd := s.FormRedirect(req, "http://foo:4001"); rd != "http://foo:4001" { rd, err := s.FormRedirect(req)
t.Fatal("failed to form redirect for simple URL") if err != nil {
t.Fatalf("failed to form redirect: %s", err.Error())
}
if exp, got := "http://foo:4001", rd; exp != got {
t.Fatalf("incorrect redirect, exp: %s, got: %s", exp, got)
} }
} }
func Test_FormRedirectParam(t *testing.T) { func Test_FormRedirectParam(t *testing.T) {
m := &MockStore{} m := &MockStore{
c := &mockClusterService{} leaderAddr: "foo:4002",
}
c := &mockClusterService{
apiAddr: "http://foo:4001",
}
s := New("127.0.0.1:0", m, c, nil) s := New("127.0.0.1:0", m, c, nil)
req := mustNewHTTPRequest("http://qux:4001/db/query?x=y") req := mustNewHTTPRequest("http://qux:4001/db/query?x=y")
if rd := s.FormRedirect(req, "http://foo:4001"); rd != "http://foo:4001/db/query?x=y" { rd, err := s.FormRedirect(req)
t.Fatal("failed to form redirect for URL") if err != nil {
t.Fatalf("failed to form redirect: %s", err.Error())
}
if exp, got := "http://foo:4001/db/query?x=y", rd; rd != got {
t.Fatalf("incorrect redirect, exp: %s, got: %s", exp, got)
} }
} }
func Test_FormRedirectHTTPS(t *testing.T) { func Test_FormRedirectHTTPS(t *testing.T) {
m := &MockStore{} m := &MockStore{
c := &mockClusterService{} leaderAddr: "foo:4002",
}
c := &mockClusterService{
apiAddr: "https://foo:4001",
}
s := New("127.0.0.1:0", m, c, nil)
req := mustNewHTTPRequest("http://qux:4001")
rd, err := s.FormRedirect(req)
if err != nil {
t.Fatalf("failed to form redirect: %s", err.Error())
}
if exp, got := "https://foo:4001", rd; exp != got {
t.Fatalf("incorrect redirect, exp: %s, got: %s", exp, got)
}
}
func Test_DoRedirect(t *testing.T) {
m := &MockStore{
leaderAddr: "foo:4002",
}
c := &mockClusterService{
apiAddr: "https://foo:4001",
}
s := New("127.0.0.1:0", m, c, nil) s := New("127.0.0.1:0", m, c, nil)
req := mustNewHTTPRequest("http://qux:4001") req := mustNewHTTPRequest("http://qux:4001")
if rd := s.FormRedirect(req, "https://foo:4001"); rd != "https://foo:4001" { if s.DoRedirect(nil, req) {
t.Fatal("failed to form redirect for simple URL") t.Fatalf("incorrectly redirected")
}
req = mustNewHTTPRequest("http://qux:4001/db/query?redirect")
w := httptest.NewRecorder()
if !s.DoRedirect(w, req) {
t.Fatalf("incorrectly not redirected")
}
if exp, got := http.StatusMovedPermanently, w.Code; exp != got {
t.Fatalf("incorrect redirect code, exp: %d, got: %d", exp, got)
}
// check location header
if exp, got := "https://foo:4001/db/query?redirect", w.Header().Get("Location"); exp != got {
t.Fatalf("incorrect redirect location, exp: %s, got: %s", exp, got)
} }
} }

Loading…
Cancel
Save