1
0
Fork 0
You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

539 lines
14 KiB
Go

package cluster
import (
"bytes"
"compress/gzip"
"encoding/binary"
"expvar"
"fmt"
"io"
"log"
"net"
"os"
"strconv"
"sync"
"time"
"github.com/rqlite/rqlite/v8/auth"
"github.com/rqlite/rqlite/v8/cluster/proto"
command "github.com/rqlite/rqlite/v8/command/proto"
pb "google.golang.org/protobuf/proto"
)
// stats captures stats for the Cluster service.
var stats *expvar.Map
const (
numGetNodeAPIRequest = "num_get_node_api_req"
numGetNodeAPIResponse = "num_get_node_api_resp"
numExecuteRequest = "num_execute_req"
numQueryRequest = "num_query_req"
numRequestRequest = "num_request_req"
numBackupRequest = "num_backup_req"
numLoadRequest = "num_load_req"
numRemoveNodeRequest = "num_remove_node_req"
numNotifyRequest = "num_notify_req"
numJoinRequest = "num_join_req"
numClientRetries = "num_client_retries"
numGetNodeAPIRequestRetries = "num_get_node_api_req_retries"
numClientLoadRetries = "num_client_load_retries"
numClientExecuteRetries = "num_client_execute_retries"
numClientQueryRetries = "num_client_query_retries"
numClientRequestRetries = "num_client_request_retries"
// Client stats for this package.
numGetNodeAPIRequestLocal = "num_get_node_api_req_local"
)
const (
// MuxRaftHeader is the byte used to indicate internode Raft communications.
MuxRaftHeader = 1
// MuxClusterHeader is the byte used to request internode cluster state information.
MuxClusterHeader = 2 // Cluster state communications
)
func init() {
stats = expvar.NewMap("cluster")
stats.Add(numGetNodeAPIRequest, 0)
stats.Add(numGetNodeAPIResponse, 0)
stats.Add(numExecuteRequest, 0)
stats.Add(numQueryRequest, 0)
stats.Add(numRequestRequest, 0)
stats.Add(numBackupRequest, 0)
stats.Add(numLoadRequest, 0)
stats.Add(numRemoveNodeRequest, 0)
stats.Add(numGetNodeAPIRequestLocal, 0)
stats.Add(numNotifyRequest, 0)
stats.Add(numJoinRequest, 0)
stats.Add(numClientRetries, 0)
stats.Add(numGetNodeAPIRequestRetries, 0)
stats.Add(numClientLoadRetries, 0)
stats.Add(numClientExecuteRetries, 0)
stats.Add(numClientQueryRetries, 0)
stats.Add(numClientRequestRetries, 0)
}
// Dialer is the interface dialers must implement.
type Dialer interface {
// Dial is used to create a connection to a service listening
// on an address.
Dial(address string, timeout time.Duration) (net.Conn, error)
}
// Database is the interface any queryable system must implement
type Database interface {
// Execute executes a slice of queries, none of which is expected
// to return rows.
Execute(er *command.ExecuteRequest) ([]*command.ExecuteResult, error)
// Query executes a slice of queries, each of which returns rows.
Query(qr *command.QueryRequest) ([]*command.QueryRows, error)
// Request processes a request that can both executes and queries.
Request(rr *command.ExecuteQueryRequest) ([]*command.ExecuteQueryResponse, error)
// Backup writes a backup of the database to the writer.
Backup(br *command.BackupRequest, dst io.Writer) error
// Loads an entire SQLite file into the database
Load(lr *command.LoadRequest) error
}
// Manager is the interface node-management systems must implement
type Manager interface {
// LeaderAddr returns the Raft address of the leader of the cluster.
LeaderAddr() (string, error)
// CommitIndex returns the Raft commit index of the cluster.
CommitIndex() (uint64, error)
// Remove removes the node, given by id, from the cluster
Remove(rn *command.RemoveNodeRequest) error
// Notify notifies this node that a remote node is ready
// for bootstrapping.
Notify(n *command.NotifyRequest) error
// Join joins a remote node to the cluster.
Join(n *command.JoinRequest) error
}
// CredentialStore is the interface credential stores must support.
type CredentialStore interface {
// AA authenticates and checks authorization for the given perm.
AA(username, password, perm string) bool
}
// Service provides information about the node and cluster.
type Service struct {
ln net.Listener // Incoming connections to the service
addr net.Addr // Address on which this service is listening
db Database // The queryable system.
mgr Manager // The cluster management system.
credentialStore CredentialStore
mu sync.RWMutex
https bool // Serving HTTPS?
apiAddr string // host:port this node serves the HTTP API.
logger *log.Logger
}
// New returns a new instance of the cluster service
func New(ln net.Listener, db Database, m Manager, credentialStore CredentialStore) *Service {
return &Service{
ln: ln,
addr: ln.Addr(),
db: db,
mgr: m,
logger: log.New(os.Stderr, "[cluster] ", log.LstdFlags),
credentialStore: credentialStore,
}
}
// Open opens the Service.
func (s *Service) Open() error {
go s.serve()
s.logger.Println("service listening on", s.addr)
return nil
}
// Close closes the service.
func (s *Service) Close() error {
s.ln.Close()
return nil
}
// Addr returns the address the service is listening on.
func (s *Service) Addr() string {
return s.addr.String()
}
// EnableHTTPS tells the cluster service the API serves HTTPS.
func (s *Service) EnableHTTPS(b bool) {
s.mu.Lock()
defer s.mu.Unlock()
s.https = b
}
// SetAPIAddr sets the API address the cluster service returns.
func (s *Service) SetAPIAddr(addr string) {
s.mu.Lock()
defer s.mu.Unlock()
s.apiAddr = addr
}
// GetAPIAddr returns the previously-set API address
func (s *Service) GetAPIAddr() string {
s.mu.RLock()
defer s.mu.RUnlock()
return s.apiAddr
}
// GetNodeAPIURL returns fully-specified HTTP(S) API URL for the
// node running this service.
func (s *Service) GetNodeAPIURL() string {
s.mu.RLock()
defer s.mu.RUnlock()
scheme := "http"
if s.https {
scheme = "https"
}
return fmt.Sprintf("%s://%s", scheme, s.apiAddr)
}
// Stats returns status of the Service.
func (s *Service) Stats() (map[string]interface{}, error) {
st := map[string]interface{}{
"addr": s.addr.String(),
"https": strconv.FormatBool(s.https),
"api_addr": s.apiAddr,
}
return st, nil
}
func (s *Service) serve() error {
for {
conn, err := s.ln.Accept()
if err != nil {
return err
}
go s.handleConn(conn)
}
}
func (s *Service) checkCommandPerm(c *proto.Command, perm string) bool {
if s.credentialStore == nil {
return true
}
username := ""
password := ""
if c.Credentials != nil {
username = c.Credentials.GetUsername()
password = c.Credentials.GetPassword()
}
return s.credentialStore.AA(username, password, perm)
}
func (s *Service) checkCommandPermAll(c *proto.Command, perms ...string) bool {
if s.credentialStore == nil {
return true
}
username := ""
password := ""
if c.Credentials != nil {
username = c.Credentials.GetUsername()
password = c.Credentials.GetPassword()
}
for _, perm := range perms {
if !s.credentialStore.AA(username, password, perm) {
return false
}
}
return true
}
func (s *Service) handleConn(conn net.Conn) {
defer conn.Close()
b := make([]byte, protoBufferLengthSize)
for {
_, err := io.ReadFull(conn, b)
if err != nil {
return
}
sz := binary.LittleEndian.Uint64(b[0:])
p := make([]byte, sz)
_, err = io.ReadFull(conn, p)
if err != nil {
return
}
c := &proto.Command{}
err = pb.Unmarshal(p, c)
if err != nil {
conn.Close()
}
switch c.Type {
case proto.Command_COMMAND_TYPE_GET_NODE_API_URL:
stats.Add(numGetNodeAPIRequest, 1)
ci, err := s.mgr.CommitIndex()
if err != nil {
conn.Close()
return
}
p, err = pb.Marshal(&proto.NodeMeta{
Url: s.GetNodeAPIURL(),
CommitIndex: ci,
})
if err != nil {
conn.Close()
}
writeBytesWithLength(conn, p)
stats.Add(numGetNodeAPIResponse, 1)
case proto.Command_COMMAND_TYPE_EXECUTE:
stats.Add(numExecuteRequest, 1)
resp := &proto.CommandExecuteResponse{}
er := c.GetExecuteRequest()
if er == nil {
resp.Error = "ExecuteRequest is nil"
} else if !s.checkCommandPerm(c, auth.PermExecute) {
resp.Error = "unauthorized"
} else {
res, err := s.db.Execute(er)
if err != nil {
resp.Error = err.Error()
} else {
resp.Results = make([]*command.ExecuteResult, len(res))
copy(resp.Results, res)
}
}
marshalAndWrite(conn, resp)
case proto.Command_COMMAND_TYPE_QUERY:
stats.Add(numQueryRequest, 1)
resp := &proto.CommandQueryResponse{}
qr := c.GetQueryRequest()
if qr == nil {
resp.Error = "QueryRequest is nil"
} else if !s.checkCommandPerm(c, auth.PermQuery) {
resp.Error = "unauthorized"
} else {
res, err := s.db.Query(qr)
if err != nil {
resp.Error = err.Error()
} else {
resp.Rows = make([]*command.QueryRows, len(res))
copy(resp.Rows, res)
}
}
marshalAndWrite(conn, resp)
case proto.Command_COMMAND_TYPE_REQUEST:
stats.Add(numRequestRequest, 1)
resp := &proto.CommandRequestResponse{}
rr := c.GetExecuteQueryRequest()
if rr == nil {
resp.Error = "RequestRequest is nil"
} else if !s.checkCommandPermAll(c, auth.PermQuery, auth.PermExecute) {
resp.Error = "unauthorized"
} else {
res, err := s.db.Request(rr)
if err != nil {
resp.Error = err.Error()
} else {
resp.Response = make([]*command.ExecuteQueryResponse, len(res))
copy(resp.Response, res)
}
}
marshalAndWrite(conn, resp)
case proto.Command_COMMAND_TYPE_BACKUP:
stats.Add(numBackupRequest, 1)
resp := &proto.CommandBackupResponse{}
br := c.GetBackupRequest()
if br == nil {
resp.Error = "BackupRequest is nil"
} else if !s.checkCommandPerm(c, auth.PermBackup) {
resp.Error = "unauthorized"
} else {
buf := new(bytes.Buffer)
if err := s.db.Backup(br, buf); err != nil {
resp.Error = err.Error()
} else {
resp.Data = buf.Bytes()
}
}
p, err = pb.Marshal(resp)
if err != nil {
conn.Close()
return
}
// Compress the backup for less space on the wire between nodes.
p, err = gzCompress(p)
if err != nil {
conn.Close()
return
}
writeBytesWithLength(conn, p)
case proto.Command_COMMAND_TYPE_BACKUP_STREAM:
stats.Add(numBackupRequest, 1)
resp := &proto.CommandBackupResponse{}
br := c.GetBackupRequest()
if br == nil {
resp.Error = "BackupRequest is nil"
} else if !s.checkCommandPerm(c, auth.PermBackup) {
resp.Error = "unauthorized"
}
p, err = pb.Marshal(resp)
if err != nil {
conn.Close()
return
}
writeBytesWithLength(conn, p)
// Now, start streaming the backup. Enable compressed mode
// regardless of whether the client requested it, so the client
// can easily detect the end of the stream, as well as saving
// space on the wire.
br.Compress = true
if err := s.db.Backup(br, conn); err != nil {
s.logger.Printf("failed to stream backup: %s", err.Error())
return
}
case proto.Command_COMMAND_TYPE_LOAD:
stats.Add(numLoadRequest, 1)
resp := &proto.CommandLoadResponse{}
lr := c.GetLoadRequest()
if lr == nil {
resp.Error = "LoadRequest is nil"
} else if !s.checkCommandPerm(c, auth.PermLoad) {
resp.Error = "unauthorized"
} else {
if err := s.db.Load(lr); err != nil {
resp.Error = fmt.Sprintf("remote node failed to load: %s", err.Error())
}
}
marshalAndWrite(conn, resp)
case proto.Command_COMMAND_TYPE_LOAD_CHUNK:
resp := &proto.CommandLoadChunkResponse{
Error: "unsupported",
}
marshalAndWrite(conn, resp)
case proto.Command_COMMAND_TYPE_REMOVE_NODE:
stats.Add(numRemoveNodeRequest, 1)
resp := &proto.CommandRemoveNodeResponse{}
rn := c.GetRemoveNodeRequest()
if rn == nil {
resp.Error = "LoadRequest is nil"
} else if !s.checkCommandPerm(c, auth.PermRemove) {
resp.Error = "unauthorized"
} else {
if err := s.mgr.Remove(rn); err != nil {
resp.Error = err.Error()
}
}
marshalAndWrite(conn, resp)
case proto.Command_COMMAND_TYPE_NOTIFY:
stats.Add(numNotifyRequest, 1)
resp := &proto.CommandNotifyResponse{}
nr := c.GetNotifyRequest()
if nr == nil {
resp.Error = "NotifyRequest is nil"
} else if !s.checkCommandPerm(c, auth.PermJoin) {
resp.Error = "unauthorized"
} else {
if err := s.mgr.Notify(nr); err != nil {
resp.Error = err.Error()
}
}
marshalAndWrite(conn, resp)
case proto.Command_COMMAND_TYPE_JOIN:
stats.Add(numJoinRequest, 1)
resp := &proto.CommandJoinResponse{}
jr := c.GetJoinRequest()
if jr == nil {
resp.Error = "JoinRequest is nil"
} else {
if (jr.Voter && s.checkCommandPerm(c, auth.PermJoin)) ||
(!jr.Voter && s.checkCommandPerm(c, auth.PermJoinReadOnly)) {
if err := s.mgr.Join(jr); err != nil {
resp.Error = err.Error()
if err.Error() == "not leader" {
laddr, err := s.mgr.LeaderAddr()
if err != nil {
resp.Error = err.Error()
} else {
resp.Leader = laddr
}
}
}
} else {
resp.Error = "unauthorized"
}
}
marshalAndWrite(conn, resp)
}
}
}
func marshalAndWrite(conn net.Conn, m pb.Message) {
p, err := pb.Marshal(m)
if err != nil {
conn.Close()
}
writeBytesWithLength(conn, p)
}
func writeBytesWithLength(conn net.Conn, p []byte) {
b := make([]byte, protoBufferLengthSize)
binary.LittleEndian.PutUint64(b[0:], uint64(len(p)))
conn.Write(b)
conn.Write(p)
}
// gzCompress compresses the given byte slice.
func gzCompress(b []byte) ([]byte, error) {
var buf bytes.Buffer
gzw, err := gzip.NewWriterLevel(&buf, gzip.BestCompression)
if err != nil {
return nil, fmt.Errorf("gzip new writer: %s", err)
}
if _, err := gzw.Write(b); err != nil {
return nil, fmt.Errorf("gzip Write: %s", err)
}
if err := gzw.Close(); err != nil {
return nil, fmt.Errorf("gzip Close: %s", err)
}
return buf.Bytes(), nil
}