diff --git a/DOC/CLUSTER_MGMT.md b/DOC/CLUSTER_MGMT.md index baee18b2..1eeff33c 100644 --- a/DOC/CLUSTER_MGMT.md +++ b/DOC/CLUSTER_MGMT.md @@ -57,7 +57,7 @@ You can grow a cluster, at anytime, simply by starting up a new node and having # Removing or replacing a node If a node fails completely and is not coming back, or if you shut down a node because you wish to deprovision it, its record should also be removed from the cluster. To remove the record of a node from a cluster, execute the following command: ``` -curl -XDELETE http://localhost:4001/remove -d '{"addr": ""}' +curl -XDELETE http://localhost:4001/remove -d '{"id": ""}' ``` assuming `localhost` is the address of the cluster leader. If you do not do this the leader will continually attempt to communicate with that node. diff --git a/cluster/join.go b/cluster/join.go index 5458cc72..95b3ac14 100644 --- a/cluster/join.go +++ b/cluster/join.go @@ -20,9 +20,9 @@ const numAttempts int = 3 const attemptInterval time.Duration = 5 * time.Second // It walks through joinAddr in order, and sets the node ID and Raft address of -// the joining node as nodeID advAddr respectively. It returns the endpoint -// successfully used to join the cluster. -func Join(joinAddr []string, nodeID, advAddr string, tlsConfig *tls.Config) (string, error) { +// the joining node as id addr respectively. It returns the endpoint successfully +// used to join the cluster. +func Join(joinAddr []string, id, addr string, meta map[string]string, tlsConfig *tls.Config) (string, error) { var err error var j string logger := log.New(os.Stderr, "[cluster-join] ", log.LstdFlags) @@ -32,7 +32,7 @@ func Join(joinAddr []string, nodeID, advAddr string, tlsConfig *tls.Config) (str for i := 0; i < numAttempts; i++ { for _, a := range joinAddr { - j, err = join(a, nodeID, advAddr, tlsConfig, logger) + j, err = join(a, id, addr, meta, tlsConfig, logger) if err == nil { // Success! return j, nil @@ -45,13 +45,13 @@ func Join(joinAddr []string, nodeID, advAddr string, tlsConfig *tls.Config) (str return "", err } -func join(joinAddr, nodeID, advAddr string, tlsConfig *tls.Config, logger *log.Logger) (string, error) { - if nodeID == "" { +func join(joinAddr, id, addr string, meta map[string]string, tlsConfig *tls.Config, logger *log.Logger) (string, error) { + if id == "" { return "", fmt.Errorf("node ID not set") } // Join using IP address, as that is what Hashicorp Raft works in. - resv, err := net.ResolveTCPAddr("tcp", advAddr) + resv, err := net.ResolveTCPAddr("tcp", addr) if err != nil { return "", err } @@ -59,7 +59,7 @@ func join(joinAddr, nodeID, advAddr string, tlsConfig *tls.Config, logger *log.L // Check for protocol scheme, and insert default if necessary. fullAddr := httpd.NormalizeAddr(fmt.Sprintf("%s/join", joinAddr)) - // Enable skipVerify as requested. + // Create and configure the client to connect to the other node. tr := &http.Transport{ TLSClientConfig: tlsConfig, } @@ -69,9 +69,10 @@ func join(joinAddr, nodeID, advAddr string, tlsConfig *tls.Config, logger *log.L } for { - b, err := json.Marshal(map[string]string{ - "id": nodeID, + b, err := json.Marshal(map[string]interface{}{ + "id": id, "addr": resv.String(), + "meta": meta, }) // Attempt to join. diff --git a/cluster/join_test.go b/cluster/join_test.go index c521800f..98a53f67 100644 --- a/cluster/join_test.go +++ b/cluster/join_test.go @@ -1,7 +1,9 @@ package cluster import ( + "encoding/json" "fmt" + "io/ioutil" "net/http" "net/http/httptest" "testing" @@ -16,7 +18,7 @@ func Test_SingleJoinOK(t *testing.T) { })) defer ts.Close() - j, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil) + j, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil, nil) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) } @@ -25,13 +27,57 @@ func Test_SingleJoinOK(t *testing.T) { } } +func Test_SingleJoinMetaOK(t *testing.T) { + var body map[string]interface{} + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Fatalf("Client did not use POST") + } + w.WriteHeader(http.StatusOK) + + b, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + if err := json.Unmarshal(b, &body); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + })) + defer ts.Close() + + nodeAddr := "127.0.0.1:9090" + md := map[string]string{"foo": "bar"} + j, err := Join([]string{ts.URL}, "id0", nodeAddr, md, nil) + if err != nil { + t.Fatalf("failed to join a single node: %s", err.Error()) + } + if j != ts.URL+"/join" { + t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts.URL) + } + + if id, _ := body["id"]; id != "id0" { + t.Fatalf("node joined supplying wrong ID, exp %s, got %s", "id0", body["id"]) + } + if addr, _ := body["addr"]; addr != nodeAddr { + t.Fatalf("node joined supplying wrong address, exp %s, got %s", nodeAddr, body["addr"]) + } + rxMd, _ := body["meta"].(map[string]interface{}) + if len(rxMd) != len(md) || rxMd["foo"] != "bar" { + t.Fatalf("node joined supplying wrong meta") + } +} + func Test_SingleJoinFail(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) })) defer ts.Close() - _, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil) + _, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil, nil) if err == nil { t.Fatalf("expected error when joining bad node") } @@ -45,7 +91,7 @@ func Test_DoubleJoinOK(t *testing.T) { })) defer ts2.Close() - j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil) + j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil, nil) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) } @@ -63,7 +109,7 @@ func Test_DoubleJoinOKSecondNode(t *testing.T) { })) defer ts2.Close() - j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil) + j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil, nil) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) } @@ -83,7 +129,7 @@ func Test_DoubleJoinOKSecondNodeRedirect(t *testing.T) { })) defer ts2.Close() - j, err := Join([]string{ts2.URL}, "id0", "127.0.0.1:9090", nil) + j, err := Join([]string{ts2.URL}, "id0", "127.0.0.1:9090", nil, nil) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) } diff --git a/cluster/service.go b/cluster/service.go deleted file mode 100644 index d51207b7..00000000 --- a/cluster/service.go +++ /dev/null @@ -1,180 +0,0 @@ -package cluster - -import ( - "encoding/json" - "fmt" - "log" - "net" - "os" - "sync" - "time" -) - -const ( - connectionTimeout = 10 * time.Second -) - -var respOKMarshalled []byte - -func init() { - var err error - respOKMarshalled, err = json.Marshal(response{}) - if err != nil { - panic(fmt.Sprintf("unable to JSON marshal OK response: %s", err.Error())) - } -} - -type response struct { - Code int `json:"code,omitempty"` - Message string `json:"message,omitempty"` -} - -// Transport is the interface the network service must provide. -type Transport interface { - net.Listener - - // Dial is used to create a new outgoing connection - Dial(address string, timeout time.Duration) (net.Conn, error) -} - -// Store represents a store of information, managed via consensus. -type Store interface { - // Leader returns the address of the leader of the consensus system. - LeaderAddr() string - - // UpdateAPIPeers updates the API peers on the store. - UpdateAPIPeers(peers map[string]string) error -} - -// Service allows access to the cluster and associated meta data, -// via consensus. -type Service struct { - tn Transport - store Store - addr net.Addr - - wg sync.WaitGroup - - logger *log.Logger -} - -// NewService returns a new instance of the cluster service -func NewService(tn Transport, store Store) *Service { - return &Service{ - tn: tn, - store: store, - addr: tn.Addr(), - logger: log.New(os.Stderr, "[cluster] ", log.LstdFlags), - } -} - -// Open opens the Service. -func (s *Service) Open() error { - s.wg.Add(1) - go s.serve() - s.logger.Println("service listening on", s.tn.Addr()) - return nil -} - -// Close closes the service. -func (s *Service) Close() error { - s.tn.Close() - s.wg.Wait() - return nil -} - -// Addr returns the address the service is listening on. -func (s *Service) Addr() string { - return s.addr.String() -} - -// SetPeer will set the mapping between raftAddr and apiAddr for the entire cluster. -func (s *Service) SetPeer(raftAddr, apiAddr string) error { - peer := map[string]string{ - raftAddr: apiAddr, - } - - // Try the local store. It might be the leader. - err := s.store.UpdateAPIPeers(peer) - if err == nil { - // All done! Aren't we lucky? - return nil - } - - // Try talking to the leader over the network. - if leader := s.store.LeaderAddr(); leader == "" { - return fmt.Errorf("no leader available") - } - conn, err := s.tn.Dial(s.store.LeaderAddr(), connectionTimeout) - if err != nil { - return err - } - defer conn.Close() - - b, err := json.Marshal(peer) - if err != nil { - return err - } - - if _, err := conn.Write(b); err != nil { - return err - } - - // Wait for the response and verify the operation went through. - resp := response{} - d := json.NewDecoder(conn) - err = d.Decode(&resp) - if err != nil { - return err - } - - if resp.Code != 0 { - return fmt.Errorf(resp.Message) - } - return nil -} - -func (s *Service) serve() error { - defer s.wg.Done() - - for { - conn, err := s.tn.Accept() - if err != nil { - return err - } - - go s.handleConn(conn) - } -} - -func (s *Service) handleConn(conn net.Conn) { - s.logger.Printf("received connection from %s", conn.RemoteAddr().String()) - - // Only handles peers updates for now. - peers := make(map[string]string) - d := json.NewDecoder(conn) - err := d.Decode(&peers) - if err != nil { - return - } - - // Update the peers. - if err := s.store.UpdateAPIPeers(peers); err != nil { - resp := response{1, err.Error()} - b, err := json.Marshal(resp) - if err != nil { - conn.Close() // Only way left to signal. - } else { - if _, err := conn.Write(b); err != nil { - conn.Close() // Only way left to signal. - } - } - return - } - - // Let the remote node know everything went OK. - if _, err := conn.Write(respOKMarshalled); err != nil { - conn.Close() // Only way left to signal. - } - return -} diff --git a/cluster/service_test.go b/cluster/service_test.go deleted file mode 100644 index 2f38c346..00000000 --- a/cluster/service_test.go +++ /dev/null @@ -1,194 +0,0 @@ -package cluster - -import ( - "encoding/json" - "fmt" - "net" - "testing" - "time" -) - -func Test_NewServiceOpenClose(t *testing.T) { - ml := mustNewMockTransport() - ms := &mockStore{} - s := NewService(ml, ms) - if s == nil { - t.Fatalf("failed to create cluster service") - } - - if err := s.Open(); err != nil { - t.Fatalf("failed to open cluster service") - } - if err := s.Close(); err != nil { - t.Fatalf("failed to close cluster service") - } -} - -func Test_SetAPIPeer(t *testing.T) { - raftAddr, apiAddr := "localhost:4002", "localhost:4001" - - s, _, ms := mustNewOpenService() - defer s.Close() - if err := s.SetPeer(raftAddr, apiAddr); err != nil { - t.Fatalf("failed to set peer: %s", err.Error()) - } - - if ms.peers[raftAddr] != apiAddr { - t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr]) - } -} - -func Test_SetAPIPeerNetwork(t *testing.T) { - raftAddr, apiAddr := "localhost:4002", "localhost:4001" - - s, _, ms := mustNewOpenService() - defer s.Close() - - raddr, err := net.ResolveTCPAddr("tcp", s.Addr()) - if err != nil { - t.Fatalf("failed to resolve remote uster ervice address: %s", err.Error()) - } - - conn, err := net.DialTCP("tcp4", nil, raddr) - if err != nil { - t.Fatalf("failed to connect to remote cluster service: %s", err.Error()) - } - if _, err := conn.Write([]byte(fmt.Sprintf(`{"%s": "%s"}`, raftAddr, apiAddr))); err != nil { - t.Fatalf("failed to write to remote cluster service: %s", err.Error()) - } - - resp := response{} - d := json.NewDecoder(conn) - err = d.Decode(&resp) - if err != nil { - t.Fatalf("failed to decode response: %s", err.Error()) - } - - if resp.Code != 0 { - t.Fatalf("response code was non-zero") - } - - if ms.peers[raftAddr] != apiAddr { - t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr]) - } -} - -func Test_SetAPIPeerFailUpdate(t *testing.T) { - raftAddr, apiAddr := "localhost:4002", "localhost:4001" - - s, _, ms := mustNewOpenService() - defer s.Close() - ms.failUpdateAPIPeers = true - - // Attempt to set peer without a leader - if err := s.SetPeer(raftAddr, apiAddr); err == nil { - t.Fatalf("no error returned by set peer when no leader") - } - - // Start a network server. - tn, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("failed to open test server: %s", err.Error()) - } - ms.leader = tn.Addr().String() - - c := make(chan map[string]string, 1) - go func() { - conn, err := tn.Accept() - if err != nil { - t.Fatalf("failed to accept connection from cluster: %s", err.Error()) - } - t.Logf("test server received connection from: %s", conn.RemoteAddr()) - - peers := make(map[string]string) - d := json.NewDecoder(conn) - err = d.Decode(&peers) - if err != nil { - t.Fatalf("failed to decode message from cluster: %s", err.Error()) - } - - // Response OK. - // Let the remote node know everything went OK. - if _, err := conn.Write(respOKMarshalled); err != nil { - t.Fatalf("failed to respond to cluster: %s", err.Error()) - } - - c <- peers - }() - - if err := s.SetPeer(raftAddr, apiAddr); err != nil { - t.Fatalf("failed to set peer on cluster: %s", err.Error()) - } - - peers := <-c - if peers[raftAddr] != apiAddr { - t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr]) - } -} - -func mustNewOpenService() (*Service, *mockTransport, *mockStore) { - ml := mustNewMockTransport() - ms := newMockStore() - s := NewService(ml, ms) - if err := s.Open(); err != nil { - panic("failed to open new service") - } - return s, ml, ms -} - -type mockTransport struct { - tn net.Listener -} - -func mustNewMockTransport() *mockTransport { - tn, err := net.Listen("tcp", "localhost:0") - if err != nil { - panic("failed to create mock listener") - } - return &mockTransport{ - tn: tn, - } -} - -func (ml *mockTransport) Accept() (c net.Conn, err error) { - return ml.tn.Accept() -} - -func (ml *mockTransport) Addr() net.Addr { - return ml.tn.Addr() -} - -func (ml *mockTransport) Close() (err error) { - return ml.tn.Close() -} - -func (ml *mockTransport) Dial(addr string, t time.Duration) (net.Conn, error) { - return net.DialTimeout("tcp", addr, 5*time.Second) -} - -type mockStore struct { - leader string - peers map[string]string - failUpdateAPIPeers bool -} - -func newMockStore() *mockStore { - return &mockStore{ - peers: make(map[string]string), - } -} - -func (ms *mockStore) LeaderAddr() string { - return ms.leader -} - -func (ms *mockStore) UpdateAPIPeers(peers map[string]string) error { - if ms.failUpdateAPIPeers { - return fmt.Errorf("forced fail") - } - - for k, v := range peers { - ms.peers[k] = v - } - return nil -} diff --git a/cmd/rqlited/main.go b/cmd/rqlited/main.go index a1ad44d4..993cd26e 100644 --- a/cmd/rqlited/main.go +++ b/cmd/rqlited/main.go @@ -8,7 +8,6 @@ import ( "fmt" "io/ioutil" "log" - "net" "os" "os/signal" "path/filepath" @@ -161,47 +160,28 @@ func main() { // Start requested profiling. startProfile(cpuProfile, memProfile) - // Set up node-to-node TCP communication. - ln, err := net.Listen("tcp", raftAddr) - if err != nil { - log.Fatalf("failed to listen on %s: %s", raftAddr, err.Error()) - } - var adv net.Addr - if raftAdv != "" { - adv, err = net.ResolveTCPAddr("tcp", raftAdv) - if err != nil { - log.Fatalf("failed to resolve advertise address %s: %s", raftAdv, err.Error()) - } - } - - // Start up node-to-node network mux. - var mux *tcp.Mux + // Create internode network layer. + var tn *tcp.Transport if nodeEncrypt { log.Printf("enabling node-to-node encryption with cert: %s, key: %s", nodeX509Cert, nodeX509Key) - mux, err = tcp.NewTLSMux(ln, adv, nodeX509Cert, nodeX509Key, nodeX509CACert) + tn = tcp.NewTLSTransport(nodeX509Cert, nodeX509Key, noVerify) } else { - mux, err = tcp.NewMux(ln, adv) + tn = tcp.NewTransport() } - if err != nil { - log.Fatalf("failed to create node-to-node mux: %s", err.Error()) + if err := tn.Open(raftAddr); err != nil { + log.Fatalf("failed to open internode network layer: %s", err.Error()) } - mux.InsecureSkipVerify = noNodeVerify - go mux.Serve() - - // Get transport for Raft communications. - raftTn := mux.Listen(muxRaftHeader) // Create and open the store. - dataPath, err = filepath.Abs(dataPath) + dataPath, err := filepath.Abs(dataPath) if err != nil { log.Fatalf("failed to determine absolute data path: %s", err.Error()) } dbConf := store.NewDBConfig(dsn, !onDisk) - str := store.New(&store.StoreConfig{ + str := store.New(tn, &store.StoreConfig{ DBConf: dbConf, Dir: dataPath, - Tn: raftTn, ID: idOrRaftAddr(), }) @@ -220,10 +200,6 @@ func main() { if err != nil { log.Fatalf("failed to parse Raft apply timeout %s: %s", raftApplyTimeout, err.Error()) } - str.OpenTimeout, err = time.ParseDuration(raftOpenTimeout) - if err != nil { - log.Fatalf("failed to parse Raft open timeout %s: %s", raftOpenTimeout, err.Error()) - } // Determine join addresses, if necessary. ja, err := store.JoinAllowed(dataPath) @@ -241,16 +217,18 @@ func main() { log.Println("node is already member of cluster, skip determining join addresses") } - // Now, open it. + // Now, open store. if err := str.Open(len(joins) == 0); err != nil { log.Fatalf("failed to open store: %s", err.Error()) } - // Create and configure cluster service. - tn := mux.Listen(muxMetaHeader) - cs := cluster.NewService(tn, str) - if err := cs.Open(); err != nil { - log.Fatalf("failed to open cluster service: %s", err.Error()) + // Prepare metadata for join command. + apiAdv := httpAddr + if httpAdv != "" { + apiAdv = httpAdv + } + meta := map[string]string{ + "api_addr": apiAdv, } // Execute any requested join operation. @@ -274,7 +252,7 @@ func main() { } } - if j, err := cluster.Join(joins, str.ID(), advAddr, &tlsConfig); err != nil { + if j, err := cluster.Join(joins, str.ID(), advAddr, meta, &tlsConfig); err != nil { log.Fatalf("failed to join cluster at %s: %s", joins, err.Error()) } else { log.Println("successfully joined cluster at", j) @@ -284,53 +262,26 @@ func main() { log.Println("no join addresses set") } - // Publish to the cluster the mapping between this Raft address and API address. - // The Raft layer broadcasts the resolved address, so use that as the key. But - // only set different HTTP advertise address if set. - apiAdv := httpAddr - if httpAdv != "" { - apiAdv = httpAdv - } - - if err := publishAPIAddr(cs, raftTn.Addr().String(), apiAdv, publishPeerTimeout); err != nil { - log.Fatalf("failed to set peer for %s to %s: %s", raftAddr, httpAddr, err.Error()) - } - log.Printf("set peer for %s to %s", raftTn.Addr().String(), apiAdv) - - // Get the credential store. - credStr, err := credentialStore() + // Wait until the store is in full consensus. + openTimeout, err := time.ParseDuration(raftOpenTimeout) if err != nil { - log.Fatalf("failed to get credential store: %s", err.Error()) + log.Fatalf("failed to parse Raft open timeout %s: %s", raftOpenTimeout, err.Error()) } + str.WaitForLeader(openTimeout) + str.WaitForApplied(openTimeout) - // Create HTTP server and load authentication information if required. - var s *httpd.Service - if credStr != nil { - s = httpd.New(httpAddr, str, credStr) - } else { - s = httpd.New(httpAddr, str, nil) + // This may be a standalone server. In that case set its own metadata. + if err := str.SetMetadata(meta); err != nil && err != store.ErrNotLeader { + // Non-leader errors are OK, since metadata will then be set through + // consensus as a result of a join. All other errors indicate a problem. + log.Fatalf("failed to set store metadata: %s", err.Error()) } - s.CACertFile = x509CACert - s.CertFile = x509Cert - s.KeyFile = x509Key - s.Expvar = expvar - s.Pprof = pprofEnabled - s.BuildInfo = map[string]interface{}{ - "commit": commit, - "branch": branch, - "version": version, - "build_time": buildtime, - } - if err := s.Start(); err != nil { + // Start the HTTP API server. + if err := startHTTPService(str); err != nil { log.Fatalf("failed to start HTTP server: %s", err.Error()) } - // Register cross-component statuses. - if err := s.RegisterStatus("mux", mux); err != nil { - log.Fatalf("failed to register mux status: %s", err.Error()) - } - // Block until signalled. terminate := make(chan os.Signal, 1) signal.Notify(terminate, os.Interrupt) @@ -373,25 +324,32 @@ func determineJoinAddresses() ([]string, error) { return addrs, nil } -func publishAPIAddr(c *cluster.Service, raftAddr, apiAddr string, t time.Duration) error { - tck := time.NewTicker(publishPeerDelay) - defer tck.Stop() - tmr := time.NewTimer(t) - defer tmr.Stop() - - for { - select { - case <-tck.C: - if err := c.SetPeer(raftAddr, apiAddr); err != nil { - log.Printf("failed to set peer for %s to %s: %s (retrying)", - raftAddr, apiAddr, err.Error()) - continue - } - return nil - case <-tmr.C: - return fmt.Errorf("set peer timeout expired") - } +func startHTTPService(str *store.Store) error { + // Get the credential store. + credStr, err := credentialStore() + if err != nil { + return err + } + + // Create HTTP server and load authentication information if required. + var s *httpd.Service + if credStr != nil { + s = httpd.New(httpAddr, str, credStr) + } else { + s = httpd.New(httpAddr, str, nil) + } + + s.CertFile = x509Cert + s.KeyFile = x509Key + s.Expvar = expvar + s.Pprof = pprofEnabled + s.BuildInfo = map[string]interface{}{ + "commit": commit, + "branch": branch, + "version": version, + "build_time": buildtime, } + return s.Start() } func credentialStore() (*auth.CredentialsStore, error) { diff --git a/http/service.go b/http/service.go index 362ca07b..74a4a45a 100644 --- a/http/service.go +++ b/http/service.go @@ -44,16 +44,16 @@ type Store interface { Query(qr *store.QueryRequest) ([]*sql.Rows, error) // Join joins the node with the given ID, reachable at addr, to this node. - Join(id, addr string) error + Join(id, addr string, metadata map[string]string) error - // Remove removes the node, specified by addr, from the cluster. - Remove(addr string) error + // Remove removes the node, specified by id, from the cluster. + Remove(id string) error - // LeaderAddr returns the Raft address of the leader of the cluster. - LeaderAddr() string + // Metadata returns the value for the given node ID, for the given key. + Metadata(id, key string) string - // Peer returns the API peer for the given address - Peer(addr string) string + // Leader returns the Raft address of the leader of the cluster. + LeaderID() (string, error) // Stats returns stats on the Store. Stats() (map[string]interface{}, error) @@ -289,27 +289,35 @@ func (s *Service) handleJoin(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) return } - m := map[string]string{} - if err := json.Unmarshal(b, &m); err != nil { + md := map[string]interface{}{} + if err := json.Unmarshal(b, &md); err != nil { w.WriteHeader(http.StatusBadRequest) return } - remoteID, ok := m["id"] + remoteID, ok := md["id"] if !ok { w.WriteHeader(http.StatusBadRequest) return } + var m map[string]string + if _, ok := md["meta"].(map[string]interface{}); ok { + m = make(map[string]string) + for k, v := range md["meta"].(map[string]interface{}) { + m[k] = v.(string) + } + } - remoteAddr, ok := m["addr"] + remoteAddr, ok := md["addr"] if !ok { + fmt.Println("4444") w.WriteHeader(http.StatusBadRequest) return } - if err := s.store.Join(remoteID, remoteAddr); err != nil { + if err := s.store.Join(remoteID.(string), remoteAddr.(string), m); err != nil { if err == store.ErrNotLeader { - leader := s.store.Peer(s.store.LeaderAddr()) + leader := s.leaderAPIAddr() if leader == "" { http.Error(w, err.Error(), http.StatusServiceUnavailable) return @@ -355,15 +363,15 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) { return } - remoteAddr, ok := m["addr"] + remoteID, ok := m["id"] if !ok { w.WriteHeader(http.StatusBadRequest) return } - if err := s.store.Remove(remoteAddr); err != nil { + if err := s.store.Remove(remoteID); err != nil { if err == store.ErrNotLeader { - leader := s.store.Peer(s.store.LeaderAddr()) + leader := s.leaderAPIAddr() if leader == "" { http.Error(w, err.Error(), http.StatusServiceUnavailable) return @@ -448,7 +456,7 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) { results, err := s.store.ExecuteOrAbort(&store.ExecuteRequest{queries, timings, false}) if err != nil { if err == store.ErrNotLeader { - leader := s.store.Peer(s.store.LeaderAddr()) + leader := s.leaderAPIAddr() if leader == "" { http.Error(w, err.Error(), http.StatusServiceUnavailable) return @@ -498,7 +506,7 @@ func (s *Service) handleStatus(w http.ResponseWriter, r *http.Request) { httpStatus := map[string]interface{}{ "addr": s.Addr().String(), "auth": prettyEnabled(s.credentialStore != nil), - "redirect": s.store.Peer(s.store.LeaderAddr()), + "redirect": s.leaderAPIAddr(), } nodeStatus := map[string]interface{}{ @@ -595,7 +603,7 @@ func (s *Service) handleExecute(w http.ResponseWriter, r *http.Request) { results, err := s.store.Execute(&store.ExecuteRequest{queries, timings, isTx}) if err != nil { if err == store.ErrNotLeader { - leader := s.store.Peer(s.store.LeaderAddr()) + leader := s.leaderAPIAddr() if leader == "" { http.Error(w, err.Error(), http.StatusServiceUnavailable) return @@ -657,7 +665,7 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) { results, err := s.store.Query(&store.QueryRequest{queries, timings, isTx, lvl}) if err != nil { if err == store.ErrNotLeader { - leader := s.store.Peer(s.store.LeaderAddr()) + leader := s.leaderAPIAddr() if leader == "" { http.Error(w, err.Error(), http.StatusServiceUnavailable) return @@ -746,6 +754,14 @@ func (s *Service) CheckRequestPerm(r *http.Request, perm string) bool { return s.credentialStore.HasPerm(username, PermAll) || s.credentialStore.HasPerm(username, perm) } +func (s *Service) leaderAPIAddr() string { + id, err := s.store.LeaderID() + if err != nil { + return "" + } + return s.store.Metadata(id, "api_addr") +} + // addBuildVersion adds the build version to the HTTP response. func (s *Service) addBuildVersion(w http.ResponseWriter) { // Add version header to every response, if available. diff --git a/http/service_test.go b/http/service_test.go index 4d0e51f6..eda78be3 100644 --- a/http/service_test.go +++ b/http/service_test.go @@ -491,19 +491,19 @@ func (m *MockStore) Query(qr *store.QueryRequest) ([]*sql.Rows, error) { return nil, nil } -func (m *MockStore) Join(id, addr string) error { +func (m *MockStore) Join(id, addr string, metadata map[string]string) error { return nil } -func (m *MockStore) Remove(addr string) error { +func (m *MockStore) Remove(id string) error { return nil } -func (m *MockStore) LeaderAddr() string { - return "" +func (m *MockStore) LeaderID() (string, error) { + return "", nil } -func (m *MockStore) Peer(addr string) string { +func (m *MockStore) Metadata(id, key string) string { return "" } diff --git a/store/command.go b/store/command.go index 7414b7be..5538c319 100644 --- a/store/command.go +++ b/store/command.go @@ -8,9 +8,10 @@ import ( type commandType int const ( - execute commandType = iota // Commands which modify the database. - query // Commands which query the database. - peer // Commands that modify peers map. + execute commandType = iota // Commands which modify the database. + query // Commands which query the database. + metadataSet // Commands which sets Store metadata + metadataDelete // Commands which deletes Store metadata ) type command struct { @@ -27,7 +28,14 @@ func newCommand(t commandType, d interface{}) (*command, error) { Typ: t, Sub: b, }, nil +} +func newMetadataSetCommand(id string, md map[string]string) (*command, error) { + m := metadataSetSub{ + RaftID: id, + Data: md, + } + return newCommand(metadataSet, m) } // databaseSub is a command sub which involves interaction with the database. @@ -37,5 +45,7 @@ type databaseSub struct { Timings bool `json:"timings,omitempty"` } -// peersSub is a command which sets the API address for a Raft address. -type peersSub map[string]string +type metadataSetSub struct { + RaftID string `json:"raft_id,omitempty"` + Data map[string]string `json:"data,omitempty"` +} diff --git a/store/db_config.go b/store/db_config.go new file mode 100644 index 00000000..5750843d --- /dev/null +++ b/store/db_config.go @@ -0,0 +1,12 @@ +package store + +// DBConfig represents the configuration of the underlying SQLite database. +type DBConfig struct { + DSN string // Any custom DSN + Memory bool // Whether the database is in-memory only. +} + +// NewDBConfig returns a new DB config instance. +func NewDBConfig(dsn string, memory bool) *DBConfig { + return &DBConfig{DSN: dsn, Memory: memory} +} diff --git a/store/server.go b/store/server.go new file mode 100644 index 00000000..42edb16c --- /dev/null +++ b/store/server.go @@ -0,0 +1,14 @@ +package store + +// Server represents another node in the cluster. +type Server struct { + ID string `json:"id,omitempty"` + Addr string `json:"addr,omitempty"` +} + +// Servers is a set of Servers. +type Servers []*Server + +func (s Servers) Less(i, j int) bool { return s[i].ID < s[j].ID } +func (s Servers) Len() int { return len(s) } +func (s Servers) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/store/store.go b/store/store.go index 22dd7864..d2c6c5f6 100644 --- a/store/store.go +++ b/store/store.go @@ -13,7 +13,6 @@ import ( "io" "io/ioutil" "log" - "net" "os" "path/filepath" "sort" @@ -46,6 +45,9 @@ const ( sqliteFile = "db.sqlite" leaderWaitDelay = 100 * time.Millisecond appliedWaitDelay = 100 * time.Millisecond + connectionPoolCount = 5 + connectionTimeout = 10 * time.Second + raftLogCacheSize = 512 ) const ( @@ -114,60 +116,6 @@ const ( Unknown ) -// clusterMeta represents cluster meta which must be kept in consensus. -type clusterMeta struct { - APIPeers map[string]string // Map from Raft address to API address -} - -// NewClusterMeta returns an initialized cluster meta store. -func newClusterMeta() *clusterMeta { - return &clusterMeta{ - APIPeers: make(map[string]string), - } -} - -func (c *clusterMeta) AddrForPeer(addr string) string { - if api, ok := c.APIPeers[addr]; ok && api != "" { - return api - } - - // Go through each entry, and see if any key resolves to addr. - for k, v := range c.APIPeers { - resv, err := net.ResolveTCPAddr("tcp", k) - if err != nil { - continue - } - if resv.String() == addr { - return v - } - } - - return "" -} - -// DBConfig represents the configuration of the underlying SQLite database. -type DBConfig struct { - DSN string // Any custom DSN - Memory bool // Whether the database is in-memory only. -} - -// NewDBConfig returns a new DB config instance. -func NewDBConfig(dsn string, memory bool) *DBConfig { - return &DBConfig{DSN: dsn, Memory: memory} -} - -// Server represents another node in the cluster. -type Server struct { - ID string `json:"id,omitempty"` - Addr string `json:"addr,omitempty"` -} - -type Servers []*Server - -func (s Servers) Less(i, j int) bool { return s[i].ID < s[j].ID } -func (s Servers) Len() int { return len(s) } -func (s Servers) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - // Store is a SQLite database, where all changes are made via Raft consensus. type Store struct { raftDir string @@ -175,14 +123,19 @@ type Store struct { mu sync.RWMutex // Sync access between queries and snapshots. raft *raft.Raft // The consensus mechanism. - raftTn *raftTransport + ln Listener + raftTn *raft.NetworkTransport raftID string // Node ID. dbConf *DBConfig // SQLite database config. dbPath string // Path to underlying SQLite file, if not in-memory. db *sql.DB // The underlying SQLite store. + raftLog raft.LogStore // Persistent log store. + raftStable raft.StableStore // Persistent k-v store. + boltStore *raftboltdb.BoltStore // Physical store. + metaMu sync.RWMutex - meta *clusterMeta + meta map[string]map[string]string logger *log.Logger @@ -192,7 +145,6 @@ type Store struct { HeartbeatTimeout time.Duration ElectionTimeout time.Duration ApplyTimeout time.Duration - OpenTimeout time.Duration } // StoreConfig represents the configuration of the underlying Store. @@ -205,22 +157,21 @@ type StoreConfig struct { } // New returns a new Store. -func New(c *StoreConfig) *Store { +func New(ln Listener, c *StoreConfig) *Store { logger := c.Logger if logger == nil { logger = log.New(os.Stderr, "[store] ", log.LstdFlags) } return &Store{ + ln: ln, raftDir: c.Dir, - raftTn: &raftTransport{c.Tn}, raftID: c.ID, dbConf: c.DBConf, dbPath: filepath.Join(c.Dir, sqliteFile), - meta: newClusterMeta(), + meta: make(map[string]map[string]string), logger: logger, ApplyTimeout: applyTimeout, - OpenTimeout: openTimeout, } } @@ -234,6 +185,7 @@ func (s *Store) Open(enableSingle bool) error { return err } + // Open underlying database. db, err := s.open() if err != nil { return err @@ -243,14 +195,11 @@ func (s *Store) Open(enableSingle bool) error { // Is this a brand new node? newNode := !pathExists(filepath.Join(s.raftDir, "raft.db")) - // Setup Raft communication. - transport := raft.NewNetworkTransport(s.raftTn, 3, 10*time.Second, os.Stderr) + // Create Raft-compatible network layer. + s.raftTn = raft.NewNetworkTransport(NewTransport(s.ln), connectionPoolCount, connectionTimeout, nil) - // Get the Raft configuration for this store. config := s.raftConfig() - config.LocalID = raft.ServerID(s.raftID) - // XXXconfig.Logger = log.New(os.Stderr, "[raft] ", log.LstdFlags) // Create the snapshot store. This allows Raft to truncate the log. snapshots, err := raft.NewFileSnapshotStore(s.raftDir, retainSnapshotCount, os.Stderr) @@ -259,13 +208,18 @@ func (s *Store) Open(enableSingle bool) error { } // Create the log store and stable store. - logStore, err := raftboltdb.NewBoltStore(filepath.Join(s.raftDir, "raft.db")) + s.boltStore, err = raftboltdb.NewBoltStore(filepath.Join(s.raftDir, "raft.db")) if err != nil { return fmt.Errorf("new bolt store: %s", err) } + s.raftStable = s.boltStore + s.raftLog, err = raft.NewLogCache(raftLogCacheSize, s.boltStore) + if err != nil { + return fmt.Errorf("new cached store: %s", err) + } // Instantiate the Raft system. - ra, err := raft.NewRaft(config, s, logStore, logStore, snapshots, transport) + ra, err := raft.NewRaft(config, s, s.raftLog, s.raftStable, snapshots, s.raftTn) if err != nil { return fmt.Errorf("new raft: %s", err) } @@ -276,7 +230,7 @@ func (s *Store) Open(enableSingle bool) error { Servers: []raft.Server{ raft.Server{ ID: config.LocalID, - Address: transport.LocalAddr(), + Address: s.raftTn.LocalAddr(), }, }, } @@ -287,16 +241,6 @@ func (s *Store) Open(enableSingle bool) error { s.raft = ra - if s.OpenTimeout != 0 { - // Wait until the initial logs are applied. - s.logger.Printf("waiting for up to %s for application of initial logs", s.OpenTimeout) - if err := s.WaitForAppliedIndex(s.raft.LastIndex(), s.OpenTimeout); err != nil { - return ErrOpenTimeout - } - } else { - s.logger.Println("not waiting for application of initial logs") - } - return nil } @@ -314,6 +258,19 @@ func (s *Store) Close(wait bool) error { return nil } +// WaitForApplied waits for all Raft log entries to to be applied to the +// underlying database. +func (s *Store) WaitForApplied(timeout time.Duration) error { + if timeout == 0 { + return nil + } + s.logger.Printf("waiting for up to %s for application of initial logs", timeout) + if err := s.WaitForAppliedIndex(s.raft.LastIndex(), timeout); err != nil { + return ErrOpenTimeout + } + return nil +} + // IsLeader is used to determine if the current node is cluster leader func (s *Store) IsLeader() bool { return s.raft.State() == raft.Leader @@ -342,8 +299,8 @@ func (s *Store) Path() string { } // Addr returns the address of the store. -func (s *Store) Addr() net.Addr { - return s.raftTn.Addr() +func (s *Store) Addr() string { + return string(s.raftTn.LocalAddr()) } // ID returns the Raft ID of the store. @@ -375,24 +332,6 @@ func (s *Store) LeaderID() (string, error) { return "", nil } -// Peer returns the API address for the given addr. If there is no peer -// for the address, it returns the empty string. -func (s *Store) Peer(addr string) string { - return s.meta.AddrForPeer(addr) -} - -// APIPeers return the map of Raft addresses to API addresses. -func (s *Store) APIPeers() (map[string]string, error) { - s.metaMu.RLock() - defer s.metaMu.RUnlock() - - peers := make(map[string]string, len(s.meta.APIPeers)) - for k, v := range s.meta.APIPeers { - peers[k] = v - } - return peers, nil -} - // Nodes returns the slice of nodes in the cluster, sorted by ID ascending. func (s *Store) Nodes() ([]*Server, error) { f := s.raft.GetConfiguration() @@ -480,18 +419,25 @@ func (s *Store) Stats() (map[string]interface{}, error) { if err != nil { return nil, err } + leaderID, err := s.LeaderID() + if err != nil { + return nil, err + } + status := map[string]interface{}{ - "node_id": s.raftID, - "raft": s.raft.Stats(), - "addr": s.Addr().String(), - "leader": s.LeaderAddr(), + "node_id": s.raftID, + "raft": s.raft.Stats(), + "addr": s.Addr(), + "leader": map[string]string{ + "node_id": leaderID, + "addr": s.LeaderAddr(), + }, "apply_timeout": s.ApplyTimeout.String(), - "open_timeout": s.OpenTimeout.String(), "heartbeat_timeout": s.HeartbeatTimeout.String(), "election_timeout": s.ElectionTimeout.String(), "snapshot_threshold": s.SnapshotThreshold, - "meta": s.meta, - "peers": nodes, + "metadata": s.meta, + "nodes": nodes, "dir": s.raftDir, "sqlite3": dbStatus, "db_conf": s.dbConf, @@ -636,29 +582,38 @@ func (s *Store) Query(qr *QueryRequest) ([]*sql.Rows, error) { return r, err } -// UpdateAPIPeers updates the cluster-wide peer information. -func (s *Store) UpdateAPIPeers(peers map[string]string) error { - c, err := newCommand(peer, peers) - if err != nil { - return err - } - b, err := json.Marshal(c) - if err != nil { - return err - } - - f := s.raft.Apply(b, s.ApplyTimeout) - return f.Error() -} - // Join joins a node, identified by id and located at addr, to this store. // The node must be ready to respond to Raft communications at that address. -func (s *Store) Join(id, addr string) error { +func (s *Store) Join(id, addr string, metadata map[string]string) error { s.logger.Printf("received request to join node at %s", addr) if s.raft.State() != raft.Leader { return ErrNotLeader } + configFuture := s.raft.GetConfiguration() + if err := configFuture.Error(); err != nil { + s.logger.Printf("failed to get raft configuration: %v", err) + return err + } + + for _, srv := range configFuture.Configuration().Servers { + // If a node already exists with either the joining node's ID or address, + // that node may need to be removed from the config first. + if srv.ID == raft.ServerID(id) || srv.Address == raft.ServerAddress(addr) { + // However if *both* the ID and the address are the same, the no + // join is actually needed. + if srv.Address == raft.ServerAddress(addr) && srv.ID == raft.ServerID(id) { + s.logger.Printf("node %s at %s already member of cluster, ignoring join request", id, addr) + return nil + } + + if err := s.remove(id); err != nil { + s.logger.Printf("failed to remove node: %v", err) + return err + } + } + } + f := s.raft.AddVoter(raft.ServerID(id), raft.ServerAddress(addr), 0, 0) if e := f.(raft.Future); e.Error() != nil { if e.Error() == raft.ErrNotLeader { @@ -666,6 +621,11 @@ func (s *Store) Join(id, addr string) error { } return e.Error() } + + if err := s.setMetadata(id, metadata); err != nil { + return err + } + s.logger.Printf("node at %s joined successfully", addr) return nil } @@ -673,18 +633,74 @@ func (s *Store) Join(id, addr string) error { // Remove removes a node from the store, specified by ID. func (s *Store) Remove(id string) error { s.logger.Printf("received request to remove node %s", id) - if s.raft.State() != raft.Leader { - return ErrNotLeader + if err := s.remove(id); err != nil { + s.logger.Printf("failed to remove node %s: %s", id, err.Error()) + return err } - f := s.raft.RemoveServer(raft.ServerID(id), 0, 0) - if f.Error() != nil { - if f.Error() == raft.ErrNotLeader { + s.logger.Printf("node %s removed successfully", id) + return nil +} + +// Metadata returns the value for a given key, for a given node ID. +func (s *Store) Metadata(id, key string) string { + s.metaMu.RLock() + defer s.metaMu.RUnlock() + + if _, ok := s.meta[id]; !ok { + return "" + } + v, ok := s.meta[id][key] + if ok { + return v + } + return "" +} + +// SetMetadata adds the metadata md to any existing metadata for +// this node. +func (s *Store) SetMetadata(md map[string]string) error { + return s.setMetadata(s.raftID, md) +} + +// setMetadata adds the metadata md to any existing metadata for +// the given node ID. +func (s *Store) setMetadata(id string, md map[string]string) error { + // Check local data first. + if func() bool { + s.metaMu.RLock() + defer s.metaMu.RUnlock() + if _, ok := s.meta[id]; ok { + for k, v := range md { + if s.meta[id][k] != v { + return false + } + } + return true + } + return false + }() { + // Local data is same as data being pushed in, + // nothing to do. + return nil + } + + c, err := newMetadataSetCommand(id, md) + if err != nil { + return err + } + b, err := json.Marshal(c) + if err != nil { + return err + } + f := s.raft.Apply(b, s.ApplyTimeout) + if e := f.(raft.Future); e.Error() != nil { + if e.Error() == raft.ErrNotLeader { return ErrNotLeader } - return f.Error() + e.Error() } - s.logger.Printf("node %s removed successfully", id) + return nil } @@ -712,6 +728,36 @@ func (s *Store) open() (*sql.DB, error) { return db, nil } +// remove removes the node, with the given ID, from the cluster. +func (s *Store) remove(id string) error { + if s.raft.State() != raft.Leader { + return ErrNotLeader + } + + f := s.raft.RemoveServer(raft.ServerID(id), 0, 0) + if f.Error() != nil { + if f.Error() == raft.ErrNotLeader { + return ErrNotLeader + } + return f.Error() + } + + c, err := newCommand(metadataDelete, id) + b, err := json.Marshal(c) + if err != nil { + return err + } + f = s.raft.Apply(b, s.ApplyTimeout) + if e := f.(raft.Future); e.Error() != nil { + if e.Error() == raft.ErrNotLeader { + return ErrNotLeader + } + e.Error() + } + + return nil +} + // raftConfig returns a new Raft config for the store. func (s *Store) raftConfig() *raft.Config { config := raft.DefaultConfig() @@ -764,17 +810,31 @@ func (s *Store) Apply(l *raft.Log) interface{} { } r, err := s.db.Query(d.Queries, d.Tx, d.Timings) return &fsmQueryResponse{rows: r, error: err} - case peer: - var d peersSub + case metadataSet: + var d metadataSetSub if err := json.Unmarshal(c.Sub, &d); err != nil { return &fsmGenericResponse{error: err} } func() { s.metaMu.Lock() defer s.metaMu.Unlock() - for k, v := range d { - s.meta.APIPeers[k] = v + if _, ok := s.meta[d.RaftID]; !ok { + s.meta[d.RaftID] = make(map[string]string) } + for k, v := range d.Data { + s.meta[d.RaftID][k] = v + } + }() + return &fsmGenericResponse{} + case metadataDelete: + var d string + if err := json.Unmarshal(c.Sub, &d); err != nil { + return &fsmGenericResponse{error: err} + } + func() { + s.metaMu.Lock() + defer s.metaMu.Unlock() + delete(s.meta, d) }() return &fsmGenericResponse{} default: diff --git a/store/store_test.go b/store/store_test.go index 2c7b51d1..758f6d3c 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -6,7 +6,6 @@ import ( "net" "os" "path/filepath" - "reflect" "sort" "testing" "time" @@ -14,35 +13,6 @@ import ( "github.com/rqlite/rqlite/testdata/chinook" ) -type mockSnapshotSink struct { - *os.File -} - -func (m *mockSnapshotSink) ID() string { - return "1" -} - -func (m *mockSnapshotSink) Cancel() error { - return nil -} - -func Test_ClusterMeta(t *testing.T) { - c := newClusterMeta() - c.APIPeers["localhost:4002"] = "localhost:4001" - - if c.AddrForPeer("localhost:4002") != "localhost:4001" { - t.Fatalf("wrong address returned for localhost:4002") - } - - if c.AddrForPeer("127.0.0.1:4002") != "localhost:4001" { - t.Fatalf("wrong address returned for 127.0.0.1:4002") - } - - if c.AddrForPeer("127.0.0.1:4004") != "" { - t.Fatalf("wrong address returned for 127.0.0.1:4003") - } -} - func Test_OpenStoreSingleNode(t *testing.T) { s := mustNewStore(true) defer os.RemoveAll(s.Path()) @@ -52,7 +22,7 @@ func Test_OpenStoreSingleNode(t *testing.T) { } s.WaitForLeader(10 * time.Second) - if got, exp := s.LeaderAddr(), s.Addr().String(); got != exp { + if got, exp := s.LeaderAddr(), s.Addr(); got != exp { t.Fatalf("wrong leader address returned, got: %s, exp %s", got, exp) } id, err := s.LeaderID() @@ -71,6 +41,7 @@ func Test_OpenStoreCloseSingleNode(t *testing.T) { if err := s.Open(true); err != nil { t.Fatalf("failed to open single-node store: %s", err.Error()) } + s.WaitForLeader(10 * time.Second) if err := s.Close(true); err != nil { t.Fatalf("failed to close single-node store: %s", err.Error()) } @@ -450,14 +421,14 @@ func Test_MultiNodeJoinRemove(t *testing.T) { sort.StringSlice(storeNodes).Sort() // Join the second node to the first. - if err := s0.Join(s1.ID(), s1.Addr().String()); err != nil { - t.Fatalf("failed to join to node at %s: %s", s0.Addr().String(), err.Error()) + if err := s0.Join(s1.ID(), s1.Addr(), nil); err != nil { + t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error()) } s1.WaitForLeader(10 * time.Second) // Check leader state on follower. - if got, exp := s1.LeaderAddr(), s0.Addr().String(); got != exp { + if got, exp := s1.LeaderAddr(), s0.Addr(); got != exp { t.Fatalf("wrong leader address returned, got: %s, exp %s", got, exp) } id, err := s1.LeaderID() @@ -514,8 +485,8 @@ func Test_MultiNodeExecuteQuery(t *testing.T) { defer s1.Close(true) // Join the second node to the first. - if err := s0.Join(s1.ID(), s1.Addr().String()); err != nil { - t.Fatalf("failed to join to node at %s: %s", s0.Addr().String(), err.Error()) + if err := s0.Join(s1.ID(), s1.Addr(), nil); err != nil { + t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error()) } queries := []string{ @@ -609,7 +580,7 @@ func Test_StoreLogTruncationMultinode(t *testing.T) { defer s1.Close(true) // Join the second node to the first. - if err := s0.Join(s1.ID(), s1.Addr().String()); err != nil { + if err := s0.Join(s1.ID(), s1.Addr(), nil); err != nil { t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error()) } s1.WaitForLeader(10 * time.Second) @@ -754,38 +725,65 @@ func Test_SingleNodeSnapshotInMem(t *testing.T) { } } -func Test_APIPeers(t *testing.T) { - s := mustNewStore(false) - defer os.RemoveAll(s.Path()) - - if err := s.Open(true); err != nil { +func Test_MetadataMultinode(t *testing.T) { + s0 := mustNewStore(true) + if err := s0.Open(true); err != nil { t.Fatalf("failed to open single-node store: %s", err.Error()) } - defer s.Close(true) - s.WaitForLeader(10 * time.Second) + defer s0.Close(true) + s0.WaitForLeader(10 * time.Second) + s1 := mustNewStore(true) + if err := s1.Open(true); err != nil { + t.Fatalf("failed to open single-node store: %s", err.Error()) + } + defer s1.Close(true) + s1.WaitForLeader(10 * time.Second) - peers := map[string]string{ - "localhost:4002": "localhost:4001", - "localhost:4004": "localhost:4003", + if s0.Metadata(s0.raftID, "foo") != "" { + t.Fatal("nonexistent metadata foo found") } - if err := s.UpdateAPIPeers(peers); err != nil { - t.Fatalf("failed to update API peers: %s", err.Error()) + if s0.Metadata("nonsense", "foo") != "" { + t.Fatal("nonexistent metadata foo found for nonexistent node") } - // Retrieve peers and verify them. - apiPeers, err := s.APIPeers() - if err != nil { - t.Fatalf("failed to retrieve API peers: %s", err.Error()) + if err := s0.SetMetadata(map[string]string{"foo": "bar"}); err != nil { + t.Fatalf("failed to set metadata: %s", err.Error()) + } + if s0.Metadata(s0.raftID, "foo") != "bar" { + t.Fatal("key foo not found") } - if !reflect.DeepEqual(peers, apiPeers) { - t.Fatalf("set and retrieved API peers not identical, got %v, exp %v", - apiPeers, peers) + if s0.Metadata("nonsense", "foo") != "" { + t.Fatal("nonexistent metadata foo found for nonexistent node") } - if s.Peer("localhost:4002") != "localhost:4001" || - s.Peer("localhost:4004") != "localhost:4003" || - s.Peer("not exist") != "" { - t.Fatalf("failed to retrieve correct single API peer") + // Join the second node to the first. + meta := map[string]string{"baz": "qux"} + if err := s0.Join(s1.ID(), s1.Addr(), meta); err != nil { + t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error()) + } + s1.WaitForLeader(10 * time.Second) + // Wait until the log entries have been applied to the follower, + // and then query. + if err := s1.WaitForAppliedIndex(5, 5*time.Second); err != nil { + t.Fatalf("error waiting for follower to apply index: %s:", err.Error()) + } + + if s1.Metadata(s0.raftID, "foo") != "bar" { + t.Fatal("key foo not found for s0") + } + if s0.Metadata(s1.raftID, "baz") != "qux" { + t.Fatal("key baz not found for s1") + } + + // Remove a node. + if err := s0.Remove(s1.ID()); err != nil { + t.Fatalf("failed to remove %s from cluster: %s", s1.ID(), err.Error()) + } + if s1.Metadata(s0.raftID, "foo") != "bar" { + t.Fatal("key foo not found for s0") + } + if s0.Metadata(s1.raftID, "baz") != "" { + t.Fatal("key baz found for removed node s1") } } @@ -824,12 +822,10 @@ func mustNewStore(inmem bool) *Store { path := mustTempDir() defer os.RemoveAll(path) - tn := mustMockTransport("localhost:0") cfg := NewDBConfig("", inmem) - s := New(&StoreConfig{ + s := New(mustMockLister("localhost:0"), &StoreConfig{ DBConf: cfg, Dir: path, - Tn: tn, ID: path, // Could be any unique string. }) if s == nil { @@ -838,36 +834,52 @@ func mustNewStore(inmem bool) *Store { return s } -func mustTempDir() string { - var err error - path, err := ioutil.TempDir("", "rqlilte-test-") - if err != nil { - panic("failed to create temp dir") - } - return path +type mockSnapshotSink struct { + *os.File +} + +func (m *mockSnapshotSink) ID() string { + return "1" +} + +func (m *mockSnapshotSink) Cancel() error { + return nil } type mockTransport struct { ln net.Listener } -func mustMockTransport(addr string) Transport { +type mockListener struct { + ln net.Listener +} + +func mustMockLister(addr string) Listener { ln, err := net.Listen("tcp", addr) if err != nil { - panic("failed to create new transport") + panic("failed to create new listner") } - return &mockTransport{ln} + return &mockListener{ln} } -func (m *mockTransport) Dial(addr string, timeout time.Duration) (net.Conn, error) { +func (m *mockListener) Dial(addr string, timeout time.Duration) (net.Conn, error) { return net.DialTimeout("tcp", addr, timeout) } -func (m *mockTransport) Accept() (net.Conn, error) { return m.ln.Accept() } +func (m *mockListener) Accept() (net.Conn, error) { return m.ln.Accept() } + +func (m *mockListener) Close() error { return m.ln.Close() } -func (m *mockTransport) Close() error { return m.ln.Close() } +func (m *mockListener) Addr() net.Addr { return m.ln.Addr() } -func (m *mockTransport) Addr() net.Addr { return m.ln.Addr() } +func mustTempDir() string { + var err error + path, err := ioutil.TempDir("", "rqlilte-test-") + if err != nil { + panic("failed to create temp dir") + } + return path +} func asJSON(v interface{}) string { b, err := json.Marshal(v) diff --git a/store/transport.go b/store/transport.go index 1edd1782..a6df008d 100644 --- a/store/transport.go +++ b/store/transport.go @@ -7,32 +7,39 @@ import ( "github.com/hashicorp/raft" ) -// Transport is the interface the network service must provide. -type Transport interface { +type Listener interface { net.Listener - - // Dial is used to create a new outgoing connection Dial(address string, timeout time.Duration) (net.Conn, error) } -// raftTransport takes a Transport and makes it suitable for use by the Raft -// networking system. -type raftTransport struct { - tn Transport +// Transport is the network service provided to Raft, and wraps a Listener. +type Transport struct { + ln Listener +} + +// NewTransport returns an initialized Transport. +func NewTransport(ln Listener) *Transport { + return &Transport{ + ln: ln, + } } -func (r *raftTransport) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { - return r.tn.Dial(string(address), timeout) +// Dial creates a new network connection. +func (t *Transport) Dial(addr raft.ServerAddress, timeout time.Duration) (net.Conn, error) { + return t.ln.Dial(string(addr), timeout) } -func (r *raftTransport) Accept() (net.Conn, error) { - return r.tn.Accept() +// Accept waits for the next connection. +func (t *Transport) Accept() (net.Conn, error) { + return t.ln.Accept() } -func (r *raftTransport) Addr() net.Addr { - return r.tn.Addr() +// Close closes the transport +func (t *Transport) Close() error { + return t.ln.Close() } -func (r *raftTransport) Close() error { - return r.tn.Close() +// Addr returns the binding address of the transport. +func (t *Transport) Addr() net.Addr { + return t.ln.Addr() } diff --git a/store/transport_test.go b/store/transport_test.go new file mode 100644 index 00000000..94d24c10 --- /dev/null +++ b/store/transport_test.go @@ -0,0 +1,11 @@ +package store + +import ( + "testing" +) + +func Test_NewTransport(t *testing.T) { + if NewTransport(nil) == nil { + t.Fatal("failed to create new Transport") + } +} diff --git a/system_test/helpers.go b/system_test/helpers.go index 8c112de5..7d34654a 100644 --- a/system_test/helpers.go +++ b/system_test/helpers.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io/ioutil" - "net" "net/http" "net/url" "os" @@ -14,12 +13,14 @@ import ( httpd "github.com/rqlite/rqlite/http" "github.com/rqlite/rqlite/store" + "github.com/rqlite/rqlite/tcp" ) // Node represents a node under test. type Node struct { APIAddr string RaftAddr string + ID string Dir string Store *store.Store Service *httpd.Service @@ -302,18 +303,21 @@ func mustNewNode(enableSingle bool) *Node { } dbConf := store.NewDBConfig("", false) - tn := mustMockTransport("localhost:0") - node.Store = store.New(&store.StoreConfig{ + tn := tcp.NewTransport() + if err := tn.Open("localhost:0"); err != nil { + panic(err.Error()) + } + node.Store = store.New(tn, &store.StoreConfig{ DBConf: dbConf, Dir: node.Dir, - Tn: tn, ID: tn.Addr().String(), }) if err := node.Store.Open(enableSingle); err != nil { node.Deprovision() panic(fmt.Sprintf("failed to open store: %s", err.Error())) } - node.RaftAddr = node.Store.Addr().String() + node.RaftAddr = node.Store.Addr() + node.ID = node.Store.ID() node.Service = httpd.New("localhost:0", node.Store, nil) node.Service.Expvar = true @@ -335,28 +339,6 @@ func mustNewLeaderNode() *Node { return node } -type mockTransport struct { - ln net.Listener -} - -func mustMockTransport(addr string) *mockTransport { - ln, err := net.Listen("tcp", addr) - if err != nil { - panic("failed to create new transport") - } - return &mockTransport{ln} -} - -func (m *mockTransport) Dial(addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout("tcp", addr, timeout) -} - -func (m *mockTransport) Accept() (net.Conn, error) { return m.ln.Accept() } - -func (m *mockTransport) Close() error { return m.ln.Close() } - -func (m *mockTransport) Addr() net.Addr { return m.ln.Addr() } - func mustTempDir() string { var err error path, err := ioutil.TempDir("", "rqlilte-system-test-") diff --git a/tcp/doc.go b/tcp/doc.go index 8630ff9b..0a8e9709 100644 --- a/tcp/doc.go +++ b/tcp/doc.go @@ -1,4 +1,4 @@ /* -Package tcp provides various TCP-related utilities. The TCP mux code provided by this package originated with InfluxDB. +Package tcp provides the internode communication network layer. */ package tcp diff --git a/tcp/mux.go b/tcp/mux.go deleted file mode 100644 index 27d29029..00000000 --- a/tcp/mux.go +++ /dev/null @@ -1,301 +0,0 @@ -package tcp - -import ( - "crypto/tls" - "crypto/x509" - "errors" - "fmt" - "io" - "io/ioutil" - "log" - "net" - "os" - "strconv" - "sync" - "time" -) - -const ( - // DefaultTimeout is the default length of time to wait for first byte. - DefaultTimeout = 30 * time.Second -) - -// Layer represents the connection between nodes. -type Layer struct { - ln net.Listener - header byte - addr net.Addr - - remoteEncrypted bool - skipVerify bool - nodeX509CACert string - tlsConfig *tls.Config -} - -// Addr returns the local address for the layer. -func (l *Layer) Addr() net.Addr { - return l.addr -} - -// Dial creates a new network connection. -func (l *Layer) Dial(addr string, timeout time.Duration) (net.Conn, error) { - dialer := &net.Dialer{Timeout: timeout} - - var err error - var conn net.Conn - if l.remoteEncrypted { - conn, err = tls.DialWithDialer(dialer, "tcp", addr, l.tlsConfig) - } else { - conn, err = dialer.Dial("tcp", addr) - } - if err != nil { - return nil, err - } - - // Write a marker byte to indicate message type. - _, err = conn.Write([]byte{l.header}) - if err != nil { - conn.Close() - return nil, err - } - return conn, err -} - -// Accept waits for the next connection. -func (l *Layer) Accept() (net.Conn, error) { return l.ln.Accept() } - -// Close closes the layer. -func (l *Layer) Close() error { return l.ln.Close() } - -// Mux multiplexes a network connection. -type Mux struct { - ln net.Listener - addr net.Addr - m map[byte]*listener - - wg sync.WaitGroup - - remoteEncrypted bool - - // The amount of time to wait for the first header byte. - Timeout time.Duration - - // Out-of-band error logger - Logger *log.Logger - - // Path to root X.509 certificate. - nodeX509CACert string - - // Path to X509 certificate - nodeX509Cert string - - // Path to X509 key. - nodeX509Key string - - // Whether to skip verification of other nodes' certificates. - InsecureSkipVerify bool - - tlsConfig *tls.Config -} - -// NewMux returns a new instance of Mux for ln. If adv is nil, -// then the addr of ln is used. -func NewMux(ln net.Listener, adv net.Addr) (*Mux, error) { - addr := adv - if addr == nil { - addr = ln.Addr() - } - - return &Mux{ - ln: ln, - addr: addr, - m: make(map[byte]*listener), - Timeout: DefaultTimeout, - Logger: log.New(os.Stderr, "[mux] ", log.LstdFlags), - }, nil -} - -// NewTLSMux returns a new instance of Mux for ln, and encrypts all traffic -// using TLS. If adv is nil, then the addr of ln is used. -func NewTLSMux(ln net.Listener, adv net.Addr, cert, key, caCert string) (*Mux, error) { - mux, err := NewMux(ln, adv) - if err != nil { - return nil, err - } - - mux.tlsConfig, err = createTLSConfig(cert, key, caCert) - if err != nil { - return nil, err - } - - mux.ln = tls.NewListener(ln, mux.tlsConfig) - mux.remoteEncrypted = true - mux.nodeX509CACert = caCert - mux.nodeX509Cert = cert - mux.nodeX509Key = key - - return mux, nil -} - -// Serve handles connections from ln and multiplexes then across registered listener. -func (mux *Mux) Serve() error { - mux.Logger.Printf("mux serving on %s, advertising %s", mux.ln.Addr().String(), mux.addr) - - for { - // Wait for the next connection. - // If it returns a temporary error then simply retry. - // If it returns any other error then exit immediately. - conn, err := mux.ln.Accept() - if err, ok := err.(interface { - Temporary() bool - }); ok && err.Temporary() { - continue - } - if err != nil { - // Wait for all connections to be demuxed - mux.wg.Wait() - for _, ln := range mux.m { - close(ln.c) - } - return err - } - - // Demux in a goroutine to - mux.wg.Add(1) - go mux.handleConn(conn) - } -} - -// Stats returns status of the mux. -func (mux *Mux) Stats() (interface{}, error) { - s := map[string]string{ - "addr": mux.addr.String(), - "timeout": mux.Timeout.String(), - "encrypted": strconv.FormatBool(mux.remoteEncrypted), - } - - if mux.remoteEncrypted { - s["certificate"] = mux.nodeX509Cert - s["key"] = mux.nodeX509Key - s["ca_certificate"] = mux.nodeX509CACert - s["skip_verify"] = strconv.FormatBool(mux.InsecureSkipVerify) - } - - return s, nil -} - -func (mux *Mux) handleConn(conn net.Conn) { - defer mux.wg.Done() - // Set a read deadline so connections with no data don't timeout. - if err := conn.SetReadDeadline(time.Now().Add(mux.Timeout)); err != nil { - conn.Close() - mux.Logger.Printf("tcp.Mux: cannot set read deadline: %s", err) - return - } - - // Read first byte from connection to determine handler. - var typ [1]byte - if _, err := io.ReadFull(conn, typ[:]); err != nil { - conn.Close() - mux.Logger.Printf("tcp.Mux: cannot read header byte: %s", err) - return - } - - // Reset read deadline and let the listener handle that. - if err := conn.SetReadDeadline(time.Time{}); err != nil { - conn.Close() - mux.Logger.Printf("tcp.Mux: cannot reset set read deadline: %s", err) - return - } - - // Retrieve handler based on first byte. - handler := mux.m[typ[0]] - if handler == nil { - conn.Close() - mux.Logger.Printf("tcp.Mux: handler not registered: %d", typ[0]) - return - } - - // Send connection to handler. The handler is responsible for closing the connection. - handler.c <- conn -} - -// Listen returns a listener identified by header. -// Any connection accepted by mux is multiplexed based on the initial header byte. -func (mux *Mux) Listen(header byte) *Layer { - // Ensure two listeners are not created for the same header byte. - if _, ok := mux.m[header]; ok { - panic(fmt.Sprintf("listener already registered under header byte: %d", header)) - } - - // Create a new listener and assign it. - ln := &listener{ - c: make(chan net.Conn), - } - mux.m[header] = ln - - layer := &Layer{ - ln: ln, - header: header, - addr: mux.addr, - remoteEncrypted: mux.remoteEncrypted, - skipVerify: mux.InsecureSkipVerify, - nodeX509CACert: mux.nodeX509CACert, - tlsConfig: mux.tlsConfig, - } - - return layer -} - -// listener is a receiver for connections received by Mux. -type listener struct { - c chan net.Conn -} - -// Accept waits for and returns the next connection to the listener. -func (ln *listener) Accept() (c net.Conn, err error) { - conn, ok := <-ln.c - if !ok { - return nil, errors.New("network connection closed") - } - return conn, nil -} - -// Close is a no-op. The mux's listener should be closed instead. -func (ln *listener) Close() error { return nil } - -// Addr always returns nil -func (ln *listener) Addr() net.Addr { return nil } - -// newTLSListener returns a net listener which encrypts the traffic using TLS. -func newTLSListener(ln net.Listener, certFile, keyFile, caCertFile string) (net.Listener, error) { - config, err := createTLSConfig(certFile, keyFile, caCertFile) - if err != nil { - return nil, err - } - - return tls.NewListener(ln, config), nil -} - -// createTLSConfig returns a TLS config from the given cert and key. -func createTLSConfig(certFile, keyFile, caCertFile string) (*tls.Config, error) { - var err error - config := &tls.Config{} - config.Certificates = make([]tls.Certificate, 1) - config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return nil, err - } - if caCertFile != "" { - asn1Data, err := ioutil.ReadFile(caCertFile) - if err != nil { - return nil, err - } - config.RootCAs = x509.NewCertPool() - ok := config.RootCAs.AppendCertsFromPEM([]byte(asn1Data)) - if !ok { - return nil, fmt.Errorf("failed to parse root certificate(s) in %s", caCertFile) - } - } - return config, nil -} diff --git a/tcp/mux_test.go b/tcp/mux_test.go deleted file mode 100644 index 8242e733..00000000 --- a/tcp/mux_test.go +++ /dev/null @@ -1,235 +0,0 @@ -package tcp - -import ( - "bytes" - "crypto/tls" - "io" - "io/ioutil" - "log" - "net" - "os" - "strings" - "sync" - "testing" - "testing/quick" - "time" - - "github.com/rqlite/rqlite/testdata/x509" -) - -// Ensure the muxer can split a listener's connections across multiple listeners. -func TestMux(t *testing.T) { - if err := quick.Check(func(n uint8, msg []byte) bool { - if testing.Verbose() { - if len(msg) == 0 { - log.Printf("n=%d, ", n) - } else { - log.Printf("n=%d, hdr=%d, len=%d", n, msg[0], len(msg)) - } - } - - var wg sync.WaitGroup - - // Open single listener on random port. - tcpListener := mustTCPListener("127.0.0.1:0") - defer tcpListener.Close() - - // Setup muxer & listeners. - mux, err := NewMux(tcpListener, nil) - if err != nil { - t.Fatalf("failed to create mux: %s", err.Error()) - } - mux.Timeout = 200 * time.Millisecond - if !testing.Verbose() { - mux.Logger = log.New(ioutil.Discard, "", 0) - } - for i := uint8(0); i < n; i++ { - ln := mux.Listen(byte(i)) - - wg.Add(1) - go func(i uint8, ln net.Listener) { - defer wg.Done() - - // Wait for a connection for this listener. - conn, err := ln.Accept() - if conn != nil { - defer conn.Close() - } - - // If there is no message or the header byte - // doesn't match then expect close. - if len(msg) == 0 || msg[0] != byte(i) { - if err == nil || err.Error() != "network connection closed" { - t.Fatalf("unexpected error: %s", err) - } - return - } - - // If the header byte matches this listener - // then expect a connection and read the message. - var buf bytes.Buffer - if _, err := io.CopyN(&buf, conn, int64(len(msg)-1)); err != nil { - t.Fatal(err) - } else if !bytes.Equal(msg[1:], buf.Bytes()) { - t.Fatalf("message mismatch:\n\nexp=%x\n\ngot=%x\n\n", msg[1:], buf.Bytes()) - } - - // Write response. - if _, err := conn.Write([]byte("OK")); err != nil { - t.Fatal(err) - } - }(i, ln) - } - - // Begin serving from the listener. - go mux.Serve() - - // Write message to TCP listener and read OK response. - conn, err := net.Dial("tcp", tcpListener.Addr().String()) - if err != nil { - t.Fatal(err) - } else if _, err = conn.Write(msg); err != nil { - t.Fatal(err) - } - - // Read the response into the buffer. - var resp [2]byte - _, err = io.ReadFull(conn, resp[:]) - - // If the message header is less than n then expect a response. - // Otherwise we should get an EOF because the mux closed. - if len(msg) > 0 && uint8(msg[0]) < n { - if string(resp[:]) != `OK` { - t.Fatalf("unexpected response: %s", resp[:]) - } - } else { - if err == nil || (err != io.EOF && !(strings.Contains(err.Error(), "connection reset by peer") || - strings.Contains(err.Error(), "closed by the remote host"))) { - t.Fatalf("unexpected error: %s", err) - } - } - - // Close connection. - if err := conn.Close(); err != nil { - t.Fatal(err) - } - - // Close original TCP listener and wait for all goroutines to close. - tcpListener.Close() - wg.Wait() - - return true - }, nil); err != nil { - t.Error(err) - } -} - -func TestMux_Advertise(t *testing.T) { - // Setup muxer. - tcpListener := mustTCPListener("127.0.0.1:0") - defer tcpListener.Close() - - addr := &mockAddr{ - Nwk: "tcp", - Addr: "rqlite.com:8081", - } - - mux, err := NewMux(tcpListener, addr) - if err != nil { - t.Fatalf("failed to create mux: %s", err.Error()) - } - mux.Timeout = 200 * time.Millisecond - if !testing.Verbose() { - mux.Logger = log.New(ioutil.Discard, "", 0) - } - - layer := mux.Listen(1) - if layer.Addr().String() != addr.Addr { - t.Fatalf("layer advertise address not correct, exp %s, got %s", - layer.Addr().String(), addr.Addr) - } -} - -// Ensure two handlers cannot be registered for the same header byte. -func TestMux_Listen_ErrAlreadyRegistered(t *testing.T) { - defer func() { - if r := recover(); r != `listener already registered under header byte: 5` { - t.Fatalf("unexpected recover: %#v", r) - } - }() - - // Register two listeners with the same header byte. - tcpListener := mustTCPListener("127.0.0.1:0") - mux, err := NewMux(tcpListener, nil) - if err != nil { - t.Fatalf("failed to create mux: %s", err.Error()) - } - mux.Listen(5) - mux.Listen(5) -} - -func TestTLSMux(t *testing.T) { - tcpListener := mustTCPListener("127.0.0.1:0") - defer tcpListener.Close() - - cert := x509.CertFile() - defer os.Remove(cert) - key := x509.KeyFile() - defer os.Remove(key) - - mux, err := NewTLSMux(tcpListener, nil, cert, key, "") - if err != nil { - t.Fatalf("failed to create mux: %s", err.Error()) - } - go mux.Serve() - - // Verify that the listener is secured. - conn, err := tls.Dial("tcp", tcpListener.Addr().String(), &tls.Config{ - InsecureSkipVerify: true, - }) - if err != nil { - t.Fatal(err) - } - - state := conn.ConnectionState() - if !state.HandshakeComplete { - t.Fatal("connection handshake failed to complete") - } -} - -func TestTLSMux_Fail(t *testing.T) { - tcpListener := mustTCPListener("127.0.0.1:0") - defer tcpListener.Close() - - cert := x509.CertFile() - defer os.Remove(cert) - key := x509.KeyFile() - defer os.Remove(key) - - _, err := NewTLSMux(tcpListener, nil, "xxxx", "yyyy", "") - if err == nil { - t.Fatalf("created mux unexpectedly with bad resources") - } -} - -type mockAddr struct { - Nwk string - Addr string -} - -func (m *mockAddr) Network() string { - return m.Nwk -} - -func (m *mockAddr) String() string { - return m.Addr -} - -// mustTCPListener returns a listener on bind, or panics. -func mustTCPListener(bind string) net.Listener { - l, err := net.Listen("tcp", bind) - if err != nil { - panic(err) - } - return l -} diff --git a/tcp/transport.go b/tcp/transport.go new file mode 100644 index 00000000..4a3156f1 --- /dev/null +++ b/tcp/transport.go @@ -0,0 +1,101 @@ +package tcp + +import ( + "crypto/tls" + "fmt" + "net" + "time" +) + +// Transport is the network layer for internode communications. +type Transport struct { + ln net.Listener + + certFile string // Path to local X.509 cert. + certKey string // Path to corresponding X.509 key. + remoteEncrypted bool // Remote nodes use encrypted communication. + skipVerify bool // Skip verification of remote node certs. +} + +// NewTransport returns an initialized unecrypted Transport. +func NewTransport() *Transport { + return &Transport{} +} + +// NewTransport returns an initialized TLS-ecrypted Transport. +func NewTLSTransport(certFile, keyPath string, skipVerify bool) *Transport { + return &Transport{ + certFile: certFile, + certKey: keyPath, + remoteEncrypted: true, + skipVerify: skipVerify, + } +} + +// Open opens the transport, binding to the supplied address. +func (t *Transport) Open(addr string) error { + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + if t.certFile != "" { + config, err := createTLSConfig(t.certFile, t.certKey) + if err != nil { + return err + } + ln = tls.NewListener(ln, config) + } + + t.ln = ln + return nil +} + +// Dial opens a network connection. +func (t *Transport) Dial(addr string, timeout time.Duration) (net.Conn, error) { + dialer := &net.Dialer{Timeout: timeout} + + var err error + var conn net.Conn + if t.remoteEncrypted { + conf := &tls.Config{ + InsecureSkipVerify: t.skipVerify, + } + fmt.Println("doing a TLS dial") + conn, err = tls.DialWithDialer(dialer, "tcp", addr, conf) + } else { + conn, err = dialer.Dial("tcp", addr) + } + + return conn, err +} + +// Accept waits for the next connection. +func (t *Transport) Accept() (net.Conn, error) { + c, err := t.ln.Accept() + if err != nil { + fmt.Println("error accepting: ", err.Error()) + } + return c, err +} + +// Close closes the transport +func (t *Transport) Close() error { + return t.ln.Close() +} + +// Addr returns the binding address of the transport. +func (t *Transport) Addr() net.Addr { + return t.ln.Addr() +} + +// createTLSConfig returns a TLS config from the given cert and key. +func createTLSConfig(certFile, keyFile string) (*tls.Config, error) { + var err error + config := &tls.Config{} + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + return config, nil +} diff --git a/tcp/transport_test.go b/tcp/transport_test.go new file mode 100644 index 00000000..782f6655 --- /dev/null +++ b/tcp/transport_test.go @@ -0,0 +1,70 @@ +package tcp + +import ( + "os" + "testing" + "time" + + "github.com/rqlite/rqlite/testdata/x509" +) + +func Test_NewTransport(t *testing.T) { + if NewTransport() == nil { + t.Fatal("failed to create new Transport") + } +} + +func Test_TransportOpenClose(t *testing.T) { + tn := NewTransport() + if err := tn.Open("localhost:0"); err != nil { + t.Fatalf("failed to open transport: %s", err.Error()) + } + if tn.Addr().String() == "localhost:0" { + t.Fatalf("transport address set incorrectly, got: %s", tn.Addr().String()) + } + if err := tn.Close(); err != nil { + t.Fatalf("failed to close transport: %s", err.Error()) + } +} + +func Test_TransportDial(t *testing.T) { + tn1 := NewTransport() + tn1.Open("localhost:0") + go tn1.Accept() + tn2 := NewTransport() + + _, err := tn2.Dial(tn1.Addr().String(), time.Second) + if err != nil { + t.Fatalf("failed to connect to first transport: %s", err.Error()) + } + tn1.Close() +} + +func Test_NewTLSTransport(t *testing.T) { + c := x509.CertFile() + defer os.Remove(c) + k := x509.KeyFile() + defer os.Remove(k) + + if NewTLSTransport(c, k, true) == nil { + t.Fatal("failed to create new TLS Transport") + } +} + +func Test_TLSTransportOpenClose(t *testing.T) { + c := x509.CertFile() + defer os.Remove(c) + k := x509.KeyFile() + defer os.Remove(k) + + tn := NewTLSTransport(c, k, true) + if err := tn.Open("localhost:0"); err != nil { + t.Fatalf("failed to open TLS transport: %s", err.Error()) + } + if tn.Addr().String() == "localhost:0" { + t.Fatalf("TLS transport address set incorrectly, got: %s", tn.Addr().String()) + } + if err := tn.Close(); err != nil { + t.Fatalf("failed to close TLS transport: %s", err.Error()) + } +}