diff --git a/cluster/client.go b/cluster/client.go index f5199265..a5d78f90 100644 --- a/cluster/client.go +++ b/cluster/client.go @@ -4,17 +4,32 @@ import ( "encoding/binary" "errors" "fmt" - "io/ioutil" + "io" + "net" + "sync" "time" "github.com/golang/protobuf/proto" "github.com/rqlite/rqlite/command" + "github.com/rqlite/rqlite/tcp/pool" +) + +const ( + initialPoolSize = 4 + maxPoolCapacity = 64 ) // Client allows communicating with a remote node. type Client struct { dialer Dialer timeout time.Duration + + lMu sync.RWMutex + localNodeAddr string + localServ *Service + + mu sync.RWMutex + pools map[string]pool.Pool } // NewClient returns a client instance for talking to a remote node. @@ -22,14 +37,34 @@ func NewClient(dl Dialer) *Client { return &Client{ dialer: dl, timeout: 30 * time.Second, + pools: make(map[string]pool.Pool), } } +// SetLocal informs the client instance of the node address for +// the node using this client. Along with the Service instance +// it allows this client to serve requests for this node locally +// without the network hop. +func (c *Client) SetLocal(nodeAddr string, serv *Service) { + c.lMu.Lock() + defer c.lMu.Unlock() + c.localNodeAddr = nodeAddr + c.localServ = serv +} + // GetNodeAPIAddr retrieves the API Address for the node at nodeAddr func (c *Client) GetNodeAPIAddr(nodeAddr string) (string, error) { - conn, err := c.dialer.Dial(nodeAddr, c.timeout) + c.lMu.RLock() + defer c.lMu.RUnlock() + if c.localNodeAddr == nodeAddr && c.localServ != nil { + // Serve it locally! + stats.Add(numGetNodeAPIRequestLocal, 1) + return c.localServ.GetNodeAPIURL(), nil + } + + conn, err := c.dial(nodeAddr, c.timeout) if err != nil { - return "", fmt.Errorf("dial connection: %s", err) + return "", err } defer conn.Close() @@ -48,20 +83,31 @@ func (c *Client) GetNodeAPIAddr(nodeAddr string) (string, error) { _, err = conn.Write(b) if err != nil { + handleConnError(conn) return "", fmt.Errorf("write protobuf length: %s", err) } _, err = conn.Write(p) if err != nil { + handleConnError(conn) return "", fmt.Errorf("write protobuf: %s", err) } - b, err = ioutil.ReadAll(conn) + // Read length of response. + _, err = io.ReadFull(conn, b) if err != nil { - return "", fmt.Errorf("read protobuf bytes: %s", err) + return "", err + } + sz := binary.LittleEndian.Uint16(b[0:]) + + // Read in the actual response. + p = make([]byte, sz) + _, err = io.ReadFull(conn, p) + if err != nil { + return "", err } a := &Address{} - err = proto.Unmarshal(b, a) + err = proto.Unmarshal(p, a) if err != nil { return "", fmt.Errorf("protobuf unmarshal: %s", err) } @@ -71,9 +117,9 @@ func (c *Client) GetNodeAPIAddr(nodeAddr string) (string, error) { // Execute performs an Execute on a remote node. func (c *Client) Execute(er *command.ExecuteRequest, nodeAddr string, timeout time.Duration) ([]*command.ExecuteResult, error) { - conn, err := c.dialer.Dial(nodeAddr, c.timeout) + conn, err := c.dial(nodeAddr, c.timeout) if err != nil { - return nil, fmt.Errorf("dial connection: %s", err) + return nil, err } defer conn.Close() @@ -94,30 +140,45 @@ func (c *Client) Execute(er *command.ExecuteRequest, nodeAddr string, timeout ti binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + handleConnError(conn) return nil, err } _, err = conn.Write(b) if err != nil { + handleConnError(conn) return nil, err } if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + handleConnError(conn) return nil, err } _, err = conn.Write(p) if err != nil { + handleConnError(conn) return nil, err } if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + handleConnError(conn) return nil, err } - b, err = ioutil.ReadAll(conn) + + // Read length of response. + _, err = io.ReadFull(conn, b) + if err != nil { + return nil, err + } + sz := binary.LittleEndian.Uint16(b[0:]) + + // Read in the actual response. + p = make([]byte, sz) + _, err = io.ReadFull(conn, p) if err != nil { return nil, err } a := &CommandExecuteResponse{} - err = proto.Unmarshal(b, a) + err = proto.Unmarshal(p, a) if err != nil { return nil, err } @@ -130,9 +191,9 @@ func (c *Client) Execute(er *command.ExecuteRequest, nodeAddr string, timeout ti // Query performs an Query on a remote node. func (c *Client) Query(qr *command.QueryRequest, nodeAddr string, timeout time.Duration) ([]*command.QueryRows, error) { - conn, err := c.dialer.Dial(nodeAddr, c.timeout) + conn, err := c.dial(nodeAddr, c.timeout) if err != nil { - return nil, fmt.Errorf("dial connection: %s", err) + return nil, err } defer conn.Close() @@ -148,35 +209,50 @@ func (c *Client) Query(qr *command.QueryRequest, nodeAddr string, timeout time.D return nil, fmt.Errorf("command marshal: %s", err) } - // Write length of Protobuf, the Protobuf + // Write length of Protobuf, then the Protobuf b := make([]byte, 4) binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + handleConnError(conn) return nil, err } _, err = conn.Write(b) if err != nil { + handleConnError(conn) return nil, err } if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + handleConnError(conn) return nil, err } _, err = conn.Write(p) if err != nil { + handleConnError(conn) return nil, err } if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + handleConnError(conn) return nil, err } - b, err = ioutil.ReadAll(conn) + + // Read length of response. + _, err = io.ReadFull(conn, b) + if err != nil { + return nil, err + } + sz := binary.LittleEndian.Uint16(b[0:]) + + // Read in the actual response. + p = make([]byte, sz) + _, err = io.ReadFull(conn, p) if err != nil { return nil, err } a := &CommandQueryResponse{} - err = proto.Unmarshal(b, a) + err = proto.Unmarshal(p, a) if err != nil { return nil, err } @@ -189,7 +265,65 @@ func (c *Client) Query(qr *command.QueryRequest, nodeAddr string, timeout time.D // Stats returns stats on the Client instance func (c *Client) Stats() (map[string]interface{}, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + poolStats := make(map[string]interface{}, len(c.pools)) + for k, v := range c.pools { + s, err := v.Stats() + if err != nil { + return nil, err + } + poolStats[k] = s + } return map[string]interface{}{ - "timeout": c.timeout, + "timeout": c.timeout, + "conn_pool_stats": poolStats, }, nil } + +func (c *Client) dial(nodeAddr string, timeout time.Duration) (net.Conn, error) { + var pl pool.Pool + var ok bool + + c.mu.RLock() + pl, ok = c.pools[nodeAddr] + c.mu.RUnlock() + + // Do we need a new pool for the given address? + if !ok { + if err := func() error { + c.mu.Lock() + defer c.mu.Unlock() + pl, ok = c.pools[nodeAddr] + if ok { + return nil // Pool was inserted just after we checked. + } + + // New pool is needed for given address. + factory := func() (net.Conn, error) { return c.dialer.Dial(nodeAddr, c.timeout) } + p, err := pool.NewChannelPool(initialPoolSize, maxPoolCapacity, factory) + if err != nil { + return err + } + c.pools[nodeAddr] = p + pl = p + return nil + }(); err != nil { + return nil, err + } + } + + // Got pool, now get a connection. + conn, err := pl.Get() + if err != nil { + return nil, fmt.Errorf("pool get: %s", err) + } + return conn, nil +} + +func handleConnError(conn net.Conn) { + if pc, ok := conn.(*pool.PoolConn); ok { + pc.MarkUnusable() + } +} diff --git a/cluster/service.go b/cluster/service.go index d71f56db..807f7b68 100644 --- a/cluster/service.go +++ b/cluster/service.go @@ -24,6 +24,9 @@ const ( numGetNodeAPIResponse = "num_get_node_api_resp" numExecuteRequest = "num_execute_req" numQueryRequest = "num_query_req" + + // Client stats for this package. + numGetNodeAPIRequestLocal = "num_get_node_api_req_local" ) const ( @@ -40,6 +43,7 @@ func init() { stats.Add(numGetNodeAPIResponse, 0) stats.Add(numExecuteRequest, 0) stats.Add(numQueryRequest, 0) + stats.Add(numGetNodeAPIRequestLocal, 0) } // Dialer is the interface dialers must implement. @@ -166,87 +170,102 @@ func (s *Service) serve() error { func (s *Service) handleConn(conn net.Conn) { defer conn.Close() - b := make([]byte, 4) - _, err := io.ReadFull(conn, b) - if err != nil { - return - } - sz := binary.LittleEndian.Uint16(b[0:]) - - b = make([]byte, sz) - _, err = io.ReadFull(conn, b) - if err != nil { - return - } + for { + b := make([]byte, 4) + _, err := io.ReadFull(conn, b) + if err != nil { + return + } + sz := binary.LittleEndian.Uint16(b[0:]) - c := &Command{} - err = proto.Unmarshal(b, c) - if err != nil { - conn.Close() - } + p := make([]byte, sz) + _, err = io.ReadFull(conn, p) + if err != nil { + return + } - switch c.Type { - case Command_COMMAND_TYPE_GET_NODE_API_URL: - stats.Add(numGetNodeAPIRequest, 1) - b, err = proto.Marshal(&Address{ - Url: s.GetNodeAPIURL(), - }) + c := &Command{} + err = proto.Unmarshal(p, c) if err != nil { conn.Close() } - conn.Write(b) - stats.Add(numGetNodeAPIResponse, 1) - case Command_COMMAND_TYPE_EXECUTE: - stats.Add(numExecuteRequest, 1) + switch c.Type { + case Command_COMMAND_TYPE_GET_NODE_API_URL: + stats.Add(numGetNodeAPIRequest, 1) + p, err = proto.Marshal(&Address{ + Url: s.GetNodeAPIURL(), + }) + if err != nil { + conn.Close() + } - resp := &CommandExecuteResponse{} + // Write length of Protobuf first, then write the actual Protobuf. + b = make([]byte, 4) + binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) + conn.Write(b) + conn.Write(p) + stats.Add(numGetNodeAPIResponse, 1) - er := c.GetExecuteRequest() - if er == nil { - resp.Error = "ExecuteRequest is nil" - } else { - res, err := s.db.Execute(er) - if err != nil { - resp.Error = err.Error() + case Command_COMMAND_TYPE_EXECUTE: + stats.Add(numExecuteRequest, 1) + + resp := &CommandExecuteResponse{} + + er := c.GetExecuteRequest() + if er == nil { + resp.Error = "ExecuteRequest is nil" } else { - resp.Results = make([]*command.ExecuteResult, len(res)) - for i := range res { - resp.Results[i] = res[i] + res, err := s.db.Execute(er) + if err != nil { + resp.Error = err.Error() + } else { + resp.Results = make([]*command.ExecuteResult, len(res)) + for i := range res { + resp.Results[i] = res[i] + } } } - } - b, err = proto.Marshal(resp) - if err != nil { - return - } - conn.Write(b) + p, err := proto.Marshal(resp) + if err != nil { + return + } + // Write length of Protobuf first, then write the actual Protobuf. + b = make([]byte, 4) + binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) + conn.Write(b) + conn.Write(p) - case Command_COMMAND_TYPE_QUERY: - stats.Add(numQueryRequest, 1) + case Command_COMMAND_TYPE_QUERY: + stats.Add(numQueryRequest, 1) - resp := &CommandQueryResponse{} + resp := &CommandQueryResponse{} - qr := c.GetQueryRequest() - if qr == nil { - resp.Error = "QueryRequest is nil" - } else { - res, err := s.db.Query(qr) - if err != nil { - resp.Error = err.Error() + qr := c.GetQueryRequest() + if qr == nil { + resp.Error = "QueryRequest is nil" } else { - resp.Rows = make([]*command.QueryRows, len(res)) - for i := range res { - resp.Rows[i] = res[i] + res, err := s.db.Query(qr) + if err != nil { + resp.Error = err.Error() + } else { + resp.Rows = make([]*command.QueryRows, len(res)) + for i := range res { + resp.Rows[i] = res[i] + } } } - } - b, err = proto.Marshal(resp) - if err != nil { - return + p, err = proto.Marshal(resp) + if err != nil { + return + } + // Write length of Protobuf first, then write the actual Protobuf. + b = make([]byte, 4) + binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) + conn.Write(b) + conn.Write(p) } - conn.Write(b) } } diff --git a/cluster/service_test.go b/cluster/service_test.go index 3ae2e1ee..e4bc5abe 100644 --- a/cluster/service_test.go +++ b/cluster/service_test.go @@ -95,6 +95,41 @@ func Test_NewServiceSetGetNodeAPIAddr(t *testing.T) { } } +func Test_NewServiceSetGetNodeAPIAddrLocal(t *testing.T) { + ml := mustNewMockTransport() + s := New(ml, mustNewMockDatabase()) + if s == nil { + t.Fatalf("failed to create cluster service") + } + + if err := s.Open(); err != nil { + t.Fatalf("failed to open cluster service") + } + + s.SetAPIAddr("foo") + + // Check stats to confirm no local request yet. + if stats.Get(numGetNodeAPIRequestLocal).String() != "0" { + t.Fatalf("failed to confirm request served locally") + } + + // Test by enabling local answering + c := NewClient(ml) + c.SetLocal(s.Addr(), s) + addr, err := c.GetNodeAPIAddr(s.Addr()) + if err != nil { + t.Fatalf("failed to get node API address locally: %s", err) + } + if addr != "http://foo" { + t.Fatalf("failed to get correct node API address locally, exp %s, got %s", "http://foo", addr) + } + + // Check stats to confirm local response. + if stats.Get(numGetNodeAPIRequestLocal).String() != "1" { + t.Fatalf("failed to confirm request served locally") + } +} + func Test_NewServiceSetGetNodeAPIAddrTLS(t *testing.T) { ml := mustNewMockTLSTransport() s := New(ml, mustNewMockDatabase()) diff --git a/cmd/rqlited/main.go b/cmd/rqlited/main.go index ed4bd552..c38ccc72 100644 --- a/cmd/rqlited/main.go +++ b/cmd/rqlited/main.go @@ -307,6 +307,7 @@ func main() { // Start the HTTP API server. clstrDialer := tcp.NewDialer(cluster.MuxClusterHeader, nodeEncrypt, noNodeVerify) clstrClient := cluster.NewClient(clstrDialer) + clstrClient.SetLocal(raftAdv, clstr) if err := startHTTPService(str, clstrClient); err != nil { log.Fatalf("failed to start HTTP server: %s", err.Error()) } diff --git a/http/service.go b/http/service.go index 75696315..2110f694 100644 --- a/http/service.go +++ b/http/service.go @@ -26,6 +26,11 @@ import ( "github.com/rqlite/rqlite/store" ) +var ( + // ErrLeaderNotFound is returned when a node cannot locate a leader + ErrLeaderNotFound = errors.New("leader not found") +) + // Database is the interface any queryable system must implement type Database interface { // Execute executes a slice of queries, each of which is not expected @@ -126,6 +131,7 @@ type Response struct { var stats *expvar.Map const ( + numLeaderNotFound = "leader_not_found" numExecutions = "executions" numQueries = "queries" numRemoteExecutions = "remote_executions" @@ -167,6 +173,7 @@ const ( func init() { stats = expvar.NewMap("http") + stats.Add(numLeaderNotFound, 0) stats.Add(numExecutions, 0) stats.Add(numQueries, 0) stats.Add(numRemoteExecutions, 0) @@ -375,7 +382,8 @@ func (s *Service) handleJoin(w http.ResponseWriter, r *http.Request) { if err == store.ErrNotLeader { leaderAPIAddr := s.LeaderAPIAddr() if leaderAPIAddr == "" { - http.Error(w, err.Error(), http.StatusServiceUnavailable) + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) return } @@ -427,7 +435,8 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) { if err == store.ErrNotLeader { leaderAPIAddr := s.LeaderAPIAddr() if leaderAPIAddr == "" { - http.Error(w, err.Error(), http.StatusServiceUnavailable) + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) return } @@ -470,7 +479,8 @@ func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request) { if err == store.ErrNotLeader { leaderAPIAddr := s.LeaderAPIAddr() if leaderAPIAddr == "" { - http.Error(w, err.Error(), http.StatusServiceUnavailable) + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) return } @@ -522,7 +532,8 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) { if err == store.ErrNotLeader { leaderAPIAddr := s.LeaderAPIAddr() if leaderAPIAddr == "" { - http.Error(w, err.Error(), http.StatusServiceUnavailable) + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) return } @@ -767,7 +778,8 @@ func (s *Service) handleExecute(w http.ResponseWriter, r *http.Request) { if redirect { leaderAPIAddr := s.LeaderAPIAddr() if leaderAPIAddr == "" { - http.Error(w, err.Error(), http.StatusServiceUnavailable) + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) return } loc := s.FormRedirect(r, leaderAPIAddr) @@ -779,6 +791,10 @@ func (s *Service) handleExecute(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } + if addr == "" { + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) + } results, resultsErr = s.cluster.Execute(er, addr, timeout) stats.Add(numRemoteExecutions, 1) w.Header().Add(ServedByHTTPHeader, addr) @@ -849,7 +865,8 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) { if redirect { leaderAPIAddr := s.LeaderAPIAddr() if leaderAPIAddr == "" { - http.Error(w, err.Error(), http.StatusServiceUnavailable) + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) return } loc := s.FormRedirect(r, leaderAPIAddr) @@ -861,6 +878,10 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } + if addr == "" { + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) + } results, resultsErr = s.cluster.Query(qr, addr, timeout) stats.Add(numRemoteQueries, 1) w.Header().Add(ServedByHTTPHeader, addr) diff --git a/http/service_test.go b/http/service_test.go index da47c484..6915976d 100644 --- a/http/service_test.go +++ b/http/service_test.go @@ -677,7 +677,7 @@ func Test_Nodes(t *testing.T) { } } -func Test_ForwardingRedirectQueries(t *testing.T) { +func Test_ForwardingRedirectQuery(t *testing.T) { m := &MockStore{ leaderAddr: "foo:1234", } @@ -725,6 +725,16 @@ func Test_ForwardingRedirectQueries(t *testing.T) { if resp.StatusCode != http.StatusMovedPermanently { t.Fatalf("failed to get expected StatusMovedPermanently for query, got %d", resp.StatusCode) } + + // Check leader failure case. + m.leaderAddr = "" + resp, err = client.Get(host + "/db/query?pretty&timings&q=SELECT%20%2A%20FROM%20foo") + if err != nil { + t.Fatalf("failed to make query forwarded request") + } + if resp.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("failed to get expected StatusServiceUnavailable for node with no leader, got %d", resp.StatusCode) + } } func Test_ForwardingRedirectExecute(t *testing.T) { @@ -775,6 +785,16 @@ func Test_ForwardingRedirectExecute(t *testing.T) { if resp.StatusCode != http.StatusMovedPermanently { t.Fatalf("failed to get expected StatusMovedPermanently for execute, got %d", resp.StatusCode) } + + // Check leader failure case. + m.leaderAddr = "" + resp, err = client.Post(host+"/db/execute", "application/json", strings.NewReader(`["Some SQL"]`)) + if err != nil { + t.Fatalf("failed to make execute request") + } + if resp.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("failed to get expected StatusServiceUnavailable for node with no leader, got %d", resp.StatusCode) + } } func Test_TLSServce(t *testing.T) { diff --git a/system_test/full_system_test.py b/system_test/full_system_test.py index 6df359e0..00b95294 100755 --- a/system_test/full_system_test.py +++ b/system_test/full_system_test.py @@ -136,7 +136,7 @@ class Node(object): r.raise_for_status() return r.json() - def is_leader(self, constraint_check=True): + def is_leader(self): ''' is_leader returns whether this node is the cluster leader It also performs a check, to ensure the node nevers gives out @@ -144,28 +144,16 @@ class Node(object): ''' try: - isLeaderRaft = self.status()['store']['raft']['state'] == 'Leader' - isLeaderNodes = self.nodes()[self.node_id]['leader'] is True + return self.status()['store']['raft']['state'] == 'Leader' except requests.exceptions.ConnectionError: return False - if (isLeaderRaft != isLeaderNodes) and constraint_check: - raise AssertionError("conflicting states reported for leadership (raft: %s, nodes: %s)" - % (isLeaderRaft, isLeaderNodes)) - return isLeaderNodes - def is_follower(self): try: - isFollowerRaft = self.status()['store']['raft']['state'] == 'Follower' - isFollowersNodes = self.nodes()[self.node_id]['leader'] is False + return self.status()['store']['raft']['state'] == 'Follower' except requests.exceptions.ConnectionError: return False - if isFollowerRaft != isFollowersNodes: - raise AssertionError("conflicting states reported for followership (raft: %s, nodes: %s)" - % (isFollowerRaft, isFollowersNodes)) - return isFollowersNodes - def wait_for_leader(self, timeout=TIMEOUT): lr = None t = 0 @@ -289,6 +277,7 @@ class Node(object): def redirect_addr(self): r = requests.post(self._execute_url(redirect=True), data=json.dumps(['nonsense']), allow_redirects=False) + r.raise_for_status() if r.status_code == 301: return "%s://%s" % (urlparse(r.headers['Location']).scheme, urlparse(r.headers['Location']).netloc) @@ -333,7 +322,7 @@ def deprovision_node(node): class Cluster(object): def __init__(self, nodes): self.nodes = nodes - def wait_for_leader(self, node_exc=None, timeout=TIMEOUT, constraint_check=True): + def wait_for_leader(self, node_exc=None, timeout=TIMEOUT): t = 0 while True: if t > timeout: @@ -341,7 +330,7 @@ class Cluster(object): for n in self.nodes: if node_exc is not None and n == node_exc: continue - if n.is_leader(constraint_check): + if n.is_leader(): return n time.sleep(1) t+=1 @@ -522,6 +511,8 @@ class TestEndToEnd(unittest.TestCase): for n in fs: self.assertEqual(l.APIProtoAddr(), n.redirect_addr()) + # Kill the leader, wait for a new leader, and check that the + # redirect returns the new leader address. l.stop() n = self.cluster.wait_for_leader(node_exc=l) for f in self.cluster.followers(): @@ -682,10 +673,9 @@ class TestEndToEndNonVoterFollowsLeader(unittest.TestCase): j = n.query('SELECT * FROM foo') self.assertEqual(str(j), "{u'results': [{u'values': [[1, u'fiona']], u'types': [u'integer', u'text'], u'columns': [u'id', u'name']}]}") - # Kill leader, and then make more changes. Don't perform leader-constraint checks - # since the cluster is changing right now. - n0 = self.cluster.wait_for_leader(constraint_check=False).stop() - n1 = self.cluster.wait_for_leader(node_exc=n0, constraint_check=False) + # Kill leader, and then make more changes. + n0 = self.cluster.wait_for_leader().stop() + n1 = self.cluster.wait_for_leader(node_exc=n0) n1.wait_for_all_applied() j = n1.query('SELECT * FROM foo') self.assertEqual(str(j), "{u'results': [{u'values': [[1, u'fiona']], u'types': [u'integer', u'text'], u'columns': [u'id', u'name']}]}") diff --git a/system_test/single_node_test.go b/system_test/single_node_test.go index bbd7a5ff..3ca849eb 100644 --- a/system_test/single_node_test.go +++ b/system_test/single_node_test.go @@ -434,8 +434,8 @@ func Test_SingleNodeNodes(t *testing.T) { if n.Addr != node.RaftAddr { t.Fatalf("node has wrong Raft address") } - if n.APIAddr != fmt.Sprintf("http://%s", node.APIAddr) { - t.Fatalf("node has wrong API address") + if got, exp := n.APIAddr, fmt.Sprintf("http://%s", node.APIAddr); exp != got { + t.Fatalf("node has wrong API address, exp: %s got: %s", exp, got) } if !n.Leader { t.Fatalf("node is not leader") diff --git a/tcp/pool/LICENSE b/tcp/pool/LICENSE new file mode 100644 index 00000000..25fdaf63 --- /dev/null +++ b/tcp/pool/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2013 Fatih Arslan + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/tcp/pool/channel.go b/tcp/pool/channel.go new file mode 100644 index 00000000..fd525d1b --- /dev/null +++ b/tcp/pool/channel.go @@ -0,0 +1,155 @@ +package pool + +import ( + "errors" + "fmt" + "net" + "sync" + "sync/atomic" +) + +// channelPool implements the Pool interface based on buffered channels. +type channelPool struct { + // storage for our net.Conn connections + mu sync.RWMutex + conns chan net.Conn + + // net.Conn generator + factory Factory + nOpenConns int64 +} + +// Factory is a function to create new connections. +type Factory func() (net.Conn, error) + +// NewChannelPool returns a new pool based on buffered channels with an initial +// capacity and maximum capacity. Factory is used when initial capacity is +// greater than zero to fill the pool. A zero initialCap doesn't fill the Pool +// until a new Get() is called. During a Get(), If there is no new connection +// available in the pool, a new connection will be created via the Factory() +// method. +func NewChannelPool(initialCap, maxCap int, factory Factory) (Pool, error) { + if initialCap < 0 || maxCap <= 0 || initialCap > maxCap { + return nil, errors.New("invalid capacity settings") + } + + c := &channelPool{ + conns: make(chan net.Conn, maxCap), + factory: factory, + } + + // create initial connections, if something goes wrong, + // just close the pool error out. + for i := 0; i < initialCap; i++ { + conn, err := factory() + if err != nil { + c.Close() + return nil, fmt.Errorf("factory is not able to fill the pool: %s", err) + } + atomic.AddInt64(&c.nOpenConns, 1) + c.conns <- conn + } + + return c, nil +} + +func (c *channelPool) getConnsAndFactory() (chan net.Conn, Factory) { + c.mu.RLock() + conns := c.conns + factory := c.factory + c.mu.RUnlock() + return conns, factory +} + +// Get implements the Pool interfaces Get() method. If there is no new +// connection available in the pool, a new connection will be created via the +// Factory() method. +func (c *channelPool) Get() (net.Conn, error) { + conns, factory := c.getConnsAndFactory() + if conns == nil { + return nil, ErrClosed + } + + // wrap our connections with out custom net.Conn implementation (wrapConn + // method) that puts the connection back to the pool if it's closed. + select { + case conn := <-conns: + if conn == nil { + return nil, ErrClosed + } + + return c.wrapConn(conn), nil + default: + conn, err := factory() + if err != nil { + return nil, err + } + atomic.AddInt64(&c.nOpenConns, 1) + + return c.wrapConn(conn), nil + } +} + +// put puts the connection back to the pool. If the pool is full or closed, +// conn is simply closed. A nil conn will be rejected. +func (c *channelPool) put(conn net.Conn) error { + if conn == nil { + return errors.New("connection is nil. rejecting") + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.conns == nil { + // pool is closed, close passed connection + atomic.AddInt64(&c.nOpenConns, -1) + return conn.Close() + } + + // put the resource back into the pool. If the pool is full, this will + // block and the default case will be executed. + select { + case c.conns <- conn: + return nil + default: + // pool is full, close passed connection + atomic.AddInt64(&c.nOpenConns, -1) + + return conn.Close() + } +} + +// Close closes every connection in the pool. +func (c *channelPool) Close() { + c.mu.Lock() + conns := c.conns + c.conns = nil + c.factory = nil + c.mu.Unlock() + + if conns == nil { + return + } + + close(conns) + for conn := range conns { + conn.Close() + } + atomic.AddInt64(&c.nOpenConns, 0) +} + +// Len() returns the number of idle connections. +func (c *channelPool) Len() int { + conns, _ := c.getConnsAndFactory() + return len(conns) +} + +// Stats returns stats for the pool. +func (c *channelPool) Stats() (map[string]interface{}, error) { + conns, _ := c.getConnsAndFactory() + return map[string]interface{}{ + "idle": len(conns), + "open_connections": c.nOpenConns, + "max_open_connections": cap(conns), + }, nil +} diff --git a/tcp/pool/channel_test.go b/tcp/pool/channel_test.go new file mode 100644 index 00000000..232dd986 --- /dev/null +++ b/tcp/pool/channel_test.go @@ -0,0 +1,292 @@ +package pool + +import ( + "log" + "math/rand" + "net" + "sync" + "testing" + "time" +) + +var ( + InitialCap = 5 + MaximumCap = 30 + network = "tcp" + address = "127.0.0.1:7777" + factory = func() (net.Conn, error) { return net.Dial(network, address) } +) + +func init() { + // used for factory function + go simpleTCPServer() + time.Sleep(time.Millisecond * 300) // wait until tcp server has been settled + + rand.Seed(time.Now().UTC().UnixNano()) +} + +func TestNew(t *testing.T) { + _, err := newChannelPool() + if err != nil { + t.Errorf("New error: %s", err) + } +} +func TestPool_Get_Impl(t *testing.T) { + p, _ := newChannelPool() + defer p.Close() + + conn, err := p.Get() + if err != nil { + t.Errorf("Get error: %s", err) + } + + _, ok := conn.(*PoolConn) + if !ok { + t.Errorf("Conn is not of type poolConn") + } +} + +func TestPool_Get(t *testing.T) { + p, _ := newChannelPool() + defer p.Close() + + _, err := p.Get() + if err != nil { + t.Errorf("Get error: %s", err) + } + + // after one get, current capacity should be lowered by one. + if p.Len() != (InitialCap - 1) { + t.Errorf("Get error. Expecting %d, got %d", + (InitialCap - 1), p.Len()) + } + + // get them all + var wg sync.WaitGroup + for i := 0; i < (InitialCap - 1); i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := p.Get() + if err != nil { + t.Errorf("Get error: %s", err) + } + }() + } + wg.Wait() + + if p.Len() != 0 { + t.Errorf("Get error. Expecting %d, got %d", + (InitialCap - 1), p.Len()) + } + + _, err = p.Get() + if err != nil { + t.Errorf("Get error: %s", err) + } +} + +func TestPool_Put(t *testing.T) { + p, err := NewChannelPool(0, 30, factory) + if err != nil { + t.Fatal(err) + } + defer p.Close() + + // get/create from the pool + conns := make([]net.Conn, MaximumCap) + for i := 0; i < MaximumCap; i++ { + conn, _ := p.Get() + conns[i] = conn + } + + // now put them all back + for _, conn := range conns { + conn.Close() + } + + if p.Len() != MaximumCap { + t.Errorf("Put error len. Expecting %d, got %d", + 1, p.Len()) + } + + conn, _ := p.Get() + p.Close() // close pool + + conn.Close() // try to put into a full pool + if p.Len() != 0 { + t.Errorf("Put error. Closed pool shouldn't allow to put connections.") + } +} + +func TestPool_PutUnusableConn(t *testing.T) { + p, _ := newChannelPool() + defer p.Close() + + // ensure pool is not empty + conn, _ := p.Get() + conn.Close() + + poolSize := p.Len() + conn, _ = p.Get() + conn.Close() + if p.Len() != poolSize { + t.Errorf("Pool size is expected to be equal to initial size") + } + + conn, _ = p.Get() + if pc, ok := conn.(*PoolConn); !ok { + t.Errorf("impossible") + } else { + pc.MarkUnusable() + } + conn.Close() + if p.Len() != poolSize-1 { + t.Errorf("Pool size is expected to be initial_size - 1, %d, %d", p.Len(), poolSize-1) + } +} + +func TestPool_UsedCapacity(t *testing.T) { + p, _ := newChannelPool() + defer p.Close() + + if p.Len() != InitialCap { + t.Errorf("InitialCap error. Expecting %d, got %d", + InitialCap, p.Len()) + } +} + +func TestPool_Close(t *testing.T) { + p, _ := newChannelPool() + + // now close it and test all cases we are expecting. + p.Close() + + c := p.(*channelPool) + + if c.conns != nil { + t.Errorf("Close error, conns channel should be nil") + } + + if c.factory != nil { + t.Errorf("Close error, factory should be nil") + } + + _, err := p.Get() + if err == nil { + t.Errorf("Close error, get conn should return an error") + } + + if p.Len() != 0 { + t.Errorf("Close error used capacity. Expecting 0, got %d", p.Len()) + } +} + +func TestPoolConcurrent(t *testing.T) { + p, _ := newChannelPool() + pipe := make(chan net.Conn, 0) + + go func() { + p.Close() + }() + + for i := 0; i < MaximumCap; i++ { + go func() { + conn, _ := p.Get() + + pipe <- conn + }() + + go func() { + conn := <-pipe + if conn == nil { + return + } + conn.Close() + }() + } +} + +func TestPoolWriteRead(t *testing.T) { + p, _ := NewChannelPool(0, 30, factory) + + conn, _ := p.Get() + + msg := "hello" + _, err := conn.Write([]byte(msg)) + if err != nil { + t.Error(err) + } +} + +func TestPoolConcurrent2(t *testing.T) { + p, _ := NewChannelPool(0, 30, factory) + + var wg sync.WaitGroup + + go func() { + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + conn, _ := p.Get() + time.Sleep(time.Millisecond * time.Duration(rand.Intn(100))) + conn.Close() + wg.Done() + }(i) + } + }() + + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + conn, _ := p.Get() + time.Sleep(time.Millisecond * time.Duration(rand.Intn(100))) + conn.Close() + wg.Done() + }(i) + } + + wg.Wait() +} + +func TestPoolConcurrent3(t *testing.T) { + p, _ := NewChannelPool(0, 1, factory) + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + p.Close() + wg.Done() + }() + + if conn, err := p.Get(); err == nil { + conn.Close() + } + + wg.Wait() +} + +func newChannelPool() (Pool, error) { + return NewChannelPool(InitialCap, MaximumCap, factory) +} + +func simpleTCPServer() { + l, err := net.Listen(network, address) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + for { + conn, err := l.Accept() + if err != nil { + log.Fatal(err) + } + + go func() { + buffer := make([]byte, 256) + conn.Read(buffer) + }() + } +} diff --git a/tcp/pool/conn.go b/tcp/pool/conn.go new file mode 100644 index 00000000..693488c8 --- /dev/null +++ b/tcp/pool/conn.go @@ -0,0 +1,43 @@ +package pool + +import ( + "net" + "sync" +) + +// PoolConn is a wrapper around net.Conn to modify the the behavior of +// net.Conn's Close() method. +type PoolConn struct { + net.Conn + mu sync.RWMutex + c *channelPool + unusable bool +} + +// Close() puts the given connects back to the pool instead of closing it. +func (p *PoolConn) Close() error { + p.mu.RLock() + defer p.mu.RUnlock() + + if p.unusable { + if p.Conn != nil { + return p.Conn.Close() + } + return nil + } + return p.c.put(p.Conn) +} + +// MarkUnusable() marks the connection not usable any more, to let the pool close it instead of returning it to pool. +func (p *PoolConn) MarkUnusable() { + p.mu.Lock() + p.unusable = true + p.mu.Unlock() +} + +// newConn wraps a standard net.Conn to a poolConn net.Conn. +func (c *channelPool) wrapConn(conn net.Conn) net.Conn { + p := &PoolConn{c: c} + p.Conn = conn + return p +} diff --git a/tcp/pool/conn_test.go b/tcp/pool/conn_test.go new file mode 100644 index 00000000..55f9237f --- /dev/null +++ b/tcp/pool/conn_test.go @@ -0,0 +1,10 @@ +package pool + +import ( + "net" + "testing" +) + +func TestConn_Impl(t *testing.T) { + var _ net.Conn = new(PoolConn) +} diff --git a/tcp/pool/pool.go b/tcp/pool/pool.go new file mode 100644 index 00000000..7c66c370 --- /dev/null +++ b/tcp/pool/pool.go @@ -0,0 +1,31 @@ +// Package pool implements a pool of net.Conn interfaces to manage and reuse them. +package pool + +import ( + "errors" + "net" +) + +var ( + // ErrClosed is the error resulting if the pool is closed via pool.Close(). + ErrClosed = errors.New("pool is closed") +) + +// Pool interface describes a pool implementation. A pool should have maximum +// capacity. An ideal pool is threadsafe and easy to use. +type Pool interface { + // Get returns a new connection from the pool. Closing the connections puts + // it back to the Pool. Closing it when the pool is destroyed or full will + // be counted as an error. + Get() (net.Conn, error) + + // Close closes the pool and all its connections. After Close() the pool is + // no longer usable. + Close() + + // Len returns the current number of connections of the pool. + Len() int + + // Stats returns stats about the pool. + Stats() (map[string]interface{}, error) +}