1
0
Fork 0

Broadcast Store meta via standard consensus

With this change the cluster metadata (arbitrary key-value data associated with each node) is now broadcast across the cluster using the standard consensus mechanism. Specifically the use case for this metadata is to allow all nodes know the HTTP API address of all other nodes, for the purpose of redirecting requests to the leader.

This change removed the need for multiplexing two logical connections
over the single Raft TCP connection, which greatly simplifies the
networking code generally.

Original PR https://github.com/rqlite/rqlite/pull/434
master
Philip O'Toole 5 years ago
parent 4a9345cc20
commit f57ace7da2

@ -57,7 +57,7 @@ You can grow a cluster, at anytime, simply by starting up a new node and having
# Removing or replacing a node
If a node fails completely and is not coming back, or if you shut down a node because you wish to deprovision it, its record should also be removed from the cluster. To remove the record of a node from a cluster, execute the following command:
```
curl -XDELETE http://localhost:4001/remove -d '{"addr": "<node raft address>"}'
curl -XDELETE http://localhost:4001/remove -d '{"id": "<node raft ID>"}'
```
assuming `localhost` is the address of the cluster leader. If you do not do this the leader will continually attempt to communicate with that node.

@ -20,9 +20,9 @@ const numAttempts int = 3
const attemptInterval time.Duration = 5 * time.Second
// It walks through joinAddr in order, and sets the node ID and Raft address of
// the joining node as nodeID advAddr respectively. It returns the endpoint
// successfully used to join the cluster.
func Join(joinAddr []string, nodeID, advAddr string, tlsConfig *tls.Config) (string, error) {
// the joining node as id addr respectively. It returns the endpoint successfully
// used to join the cluster.
func Join(joinAddr []string, id, addr string, meta map[string]string, tlsConfig *tls.Config) (string, error) {
var err error
var j string
logger := log.New(os.Stderr, "[cluster-join] ", log.LstdFlags)
@ -32,7 +32,7 @@ func Join(joinAddr []string, nodeID, advAddr string, tlsConfig *tls.Config) (str
for i := 0; i < numAttempts; i++ {
for _, a := range joinAddr {
j, err = join(a, nodeID, advAddr, tlsConfig, logger)
j, err = join(a, id, addr, meta, tlsConfig, logger)
if err == nil {
// Success!
return j, nil
@ -45,13 +45,13 @@ func Join(joinAddr []string, nodeID, advAddr string, tlsConfig *tls.Config) (str
return "", err
}
func join(joinAddr, nodeID, advAddr string, tlsConfig *tls.Config, logger *log.Logger) (string, error) {
if nodeID == "" {
func join(joinAddr, id, addr string, meta map[string]string, tlsConfig *tls.Config, logger *log.Logger) (string, error) {
if id == "" {
return "", fmt.Errorf("node ID not set")
}
// Join using IP address, as that is what Hashicorp Raft works in.
resv, err := net.ResolveTCPAddr("tcp", advAddr)
resv, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return "", err
}
@ -59,7 +59,7 @@ func join(joinAddr, nodeID, advAddr string, tlsConfig *tls.Config, logger *log.L
// Check for protocol scheme, and insert default if necessary.
fullAddr := httpd.NormalizeAddr(fmt.Sprintf("%s/join", joinAddr))
// Enable skipVerify as requested.
// Create and configure the client to connect to the other node.
tr := &http.Transport{
TLSClientConfig: tlsConfig,
}
@ -69,9 +69,10 @@ func join(joinAddr, nodeID, advAddr string, tlsConfig *tls.Config, logger *log.L
}
for {
b, err := json.Marshal(map[string]string{
"id": nodeID,
b, err := json.Marshal(map[string]interface{}{
"id": id,
"addr": resv.String(),
"meta": meta,
})
// Attempt to join.

@ -1,7 +1,9 @@
package cluster
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
@ -16,7 +18,7 @@ func Test_SingleJoinOK(t *testing.T) {
}))
defer ts.Close()
j, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil)
j, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil, nil)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
@ -25,13 +27,57 @@ func Test_SingleJoinOK(t *testing.T) {
}
}
func Test_SingleJoinMetaOK(t *testing.T) {
var body map[string]interface{}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Fatalf("Client did not use POST")
}
w.WriteHeader(http.StatusOK)
b, err := ioutil.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
if err := json.Unmarshal(b, &body); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
}))
defer ts.Close()
nodeAddr := "127.0.0.1:9090"
md := map[string]string{"foo": "bar"}
j, err := Join([]string{ts.URL}, "id0", nodeAddr, md, nil)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
if j != ts.URL+"/join" {
t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts.URL)
}
if id, _ := body["id"]; id != "id0" {
t.Fatalf("node joined supplying wrong ID, exp %s, got %s", "id0", body["id"])
}
if addr, _ := body["addr"]; addr != nodeAddr {
t.Fatalf("node joined supplying wrong address, exp %s, got %s", nodeAddr, body["addr"])
}
rxMd, _ := body["meta"].(map[string]interface{})
if len(rxMd) != len(md) || rxMd["foo"] != "bar" {
t.Fatalf("node joined supplying wrong meta")
}
}
func Test_SingleJoinFail(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer ts.Close()
_, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil)
_, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil, nil)
if err == nil {
t.Fatalf("expected error when joining bad node")
}
@ -45,7 +91,7 @@ func Test_DoubleJoinOK(t *testing.T) {
}))
defer ts2.Close()
j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil)
j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil, nil)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
@ -63,7 +109,7 @@ func Test_DoubleJoinOKSecondNode(t *testing.T) {
}))
defer ts2.Close()
j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil)
j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil, nil)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
@ -83,7 +129,7 @@ func Test_DoubleJoinOKSecondNodeRedirect(t *testing.T) {
}))
defer ts2.Close()
j, err := Join([]string{ts2.URL}, "id0", "127.0.0.1:9090", nil)
j, err := Join([]string{ts2.URL}, "id0", "127.0.0.1:9090", nil, nil)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}

@ -1,180 +0,0 @@
package cluster
import (
"encoding/json"
"fmt"
"log"
"net"
"os"
"sync"
"time"
)
const (
connectionTimeout = 10 * time.Second
)
var respOKMarshalled []byte
func init() {
var err error
respOKMarshalled, err = json.Marshal(response{})
if err != nil {
panic(fmt.Sprintf("unable to JSON marshal OK response: %s", err.Error()))
}
}
type response struct {
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// Transport is the interface the network service must provide.
type Transport interface {
net.Listener
// Dial is used to create a new outgoing connection
Dial(address string, timeout time.Duration) (net.Conn, error)
}
// Store represents a store of information, managed via consensus.
type Store interface {
// Leader returns the address of the leader of the consensus system.
LeaderAddr() string
// UpdateAPIPeers updates the API peers on the store.
UpdateAPIPeers(peers map[string]string) error
}
// Service allows access to the cluster and associated meta data,
// via consensus.
type Service struct {
tn Transport
store Store
addr net.Addr
wg sync.WaitGroup
logger *log.Logger
}
// NewService returns a new instance of the cluster service
func NewService(tn Transport, store Store) *Service {
return &Service{
tn: tn,
store: store,
addr: tn.Addr(),
logger: log.New(os.Stderr, "[cluster] ", log.LstdFlags),
}
}
// Open opens the Service.
func (s *Service) Open() error {
s.wg.Add(1)
go s.serve()
s.logger.Println("service listening on", s.tn.Addr())
return nil
}
// Close closes the service.
func (s *Service) Close() error {
s.tn.Close()
s.wg.Wait()
return nil
}
// Addr returns the address the service is listening on.
func (s *Service) Addr() string {
return s.addr.String()
}
// SetPeer will set the mapping between raftAddr and apiAddr for the entire cluster.
func (s *Service) SetPeer(raftAddr, apiAddr string) error {
peer := map[string]string{
raftAddr: apiAddr,
}
// Try the local store. It might be the leader.
err := s.store.UpdateAPIPeers(peer)
if err == nil {
// All done! Aren't we lucky?
return nil
}
// Try talking to the leader over the network.
if leader := s.store.LeaderAddr(); leader == "" {
return fmt.Errorf("no leader available")
}
conn, err := s.tn.Dial(s.store.LeaderAddr(), connectionTimeout)
if err != nil {
return err
}
defer conn.Close()
b, err := json.Marshal(peer)
if err != nil {
return err
}
if _, err := conn.Write(b); err != nil {
return err
}
// Wait for the response and verify the operation went through.
resp := response{}
d := json.NewDecoder(conn)
err = d.Decode(&resp)
if err != nil {
return err
}
if resp.Code != 0 {
return fmt.Errorf(resp.Message)
}
return nil
}
func (s *Service) serve() error {
defer s.wg.Done()
for {
conn, err := s.tn.Accept()
if err != nil {
return err
}
go s.handleConn(conn)
}
}
func (s *Service) handleConn(conn net.Conn) {
s.logger.Printf("received connection from %s", conn.RemoteAddr().String())
// Only handles peers updates for now.
peers := make(map[string]string)
d := json.NewDecoder(conn)
err := d.Decode(&peers)
if err != nil {
return
}
// Update the peers.
if err := s.store.UpdateAPIPeers(peers); err != nil {
resp := response{1, err.Error()}
b, err := json.Marshal(resp)
if err != nil {
conn.Close() // Only way left to signal.
} else {
if _, err := conn.Write(b); err != nil {
conn.Close() // Only way left to signal.
}
}
return
}
// Let the remote node know everything went OK.
if _, err := conn.Write(respOKMarshalled); err != nil {
conn.Close() // Only way left to signal.
}
return
}

@ -1,194 +0,0 @@
package cluster
import (
"encoding/json"
"fmt"
"net"
"testing"
"time"
)
func Test_NewServiceOpenClose(t *testing.T) {
ml := mustNewMockTransport()
ms := &mockStore{}
s := NewService(ml, ms)
if s == nil {
t.Fatalf("failed to create cluster service")
}
if err := s.Open(); err != nil {
t.Fatalf("failed to open cluster service")
}
if err := s.Close(); err != nil {
t.Fatalf("failed to close cluster service")
}
}
func Test_SetAPIPeer(t *testing.T) {
raftAddr, apiAddr := "localhost:4002", "localhost:4001"
s, _, ms := mustNewOpenService()
defer s.Close()
if err := s.SetPeer(raftAddr, apiAddr); err != nil {
t.Fatalf("failed to set peer: %s", err.Error())
}
if ms.peers[raftAddr] != apiAddr {
t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr])
}
}
func Test_SetAPIPeerNetwork(t *testing.T) {
raftAddr, apiAddr := "localhost:4002", "localhost:4001"
s, _, ms := mustNewOpenService()
defer s.Close()
raddr, err := net.ResolveTCPAddr("tcp", s.Addr())
if err != nil {
t.Fatalf("failed to resolve remote uster ervice address: %s", err.Error())
}
conn, err := net.DialTCP("tcp4", nil, raddr)
if err != nil {
t.Fatalf("failed to connect to remote cluster service: %s", err.Error())
}
if _, err := conn.Write([]byte(fmt.Sprintf(`{"%s": "%s"}`, raftAddr, apiAddr))); err != nil {
t.Fatalf("failed to write to remote cluster service: %s", err.Error())
}
resp := response{}
d := json.NewDecoder(conn)
err = d.Decode(&resp)
if err != nil {
t.Fatalf("failed to decode response: %s", err.Error())
}
if resp.Code != 0 {
t.Fatalf("response code was non-zero")
}
if ms.peers[raftAddr] != apiAddr {
t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr])
}
}
func Test_SetAPIPeerFailUpdate(t *testing.T) {
raftAddr, apiAddr := "localhost:4002", "localhost:4001"
s, _, ms := mustNewOpenService()
defer s.Close()
ms.failUpdateAPIPeers = true
// Attempt to set peer without a leader
if err := s.SetPeer(raftAddr, apiAddr); err == nil {
t.Fatalf("no error returned by set peer when no leader")
}
// Start a network server.
tn, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to open test server: %s", err.Error())
}
ms.leader = tn.Addr().String()
c := make(chan map[string]string, 1)
go func() {
conn, err := tn.Accept()
if err != nil {
t.Fatalf("failed to accept connection from cluster: %s", err.Error())
}
t.Logf("test server received connection from: %s", conn.RemoteAddr())
peers := make(map[string]string)
d := json.NewDecoder(conn)
err = d.Decode(&peers)
if err != nil {
t.Fatalf("failed to decode message from cluster: %s", err.Error())
}
// Response OK.
// Let the remote node know everything went OK.
if _, err := conn.Write(respOKMarshalled); err != nil {
t.Fatalf("failed to respond to cluster: %s", err.Error())
}
c <- peers
}()
if err := s.SetPeer(raftAddr, apiAddr); err != nil {
t.Fatalf("failed to set peer on cluster: %s", err.Error())
}
peers := <-c
if peers[raftAddr] != apiAddr {
t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr])
}
}
func mustNewOpenService() (*Service, *mockTransport, *mockStore) {
ml := mustNewMockTransport()
ms := newMockStore()
s := NewService(ml, ms)
if err := s.Open(); err != nil {
panic("failed to open new service")
}
return s, ml, ms
}
type mockTransport struct {
tn net.Listener
}
func mustNewMockTransport() *mockTransport {
tn, err := net.Listen("tcp", "localhost:0")
if err != nil {
panic("failed to create mock listener")
}
return &mockTransport{
tn: tn,
}
}
func (ml *mockTransport) Accept() (c net.Conn, err error) {
return ml.tn.Accept()
}
func (ml *mockTransport) Addr() net.Addr {
return ml.tn.Addr()
}
func (ml *mockTransport) Close() (err error) {
return ml.tn.Close()
}
func (ml *mockTransport) Dial(addr string, t time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, 5*time.Second)
}
type mockStore struct {
leader string
peers map[string]string
failUpdateAPIPeers bool
}
func newMockStore() *mockStore {
return &mockStore{
peers: make(map[string]string),
}
}
func (ms *mockStore) LeaderAddr() string {
return ms.leader
}
func (ms *mockStore) UpdateAPIPeers(peers map[string]string) error {
if ms.failUpdateAPIPeers {
return fmt.Errorf("forced fail")
}
for k, v := range peers {
ms.peers[k] = v
}
return nil
}

@ -8,7 +8,6 @@ import (
"fmt"
"io/ioutil"
"log"
"net"
"os"
"os/signal"
"path/filepath"
@ -161,47 +160,28 @@ func main() {
// Start requested profiling.
startProfile(cpuProfile, memProfile)
// Set up node-to-node TCP communication.
ln, err := net.Listen("tcp", raftAddr)
if err != nil {
log.Fatalf("failed to listen on %s: %s", raftAddr, err.Error())
}
var adv net.Addr
if raftAdv != "" {
adv, err = net.ResolveTCPAddr("tcp", raftAdv)
if err != nil {
log.Fatalf("failed to resolve advertise address %s: %s", raftAdv, err.Error())
}
}
// Start up node-to-node network mux.
var mux *tcp.Mux
// Create internode network layer.
var tn *tcp.Transport
if nodeEncrypt {
log.Printf("enabling node-to-node encryption with cert: %s, key: %s", nodeX509Cert, nodeX509Key)
mux, err = tcp.NewTLSMux(ln, adv, nodeX509Cert, nodeX509Key, nodeX509CACert)
tn = tcp.NewTLSTransport(nodeX509Cert, nodeX509Key, noVerify)
} else {
mux, err = tcp.NewMux(ln, adv)
tn = tcp.NewTransport()
}
if err != nil {
log.Fatalf("failed to create node-to-node mux: %s", err.Error())
if err := tn.Open(raftAddr); err != nil {
log.Fatalf("failed to open internode network layer: %s", err.Error())
}
mux.InsecureSkipVerify = noNodeVerify
go mux.Serve()
// Get transport for Raft communications.
raftTn := mux.Listen(muxRaftHeader)
// Create and open the store.
dataPath, err = filepath.Abs(dataPath)
dataPath, err := filepath.Abs(dataPath)
if err != nil {
log.Fatalf("failed to determine absolute data path: %s", err.Error())
}
dbConf := store.NewDBConfig(dsn, !onDisk)
str := store.New(&store.StoreConfig{
str := store.New(tn, &store.StoreConfig{
DBConf: dbConf,
Dir: dataPath,
Tn: raftTn,
ID: idOrRaftAddr(),
})
@ -220,10 +200,6 @@ func main() {
if err != nil {
log.Fatalf("failed to parse Raft apply timeout %s: %s", raftApplyTimeout, err.Error())
}
str.OpenTimeout, err = time.ParseDuration(raftOpenTimeout)
if err != nil {
log.Fatalf("failed to parse Raft open timeout %s: %s", raftOpenTimeout, err.Error())
}
// Determine join addresses, if necessary.
ja, err := store.JoinAllowed(dataPath)
@ -241,16 +217,18 @@ func main() {
log.Println("node is already member of cluster, skip determining join addresses")
}
// Now, open it.
// Now, open store.
if err := str.Open(len(joins) == 0); err != nil {
log.Fatalf("failed to open store: %s", err.Error())
}
// Create and configure cluster service.
tn := mux.Listen(muxMetaHeader)
cs := cluster.NewService(tn, str)
if err := cs.Open(); err != nil {
log.Fatalf("failed to open cluster service: %s", err.Error())
// Prepare metadata for join command.
apiAdv := httpAddr
if httpAdv != "" {
apiAdv = httpAdv
}
meta := map[string]string{
"api_addr": apiAdv,
}
// Execute any requested join operation.
@ -274,7 +252,7 @@ func main() {
}
}
if j, err := cluster.Join(joins, str.ID(), advAddr, &tlsConfig); err != nil {
if j, err := cluster.Join(joins, str.ID(), advAddr, meta, &tlsConfig); err != nil {
log.Fatalf("failed to join cluster at %s: %s", joins, err.Error())
} else {
log.Println("successfully joined cluster at", j)
@ -284,53 +262,26 @@ func main() {
log.Println("no join addresses set")
}
// Publish to the cluster the mapping between this Raft address and API address.
// The Raft layer broadcasts the resolved address, so use that as the key. But
// only set different HTTP advertise address if set.
apiAdv := httpAddr
if httpAdv != "" {
apiAdv = httpAdv
}
if err := publishAPIAddr(cs, raftTn.Addr().String(), apiAdv, publishPeerTimeout); err != nil {
log.Fatalf("failed to set peer for %s to %s: %s", raftAddr, httpAddr, err.Error())
}
log.Printf("set peer for %s to %s", raftTn.Addr().String(), apiAdv)
// Get the credential store.
credStr, err := credentialStore()
// Wait until the store is in full consensus.
openTimeout, err := time.ParseDuration(raftOpenTimeout)
if err != nil {
log.Fatalf("failed to get credential store: %s", err.Error())
log.Fatalf("failed to parse Raft open timeout %s: %s", raftOpenTimeout, err.Error())
}
str.WaitForLeader(openTimeout)
str.WaitForApplied(openTimeout)
// Create HTTP server and load authentication information if required.
var s *httpd.Service
if credStr != nil {
s = httpd.New(httpAddr, str, credStr)
} else {
s = httpd.New(httpAddr, str, nil)
// This may be a standalone server. In that case set its own metadata.
if err := str.SetMetadata(meta); err != nil && err != store.ErrNotLeader {
// Non-leader errors are OK, since metadata will then be set through
// consensus as a result of a join. All other errors indicate a problem.
log.Fatalf("failed to set store metadata: %s", err.Error())
}
s.CACertFile = x509CACert
s.CertFile = x509Cert
s.KeyFile = x509Key
s.Expvar = expvar
s.Pprof = pprofEnabled
s.BuildInfo = map[string]interface{}{
"commit": commit,
"branch": branch,
"version": version,
"build_time": buildtime,
}
if err := s.Start(); err != nil {
// Start the HTTP API server.
if err := startHTTPService(str); err != nil {
log.Fatalf("failed to start HTTP server: %s", err.Error())
}
// Register cross-component statuses.
if err := s.RegisterStatus("mux", mux); err != nil {
log.Fatalf("failed to register mux status: %s", err.Error())
}
// Block until signalled.
terminate := make(chan os.Signal, 1)
signal.Notify(terminate, os.Interrupt)
@ -373,25 +324,32 @@ func determineJoinAddresses() ([]string, error) {
return addrs, nil
}
func publishAPIAddr(c *cluster.Service, raftAddr, apiAddr string, t time.Duration) error {
tck := time.NewTicker(publishPeerDelay)
defer tck.Stop()
tmr := time.NewTimer(t)
defer tmr.Stop()
for {
select {
case <-tck.C:
if err := c.SetPeer(raftAddr, apiAddr); err != nil {
log.Printf("failed to set peer for %s to %s: %s (retrying)",
raftAddr, apiAddr, err.Error())
continue
func startHTTPService(str *store.Store) error {
// Get the credential store.
credStr, err := credentialStore()
if err != nil {
return err
}
return nil
case <-tmr.C:
return fmt.Errorf("set peer timeout expired")
// Create HTTP server and load authentication information if required.
var s *httpd.Service
if credStr != nil {
s = httpd.New(httpAddr, str, credStr)
} else {
s = httpd.New(httpAddr, str, nil)
}
s.CertFile = x509Cert
s.KeyFile = x509Key
s.Expvar = expvar
s.Pprof = pprofEnabled
s.BuildInfo = map[string]interface{}{
"commit": commit,
"branch": branch,
"version": version,
"build_time": buildtime,
}
return s.Start()
}
func credentialStore() (*auth.CredentialsStore, error) {

@ -44,16 +44,16 @@ type Store interface {
Query(qr *store.QueryRequest) ([]*sql.Rows, error)
// Join joins the node with the given ID, reachable at addr, to this node.
Join(id, addr string) error
Join(id, addr string, metadata map[string]string) error
// Remove removes the node, specified by addr, from the cluster.
Remove(addr string) error
// Remove removes the node, specified by id, from the cluster.
Remove(id string) error
// LeaderAddr returns the Raft address of the leader of the cluster.
LeaderAddr() string
// Metadata returns the value for the given node ID, for the given key.
Metadata(id, key string) string
// Peer returns the API peer for the given address
Peer(addr string) string
// Leader returns the Raft address of the leader of the cluster.
LeaderID() (string, error)
// Stats returns stats on the Store.
Stats() (map[string]interface{}, error)
@ -289,27 +289,35 @@ func (s *Service) handleJoin(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
return
}
m := map[string]string{}
if err := json.Unmarshal(b, &m); err != nil {
md := map[string]interface{}{}
if err := json.Unmarshal(b, &md); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
remoteID, ok := m["id"]
remoteID, ok := md["id"]
if !ok {
w.WriteHeader(http.StatusBadRequest)
return
}
var m map[string]string
if _, ok := md["meta"].(map[string]interface{}); ok {
m = make(map[string]string)
for k, v := range md["meta"].(map[string]interface{}) {
m[k] = v.(string)
}
}
remoteAddr, ok := m["addr"]
remoteAddr, ok := md["addr"]
if !ok {
fmt.Println("4444")
w.WriteHeader(http.StatusBadRequest)
return
}
if err := s.store.Join(remoteID, remoteAddr); err != nil {
if err := s.store.Join(remoteID.(string), remoteAddr.(string), m); err != nil {
if err == store.ErrNotLeader {
leader := s.store.Peer(s.store.LeaderAddr())
leader := s.leaderAPIAddr()
if leader == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
@ -355,15 +363,15 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) {
return
}
remoteAddr, ok := m["addr"]
remoteID, ok := m["id"]
if !ok {
w.WriteHeader(http.StatusBadRequest)
return
}
if err := s.store.Remove(remoteAddr); err != nil {
if err := s.store.Remove(remoteID); err != nil {
if err == store.ErrNotLeader {
leader := s.store.Peer(s.store.LeaderAddr())
leader := s.leaderAPIAddr()
if leader == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
@ -448,7 +456,7 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
results, err := s.store.ExecuteOrAbort(&store.ExecuteRequest{queries, timings, false})
if err != nil {
if err == store.ErrNotLeader {
leader := s.store.Peer(s.store.LeaderAddr())
leader := s.leaderAPIAddr()
if leader == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
@ -498,7 +506,7 @@ func (s *Service) handleStatus(w http.ResponseWriter, r *http.Request) {
httpStatus := map[string]interface{}{
"addr": s.Addr().String(),
"auth": prettyEnabled(s.credentialStore != nil),
"redirect": s.store.Peer(s.store.LeaderAddr()),
"redirect": s.leaderAPIAddr(),
}
nodeStatus := map[string]interface{}{
@ -595,7 +603,7 @@ func (s *Service) handleExecute(w http.ResponseWriter, r *http.Request) {
results, err := s.store.Execute(&store.ExecuteRequest{queries, timings, isTx})
if err != nil {
if err == store.ErrNotLeader {
leader := s.store.Peer(s.store.LeaderAddr())
leader := s.leaderAPIAddr()
if leader == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
@ -657,7 +665,7 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
results, err := s.store.Query(&store.QueryRequest{queries, timings, isTx, lvl})
if err != nil {
if err == store.ErrNotLeader {
leader := s.store.Peer(s.store.LeaderAddr())
leader := s.leaderAPIAddr()
if leader == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
@ -746,6 +754,14 @@ func (s *Service) CheckRequestPerm(r *http.Request, perm string) bool {
return s.credentialStore.HasPerm(username, PermAll) || s.credentialStore.HasPerm(username, perm)
}
func (s *Service) leaderAPIAddr() string {
id, err := s.store.LeaderID()
if err != nil {
return ""
}
return s.store.Metadata(id, "api_addr")
}
// addBuildVersion adds the build version to the HTTP response.
func (s *Service) addBuildVersion(w http.ResponseWriter) {
// Add version header to every response, if available.

@ -491,19 +491,19 @@ func (m *MockStore) Query(qr *store.QueryRequest) ([]*sql.Rows, error) {
return nil, nil
}
func (m *MockStore) Join(id, addr string) error {
func (m *MockStore) Join(id, addr string, metadata map[string]string) error {
return nil
}
func (m *MockStore) Remove(addr string) error {
func (m *MockStore) Remove(id string) error {
return nil
}
func (m *MockStore) LeaderAddr() string {
return ""
func (m *MockStore) LeaderID() (string, error) {
return "", nil
}
func (m *MockStore) Peer(addr string) string {
func (m *MockStore) Metadata(id, key string) string {
return ""
}

@ -10,7 +10,8 @@ type commandType int
const (
execute commandType = iota // Commands which modify the database.
query // Commands which query the database.
peer // Commands that modify peers map.
metadataSet // Commands which sets Store metadata
metadataDelete // Commands which deletes Store metadata
)
type command struct {
@ -27,7 +28,14 @@ func newCommand(t commandType, d interface{}) (*command, error) {
Typ: t,
Sub: b,
}, nil
}
func newMetadataSetCommand(id string, md map[string]string) (*command, error) {
m := metadataSetSub{
RaftID: id,
Data: md,
}
return newCommand(metadataSet, m)
}
// databaseSub is a command sub which involves interaction with the database.
@ -37,5 +45,7 @@ type databaseSub struct {
Timings bool `json:"timings,omitempty"`
}
// peersSub is a command which sets the API address for a Raft address.
type peersSub map[string]string
type metadataSetSub struct {
RaftID string `json:"raft_id,omitempty"`
Data map[string]string `json:"data,omitempty"`
}

@ -0,0 +1,12 @@
package store
// DBConfig represents the configuration of the underlying SQLite database.
type DBConfig struct {
DSN string // Any custom DSN
Memory bool // Whether the database is in-memory only.
}
// NewDBConfig returns a new DB config instance.
func NewDBConfig(dsn string, memory bool) *DBConfig {
return &DBConfig{DSN: dsn, Memory: memory}
}

@ -0,0 +1,14 @@
package store
// Server represents another node in the cluster.
type Server struct {
ID string `json:"id,omitempty"`
Addr string `json:"addr,omitempty"`
}
// Servers is a set of Servers.
type Servers []*Server
func (s Servers) Less(i, j int) bool { return s[i].ID < s[j].ID }
func (s Servers) Len() int { return len(s) }
func (s Servers) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

@ -13,7 +13,6 @@ import (
"io"
"io/ioutil"
"log"
"net"
"os"
"path/filepath"
"sort"
@ -46,6 +45,9 @@ const (
sqliteFile = "db.sqlite"
leaderWaitDelay = 100 * time.Millisecond
appliedWaitDelay = 100 * time.Millisecond
connectionPoolCount = 5
connectionTimeout = 10 * time.Second
raftLogCacheSize = 512
)
const (
@ -114,60 +116,6 @@ const (
Unknown
)
// clusterMeta represents cluster meta which must be kept in consensus.
type clusterMeta struct {
APIPeers map[string]string // Map from Raft address to API address
}
// NewClusterMeta returns an initialized cluster meta store.
func newClusterMeta() *clusterMeta {
return &clusterMeta{
APIPeers: make(map[string]string),
}
}
func (c *clusterMeta) AddrForPeer(addr string) string {
if api, ok := c.APIPeers[addr]; ok && api != "" {
return api
}
// Go through each entry, and see if any key resolves to addr.
for k, v := range c.APIPeers {
resv, err := net.ResolveTCPAddr("tcp", k)
if err != nil {
continue
}
if resv.String() == addr {
return v
}
}
return ""
}
// DBConfig represents the configuration of the underlying SQLite database.
type DBConfig struct {
DSN string // Any custom DSN
Memory bool // Whether the database is in-memory only.
}
// NewDBConfig returns a new DB config instance.
func NewDBConfig(dsn string, memory bool) *DBConfig {
return &DBConfig{DSN: dsn, Memory: memory}
}
// Server represents another node in the cluster.
type Server struct {
ID string `json:"id,omitempty"`
Addr string `json:"addr,omitempty"`
}
type Servers []*Server
func (s Servers) Less(i, j int) bool { return s[i].ID < s[j].ID }
func (s Servers) Len() int { return len(s) }
func (s Servers) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// Store is a SQLite database, where all changes are made via Raft consensus.
type Store struct {
raftDir string
@ -175,14 +123,19 @@ type Store struct {
mu sync.RWMutex // Sync access between queries and snapshots.
raft *raft.Raft // The consensus mechanism.
raftTn *raftTransport
ln Listener
raftTn *raft.NetworkTransport
raftID string // Node ID.
dbConf *DBConfig // SQLite database config.
dbPath string // Path to underlying SQLite file, if not in-memory.
db *sql.DB // The underlying SQLite store.
raftLog raft.LogStore // Persistent log store.
raftStable raft.StableStore // Persistent k-v store.
boltStore *raftboltdb.BoltStore // Physical store.
metaMu sync.RWMutex
meta *clusterMeta
meta map[string]map[string]string
logger *log.Logger
@ -192,7 +145,6 @@ type Store struct {
HeartbeatTimeout time.Duration
ElectionTimeout time.Duration
ApplyTimeout time.Duration
OpenTimeout time.Duration
}
// StoreConfig represents the configuration of the underlying Store.
@ -205,22 +157,21 @@ type StoreConfig struct {
}
// New returns a new Store.
func New(c *StoreConfig) *Store {
func New(ln Listener, c *StoreConfig) *Store {
logger := c.Logger
if logger == nil {
logger = log.New(os.Stderr, "[store] ", log.LstdFlags)
}
return &Store{
ln: ln,
raftDir: c.Dir,
raftTn: &raftTransport{c.Tn},
raftID: c.ID,
dbConf: c.DBConf,
dbPath: filepath.Join(c.Dir, sqliteFile),
meta: newClusterMeta(),
meta: make(map[string]map[string]string),
logger: logger,
ApplyTimeout: applyTimeout,
OpenTimeout: openTimeout,
}
}
@ -234,6 +185,7 @@ func (s *Store) Open(enableSingle bool) error {
return err
}
// Open underlying database.
db, err := s.open()
if err != nil {
return err
@ -243,14 +195,11 @@ func (s *Store) Open(enableSingle bool) error {
// Is this a brand new node?
newNode := !pathExists(filepath.Join(s.raftDir, "raft.db"))
// Setup Raft communication.
transport := raft.NewNetworkTransport(s.raftTn, 3, 10*time.Second, os.Stderr)
// Create Raft-compatible network layer.
s.raftTn = raft.NewNetworkTransport(NewTransport(s.ln), connectionPoolCount, connectionTimeout, nil)
// Get the Raft configuration for this store.
config := s.raftConfig()
config.LocalID = raft.ServerID(s.raftID)
// XXXconfig.Logger = log.New(os.Stderr, "[raft] ", log.LstdFlags)
// Create the snapshot store. This allows Raft to truncate the log.
snapshots, err := raft.NewFileSnapshotStore(s.raftDir, retainSnapshotCount, os.Stderr)
@ -259,13 +208,18 @@ func (s *Store) Open(enableSingle bool) error {
}
// Create the log store and stable store.
logStore, err := raftboltdb.NewBoltStore(filepath.Join(s.raftDir, "raft.db"))
s.boltStore, err = raftboltdb.NewBoltStore(filepath.Join(s.raftDir, "raft.db"))
if err != nil {
return fmt.Errorf("new bolt store: %s", err)
}
s.raftStable = s.boltStore
s.raftLog, err = raft.NewLogCache(raftLogCacheSize, s.boltStore)
if err != nil {
return fmt.Errorf("new cached store: %s", err)
}
// Instantiate the Raft system.
ra, err := raft.NewRaft(config, s, logStore, logStore, snapshots, transport)
ra, err := raft.NewRaft(config, s, s.raftLog, s.raftStable, snapshots, s.raftTn)
if err != nil {
return fmt.Errorf("new raft: %s", err)
}
@ -276,7 +230,7 @@ func (s *Store) Open(enableSingle bool) error {
Servers: []raft.Server{
raft.Server{
ID: config.LocalID,
Address: transport.LocalAddr(),
Address: s.raftTn.LocalAddr(),
},
},
}
@ -287,16 +241,6 @@ func (s *Store) Open(enableSingle bool) error {
s.raft = ra
if s.OpenTimeout != 0 {
// Wait until the initial logs are applied.
s.logger.Printf("waiting for up to %s for application of initial logs", s.OpenTimeout)
if err := s.WaitForAppliedIndex(s.raft.LastIndex(), s.OpenTimeout); err != nil {
return ErrOpenTimeout
}
} else {
s.logger.Println("not waiting for application of initial logs")
}
return nil
}
@ -314,6 +258,19 @@ func (s *Store) Close(wait bool) error {
return nil
}
// WaitForApplied waits for all Raft log entries to to be applied to the
// underlying database.
func (s *Store) WaitForApplied(timeout time.Duration) error {
if timeout == 0 {
return nil
}
s.logger.Printf("waiting for up to %s for application of initial logs", timeout)
if err := s.WaitForAppliedIndex(s.raft.LastIndex(), timeout); err != nil {
return ErrOpenTimeout
}
return nil
}
// IsLeader is used to determine if the current node is cluster leader
func (s *Store) IsLeader() bool {
return s.raft.State() == raft.Leader
@ -342,8 +299,8 @@ func (s *Store) Path() string {
}
// Addr returns the address of the store.
func (s *Store) Addr() net.Addr {
return s.raftTn.Addr()
func (s *Store) Addr() string {
return string(s.raftTn.LocalAddr())
}
// ID returns the Raft ID of the store.
@ -375,24 +332,6 @@ func (s *Store) LeaderID() (string, error) {
return "", nil
}
// Peer returns the API address for the given addr. If there is no peer
// for the address, it returns the empty string.
func (s *Store) Peer(addr string) string {
return s.meta.AddrForPeer(addr)
}
// APIPeers return the map of Raft addresses to API addresses.
func (s *Store) APIPeers() (map[string]string, error) {
s.metaMu.RLock()
defer s.metaMu.RUnlock()
peers := make(map[string]string, len(s.meta.APIPeers))
for k, v := range s.meta.APIPeers {
peers[k] = v
}
return peers, nil
}
// Nodes returns the slice of nodes in the cluster, sorted by ID ascending.
func (s *Store) Nodes() ([]*Server, error) {
f := s.raft.GetConfiguration()
@ -480,18 +419,25 @@ func (s *Store) Stats() (map[string]interface{}, error) {
if err != nil {
return nil, err
}
leaderID, err := s.LeaderID()
if err != nil {
return nil, err
}
status := map[string]interface{}{
"node_id": s.raftID,
"raft": s.raft.Stats(),
"addr": s.Addr().String(),
"leader": s.LeaderAddr(),
"addr": s.Addr(),
"leader": map[string]string{
"node_id": leaderID,
"addr": s.LeaderAddr(),
},
"apply_timeout": s.ApplyTimeout.String(),
"open_timeout": s.OpenTimeout.String(),
"heartbeat_timeout": s.HeartbeatTimeout.String(),
"election_timeout": s.ElectionTimeout.String(),
"snapshot_threshold": s.SnapshotThreshold,
"meta": s.meta,
"peers": nodes,
"metadata": s.meta,
"nodes": nodes,
"dir": s.raftDir,
"sqlite3": dbStatus,
"db_conf": s.dbConf,
@ -636,27 +582,36 @@ func (s *Store) Query(qr *QueryRequest) ([]*sql.Rows, error) {
return r, err
}
// UpdateAPIPeers updates the cluster-wide peer information.
func (s *Store) UpdateAPIPeers(peers map[string]string) error {
c, err := newCommand(peer, peers)
if err != nil {
return err
// Join joins a node, identified by id and located at addr, to this store.
// The node must be ready to respond to Raft communications at that address.
func (s *Store) Join(id, addr string, metadata map[string]string) error {
s.logger.Printf("received request to join node at %s", addr)
if s.raft.State() != raft.Leader {
return ErrNotLeader
}
b, err := json.Marshal(c)
if err != nil {
configFuture := s.raft.GetConfiguration()
if err := configFuture.Error(); err != nil {
s.logger.Printf("failed to get raft configuration: %v", err)
return err
}
f := s.raft.Apply(b, s.ApplyTimeout)
return f.Error()
for _, srv := range configFuture.Configuration().Servers {
// If a node already exists with either the joining node's ID or address,
// that node may need to be removed from the config first.
if srv.ID == raft.ServerID(id) || srv.Address == raft.ServerAddress(addr) {
// However if *both* the ID and the address are the same, the no
// join is actually needed.
if srv.Address == raft.ServerAddress(addr) && srv.ID == raft.ServerID(id) {
s.logger.Printf("node %s at %s already member of cluster, ignoring join request", id, addr)
return nil
}
// Join joins a node, identified by id and located at addr, to this store.
// The node must be ready to respond to Raft communications at that address.
func (s *Store) Join(id, addr string) error {
s.logger.Printf("received request to join node at %s", addr)
if s.raft.State() != raft.Leader {
return ErrNotLeader
if err := s.remove(id); err != nil {
s.logger.Printf("failed to remove node: %v", err)
return err
}
}
}
f := s.raft.AddVoter(raft.ServerID(id), raft.ServerAddress(addr), 0, 0)
@ -666,6 +621,11 @@ func (s *Store) Join(id, addr string) error {
}
return e.Error()
}
if err := s.setMetadata(id, metadata); err != nil {
return err
}
s.logger.Printf("node at %s joined successfully", addr)
return nil
}
@ -673,18 +633,74 @@ func (s *Store) Join(id, addr string) error {
// Remove removes a node from the store, specified by ID.
func (s *Store) Remove(id string) error {
s.logger.Printf("received request to remove node %s", id)
if s.raft.State() != raft.Leader {
return ErrNotLeader
if err := s.remove(id); err != nil {
s.logger.Printf("failed to remove node %s: %s", id, err.Error())
return err
}
f := s.raft.RemoveServer(raft.ServerID(id), 0, 0)
if f.Error() != nil {
if f.Error() == raft.ErrNotLeader {
s.logger.Printf("node %s removed successfully", id)
return nil
}
// Metadata returns the value for a given key, for a given node ID.
func (s *Store) Metadata(id, key string) string {
s.metaMu.RLock()
defer s.metaMu.RUnlock()
if _, ok := s.meta[id]; !ok {
return ""
}
v, ok := s.meta[id][key]
if ok {
return v
}
return ""
}
// SetMetadata adds the metadata md to any existing metadata for
// this node.
func (s *Store) SetMetadata(md map[string]string) error {
return s.setMetadata(s.raftID, md)
}
// setMetadata adds the metadata md to any existing metadata for
// the given node ID.
func (s *Store) setMetadata(id string, md map[string]string) error {
// Check local data first.
if func() bool {
s.metaMu.RLock()
defer s.metaMu.RUnlock()
if _, ok := s.meta[id]; ok {
for k, v := range md {
if s.meta[id][k] != v {
return false
}
}
return true
}
return false
}() {
// Local data is same as data being pushed in,
// nothing to do.
return nil
}
c, err := newMetadataSetCommand(id, md)
if err != nil {
return err
}
b, err := json.Marshal(c)
if err != nil {
return err
}
f := s.raft.Apply(b, s.ApplyTimeout)
if e := f.(raft.Future); e.Error() != nil {
if e.Error() == raft.ErrNotLeader {
return ErrNotLeader
}
return f.Error()
e.Error()
}
s.logger.Printf("node %s removed successfully", id)
return nil
}
@ -712,6 +728,36 @@ func (s *Store) open() (*sql.DB, error) {
return db, nil
}
// remove removes the node, with the given ID, from the cluster.
func (s *Store) remove(id string) error {
if s.raft.State() != raft.Leader {
return ErrNotLeader
}
f := s.raft.RemoveServer(raft.ServerID(id), 0, 0)
if f.Error() != nil {
if f.Error() == raft.ErrNotLeader {
return ErrNotLeader
}
return f.Error()
}
c, err := newCommand(metadataDelete, id)
b, err := json.Marshal(c)
if err != nil {
return err
}
f = s.raft.Apply(b, s.ApplyTimeout)
if e := f.(raft.Future); e.Error() != nil {
if e.Error() == raft.ErrNotLeader {
return ErrNotLeader
}
e.Error()
}
return nil
}
// raftConfig returns a new Raft config for the store.
func (s *Store) raftConfig() *raft.Config {
config := raft.DefaultConfig()
@ -764,17 +810,31 @@ func (s *Store) Apply(l *raft.Log) interface{} {
}
r, err := s.db.Query(d.Queries, d.Tx, d.Timings)
return &fsmQueryResponse{rows: r, error: err}
case peer:
var d peersSub
case metadataSet:
var d metadataSetSub
if err := json.Unmarshal(c.Sub, &d); err != nil {
return &fsmGenericResponse{error: err}
}
func() {
s.metaMu.Lock()
defer s.metaMu.Unlock()
for k, v := range d {
s.meta.APIPeers[k] = v
if _, ok := s.meta[d.RaftID]; !ok {
s.meta[d.RaftID] = make(map[string]string)
}
for k, v := range d.Data {
s.meta[d.RaftID][k] = v
}
}()
return &fsmGenericResponse{}
case metadataDelete:
var d string
if err := json.Unmarshal(c.Sub, &d); err != nil {
return &fsmGenericResponse{error: err}
}
func() {
s.metaMu.Lock()
defer s.metaMu.Unlock()
delete(s.meta, d)
}()
return &fsmGenericResponse{}
default:

@ -6,7 +6,6 @@ import (
"net"
"os"
"path/filepath"
"reflect"
"sort"
"testing"
"time"
@ -14,35 +13,6 @@ import (
"github.com/rqlite/rqlite/testdata/chinook"
)
type mockSnapshotSink struct {
*os.File
}
func (m *mockSnapshotSink) ID() string {
return "1"
}
func (m *mockSnapshotSink) Cancel() error {
return nil
}
func Test_ClusterMeta(t *testing.T) {
c := newClusterMeta()
c.APIPeers["localhost:4002"] = "localhost:4001"
if c.AddrForPeer("localhost:4002") != "localhost:4001" {
t.Fatalf("wrong address returned for localhost:4002")
}
if c.AddrForPeer("127.0.0.1:4002") != "localhost:4001" {
t.Fatalf("wrong address returned for 127.0.0.1:4002")
}
if c.AddrForPeer("127.0.0.1:4004") != "" {
t.Fatalf("wrong address returned for 127.0.0.1:4003")
}
}
func Test_OpenStoreSingleNode(t *testing.T) {
s := mustNewStore(true)
defer os.RemoveAll(s.Path())
@ -52,7 +22,7 @@ func Test_OpenStoreSingleNode(t *testing.T) {
}
s.WaitForLeader(10 * time.Second)
if got, exp := s.LeaderAddr(), s.Addr().String(); got != exp {
if got, exp := s.LeaderAddr(), s.Addr(); got != exp {
t.Fatalf("wrong leader address returned, got: %s, exp %s", got, exp)
}
id, err := s.LeaderID()
@ -71,6 +41,7 @@ func Test_OpenStoreCloseSingleNode(t *testing.T) {
if err := s.Open(true); err != nil {
t.Fatalf("failed to open single-node store: %s", err.Error())
}
s.WaitForLeader(10 * time.Second)
if err := s.Close(true); err != nil {
t.Fatalf("failed to close single-node store: %s", err.Error())
}
@ -450,14 +421,14 @@ func Test_MultiNodeJoinRemove(t *testing.T) {
sort.StringSlice(storeNodes).Sort()
// Join the second node to the first.
if err := s0.Join(s1.ID(), s1.Addr().String()); err != nil {
t.Fatalf("failed to join to node at %s: %s", s0.Addr().String(), err.Error())
if err := s0.Join(s1.ID(), s1.Addr(), nil); err != nil {
t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error())
}
s1.WaitForLeader(10 * time.Second)
// Check leader state on follower.
if got, exp := s1.LeaderAddr(), s0.Addr().String(); got != exp {
if got, exp := s1.LeaderAddr(), s0.Addr(); got != exp {
t.Fatalf("wrong leader address returned, got: %s, exp %s", got, exp)
}
id, err := s1.LeaderID()
@ -514,8 +485,8 @@ func Test_MultiNodeExecuteQuery(t *testing.T) {
defer s1.Close(true)
// Join the second node to the first.
if err := s0.Join(s1.ID(), s1.Addr().String()); err != nil {
t.Fatalf("failed to join to node at %s: %s", s0.Addr().String(), err.Error())
if err := s0.Join(s1.ID(), s1.Addr(), nil); err != nil {
t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error())
}
queries := []string{
@ -609,7 +580,7 @@ func Test_StoreLogTruncationMultinode(t *testing.T) {
defer s1.Close(true)
// Join the second node to the first.
if err := s0.Join(s1.ID(), s1.Addr().String()); err != nil {
if err := s0.Join(s1.ID(), s1.Addr(), nil); err != nil {
t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error())
}
s1.WaitForLeader(10 * time.Second)
@ -754,38 +725,65 @@ func Test_SingleNodeSnapshotInMem(t *testing.T) {
}
}
func Test_APIPeers(t *testing.T) {
s := mustNewStore(false)
defer os.RemoveAll(s.Path())
if err := s.Open(true); err != nil {
func Test_MetadataMultinode(t *testing.T) {
s0 := mustNewStore(true)
if err := s0.Open(true); err != nil {
t.Fatalf("failed to open single-node store: %s", err.Error())
}
defer s.Close(true)
s.WaitForLeader(10 * time.Second)
defer s0.Close(true)
s0.WaitForLeader(10 * time.Second)
s1 := mustNewStore(true)
if err := s1.Open(true); err != nil {
t.Fatalf("failed to open single-node store: %s", err.Error())
}
defer s1.Close(true)
s1.WaitForLeader(10 * time.Second)
peers := map[string]string{
"localhost:4002": "localhost:4001",
"localhost:4004": "localhost:4003",
if s0.Metadata(s0.raftID, "foo") != "" {
t.Fatal("nonexistent metadata foo found")
}
if err := s.UpdateAPIPeers(peers); err != nil {
t.Fatalf("failed to update API peers: %s", err.Error())
if s0.Metadata("nonsense", "foo") != "" {
t.Fatal("nonexistent metadata foo found for nonexistent node")
}
// Retrieve peers and verify them.
apiPeers, err := s.APIPeers()
if err != nil {
t.Fatalf("failed to retrieve API peers: %s", err.Error())
if err := s0.SetMetadata(map[string]string{"foo": "bar"}); err != nil {
t.Fatalf("failed to set metadata: %s", err.Error())
}
if s0.Metadata(s0.raftID, "foo") != "bar" {
t.Fatal("key foo not found")
}
if !reflect.DeepEqual(peers, apiPeers) {
t.Fatalf("set and retrieved API peers not identical, got %v, exp %v",
apiPeers, peers)
if s0.Metadata("nonsense", "foo") != "" {
t.Fatal("nonexistent metadata foo found for nonexistent node")
}
if s.Peer("localhost:4002") != "localhost:4001" ||
s.Peer("localhost:4004") != "localhost:4003" ||
s.Peer("not exist") != "" {
t.Fatalf("failed to retrieve correct single API peer")
// Join the second node to the first.
meta := map[string]string{"baz": "qux"}
if err := s0.Join(s1.ID(), s1.Addr(), meta); err != nil {
t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error())
}
s1.WaitForLeader(10 * time.Second)
// Wait until the log entries have been applied to the follower,
// and then query.
if err := s1.WaitForAppliedIndex(5, 5*time.Second); err != nil {
t.Fatalf("error waiting for follower to apply index: %s:", err.Error())
}
if s1.Metadata(s0.raftID, "foo") != "bar" {
t.Fatal("key foo not found for s0")
}
if s0.Metadata(s1.raftID, "baz") != "qux" {
t.Fatal("key baz not found for s1")
}
// Remove a node.
if err := s0.Remove(s1.ID()); err != nil {
t.Fatalf("failed to remove %s from cluster: %s", s1.ID(), err.Error())
}
if s1.Metadata(s0.raftID, "foo") != "bar" {
t.Fatal("key foo not found for s0")
}
if s0.Metadata(s1.raftID, "baz") != "" {
t.Fatal("key baz found for removed node s1")
}
}
@ -824,12 +822,10 @@ func mustNewStore(inmem bool) *Store {
path := mustTempDir()
defer os.RemoveAll(path)
tn := mustMockTransport("localhost:0")
cfg := NewDBConfig("", inmem)
s := New(&StoreConfig{
s := New(mustMockLister("localhost:0"), &StoreConfig{
DBConf: cfg,
Dir: path,
Tn: tn,
ID: path, // Could be any unique string.
})
if s == nil {
@ -838,36 +834,52 @@ func mustNewStore(inmem bool) *Store {
return s
}
func mustTempDir() string {
var err error
path, err := ioutil.TempDir("", "rqlilte-test-")
if err != nil {
panic("failed to create temp dir")
type mockSnapshotSink struct {
*os.File
}
return path
func (m *mockSnapshotSink) ID() string {
return "1"
}
func (m *mockSnapshotSink) Cancel() error {
return nil
}
type mockTransport struct {
ln net.Listener
}
func mustMockTransport(addr string) Transport {
type mockListener struct {
ln net.Listener
}
func mustMockLister(addr string) Listener {
ln, err := net.Listen("tcp", addr)
if err != nil {
panic("failed to create new transport")
panic("failed to create new listner")
}
return &mockTransport{ln}
return &mockListener{ln}
}
func (m *mockTransport) Dial(addr string, timeout time.Duration) (net.Conn, error) {
func (m *mockListener) Dial(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, timeout)
}
func (m *mockTransport) Accept() (net.Conn, error) { return m.ln.Accept() }
func (m *mockListener) Accept() (net.Conn, error) { return m.ln.Accept() }
func (m *mockListener) Close() error { return m.ln.Close() }
func (m *mockTransport) Close() error { return m.ln.Close() }
func (m *mockListener) Addr() net.Addr { return m.ln.Addr() }
func (m *mockTransport) Addr() net.Addr { return m.ln.Addr() }
func mustTempDir() string {
var err error
path, err := ioutil.TempDir("", "rqlilte-test-")
if err != nil {
panic("failed to create temp dir")
}
return path
}
func asJSON(v interface{}) string {
b, err := json.Marshal(v)

@ -7,32 +7,39 @@ import (
"github.com/hashicorp/raft"
)
// Transport is the interface the network service must provide.
type Transport interface {
type Listener interface {
net.Listener
// Dial is used to create a new outgoing connection
Dial(address string, timeout time.Duration) (net.Conn, error)
}
// raftTransport takes a Transport and makes it suitable for use by the Raft
// networking system.
type raftTransport struct {
tn Transport
// Transport is the network service provided to Raft, and wraps a Listener.
type Transport struct {
ln Listener
}
// NewTransport returns an initialized Transport.
func NewTransport(ln Listener) *Transport {
return &Transport{
ln: ln,
}
}
func (r *raftTransport) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
return r.tn.Dial(string(address), timeout)
// Dial creates a new network connection.
func (t *Transport) Dial(addr raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
return t.ln.Dial(string(addr), timeout)
}
func (r *raftTransport) Accept() (net.Conn, error) {
return r.tn.Accept()
// Accept waits for the next connection.
func (t *Transport) Accept() (net.Conn, error) {
return t.ln.Accept()
}
func (r *raftTransport) Addr() net.Addr {
return r.tn.Addr()
// Close closes the transport
func (t *Transport) Close() error {
return t.ln.Close()
}
func (r *raftTransport) Close() error {
return r.tn.Close()
// Addr returns the binding address of the transport.
func (t *Transport) Addr() net.Addr {
return t.ln.Addr()
}

@ -0,0 +1,11 @@
package store
import (
"testing"
)
func Test_NewTransport(t *testing.T) {
if NewTransport(nil) == nil {
t.Fatal("failed to create new Transport")
}
}

@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
@ -14,12 +13,14 @@ import (
httpd "github.com/rqlite/rqlite/http"
"github.com/rqlite/rqlite/store"
"github.com/rqlite/rqlite/tcp"
)
// Node represents a node under test.
type Node struct {
APIAddr string
RaftAddr string
ID string
Dir string
Store *store.Store
Service *httpd.Service
@ -302,18 +303,21 @@ func mustNewNode(enableSingle bool) *Node {
}
dbConf := store.NewDBConfig("", false)
tn := mustMockTransport("localhost:0")
node.Store = store.New(&store.StoreConfig{
tn := tcp.NewTransport()
if err := tn.Open("localhost:0"); err != nil {
panic(err.Error())
}
node.Store = store.New(tn, &store.StoreConfig{
DBConf: dbConf,
Dir: node.Dir,
Tn: tn,
ID: tn.Addr().String(),
})
if err := node.Store.Open(enableSingle); err != nil {
node.Deprovision()
panic(fmt.Sprintf("failed to open store: %s", err.Error()))
}
node.RaftAddr = node.Store.Addr().String()
node.RaftAddr = node.Store.Addr()
node.ID = node.Store.ID()
node.Service = httpd.New("localhost:0", node.Store, nil)
node.Service.Expvar = true
@ -335,28 +339,6 @@ func mustNewLeaderNode() *Node {
return node
}
type mockTransport struct {
ln net.Listener
}
func mustMockTransport(addr string) *mockTransport {
ln, err := net.Listen("tcp", addr)
if err != nil {
panic("failed to create new transport")
}
return &mockTransport{ln}
}
func (m *mockTransport) Dial(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, timeout)
}
func (m *mockTransport) Accept() (net.Conn, error) { return m.ln.Accept() }
func (m *mockTransport) Close() error { return m.ln.Close() }
func (m *mockTransport) Addr() net.Addr { return m.ln.Addr() }
func mustTempDir() string {
var err error
path, err := ioutil.TempDir("", "rqlilte-system-test-")

@ -1,4 +1,4 @@
/*
Package tcp provides various TCP-related utilities. The TCP mux code provided by this package originated with InfluxDB.
Package tcp provides the internode communication network layer.
*/
package tcp

@ -1,301 +0,0 @@
package tcp
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"strconv"
"sync"
"time"
)
const (
// DefaultTimeout is the default length of time to wait for first byte.
DefaultTimeout = 30 * time.Second
)
// Layer represents the connection between nodes.
type Layer struct {
ln net.Listener
header byte
addr net.Addr
remoteEncrypted bool
skipVerify bool
nodeX509CACert string
tlsConfig *tls.Config
}
// Addr returns the local address for the layer.
func (l *Layer) Addr() net.Addr {
return l.addr
}
// Dial creates a new network connection.
func (l *Layer) Dial(addr string, timeout time.Duration) (net.Conn, error) {
dialer := &net.Dialer{Timeout: timeout}
var err error
var conn net.Conn
if l.remoteEncrypted {
conn, err = tls.DialWithDialer(dialer, "tcp", addr, l.tlsConfig)
} else {
conn, err = dialer.Dial("tcp", addr)
}
if err != nil {
return nil, err
}
// Write a marker byte to indicate message type.
_, err = conn.Write([]byte{l.header})
if err != nil {
conn.Close()
return nil, err
}
return conn, err
}
// Accept waits for the next connection.
func (l *Layer) Accept() (net.Conn, error) { return l.ln.Accept() }
// Close closes the layer.
func (l *Layer) Close() error { return l.ln.Close() }
// Mux multiplexes a network connection.
type Mux struct {
ln net.Listener
addr net.Addr
m map[byte]*listener
wg sync.WaitGroup
remoteEncrypted bool
// The amount of time to wait for the first header byte.
Timeout time.Duration
// Out-of-band error logger
Logger *log.Logger
// Path to root X.509 certificate.
nodeX509CACert string
// Path to X509 certificate
nodeX509Cert string
// Path to X509 key.
nodeX509Key string
// Whether to skip verification of other nodes' certificates.
InsecureSkipVerify bool
tlsConfig *tls.Config
}
// NewMux returns a new instance of Mux for ln. If adv is nil,
// then the addr of ln is used.
func NewMux(ln net.Listener, adv net.Addr) (*Mux, error) {
addr := adv
if addr == nil {
addr = ln.Addr()
}
return &Mux{
ln: ln,
addr: addr,
m: make(map[byte]*listener),
Timeout: DefaultTimeout,
Logger: log.New(os.Stderr, "[mux] ", log.LstdFlags),
}, nil
}
// NewTLSMux returns a new instance of Mux for ln, and encrypts all traffic
// using TLS. If adv is nil, then the addr of ln is used.
func NewTLSMux(ln net.Listener, adv net.Addr, cert, key, caCert string) (*Mux, error) {
mux, err := NewMux(ln, adv)
if err != nil {
return nil, err
}
mux.tlsConfig, err = createTLSConfig(cert, key, caCert)
if err != nil {
return nil, err
}
mux.ln = tls.NewListener(ln, mux.tlsConfig)
mux.remoteEncrypted = true
mux.nodeX509CACert = caCert
mux.nodeX509Cert = cert
mux.nodeX509Key = key
return mux, nil
}
// Serve handles connections from ln and multiplexes then across registered listener.
func (mux *Mux) Serve() error {
mux.Logger.Printf("mux serving on %s, advertising %s", mux.ln.Addr().String(), mux.addr)
for {
// Wait for the next connection.
// If it returns a temporary error then simply retry.
// If it returns any other error then exit immediately.
conn, err := mux.ln.Accept()
if err, ok := err.(interface {
Temporary() bool
}); ok && err.Temporary() {
continue
}
if err != nil {
// Wait for all connections to be demuxed
mux.wg.Wait()
for _, ln := range mux.m {
close(ln.c)
}
return err
}
// Demux in a goroutine to
mux.wg.Add(1)
go mux.handleConn(conn)
}
}
// Stats returns status of the mux.
func (mux *Mux) Stats() (interface{}, error) {
s := map[string]string{
"addr": mux.addr.String(),
"timeout": mux.Timeout.String(),
"encrypted": strconv.FormatBool(mux.remoteEncrypted),
}
if mux.remoteEncrypted {
s["certificate"] = mux.nodeX509Cert
s["key"] = mux.nodeX509Key
s["ca_certificate"] = mux.nodeX509CACert
s["skip_verify"] = strconv.FormatBool(mux.InsecureSkipVerify)
}
return s, nil
}
func (mux *Mux) handleConn(conn net.Conn) {
defer mux.wg.Done()
// Set a read deadline so connections with no data don't timeout.
if err := conn.SetReadDeadline(time.Now().Add(mux.Timeout)); err != nil {
conn.Close()
mux.Logger.Printf("tcp.Mux: cannot set read deadline: %s", err)
return
}
// Read first byte from connection to determine handler.
var typ [1]byte
if _, err := io.ReadFull(conn, typ[:]); err != nil {
conn.Close()
mux.Logger.Printf("tcp.Mux: cannot read header byte: %s", err)
return
}
// Reset read deadline and let the listener handle that.
if err := conn.SetReadDeadline(time.Time{}); err != nil {
conn.Close()
mux.Logger.Printf("tcp.Mux: cannot reset set read deadline: %s", err)
return
}
// Retrieve handler based on first byte.
handler := mux.m[typ[0]]
if handler == nil {
conn.Close()
mux.Logger.Printf("tcp.Mux: handler not registered: %d", typ[0])
return
}
// Send connection to handler. The handler is responsible for closing the connection.
handler.c <- conn
}
// Listen returns a listener identified by header.
// Any connection accepted by mux is multiplexed based on the initial header byte.
func (mux *Mux) Listen(header byte) *Layer {
// Ensure two listeners are not created for the same header byte.
if _, ok := mux.m[header]; ok {
panic(fmt.Sprintf("listener already registered under header byte: %d", header))
}
// Create a new listener and assign it.
ln := &listener{
c: make(chan net.Conn),
}
mux.m[header] = ln
layer := &Layer{
ln: ln,
header: header,
addr: mux.addr,
remoteEncrypted: mux.remoteEncrypted,
skipVerify: mux.InsecureSkipVerify,
nodeX509CACert: mux.nodeX509CACert,
tlsConfig: mux.tlsConfig,
}
return layer
}
// listener is a receiver for connections received by Mux.
type listener struct {
c chan net.Conn
}
// Accept waits for and returns the next connection to the listener.
func (ln *listener) Accept() (c net.Conn, err error) {
conn, ok := <-ln.c
if !ok {
return nil, errors.New("network connection closed")
}
return conn, nil
}
// Close is a no-op. The mux's listener should be closed instead.
func (ln *listener) Close() error { return nil }
// Addr always returns nil
func (ln *listener) Addr() net.Addr { return nil }
// newTLSListener returns a net listener which encrypts the traffic using TLS.
func newTLSListener(ln net.Listener, certFile, keyFile, caCertFile string) (net.Listener, error) {
config, err := createTLSConfig(certFile, keyFile, caCertFile)
if err != nil {
return nil, err
}
return tls.NewListener(ln, config), nil
}
// createTLSConfig returns a TLS config from the given cert and key.
func createTLSConfig(certFile, keyFile, caCertFile string) (*tls.Config, error) {
var err error
config := &tls.Config{}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
if caCertFile != "" {
asn1Data, err := ioutil.ReadFile(caCertFile)
if err != nil {
return nil, err
}
config.RootCAs = x509.NewCertPool()
ok := config.RootCAs.AppendCertsFromPEM([]byte(asn1Data))
if !ok {
return nil, fmt.Errorf("failed to parse root certificate(s) in %s", caCertFile)
}
}
return config, nil
}

@ -1,235 +0,0 @@
package tcp
import (
"bytes"
"crypto/tls"
"io"
"io/ioutil"
"log"
"net"
"os"
"strings"
"sync"
"testing"
"testing/quick"
"time"
"github.com/rqlite/rqlite/testdata/x509"
)
// Ensure the muxer can split a listener's connections across multiple listeners.
func TestMux(t *testing.T) {
if err := quick.Check(func(n uint8, msg []byte) bool {
if testing.Verbose() {
if len(msg) == 0 {
log.Printf("n=%d, <no message>", n)
} else {
log.Printf("n=%d, hdr=%d, len=%d", n, msg[0], len(msg))
}
}
var wg sync.WaitGroup
// Open single listener on random port.
tcpListener := mustTCPListener("127.0.0.1:0")
defer tcpListener.Close()
// Setup muxer & listeners.
mux, err := NewMux(tcpListener, nil)
if err != nil {
t.Fatalf("failed to create mux: %s", err.Error())
}
mux.Timeout = 200 * time.Millisecond
if !testing.Verbose() {
mux.Logger = log.New(ioutil.Discard, "", 0)
}
for i := uint8(0); i < n; i++ {
ln := mux.Listen(byte(i))
wg.Add(1)
go func(i uint8, ln net.Listener) {
defer wg.Done()
// Wait for a connection for this listener.
conn, err := ln.Accept()
if conn != nil {
defer conn.Close()
}
// If there is no message or the header byte
// doesn't match then expect close.
if len(msg) == 0 || msg[0] != byte(i) {
if err == nil || err.Error() != "network connection closed" {
t.Fatalf("unexpected error: %s", err)
}
return
}
// If the header byte matches this listener
// then expect a connection and read the message.
var buf bytes.Buffer
if _, err := io.CopyN(&buf, conn, int64(len(msg)-1)); err != nil {
t.Fatal(err)
} else if !bytes.Equal(msg[1:], buf.Bytes()) {
t.Fatalf("message mismatch:\n\nexp=%x\n\ngot=%x\n\n", msg[1:], buf.Bytes())
}
// Write response.
if _, err := conn.Write([]byte("OK")); err != nil {
t.Fatal(err)
}
}(i, ln)
}
// Begin serving from the listener.
go mux.Serve()
// Write message to TCP listener and read OK response.
conn, err := net.Dial("tcp", tcpListener.Addr().String())
if err != nil {
t.Fatal(err)
} else if _, err = conn.Write(msg); err != nil {
t.Fatal(err)
}
// Read the response into the buffer.
var resp [2]byte
_, err = io.ReadFull(conn, resp[:])
// If the message header is less than n then expect a response.
// Otherwise we should get an EOF because the mux closed.
if len(msg) > 0 && uint8(msg[0]) < n {
if string(resp[:]) != `OK` {
t.Fatalf("unexpected response: %s", resp[:])
}
} else {
if err == nil || (err != io.EOF && !(strings.Contains(err.Error(), "connection reset by peer") ||
strings.Contains(err.Error(), "closed by the remote host"))) {
t.Fatalf("unexpected error: %s", err)
}
}
// Close connection.
if err := conn.Close(); err != nil {
t.Fatal(err)
}
// Close original TCP listener and wait for all goroutines to close.
tcpListener.Close()
wg.Wait()
return true
}, nil); err != nil {
t.Error(err)
}
}
func TestMux_Advertise(t *testing.T) {
// Setup muxer.
tcpListener := mustTCPListener("127.0.0.1:0")
defer tcpListener.Close()
addr := &mockAddr{
Nwk: "tcp",
Addr: "rqlite.com:8081",
}
mux, err := NewMux(tcpListener, addr)
if err != nil {
t.Fatalf("failed to create mux: %s", err.Error())
}
mux.Timeout = 200 * time.Millisecond
if !testing.Verbose() {
mux.Logger = log.New(ioutil.Discard, "", 0)
}
layer := mux.Listen(1)
if layer.Addr().String() != addr.Addr {
t.Fatalf("layer advertise address not correct, exp %s, got %s",
layer.Addr().String(), addr.Addr)
}
}
// Ensure two handlers cannot be registered for the same header byte.
func TestMux_Listen_ErrAlreadyRegistered(t *testing.T) {
defer func() {
if r := recover(); r != `listener already registered under header byte: 5` {
t.Fatalf("unexpected recover: %#v", r)
}
}()
// Register two listeners with the same header byte.
tcpListener := mustTCPListener("127.0.0.1:0")
mux, err := NewMux(tcpListener, nil)
if err != nil {
t.Fatalf("failed to create mux: %s", err.Error())
}
mux.Listen(5)
mux.Listen(5)
}
func TestTLSMux(t *testing.T) {
tcpListener := mustTCPListener("127.0.0.1:0")
defer tcpListener.Close()
cert := x509.CertFile()
defer os.Remove(cert)
key := x509.KeyFile()
defer os.Remove(key)
mux, err := NewTLSMux(tcpListener, nil, cert, key, "")
if err != nil {
t.Fatalf("failed to create mux: %s", err.Error())
}
go mux.Serve()
// Verify that the listener is secured.
conn, err := tls.Dial("tcp", tcpListener.Addr().String(), &tls.Config{
InsecureSkipVerify: true,
})
if err != nil {
t.Fatal(err)
}
state := conn.ConnectionState()
if !state.HandshakeComplete {
t.Fatal("connection handshake failed to complete")
}
}
func TestTLSMux_Fail(t *testing.T) {
tcpListener := mustTCPListener("127.0.0.1:0")
defer tcpListener.Close()
cert := x509.CertFile()
defer os.Remove(cert)
key := x509.KeyFile()
defer os.Remove(key)
_, err := NewTLSMux(tcpListener, nil, "xxxx", "yyyy", "")
if err == nil {
t.Fatalf("created mux unexpectedly with bad resources")
}
}
type mockAddr struct {
Nwk string
Addr string
}
func (m *mockAddr) Network() string {
return m.Nwk
}
func (m *mockAddr) String() string {
return m.Addr
}
// mustTCPListener returns a listener on bind, or panics.
func mustTCPListener(bind string) net.Listener {
l, err := net.Listen("tcp", bind)
if err != nil {
panic(err)
}
return l
}

@ -0,0 +1,101 @@
package tcp
import (
"crypto/tls"
"fmt"
"net"
"time"
)
// Transport is the network layer for internode communications.
type Transport struct {
ln net.Listener
certFile string // Path to local X.509 cert.
certKey string // Path to corresponding X.509 key.
remoteEncrypted bool // Remote nodes use encrypted communication.
skipVerify bool // Skip verification of remote node certs.
}
// NewTransport returns an initialized unecrypted Transport.
func NewTransport() *Transport {
return &Transport{}
}
// NewTransport returns an initialized TLS-ecrypted Transport.
func NewTLSTransport(certFile, keyPath string, skipVerify bool) *Transport {
return &Transport{
certFile: certFile,
certKey: keyPath,
remoteEncrypted: true,
skipVerify: skipVerify,
}
}
// Open opens the transport, binding to the supplied address.
func (t *Transport) Open(addr string) error {
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
if t.certFile != "" {
config, err := createTLSConfig(t.certFile, t.certKey)
if err != nil {
return err
}
ln = tls.NewListener(ln, config)
}
t.ln = ln
return nil
}
// Dial opens a network connection.
func (t *Transport) Dial(addr string, timeout time.Duration) (net.Conn, error) {
dialer := &net.Dialer{Timeout: timeout}
var err error
var conn net.Conn
if t.remoteEncrypted {
conf := &tls.Config{
InsecureSkipVerify: t.skipVerify,
}
fmt.Println("doing a TLS dial")
conn, err = tls.DialWithDialer(dialer, "tcp", addr, conf)
} else {
conn, err = dialer.Dial("tcp", addr)
}
return conn, err
}
// Accept waits for the next connection.
func (t *Transport) Accept() (net.Conn, error) {
c, err := t.ln.Accept()
if err != nil {
fmt.Println("error accepting: ", err.Error())
}
return c, err
}
// Close closes the transport
func (t *Transport) Close() error {
return t.ln.Close()
}
// Addr returns the binding address of the transport.
func (t *Transport) Addr() net.Addr {
return t.ln.Addr()
}
// createTLSConfig returns a TLS config from the given cert and key.
func createTLSConfig(certFile, keyFile string) (*tls.Config, error) {
var err error
config := &tls.Config{}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
return config, nil
}

@ -0,0 +1,70 @@
package tcp
import (
"os"
"testing"
"time"
"github.com/rqlite/rqlite/testdata/x509"
)
func Test_NewTransport(t *testing.T) {
if NewTransport() == nil {
t.Fatal("failed to create new Transport")
}
}
func Test_TransportOpenClose(t *testing.T) {
tn := NewTransport()
if err := tn.Open("localhost:0"); err != nil {
t.Fatalf("failed to open transport: %s", err.Error())
}
if tn.Addr().String() == "localhost:0" {
t.Fatalf("transport address set incorrectly, got: %s", tn.Addr().String())
}
if err := tn.Close(); err != nil {
t.Fatalf("failed to close transport: %s", err.Error())
}
}
func Test_TransportDial(t *testing.T) {
tn1 := NewTransport()
tn1.Open("localhost:0")
go tn1.Accept()
tn2 := NewTransport()
_, err := tn2.Dial(tn1.Addr().String(), time.Second)
if err != nil {
t.Fatalf("failed to connect to first transport: %s", err.Error())
}
tn1.Close()
}
func Test_NewTLSTransport(t *testing.T) {
c := x509.CertFile()
defer os.Remove(c)
k := x509.KeyFile()
defer os.Remove(k)
if NewTLSTransport(c, k, true) == nil {
t.Fatal("failed to create new TLS Transport")
}
}
func Test_TLSTransportOpenClose(t *testing.T) {
c := x509.CertFile()
defer os.Remove(c)
k := x509.KeyFile()
defer os.Remove(k)
tn := NewTLSTransport(c, k, true)
if err := tn.Open("localhost:0"); err != nil {
t.Fatalf("failed to open TLS transport: %s", err.Error())
}
if tn.Addr().String() == "localhost:0" {
t.Fatalf("TLS transport address set incorrectly, got: %s", tn.Addr().String())
}
if err := tn.Close(); err != nil {
t.Fatalf("failed to close TLS transport: %s", err.Error())
}
}
Loading…
Cancel
Save