From 3a0f776e57ac03b3d0d9806308c900a207f8d003 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 9 Dec 2023 11:11:05 -0500 Subject: [PATCH 1/4] Refactor redirect --- http/service.go | 79 ++++++++++++++++++-------------------------- http/service_test.go | 50 +++++++++++++++++++++------- 2 files changed, 71 insertions(+), 58 deletions(-) diff --git a/http/service.go b/http/service.go index b2e40510..26a6939c 100644 --- a/http/service.go +++ b/http/service.go @@ -544,14 +544,11 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) { if err != nil { if err == store.ErrNotLeader { if redirect { - leaderAPIAddr := s.LeaderAPIAddr() - if leaderAPIAddr == "" { - stats.Add(numLeaderNotFound, 1) - http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) + redirect, err := s.FormRedirect(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) return } - - redirect := s.FormRedirect(r, leaderAPIAddr) http.Redirect(w, r, redirect, http.StatusMovedPermanently) return } @@ -642,14 +639,11 @@ func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request) { if err != nil { if err == store.ErrNotLeader { if redirect { - leaderAPIAddr := s.LeaderAPIAddr() - if leaderAPIAddr == "" { - stats.Add(numLeaderNotFound, 1) - http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) + redirect, err := s.FormRedirect(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) return } - - redirect := s.FormRedirect(r, leaderAPIAddr) http.Redirect(w, r, redirect, http.StatusMovedPermanently) return } @@ -767,14 +761,11 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) { results, err := s.store.Execute(er) if err != nil { if err == store.ErrNotLeader { - leaderAPIAddr := s.LeaderAPIAddr() - if leaderAPIAddr == "" { - stats.Add(numLeaderNotFound, 1) - http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) + redirect, err := s.FormRedirect(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) return } - - redirect := s.FormRedirect(r, leaderAPIAddr) http.Redirect(w, r, redirect, http.StatusMovedPermanently) return } @@ -788,7 +779,7 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) { for { chunk, err := chunker.Next() if err != nil { - http.Error(w, err.Error(), http.StatusServiceUnavailable) + http.Error(w, err.Error(), http.StatusInternalServerError) return } err = s.store.LoadChunk(chunk) @@ -797,14 +788,11 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) { return } else if err != nil && err == store.ErrNotLeader { if redirect { - leaderAPIAddr := s.LeaderAPIAddr() - if leaderAPIAddr == "" { - stats.Add(numLeaderNotFound, 1) - http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) + redirect, err := s.FormRedirect(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) return } - - redirect := s.FormRedirect(r, leaderAPIAddr) http.Redirect(w, r, redirect, http.StatusMovedPermanently) return } @@ -1268,14 +1256,12 @@ func (s *Service) execute(w http.ResponseWriter, r *http.Request) { results, resultsErr := s.store.Execute(er) if resultsErr != nil && resultsErr == store.ErrNotLeader { if redirect { - leaderAPIAddr := s.LeaderAPIAddr() - if leaderAPIAddr == "" { - stats.Add(numLeaderNotFound, 1) - http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) + redirect, err := s.FormRedirect(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) return } - loc := s.FormRedirect(r, leaderAPIAddr) - http.Redirect(w, r, loc, http.StatusMovedPermanently) + http.Redirect(w, r, redirect, http.StatusMovedPermanently) return } @@ -1370,14 +1356,12 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) { results, resultsErr := s.store.Query(qr) if resultsErr != nil && resultsErr == store.ErrNotLeader { if redirect { - leaderAPIAddr := s.LeaderAPIAddr() - if leaderAPIAddr == "" { - stats.Add(numLeaderNotFound, 1) - http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) + redirect, err := s.FormRedirect(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) return } - loc := s.FormRedirect(r, leaderAPIAddr) - http.Redirect(w, r, loc, http.StatusMovedPermanently) + http.Redirect(w, r, redirect, http.StatusMovedPermanently) return } @@ -1471,14 +1455,12 @@ func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request) { results, resultErr := s.store.Request(eqr) if resultErr != nil && resultErr == store.ErrNotLeader { if redirect { - leaderAPIAddr := s.LeaderAPIAddr() - if leaderAPIAddr == "" { - stats.Add(numLeaderNotFound, 1) - http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) + redirect, err := s.FormRedirect(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) return } - loc := s.FormRedirect(r, leaderAPIAddr) - http.Redirect(w, r, loc, http.StatusMovedPermanently) + http.Redirect(w, r, redirect, http.StatusMovedPermanently) return } @@ -1567,12 +1549,18 @@ func (s *Service) Addr() net.Addr { } // FormRedirect returns the value for the "Location" header for a 301 response. -func (s *Service) FormRedirect(r *http.Request, url string) string { +func (s *Service) FormRedirect(r *http.Request) (string, error) { + leaderAPIAddr := s.LeaderAPIAddr() + if leaderAPIAddr == "" { + stats.Add(numLeaderNotFound, 1) + return "", ErrLeaderNotFound + } + rq := r.URL.RawQuery if rq != "" { rq = fmt.Sprintf("?%s", rq) } - return fmt.Sprintf("%s%s%s", url, r.URL.Path, rq) + return fmt.Sprintf("%s%s%s", leaderAPIAddr, r.URL.Path, rq), nil } // CheckRequestPerm checks if the request is authenticated and authorized @@ -1636,7 +1624,6 @@ func (s *Service) LeaderAPIAddr() string { } apiAddr, err := s.cluster.GetNodeAPIAddr(nodeAddr, defaultTimeout) - if err != nil { return "" } diff --git a/http/service_test.go b/http/service_test.go index e770a20e..93140cf1 100644 --- a/http/service_test.go +++ b/http/service_test.go @@ -826,36 +826,62 @@ func Test_RegisterStatus(t *testing.T) { } func Test_FormRedirect(t *testing.T) { - m := &MockStore{} - c := &mockClusterService{} + m := &MockStore{ + leaderAddr: "foo:4002", + } + c := &mockClusterService{ + apiAddr: "http://foo:4001", + } s := New("127.0.0.1:0", m, c, nil) req := mustNewHTTPRequest("http://qux:4001") - if rd := s.FormRedirect(req, "http://foo:4001"); rd != "http://foo:4001" { - t.Fatal("failed to form redirect for simple URL") + rd, err := s.FormRedirect(req) + if err != nil { + t.Fatalf("failed to form redirect: %s", err.Error()) + } + if exp, got := "http://foo:4001", rd; exp != got { + t.Fatalf("incorrect redirect, exp: %s, got: %s", exp, got) } } func Test_FormRedirectParam(t *testing.T) { - m := &MockStore{} - c := &mockClusterService{} + m := &MockStore{ + leaderAddr: "foo:4002", + } + c := &mockClusterService{ + apiAddr: "http://foo:4001", + } s := New("127.0.0.1:0", m, c, nil) req := mustNewHTTPRequest("http://qux:4001/db/query?x=y") - if rd := s.FormRedirect(req, "http://foo:4001"); rd != "http://foo:4001/db/query?x=y" { - t.Fatal("failed to form redirect for URL") + rd, err := s.FormRedirect(req) + if err != nil { + t.Fatalf("failed to form redirect: %s", err.Error()) + } + + if exp, got := "http://foo:4001/db/query?x=y", rd; rd != got { + t.Fatalf("incorrect redirect, exp: %s, got: %s", exp, got) } } func Test_FormRedirectHTTPS(t *testing.T) { - m := &MockStore{} - c := &mockClusterService{} + m := &MockStore{ + leaderAddr: "foo:4002", + } + c := &mockClusterService{ + apiAddr: "https://foo:4001", + } + s := New("127.0.0.1:0", m, c, nil) req := mustNewHTTPRequest("http://qux:4001") - if rd := s.FormRedirect(req, "https://foo:4001"); rd != "https://foo:4001" { - t.Fatal("failed to form redirect for simple URL") + rd, err := s.FormRedirect(req) + if err != nil { + t.Fatalf("failed to form redirect: %s", err.Error()) + } + if exp, got := "https://foo:4001", rd; exp != got { + t.Fatalf("incorrect redirect, exp: %s, got: %s", exp, got) } } From e9a0a9ebfa06babe2406191f43d1f3b1beb18151 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 9 Dec 2023 11:24:13 -0500 Subject: [PATCH 2/4] Roll all Redirect logic into single function --- http/service.go | 135 +++++++++++++++++------------------------------- 1 file changed, 48 insertions(+), 87 deletions(-) diff --git a/http/service.go b/http/service.go index 26a6939c..e7996134 100644 --- a/http/service.go +++ b/http/service.go @@ -502,12 +502,6 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) { return } - redirect, err := isRedirect(r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - b, err := io.ReadAll(r.Body) if err != nil { w.WriteHeader(http.StatusBadRequest) @@ -543,13 +537,7 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) { err = s.store.Remove(rn) if err != nil { if err == store.ErrNotLeader { - if redirect { - redirect, err := s.FormRedirect(r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - http.Redirect(w, r, redirect, http.StatusMovedPermanently) + if s.DoRedirect(w, r) { return } @@ -605,11 +593,6 @@ func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } - redirect, err := isRedirect(r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } format, err := backupFormat(w, r) if err != nil { @@ -638,13 +621,7 @@ func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request) { err = s.store.Backup(br, w) if err != nil { if err == store.ErrNotLeader { - if redirect { - redirect, err := s.FormRedirect(r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - http.Redirect(w, r, redirect, http.StatusMovedPermanently) + if s.DoRedirect(w, r) { return } @@ -716,12 +693,6 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) { return } - redirect, err := isRedirect(r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - chunkSz, err := chunkSizeParam(r, defaultChunkSize) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -761,13 +732,9 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) { results, err := s.store.Execute(er) if err != nil { if err == store.ErrNotLeader { - redirect, err := s.FormRedirect(r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) + if s.DoRedirect(w, r) { return } - http.Redirect(w, r, redirect, http.StatusMovedPermanently) - return } resp.Error = err.Error() } else { @@ -787,13 +754,7 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } else if err != nil && err == store.ErrNotLeader { - if redirect { - redirect, err := s.FormRedirect(r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - http.Redirect(w, r, redirect, http.StatusMovedPermanently) + if s.DoRedirect(w, r) { return } @@ -1221,7 +1182,7 @@ func (s *Service) queuedExecute(w http.ResponseWriter, r *http.Request) { func (s *Service) execute(w http.ResponseWriter, r *http.Request) { resp := NewResponse() - timeout, isTx, timings, redirect, noRewriteRandom, err := reqParams(r, defaultTimeout) + timeout, isTx, timings, noRewriteRandom, err := reqParams(r, defaultTimeout) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -1255,13 +1216,7 @@ func (s *Service) execute(w http.ResponseWriter, r *http.Request) { results, resultsErr := s.store.Execute(er) if resultsErr != nil && resultsErr == store.ErrNotLeader { - if redirect { - redirect, err := s.FormRedirect(r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - http.Redirect(w, r, redirect, http.StatusMovedPermanently) + if s.DoRedirect(w, r) { return } @@ -1317,7 +1272,7 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) { return } - timeout, frsh, lvl, isTx, timings, redirect, noRewriteRandom, isAssoc, err := queryReqParams(r, defaultTimeout) + timeout, frsh, lvl, isTx, timings, noRewriteRandom, isAssoc, err := queryReqParams(r, defaultTimeout) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -1355,13 +1310,7 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) { results, resultsErr := s.store.Query(qr) if resultsErr != nil && resultsErr == store.ErrNotLeader { - if redirect { - redirect, err := s.FormRedirect(r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - http.Redirect(w, r, redirect, http.StatusMovedPermanently) + if s.DoRedirect(w, r) { return } @@ -1414,7 +1363,7 @@ func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request) { return } - timeout, frsh, lvl, isTx, timings, redirect, noRewriteRandom, isAssoc, err := executeQueryReqParams(r, defaultTimeout) + timeout, frsh, lvl, isTx, timings, noRewriteRandom, isAssoc, err := executeQueryReqParams(r, defaultTimeout) if err != nil { http.Error(w, err.Error(), http.StatusBadRequest) return @@ -1454,13 +1403,7 @@ func (s *Service) handleRequest(w http.ResponseWriter, r *http.Request) { results, resultErr := s.store.Request(eqr) if resultErr != nil && resultErr == store.ErrNotLeader { - if redirect { - redirect, err := s.FormRedirect(r) - if err != nil { - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } - http.Redirect(w, r, redirect, http.StatusMovedPermanently) + if s.DoRedirect(w, r) { return } @@ -1548,6 +1491,28 @@ func (s *Service) Addr() net.Addr { return s.ln.Addr() } +// DoRedirect checks if the request is a redirect, and if so, performs the redirect. +// Returns true caller can consider the request handled. Returns false if the request +// was not a redirect and the caller should continue processing the request. +func (s *Service) DoRedirect(w http.ResponseWriter, r *http.Request) bool { + redirect, err := isRedirect(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return true + } + if !redirect { + return false + } + + rd, err := s.FormRedirect(r) + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + } else { + http.Redirect(w, r, rd, http.StatusMovedPermanently) + } + return true +} + // FormRedirect returns the value for the "Location" header for a 301 response. func (s *Service) FormRedirect(r *http.Request) (string, error) { leaderAPIAddr := s.LeaderAPIAddr() @@ -1919,61 +1884,57 @@ func isQueue(req *http.Request) (bool, error) { // reqParams is a convenience function to get a bunch of query params // in one function call. -func reqParams(req *http.Request, def time.Duration) (timeout time.Duration, tx, timings, redirect, noRwRandom bool, err error) { +func reqParams(req *http.Request, def time.Duration) (timeout time.Duration, tx, timings, noRwRandom bool, err error) { timeout, err = timeoutParam(req, def) if err != nil { - return 0, false, false, false, true, err + return 0, false, false, true, err } tx, err = isTx(req) if err != nil { - return 0, false, false, false, true, err + return 0, false, false, true, err } timings, err = isTimings(req) if err != nil { - return 0, false, false, false, true, err - } - redirect, err = isRedirect(req) - if err != nil { - return 0, false, false, false, true, err + return 0, false, false, true, err } noRwRandom, err = noRewriteRandom(req) if err != nil { - return 0, false, false, false, true, err + return 0, false, false, true, err } - return timeout, tx, timings, redirect, noRwRandom, nil + return timeout, tx, timings, noRwRandom, nil } // queryReqParams is a convenience function to get a bunch of query params // in one function call. -func queryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, redirect, noRwRandom, isAssoc bool, err error) { - timeout, isTx, timings, redirect, noRwRandom, err = reqParams(req, defaultTimeout) +func queryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, noRwRandom, isAssoc bool, err error) { + timeout, isTx, timings, noRwRandom, err = reqParams(req, defaultTimeout) if err != nil { - return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err + return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err } lvl, err = level(req) if err != nil { - return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err + return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err } frsh, err = freshness(req) if err != nil { - return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err + return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err } isAssoc, err = isAssociative(req) if err != nil { - return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err + return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err } return } -func executeQueryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, redirect, noRwRandom, isAssoc bool, err error) { - timeout, frsh, lvl, isTx, timings, redirect, noRwRandom, isAssoc, err = queryReqParams(req, defaultTimeout) +func executeQueryReqParams(req *http.Request, def time.Duration) (timeout, frsh time.Duration, lvl command.QueryRequest_Level, isTx, timings, noRwRandom, isAssoc bool, err error) { + timeout, frsh, lvl, isTx, timings, noRwRandom, isAssoc, err = queryReqParams(req, defaultTimeout) if err != nil { - return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, false, err + return 0, 0, command.QueryRequest_QUERY_REQUEST_LEVEL_WEAK, false, false, false, false, err } - return timeout, frsh, lvl, isTx, timings, redirect, noRwRandom, isAssoc, nil + return timeout, frsh, lvl, isTx, timings, noRwRandom, isAssoc, nil } // noLeader returns whether processing should skip the leader check. From 4470f751a33799f94a36fdf69cdf5dfe02db18c3 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 9 Dec 2023 11:25:27 -0500 Subject: [PATCH 3/4] CHANGELOG --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index b6ba57ca..0c5b1451 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ### Implementation changes and bug fixes - [PR #1456](https://github.com/rqlite/rqlite/pull/1459): Standardize on chunk size. - [PR #1456](https://github.com/rqlite/rqlite/pull/1459): Set `TrailingLogs=0` to truncate log during user-initiated Snapshotting. +- [PR #1462](https://github.com/rqlite/rqlite/pull/1462): Refactor redirect logic in HTTP service ## 8.0.1 (December 8th 2023) This release fixes an edge case issue during restore-from-SQLite. It's possible if a rqlite system crashes shortly after restoring from SQLite it may not have loaded the data correctly. From 2c57bfa99d25c96dba264bdd5539185059127a1c Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 9 Dec 2023 11:32:52 -0500 Subject: [PATCH 4/4] Unit test DoRedirect --- http/service_test.go | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/http/service_test.go b/http/service_test.go index 93140cf1..d1562c9e 100644 --- a/http/service_test.go +++ b/http/service_test.go @@ -8,6 +8,7 @@ import ( "io" "io/ioutil" "net/http" + "net/http/httptest" "net/url" "os" "strings" @@ -885,6 +886,34 @@ func Test_FormRedirectHTTPS(t *testing.T) { } } +func Test_DoRedirect(t *testing.T) { + m := &MockStore{ + leaderAddr: "foo:4002", + } + c := &mockClusterService{ + apiAddr: "https://foo:4001", + } + s := New("127.0.0.1:0", m, c, nil) + req := mustNewHTTPRequest("http://qux:4001") + + if s.DoRedirect(nil, req) { + t.Fatalf("incorrectly redirected") + } + + req = mustNewHTTPRequest("http://qux:4001/db/query?redirect") + w := httptest.NewRecorder() + if !s.DoRedirect(w, req) { + t.Fatalf("incorrectly not redirected") + } + if exp, got := http.StatusMovedPermanently, w.Code; exp != got { + t.Fatalf("incorrect redirect code, exp: %d, got: %d", exp, got) + } + // check location header + if exp, got := "https://foo:4001/db/query?redirect", w.Header().Get("Location"); exp != got { + t.Fatalf("incorrect redirect location, exp: %s, got: %s", exp, got) + } +} + func Test_Nodes(t *testing.T) { m := &MockStore{ leaderAddr: "foo:1234",