1
0
Fork 0

Merge pull request #605 from rqlite/broadcast_meta

Broadcast Store meta via standard consensus
master
Philip O'Toole 5 years ago committed by GitHub
commit f9388f8344
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -57,7 +57,7 @@ You can grow a cluster, at anytime, simply by starting up a new node and having
# Removing or replacing a node
If a node fails completely and is not coming back, or if you shut down a node because you wish to deprovision it, its record should also be removed from the cluster. To remove the record of a node from a cluster, execute the following command:
```
curl -XDELETE http://localhost:4001/remove -d '{"addr": "<node raft address>"}'
curl -XDELETE http://localhost:4001/remove -d '{"id": "<node raft ID>"}'
```
assuming `localhost` is the address of the cluster leader. If you do not do this the leader will continually attempt to communicate with that node.

@ -20,9 +20,9 @@ const numAttempts int = 3
const attemptInterval time.Duration = 5 * time.Second
// It walks through joinAddr in order, and sets the node ID and Raft address of
// the joining node as nodeID advAddr respectively. It returns the endpoint
// successfully used to join the cluster.
func Join(joinAddr []string, nodeID, advAddr string, tlsConfig *tls.Config) (string, error) {
// the joining node as id addr respectively. It returns the endpoint successfully
// used to join the cluster.
func Join(joinAddr []string, id, addr string, meta map[string]string, tlsConfig *tls.Config) (string, error) {
var err error
var j string
logger := log.New(os.Stderr, "[cluster-join] ", log.LstdFlags)
@ -32,7 +32,7 @@ func Join(joinAddr []string, nodeID, advAddr string, tlsConfig *tls.Config) (str
for i := 0; i < numAttempts; i++ {
for _, a := range joinAddr {
j, err = join(a, nodeID, advAddr, tlsConfig, logger)
j, err = join(a, id, addr, meta, tlsConfig, logger)
if err == nil {
// Success!
return j, nil
@ -45,13 +45,13 @@ func Join(joinAddr []string, nodeID, advAddr string, tlsConfig *tls.Config) (str
return "", err
}
func join(joinAddr, nodeID, advAddr string, tlsConfig *tls.Config, logger *log.Logger) (string, error) {
if nodeID == "" {
func join(joinAddr, id, addr string, meta map[string]string, tlsConfig *tls.Config, logger *log.Logger) (string, error) {
if id == "" {
return "", fmt.Errorf("node ID not set")
}
// Join using IP address, as that is what Hashicorp Raft works in.
resv, err := net.ResolveTCPAddr("tcp", advAddr)
resv, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return "", err
}
@ -59,7 +59,7 @@ func join(joinAddr, nodeID, advAddr string, tlsConfig *tls.Config, logger *log.L
// Check for protocol scheme, and insert default if necessary.
fullAddr := httpd.NormalizeAddr(fmt.Sprintf("%s/join", joinAddr))
// Enable skipVerify as requested.
// Create and configure the client to connect to the other node.
tr := &http.Transport{
TLSClientConfig: tlsConfig,
}
@ -69,9 +69,10 @@ func join(joinAddr, nodeID, advAddr string, tlsConfig *tls.Config, logger *log.L
}
for {
b, err := json.Marshal(map[string]string{
"id": nodeID,
b, err := json.Marshal(map[string]interface{}{
"id": id,
"addr": resv.String(),
"meta": meta,
})
// Attempt to join.

@ -1,7 +1,9 @@
package cluster
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"testing"
@ -16,7 +18,7 @@ func Test_SingleJoinOK(t *testing.T) {
}))
defer ts.Close()
j, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil)
j, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil, nil)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
@ -25,13 +27,57 @@ func Test_SingleJoinOK(t *testing.T) {
}
}
func Test_SingleJoinMetaOK(t *testing.T) {
var body map[string]interface{}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Fatalf("Client did not use POST")
}
w.WriteHeader(http.StatusOK)
b, err := ioutil.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
if err := json.Unmarshal(b, &body); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
}))
defer ts.Close()
nodeAddr := "127.0.0.1:9090"
md := map[string]string{"foo": "bar"}
j, err := Join([]string{ts.URL}, "id0", nodeAddr, md, nil)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
if j != ts.URL+"/join" {
t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts.URL)
}
if id, _ := body["id"]; id != "id0" {
t.Fatalf("node joined supplying wrong ID, exp %s, got %s", "id0", body["id"])
}
if addr, _ := body["addr"]; addr != nodeAddr {
t.Fatalf("node joined supplying wrong address, exp %s, got %s", nodeAddr, body["addr"])
}
rxMd, _ := body["meta"].(map[string]interface{})
if len(rxMd) != len(md) || rxMd["foo"] != "bar" {
t.Fatalf("node joined supplying wrong meta")
}
}
func Test_SingleJoinFail(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer ts.Close()
_, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil)
_, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil, nil)
if err == nil {
t.Fatalf("expected error when joining bad node")
}
@ -45,7 +91,7 @@ func Test_DoubleJoinOK(t *testing.T) {
}))
defer ts2.Close()
j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil)
j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil, nil)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
@ -63,7 +109,7 @@ func Test_DoubleJoinOKSecondNode(t *testing.T) {
}))
defer ts2.Close()
j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil)
j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil, nil)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
@ -83,7 +129,7 @@ func Test_DoubleJoinOKSecondNodeRedirect(t *testing.T) {
}))
defer ts2.Close()
j, err := Join([]string{ts2.URL}, "id0", "127.0.0.1:9090", nil)
j, err := Join([]string{ts2.URL}, "id0", "127.0.0.1:9090", nil, nil)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}

@ -1,180 +0,0 @@
package cluster
import (
"encoding/json"
"fmt"
"log"
"net"
"os"
"sync"
"time"
)
const (
connectionTimeout = 10 * time.Second
)
var respOKMarshalled []byte
func init() {
var err error
respOKMarshalled, err = json.Marshal(response{})
if err != nil {
panic(fmt.Sprintf("unable to JSON marshal OK response: %s", err.Error()))
}
}
type response struct {
Code int `json:"code,omitempty"`
Message string `json:"message,omitempty"`
}
// Transport is the interface the network service must provide.
type Transport interface {
net.Listener
// Dial is used to create a new outgoing connection
Dial(address string, timeout time.Duration) (net.Conn, error)
}
// Store represents a store of information, managed via consensus.
type Store interface {
// Leader returns the address of the leader of the consensus system.
LeaderAddr() string
// UpdateAPIPeers updates the API peers on the store.
UpdateAPIPeers(peers map[string]string) error
}
// Service allows access to the cluster and associated meta data,
// via consensus.
type Service struct {
tn Transport
store Store
addr net.Addr
wg sync.WaitGroup
logger *log.Logger
}
// NewService returns a new instance of the cluster service
func NewService(tn Transport, store Store) *Service {
return &Service{
tn: tn,
store: store,
addr: tn.Addr(),
logger: log.New(os.Stderr, "[cluster] ", log.LstdFlags),
}
}
// Open opens the Service.
func (s *Service) Open() error {
s.wg.Add(1)
go s.serve()
s.logger.Println("service listening on", s.tn.Addr())
return nil
}
// Close closes the service.
func (s *Service) Close() error {
s.tn.Close()
s.wg.Wait()
return nil
}
// Addr returns the address the service is listening on.
func (s *Service) Addr() string {
return s.addr.String()
}
// SetPeer will set the mapping between raftAddr and apiAddr for the entire cluster.
func (s *Service) SetPeer(raftAddr, apiAddr string) error {
peer := map[string]string{
raftAddr: apiAddr,
}
// Try the local store. It might be the leader.
err := s.store.UpdateAPIPeers(peer)
if err == nil {
// All done! Aren't we lucky?
return nil
}
// Try talking to the leader over the network.
if leader := s.store.LeaderAddr(); leader == "" {
return fmt.Errorf("no leader available")
}
conn, err := s.tn.Dial(s.store.LeaderAddr(), connectionTimeout)
if err != nil {
return err
}
defer conn.Close()
b, err := json.Marshal(peer)
if err != nil {
return err
}
if _, err := conn.Write(b); err != nil {
return err
}
// Wait for the response and verify the operation went through.
resp := response{}
d := json.NewDecoder(conn)
err = d.Decode(&resp)
if err != nil {
return err
}
if resp.Code != 0 {
return fmt.Errorf(resp.Message)
}
return nil
}
func (s *Service) serve() error {
defer s.wg.Done()
for {
conn, err := s.tn.Accept()
if err != nil {
return err
}
go s.handleConn(conn)
}
}
func (s *Service) handleConn(conn net.Conn) {
s.logger.Printf("received connection from %s", conn.RemoteAddr().String())
// Only handles peers updates for now.
peers := make(map[string]string)
d := json.NewDecoder(conn)
err := d.Decode(&peers)
if err != nil {
return
}
// Update the peers.
if err := s.store.UpdateAPIPeers(peers); err != nil {
resp := response{1, err.Error()}
b, err := json.Marshal(resp)
if err != nil {
conn.Close() // Only way left to signal.
} else {
if _, err := conn.Write(b); err != nil {
conn.Close() // Only way left to signal.
}
}
return
}
// Let the remote node know everything went OK.
if _, err := conn.Write(respOKMarshalled); err != nil {
conn.Close() // Only way left to signal.
}
return
}

@ -1,194 +0,0 @@
package cluster
import (
"encoding/json"
"fmt"
"net"
"testing"
"time"
)
func Test_NewServiceOpenClose(t *testing.T) {
ml := mustNewMockTransport()
ms := &mockStore{}
s := NewService(ml, ms)
if s == nil {
t.Fatalf("failed to create cluster service")
}
if err := s.Open(); err != nil {
t.Fatalf("failed to open cluster service")
}
if err := s.Close(); err != nil {
t.Fatalf("failed to close cluster service")
}
}
func Test_SetAPIPeer(t *testing.T) {
raftAddr, apiAddr := "localhost:4002", "localhost:4001"
s, _, ms := mustNewOpenService()
defer s.Close()
if err := s.SetPeer(raftAddr, apiAddr); err != nil {
t.Fatalf("failed to set peer: %s", err.Error())
}
if ms.peers[raftAddr] != apiAddr {
t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr])
}
}
func Test_SetAPIPeerNetwork(t *testing.T) {
raftAddr, apiAddr := "localhost:4002", "localhost:4001"
s, _, ms := mustNewOpenService()
defer s.Close()
raddr, err := net.ResolveTCPAddr("tcp", s.Addr())
if err != nil {
t.Fatalf("failed to resolve remote uster ervice address: %s", err.Error())
}
conn, err := net.DialTCP("tcp4", nil, raddr)
if err != nil {
t.Fatalf("failed to connect to remote cluster service: %s", err.Error())
}
if _, err := conn.Write([]byte(fmt.Sprintf(`{"%s": "%s"}`, raftAddr, apiAddr))); err != nil {
t.Fatalf("failed to write to remote cluster service: %s", err.Error())
}
resp := response{}
d := json.NewDecoder(conn)
err = d.Decode(&resp)
if err != nil {
t.Fatalf("failed to decode response: %s", err.Error())
}
if resp.Code != 0 {
t.Fatalf("response code was non-zero")
}
if ms.peers[raftAddr] != apiAddr {
t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr])
}
}
func Test_SetAPIPeerFailUpdate(t *testing.T) {
raftAddr, apiAddr := "localhost:4002", "localhost:4001"
s, _, ms := mustNewOpenService()
defer s.Close()
ms.failUpdateAPIPeers = true
// Attempt to set peer without a leader
if err := s.SetPeer(raftAddr, apiAddr); err == nil {
t.Fatalf("no error returned by set peer when no leader")
}
// Start a network server.
tn, err := net.Listen("tcp", "localhost:0")
if err != nil {
t.Fatalf("failed to open test server: %s", err.Error())
}
ms.leader = tn.Addr().String()
c := make(chan map[string]string, 1)
go func() {
conn, err := tn.Accept()
if err != nil {
t.Fatalf("failed to accept connection from cluster: %s", err.Error())
}
t.Logf("test server received connection from: %s", conn.RemoteAddr())
peers := make(map[string]string)
d := json.NewDecoder(conn)
err = d.Decode(&peers)
if err != nil {
t.Fatalf("failed to decode message from cluster: %s", err.Error())
}
// Response OK.
// Let the remote node know everything went OK.
if _, err := conn.Write(respOKMarshalled); err != nil {
t.Fatalf("failed to respond to cluster: %s", err.Error())
}
c <- peers
}()
if err := s.SetPeer(raftAddr, apiAddr); err != nil {
t.Fatalf("failed to set peer on cluster: %s", err.Error())
}
peers := <-c
if peers[raftAddr] != apiAddr {
t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr])
}
}
func mustNewOpenService() (*Service, *mockTransport, *mockStore) {
ml := mustNewMockTransport()
ms := newMockStore()
s := NewService(ml, ms)
if err := s.Open(); err != nil {
panic("failed to open new service")
}
return s, ml, ms
}
type mockTransport struct {
tn net.Listener
}
func mustNewMockTransport() *mockTransport {
tn, err := net.Listen("tcp", "localhost:0")
if err != nil {
panic("failed to create mock listener")
}
return &mockTransport{
tn: tn,
}
}
func (ml *mockTransport) Accept() (c net.Conn, err error) {
return ml.tn.Accept()
}
func (ml *mockTransport) Addr() net.Addr {
return ml.tn.Addr()
}
func (ml *mockTransport) Close() (err error) {
return ml.tn.Close()
}
func (ml *mockTransport) Dial(addr string, t time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, 5*time.Second)
}
type mockStore struct {
leader string
peers map[string]string
failUpdateAPIPeers bool
}
func newMockStore() *mockStore {
return &mockStore{
peers: make(map[string]string),
}
}
func (ms *mockStore) LeaderAddr() string {
return ms.leader
}
func (ms *mockStore) UpdateAPIPeers(peers map[string]string) error {
if ms.failUpdateAPIPeers {
return fmt.Errorf("forced fail")
}
for k, v := range peers {
ms.peers[k] = v
}
return nil
}

@ -8,7 +8,6 @@ import (
"fmt"
"io/ioutil"
"log"
"net"
"os"
"os/signal"
"path/filepath"
@ -161,47 +160,28 @@ func main() {
// Start requested profiling.
startProfile(cpuProfile, memProfile)
// Set up node-to-node TCP communication.
ln, err := net.Listen("tcp", raftAddr)
if err != nil {
log.Fatalf("failed to listen on %s: %s", raftAddr, err.Error())
}
var adv net.Addr
if raftAdv != "" {
adv, err = net.ResolveTCPAddr("tcp", raftAdv)
if err != nil {
log.Fatalf("failed to resolve advertise address %s: %s", raftAdv, err.Error())
}
}
// Start up node-to-node network mux.
var mux *tcp.Mux
// Create internode network layer.
var tn *tcp.Transport
if nodeEncrypt {
log.Printf("enabling node-to-node encryption with cert: %s, key: %s", nodeX509Cert, nodeX509Key)
mux, err = tcp.NewTLSMux(ln, adv, nodeX509Cert, nodeX509Key, nodeX509CACert)
tn = tcp.NewTLSTransport(nodeX509Cert, nodeX509Key, noVerify)
} else {
mux, err = tcp.NewMux(ln, adv)
tn = tcp.NewTransport()
}
if err != nil {
log.Fatalf("failed to create node-to-node mux: %s", err.Error())
if err := tn.Open(raftAddr); err != nil {
log.Fatalf("failed to open internode network layer: %s", err.Error())
}
mux.InsecureSkipVerify = noNodeVerify
go mux.Serve()
// Get transport for Raft communications.
raftTn := mux.Listen(muxRaftHeader)
// Create and open the store.
dataPath, err = filepath.Abs(dataPath)
dataPath, err := filepath.Abs(dataPath)
if err != nil {
log.Fatalf("failed to determine absolute data path: %s", err.Error())
}
dbConf := store.NewDBConfig(dsn, !onDisk)
str := store.New(&store.StoreConfig{
str := store.New(tn, &store.StoreConfig{
DBConf: dbConf,
Dir: dataPath,
Tn: raftTn,
ID: idOrRaftAddr(),
})
@ -220,10 +200,6 @@ func main() {
if err != nil {
log.Fatalf("failed to parse Raft apply timeout %s: %s", raftApplyTimeout, err.Error())
}
str.OpenTimeout, err = time.ParseDuration(raftOpenTimeout)
if err != nil {
log.Fatalf("failed to parse Raft open timeout %s: %s", raftOpenTimeout, err.Error())
}
// Determine join addresses, if necessary.
ja, err := store.JoinAllowed(dataPath)
@ -241,16 +217,18 @@ func main() {
log.Println("node is already member of cluster, skip determining join addresses")
}
// Now, open it.
// Now, open store.
if err := str.Open(len(joins) == 0); err != nil {
log.Fatalf("failed to open store: %s", err.Error())
}
// Create and configure cluster service.
tn := mux.Listen(muxMetaHeader)
cs := cluster.NewService(tn, str)
if err := cs.Open(); err != nil {
log.Fatalf("failed to open cluster service: %s", err.Error())
// Prepare metadata for join command.
apiAdv := httpAddr
if httpAdv != "" {
apiAdv = httpAdv
}
meta := map[string]string{
"api_addr": apiAdv,
}
// Execute any requested join operation.
@ -274,7 +252,7 @@ func main() {
}
}
if j, err := cluster.Join(joins, str.ID(), advAddr, &tlsConfig); err != nil {
if j, err := cluster.Join(joins, str.ID(), advAddr, meta, &tlsConfig); err != nil {
log.Fatalf("failed to join cluster at %s: %s", joins, err.Error())
} else {
log.Println("successfully joined cluster at", j)
@ -284,53 +262,26 @@ func main() {
log.Println("no join addresses set")
}
// Publish to the cluster the mapping between this Raft address and API address.
// The Raft layer broadcasts the resolved address, so use that as the key. But
// only set different HTTP advertise address if set.
apiAdv := httpAddr
if httpAdv != "" {
apiAdv = httpAdv
}
if err := publishAPIAddr(cs, raftTn.Addr().String(), apiAdv, publishPeerTimeout); err != nil {
log.Fatalf("failed to set peer for %s to %s: %s", raftAddr, httpAddr, err.Error())
}
log.Printf("set peer for %s to %s", raftTn.Addr().String(), apiAdv)
// Get the credential store.
credStr, err := credentialStore()
// Wait until the store is in full consensus.
openTimeout, err := time.ParseDuration(raftOpenTimeout)
if err != nil {
log.Fatalf("failed to get credential store: %s", err.Error())
log.Fatalf("failed to parse Raft open timeout %s: %s", raftOpenTimeout, err.Error())
}
str.WaitForLeader(openTimeout)
str.WaitForApplied(openTimeout)
// Create HTTP server and load authentication information if required.
var s *httpd.Service
if credStr != nil {
s = httpd.New(httpAddr, str, credStr)
} else {
s = httpd.New(httpAddr, str, nil)
// This may be a standalone server. In that case set its own metadata.
if err := str.SetMetadata(meta); err != nil && err != store.ErrNotLeader {
// Non-leader errors are OK, since metadata will then be set through
// consensus as a result of a join. All other errors indicate a problem.
log.Fatalf("failed to set store metadata: %s", err.Error())
}
s.CACertFile = x509CACert
s.CertFile = x509Cert
s.KeyFile = x509Key
s.Expvar = expvar
s.Pprof = pprofEnabled
s.BuildInfo = map[string]interface{}{
"commit": commit,
"branch": branch,
"version": version,
"build_time": buildtime,
}
if err := s.Start(); err != nil {
// Start the HTTP API server.
if err := startHTTPService(str); err != nil {
log.Fatalf("failed to start HTTP server: %s", err.Error())
}
// Register cross-component statuses.
if err := s.RegisterStatus("mux", mux); err != nil {
log.Fatalf("failed to register mux status: %s", err.Error())
}
// Block until signalled.
terminate := make(chan os.Signal, 1)
signal.Notify(terminate, os.Interrupt)
@ -373,25 +324,32 @@ func determineJoinAddresses() ([]string, error) {
return addrs, nil
}
func publishAPIAddr(c *cluster.Service, raftAddr, apiAddr string, t time.Duration) error {
tck := time.NewTicker(publishPeerDelay)
defer tck.Stop()
tmr := time.NewTimer(t)
defer tmr.Stop()
for {
select {
case <-tck.C:
if err := c.SetPeer(raftAddr, apiAddr); err != nil {
log.Printf("failed to set peer for %s to %s: %s (retrying)",
raftAddr, apiAddr, err.Error())
continue
}
return nil
case <-tmr.C:
return fmt.Errorf("set peer timeout expired")
}
func startHTTPService(str *store.Store) error {
// Get the credential store.
credStr, err := credentialStore()
if err != nil {
return err
}
// Create HTTP server and load authentication information if required.
var s *httpd.Service
if credStr != nil {
s = httpd.New(httpAddr, str, credStr)
} else {
s = httpd.New(httpAddr, str, nil)
}
s.CertFile = x509Cert
s.KeyFile = x509Key
s.Expvar = expvar
s.Pprof = pprofEnabled
s.BuildInfo = map[string]interface{}{
"commit": commit,
"branch": branch,
"version": version,
"build_time": buildtime,
}
return s.Start()
}
func credentialStore() (*auth.CredentialsStore, error) {

@ -44,16 +44,16 @@ type Store interface {
Query(qr *store.QueryRequest) ([]*sql.Rows, error)
// Join joins the node with the given ID, reachable at addr, to this node.
Join(id, addr string) error
Join(id, addr string, metadata map[string]string) error
// Remove removes the node, specified by addr, from the cluster.
Remove(addr string) error
// Remove removes the node, specified by id, from the cluster.
Remove(id string) error
// LeaderAddr returns the Raft address of the leader of the cluster.
LeaderAddr() string
// Metadata returns the value for the given node ID, for the given key.
Metadata(id, key string) string
// Peer returns the API peer for the given address
Peer(addr string) string
// Leader returns the Raft address of the leader of the cluster.
LeaderID() (string, error)
// Stats returns stats on the Store.
Stats() (map[string]interface{}, error)
@ -289,27 +289,35 @@ func (s *Service) handleJoin(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
return
}
m := map[string]string{}
if err := json.Unmarshal(b, &m); err != nil {
md := map[string]interface{}{}
if err := json.Unmarshal(b, &md); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
remoteID, ok := m["id"]
remoteID, ok := md["id"]
if !ok {
w.WriteHeader(http.StatusBadRequest)
return
}
var m map[string]string
if _, ok := md["meta"].(map[string]interface{}); ok {
m = make(map[string]string)
for k, v := range md["meta"].(map[string]interface{}) {
m[k] = v.(string)
}
}
remoteAddr, ok := m["addr"]
remoteAddr, ok := md["addr"]
if !ok {
fmt.Println("4444")
w.WriteHeader(http.StatusBadRequest)
return
}
if err := s.store.Join(remoteID, remoteAddr); err != nil {
if err := s.store.Join(remoteID.(string), remoteAddr.(string), m); err != nil {
if err == store.ErrNotLeader {
leader := s.store.Peer(s.store.LeaderAddr())
leader := s.leaderAPIAddr()
if leader == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
@ -355,15 +363,15 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) {
return
}
remoteAddr, ok := m["addr"]
remoteID, ok := m["id"]
if !ok {
w.WriteHeader(http.StatusBadRequest)
return
}
if err := s.store.Remove(remoteAddr); err != nil {
if err := s.store.Remove(remoteID); err != nil {
if err == store.ErrNotLeader {
leader := s.store.Peer(s.store.LeaderAddr())
leader := s.leaderAPIAddr()
if leader == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
@ -448,7 +456,7 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
results, err := s.store.ExecuteOrAbort(&store.ExecuteRequest{queries, timings, false})
if err != nil {
if err == store.ErrNotLeader {
leader := s.store.Peer(s.store.LeaderAddr())
leader := s.leaderAPIAddr()
if leader == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
@ -498,7 +506,7 @@ func (s *Service) handleStatus(w http.ResponseWriter, r *http.Request) {
httpStatus := map[string]interface{}{
"addr": s.Addr().String(),
"auth": prettyEnabled(s.credentialStore != nil),
"redirect": s.store.Peer(s.store.LeaderAddr()),
"redirect": s.leaderAPIAddr(),
}
nodeStatus := map[string]interface{}{
@ -595,7 +603,7 @@ func (s *Service) handleExecute(w http.ResponseWriter, r *http.Request) {
results, err := s.store.Execute(&store.ExecuteRequest{queries, timings, isTx})
if err != nil {
if err == store.ErrNotLeader {
leader := s.store.Peer(s.store.LeaderAddr())
leader := s.leaderAPIAddr()
if leader == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
@ -657,7 +665,7 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
results, err := s.store.Query(&store.QueryRequest{queries, timings, isTx, lvl})
if err != nil {
if err == store.ErrNotLeader {
leader := s.store.Peer(s.store.LeaderAddr())
leader := s.leaderAPIAddr()
if leader == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
@ -746,6 +754,14 @@ func (s *Service) CheckRequestPerm(r *http.Request, perm string) bool {
return s.credentialStore.HasPerm(username, PermAll) || s.credentialStore.HasPerm(username, perm)
}
func (s *Service) leaderAPIAddr() string {
id, err := s.store.LeaderID()
if err != nil {
return ""
}
return s.store.Metadata(id, "api_addr")
}
// addBuildVersion adds the build version to the HTTP response.
func (s *Service) addBuildVersion(w http.ResponseWriter) {
// Add version header to every response, if available.

@ -491,19 +491,19 @@ func (m *MockStore) Query(qr *store.QueryRequest) ([]*sql.Rows, error) {
return nil, nil
}
func (m *MockStore) Join(id, addr string) error {
func (m *MockStore) Join(id, addr string, metadata map[string]string) error {
return nil
}
func (m *MockStore) Remove(addr string) error {
func (m *MockStore) Remove(id string) error {
return nil
}
func (m *MockStore) LeaderAddr() string {
return ""
func (m *MockStore) LeaderID() (string, error) {
return "", nil
}
func (m *MockStore) Peer(addr string) string {
func (m *MockStore) Metadata(id, key string) string {
return ""
}

@ -8,9 +8,10 @@ import (
type commandType int
const (
execute commandType = iota // Commands which modify the database.
query // Commands which query the database.
peer // Commands that modify peers map.
execute commandType = iota // Commands which modify the database.
query // Commands which query the database.
metadataSet // Commands which sets Store metadata
metadataDelete // Commands which deletes Store metadata
)
type command struct {
@ -27,7 +28,14 @@ func newCommand(t commandType, d interface{}) (*command, error) {
Typ: t,
Sub: b,
}, nil
}
func newMetadataSetCommand(id string, md map[string]string) (*command, error) {
m := metadataSetSub{
RaftID: id,
Data: md,
}
return newCommand(metadataSet, m)
}
// databaseSub is a command sub which involves interaction with the database.
@ -37,5 +45,7 @@ type databaseSub struct {
Timings bool `json:"timings,omitempty"`
}
// peersSub is a command which sets the API address for a Raft address.
type peersSub map[string]string
type metadataSetSub struct {
RaftID string `json:"raft_id,omitempty"`
Data map[string]string `json:"data,omitempty"`
}

@ -0,0 +1,12 @@
package store
// DBConfig represents the configuration of the underlying SQLite database.
type DBConfig struct {
DSN string // Any custom DSN
Memory bool // Whether the database is in-memory only.
}
// NewDBConfig returns a new DB config instance.
func NewDBConfig(dsn string, memory bool) *DBConfig {
return &DBConfig{DSN: dsn, Memory: memory}
}

@ -0,0 +1,14 @@
package store
// Server represents another node in the cluster.
type Server struct {
ID string `json:"id,omitempty"`
Addr string `json:"addr,omitempty"`
}
// Servers is a set of Servers.
type Servers []*Server
func (s Servers) Less(i, j int) bool { return s[i].ID < s[j].ID }
func (s Servers) Len() int { return len(s) }
func (s Servers) Swap(i, j int) { s[i], s[j] = s[j], s[i] }

@ -13,7 +13,6 @@ import (
"io"
"io/ioutil"
"log"
"net"
"os"
"path/filepath"
"sort"
@ -46,6 +45,9 @@ const (
sqliteFile = "db.sqlite"
leaderWaitDelay = 100 * time.Millisecond
appliedWaitDelay = 100 * time.Millisecond
connectionPoolCount = 5
connectionTimeout = 10 * time.Second
raftLogCacheSize = 512
)
const (
@ -114,60 +116,6 @@ const (
Unknown
)
// clusterMeta represents cluster meta which must be kept in consensus.
type clusterMeta struct {
APIPeers map[string]string // Map from Raft address to API address
}
// NewClusterMeta returns an initialized cluster meta store.
func newClusterMeta() *clusterMeta {
return &clusterMeta{
APIPeers: make(map[string]string),
}
}
func (c *clusterMeta) AddrForPeer(addr string) string {
if api, ok := c.APIPeers[addr]; ok && api != "" {
return api
}
// Go through each entry, and see if any key resolves to addr.
for k, v := range c.APIPeers {
resv, err := net.ResolveTCPAddr("tcp", k)
if err != nil {
continue
}
if resv.String() == addr {
return v
}
}
return ""
}
// DBConfig represents the configuration of the underlying SQLite database.
type DBConfig struct {
DSN string // Any custom DSN
Memory bool // Whether the database is in-memory only.
}
// NewDBConfig returns a new DB config instance.
func NewDBConfig(dsn string, memory bool) *DBConfig {
return &DBConfig{DSN: dsn, Memory: memory}
}
// Server represents another node in the cluster.
type Server struct {
ID string `json:"id,omitempty"`
Addr string `json:"addr,omitempty"`
}
type Servers []*Server
func (s Servers) Less(i, j int) bool { return s[i].ID < s[j].ID }
func (s Servers) Len() int { return len(s) }
func (s Servers) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
// Store is a SQLite database, where all changes are made via Raft consensus.
type Store struct {
raftDir string
@ -175,14 +123,19 @@ type Store struct {
mu sync.RWMutex // Sync access between queries and snapshots.
raft *raft.Raft // The consensus mechanism.
raftTn *raftTransport
ln Listener
raftTn *raft.NetworkTransport
raftID string // Node ID.
dbConf *DBConfig // SQLite database config.
dbPath string // Path to underlying SQLite file, if not in-memory.
db *sql.DB // The underlying SQLite store.
raftLog raft.LogStore // Persistent log store.
raftStable raft.StableStore // Persistent k-v store.
boltStore *raftboltdb.BoltStore // Physical store.
metaMu sync.RWMutex
meta *clusterMeta
meta map[string]map[string]string
logger *log.Logger
@ -192,7 +145,6 @@ type Store struct {
HeartbeatTimeout time.Duration
ElectionTimeout time.Duration
ApplyTimeout time.Duration
OpenTimeout time.Duration
}
// StoreConfig represents the configuration of the underlying Store.
@ -205,22 +157,21 @@ type StoreConfig struct {
}
// New returns a new Store.
func New(c *StoreConfig) *Store {
func New(ln Listener, c *StoreConfig) *Store {
logger := c.Logger
if logger == nil {
logger = log.New(os.Stderr, "[store] ", log.LstdFlags)
}
return &Store{
ln: ln,
raftDir: c.Dir,
raftTn: &raftTransport{c.Tn},
raftID: c.ID,
dbConf: c.DBConf,
dbPath: filepath.Join(c.Dir, sqliteFile),
meta: newClusterMeta(),
meta: make(map[string]map[string]string),
logger: logger,
ApplyTimeout: applyTimeout,
OpenTimeout: openTimeout,
}
}
@ -234,6 +185,7 @@ func (s *Store) Open(enableSingle bool) error {
return err
}
// Open underlying database.
db, err := s.open()
if err != nil {
return err
@ -243,14 +195,11 @@ func (s *Store) Open(enableSingle bool) error {
// Is this a brand new node?
newNode := !pathExists(filepath.Join(s.raftDir, "raft.db"))
// Setup Raft communication.
transport := raft.NewNetworkTransport(s.raftTn, 3, 10*time.Second, os.Stderr)
// Create Raft-compatible network layer.
s.raftTn = raft.NewNetworkTransport(NewTransport(s.ln), connectionPoolCount, connectionTimeout, nil)
// Get the Raft configuration for this store.
config := s.raftConfig()
config.LocalID = raft.ServerID(s.raftID)
// XXXconfig.Logger = log.New(os.Stderr, "[raft] ", log.LstdFlags)
// Create the snapshot store. This allows Raft to truncate the log.
snapshots, err := raft.NewFileSnapshotStore(s.raftDir, retainSnapshotCount, os.Stderr)
@ -259,13 +208,18 @@ func (s *Store) Open(enableSingle bool) error {
}
// Create the log store and stable store.
logStore, err := raftboltdb.NewBoltStore(filepath.Join(s.raftDir, "raft.db"))
s.boltStore, err = raftboltdb.NewBoltStore(filepath.Join(s.raftDir, "raft.db"))
if err != nil {
return fmt.Errorf("new bolt store: %s", err)
}
s.raftStable = s.boltStore
s.raftLog, err = raft.NewLogCache(raftLogCacheSize, s.boltStore)
if err != nil {
return fmt.Errorf("new cached store: %s", err)
}
// Instantiate the Raft system.
ra, err := raft.NewRaft(config, s, logStore, logStore, snapshots, transport)
ra, err := raft.NewRaft(config, s, s.raftLog, s.raftStable, snapshots, s.raftTn)
if err != nil {
return fmt.Errorf("new raft: %s", err)
}
@ -276,7 +230,7 @@ func (s *Store) Open(enableSingle bool) error {
Servers: []raft.Server{
raft.Server{
ID: config.LocalID,
Address: transport.LocalAddr(),
Address: s.raftTn.LocalAddr(),
},
},
}
@ -287,16 +241,6 @@ func (s *Store) Open(enableSingle bool) error {
s.raft = ra
if s.OpenTimeout != 0 {
// Wait until the initial logs are applied.
s.logger.Printf("waiting for up to %s for application of initial logs", s.OpenTimeout)
if err := s.WaitForAppliedIndex(s.raft.LastIndex(), s.OpenTimeout); err != nil {
return ErrOpenTimeout
}
} else {
s.logger.Println("not waiting for application of initial logs")
}
return nil
}
@ -314,6 +258,19 @@ func (s *Store) Close(wait bool) error {
return nil
}
// WaitForApplied waits for all Raft log entries to to be applied to the
// underlying database.
func (s *Store) WaitForApplied(timeout time.Duration) error {
if timeout == 0 {
return nil
}
s.logger.Printf("waiting for up to %s for application of initial logs", timeout)
if err := s.WaitForAppliedIndex(s.raft.LastIndex(), timeout); err != nil {
return ErrOpenTimeout
}
return nil
}
// IsLeader is used to determine if the current node is cluster leader
func (s *Store) IsLeader() bool {
return s.raft.State() == raft.Leader
@ -342,8 +299,8 @@ func (s *Store) Path() string {
}
// Addr returns the address of the store.
func (s *Store) Addr() net.Addr {
return s.raftTn.Addr()
func (s *Store) Addr() string {
return string(s.raftTn.LocalAddr())
}
// ID returns the Raft ID of the store.
@ -375,24 +332,6 @@ func (s *Store) LeaderID() (string, error) {
return "", nil
}
// Peer returns the API address for the given addr. If there is no peer
// for the address, it returns the empty string.
func (s *Store) Peer(addr string) string {
return s.meta.AddrForPeer(addr)
}
// APIPeers return the map of Raft addresses to API addresses.
func (s *Store) APIPeers() (map[string]string, error) {
s.metaMu.RLock()
defer s.metaMu.RUnlock()
peers := make(map[string]string, len(s.meta.APIPeers))
for k, v := range s.meta.APIPeers {
peers[k] = v
}
return peers, nil
}
// Nodes returns the slice of nodes in the cluster, sorted by ID ascending.
func (s *Store) Nodes() ([]*Server, error) {
f := s.raft.GetConfiguration()
@ -480,18 +419,25 @@ func (s *Store) Stats() (map[string]interface{}, error) {
if err != nil {
return nil, err
}
leaderID, err := s.LeaderID()
if err != nil {
return nil, err
}
status := map[string]interface{}{
"node_id": s.raftID,
"raft": s.raft.Stats(),
"addr": s.Addr().String(),
"leader": s.LeaderAddr(),
"node_id": s.raftID,
"raft": s.raft.Stats(),
"addr": s.Addr(),
"leader": map[string]string{
"node_id": leaderID,
"addr": s.LeaderAddr(),
},
"apply_timeout": s.ApplyTimeout.String(),
"open_timeout": s.OpenTimeout.String(),
"heartbeat_timeout": s.HeartbeatTimeout.String(),
"election_timeout": s.ElectionTimeout.String(),
"snapshot_threshold": s.SnapshotThreshold,
"meta": s.meta,
"peers": nodes,
"metadata": s.meta,
"nodes": nodes,
"dir": s.raftDir,
"sqlite3": dbStatus,
"db_conf": s.dbConf,
@ -636,29 +582,38 @@ func (s *Store) Query(qr *QueryRequest) ([]*sql.Rows, error) {
return r, err
}
// UpdateAPIPeers updates the cluster-wide peer information.
func (s *Store) UpdateAPIPeers(peers map[string]string) error {
c, err := newCommand(peer, peers)
if err != nil {
return err
}
b, err := json.Marshal(c)
if err != nil {
return err
}
f := s.raft.Apply(b, s.ApplyTimeout)
return f.Error()
}
// Join joins a node, identified by id and located at addr, to this store.
// The node must be ready to respond to Raft communications at that address.
func (s *Store) Join(id, addr string) error {
func (s *Store) Join(id, addr string, metadata map[string]string) error {
s.logger.Printf("received request to join node at %s", addr)
if s.raft.State() != raft.Leader {
return ErrNotLeader
}
configFuture := s.raft.GetConfiguration()
if err := configFuture.Error(); err != nil {
s.logger.Printf("failed to get raft configuration: %v", err)
return err
}
for _, srv := range configFuture.Configuration().Servers {
// If a node already exists with either the joining node's ID or address,
// that node may need to be removed from the config first.
if srv.ID == raft.ServerID(id) || srv.Address == raft.ServerAddress(addr) {
// However if *both* the ID and the address are the same, the no
// join is actually needed.
if srv.Address == raft.ServerAddress(addr) && srv.ID == raft.ServerID(id) {
s.logger.Printf("node %s at %s already member of cluster, ignoring join request", id, addr)
return nil
}
if err := s.remove(id); err != nil {
s.logger.Printf("failed to remove node: %v", err)
return err
}
}
}
f := s.raft.AddVoter(raft.ServerID(id), raft.ServerAddress(addr), 0, 0)
if e := f.(raft.Future); e.Error() != nil {
if e.Error() == raft.ErrNotLeader {
@ -666,6 +621,11 @@ func (s *Store) Join(id, addr string) error {
}
return e.Error()
}
if err := s.setMetadata(id, metadata); err != nil {
return err
}
s.logger.Printf("node at %s joined successfully", addr)
return nil
}
@ -673,18 +633,74 @@ func (s *Store) Join(id, addr string) error {
// Remove removes a node from the store, specified by ID.
func (s *Store) Remove(id string) error {
s.logger.Printf("received request to remove node %s", id)
if s.raft.State() != raft.Leader {
return ErrNotLeader
if err := s.remove(id); err != nil {
s.logger.Printf("failed to remove node %s: %s", id, err.Error())
return err
}
f := s.raft.RemoveServer(raft.ServerID(id), 0, 0)
if f.Error() != nil {
if f.Error() == raft.ErrNotLeader {
s.logger.Printf("node %s removed successfully", id)
return nil
}
// Metadata returns the value for a given key, for a given node ID.
func (s *Store) Metadata(id, key string) string {
s.metaMu.RLock()
defer s.metaMu.RUnlock()
if _, ok := s.meta[id]; !ok {
return ""
}
v, ok := s.meta[id][key]
if ok {
return v
}
return ""
}
// SetMetadata adds the metadata md to any existing metadata for
// this node.
func (s *Store) SetMetadata(md map[string]string) error {
return s.setMetadata(s.raftID, md)
}
// setMetadata adds the metadata md to any existing metadata for
// the given node ID.
func (s *Store) setMetadata(id string, md map[string]string) error {
// Check local data first.
if func() bool {
s.metaMu.RLock()
defer s.metaMu.RUnlock()
if _, ok := s.meta[id]; ok {
for k, v := range md {
if s.meta[id][k] != v {
return false
}
}
return true
}
return false
}() {
// Local data is same as data being pushed in,
// nothing to do.
return nil
}
c, err := newMetadataSetCommand(id, md)
if err != nil {
return err
}
b, err := json.Marshal(c)
if err != nil {
return err
}
f := s.raft.Apply(b, s.ApplyTimeout)
if e := f.(raft.Future); e.Error() != nil {
if e.Error() == raft.ErrNotLeader {
return ErrNotLeader
}
return f.Error()
e.Error()
}
s.logger.Printf("node %s removed successfully", id)
return nil
}
@ -712,6 +728,36 @@ func (s *Store) open() (*sql.DB, error) {
return db, nil
}
// remove removes the node, with the given ID, from the cluster.
func (s *Store) remove(id string) error {
if s.raft.State() != raft.Leader {
return ErrNotLeader
}
f := s.raft.RemoveServer(raft.ServerID(id), 0, 0)
if f.Error() != nil {
if f.Error() == raft.ErrNotLeader {
return ErrNotLeader
}
return f.Error()
}
c, err := newCommand(metadataDelete, id)
b, err := json.Marshal(c)
if err != nil {
return err
}
f = s.raft.Apply(b, s.ApplyTimeout)
if e := f.(raft.Future); e.Error() != nil {
if e.Error() == raft.ErrNotLeader {
return ErrNotLeader
}
e.Error()
}
return nil
}
// raftConfig returns a new Raft config for the store.
func (s *Store) raftConfig() *raft.Config {
config := raft.DefaultConfig()
@ -764,17 +810,31 @@ func (s *Store) Apply(l *raft.Log) interface{} {
}
r, err := s.db.Query(d.Queries, d.Tx, d.Timings)
return &fsmQueryResponse{rows: r, error: err}
case peer:
var d peersSub
case metadataSet:
var d metadataSetSub
if err := json.Unmarshal(c.Sub, &d); err != nil {
return &fsmGenericResponse{error: err}
}
func() {
s.metaMu.Lock()
defer s.metaMu.Unlock()
for k, v := range d {
s.meta.APIPeers[k] = v
if _, ok := s.meta[d.RaftID]; !ok {
s.meta[d.RaftID] = make(map[string]string)
}
for k, v := range d.Data {
s.meta[d.RaftID][k] = v
}
}()
return &fsmGenericResponse{}
case metadataDelete:
var d string
if err := json.Unmarshal(c.Sub, &d); err != nil {
return &fsmGenericResponse{error: err}
}
func() {
s.metaMu.Lock()
defer s.metaMu.Unlock()
delete(s.meta, d)
}()
return &fsmGenericResponse{}
default:

@ -6,7 +6,6 @@ import (
"net"
"os"
"path/filepath"
"reflect"
"sort"
"testing"
"time"
@ -14,35 +13,6 @@ import (
"github.com/rqlite/rqlite/testdata/chinook"
)
type mockSnapshotSink struct {
*os.File
}
func (m *mockSnapshotSink) ID() string {
return "1"
}
func (m *mockSnapshotSink) Cancel() error {
return nil
}
func Test_ClusterMeta(t *testing.T) {
c := newClusterMeta()
c.APIPeers["localhost:4002"] = "localhost:4001"
if c.AddrForPeer("localhost:4002") != "localhost:4001" {
t.Fatalf("wrong address returned for localhost:4002")
}
if c.AddrForPeer("127.0.0.1:4002") != "localhost:4001" {
t.Fatalf("wrong address returned for 127.0.0.1:4002")
}
if c.AddrForPeer("127.0.0.1:4004") != "" {
t.Fatalf("wrong address returned for 127.0.0.1:4003")
}
}
func Test_OpenStoreSingleNode(t *testing.T) {
s := mustNewStore(true)
defer os.RemoveAll(s.Path())
@ -52,7 +22,7 @@ func Test_OpenStoreSingleNode(t *testing.T) {
}
s.WaitForLeader(10 * time.Second)
if got, exp := s.LeaderAddr(), s.Addr().String(); got != exp {
if got, exp := s.LeaderAddr(), s.Addr(); got != exp {
t.Fatalf("wrong leader address returned, got: %s, exp %s", got, exp)
}
id, err := s.LeaderID()
@ -71,6 +41,7 @@ func Test_OpenStoreCloseSingleNode(t *testing.T) {
if err := s.Open(true); err != nil {
t.Fatalf("failed to open single-node store: %s", err.Error())
}
s.WaitForLeader(10 * time.Second)
if err := s.Close(true); err != nil {
t.Fatalf("failed to close single-node store: %s", err.Error())
}
@ -450,14 +421,14 @@ func Test_MultiNodeJoinRemove(t *testing.T) {
sort.StringSlice(storeNodes).Sort()
// Join the second node to the first.
if err := s0.Join(s1.ID(), s1.Addr().String()); err != nil {
t.Fatalf("failed to join to node at %s: %s", s0.Addr().String(), err.Error())
if err := s0.Join(s1.ID(), s1.Addr(), nil); err != nil {
t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error())
}
s1.WaitForLeader(10 * time.Second)
// Check leader state on follower.
if got, exp := s1.LeaderAddr(), s0.Addr().String(); got != exp {
if got, exp := s1.LeaderAddr(), s0.Addr(); got != exp {
t.Fatalf("wrong leader address returned, got: %s, exp %s", got, exp)
}
id, err := s1.LeaderID()
@ -514,8 +485,8 @@ func Test_MultiNodeExecuteQuery(t *testing.T) {
defer s1.Close(true)
// Join the second node to the first.
if err := s0.Join(s1.ID(), s1.Addr().String()); err != nil {
t.Fatalf("failed to join to node at %s: %s", s0.Addr().String(), err.Error())
if err := s0.Join(s1.ID(), s1.Addr(), nil); err != nil {
t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error())
}
queries := []string{
@ -609,7 +580,7 @@ func Test_StoreLogTruncationMultinode(t *testing.T) {
defer s1.Close(true)
// Join the second node to the first.
if err := s0.Join(s1.ID(), s1.Addr().String()); err != nil {
if err := s0.Join(s1.ID(), s1.Addr(), nil); err != nil {
t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error())
}
s1.WaitForLeader(10 * time.Second)
@ -754,38 +725,65 @@ func Test_SingleNodeSnapshotInMem(t *testing.T) {
}
}
func Test_APIPeers(t *testing.T) {
s := mustNewStore(false)
defer os.RemoveAll(s.Path())
if err := s.Open(true); err != nil {
func Test_MetadataMultinode(t *testing.T) {
s0 := mustNewStore(true)
if err := s0.Open(true); err != nil {
t.Fatalf("failed to open single-node store: %s", err.Error())
}
defer s.Close(true)
s.WaitForLeader(10 * time.Second)
defer s0.Close(true)
s0.WaitForLeader(10 * time.Second)
s1 := mustNewStore(true)
if err := s1.Open(true); err != nil {
t.Fatalf("failed to open single-node store: %s", err.Error())
}
defer s1.Close(true)
s1.WaitForLeader(10 * time.Second)
peers := map[string]string{
"localhost:4002": "localhost:4001",
"localhost:4004": "localhost:4003",
if s0.Metadata(s0.raftID, "foo") != "" {
t.Fatal("nonexistent metadata foo found")
}
if err := s.UpdateAPIPeers(peers); err != nil {
t.Fatalf("failed to update API peers: %s", err.Error())
if s0.Metadata("nonsense", "foo") != "" {
t.Fatal("nonexistent metadata foo found for nonexistent node")
}
// Retrieve peers and verify them.
apiPeers, err := s.APIPeers()
if err != nil {
t.Fatalf("failed to retrieve API peers: %s", err.Error())
if err := s0.SetMetadata(map[string]string{"foo": "bar"}); err != nil {
t.Fatalf("failed to set metadata: %s", err.Error())
}
if s0.Metadata(s0.raftID, "foo") != "bar" {
t.Fatal("key foo not found")
}
if !reflect.DeepEqual(peers, apiPeers) {
t.Fatalf("set and retrieved API peers not identical, got %v, exp %v",
apiPeers, peers)
if s0.Metadata("nonsense", "foo") != "" {
t.Fatal("nonexistent metadata foo found for nonexistent node")
}
if s.Peer("localhost:4002") != "localhost:4001" ||
s.Peer("localhost:4004") != "localhost:4003" ||
s.Peer("not exist") != "" {
t.Fatalf("failed to retrieve correct single API peer")
// Join the second node to the first.
meta := map[string]string{"baz": "qux"}
if err := s0.Join(s1.ID(), s1.Addr(), meta); err != nil {
t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error())
}
s1.WaitForLeader(10 * time.Second)
// Wait until the log entries have been applied to the follower,
// and then query.
if err := s1.WaitForAppliedIndex(5, 5*time.Second); err != nil {
t.Fatalf("error waiting for follower to apply index: %s:", err.Error())
}
if s1.Metadata(s0.raftID, "foo") != "bar" {
t.Fatal("key foo not found for s0")
}
if s0.Metadata(s1.raftID, "baz") != "qux" {
t.Fatal("key baz not found for s1")
}
// Remove a node.
if err := s0.Remove(s1.ID()); err != nil {
t.Fatalf("failed to remove %s from cluster: %s", s1.ID(), err.Error())
}
if s1.Metadata(s0.raftID, "foo") != "bar" {
t.Fatal("key foo not found for s0")
}
if s0.Metadata(s1.raftID, "baz") != "" {
t.Fatal("key baz found for removed node s1")
}
}
@ -824,12 +822,10 @@ func mustNewStore(inmem bool) *Store {
path := mustTempDir()
defer os.RemoveAll(path)
tn := mustMockTransport("localhost:0")
cfg := NewDBConfig("", inmem)
s := New(&StoreConfig{
s := New(mustMockLister("localhost:0"), &StoreConfig{
DBConf: cfg,
Dir: path,
Tn: tn,
ID: path, // Could be any unique string.
})
if s == nil {
@ -838,36 +834,52 @@ func mustNewStore(inmem bool) *Store {
return s
}
func mustTempDir() string {
var err error
path, err := ioutil.TempDir("", "rqlilte-test-")
if err != nil {
panic("failed to create temp dir")
}
return path
type mockSnapshotSink struct {
*os.File
}
func (m *mockSnapshotSink) ID() string {
return "1"
}
func (m *mockSnapshotSink) Cancel() error {
return nil
}
type mockTransport struct {
ln net.Listener
}
func mustMockTransport(addr string) Transport {
type mockListener struct {
ln net.Listener
}
func mustMockLister(addr string) Listener {
ln, err := net.Listen("tcp", addr)
if err != nil {
panic("failed to create new transport")
panic("failed to create new listner")
}
return &mockTransport{ln}
return &mockListener{ln}
}
func (m *mockTransport) Dial(addr string, timeout time.Duration) (net.Conn, error) {
func (m *mockListener) Dial(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, timeout)
}
func (m *mockTransport) Accept() (net.Conn, error) { return m.ln.Accept() }
func (m *mockListener) Accept() (net.Conn, error) { return m.ln.Accept() }
func (m *mockListener) Close() error { return m.ln.Close() }
func (m *mockTransport) Close() error { return m.ln.Close() }
func (m *mockListener) Addr() net.Addr { return m.ln.Addr() }
func (m *mockTransport) Addr() net.Addr { return m.ln.Addr() }
func mustTempDir() string {
var err error
path, err := ioutil.TempDir("", "rqlilte-test-")
if err != nil {
panic("failed to create temp dir")
}
return path
}
func asJSON(v interface{}) string {
b, err := json.Marshal(v)

@ -7,32 +7,39 @@ import (
"github.com/hashicorp/raft"
)
// Transport is the interface the network service must provide.
type Transport interface {
type Listener interface {
net.Listener
// Dial is used to create a new outgoing connection
Dial(address string, timeout time.Duration) (net.Conn, error)
}
// raftTransport takes a Transport and makes it suitable for use by the Raft
// networking system.
type raftTransport struct {
tn Transport
// Transport is the network service provided to Raft, and wraps a Listener.
type Transport struct {
ln Listener
}
// NewTransport returns an initialized Transport.
func NewTransport(ln Listener) *Transport {
return &Transport{
ln: ln,
}
}
func (r *raftTransport) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
return r.tn.Dial(string(address), timeout)
// Dial creates a new network connection.
func (t *Transport) Dial(addr raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
return t.ln.Dial(string(addr), timeout)
}
func (r *raftTransport) Accept() (net.Conn, error) {
return r.tn.Accept()
// Accept waits for the next connection.
func (t *Transport) Accept() (net.Conn, error) {
return t.ln.Accept()
}
func (r *raftTransport) Addr() net.Addr {
return r.tn.Addr()
// Close closes the transport
func (t *Transport) Close() error {
return t.ln.Close()
}
func (r *raftTransport) Close() error {
return r.tn.Close()
// Addr returns the binding address of the transport.
func (t *Transport) Addr() net.Addr {
return t.ln.Addr()
}

@ -0,0 +1,11 @@
package store
import (
"testing"
)
func Test_NewTransport(t *testing.T) {
if NewTransport(nil) == nil {
t.Fatal("failed to create new Transport")
}
}

@ -5,7 +5,6 @@ import (
"encoding/json"
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
"os"
@ -14,12 +13,14 @@ import (
httpd "github.com/rqlite/rqlite/http"
"github.com/rqlite/rqlite/store"
"github.com/rqlite/rqlite/tcp"
)
// Node represents a node under test.
type Node struct {
APIAddr string
RaftAddr string
ID string
Dir string
Store *store.Store
Service *httpd.Service
@ -302,18 +303,21 @@ func mustNewNode(enableSingle bool) *Node {
}
dbConf := store.NewDBConfig("", false)
tn := mustMockTransport("localhost:0")
node.Store = store.New(&store.StoreConfig{
tn := tcp.NewTransport()
if err := tn.Open("localhost:0"); err != nil {
panic(err.Error())
}
node.Store = store.New(tn, &store.StoreConfig{
DBConf: dbConf,
Dir: node.Dir,
Tn: tn,
ID: tn.Addr().String(),
})
if err := node.Store.Open(enableSingle); err != nil {
node.Deprovision()
panic(fmt.Sprintf("failed to open store: %s", err.Error()))
}
node.RaftAddr = node.Store.Addr().String()
node.RaftAddr = node.Store.Addr()
node.ID = node.Store.ID()
node.Service = httpd.New("localhost:0", node.Store, nil)
node.Service.Expvar = true
@ -335,28 +339,6 @@ func mustNewLeaderNode() *Node {
return node
}
type mockTransport struct {
ln net.Listener
}
func mustMockTransport(addr string) *mockTransport {
ln, err := net.Listen("tcp", addr)
if err != nil {
panic("failed to create new transport")
}
return &mockTransport{ln}
}
func (m *mockTransport) Dial(addr string, timeout time.Duration) (net.Conn, error) {
return net.DialTimeout("tcp", addr, timeout)
}
func (m *mockTransport) Accept() (net.Conn, error) { return m.ln.Accept() }
func (m *mockTransport) Close() error { return m.ln.Close() }
func (m *mockTransport) Addr() net.Addr { return m.ln.Addr() }
func mustTempDir() string {
var err error
path, err := ioutil.TempDir("", "rqlilte-system-test-")

@ -1,4 +1,4 @@
/*
Package tcp provides various TCP-related utilities. The TCP mux code provided by this package originated with InfluxDB.
Package tcp provides the internode communication network layer.
*/
package tcp

@ -1,301 +0,0 @@
package tcp
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
"strconv"
"sync"
"time"
)
const (
// DefaultTimeout is the default length of time to wait for first byte.
DefaultTimeout = 30 * time.Second
)
// Layer represents the connection between nodes.
type Layer struct {
ln net.Listener
header byte
addr net.Addr
remoteEncrypted bool
skipVerify bool
nodeX509CACert string
tlsConfig *tls.Config
}
// Addr returns the local address for the layer.
func (l *Layer) Addr() net.Addr {
return l.addr
}
// Dial creates a new network connection.
func (l *Layer) Dial(addr string, timeout time.Duration) (net.Conn, error) {
dialer := &net.Dialer{Timeout: timeout}
var err error
var conn net.Conn
if l.remoteEncrypted {
conn, err = tls.DialWithDialer(dialer, "tcp", addr, l.tlsConfig)
} else {
conn, err = dialer.Dial("tcp", addr)
}
if err != nil {
return nil, err
}
// Write a marker byte to indicate message type.
_, err = conn.Write([]byte{l.header})
if err != nil {
conn.Close()
return nil, err
}
return conn, err
}
// Accept waits for the next connection.
func (l *Layer) Accept() (net.Conn, error) { return l.ln.Accept() }
// Close closes the layer.
func (l *Layer) Close() error { return l.ln.Close() }
// Mux multiplexes a network connection.
type Mux struct {
ln net.Listener
addr net.Addr
m map[byte]*listener
wg sync.WaitGroup
remoteEncrypted bool
// The amount of time to wait for the first header byte.
Timeout time.Duration
// Out-of-band error logger
Logger *log.Logger
// Path to root X.509 certificate.
nodeX509CACert string
// Path to X509 certificate
nodeX509Cert string
// Path to X509 key.
nodeX509Key string
// Whether to skip verification of other nodes' certificates.
InsecureSkipVerify bool
tlsConfig *tls.Config
}
// NewMux returns a new instance of Mux for ln. If adv is nil,
// then the addr of ln is used.
func NewMux(ln net.Listener, adv net.Addr) (*Mux, error) {
addr := adv
if addr == nil {
addr = ln.Addr()
}
return &Mux{
ln: ln,
addr: addr,
m: make(map[byte]*listener),
Timeout: DefaultTimeout,
Logger: log.New(os.Stderr, "[mux] ", log.LstdFlags),
}, nil
}
// NewTLSMux returns a new instance of Mux for ln, and encrypts all traffic
// using TLS. If adv is nil, then the addr of ln is used.
func NewTLSMux(ln net.Listener, adv net.Addr, cert, key, caCert string) (*Mux, error) {
mux, err := NewMux(ln, adv)
if err != nil {
return nil, err
}
mux.tlsConfig, err = createTLSConfig(cert, key, caCert)
if err != nil {
return nil, err
}
mux.ln = tls.NewListener(ln, mux.tlsConfig)
mux.remoteEncrypted = true
mux.nodeX509CACert = caCert
mux.nodeX509Cert = cert
mux.nodeX509Key = key
return mux, nil
}
// Serve handles connections from ln and multiplexes then across registered listener.
func (mux *Mux) Serve() error {
mux.Logger.Printf("mux serving on %s, advertising %s", mux.ln.Addr().String(), mux.addr)
for {
// Wait for the next connection.
// If it returns a temporary error then simply retry.
// If it returns any other error then exit immediately.
conn, err := mux.ln.Accept()
if err, ok := err.(interface {
Temporary() bool
}); ok && err.Temporary() {
continue
}
if err != nil {
// Wait for all connections to be demuxed
mux.wg.Wait()
for _, ln := range mux.m {
close(ln.c)
}
return err
}
// Demux in a goroutine to
mux.wg.Add(1)
go mux.handleConn(conn)
}
}
// Stats returns status of the mux.
func (mux *Mux) Stats() (interface{}, error) {
s := map[string]string{
"addr": mux.addr.String(),
"timeout": mux.Timeout.String(),
"encrypted": strconv.FormatBool(mux.remoteEncrypted),
}
if mux.remoteEncrypted {
s["certificate"] = mux.nodeX509Cert
s["key"] = mux.nodeX509Key
s["ca_certificate"] = mux.nodeX509CACert
s["skip_verify"] = strconv.FormatBool(mux.InsecureSkipVerify)
}
return s, nil
}
func (mux *Mux) handleConn(conn net.Conn) {
defer mux.wg.Done()
// Set a read deadline so connections with no data don't timeout.
if err := conn.SetReadDeadline(time.Now().Add(mux.Timeout)); err != nil {
conn.Close()
mux.Logger.Printf("tcp.Mux: cannot set read deadline: %s", err)
return
}
// Read first byte from connection to determine handler.
var typ [1]byte
if _, err := io.ReadFull(conn, typ[:]); err != nil {
conn.Close()
mux.Logger.Printf("tcp.Mux: cannot read header byte: %s", err)
return
}
// Reset read deadline and let the listener handle that.
if err := conn.SetReadDeadline(time.Time{}); err != nil {
conn.Close()
mux.Logger.Printf("tcp.Mux: cannot reset set read deadline: %s", err)
return
}
// Retrieve handler based on first byte.
handler := mux.m[typ[0]]
if handler == nil {
conn.Close()
mux.Logger.Printf("tcp.Mux: handler not registered: %d", typ[0])
return
}
// Send connection to handler. The handler is responsible for closing the connection.
handler.c <- conn
}
// Listen returns a listener identified by header.
// Any connection accepted by mux is multiplexed based on the initial header byte.
func (mux *Mux) Listen(header byte) *Layer {
// Ensure two listeners are not created for the same header byte.
if _, ok := mux.m[header]; ok {
panic(fmt.Sprintf("listener already registered under header byte: %d", header))
}
// Create a new listener and assign it.
ln := &listener{
c: make(chan net.Conn),
}
mux.m[header] = ln
layer := &Layer{
ln: ln,
header: header,
addr: mux.addr,
remoteEncrypted: mux.remoteEncrypted,
skipVerify: mux.InsecureSkipVerify,
nodeX509CACert: mux.nodeX509CACert,
tlsConfig: mux.tlsConfig,
}
return layer
}
// listener is a receiver for connections received by Mux.
type listener struct {
c chan net.Conn
}
// Accept waits for and returns the next connection to the listener.
func (ln *listener) Accept() (c net.Conn, err error) {
conn, ok := <-ln.c
if !ok {
return nil, errors.New("network connection closed")
}
return conn, nil
}
// Close is a no-op. The mux's listener should be closed instead.
func (ln *listener) Close() error { return nil }
// Addr always returns nil
func (ln *listener) Addr() net.Addr { return nil }
// newTLSListener returns a net listener which encrypts the traffic using TLS.
func newTLSListener(ln net.Listener, certFile, keyFile, caCertFile string) (net.Listener, error) {
config, err := createTLSConfig(certFile, keyFile, caCertFile)
if err != nil {
return nil, err
}
return tls.NewListener(ln, config), nil
}
// createTLSConfig returns a TLS config from the given cert and key.
func createTLSConfig(certFile, keyFile, caCertFile string) (*tls.Config, error) {
var err error
config := &tls.Config{}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
if caCertFile != "" {
asn1Data, err := ioutil.ReadFile(caCertFile)
if err != nil {
return nil, err
}
config.RootCAs = x509.NewCertPool()
ok := config.RootCAs.AppendCertsFromPEM([]byte(asn1Data))
if !ok {
return nil, fmt.Errorf("failed to parse root certificate(s) in %s", caCertFile)
}
}
return config, nil
}

@ -1,235 +0,0 @@
package tcp
import (
"bytes"
"crypto/tls"
"io"
"io/ioutil"
"log"
"net"
"os"
"strings"
"sync"
"testing"
"testing/quick"
"time"
"github.com/rqlite/rqlite/testdata/x509"
)
// Ensure the muxer can split a listener's connections across multiple listeners.
func TestMux(t *testing.T) {
if err := quick.Check(func(n uint8, msg []byte) bool {
if testing.Verbose() {
if len(msg) == 0 {
log.Printf("n=%d, <no message>", n)
} else {
log.Printf("n=%d, hdr=%d, len=%d", n, msg[0], len(msg))
}
}
var wg sync.WaitGroup
// Open single listener on random port.
tcpListener := mustTCPListener("127.0.0.1:0")
defer tcpListener.Close()
// Setup muxer & listeners.
mux, err := NewMux(tcpListener, nil)
if err != nil {
t.Fatalf("failed to create mux: %s", err.Error())
}
mux.Timeout = 200 * time.Millisecond
if !testing.Verbose() {
mux.Logger = log.New(ioutil.Discard, "", 0)
}
for i := uint8(0); i < n; i++ {
ln := mux.Listen(byte(i))
wg.Add(1)
go func(i uint8, ln net.Listener) {
defer wg.Done()
// Wait for a connection for this listener.
conn, err := ln.Accept()
if conn != nil {
defer conn.Close()
}
// If there is no message or the header byte
// doesn't match then expect close.
if len(msg) == 0 || msg[0] != byte(i) {
if err == nil || err.Error() != "network connection closed" {
t.Fatalf("unexpected error: %s", err)
}
return
}
// If the header byte matches this listener
// then expect a connection and read the message.
var buf bytes.Buffer
if _, err := io.CopyN(&buf, conn, int64(len(msg)-1)); err != nil {
t.Fatal(err)
} else if !bytes.Equal(msg[1:], buf.Bytes()) {
t.Fatalf("message mismatch:\n\nexp=%x\n\ngot=%x\n\n", msg[1:], buf.Bytes())
}
// Write response.
if _, err := conn.Write([]byte("OK")); err != nil {
t.Fatal(err)
}
}(i, ln)
}
// Begin serving from the listener.
go mux.Serve()
// Write message to TCP listener and read OK response.
conn, err := net.Dial("tcp", tcpListener.Addr().String())
if err != nil {
t.Fatal(err)
} else if _, err = conn.Write(msg); err != nil {
t.Fatal(err)
}
// Read the response into the buffer.
var resp [2]byte
_, err = io.ReadFull(conn, resp[:])
// If the message header is less than n then expect a response.
// Otherwise we should get an EOF because the mux closed.
if len(msg) > 0 && uint8(msg[0]) < n {
if string(resp[:]) != `OK` {
t.Fatalf("unexpected response: %s", resp[:])
}
} else {
if err == nil || (err != io.EOF && !(strings.Contains(err.Error(), "connection reset by peer") ||
strings.Contains(err.Error(), "closed by the remote host"))) {
t.Fatalf("unexpected error: %s", err)
}
}
// Close connection.
if err := conn.Close(); err != nil {
t.Fatal(err)
}
// Close original TCP listener and wait for all goroutines to close.
tcpListener.Close()
wg.Wait()
return true
}, nil); err != nil {
t.Error(err)
}
}
func TestMux_Advertise(t *testing.T) {
// Setup muxer.
tcpListener := mustTCPListener("127.0.0.1:0")
defer tcpListener.Close()
addr := &mockAddr{
Nwk: "tcp",
Addr: "rqlite.com:8081",
}
mux, err := NewMux(tcpListener, addr)
if err != nil {
t.Fatalf("failed to create mux: %s", err.Error())
}
mux.Timeout = 200 * time.Millisecond
if !testing.Verbose() {
mux.Logger = log.New(ioutil.Discard, "", 0)
}
layer := mux.Listen(1)
if layer.Addr().String() != addr.Addr {
t.Fatalf("layer advertise address not correct, exp %s, got %s",
layer.Addr().String(), addr.Addr)
}
}
// Ensure two handlers cannot be registered for the same header byte.
func TestMux_Listen_ErrAlreadyRegistered(t *testing.T) {
defer func() {
if r := recover(); r != `listener already registered under header byte: 5` {
t.Fatalf("unexpected recover: %#v", r)
}
}()
// Register two listeners with the same header byte.
tcpListener := mustTCPListener("127.0.0.1:0")
mux, err := NewMux(tcpListener, nil)
if err != nil {
t.Fatalf("failed to create mux: %s", err.Error())
}
mux.Listen(5)
mux.Listen(5)
}
func TestTLSMux(t *testing.T) {
tcpListener := mustTCPListener("127.0.0.1:0")
defer tcpListener.Close()
cert := x509.CertFile()
defer os.Remove(cert)
key := x509.KeyFile()
defer os.Remove(key)
mux, err := NewTLSMux(tcpListener, nil, cert, key, "")
if err != nil {
t.Fatalf("failed to create mux: %s", err.Error())
}
go mux.Serve()
// Verify that the listener is secured.
conn, err := tls.Dial("tcp", tcpListener.Addr().String(), &tls.Config{
InsecureSkipVerify: true,
})
if err != nil {
t.Fatal(err)
}
state := conn.ConnectionState()
if !state.HandshakeComplete {
t.Fatal("connection handshake failed to complete")
}
}
func TestTLSMux_Fail(t *testing.T) {
tcpListener := mustTCPListener("127.0.0.1:0")
defer tcpListener.Close()
cert := x509.CertFile()
defer os.Remove(cert)
key := x509.KeyFile()
defer os.Remove(key)
_, err := NewTLSMux(tcpListener, nil, "xxxx", "yyyy", "")
if err == nil {
t.Fatalf("created mux unexpectedly with bad resources")
}
}
type mockAddr struct {
Nwk string
Addr string
}
func (m *mockAddr) Network() string {
return m.Nwk
}
func (m *mockAddr) String() string {
return m.Addr
}
// mustTCPListener returns a listener on bind, or panics.
func mustTCPListener(bind string) net.Listener {
l, err := net.Listen("tcp", bind)
if err != nil {
panic(err)
}
return l
}

@ -0,0 +1,101 @@
package tcp
import (
"crypto/tls"
"fmt"
"net"
"time"
)
// Transport is the network layer for internode communications.
type Transport struct {
ln net.Listener
certFile string // Path to local X.509 cert.
certKey string // Path to corresponding X.509 key.
remoteEncrypted bool // Remote nodes use encrypted communication.
skipVerify bool // Skip verification of remote node certs.
}
// NewTransport returns an initialized unecrypted Transport.
func NewTransport() *Transport {
return &Transport{}
}
// NewTransport returns an initialized TLS-ecrypted Transport.
func NewTLSTransport(certFile, keyPath string, skipVerify bool) *Transport {
return &Transport{
certFile: certFile,
certKey: keyPath,
remoteEncrypted: true,
skipVerify: skipVerify,
}
}
// Open opens the transport, binding to the supplied address.
func (t *Transport) Open(addr string) error {
ln, err := net.Listen("tcp", addr)
if err != nil {
return err
}
if t.certFile != "" {
config, err := createTLSConfig(t.certFile, t.certKey)
if err != nil {
return err
}
ln = tls.NewListener(ln, config)
}
t.ln = ln
return nil
}
// Dial opens a network connection.
func (t *Transport) Dial(addr string, timeout time.Duration) (net.Conn, error) {
dialer := &net.Dialer{Timeout: timeout}
var err error
var conn net.Conn
if t.remoteEncrypted {
conf := &tls.Config{
InsecureSkipVerify: t.skipVerify,
}
fmt.Println("doing a TLS dial")
conn, err = tls.DialWithDialer(dialer, "tcp", addr, conf)
} else {
conn, err = dialer.Dial("tcp", addr)
}
return conn, err
}
// Accept waits for the next connection.
func (t *Transport) Accept() (net.Conn, error) {
c, err := t.ln.Accept()
if err != nil {
fmt.Println("error accepting: ", err.Error())
}
return c, err
}
// Close closes the transport
func (t *Transport) Close() error {
return t.ln.Close()
}
// Addr returns the binding address of the transport.
func (t *Transport) Addr() net.Addr {
return t.ln.Addr()
}
// createTLSConfig returns a TLS config from the given cert and key.
func createTLSConfig(certFile, keyFile string) (*tls.Config, error) {
var err error
config := &tls.Config{}
config.Certificates = make([]tls.Certificate, 1)
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
if err != nil {
return nil, err
}
return config, nil
}

@ -0,0 +1,70 @@
package tcp
import (
"os"
"testing"
"time"
"github.com/rqlite/rqlite/testdata/x509"
)
func Test_NewTransport(t *testing.T) {
if NewTransport() == nil {
t.Fatal("failed to create new Transport")
}
}
func Test_TransportOpenClose(t *testing.T) {
tn := NewTransport()
if err := tn.Open("localhost:0"); err != nil {
t.Fatalf("failed to open transport: %s", err.Error())
}
if tn.Addr().String() == "localhost:0" {
t.Fatalf("transport address set incorrectly, got: %s", tn.Addr().String())
}
if err := tn.Close(); err != nil {
t.Fatalf("failed to close transport: %s", err.Error())
}
}
func Test_TransportDial(t *testing.T) {
tn1 := NewTransport()
tn1.Open("localhost:0")
go tn1.Accept()
tn2 := NewTransport()
_, err := tn2.Dial(tn1.Addr().String(), time.Second)
if err != nil {
t.Fatalf("failed to connect to first transport: %s", err.Error())
}
tn1.Close()
}
func Test_NewTLSTransport(t *testing.T) {
c := x509.CertFile()
defer os.Remove(c)
k := x509.KeyFile()
defer os.Remove(k)
if NewTLSTransport(c, k, true) == nil {
t.Fatal("failed to create new TLS Transport")
}
}
func Test_TLSTransportOpenClose(t *testing.T) {
c := x509.CertFile()
defer os.Remove(c)
k := x509.KeyFile()
defer os.Remove(k)
tn := NewTLSTransport(c, k, true)
if err := tn.Open("localhost:0"); err != nil {
t.Fatalf("failed to open TLS transport: %s", err.Error())
}
if tn.Addr().String() == "localhost:0" {
t.Fatalf("TLS transport address set incorrectly, got: %s", tn.Addr().String())
}
if err := tn.Close(); err != nil {
t.Fatalf("failed to close TLS transport: %s", err.Error())
}
}
Loading…
Cancel
Save