1
0
Fork 0
You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

595 lines
16 KiB
Go

package server
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"net/url"
"path"
"path/filepath"
"strings"
"sync"
"time"
"github.com/goraft/raft"
"github.com/gorilla/mux"
"github.com/otoolep/rqlite/command"
"github.com/otoolep/rqlite/db"
"github.com/otoolep/rqlite/interfaces"
"github.com/otoolep/rqlite/log"
"github.com/rcrowley/go-metrics"
)
// FailedSqlStmt contains a SQL query and an error.
type FailedSqlStmt struct {
Sql string `json:"sql"`
Error string `json:"error"`
}
// StmtResponse contains a date and a list of failed
// SQL statements.
type StmtResponse struct {
Time string `json:"time"`
Failures []FailedSqlStmt `json:"failures"`
}
// QueryResponse contains the response to a query.
type QueryResponse struct {
Time string `json:"time"`
Failures []FailedSqlStmt `json:"failures"`
Rows db.RowResults `json:"rows"`
}
// Metrics are the server metrics user for statistics.
type Metrics struct {
registry metrics.Registry
joinSuccess metrics.Counter
joinFail metrics.Counter
queryReceived metrics.Counter
querySuccess metrics.Counter
queryFail metrics.Counter
executeReceived metrics.Counter
executeTxReceived metrics.Counter
executeSuccess metrics.Counter
executeFail metrics.Counter
snapshotCreated metrics.Counter
}
// Diagnostics contains a start time of the server.
type Diagnostics struct {
startTime time.Time
}
// SnapshotConf contains the index when the last snapshot happened
// and a threshold for index entries since the last snapshot.
type SnapshotConf struct {
// The index when the last snapshot happened
lastIndex uint64
// If the incremental number of index entries since the last
// snapshot exceeds snapshotAfter rqlite will do a snapshot
snapshotAfter uint64
}
// Server is is a combination of the Raft server and an HTTP
// server which acts as the transport.
type Server struct {
name string
host string
port int
path string
router *mux.Router
raftServer raft.Server
httpServer *http.Server
dbPath string
db *db.DB
snapConf *SnapshotConf
metrics *Metrics
diagnostics *Diagnostics
mutex sync.Mutex
}
// ensurePrettyPrint returns a JSON representation of the object o. If
// the HTTP request requested pretty-printing, it ensures that happens.
func ensurePrettyPrint(req *http.Request, o map[string]interface{}) []byte {
var b []byte
pretty, _ := isPretty(req)
if pretty {
b, _ = json.MarshalIndent(o, "", " ")
} else {
b, _ = json.Marshal(o)
}
return b
}
// queryParam returns whether the given query param is set to true.
func queryParam(req *http.Request, param string) (bool, error) {
err := req.ParseForm()
if err != nil {
return false, err
}
if _, ok := req.Form[param]; ok {
return true, nil
}
return false, nil
}
// isPretty returns whether the HTTP response body should be pretty-printed.
func isPretty(req *http.Request) (bool, error) {
return queryParam(req, "pretty")
}
// isTransaction returns whether the client requested an explicit
// transaction for the request.
func isTransaction(req *http.Request) (bool, error) {
return queryParam(req, "transaction")
}
// NewMetrics creates a new Metrics object.
func NewMetrics() *Metrics {
m := &Metrics{
registry: metrics.NewRegistry(),
joinSuccess: metrics.NewCounter(),
joinFail: metrics.NewCounter(),
queryReceived: metrics.NewCounter(),
querySuccess: metrics.NewCounter(),
queryFail: metrics.NewCounter(),
executeReceived: metrics.NewCounter(),
executeTxReceived: metrics.NewCounter(),
executeSuccess: metrics.NewCounter(),
executeFail: metrics.NewCounter(),
snapshotCreated: metrics.NewCounter(),
}
_ = m.registry.Register("join.success", m.joinSuccess)
_ = m.registry.Register("join.fail", m.joinFail)
_ = m.registry.Register("query.Received", m.queryReceived)
_ = m.registry.Register("query.success", m.querySuccess)
_ = m.registry.Register("query.fail", m.queryFail)
_ = m.registry.Register("execute.Received", m.executeReceived)
_ = m.registry.Register("execute.tx.received", m.executeTxReceived)
_ = m.registry.Register("execute.success", m.executeSuccess)
_ = m.registry.Register("execute.fail", m.executeFail)
_ = m.registry.Register("snapshot.created", m.snapshotCreated)
return m
}
// NewDiagnostics creates a new Diagnostics object.
func NewDiagnostics() *Diagnostics {
d := &Diagnostics{
startTime: time.Now(),
}
return d
}
10 years ago
// NewServer creates a new server.
func NewServer(dataDir string, dbfile string, snapAfter int, host string, port int) *Server {
dbPath := path.Join(dataDir, dbfile)
// Raft requires randomness.
rand.Seed(time.Now().UnixNano())
log.Info("Raft random seed initialized")
// Setup commands.
raft.RegisterCommand(&command.ExecuteCommand{})
raft.RegisterCommand(&command.TransactionExecuteCommandSet{})
log.Info("Raft commands registered")
s := &Server{
host: host,
port: port,
path: dataDir,
dbPath: dbPath,
db: db.New(dbPath),
snapConf: &SnapshotConf{snapshotAfter: uint64(snapAfter)},
metrics: NewMetrics(),
diagnostics: NewDiagnostics(),
router: mux.NewRouter(),
}
// Read existing name or generate a new one.
if b, err := ioutil.ReadFile(filepath.Join(dataDir, "name")); err == nil {
s.name = string(b)
} else {
s.name = fmt.Sprintf("%07x", rand.Int())[0:7]
if err = ioutil.WriteFile(filepath.Join(dataDir, "name"), []byte(s.name), 0644); err != nil {
panic(err)
}
}
return s
}
// GetStatistics returns an object storing statistics, which supports JSON
// marshalling.
func (s *Server) GetStatistics() (metrics.Registry, error) {
return s.metrics.registry, nil
}
10 years ago
// connectionString returns the string used to connect to this server.
func (s *Server) connectionString() string {
return fmt.Sprintf("http://%s:%d", s.host, s.port)
}
// logSnapshot logs about the snapshot that was taken.
func (s *Server) logSnapshot(err error, currentIndex, count uint64) {
info := fmt.Sprintf("%s: snapshot of %d events at index %d", s.connectionString(), count, currentIndex)
if err != nil {
log.Infof("%s attempted and failed: %v", info, err)
} else {
log.Infof("%s completed", info)
}
}
10 years ago
// ListenAndServe starts the server.
func (s *Server) ListenAndServe(leader string) error {
var err error
log.Infof("Initializing Raft Server: %s", s.path)
// Initialize and start Raft server.
transporter := raft.NewHTTPTransporter("/raft", 200*time.Millisecond)
stateMachine := NewDbStateMachine(s.dbPath)
s.raftServer, err = raft.NewServer(s.name, s.path, transporter, stateMachine, s.db, "")
if err != nil {
log.Errorf("Failed to create new Raft server: %s", err.Error())
return err
}
log.Info("Loading latest snapshot, if any, from disk")
if err := s.raftServer.LoadSnapshot(); err != nil {
log.Errorf("Error loading snapshot: %s", err.Error())
}
transporter.Install(s.raftServer, s)
if err := s.raftServer.Start(); err != nil {
log.Errorf("Error starting raft server: %s", err.Error())
}
if leader != "" {
// Join to leader if specified.
log.Infof("Attempting to join leader at %s", leader)
if !s.raftServer.IsLogEmpty() {
log.Error("Cannot join with an existing log")
return errors.New("Cannot join with an existing log")
}
if err := s.Join(leader); err != nil {
log.Errorf("Failed to join leader: %s", err.Error())
return err
}
} else if s.raftServer.IsLogEmpty() {
// Initialize the server by joining itself.
log.Info("Initializing new cluster")
_, err := s.raftServer.Do(&raft.DefaultJoinCommand{
Name: s.raftServer.Name(),
ConnectionString: s.connectionString(),
})
if err != nil {
log.Errorf("Failed to join to self: %s", err.Error())
}
} else {
log.Info("Recovered from log")
}
log.Info("Initializing HTTP server")
// Initialize and start HTTP server.
s.httpServer = &http.Server{
Addr: fmt.Sprintf(":%d", s.port),
Handler: s.router,
}
s.router.HandleFunc("/statistics", s.serveStatistics).Methods("GET")
s.router.HandleFunc("/diagnostics", s.serveDiagnostics).Methods("GET")
s.router.HandleFunc("/raft", s.serveRaftInfo).Methods("GET")
s.router.HandleFunc("/db", s.readHandler).Methods("GET")
s.router.HandleFunc("/db", s.writeHandler).Methods("POST")
s.router.HandleFunc("/join", s.joinHandler).Methods("POST")
log.Infof("Listening at %s", s.connectionString())
return s.httpServer.ListenAndServe()
}
// HandleFunc is a hack around Gorilla mux not providing the correct net/http
// HandleFunc() interface.
func (s *Server) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
s.router.HandleFunc(pattern, handler)
}
// Join joins to the leader of an existing cluster.
func (s *Server) Join(leader string) error {
command := &raft.DefaultJoinCommand{
Name: s.raftServer.Name(),
ConnectionString: s.connectionString(),
}
var b bytes.Buffer
if err := json.NewEncoder(&b).Encode(command); err != nil {
return nil
}
resp, err := http.Post(fmt.Sprintf("http://%s/join", leader), "application/json", &b)
if err != nil {
return err
}
defer func() {
_ = resp.Body.Close()
}()
// Look for redirect.
if resp.StatusCode == http.StatusTemporaryRedirect {
leader := resp.Header.Get("Location")
if leader == "" {
return errors.New("Redirect requested, but no location header supplied")
}
u, err := url.Parse(leader)
if err != nil {
return errors.New("Failed to parse redirect location")
}
log.Infof("Redirecting to leader at %s", u.Host)
return s.Join(u.Host)
}
return nil
}
func (s *Server) joinHandler(w http.ResponseWriter, req *http.Request) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.raftServer.State() != "leader" {
s.leaderRedirect(w, req)
return
}
command := &raft.DefaultJoinCommand{}
if err := json.NewDecoder(req.Body).Decode(&command); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
s.metrics.joinFail.Inc(1)
return
}
if _, err := s.raftServer.Do(command); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
s.metrics.joinFail.Inc(1)
return
}
s.metrics.joinSuccess.Inc(1)
}
func (s *Server) readHandler(w http.ResponseWriter, req *http.Request) {
log.Tracef("readHandler for URL: %s", req.URL)
s.metrics.queryReceived.Inc(1)
var failures = make([]FailedSqlStmt, 0)
b, err := ioutil.ReadAll(req.Body)
if err != nil {
log.Tracef("Bad HTTP request: %s", err.Error())
w.WriteHeader(http.StatusBadRequest)
s.metrics.queryFail.Inc(1)
return
}
stmt := string(b)
startTime := time.Now()
r, err := s.db.Query(stmt)
if err != nil {
log.Tracef("Bad SQL statement: %s", err.Error())
s.metrics.queryFail.Inc(1)
failures = append(failures, FailedSqlStmt{stmt, err.Error()})
} else {
s.metrics.querySuccess.Inc(1)
}
duration := time.Since(startTime)
rr := QueryResponse{Time: duration.String(), Failures: failures, Rows: r}
pretty, _ := isPretty(req)
if pretty {
b, err = json.MarshalIndent(rr, "", " ")
} else {
b, err = json.Marshal(rr)
}
if err != nil {
log.Tracef("Failed to marshal JSON data: %s", err.Error())
http.Error(w, err.Error(), http.StatusBadRequest) // Internal error actually
} else {
_, err = w.Write([]byte(b))
if err != nil {
log.Errorf("Error writting JSON data: %s", err.Error())
}
}
}
func (s *Server) execute(tx bool, stmts []string) ([]FailedSqlStmt, error) {
var failures = make([]FailedSqlStmt, 0)
if tx {
log.Trace("Transaction requested")
s.metrics.executeTxReceived.Inc(1)
_, err := s.raftServer.Do(command.NewTransactionExecuteCommandSet(stmts))
if err != nil {
log.Tracef("Transaction failed: %s", err.Error())
s.metrics.executeFail.Inc(1)
failures = append(failures, FailedSqlStmt{stmts[0], err.Error()})
} else {
s.metrics.executeSuccess.Inc(1)
}
} else {
log.Trace("No transaction requested")
for i := range stmts {
_, err := s.raftServer.Do(command.NewExecuteCommand(stmts[i]))
if err != nil {
log.Tracef("Execute statement %s failed: %s", stmts[i], err.Error())
s.metrics.executeFail.Inc(1)
failures = append(failures, FailedSqlStmt{stmts[i], err.Error()})
} else {
s.metrics.executeSuccess.Inc(1)
}
}
}
return failures, nil
}
func (s *Server) writeHandler(w http.ResponseWriter, req *http.Request) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.raftServer.State() != "leader" {
s.leaderRedirect(w, req)
return
}
log.Tracef("writeHandler for URL: %s", req.URL)
s.metrics.executeReceived.Inc(1)
currentIndex := s.raftServer.CommitIndex()
count := currentIndex - s.snapConf.lastIndex
if uint64(count) > s.snapConf.snapshotAfter {
log.Info("Committed log entries snapshot threshold reached, starting snapshot")
err := s.raftServer.TakeSnapshot()
s.logSnapshot(err, currentIndex, count)
s.snapConf.lastIndex = currentIndex
s.metrics.snapshotCreated.Inc(1)
}
var startTime time.Time
// Read the value from the POST body.
b, err := ioutil.ReadAll(req.Body)
if err != nil {
log.Tracef("Bad HTTP request: %s", err.Error())
s.metrics.executeFail.Inc(1)
w.WriteHeader(http.StatusBadRequest)
return
}
stmts := strings.Split(string(b), "\n")
if stmts[len(stmts)-1] == "" {
stmts = stmts[:len(stmts)-1]
}
log.Tracef("Execute statement contains %d commands", len(stmts))
if len(stmts) == 0 {
log.Trace("No database execute commands supplied")
s.metrics.executeFail.Inc(1)
w.WriteHeader(http.StatusBadRequest)
return
}
transaction, _ := isTransaction(req)
startTime = time.Now()
failures, err := s.execute(transaction, stmts)
if err != nil {
log.Errorf("Database mutation failed: %s", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
duration := time.Since(startTime)
wr := StmtResponse{Time: duration.String(), Failures: failures}
pretty, _ := isPretty(req)
if pretty {
b, err = json.MarshalIndent(wr, "", " ")
} else {
b, err = json.Marshal(wr)
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
} else {
_, err = w.Write([]byte(b))
if err != nil {
log.Errorf("Error writting JSON data: %s", err.Error())
}
}
}
// serveStatistics returns the statistics for the program
func (s *Server) serveStatistics(w http.ResponseWriter, req *http.Request) {
statistics := make(map[string]interface{})
resources := map[string]interfaces.Statistics{"server": s}
for k, v := range resources {
s, err := v.GetStatistics()
if err != nil {
log.Error("failed to get " + k + " stats")
http.Error(w, "failed to get "+k+" stats", http.StatusInternalServerError)
return
}
statistics[k] = s
}
_, err := w.Write(ensurePrettyPrint(req, statistics))
if err != nil {
log.Error("failed to serve stats")
http.Error(w, "failed to serve stats", http.StatusInternalServerError)
return
}
}
// serveDiagnostics returns basic server diagnostics
func (s *Server) serveDiagnostics(w http.ResponseWriter, req *http.Request) {
diagnostics := make(map[string]interface{})
diagnostics["started"] = s.diagnostics.startTime.String()
diagnostics["uptime"] = time.Since(s.diagnostics.startTime).String()
diagnostics["host"] = s.host
diagnostics["port"] = s.port
diagnostics["data"] = s.path
diagnostics["database"] = s.dbPath
diagnostics["connection"] = s.connectionString()
diagnostics["snapafter"] = s.snapConf.snapshotAfter
diagnostics["snapindex"] = s.snapConf.lastIndex
_, err := w.Write(ensurePrettyPrint(req, diagnostics))
if err != nil {
log.Error("failed to serve diagnostics")
http.Error(w, "failed to serve diagnostics", http.StatusInternalServerError)
return
}
}
// serveRaftInfo returns information about the underlying Raft server
func (s *Server) serveRaftInfo(w http.ResponseWriter, req *http.Request) {
peers := s.raftServer.Peers()
info := make(map[string]interface{})
info["name"] = s.raftServer.Name()
info["state"] = s.raftServer.State()
info["leader"] = s.raftServer.Leader()
info["peers"] = peers
_, err := w.Write(ensurePrettyPrint(req, info))
if err != nil {
log.Error("failed to serve raft info")
http.Error(w, "failed to serve raft info", http.StatusInternalServerError)
return
}
}
// leaderRedirect returns a 307 Temporary Redirect, with the full path
// to the leader.
func (s *Server) leaderRedirect(w http.ResponseWriter, r *http.Request) {
peers := s.raftServer.Peers()
leader := peers[s.raftServer.Leader()]
var u string
for _, p := range peers {
if p.Name == leader.Name {
u = p.ConnectionString
break
}
}
http.Redirect(w, r, u+r.URL.Path, http.StatusTemporaryRedirect)
}