From f57ace7da25eb405971e26c0da5613a4f54219b0 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Fri, 20 Dec 2019 07:52:27 -0500 Subject: [PATCH] Broadcast Store meta via standard consensus With this change the cluster metadata (arbitrary key-value data associated with each node) is now broadcast across the cluster using the standard consensus mechanism. Specifically the use case for this metadata is to allow all nodes know the HTTP API address of all other nodes, for the purpose of redirecting requests to the leader. This change removed the need for multiplexing two logical connections over the single Raft TCP connection, which greatly simplifies the networking code generally. Original PR https://github.com/rqlite/rqlite/pull/434 --- DOC/CLUSTER_MGMT.md | 2 +- cluster/join.go | 21 +-- cluster/join_test.go | 56 ++++++- cluster/service.go | 180 ---------------------- cluster/service_test.go | 194 ------------------------ cmd/rqlited/main.go | 150 +++++++----------- http/service.go | 56 ++++--- http/service_test.go | 10 +- store/command.go | 20 ++- store/db_config.go | 12 ++ store/server.go | 14 ++ store/store.go | 328 ++++++++++++++++++++++++---------------- store/store_test.go | 166 ++++++++++---------- store/transport.go | 39 +++-- store/transport_test.go | 11 ++ system_test/helpers.go | 36 ++--- tcp/doc.go | 2 +- tcp/mux.go | 301 ------------------------------------ tcp/mux_test.go | 235 ---------------------------- tcp/transport.go | 101 +++++++++++++ tcp/transport_test.go | 70 +++++++++ 21 files changed, 697 insertions(+), 1307 deletions(-) delete mode 100644 cluster/service.go delete mode 100644 cluster/service_test.go create mode 100644 store/db_config.go create mode 100644 store/server.go create mode 100644 store/transport_test.go delete mode 100644 tcp/mux.go delete mode 100644 tcp/mux_test.go create mode 100644 tcp/transport.go create mode 100644 tcp/transport_test.go diff --git a/DOC/CLUSTER_MGMT.md b/DOC/CLUSTER_MGMT.md index baee18b2..1eeff33c 100644 --- a/DOC/CLUSTER_MGMT.md +++ b/DOC/CLUSTER_MGMT.md @@ -57,7 +57,7 @@ You can grow a cluster, at anytime, simply by starting up a new node and having # Removing or replacing a node If a node fails completely and is not coming back, or if you shut down a node because you wish to deprovision it, its record should also be removed from the cluster. To remove the record of a node from a cluster, execute the following command: ``` -curl -XDELETE http://localhost:4001/remove -d '{"addr": ""}' +curl -XDELETE http://localhost:4001/remove -d '{"id": ""}' ``` assuming `localhost` is the address of the cluster leader. If you do not do this the leader will continually attempt to communicate with that node. diff --git a/cluster/join.go b/cluster/join.go index 5458cc72..95b3ac14 100644 --- a/cluster/join.go +++ b/cluster/join.go @@ -20,9 +20,9 @@ const numAttempts int = 3 const attemptInterval time.Duration = 5 * time.Second // It walks through joinAddr in order, and sets the node ID and Raft address of -// the joining node as nodeID advAddr respectively. It returns the endpoint -// successfully used to join the cluster. -func Join(joinAddr []string, nodeID, advAddr string, tlsConfig *tls.Config) (string, error) { +// the joining node as id addr respectively. It returns the endpoint successfully +// used to join the cluster. +func Join(joinAddr []string, id, addr string, meta map[string]string, tlsConfig *tls.Config) (string, error) { var err error var j string logger := log.New(os.Stderr, "[cluster-join] ", log.LstdFlags) @@ -32,7 +32,7 @@ func Join(joinAddr []string, nodeID, advAddr string, tlsConfig *tls.Config) (str for i := 0; i < numAttempts; i++ { for _, a := range joinAddr { - j, err = join(a, nodeID, advAddr, tlsConfig, logger) + j, err = join(a, id, addr, meta, tlsConfig, logger) if err == nil { // Success! return j, nil @@ -45,13 +45,13 @@ func Join(joinAddr []string, nodeID, advAddr string, tlsConfig *tls.Config) (str return "", err } -func join(joinAddr, nodeID, advAddr string, tlsConfig *tls.Config, logger *log.Logger) (string, error) { - if nodeID == "" { +func join(joinAddr, id, addr string, meta map[string]string, tlsConfig *tls.Config, logger *log.Logger) (string, error) { + if id == "" { return "", fmt.Errorf("node ID not set") } // Join using IP address, as that is what Hashicorp Raft works in. - resv, err := net.ResolveTCPAddr("tcp", advAddr) + resv, err := net.ResolveTCPAddr("tcp", addr) if err != nil { return "", err } @@ -59,7 +59,7 @@ func join(joinAddr, nodeID, advAddr string, tlsConfig *tls.Config, logger *log.L // Check for protocol scheme, and insert default if necessary. fullAddr := httpd.NormalizeAddr(fmt.Sprintf("%s/join", joinAddr)) - // Enable skipVerify as requested. + // Create and configure the client to connect to the other node. tr := &http.Transport{ TLSClientConfig: tlsConfig, } @@ -69,9 +69,10 @@ func join(joinAddr, nodeID, advAddr string, tlsConfig *tls.Config, logger *log.L } for { - b, err := json.Marshal(map[string]string{ - "id": nodeID, + b, err := json.Marshal(map[string]interface{}{ + "id": id, "addr": resv.String(), + "meta": meta, }) // Attempt to join. diff --git a/cluster/join_test.go b/cluster/join_test.go index c521800f..98a53f67 100644 --- a/cluster/join_test.go +++ b/cluster/join_test.go @@ -1,7 +1,9 @@ package cluster import ( + "encoding/json" "fmt" + "io/ioutil" "net/http" "net/http/httptest" "testing" @@ -16,7 +18,7 @@ func Test_SingleJoinOK(t *testing.T) { })) defer ts.Close() - j, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil) + j, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil, nil) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) } @@ -25,13 +27,57 @@ func Test_SingleJoinOK(t *testing.T) { } } +func Test_SingleJoinMetaOK(t *testing.T) { + var body map[string]interface{} + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != "POST" { + t.Fatalf("Client did not use POST") + } + w.WriteHeader(http.StatusOK) + + b, err := ioutil.ReadAll(r.Body) + if err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + + if err := json.Unmarshal(b, &body); err != nil { + w.WriteHeader(http.StatusBadRequest) + return + } + })) + defer ts.Close() + + nodeAddr := "127.0.0.1:9090" + md := map[string]string{"foo": "bar"} + j, err := Join([]string{ts.URL}, "id0", nodeAddr, md, nil) + if err != nil { + t.Fatalf("failed to join a single node: %s", err.Error()) + } + if j != ts.URL+"/join" { + t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts.URL) + } + + if id, _ := body["id"]; id != "id0" { + t.Fatalf("node joined supplying wrong ID, exp %s, got %s", "id0", body["id"]) + } + if addr, _ := body["addr"]; addr != nodeAddr { + t.Fatalf("node joined supplying wrong address, exp %s, got %s", nodeAddr, body["addr"]) + } + rxMd, _ := body["meta"].(map[string]interface{}) + if len(rxMd) != len(md) || rxMd["foo"] != "bar" { + t.Fatalf("node joined supplying wrong meta") + } +} + func Test_SingleJoinFail(t *testing.T) { ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) })) defer ts.Close() - _, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil) + _, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", nil, nil) if err == nil { t.Fatalf("expected error when joining bad node") } @@ -45,7 +91,7 @@ func Test_DoubleJoinOK(t *testing.T) { })) defer ts2.Close() - j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil) + j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil, nil) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) } @@ -63,7 +109,7 @@ func Test_DoubleJoinOKSecondNode(t *testing.T) { })) defer ts2.Close() - j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil) + j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", nil, nil) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) } @@ -83,7 +129,7 @@ func Test_DoubleJoinOKSecondNodeRedirect(t *testing.T) { })) defer ts2.Close() - j, err := Join([]string{ts2.URL}, "id0", "127.0.0.1:9090", nil) + j, err := Join([]string{ts2.URL}, "id0", "127.0.0.1:9090", nil, nil) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) } diff --git a/cluster/service.go b/cluster/service.go deleted file mode 100644 index d51207b7..00000000 --- a/cluster/service.go +++ /dev/null @@ -1,180 +0,0 @@ -package cluster - -import ( - "encoding/json" - "fmt" - "log" - "net" - "os" - "sync" - "time" -) - -const ( - connectionTimeout = 10 * time.Second -) - -var respOKMarshalled []byte - -func init() { - var err error - respOKMarshalled, err = json.Marshal(response{}) - if err != nil { - panic(fmt.Sprintf("unable to JSON marshal OK response: %s", err.Error())) - } -} - -type response struct { - Code int `json:"code,omitempty"` - Message string `json:"message,omitempty"` -} - -// Transport is the interface the network service must provide. -type Transport interface { - net.Listener - - // Dial is used to create a new outgoing connection - Dial(address string, timeout time.Duration) (net.Conn, error) -} - -// Store represents a store of information, managed via consensus. -type Store interface { - // Leader returns the address of the leader of the consensus system. - LeaderAddr() string - - // UpdateAPIPeers updates the API peers on the store. - UpdateAPIPeers(peers map[string]string) error -} - -// Service allows access to the cluster and associated meta data, -// via consensus. -type Service struct { - tn Transport - store Store - addr net.Addr - - wg sync.WaitGroup - - logger *log.Logger -} - -// NewService returns a new instance of the cluster service -func NewService(tn Transport, store Store) *Service { - return &Service{ - tn: tn, - store: store, - addr: tn.Addr(), - logger: log.New(os.Stderr, "[cluster] ", log.LstdFlags), - } -} - -// Open opens the Service. -func (s *Service) Open() error { - s.wg.Add(1) - go s.serve() - s.logger.Println("service listening on", s.tn.Addr()) - return nil -} - -// Close closes the service. -func (s *Service) Close() error { - s.tn.Close() - s.wg.Wait() - return nil -} - -// Addr returns the address the service is listening on. -func (s *Service) Addr() string { - return s.addr.String() -} - -// SetPeer will set the mapping between raftAddr and apiAddr for the entire cluster. -func (s *Service) SetPeer(raftAddr, apiAddr string) error { - peer := map[string]string{ - raftAddr: apiAddr, - } - - // Try the local store. It might be the leader. - err := s.store.UpdateAPIPeers(peer) - if err == nil { - // All done! Aren't we lucky? - return nil - } - - // Try talking to the leader over the network. - if leader := s.store.LeaderAddr(); leader == "" { - return fmt.Errorf("no leader available") - } - conn, err := s.tn.Dial(s.store.LeaderAddr(), connectionTimeout) - if err != nil { - return err - } - defer conn.Close() - - b, err := json.Marshal(peer) - if err != nil { - return err - } - - if _, err := conn.Write(b); err != nil { - return err - } - - // Wait for the response and verify the operation went through. - resp := response{} - d := json.NewDecoder(conn) - err = d.Decode(&resp) - if err != nil { - return err - } - - if resp.Code != 0 { - return fmt.Errorf(resp.Message) - } - return nil -} - -func (s *Service) serve() error { - defer s.wg.Done() - - for { - conn, err := s.tn.Accept() - if err != nil { - return err - } - - go s.handleConn(conn) - } -} - -func (s *Service) handleConn(conn net.Conn) { - s.logger.Printf("received connection from %s", conn.RemoteAddr().String()) - - // Only handles peers updates for now. - peers := make(map[string]string) - d := json.NewDecoder(conn) - err := d.Decode(&peers) - if err != nil { - return - } - - // Update the peers. - if err := s.store.UpdateAPIPeers(peers); err != nil { - resp := response{1, err.Error()} - b, err := json.Marshal(resp) - if err != nil { - conn.Close() // Only way left to signal. - } else { - if _, err := conn.Write(b); err != nil { - conn.Close() // Only way left to signal. - } - } - return - } - - // Let the remote node know everything went OK. - if _, err := conn.Write(respOKMarshalled); err != nil { - conn.Close() // Only way left to signal. - } - return -} diff --git a/cluster/service_test.go b/cluster/service_test.go deleted file mode 100644 index 2f38c346..00000000 --- a/cluster/service_test.go +++ /dev/null @@ -1,194 +0,0 @@ -package cluster - -import ( - "encoding/json" - "fmt" - "net" - "testing" - "time" -) - -func Test_NewServiceOpenClose(t *testing.T) { - ml := mustNewMockTransport() - ms := &mockStore{} - s := NewService(ml, ms) - if s == nil { - t.Fatalf("failed to create cluster service") - } - - if err := s.Open(); err != nil { - t.Fatalf("failed to open cluster service") - } - if err := s.Close(); err != nil { - t.Fatalf("failed to close cluster service") - } -} - -func Test_SetAPIPeer(t *testing.T) { - raftAddr, apiAddr := "localhost:4002", "localhost:4001" - - s, _, ms := mustNewOpenService() - defer s.Close() - if err := s.SetPeer(raftAddr, apiAddr); err != nil { - t.Fatalf("failed to set peer: %s", err.Error()) - } - - if ms.peers[raftAddr] != apiAddr { - t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr]) - } -} - -func Test_SetAPIPeerNetwork(t *testing.T) { - raftAddr, apiAddr := "localhost:4002", "localhost:4001" - - s, _, ms := mustNewOpenService() - defer s.Close() - - raddr, err := net.ResolveTCPAddr("tcp", s.Addr()) - if err != nil { - t.Fatalf("failed to resolve remote uster ervice address: %s", err.Error()) - } - - conn, err := net.DialTCP("tcp4", nil, raddr) - if err != nil { - t.Fatalf("failed to connect to remote cluster service: %s", err.Error()) - } - if _, err := conn.Write([]byte(fmt.Sprintf(`{"%s": "%s"}`, raftAddr, apiAddr))); err != nil { - t.Fatalf("failed to write to remote cluster service: %s", err.Error()) - } - - resp := response{} - d := json.NewDecoder(conn) - err = d.Decode(&resp) - if err != nil { - t.Fatalf("failed to decode response: %s", err.Error()) - } - - if resp.Code != 0 { - t.Fatalf("response code was non-zero") - } - - if ms.peers[raftAddr] != apiAddr { - t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr]) - } -} - -func Test_SetAPIPeerFailUpdate(t *testing.T) { - raftAddr, apiAddr := "localhost:4002", "localhost:4001" - - s, _, ms := mustNewOpenService() - defer s.Close() - ms.failUpdateAPIPeers = true - - // Attempt to set peer without a leader - if err := s.SetPeer(raftAddr, apiAddr); err == nil { - t.Fatalf("no error returned by set peer when no leader") - } - - // Start a network server. - tn, err := net.Listen("tcp", "localhost:0") - if err != nil { - t.Fatalf("failed to open test server: %s", err.Error()) - } - ms.leader = tn.Addr().String() - - c := make(chan map[string]string, 1) - go func() { - conn, err := tn.Accept() - if err != nil { - t.Fatalf("failed to accept connection from cluster: %s", err.Error()) - } - t.Logf("test server received connection from: %s", conn.RemoteAddr()) - - peers := make(map[string]string) - d := json.NewDecoder(conn) - err = d.Decode(&peers) - if err != nil { - t.Fatalf("failed to decode message from cluster: %s", err.Error()) - } - - // Response OK. - // Let the remote node know everything went OK. - if _, err := conn.Write(respOKMarshalled); err != nil { - t.Fatalf("failed to respond to cluster: %s", err.Error()) - } - - c <- peers - }() - - if err := s.SetPeer(raftAddr, apiAddr); err != nil { - t.Fatalf("failed to set peer on cluster: %s", err.Error()) - } - - peers := <-c - if peers[raftAddr] != apiAddr { - t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr]) - } -} - -func mustNewOpenService() (*Service, *mockTransport, *mockStore) { - ml := mustNewMockTransport() - ms := newMockStore() - s := NewService(ml, ms) - if err := s.Open(); err != nil { - panic("failed to open new service") - } - return s, ml, ms -} - -type mockTransport struct { - tn net.Listener -} - -func mustNewMockTransport() *mockTransport { - tn, err := net.Listen("tcp", "localhost:0") - if err != nil { - panic("failed to create mock listener") - } - return &mockTransport{ - tn: tn, - } -} - -func (ml *mockTransport) Accept() (c net.Conn, err error) { - return ml.tn.Accept() -} - -func (ml *mockTransport) Addr() net.Addr { - return ml.tn.Addr() -} - -func (ml *mockTransport) Close() (err error) { - return ml.tn.Close() -} - -func (ml *mockTransport) Dial(addr string, t time.Duration) (net.Conn, error) { - return net.DialTimeout("tcp", addr, 5*time.Second) -} - -type mockStore struct { - leader string - peers map[string]string - failUpdateAPIPeers bool -} - -func newMockStore() *mockStore { - return &mockStore{ - peers: make(map[string]string), - } -} - -func (ms *mockStore) LeaderAddr() string { - return ms.leader -} - -func (ms *mockStore) UpdateAPIPeers(peers map[string]string) error { - if ms.failUpdateAPIPeers { - return fmt.Errorf("forced fail") - } - - for k, v := range peers { - ms.peers[k] = v - } - return nil -} diff --git a/cmd/rqlited/main.go b/cmd/rqlited/main.go index a1ad44d4..993cd26e 100644 --- a/cmd/rqlited/main.go +++ b/cmd/rqlited/main.go @@ -8,7 +8,6 @@ import ( "fmt" "io/ioutil" "log" - "net" "os" "os/signal" "path/filepath" @@ -161,47 +160,28 @@ func main() { // Start requested profiling. startProfile(cpuProfile, memProfile) - // Set up node-to-node TCP communication. - ln, err := net.Listen("tcp", raftAddr) - if err != nil { - log.Fatalf("failed to listen on %s: %s", raftAddr, err.Error()) - } - var adv net.Addr - if raftAdv != "" { - adv, err = net.ResolveTCPAddr("tcp", raftAdv) - if err != nil { - log.Fatalf("failed to resolve advertise address %s: %s", raftAdv, err.Error()) - } - } - - // Start up node-to-node network mux. - var mux *tcp.Mux + // Create internode network layer. + var tn *tcp.Transport if nodeEncrypt { log.Printf("enabling node-to-node encryption with cert: %s, key: %s", nodeX509Cert, nodeX509Key) - mux, err = tcp.NewTLSMux(ln, adv, nodeX509Cert, nodeX509Key, nodeX509CACert) + tn = tcp.NewTLSTransport(nodeX509Cert, nodeX509Key, noVerify) } else { - mux, err = tcp.NewMux(ln, adv) + tn = tcp.NewTransport() } - if err != nil { - log.Fatalf("failed to create node-to-node mux: %s", err.Error()) + if err := tn.Open(raftAddr); err != nil { + log.Fatalf("failed to open internode network layer: %s", err.Error()) } - mux.InsecureSkipVerify = noNodeVerify - go mux.Serve() - - // Get transport for Raft communications. - raftTn := mux.Listen(muxRaftHeader) // Create and open the store. - dataPath, err = filepath.Abs(dataPath) + dataPath, err := filepath.Abs(dataPath) if err != nil { log.Fatalf("failed to determine absolute data path: %s", err.Error()) } dbConf := store.NewDBConfig(dsn, !onDisk) - str := store.New(&store.StoreConfig{ + str := store.New(tn, &store.StoreConfig{ DBConf: dbConf, Dir: dataPath, - Tn: raftTn, ID: idOrRaftAddr(), }) @@ -220,10 +200,6 @@ func main() { if err != nil { log.Fatalf("failed to parse Raft apply timeout %s: %s", raftApplyTimeout, err.Error()) } - str.OpenTimeout, err = time.ParseDuration(raftOpenTimeout) - if err != nil { - log.Fatalf("failed to parse Raft open timeout %s: %s", raftOpenTimeout, err.Error()) - } // Determine join addresses, if necessary. ja, err := store.JoinAllowed(dataPath) @@ -241,16 +217,18 @@ func main() { log.Println("node is already member of cluster, skip determining join addresses") } - // Now, open it. + // Now, open store. if err := str.Open(len(joins) == 0); err != nil { log.Fatalf("failed to open store: %s", err.Error()) } - // Create and configure cluster service. - tn := mux.Listen(muxMetaHeader) - cs := cluster.NewService(tn, str) - if err := cs.Open(); err != nil { - log.Fatalf("failed to open cluster service: %s", err.Error()) + // Prepare metadata for join command. + apiAdv := httpAddr + if httpAdv != "" { + apiAdv = httpAdv + } + meta := map[string]string{ + "api_addr": apiAdv, } // Execute any requested join operation. @@ -274,7 +252,7 @@ func main() { } } - if j, err := cluster.Join(joins, str.ID(), advAddr, &tlsConfig); err != nil { + if j, err := cluster.Join(joins, str.ID(), advAddr, meta, &tlsConfig); err != nil { log.Fatalf("failed to join cluster at %s: %s", joins, err.Error()) } else { log.Println("successfully joined cluster at", j) @@ -284,53 +262,26 @@ func main() { log.Println("no join addresses set") } - // Publish to the cluster the mapping between this Raft address and API address. - // The Raft layer broadcasts the resolved address, so use that as the key. But - // only set different HTTP advertise address if set. - apiAdv := httpAddr - if httpAdv != "" { - apiAdv = httpAdv - } - - if err := publishAPIAddr(cs, raftTn.Addr().String(), apiAdv, publishPeerTimeout); err != nil { - log.Fatalf("failed to set peer for %s to %s: %s", raftAddr, httpAddr, err.Error()) - } - log.Printf("set peer for %s to %s", raftTn.Addr().String(), apiAdv) - - // Get the credential store. - credStr, err := credentialStore() + // Wait until the store is in full consensus. + openTimeout, err := time.ParseDuration(raftOpenTimeout) if err != nil { - log.Fatalf("failed to get credential store: %s", err.Error()) + log.Fatalf("failed to parse Raft open timeout %s: %s", raftOpenTimeout, err.Error()) } + str.WaitForLeader(openTimeout) + str.WaitForApplied(openTimeout) - // Create HTTP server and load authentication information if required. - var s *httpd.Service - if credStr != nil { - s = httpd.New(httpAddr, str, credStr) - } else { - s = httpd.New(httpAddr, str, nil) + // This may be a standalone server. In that case set its own metadata. + if err := str.SetMetadata(meta); err != nil && err != store.ErrNotLeader { + // Non-leader errors are OK, since metadata will then be set through + // consensus as a result of a join. All other errors indicate a problem. + log.Fatalf("failed to set store metadata: %s", err.Error()) } - s.CACertFile = x509CACert - s.CertFile = x509Cert - s.KeyFile = x509Key - s.Expvar = expvar - s.Pprof = pprofEnabled - s.BuildInfo = map[string]interface{}{ - "commit": commit, - "branch": branch, - "version": version, - "build_time": buildtime, - } - if err := s.Start(); err != nil { + // Start the HTTP API server. + if err := startHTTPService(str); err != nil { log.Fatalf("failed to start HTTP server: %s", err.Error()) } - // Register cross-component statuses. - if err := s.RegisterStatus("mux", mux); err != nil { - log.Fatalf("failed to register mux status: %s", err.Error()) - } - // Block until signalled. terminate := make(chan os.Signal, 1) signal.Notify(terminate, os.Interrupt) @@ -373,25 +324,32 @@ func determineJoinAddresses() ([]string, error) { return addrs, nil } -func publishAPIAddr(c *cluster.Service, raftAddr, apiAddr string, t time.Duration) error { - tck := time.NewTicker(publishPeerDelay) - defer tck.Stop() - tmr := time.NewTimer(t) - defer tmr.Stop() - - for { - select { - case <-tck.C: - if err := c.SetPeer(raftAddr, apiAddr); err != nil { - log.Printf("failed to set peer for %s to %s: %s (retrying)", - raftAddr, apiAddr, err.Error()) - continue - } - return nil - case <-tmr.C: - return fmt.Errorf("set peer timeout expired") - } +func startHTTPService(str *store.Store) error { + // Get the credential store. + credStr, err := credentialStore() + if err != nil { + return err + } + + // Create HTTP server and load authentication information if required. + var s *httpd.Service + if credStr != nil { + s = httpd.New(httpAddr, str, credStr) + } else { + s = httpd.New(httpAddr, str, nil) + } + + s.CertFile = x509Cert + s.KeyFile = x509Key + s.Expvar = expvar + s.Pprof = pprofEnabled + s.BuildInfo = map[string]interface{}{ + "commit": commit, + "branch": branch, + "version": version, + "build_time": buildtime, } + return s.Start() } func credentialStore() (*auth.CredentialsStore, error) { diff --git a/http/service.go b/http/service.go index 362ca07b..74a4a45a 100644 --- a/http/service.go +++ b/http/service.go @@ -44,16 +44,16 @@ type Store interface { Query(qr *store.QueryRequest) ([]*sql.Rows, error) // Join joins the node with the given ID, reachable at addr, to this node. - Join(id, addr string) error + Join(id, addr string, metadata map[string]string) error - // Remove removes the node, specified by addr, from the cluster. - Remove(addr string) error + // Remove removes the node, specified by id, from the cluster. + Remove(id string) error - // LeaderAddr returns the Raft address of the leader of the cluster. - LeaderAddr() string + // Metadata returns the value for the given node ID, for the given key. + Metadata(id, key string) string - // Peer returns the API peer for the given address - Peer(addr string) string + // Leader returns the Raft address of the leader of the cluster. + LeaderID() (string, error) // Stats returns stats on the Store. Stats() (map[string]interface{}, error) @@ -289,27 +289,35 @@ func (s *Service) handleJoin(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusBadRequest) return } - m := map[string]string{} - if err := json.Unmarshal(b, &m); err != nil { + md := map[string]interface{}{} + if err := json.Unmarshal(b, &md); err != nil { w.WriteHeader(http.StatusBadRequest) return } - remoteID, ok := m["id"] + remoteID, ok := md["id"] if !ok { w.WriteHeader(http.StatusBadRequest) return } + var m map[string]string + if _, ok := md["meta"].(map[string]interface{}); ok { + m = make(map[string]string) + for k, v := range md["meta"].(map[string]interface{}) { + m[k] = v.(string) + } + } - remoteAddr, ok := m["addr"] + remoteAddr, ok := md["addr"] if !ok { + fmt.Println("4444") w.WriteHeader(http.StatusBadRequest) return } - if err := s.store.Join(remoteID, remoteAddr); err != nil { + if err := s.store.Join(remoteID.(string), remoteAddr.(string), m); err != nil { if err == store.ErrNotLeader { - leader := s.store.Peer(s.store.LeaderAddr()) + leader := s.leaderAPIAddr() if leader == "" { http.Error(w, err.Error(), http.StatusServiceUnavailable) return @@ -355,15 +363,15 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) { return } - remoteAddr, ok := m["addr"] + remoteID, ok := m["id"] if !ok { w.WriteHeader(http.StatusBadRequest) return } - if err := s.store.Remove(remoteAddr); err != nil { + if err := s.store.Remove(remoteID); err != nil { if err == store.ErrNotLeader { - leader := s.store.Peer(s.store.LeaderAddr()) + leader := s.leaderAPIAddr() if leader == "" { http.Error(w, err.Error(), http.StatusServiceUnavailable) return @@ -448,7 +456,7 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) { results, err := s.store.ExecuteOrAbort(&store.ExecuteRequest{queries, timings, false}) if err != nil { if err == store.ErrNotLeader { - leader := s.store.Peer(s.store.LeaderAddr()) + leader := s.leaderAPIAddr() if leader == "" { http.Error(w, err.Error(), http.StatusServiceUnavailable) return @@ -498,7 +506,7 @@ func (s *Service) handleStatus(w http.ResponseWriter, r *http.Request) { httpStatus := map[string]interface{}{ "addr": s.Addr().String(), "auth": prettyEnabled(s.credentialStore != nil), - "redirect": s.store.Peer(s.store.LeaderAddr()), + "redirect": s.leaderAPIAddr(), } nodeStatus := map[string]interface{}{ @@ -595,7 +603,7 @@ func (s *Service) handleExecute(w http.ResponseWriter, r *http.Request) { results, err := s.store.Execute(&store.ExecuteRequest{queries, timings, isTx}) if err != nil { if err == store.ErrNotLeader { - leader := s.store.Peer(s.store.LeaderAddr()) + leader := s.leaderAPIAddr() if leader == "" { http.Error(w, err.Error(), http.StatusServiceUnavailable) return @@ -657,7 +665,7 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) { results, err := s.store.Query(&store.QueryRequest{queries, timings, isTx, lvl}) if err != nil { if err == store.ErrNotLeader { - leader := s.store.Peer(s.store.LeaderAddr()) + leader := s.leaderAPIAddr() if leader == "" { http.Error(w, err.Error(), http.StatusServiceUnavailable) return @@ -746,6 +754,14 @@ func (s *Service) CheckRequestPerm(r *http.Request, perm string) bool { return s.credentialStore.HasPerm(username, PermAll) || s.credentialStore.HasPerm(username, perm) } +func (s *Service) leaderAPIAddr() string { + id, err := s.store.LeaderID() + if err != nil { + return "" + } + return s.store.Metadata(id, "api_addr") +} + // addBuildVersion adds the build version to the HTTP response. func (s *Service) addBuildVersion(w http.ResponseWriter) { // Add version header to every response, if available. diff --git a/http/service_test.go b/http/service_test.go index 4d0e51f6..eda78be3 100644 --- a/http/service_test.go +++ b/http/service_test.go @@ -491,19 +491,19 @@ func (m *MockStore) Query(qr *store.QueryRequest) ([]*sql.Rows, error) { return nil, nil } -func (m *MockStore) Join(id, addr string) error { +func (m *MockStore) Join(id, addr string, metadata map[string]string) error { return nil } -func (m *MockStore) Remove(addr string) error { +func (m *MockStore) Remove(id string) error { return nil } -func (m *MockStore) LeaderAddr() string { - return "" +func (m *MockStore) LeaderID() (string, error) { + return "", nil } -func (m *MockStore) Peer(addr string) string { +func (m *MockStore) Metadata(id, key string) string { return "" } diff --git a/store/command.go b/store/command.go index 7414b7be..5538c319 100644 --- a/store/command.go +++ b/store/command.go @@ -8,9 +8,10 @@ import ( type commandType int const ( - execute commandType = iota // Commands which modify the database. - query // Commands which query the database. - peer // Commands that modify peers map. + execute commandType = iota // Commands which modify the database. + query // Commands which query the database. + metadataSet // Commands which sets Store metadata + metadataDelete // Commands which deletes Store metadata ) type command struct { @@ -27,7 +28,14 @@ func newCommand(t commandType, d interface{}) (*command, error) { Typ: t, Sub: b, }, nil +} +func newMetadataSetCommand(id string, md map[string]string) (*command, error) { + m := metadataSetSub{ + RaftID: id, + Data: md, + } + return newCommand(metadataSet, m) } // databaseSub is a command sub which involves interaction with the database. @@ -37,5 +45,7 @@ type databaseSub struct { Timings bool `json:"timings,omitempty"` } -// peersSub is a command which sets the API address for a Raft address. -type peersSub map[string]string +type metadataSetSub struct { + RaftID string `json:"raft_id,omitempty"` + Data map[string]string `json:"data,omitempty"` +} diff --git a/store/db_config.go b/store/db_config.go new file mode 100644 index 00000000..5750843d --- /dev/null +++ b/store/db_config.go @@ -0,0 +1,12 @@ +package store + +// DBConfig represents the configuration of the underlying SQLite database. +type DBConfig struct { + DSN string // Any custom DSN + Memory bool // Whether the database is in-memory only. +} + +// NewDBConfig returns a new DB config instance. +func NewDBConfig(dsn string, memory bool) *DBConfig { + return &DBConfig{DSN: dsn, Memory: memory} +} diff --git a/store/server.go b/store/server.go new file mode 100644 index 00000000..42edb16c --- /dev/null +++ b/store/server.go @@ -0,0 +1,14 @@ +package store + +// Server represents another node in the cluster. +type Server struct { + ID string `json:"id,omitempty"` + Addr string `json:"addr,omitempty"` +} + +// Servers is a set of Servers. +type Servers []*Server + +func (s Servers) Less(i, j int) bool { return s[i].ID < s[j].ID } +func (s Servers) Len() int { return len(s) } +func (s Servers) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/store/store.go b/store/store.go index 22dd7864..d2c6c5f6 100644 --- a/store/store.go +++ b/store/store.go @@ -13,7 +13,6 @@ import ( "io" "io/ioutil" "log" - "net" "os" "path/filepath" "sort" @@ -46,6 +45,9 @@ const ( sqliteFile = "db.sqlite" leaderWaitDelay = 100 * time.Millisecond appliedWaitDelay = 100 * time.Millisecond + connectionPoolCount = 5 + connectionTimeout = 10 * time.Second + raftLogCacheSize = 512 ) const ( @@ -114,60 +116,6 @@ const ( Unknown ) -// clusterMeta represents cluster meta which must be kept in consensus. -type clusterMeta struct { - APIPeers map[string]string // Map from Raft address to API address -} - -// NewClusterMeta returns an initialized cluster meta store. -func newClusterMeta() *clusterMeta { - return &clusterMeta{ - APIPeers: make(map[string]string), - } -} - -func (c *clusterMeta) AddrForPeer(addr string) string { - if api, ok := c.APIPeers[addr]; ok && api != "" { - return api - } - - // Go through each entry, and see if any key resolves to addr. - for k, v := range c.APIPeers { - resv, err := net.ResolveTCPAddr("tcp", k) - if err != nil { - continue - } - if resv.String() == addr { - return v - } - } - - return "" -} - -// DBConfig represents the configuration of the underlying SQLite database. -type DBConfig struct { - DSN string // Any custom DSN - Memory bool // Whether the database is in-memory only. -} - -// NewDBConfig returns a new DB config instance. -func NewDBConfig(dsn string, memory bool) *DBConfig { - return &DBConfig{DSN: dsn, Memory: memory} -} - -// Server represents another node in the cluster. -type Server struct { - ID string `json:"id,omitempty"` - Addr string `json:"addr,omitempty"` -} - -type Servers []*Server - -func (s Servers) Less(i, j int) bool { return s[i].ID < s[j].ID } -func (s Servers) Len() int { return len(s) } -func (s Servers) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - // Store is a SQLite database, where all changes are made via Raft consensus. type Store struct { raftDir string @@ -175,14 +123,19 @@ type Store struct { mu sync.RWMutex // Sync access between queries and snapshots. raft *raft.Raft // The consensus mechanism. - raftTn *raftTransport + ln Listener + raftTn *raft.NetworkTransport raftID string // Node ID. dbConf *DBConfig // SQLite database config. dbPath string // Path to underlying SQLite file, if not in-memory. db *sql.DB // The underlying SQLite store. + raftLog raft.LogStore // Persistent log store. + raftStable raft.StableStore // Persistent k-v store. + boltStore *raftboltdb.BoltStore // Physical store. + metaMu sync.RWMutex - meta *clusterMeta + meta map[string]map[string]string logger *log.Logger @@ -192,7 +145,6 @@ type Store struct { HeartbeatTimeout time.Duration ElectionTimeout time.Duration ApplyTimeout time.Duration - OpenTimeout time.Duration } // StoreConfig represents the configuration of the underlying Store. @@ -205,22 +157,21 @@ type StoreConfig struct { } // New returns a new Store. -func New(c *StoreConfig) *Store { +func New(ln Listener, c *StoreConfig) *Store { logger := c.Logger if logger == nil { logger = log.New(os.Stderr, "[store] ", log.LstdFlags) } return &Store{ + ln: ln, raftDir: c.Dir, - raftTn: &raftTransport{c.Tn}, raftID: c.ID, dbConf: c.DBConf, dbPath: filepath.Join(c.Dir, sqliteFile), - meta: newClusterMeta(), + meta: make(map[string]map[string]string), logger: logger, ApplyTimeout: applyTimeout, - OpenTimeout: openTimeout, } } @@ -234,6 +185,7 @@ func (s *Store) Open(enableSingle bool) error { return err } + // Open underlying database. db, err := s.open() if err != nil { return err @@ -243,14 +195,11 @@ func (s *Store) Open(enableSingle bool) error { // Is this a brand new node? newNode := !pathExists(filepath.Join(s.raftDir, "raft.db")) - // Setup Raft communication. - transport := raft.NewNetworkTransport(s.raftTn, 3, 10*time.Second, os.Stderr) + // Create Raft-compatible network layer. + s.raftTn = raft.NewNetworkTransport(NewTransport(s.ln), connectionPoolCount, connectionTimeout, nil) - // Get the Raft configuration for this store. config := s.raftConfig() - config.LocalID = raft.ServerID(s.raftID) - // XXXconfig.Logger = log.New(os.Stderr, "[raft] ", log.LstdFlags) // Create the snapshot store. This allows Raft to truncate the log. snapshots, err := raft.NewFileSnapshotStore(s.raftDir, retainSnapshotCount, os.Stderr) @@ -259,13 +208,18 @@ func (s *Store) Open(enableSingle bool) error { } // Create the log store and stable store. - logStore, err := raftboltdb.NewBoltStore(filepath.Join(s.raftDir, "raft.db")) + s.boltStore, err = raftboltdb.NewBoltStore(filepath.Join(s.raftDir, "raft.db")) if err != nil { return fmt.Errorf("new bolt store: %s", err) } + s.raftStable = s.boltStore + s.raftLog, err = raft.NewLogCache(raftLogCacheSize, s.boltStore) + if err != nil { + return fmt.Errorf("new cached store: %s", err) + } // Instantiate the Raft system. - ra, err := raft.NewRaft(config, s, logStore, logStore, snapshots, transport) + ra, err := raft.NewRaft(config, s, s.raftLog, s.raftStable, snapshots, s.raftTn) if err != nil { return fmt.Errorf("new raft: %s", err) } @@ -276,7 +230,7 @@ func (s *Store) Open(enableSingle bool) error { Servers: []raft.Server{ raft.Server{ ID: config.LocalID, - Address: transport.LocalAddr(), + Address: s.raftTn.LocalAddr(), }, }, } @@ -287,16 +241,6 @@ func (s *Store) Open(enableSingle bool) error { s.raft = ra - if s.OpenTimeout != 0 { - // Wait until the initial logs are applied. - s.logger.Printf("waiting for up to %s for application of initial logs", s.OpenTimeout) - if err := s.WaitForAppliedIndex(s.raft.LastIndex(), s.OpenTimeout); err != nil { - return ErrOpenTimeout - } - } else { - s.logger.Println("not waiting for application of initial logs") - } - return nil } @@ -314,6 +258,19 @@ func (s *Store) Close(wait bool) error { return nil } +// WaitForApplied waits for all Raft log entries to to be applied to the +// underlying database. +func (s *Store) WaitForApplied(timeout time.Duration) error { + if timeout == 0 { + return nil + } + s.logger.Printf("waiting for up to %s for application of initial logs", timeout) + if err := s.WaitForAppliedIndex(s.raft.LastIndex(), timeout); err != nil { + return ErrOpenTimeout + } + return nil +} + // IsLeader is used to determine if the current node is cluster leader func (s *Store) IsLeader() bool { return s.raft.State() == raft.Leader @@ -342,8 +299,8 @@ func (s *Store) Path() string { } // Addr returns the address of the store. -func (s *Store) Addr() net.Addr { - return s.raftTn.Addr() +func (s *Store) Addr() string { + return string(s.raftTn.LocalAddr()) } // ID returns the Raft ID of the store. @@ -375,24 +332,6 @@ func (s *Store) LeaderID() (string, error) { return "", nil } -// Peer returns the API address for the given addr. If there is no peer -// for the address, it returns the empty string. -func (s *Store) Peer(addr string) string { - return s.meta.AddrForPeer(addr) -} - -// APIPeers return the map of Raft addresses to API addresses. -func (s *Store) APIPeers() (map[string]string, error) { - s.metaMu.RLock() - defer s.metaMu.RUnlock() - - peers := make(map[string]string, len(s.meta.APIPeers)) - for k, v := range s.meta.APIPeers { - peers[k] = v - } - return peers, nil -} - // Nodes returns the slice of nodes in the cluster, sorted by ID ascending. func (s *Store) Nodes() ([]*Server, error) { f := s.raft.GetConfiguration() @@ -480,18 +419,25 @@ func (s *Store) Stats() (map[string]interface{}, error) { if err != nil { return nil, err } + leaderID, err := s.LeaderID() + if err != nil { + return nil, err + } + status := map[string]interface{}{ - "node_id": s.raftID, - "raft": s.raft.Stats(), - "addr": s.Addr().String(), - "leader": s.LeaderAddr(), + "node_id": s.raftID, + "raft": s.raft.Stats(), + "addr": s.Addr(), + "leader": map[string]string{ + "node_id": leaderID, + "addr": s.LeaderAddr(), + }, "apply_timeout": s.ApplyTimeout.String(), - "open_timeout": s.OpenTimeout.String(), "heartbeat_timeout": s.HeartbeatTimeout.String(), "election_timeout": s.ElectionTimeout.String(), "snapshot_threshold": s.SnapshotThreshold, - "meta": s.meta, - "peers": nodes, + "metadata": s.meta, + "nodes": nodes, "dir": s.raftDir, "sqlite3": dbStatus, "db_conf": s.dbConf, @@ -636,29 +582,38 @@ func (s *Store) Query(qr *QueryRequest) ([]*sql.Rows, error) { return r, err } -// UpdateAPIPeers updates the cluster-wide peer information. -func (s *Store) UpdateAPIPeers(peers map[string]string) error { - c, err := newCommand(peer, peers) - if err != nil { - return err - } - b, err := json.Marshal(c) - if err != nil { - return err - } - - f := s.raft.Apply(b, s.ApplyTimeout) - return f.Error() -} - // Join joins a node, identified by id and located at addr, to this store. // The node must be ready to respond to Raft communications at that address. -func (s *Store) Join(id, addr string) error { +func (s *Store) Join(id, addr string, metadata map[string]string) error { s.logger.Printf("received request to join node at %s", addr) if s.raft.State() != raft.Leader { return ErrNotLeader } + configFuture := s.raft.GetConfiguration() + if err := configFuture.Error(); err != nil { + s.logger.Printf("failed to get raft configuration: %v", err) + return err + } + + for _, srv := range configFuture.Configuration().Servers { + // If a node already exists with either the joining node's ID or address, + // that node may need to be removed from the config first. + if srv.ID == raft.ServerID(id) || srv.Address == raft.ServerAddress(addr) { + // However if *both* the ID and the address are the same, the no + // join is actually needed. + if srv.Address == raft.ServerAddress(addr) && srv.ID == raft.ServerID(id) { + s.logger.Printf("node %s at %s already member of cluster, ignoring join request", id, addr) + return nil + } + + if err := s.remove(id); err != nil { + s.logger.Printf("failed to remove node: %v", err) + return err + } + } + } + f := s.raft.AddVoter(raft.ServerID(id), raft.ServerAddress(addr), 0, 0) if e := f.(raft.Future); e.Error() != nil { if e.Error() == raft.ErrNotLeader { @@ -666,6 +621,11 @@ func (s *Store) Join(id, addr string) error { } return e.Error() } + + if err := s.setMetadata(id, metadata); err != nil { + return err + } + s.logger.Printf("node at %s joined successfully", addr) return nil } @@ -673,18 +633,74 @@ func (s *Store) Join(id, addr string) error { // Remove removes a node from the store, specified by ID. func (s *Store) Remove(id string) error { s.logger.Printf("received request to remove node %s", id) - if s.raft.State() != raft.Leader { - return ErrNotLeader + if err := s.remove(id); err != nil { + s.logger.Printf("failed to remove node %s: %s", id, err.Error()) + return err } - f := s.raft.RemoveServer(raft.ServerID(id), 0, 0) - if f.Error() != nil { - if f.Error() == raft.ErrNotLeader { + s.logger.Printf("node %s removed successfully", id) + return nil +} + +// Metadata returns the value for a given key, for a given node ID. +func (s *Store) Metadata(id, key string) string { + s.metaMu.RLock() + defer s.metaMu.RUnlock() + + if _, ok := s.meta[id]; !ok { + return "" + } + v, ok := s.meta[id][key] + if ok { + return v + } + return "" +} + +// SetMetadata adds the metadata md to any existing metadata for +// this node. +func (s *Store) SetMetadata(md map[string]string) error { + return s.setMetadata(s.raftID, md) +} + +// setMetadata adds the metadata md to any existing metadata for +// the given node ID. +func (s *Store) setMetadata(id string, md map[string]string) error { + // Check local data first. + if func() bool { + s.metaMu.RLock() + defer s.metaMu.RUnlock() + if _, ok := s.meta[id]; ok { + for k, v := range md { + if s.meta[id][k] != v { + return false + } + } + return true + } + return false + }() { + // Local data is same as data being pushed in, + // nothing to do. + return nil + } + + c, err := newMetadataSetCommand(id, md) + if err != nil { + return err + } + b, err := json.Marshal(c) + if err != nil { + return err + } + f := s.raft.Apply(b, s.ApplyTimeout) + if e := f.(raft.Future); e.Error() != nil { + if e.Error() == raft.ErrNotLeader { return ErrNotLeader } - return f.Error() + e.Error() } - s.logger.Printf("node %s removed successfully", id) + return nil } @@ -712,6 +728,36 @@ func (s *Store) open() (*sql.DB, error) { return db, nil } +// remove removes the node, with the given ID, from the cluster. +func (s *Store) remove(id string) error { + if s.raft.State() != raft.Leader { + return ErrNotLeader + } + + f := s.raft.RemoveServer(raft.ServerID(id), 0, 0) + if f.Error() != nil { + if f.Error() == raft.ErrNotLeader { + return ErrNotLeader + } + return f.Error() + } + + c, err := newCommand(metadataDelete, id) + b, err := json.Marshal(c) + if err != nil { + return err + } + f = s.raft.Apply(b, s.ApplyTimeout) + if e := f.(raft.Future); e.Error() != nil { + if e.Error() == raft.ErrNotLeader { + return ErrNotLeader + } + e.Error() + } + + return nil +} + // raftConfig returns a new Raft config for the store. func (s *Store) raftConfig() *raft.Config { config := raft.DefaultConfig() @@ -764,17 +810,31 @@ func (s *Store) Apply(l *raft.Log) interface{} { } r, err := s.db.Query(d.Queries, d.Tx, d.Timings) return &fsmQueryResponse{rows: r, error: err} - case peer: - var d peersSub + case metadataSet: + var d metadataSetSub if err := json.Unmarshal(c.Sub, &d); err != nil { return &fsmGenericResponse{error: err} } func() { s.metaMu.Lock() defer s.metaMu.Unlock() - for k, v := range d { - s.meta.APIPeers[k] = v + if _, ok := s.meta[d.RaftID]; !ok { + s.meta[d.RaftID] = make(map[string]string) } + for k, v := range d.Data { + s.meta[d.RaftID][k] = v + } + }() + return &fsmGenericResponse{} + case metadataDelete: + var d string + if err := json.Unmarshal(c.Sub, &d); err != nil { + return &fsmGenericResponse{error: err} + } + func() { + s.metaMu.Lock() + defer s.metaMu.Unlock() + delete(s.meta, d) }() return &fsmGenericResponse{} default: diff --git a/store/store_test.go b/store/store_test.go index 2c7b51d1..758f6d3c 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -6,7 +6,6 @@ import ( "net" "os" "path/filepath" - "reflect" "sort" "testing" "time" @@ -14,35 +13,6 @@ import ( "github.com/rqlite/rqlite/testdata/chinook" ) -type mockSnapshotSink struct { - *os.File -} - -func (m *mockSnapshotSink) ID() string { - return "1" -} - -func (m *mockSnapshotSink) Cancel() error { - return nil -} - -func Test_ClusterMeta(t *testing.T) { - c := newClusterMeta() - c.APIPeers["localhost:4002"] = "localhost:4001" - - if c.AddrForPeer("localhost:4002") != "localhost:4001" { - t.Fatalf("wrong address returned for localhost:4002") - } - - if c.AddrForPeer("127.0.0.1:4002") != "localhost:4001" { - t.Fatalf("wrong address returned for 127.0.0.1:4002") - } - - if c.AddrForPeer("127.0.0.1:4004") != "" { - t.Fatalf("wrong address returned for 127.0.0.1:4003") - } -} - func Test_OpenStoreSingleNode(t *testing.T) { s := mustNewStore(true) defer os.RemoveAll(s.Path()) @@ -52,7 +22,7 @@ func Test_OpenStoreSingleNode(t *testing.T) { } s.WaitForLeader(10 * time.Second) - if got, exp := s.LeaderAddr(), s.Addr().String(); got != exp { + if got, exp := s.LeaderAddr(), s.Addr(); got != exp { t.Fatalf("wrong leader address returned, got: %s, exp %s", got, exp) } id, err := s.LeaderID() @@ -71,6 +41,7 @@ func Test_OpenStoreCloseSingleNode(t *testing.T) { if err := s.Open(true); err != nil { t.Fatalf("failed to open single-node store: %s", err.Error()) } + s.WaitForLeader(10 * time.Second) if err := s.Close(true); err != nil { t.Fatalf("failed to close single-node store: %s", err.Error()) } @@ -450,14 +421,14 @@ func Test_MultiNodeJoinRemove(t *testing.T) { sort.StringSlice(storeNodes).Sort() // Join the second node to the first. - if err := s0.Join(s1.ID(), s1.Addr().String()); err != nil { - t.Fatalf("failed to join to node at %s: %s", s0.Addr().String(), err.Error()) + if err := s0.Join(s1.ID(), s1.Addr(), nil); err != nil { + t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error()) } s1.WaitForLeader(10 * time.Second) // Check leader state on follower. - if got, exp := s1.LeaderAddr(), s0.Addr().String(); got != exp { + if got, exp := s1.LeaderAddr(), s0.Addr(); got != exp { t.Fatalf("wrong leader address returned, got: %s, exp %s", got, exp) } id, err := s1.LeaderID() @@ -514,8 +485,8 @@ func Test_MultiNodeExecuteQuery(t *testing.T) { defer s1.Close(true) // Join the second node to the first. - if err := s0.Join(s1.ID(), s1.Addr().String()); err != nil { - t.Fatalf("failed to join to node at %s: %s", s0.Addr().String(), err.Error()) + if err := s0.Join(s1.ID(), s1.Addr(), nil); err != nil { + t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error()) } queries := []string{ @@ -609,7 +580,7 @@ func Test_StoreLogTruncationMultinode(t *testing.T) { defer s1.Close(true) // Join the second node to the first. - if err := s0.Join(s1.ID(), s1.Addr().String()); err != nil { + if err := s0.Join(s1.ID(), s1.Addr(), nil); err != nil { t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error()) } s1.WaitForLeader(10 * time.Second) @@ -754,38 +725,65 @@ func Test_SingleNodeSnapshotInMem(t *testing.T) { } } -func Test_APIPeers(t *testing.T) { - s := mustNewStore(false) - defer os.RemoveAll(s.Path()) - - if err := s.Open(true); err != nil { +func Test_MetadataMultinode(t *testing.T) { + s0 := mustNewStore(true) + if err := s0.Open(true); err != nil { t.Fatalf("failed to open single-node store: %s", err.Error()) } - defer s.Close(true) - s.WaitForLeader(10 * time.Second) + defer s0.Close(true) + s0.WaitForLeader(10 * time.Second) + s1 := mustNewStore(true) + if err := s1.Open(true); err != nil { + t.Fatalf("failed to open single-node store: %s", err.Error()) + } + defer s1.Close(true) + s1.WaitForLeader(10 * time.Second) - peers := map[string]string{ - "localhost:4002": "localhost:4001", - "localhost:4004": "localhost:4003", + if s0.Metadata(s0.raftID, "foo") != "" { + t.Fatal("nonexistent metadata foo found") } - if err := s.UpdateAPIPeers(peers); err != nil { - t.Fatalf("failed to update API peers: %s", err.Error()) + if s0.Metadata("nonsense", "foo") != "" { + t.Fatal("nonexistent metadata foo found for nonexistent node") } - // Retrieve peers and verify them. - apiPeers, err := s.APIPeers() - if err != nil { - t.Fatalf("failed to retrieve API peers: %s", err.Error()) + if err := s0.SetMetadata(map[string]string{"foo": "bar"}); err != nil { + t.Fatalf("failed to set metadata: %s", err.Error()) + } + if s0.Metadata(s0.raftID, "foo") != "bar" { + t.Fatal("key foo not found") } - if !reflect.DeepEqual(peers, apiPeers) { - t.Fatalf("set and retrieved API peers not identical, got %v, exp %v", - apiPeers, peers) + if s0.Metadata("nonsense", "foo") != "" { + t.Fatal("nonexistent metadata foo found for nonexistent node") } - if s.Peer("localhost:4002") != "localhost:4001" || - s.Peer("localhost:4004") != "localhost:4003" || - s.Peer("not exist") != "" { - t.Fatalf("failed to retrieve correct single API peer") + // Join the second node to the first. + meta := map[string]string{"baz": "qux"} + if err := s0.Join(s1.ID(), s1.Addr(), meta); err != nil { + t.Fatalf("failed to join to node at %s: %s", s0.Addr(), err.Error()) + } + s1.WaitForLeader(10 * time.Second) + // Wait until the log entries have been applied to the follower, + // and then query. + if err := s1.WaitForAppliedIndex(5, 5*time.Second); err != nil { + t.Fatalf("error waiting for follower to apply index: %s:", err.Error()) + } + + if s1.Metadata(s0.raftID, "foo") != "bar" { + t.Fatal("key foo not found for s0") + } + if s0.Metadata(s1.raftID, "baz") != "qux" { + t.Fatal("key baz not found for s1") + } + + // Remove a node. + if err := s0.Remove(s1.ID()); err != nil { + t.Fatalf("failed to remove %s from cluster: %s", s1.ID(), err.Error()) + } + if s1.Metadata(s0.raftID, "foo") != "bar" { + t.Fatal("key foo not found for s0") + } + if s0.Metadata(s1.raftID, "baz") != "" { + t.Fatal("key baz found for removed node s1") } } @@ -824,12 +822,10 @@ func mustNewStore(inmem bool) *Store { path := mustTempDir() defer os.RemoveAll(path) - tn := mustMockTransport("localhost:0") cfg := NewDBConfig("", inmem) - s := New(&StoreConfig{ + s := New(mustMockLister("localhost:0"), &StoreConfig{ DBConf: cfg, Dir: path, - Tn: tn, ID: path, // Could be any unique string. }) if s == nil { @@ -838,36 +834,52 @@ func mustNewStore(inmem bool) *Store { return s } -func mustTempDir() string { - var err error - path, err := ioutil.TempDir("", "rqlilte-test-") - if err != nil { - panic("failed to create temp dir") - } - return path +type mockSnapshotSink struct { + *os.File +} + +func (m *mockSnapshotSink) ID() string { + return "1" +} + +func (m *mockSnapshotSink) Cancel() error { + return nil } type mockTransport struct { ln net.Listener } -func mustMockTransport(addr string) Transport { +type mockListener struct { + ln net.Listener +} + +func mustMockLister(addr string) Listener { ln, err := net.Listen("tcp", addr) if err != nil { - panic("failed to create new transport") + panic("failed to create new listner") } - return &mockTransport{ln} + return &mockListener{ln} } -func (m *mockTransport) Dial(addr string, timeout time.Duration) (net.Conn, error) { +func (m *mockListener) Dial(addr string, timeout time.Duration) (net.Conn, error) { return net.DialTimeout("tcp", addr, timeout) } -func (m *mockTransport) Accept() (net.Conn, error) { return m.ln.Accept() } +func (m *mockListener) Accept() (net.Conn, error) { return m.ln.Accept() } + +func (m *mockListener) Close() error { return m.ln.Close() } -func (m *mockTransport) Close() error { return m.ln.Close() } +func (m *mockListener) Addr() net.Addr { return m.ln.Addr() } -func (m *mockTransport) Addr() net.Addr { return m.ln.Addr() } +func mustTempDir() string { + var err error + path, err := ioutil.TempDir("", "rqlilte-test-") + if err != nil { + panic("failed to create temp dir") + } + return path +} func asJSON(v interface{}) string { b, err := json.Marshal(v) diff --git a/store/transport.go b/store/transport.go index 1edd1782..a6df008d 100644 --- a/store/transport.go +++ b/store/transport.go @@ -7,32 +7,39 @@ import ( "github.com/hashicorp/raft" ) -// Transport is the interface the network service must provide. -type Transport interface { +type Listener interface { net.Listener - - // Dial is used to create a new outgoing connection Dial(address string, timeout time.Duration) (net.Conn, error) } -// raftTransport takes a Transport and makes it suitable for use by the Raft -// networking system. -type raftTransport struct { - tn Transport +// Transport is the network service provided to Raft, and wraps a Listener. +type Transport struct { + ln Listener +} + +// NewTransport returns an initialized Transport. +func NewTransport(ln Listener) *Transport { + return &Transport{ + ln: ln, + } } -func (r *raftTransport) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { - return r.tn.Dial(string(address), timeout) +// Dial creates a new network connection. +func (t *Transport) Dial(addr raft.ServerAddress, timeout time.Duration) (net.Conn, error) { + return t.ln.Dial(string(addr), timeout) } -func (r *raftTransport) Accept() (net.Conn, error) { - return r.tn.Accept() +// Accept waits for the next connection. +func (t *Transport) Accept() (net.Conn, error) { + return t.ln.Accept() } -func (r *raftTransport) Addr() net.Addr { - return r.tn.Addr() +// Close closes the transport +func (t *Transport) Close() error { + return t.ln.Close() } -func (r *raftTransport) Close() error { - return r.tn.Close() +// Addr returns the binding address of the transport. +func (t *Transport) Addr() net.Addr { + return t.ln.Addr() } diff --git a/store/transport_test.go b/store/transport_test.go new file mode 100644 index 00000000..94d24c10 --- /dev/null +++ b/store/transport_test.go @@ -0,0 +1,11 @@ +package store + +import ( + "testing" +) + +func Test_NewTransport(t *testing.T) { + if NewTransport(nil) == nil { + t.Fatal("failed to create new Transport") + } +} diff --git a/system_test/helpers.go b/system_test/helpers.go index 8c112de5..7d34654a 100644 --- a/system_test/helpers.go +++ b/system_test/helpers.go @@ -5,7 +5,6 @@ import ( "encoding/json" "fmt" "io/ioutil" - "net" "net/http" "net/url" "os" @@ -14,12 +13,14 @@ import ( httpd "github.com/rqlite/rqlite/http" "github.com/rqlite/rqlite/store" + "github.com/rqlite/rqlite/tcp" ) // Node represents a node under test. type Node struct { APIAddr string RaftAddr string + ID string Dir string Store *store.Store Service *httpd.Service @@ -302,18 +303,21 @@ func mustNewNode(enableSingle bool) *Node { } dbConf := store.NewDBConfig("", false) - tn := mustMockTransport("localhost:0") - node.Store = store.New(&store.StoreConfig{ + tn := tcp.NewTransport() + if err := tn.Open("localhost:0"); err != nil { + panic(err.Error()) + } + node.Store = store.New(tn, &store.StoreConfig{ DBConf: dbConf, Dir: node.Dir, - Tn: tn, ID: tn.Addr().String(), }) if err := node.Store.Open(enableSingle); err != nil { node.Deprovision() panic(fmt.Sprintf("failed to open store: %s", err.Error())) } - node.RaftAddr = node.Store.Addr().String() + node.RaftAddr = node.Store.Addr() + node.ID = node.Store.ID() node.Service = httpd.New("localhost:0", node.Store, nil) node.Service.Expvar = true @@ -335,28 +339,6 @@ func mustNewLeaderNode() *Node { return node } -type mockTransport struct { - ln net.Listener -} - -func mustMockTransport(addr string) *mockTransport { - ln, err := net.Listen("tcp", addr) - if err != nil { - panic("failed to create new transport") - } - return &mockTransport{ln} -} - -func (m *mockTransport) Dial(addr string, timeout time.Duration) (net.Conn, error) { - return net.DialTimeout("tcp", addr, timeout) -} - -func (m *mockTransport) Accept() (net.Conn, error) { return m.ln.Accept() } - -func (m *mockTransport) Close() error { return m.ln.Close() } - -func (m *mockTransport) Addr() net.Addr { return m.ln.Addr() } - func mustTempDir() string { var err error path, err := ioutil.TempDir("", "rqlilte-system-test-") diff --git a/tcp/doc.go b/tcp/doc.go index 8630ff9b..0a8e9709 100644 --- a/tcp/doc.go +++ b/tcp/doc.go @@ -1,4 +1,4 @@ /* -Package tcp provides various TCP-related utilities. The TCP mux code provided by this package originated with InfluxDB. +Package tcp provides the internode communication network layer. */ package tcp diff --git a/tcp/mux.go b/tcp/mux.go deleted file mode 100644 index 27d29029..00000000 --- a/tcp/mux.go +++ /dev/null @@ -1,301 +0,0 @@ -package tcp - -import ( - "crypto/tls" - "crypto/x509" - "errors" - "fmt" - "io" - "io/ioutil" - "log" - "net" - "os" - "strconv" - "sync" - "time" -) - -const ( - // DefaultTimeout is the default length of time to wait for first byte. - DefaultTimeout = 30 * time.Second -) - -// Layer represents the connection between nodes. -type Layer struct { - ln net.Listener - header byte - addr net.Addr - - remoteEncrypted bool - skipVerify bool - nodeX509CACert string - tlsConfig *tls.Config -} - -// Addr returns the local address for the layer. -func (l *Layer) Addr() net.Addr { - return l.addr -} - -// Dial creates a new network connection. -func (l *Layer) Dial(addr string, timeout time.Duration) (net.Conn, error) { - dialer := &net.Dialer{Timeout: timeout} - - var err error - var conn net.Conn - if l.remoteEncrypted { - conn, err = tls.DialWithDialer(dialer, "tcp", addr, l.tlsConfig) - } else { - conn, err = dialer.Dial("tcp", addr) - } - if err != nil { - return nil, err - } - - // Write a marker byte to indicate message type. - _, err = conn.Write([]byte{l.header}) - if err != nil { - conn.Close() - return nil, err - } - return conn, err -} - -// Accept waits for the next connection. -func (l *Layer) Accept() (net.Conn, error) { return l.ln.Accept() } - -// Close closes the layer. -func (l *Layer) Close() error { return l.ln.Close() } - -// Mux multiplexes a network connection. -type Mux struct { - ln net.Listener - addr net.Addr - m map[byte]*listener - - wg sync.WaitGroup - - remoteEncrypted bool - - // The amount of time to wait for the first header byte. - Timeout time.Duration - - // Out-of-band error logger - Logger *log.Logger - - // Path to root X.509 certificate. - nodeX509CACert string - - // Path to X509 certificate - nodeX509Cert string - - // Path to X509 key. - nodeX509Key string - - // Whether to skip verification of other nodes' certificates. - InsecureSkipVerify bool - - tlsConfig *tls.Config -} - -// NewMux returns a new instance of Mux for ln. If adv is nil, -// then the addr of ln is used. -func NewMux(ln net.Listener, adv net.Addr) (*Mux, error) { - addr := adv - if addr == nil { - addr = ln.Addr() - } - - return &Mux{ - ln: ln, - addr: addr, - m: make(map[byte]*listener), - Timeout: DefaultTimeout, - Logger: log.New(os.Stderr, "[mux] ", log.LstdFlags), - }, nil -} - -// NewTLSMux returns a new instance of Mux for ln, and encrypts all traffic -// using TLS. If adv is nil, then the addr of ln is used. -func NewTLSMux(ln net.Listener, adv net.Addr, cert, key, caCert string) (*Mux, error) { - mux, err := NewMux(ln, adv) - if err != nil { - return nil, err - } - - mux.tlsConfig, err = createTLSConfig(cert, key, caCert) - if err != nil { - return nil, err - } - - mux.ln = tls.NewListener(ln, mux.tlsConfig) - mux.remoteEncrypted = true - mux.nodeX509CACert = caCert - mux.nodeX509Cert = cert - mux.nodeX509Key = key - - return mux, nil -} - -// Serve handles connections from ln and multiplexes then across registered listener. -func (mux *Mux) Serve() error { - mux.Logger.Printf("mux serving on %s, advertising %s", mux.ln.Addr().String(), mux.addr) - - for { - // Wait for the next connection. - // If it returns a temporary error then simply retry. - // If it returns any other error then exit immediately. - conn, err := mux.ln.Accept() - if err, ok := err.(interface { - Temporary() bool - }); ok && err.Temporary() { - continue - } - if err != nil { - // Wait for all connections to be demuxed - mux.wg.Wait() - for _, ln := range mux.m { - close(ln.c) - } - return err - } - - // Demux in a goroutine to - mux.wg.Add(1) - go mux.handleConn(conn) - } -} - -// Stats returns status of the mux. -func (mux *Mux) Stats() (interface{}, error) { - s := map[string]string{ - "addr": mux.addr.String(), - "timeout": mux.Timeout.String(), - "encrypted": strconv.FormatBool(mux.remoteEncrypted), - } - - if mux.remoteEncrypted { - s["certificate"] = mux.nodeX509Cert - s["key"] = mux.nodeX509Key - s["ca_certificate"] = mux.nodeX509CACert - s["skip_verify"] = strconv.FormatBool(mux.InsecureSkipVerify) - } - - return s, nil -} - -func (mux *Mux) handleConn(conn net.Conn) { - defer mux.wg.Done() - // Set a read deadline so connections with no data don't timeout. - if err := conn.SetReadDeadline(time.Now().Add(mux.Timeout)); err != nil { - conn.Close() - mux.Logger.Printf("tcp.Mux: cannot set read deadline: %s", err) - return - } - - // Read first byte from connection to determine handler. - var typ [1]byte - if _, err := io.ReadFull(conn, typ[:]); err != nil { - conn.Close() - mux.Logger.Printf("tcp.Mux: cannot read header byte: %s", err) - return - } - - // Reset read deadline and let the listener handle that. - if err := conn.SetReadDeadline(time.Time{}); err != nil { - conn.Close() - mux.Logger.Printf("tcp.Mux: cannot reset set read deadline: %s", err) - return - } - - // Retrieve handler based on first byte. - handler := mux.m[typ[0]] - if handler == nil { - conn.Close() - mux.Logger.Printf("tcp.Mux: handler not registered: %d", typ[0]) - return - } - - // Send connection to handler. The handler is responsible for closing the connection. - handler.c <- conn -} - -// Listen returns a listener identified by header. -// Any connection accepted by mux is multiplexed based on the initial header byte. -func (mux *Mux) Listen(header byte) *Layer { - // Ensure two listeners are not created for the same header byte. - if _, ok := mux.m[header]; ok { - panic(fmt.Sprintf("listener already registered under header byte: %d", header)) - } - - // Create a new listener and assign it. - ln := &listener{ - c: make(chan net.Conn), - } - mux.m[header] = ln - - layer := &Layer{ - ln: ln, - header: header, - addr: mux.addr, - remoteEncrypted: mux.remoteEncrypted, - skipVerify: mux.InsecureSkipVerify, - nodeX509CACert: mux.nodeX509CACert, - tlsConfig: mux.tlsConfig, - } - - return layer -} - -// listener is a receiver for connections received by Mux. -type listener struct { - c chan net.Conn -} - -// Accept waits for and returns the next connection to the listener. -func (ln *listener) Accept() (c net.Conn, err error) { - conn, ok := <-ln.c - if !ok { - return nil, errors.New("network connection closed") - } - return conn, nil -} - -// Close is a no-op. The mux's listener should be closed instead. -func (ln *listener) Close() error { return nil } - -// Addr always returns nil -func (ln *listener) Addr() net.Addr { return nil } - -// newTLSListener returns a net listener which encrypts the traffic using TLS. -func newTLSListener(ln net.Listener, certFile, keyFile, caCertFile string) (net.Listener, error) { - config, err := createTLSConfig(certFile, keyFile, caCertFile) - if err != nil { - return nil, err - } - - return tls.NewListener(ln, config), nil -} - -// createTLSConfig returns a TLS config from the given cert and key. -func createTLSConfig(certFile, keyFile, caCertFile string) (*tls.Config, error) { - var err error - config := &tls.Config{} - config.Certificates = make([]tls.Certificate, 1) - config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return nil, err - } - if caCertFile != "" { - asn1Data, err := ioutil.ReadFile(caCertFile) - if err != nil { - return nil, err - } - config.RootCAs = x509.NewCertPool() - ok := config.RootCAs.AppendCertsFromPEM([]byte(asn1Data)) - if !ok { - return nil, fmt.Errorf("failed to parse root certificate(s) in %s", caCertFile) - } - } - return config, nil -} diff --git a/tcp/mux_test.go b/tcp/mux_test.go deleted file mode 100644 index 8242e733..00000000 --- a/tcp/mux_test.go +++ /dev/null @@ -1,235 +0,0 @@ -package tcp - -import ( - "bytes" - "crypto/tls" - "io" - "io/ioutil" - "log" - "net" - "os" - "strings" - "sync" - "testing" - "testing/quick" - "time" - - "github.com/rqlite/rqlite/testdata/x509" -) - -// Ensure the muxer can split a listener's connections across multiple listeners. -func TestMux(t *testing.T) { - if err := quick.Check(func(n uint8, msg []byte) bool { - if testing.Verbose() { - if len(msg) == 0 { - log.Printf("n=%d, ", n) - } else { - log.Printf("n=%d, hdr=%d, len=%d", n, msg[0], len(msg)) - } - } - - var wg sync.WaitGroup - - // Open single listener on random port. - tcpListener := mustTCPListener("127.0.0.1:0") - defer tcpListener.Close() - - // Setup muxer & listeners. - mux, err := NewMux(tcpListener, nil) - if err != nil { - t.Fatalf("failed to create mux: %s", err.Error()) - } - mux.Timeout = 200 * time.Millisecond - if !testing.Verbose() { - mux.Logger = log.New(ioutil.Discard, "", 0) - } - for i := uint8(0); i < n; i++ { - ln := mux.Listen(byte(i)) - - wg.Add(1) - go func(i uint8, ln net.Listener) { - defer wg.Done() - - // Wait for a connection for this listener. - conn, err := ln.Accept() - if conn != nil { - defer conn.Close() - } - - // If there is no message or the header byte - // doesn't match then expect close. - if len(msg) == 0 || msg[0] != byte(i) { - if err == nil || err.Error() != "network connection closed" { - t.Fatalf("unexpected error: %s", err) - } - return - } - - // If the header byte matches this listener - // then expect a connection and read the message. - var buf bytes.Buffer - if _, err := io.CopyN(&buf, conn, int64(len(msg)-1)); err != nil { - t.Fatal(err) - } else if !bytes.Equal(msg[1:], buf.Bytes()) { - t.Fatalf("message mismatch:\n\nexp=%x\n\ngot=%x\n\n", msg[1:], buf.Bytes()) - } - - // Write response. - if _, err := conn.Write([]byte("OK")); err != nil { - t.Fatal(err) - } - }(i, ln) - } - - // Begin serving from the listener. - go mux.Serve() - - // Write message to TCP listener and read OK response. - conn, err := net.Dial("tcp", tcpListener.Addr().String()) - if err != nil { - t.Fatal(err) - } else if _, err = conn.Write(msg); err != nil { - t.Fatal(err) - } - - // Read the response into the buffer. - var resp [2]byte - _, err = io.ReadFull(conn, resp[:]) - - // If the message header is less than n then expect a response. - // Otherwise we should get an EOF because the mux closed. - if len(msg) > 0 && uint8(msg[0]) < n { - if string(resp[:]) != `OK` { - t.Fatalf("unexpected response: %s", resp[:]) - } - } else { - if err == nil || (err != io.EOF && !(strings.Contains(err.Error(), "connection reset by peer") || - strings.Contains(err.Error(), "closed by the remote host"))) { - t.Fatalf("unexpected error: %s", err) - } - } - - // Close connection. - if err := conn.Close(); err != nil { - t.Fatal(err) - } - - // Close original TCP listener and wait for all goroutines to close. - tcpListener.Close() - wg.Wait() - - return true - }, nil); err != nil { - t.Error(err) - } -} - -func TestMux_Advertise(t *testing.T) { - // Setup muxer. - tcpListener := mustTCPListener("127.0.0.1:0") - defer tcpListener.Close() - - addr := &mockAddr{ - Nwk: "tcp", - Addr: "rqlite.com:8081", - } - - mux, err := NewMux(tcpListener, addr) - if err != nil { - t.Fatalf("failed to create mux: %s", err.Error()) - } - mux.Timeout = 200 * time.Millisecond - if !testing.Verbose() { - mux.Logger = log.New(ioutil.Discard, "", 0) - } - - layer := mux.Listen(1) - if layer.Addr().String() != addr.Addr { - t.Fatalf("layer advertise address not correct, exp %s, got %s", - layer.Addr().String(), addr.Addr) - } -} - -// Ensure two handlers cannot be registered for the same header byte. -func TestMux_Listen_ErrAlreadyRegistered(t *testing.T) { - defer func() { - if r := recover(); r != `listener already registered under header byte: 5` { - t.Fatalf("unexpected recover: %#v", r) - } - }() - - // Register two listeners with the same header byte. - tcpListener := mustTCPListener("127.0.0.1:0") - mux, err := NewMux(tcpListener, nil) - if err != nil { - t.Fatalf("failed to create mux: %s", err.Error()) - } - mux.Listen(5) - mux.Listen(5) -} - -func TestTLSMux(t *testing.T) { - tcpListener := mustTCPListener("127.0.0.1:0") - defer tcpListener.Close() - - cert := x509.CertFile() - defer os.Remove(cert) - key := x509.KeyFile() - defer os.Remove(key) - - mux, err := NewTLSMux(tcpListener, nil, cert, key, "") - if err != nil { - t.Fatalf("failed to create mux: %s", err.Error()) - } - go mux.Serve() - - // Verify that the listener is secured. - conn, err := tls.Dial("tcp", tcpListener.Addr().String(), &tls.Config{ - InsecureSkipVerify: true, - }) - if err != nil { - t.Fatal(err) - } - - state := conn.ConnectionState() - if !state.HandshakeComplete { - t.Fatal("connection handshake failed to complete") - } -} - -func TestTLSMux_Fail(t *testing.T) { - tcpListener := mustTCPListener("127.0.0.1:0") - defer tcpListener.Close() - - cert := x509.CertFile() - defer os.Remove(cert) - key := x509.KeyFile() - defer os.Remove(key) - - _, err := NewTLSMux(tcpListener, nil, "xxxx", "yyyy", "") - if err == nil { - t.Fatalf("created mux unexpectedly with bad resources") - } -} - -type mockAddr struct { - Nwk string - Addr string -} - -func (m *mockAddr) Network() string { - return m.Nwk -} - -func (m *mockAddr) String() string { - return m.Addr -} - -// mustTCPListener returns a listener on bind, or panics. -func mustTCPListener(bind string) net.Listener { - l, err := net.Listen("tcp", bind) - if err != nil { - panic(err) - } - return l -} diff --git a/tcp/transport.go b/tcp/transport.go new file mode 100644 index 00000000..4a3156f1 --- /dev/null +++ b/tcp/transport.go @@ -0,0 +1,101 @@ +package tcp + +import ( + "crypto/tls" + "fmt" + "net" + "time" +) + +// Transport is the network layer for internode communications. +type Transport struct { + ln net.Listener + + certFile string // Path to local X.509 cert. + certKey string // Path to corresponding X.509 key. + remoteEncrypted bool // Remote nodes use encrypted communication. + skipVerify bool // Skip verification of remote node certs. +} + +// NewTransport returns an initialized unecrypted Transport. +func NewTransport() *Transport { + return &Transport{} +} + +// NewTransport returns an initialized TLS-ecrypted Transport. +func NewTLSTransport(certFile, keyPath string, skipVerify bool) *Transport { + return &Transport{ + certFile: certFile, + certKey: keyPath, + remoteEncrypted: true, + skipVerify: skipVerify, + } +} + +// Open opens the transport, binding to the supplied address. +func (t *Transport) Open(addr string) error { + ln, err := net.Listen("tcp", addr) + if err != nil { + return err + } + if t.certFile != "" { + config, err := createTLSConfig(t.certFile, t.certKey) + if err != nil { + return err + } + ln = tls.NewListener(ln, config) + } + + t.ln = ln + return nil +} + +// Dial opens a network connection. +func (t *Transport) Dial(addr string, timeout time.Duration) (net.Conn, error) { + dialer := &net.Dialer{Timeout: timeout} + + var err error + var conn net.Conn + if t.remoteEncrypted { + conf := &tls.Config{ + InsecureSkipVerify: t.skipVerify, + } + fmt.Println("doing a TLS dial") + conn, err = tls.DialWithDialer(dialer, "tcp", addr, conf) + } else { + conn, err = dialer.Dial("tcp", addr) + } + + return conn, err +} + +// Accept waits for the next connection. +func (t *Transport) Accept() (net.Conn, error) { + c, err := t.ln.Accept() + if err != nil { + fmt.Println("error accepting: ", err.Error()) + } + return c, err +} + +// Close closes the transport +func (t *Transport) Close() error { + return t.ln.Close() +} + +// Addr returns the binding address of the transport. +func (t *Transport) Addr() net.Addr { + return t.ln.Addr() +} + +// createTLSConfig returns a TLS config from the given cert and key. +func createTLSConfig(certFile, keyFile string) (*tls.Config, error) { + var err error + config := &tls.Config{} + config.Certificates = make([]tls.Certificate, 1) + config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, err + } + return config, nil +} diff --git a/tcp/transport_test.go b/tcp/transport_test.go new file mode 100644 index 00000000..782f6655 --- /dev/null +++ b/tcp/transport_test.go @@ -0,0 +1,70 @@ +package tcp + +import ( + "os" + "testing" + "time" + + "github.com/rqlite/rqlite/testdata/x509" +) + +func Test_NewTransport(t *testing.T) { + if NewTransport() == nil { + t.Fatal("failed to create new Transport") + } +} + +func Test_TransportOpenClose(t *testing.T) { + tn := NewTransport() + if err := tn.Open("localhost:0"); err != nil { + t.Fatalf("failed to open transport: %s", err.Error()) + } + if tn.Addr().String() == "localhost:0" { + t.Fatalf("transport address set incorrectly, got: %s", tn.Addr().String()) + } + if err := tn.Close(); err != nil { + t.Fatalf("failed to close transport: %s", err.Error()) + } +} + +func Test_TransportDial(t *testing.T) { + tn1 := NewTransport() + tn1.Open("localhost:0") + go tn1.Accept() + tn2 := NewTransport() + + _, err := tn2.Dial(tn1.Addr().String(), time.Second) + if err != nil { + t.Fatalf("failed to connect to first transport: %s", err.Error()) + } + tn1.Close() +} + +func Test_NewTLSTransport(t *testing.T) { + c := x509.CertFile() + defer os.Remove(c) + k := x509.KeyFile() + defer os.Remove(k) + + if NewTLSTransport(c, k, true) == nil { + t.Fatal("failed to create new TLS Transport") + } +} + +func Test_TLSTransportOpenClose(t *testing.T) { + c := x509.CertFile() + defer os.Remove(c) + k := x509.KeyFile() + defer os.Remove(k) + + tn := NewTLSTransport(c, k, true) + if err := tn.Open("localhost:0"); err != nil { + t.Fatalf("failed to open TLS transport: %s", err.Error()) + } + if tn.Addr().String() == "localhost:0" { + t.Fatalf("TLS transport address set incorrectly, got: %s", tn.Addr().String()) + } + if err := tn.Close(); err != nil { + t.Fatalf("failed to close TLS transport: %s", err.Error()) + } +}