1
0
Fork 0
You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

488 lines
13 KiB
Go

package server
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io/ioutil"
"math/rand"
"net/http"
"path"
"path/filepath"
"strconv"
"strings"
"sync"
"time"
"github.com/gorilla/mux"
"github.com/otoolep/raft"
"github.com/otoolep/rqlite/command"
"github.com/otoolep/rqlite/db"
"github.com/otoolep/rqlite/interfaces"
"github.com/rcrowley/go-metrics"
log "code.google.com/p/log4go"
)
type FailedSqlStmt struct {
Sql string `json:"sql"`
Error string `json:"error"`
}
type StmtResponse struct {
Time string `json:"time"`
Failures []FailedSqlStmt `json:"failures"`
}
type QueryResponse struct {
Time string `json:"time"`
Failures []FailedSqlStmt `json:"failures"`
Rows db.RowResults `json:"rows"`
}
type ServerMetrics struct {
registry metrics.Registry
joinSuccess metrics.Counter
joinFail metrics.Counter
queryReceived metrics.Counter
querySuccess metrics.Counter
queryFail metrics.Counter
executeReceived metrics.Counter
executeTxReceived metrics.Counter
executeSuccess metrics.Counter
executeFail metrics.Counter
}
type ServerDiagnostics struct {
startTime time.Time
}
type SnapshotConf struct {
// The index when the last snapshot happened
lastIndex uint64
// If the incremental number of index entries since the last
// snapshot exceeds snapshotAfter rqlite will do a snapshot
snapshotAfter uint64
}
// The raftd server is a combination of the Raft server and an HTTP
// server which acts as the transport.
type Server struct {
name string
host string
port int
path string
router *mux.Router
raftServer raft.Server
httpServer *http.Server
dbPath string
db *db.DB
snapConf *SnapshotConf
metrics *ServerMetrics
diagnostics *ServerDiagnostics
mutex sync.Mutex
}
// queryParam returns whether the given query param is set to true.
func queryParam(req *http.Request, param string) (bool, error) {
err := req.ParseForm()
if err != nil {
return false, err
}
if _, ok := req.Form[param]; ok {
return true, nil
}
return false, nil
}
// isPretty returns whether the HTTP response body should be pretty-printed.
func isPretty(req *http.Request) (bool, error) {
return queryParam(req, "pretty")
}
// isTransaction returns whether the client requested an explicit
// transaction for the request.
func isTransaction(req *http.Request) (bool, error) {
return queryParam(req, "transaction")
}
// Creates a new ServerMetrics object.
func NewServerMetrics() *ServerMetrics {
m := &ServerMetrics{
registry: metrics.NewRegistry(),
joinSuccess: metrics.NewCounter(),
joinFail: metrics.NewCounter(),
queryReceived: metrics.NewCounter(),
querySuccess: metrics.NewCounter(),
queryFail: metrics.NewCounter(),
executeReceived: metrics.NewCounter(),
executeTxReceived: metrics.NewCounter(),
executeSuccess: metrics.NewCounter(),
executeFail: metrics.NewCounter(),
}
m.registry.Register("join.succes", m.joinSuccess)
m.registry.Register("join.fail", m.joinFail)
m.registry.Register("query.Received", m.queryReceived)
m.registry.Register("query.success", m.querySuccess)
m.registry.Register("query.fail", m.queryFail)
m.registry.Register("execute.Received", m.executeReceived)
m.registry.Register("execute.tx.received", m.executeTxReceived)
m.registry.Register("execute.success", m.executeSuccess)
m.registry.Register("execute.fail", m.executeFail)
return m
}
func NewServerDiagnostics() *ServerDiagnostics {
d := &ServerDiagnostics{
startTime: time.Now(),
}
return d
}
// Creates a new server.
func New(dataDir string, dbfile string, snapAfter int, host string, port int) *Server {
dbPath := path.Join(dataDir, dbfile)
s := &Server{
host: host,
port: port,
path: dataDir,
dbPath: dbPath,
db: db.New(dbPath),
snapConf: &SnapshotConf{snapshotAfter: uint64(snapAfter)},
metrics: NewServerMetrics(),
diagnostics: NewServerDiagnostics(),
router: mux.NewRouter(),
}
// Read existing name or generate a new one.
if b, err := ioutil.ReadFile(filepath.Join(dataDir, "name")); err == nil {
s.name = string(b)
} else {
s.name = fmt.Sprintf("%07x", rand.Int())[0:7]
if err = ioutil.WriteFile(filepath.Join(dataDir, "name"), []byte(s.name), 0644); err != nil {
panic(err)
}
}
return s
}
// GetStatistics returns an object storing statistics, which supports JSON
// marshalling.
func (s *Server) GetStatistics() (metrics.Registry, error) {
return s.metrics.registry, nil
}
// Returns the connection string.
func (s *Server) connectionString() string {
return fmt.Sprintf("http://%s:%d", s.host, s.port)
}
// logSnapshot logs about the snapshot that was taken.
func (s *Server) logSnapshot(err error, currentIndex, count uint64) {
info := fmt.Sprintf("%s: snapshot of %d events at index %d", s.connectionString, count, currentIndex)
if err != nil {
log.Info("%s attempted and failed: %v", info, err)
} else {
log.Info("%s completed", info)
}
}
// Starts the server.
func (s *Server) ListenAndServe(leader string) error {
var err error
log.Info("Initializing Raft Server: %s", s.path)
// Initialize and start Raft server.
transporter := raft.NewHTTPTransporter("/raft", 200*time.Millisecond)
stateMachine := NewDbStateMachine(s.dbPath)
s.raftServer, err = raft.NewServer(s.name, s.path, transporter, stateMachine, s.db, "")
if err != nil {
log.Error("Failed to create new Raft server", err.Error())
return err
}
transporter.Install(s.raftServer, s)
s.raftServer.Start()
if leader != "" {
// Join to leader if specified.
log.Info("Attempting to join leader at %s", leader)
if !s.raftServer.IsLogEmpty() {
log.Error("Cannot join with an existing log")
return errors.New("Cannot join with an existing log")
}
if err := s.Join(leader); err != nil {
log.Error("Failed to join leader", err.Error())
return err
}
} else if s.raftServer.IsLogEmpty() {
// Initialize the server by joining itself.
log.Info("Initializing new cluster")
_, err := s.raftServer.Do(&raft.DefaultJoinCommand{
Name: s.raftServer.Name(),
ConnectionString: s.connectionString(),
})
if err != nil {
log.Error("Failed to join to self", err.Error())
}
} else {
log.Info("Recovered from log")
}
log.Info("Initializing HTTP server")
// Initialize and start HTTP server.
s.httpServer = &http.Server{
Addr: fmt.Sprintf(":%d", s.port),
Handler: s.router,
}
s.router.HandleFunc("/statistics", s.serveStatistics).Methods("GET")
s.router.HandleFunc("/diagnostics", s.serveDiagnostics).Methods("GET")
s.router.HandleFunc("/db", s.readHandler).Methods("GET")
s.router.HandleFunc("/db", s.writeHandler).Methods("POST")
s.router.HandleFunc("/join", s.joinHandler).Methods("POST")
log.Info("Listening at %s", s.connectionString())
return s.httpServer.ListenAndServe()
}
// This is a hack around Gorilla mux not providing the correct net/http
// HandleFunc() interface.
func (s *Server) HandleFunc(pattern string, handler func(http.ResponseWriter, *http.Request)) {
s.router.HandleFunc(pattern, handler)
}
// Joins to the leader of an existing cluster.
func (s *Server) Join(leader string) error {
command := &raft.DefaultJoinCommand{
Name: s.raftServer.Name(),
ConnectionString: s.connectionString(),
}
var b bytes.Buffer
json.NewEncoder(&b).Encode(command)
resp, err := http.Post(fmt.Sprintf("http://%s/join", leader), "application/json", &b)
if err != nil {
return err
}
resp.Body.Close()
return nil
}
func (s *Server) joinHandler(w http.ResponseWriter, req *http.Request) {
command := &raft.DefaultJoinCommand{}
if err := json.NewDecoder(req.Body).Decode(&command); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
s.metrics.joinFail.Inc(1)
return
}
if _, err := s.raftServer.Do(command); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
s.metrics.joinFail.Inc(1)
return
}
s.metrics.joinSuccess.Inc(1)
}
func (s *Server) readHandler(w http.ResponseWriter, req *http.Request) {
log.Trace("readHandler for URL: %s", req.URL)
s.metrics.queryReceived.Inc(1)
var failures = make([]FailedSqlStmt, 0)
b, err := ioutil.ReadAll(req.Body)
if err != nil {
log.Trace("Bad HTTP request", err.Error())
w.WriteHeader(http.StatusBadRequest)
s.metrics.queryFail.Inc(1)
return
}
stmt := string(b)
startTime := time.Now()
r, err := s.db.Query(stmt)
if err != nil {
log.Trace("Bad SQL statement", err.Error())
s.metrics.queryFail.Inc(1)
failures = append(failures, FailedSqlStmt{stmt, err.Error()})
} else {
s.metrics.querySuccess.Inc(1)
}
duration := time.Since(startTime)
rr := QueryResponse{Time: duration.String(), Failures: failures, Rows: r}
pretty, _ := isPretty(req)
if pretty {
b, err = json.MarshalIndent(rr, "", " ")
} else {
b, err = json.Marshal(rr)
}
if err != nil {
log.Trace("Failed to marshal JSON data", err.Error())
http.Error(w, err.Error(), http.StatusBadRequest) // Internal error actually
} else {
w.Write([]byte(b))
}
}
func (s *Server) execute(tx bool, stmts []string) ([]FailedSqlStmt, error) {
var failures = make([]FailedSqlStmt, 0)
if tx {
log.Trace("Transaction requested")
s.metrics.executeTxReceived.Inc(1)
_, err := s.raftServer.Do(command.NewTransactionExecuteCommandSet(stmts))
if err != nil {
log.Trace("Transaction failed: %s", err.Error())
s.metrics.executeFail.Inc(1)
failures = append(failures, FailedSqlStmt{stmts[0], err.Error()})
} else {
s.metrics.executeSuccess.Inc(1)
}
} else {
log.Trace("No transaction requested")
for i := range stmts {
_, err := s.raftServer.Do(command.NewExecuteCommand(stmts[i]))
if err != nil {
log.Trace("Execute statement %s failed: %s", stmts[i], err.Error())
s.metrics.executeFail.Inc(1)
failures = append(failures, FailedSqlStmt{stmts[i], err.Error()})
} else {
s.metrics.executeSuccess.Inc(1)
}
}
}
return failures, nil
}
func (s *Server) writeHandler(w http.ResponseWriter, req *http.Request) {
s.mutex.Lock()
defer s.mutex.Unlock()
log.Trace("writeHandler for URL: %s", req.URL)
s.metrics.executeReceived.Inc(1)
currentIndex := s.raftServer.CommitIndex()
count := currentIndex - s.snapConf.lastIndex
if uint64(count) > s.snapConf.snapshotAfter {
err := s.raftServer.TakeSnapshot()
s.logSnapshot(err, currentIndex, count)
s.snapConf.lastIndex = currentIndex
}
var startTime time.Time
// Read the value from the POST body.
b, err := ioutil.ReadAll(req.Body)
if err != nil {
log.Trace("Bad HTTP request", err.Error())
s.metrics.executeFail.Inc(1)
w.WriteHeader(http.StatusBadRequest)
return
}
stmts := strings.Split(string(b), "\n")
if stmts[len(stmts)-1] == "" {
stmts = stmts[:len(stmts)-1]
}
log.Trace("Execute statement contains %d commands", len(stmts))
if len(stmts) == 0 {
log.Trace("No database execute commands supplied")
s.metrics.executeFail.Inc(1)
w.WriteHeader(http.StatusBadRequest)
return
}
transaction, _ := isTransaction(req)
startTime = time.Now()
failures, err := s.execute(transaction, stmts)
if err != nil {
log.Error("Database mutation failed: %s", err.Error())
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
duration := time.Since(startTime)
wr := StmtResponse{Time: duration.String(), Failures: failures}
pretty, _ := isPretty(req)
if pretty {
b, err = json.MarshalIndent(wr, "", " ")
} else {
b, err = json.Marshal(wr)
}
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
} else {
w.Write([]byte(b))
}
}
// serveStatistics returns the statistics for the program
func (s *Server) serveStatistics(w http.ResponseWriter, req *http.Request) {
statistics := make(map[string]interface{})
resources := map[string]interfaces.Statistics{"server": s}
for k, v := range resources {
s, err := v.GetStatistics()
if err != nil {
log.Error("failed to get " + k + " stats")
http.Error(w, "failed to get "+k+" stats", http.StatusInternalServerError)
return
}
statistics[k] = s
}
var b []byte
var err error
pretty, _ := isPretty(req)
if pretty {
b, err = json.MarshalIndent(statistics, "", " ")
} else {
b, err = json.Marshal(statistics)
}
if err != nil {
log.Error("failed to JSON marshal statistics map")
http.Error(w, "failed to JSON marshal statistics map", http.StatusInternalServerError)
return
}
w.Write(b)
}
// serveDiagnostics returns basic server diagnostics
func (s *Server) serveDiagnostics(w http.ResponseWriter, req *http.Request) {
diagnostics := make(map[string]string)
diagnostics["started"] = s.diagnostics.startTime.String()
diagnostics["uptime"] = time.Since(s.diagnostics.startTime).String()
diagnostics["data"] = s.path
diagnostics["database"] = s.dbPath
diagnostics["connection"] = s.connectionString()
diagnostics["snapafter"] = strconv.FormatUint(s.snapConf.snapshotAfter, 10)
diagnostics["snapindex"] = strconv.FormatUint(s.snapConf.lastIndex, 10)
var b []byte
pretty, _ := isPretty(req)
if pretty {
b, _ = json.MarshalIndent(diagnostics, "", " ")
} else {
b, _ = json.Marshal(diagnostics)
}
w.Write(b)
}