1
0
Fork 0

Add handler for /boot

master
Philip O'Toole 9 months ago
parent 2ff62dfaac
commit 78566ea45e

@ -2,6 +2,7 @@
package http
import (
"bufio"
"context"
"crypto/tls"
"encoding/json"
@ -57,7 +58,7 @@ type Database interface {
// an Execute or Query request.
Request(eqr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, error)
// Load loads a SQLite file into the system
// Load loads a SQLite file into the system via Raft consensus.
Load(lr *command.LoadRequest) error
}
@ -82,6 +83,11 @@ type Store interface {
// Backup writes backup of the node state to dst
Backup(br *command.BackupRequest, dst io.Writer) error
// ReadFrom reads and loads a SQLite database into the node, initially bypassing
// the Raft system. It then triggers a Raft snapshot, which will then make
// Raft aware of the new data.
ReadFrom(r io.Reader) (int64, error)
}
// GetAddresser is the interface that wraps the GetNodeAPIAddr method.
@ -218,6 +224,7 @@ const (
numBackups = "backups"
numLoad = "loads"
numLoadAborted = "loads_aborted"
numBoot = "boot"
numAuthOK = "authOK"
numAuthFail = "authFail"
@ -284,6 +291,7 @@ func ResetStats() {
stats.Add(numBackups, 0)
stats.Add(numLoad, 0)
stats.Add(numLoadAborted, 0)
stats.Add(numBoot, 0)
stats.Add(numAuthOK, 0)
stats.Add(numAuthFail, 0)
}
@ -459,6 +467,9 @@ func (s *Service) ServeHTTP(w http.ResponseWriter, r *http.Request) {
case strings.HasPrefix(r.URL.Path, "/db/load"):
stats.Add(numLoad, 1)
s.handleLoad(w, r, params)
case strings.HasPrefix(r.URL.Path, "/boot"):
stats.Add(numBoot, 1)
s.handleBoot(w, r, params)
case strings.HasPrefix(r.URL.Path, "/remove"):
s.handleRemove(w, r, params)
case strings.HasPrefix(r.URL.Path, "/status"):
@ -731,6 +742,36 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request, qp QueryPar
s.writeResponse(w, r, qp, resp)
}
// handleBoot handles booting this node using a SQLite file.
func (s *Service) handleBoot(w http.ResponseWriter, r *http.Request, qp QueryParams) {
if !s.CheckRequestPerm(r, auth.PermLoad) {
w.WriteHeader(http.StatusUnauthorized)
return
}
if r.Method != "POST" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
bufReader := bufio.NewReader(r.Body)
peek, err := bufReader.Peek(db.SQLiteHeaderSize)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
if !db.IsValidSQLiteData(peek) {
http.Error(w, "invalid SQLite data", http.StatusBadRequest)
return
}
_, err = s.store.ReadFrom(bufReader)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
}
// handleStatus returns status on the system.
func (s *Service) handleStatus(w http.ResponseWriter, r *http.Request, qp QueryParams) {
w.Header().Set("Content-Type", "application/json; charset=utf-8")

@ -269,6 +269,22 @@ func Test_404Routes(t *testing.T) {
}
func Test_405Routes(t *testing.T) {
type testCase struct {
method string
path string
}
tests := []testCase{
{method: "GET", path: "/db/execute"},
{method: "GET", path: "/boot"},
{method: "GET", path: "/db/load"},
{method: "GET", path: "/remove"},
{method: "POST", path: "/remove"},
{method: "POST", path: "/db/backup"},
{method: "POST", path: "/status"},
{method: "POST", path: "/nodes"},
}
m := &MockStore{}
c := &mockClusterService{}
s := New("127.0.0.1:0", m, c, nil)
@ -280,44 +296,25 @@ func Test_405Routes(t *testing.T) {
client := &http.Client{}
resp, err := client.Get(host + "/db/execute")
if err != nil {
t.Fatalf("failed to make request")
}
if resp.StatusCode != 405 {
t.Fatalf("failed to get expected 405, got %d", resp.StatusCode)
}
resp, err = client.Get(host + "/remove")
if err != nil {
t.Fatalf("failed to make request")
}
if resp.StatusCode != 405 {
t.Fatalf("failed to get expected 405, got %d", resp.StatusCode)
}
resp, err = client.Post(host+"/remove", "", nil)
if err != nil {
t.Fatalf("failed to make request")
}
if resp.StatusCode != 405 {
t.Fatalf("failed to get expected 405, got %d", resp.StatusCode)
}
resp, err = client.Post(host+"/db/backup", "", nil)
if err != nil {
t.Fatalf("failed to make request")
}
if resp.StatusCode != 405 {
t.Fatalf("failed to get expected 405, got %d", resp.StatusCode)
}
for _, tc := range tests {
var resp *http.Response
var err error
switch tc.method {
case "GET":
resp, err = client.Get(host + tc.path)
case "POST":
resp, err = client.Post(host+tc.path, "", nil)
default:
t.Fatalf("unsupported method: %s", tc.method)
}
resp, err = client.Post(host+"/status", "", nil)
if err != nil {
t.Fatalf("failed to make request")
}
if resp.StatusCode != 405 {
t.Fatalf("failed to get expected 405, got %d", resp.StatusCode)
if err != nil {
t.Fatalf("failed to make request for %s %s", tc.method, tc.path)
}
if resp.StatusCode != 405 {
t.Fatalf("failed to get expected 405 for %s %s, got %d", tc.method, tc.path, resp.StatusCode)
}
}
}
@ -355,13 +352,13 @@ func Test_401Routes_NoBasicAuth(t *testing.T) {
host := fmt.Sprintf("http://%s", s.Addr().String())
client := &http.Client{}
for _, path := range []string{
"/db/execute",
"/db/query",
"/db/request",
"/db/backup",
"/db/load",
"/boot",
"/remove",
"/status",
"/nodes",
@ -394,13 +391,13 @@ func Test_401Routes_BasicAuthBadPassword(t *testing.T) {
host := fmt.Sprintf("http://%s", s.Addr().String())
client := &http.Client{}
for _, path := range []string{
"/db/execute",
"/db/query",
"/db/request",
"/db/backup",
"/db/load",
"/boot",
"/status",
"/nodes",
"/readyz",
@ -445,6 +442,7 @@ func Test_401Routes_BasicAuthBadPerm(t *testing.T) {
"/db/backup",
"/db/request",
"/db/load",
"/boot",
"/status",
"/nodes",
"/readyz",
@ -808,6 +806,55 @@ func Test_LoadRemoteError(t *testing.T) {
}
}
func Test_Boot(t *testing.T) {
m := &MockStore{
leaderAddr: "foo:1234",
}
c := &mockClusterService{
apiAddr: "http://1.2.3.4:999",
}
s := New("127.0.0.1:0", m, c, nil)
if err := s.Start(); err != nil {
t.Fatalf("failed to start service")
}
defer s.Close()
testData, err := os.ReadFile("testdata/load.db")
if err != nil {
t.Fatalf("failed to load test SQLite data")
}
readFromCalled := false
m.readFromFn = func(r io.Reader) (int64, error) {
// read all data from r and compare to the data in testData
b, err := io.ReadAll(r)
if err != nil {
return 0, err
}
if !bytes.Equal(b, testData) {
t.Fatalf("wrong data passed to ReadFrom")
}
readFromCalled = true
return int64(len(b)), nil
}
client := &http.Client{}
host := fmt.Sprintf("http://%s", s.Addr().String())
resp, err := client.Post(host+"/boot", "application/octet-stream", bytes.NewReader(testData))
if err != nil {
t.Fatalf("failed to make boot request")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
t.Fatalf("failed to get expected StatusOK for boot, got %d", resp.StatusCode)
}
if !readFromCalled {
t.Fatalf("ReadFrom was not called")
}
}
func Test_RegisterStatus(t *testing.T) {
var stats *mockStatusReporter
m := &MockStore{}
@ -1275,6 +1322,7 @@ type MockStore struct {
requestFn func(eqr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, error)
backupFn func(br *command.BackupRequest, dst io.Writer) error
loadFn func(lr *command.LoadRequest) error
readFromFn func(r io.Reader) (int64, error)
leaderAddr string
notReady bool // Default value is true, easier to test.
}
@ -1342,6 +1390,13 @@ func (m *MockStore) Load(lr *command.LoadRequest) error {
return nil
}
func (m *MockStore) ReadFrom(r io.Reader) (int64, error) {
if m.readFromFn != nil {
return m.readFromFn(r)
}
return 0, nil
}
type mockClusterService struct {
apiAddr string
executeFn func(er *command.ExecuteRequest, addr string, t time.Duration) ([]*command.ExecuteResult, error)

Loading…
Cancel
Save