diff --git a/http/service.go b/http/service.go index 2b6950f9..bc413fa5 100644 --- a/http/service.go +++ b/http/service.go @@ -2,7 +2,6 @@ package http import ( - "bufio" "context" "crypto/tls" "encoding/json" @@ -23,7 +22,6 @@ import ( "github.com/rqlite/rqlite/auth" "github.com/rqlite/rqlite/cluster" "github.com/rqlite/rqlite/command" - "github.com/rqlite/rqlite/command/chunking" "github.com/rqlite/rqlite/command/encoding" "github.com/rqlite/rqlite/db" "github.com/rqlite/rqlite/queue" @@ -68,6 +66,9 @@ type Database interface { // LoadChunk loads a SQLite database into the node, chunk by chunk. LoadChunk(lc *command.LoadChunkRequest) error + + // Load loads a SQLite file into the system + Load(lr *command.LoadRequest) error } // Store is the interface the Raft-based database must implement. @@ -118,6 +119,9 @@ type Cluster interface { // LoadChunk loads a SQLite database into the node, chunk by chunk. LoadChunk(lc *command.LoadChunkRequest, nodeAddr string, creds *cluster.Credentials, timeout time.Duration) error + // Load loads a SQLite database into the node. + Load(lr *command.LoadRequest, nodeAddr string, creds *cluster.Credentials, timeout time.Duration) error + // RemoveNode removes a node from the cluster. RemoveNode(rn *command.RemoveNodeRequest, nodeAddr string, creds *cluster.Credentials, timeout time.Duration) error @@ -648,8 +652,6 @@ func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request, qp QueryP // handleLoad loads the database from the given SQLite database file or SQLite dump. func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request, qp QueryParams) { - startTime := time.Now() - if !s.CheckRequestPerm(r, auth.PermLoad) { w.WriteHeader(http.StatusUnauthorized) return @@ -661,33 +663,68 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request, qp QueryPar } resp := NewResponse() + b, err := io.ReadAll(r.Body) + if err != nil { + http.Error(w, err.Error(), http.StatusBadRequest) + return + } + r.Body.Close() - // Peek at the incoming bytes so we can determine if this is a SQLite database - validSQLite := false - bufReader := bufio.NewReader(r.Body) - peek, err := bufReader.Peek(db.SQLiteHeaderSize) - if err == nil { - validSQLite = db.IsValidSQLiteData(peek) - if validSQLite { - s.logger.Printf("SQLite database file detected as load data") - if db.IsWALModeEnabled(peek) { - s.logger.Printf("SQLite database file is in WAL mode - rejecting load request") - http.Error(w, `SQLite database file is in WAL mode - convert it to DELETE mode via 'PRAGMA journal_mode=DELETE'`, - http.StatusBadRequest) - return - } + if db.IsValidSQLiteData(b) { + s.logger.Printf("SQLite database file detected as load data") + lr := &command.LoadRequest{ + Data: b, } - } - if !validSQLite { - // Assume SQL text - b, err := io.ReadAll(bufReader) - if err != nil { - http.Error(w, err.Error(), http.StatusBadRequest) + if db.IsWALModeEnabled(b) { + s.logger.Printf("SQLite database file is in WAL mode - rejecting load request") + http.Error(w, `SQLite database file is in WAL mode - convert it to DELETE mode via 'PRAGMA journal_mode=DELETE'`, + http.StatusBadRequest) return } - r.Body.Close() + err := s.store.Load(lr) + if err != nil && err != store.ErrNotLeader { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } else if err != nil && err == store.ErrNotLeader { + if s.DoRedirect(w, r, qp) { + return + } + + addr, err := s.store.LeaderAddr() + if err != nil { + http.Error(w, fmt.Sprintf("leader address: %s", err.Error()), + http.StatusInternalServerError) + return + } + if addr == "" { + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) + return + } + + username, password, ok := r.BasicAuth() + if !ok { + username = "" + } + + w.Header().Add(ServedByHTTPHeader, addr) + loadErr := s.cluster.Load(lr, addr, makeCredentials(username, password), qp.Timeout(defaultTimeout)) + if loadErr != nil { + if loadErr.Error() == "unauthorized" { + http.Error(w, "remote load not authorized", http.StatusUnauthorized) + } else { + http.Error(w, loadErr.Error(), http.StatusInternalServerError) + } + return + } + stats.Add(numRemoteLoads, 1) + // Allow this if block to exit, so response remains as before request + // forwarding was put in place. + } + } else { + // No JSON structure expected for this API. queries := []string{string(b)} er := executeRequestFromStrings(queries, qp.Timings(), false) @@ -703,81 +740,10 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request, qp QueryPar resp.Results.ExecuteResult = results } resp.end = time.Now() - } else { - chunker := chunking.NewChunker(bufReader, int64(qp.ChunkKB(defaultChunkSize))) - for { - chunk, err := chunker.Next() - if err != nil { - chunk = chunker.Abort() - } - err = s.store.LoadChunk(chunk) - if err != nil && err != store.ErrNotLeader { - s.store.LoadChunk(chunker.Abort()) - http.Error(w, err.Error(), http.StatusInternalServerError) - return - } else if err != nil && err == store.ErrNotLeader { - if s.DoRedirect(w, r, qp) { - return - } - - addr, err := s.loadClusterChunk(r, qp, chunk) - if err != nil { - if err == ErrRemoteLoadNotAuthorized { - http.Error(w, err.Error(), http.StatusUnauthorized) - } else if err == ErrLeaderNotFound { - http.Error(w, err.Error(), http.StatusServiceUnavailable) - } else { - http.Error(w, err.Error(), http.StatusInternalServerError) - } - s.loadClusterChunk(r, qp, chunker.Abort()) - return - } - w.Header().Add(ServedByHTTPHeader, addr) - // Allow this if block to exit without return, so response remains as before request - // forwarding was put in place. - } - nChunks, nr, nw := chunker.Counts() - if chunk.IsLast { - s.logger.Printf("%d bytes read, %d chunks generated, containing %d bytes of compressed data (compression ratio %.2f)", - nr, nChunks, nw, float64(nr)/float64(nw)) - break - } - if chunk.Abort { - stats.Add(numLoadAborted, 1) - s.logger.Printf("load request aborted after %d bytes read, %d chunks generated", nr, nChunks) - break - } - } } - - s.logger.Printf("load request finished in %s", time.Since(startTime).String()) s.writeResponse(w, r, qp, resp) } -func (s *Service) loadClusterChunk(r *http.Request, qp QueryParams, chunk *command.LoadChunkRequest) (string, error) { - addr, err := s.store.LeaderAddr() - if err != nil { - return "", err - } - if addr == "" { - stats.Add(numLeaderNotFound, 1) - return "", ErrLeaderNotFound - } - username, password, ok := r.BasicAuth() - if !ok { - username = "" - } - err = s.cluster.LoadChunk(chunk, addr, makeCredentials(username, password), qp.Timeout(defaultTimeout)) - if err != nil { - if err.Error() == "unauthorized" { - return "", ErrRemoteLoadNotAuthorized - } - return "", err - } - stats.Add(numRemoteLoads, 1) - return addr, nil -} - // handleStatus returns status on the system. func (s *Service) handleStatus(w http.ResponseWriter, r *http.Request, qp QueryParams) { w.Header().Set("Content-Type", "application/json; charset=utf-8") diff --git a/http/service_test.go b/http/service_test.go index 7c47c4a2..df1ef8bc 100644 --- a/http/service_test.go +++ b/http/service_test.go @@ -1281,6 +1281,7 @@ type MockStore struct { queryFn func(qr *command.QueryRequest) ([]*command.QueryRows, error) requestFn func(eqr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, error) backupFn func(br *command.BackupRequest, dst io.Writer) error + loadFn func(lr *command.LoadRequest) error loadChunkFn func(lr *command.LoadChunkRequest) error leaderAddr string notReady bool // Default value is true, easier to test. @@ -1349,12 +1350,20 @@ func (m *MockStore) LoadChunk(lc *command.LoadChunkRequest) error { return nil } +func (m *MockStore) Load(lr *command.LoadRequest) error { + if m.loadFn != nil { + return m.loadFn(lr) + } + return nil +} + type mockClusterService struct { apiAddr string executeFn func(er *command.ExecuteRequest, addr string, t time.Duration) ([]*command.ExecuteResult, error) queryFn func(qr *command.QueryRequest, addr string, t time.Duration) ([]*command.QueryRows, error) requestFn func(eqr *command.ExecuteQueryRequest, nodeAddr string, timeout time.Duration) ([]*command.ExecuteQueryResponse, error) backupFn func(br *command.BackupRequest, addr string, t time.Duration, w io.Writer) error + loadFn func(lr *command.LoadRequest) error loadChunkFn func(lc *command.LoadChunkRequest, addr string, t time.Duration) error removeNodeFn func(rn *command.RemoveNodeRequest, nodeAddr string, t time.Duration) error } @@ -1398,6 +1407,13 @@ func (m *mockClusterService) LoadChunk(lc *command.LoadChunkRequest, addr string return nil } +func (m *mockClusterService) Load(lr *command.LoadRequest, nodeAddr string, creds *cluster.Credentials, timeout time.Duration) error { + if m.loadFn != nil { + return m.loadFn(lr) + } + return nil +} + func (m *mockClusterService) RemoveNode(rn *command.RemoveNodeRequest, addr string, creds *cluster.Credentials, t time.Duration) error { if m.removeNodeFn != nil { return m.removeNodeFn(rn, addr, t) diff --git a/store/store.go b/store/store.go index bb1bcbde..46ec76c3 100644 --- a/store/store.go +++ b/store/store.go @@ -2056,7 +2056,14 @@ func (s *Store) installRestore() error { return err } defer f.Close() - return s.loadFromReader(f, s.restoreChunkSize) + b, err := io.ReadAll(f) + if err != nil { + return err + } + lr := &command.LoadRequest{ + Data: b, + } + return s.load(lr) } // logSize returns the size of the Raft log on disk.