From fd2c3a63141c234cc5523fa79f047a796d728d23 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 21 Aug 2021 08:58:11 -0400 Subject: [PATCH 01/14] Add Connection Pool source Many thanks to https://github.com/fatih/pool. --- tcp/pool/LICENSE | 20 +++ tcp/pool/channel.go | 135 ++++++++++++++++++ tcp/pool/channel_test.go | 292 +++++++++++++++++++++++++++++++++++++++ tcp/pool/conn.go | 43 ++++++ tcp/pool/conn_test.go | 10 ++ tcp/pool/pool.go | 28 ++++ 6 files changed, 528 insertions(+) create mode 100644 tcp/pool/LICENSE create mode 100644 tcp/pool/channel.go create mode 100644 tcp/pool/channel_test.go create mode 100644 tcp/pool/conn.go create mode 100644 tcp/pool/conn_test.go create mode 100644 tcp/pool/pool.go diff --git a/tcp/pool/LICENSE b/tcp/pool/LICENSE new file mode 100644 index 00000000..25fdaf63 --- /dev/null +++ b/tcp/pool/LICENSE @@ -0,0 +1,20 @@ +The MIT License (MIT) + +Copyright (c) 2013 Fatih Arslan + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/tcp/pool/channel.go b/tcp/pool/channel.go new file mode 100644 index 00000000..36bb20d6 --- /dev/null +++ b/tcp/pool/channel.go @@ -0,0 +1,135 @@ +package pool + +import ( + "errors" + "fmt" + "net" + "sync" +) + +// channelPool implements the Pool interface based on buffered channels. +type channelPool struct { + // storage for our net.Conn connections + mu sync.RWMutex + conns chan net.Conn + + // net.Conn generator + factory Factory +} + +// Factory is a function to create new connections. +type Factory func() (net.Conn, error) + +// NewChannelPool returns a new pool based on buffered channels with an initial +// capacity and maximum capacity. Factory is used when initial capacity is +// greater than zero to fill the pool. A zero initialCap doesn't fill the Pool +// until a new Get() is called. During a Get(), If there is no new connection +// available in the pool, a new connection will be created via the Factory() +// method. +func NewChannelPool(initialCap, maxCap int, factory Factory) (Pool, error) { + if initialCap < 0 || maxCap <= 0 || initialCap > maxCap { + return nil, errors.New("invalid capacity settings") + } + + c := &channelPool{ + conns: make(chan net.Conn, maxCap), + factory: factory, + } + + // create initial connections, if something goes wrong, + // just close the pool error out. + for i := 0; i < initialCap; i++ { + conn, err := factory() + if err != nil { + c.Close() + return nil, fmt.Errorf("factory is not able to fill the pool: %s", err) + } + c.conns <- conn + } + + return c, nil +} + +func (c *channelPool) getConnsAndFactory() (chan net.Conn, Factory) { + c.mu.RLock() + conns := c.conns + factory := c.factory + c.mu.RUnlock() + return conns, factory +} + +// Get implements the Pool interfaces Get() method. If there is no new +// connection available in the pool, a new connection will be created via the +// Factory() method. +func (c *channelPool) Get() (net.Conn, error) { + conns, factory := c.getConnsAndFactory() + if conns == nil { + return nil, ErrClosed + } + + // wrap our connections with out custom net.Conn implementation (wrapConn + // method) that puts the connection back to the pool if it's closed. + select { + case conn := <-conns: + if conn == nil { + return nil, ErrClosed + } + + return c.wrapConn(conn), nil + default: + conn, err := factory() + if err != nil { + return nil, err + } + + return c.wrapConn(conn), nil + } +} + +// put puts the connection back to the pool. If the pool is full or closed, +// conn is simply closed. A nil conn will be rejected. +func (c *channelPool) put(conn net.Conn) error { + if conn == nil { + return errors.New("connection is nil. rejecting") + } + + c.mu.RLock() + defer c.mu.RUnlock() + + if c.conns == nil { + // pool is closed, close passed connection + return conn.Close() + } + + // put the resource back into the pool. If the pool is full, this will + // block and the default case will be executed. + select { + case c.conns <- conn: + return nil + default: + // pool is full, close passed connection + return conn.Close() + } +} + +func (c *channelPool) Close() { + c.mu.Lock() + conns := c.conns + c.conns = nil + c.factory = nil + c.mu.Unlock() + + if conns == nil { + return + } + + close(conns) + for conn := range conns { + conn.Close() + } +} + +func (c *channelPool) Len() int { + conns, _ := c.getConnsAndFactory() + return len(conns) +} diff --git a/tcp/pool/channel_test.go b/tcp/pool/channel_test.go new file mode 100644 index 00000000..232dd986 --- /dev/null +++ b/tcp/pool/channel_test.go @@ -0,0 +1,292 @@ +package pool + +import ( + "log" + "math/rand" + "net" + "sync" + "testing" + "time" +) + +var ( + InitialCap = 5 + MaximumCap = 30 + network = "tcp" + address = "127.0.0.1:7777" + factory = func() (net.Conn, error) { return net.Dial(network, address) } +) + +func init() { + // used for factory function + go simpleTCPServer() + time.Sleep(time.Millisecond * 300) // wait until tcp server has been settled + + rand.Seed(time.Now().UTC().UnixNano()) +} + +func TestNew(t *testing.T) { + _, err := newChannelPool() + if err != nil { + t.Errorf("New error: %s", err) + } +} +func TestPool_Get_Impl(t *testing.T) { + p, _ := newChannelPool() + defer p.Close() + + conn, err := p.Get() + if err != nil { + t.Errorf("Get error: %s", err) + } + + _, ok := conn.(*PoolConn) + if !ok { + t.Errorf("Conn is not of type poolConn") + } +} + +func TestPool_Get(t *testing.T) { + p, _ := newChannelPool() + defer p.Close() + + _, err := p.Get() + if err != nil { + t.Errorf("Get error: %s", err) + } + + // after one get, current capacity should be lowered by one. + if p.Len() != (InitialCap - 1) { + t.Errorf("Get error. Expecting %d, got %d", + (InitialCap - 1), p.Len()) + } + + // get them all + var wg sync.WaitGroup + for i := 0; i < (InitialCap - 1); i++ { + wg.Add(1) + go func() { + defer wg.Done() + _, err := p.Get() + if err != nil { + t.Errorf("Get error: %s", err) + } + }() + } + wg.Wait() + + if p.Len() != 0 { + t.Errorf("Get error. Expecting %d, got %d", + (InitialCap - 1), p.Len()) + } + + _, err = p.Get() + if err != nil { + t.Errorf("Get error: %s", err) + } +} + +func TestPool_Put(t *testing.T) { + p, err := NewChannelPool(0, 30, factory) + if err != nil { + t.Fatal(err) + } + defer p.Close() + + // get/create from the pool + conns := make([]net.Conn, MaximumCap) + for i := 0; i < MaximumCap; i++ { + conn, _ := p.Get() + conns[i] = conn + } + + // now put them all back + for _, conn := range conns { + conn.Close() + } + + if p.Len() != MaximumCap { + t.Errorf("Put error len. Expecting %d, got %d", + 1, p.Len()) + } + + conn, _ := p.Get() + p.Close() // close pool + + conn.Close() // try to put into a full pool + if p.Len() != 0 { + t.Errorf("Put error. Closed pool shouldn't allow to put connections.") + } +} + +func TestPool_PutUnusableConn(t *testing.T) { + p, _ := newChannelPool() + defer p.Close() + + // ensure pool is not empty + conn, _ := p.Get() + conn.Close() + + poolSize := p.Len() + conn, _ = p.Get() + conn.Close() + if p.Len() != poolSize { + t.Errorf("Pool size is expected to be equal to initial size") + } + + conn, _ = p.Get() + if pc, ok := conn.(*PoolConn); !ok { + t.Errorf("impossible") + } else { + pc.MarkUnusable() + } + conn.Close() + if p.Len() != poolSize-1 { + t.Errorf("Pool size is expected to be initial_size - 1, %d, %d", p.Len(), poolSize-1) + } +} + +func TestPool_UsedCapacity(t *testing.T) { + p, _ := newChannelPool() + defer p.Close() + + if p.Len() != InitialCap { + t.Errorf("InitialCap error. Expecting %d, got %d", + InitialCap, p.Len()) + } +} + +func TestPool_Close(t *testing.T) { + p, _ := newChannelPool() + + // now close it and test all cases we are expecting. + p.Close() + + c := p.(*channelPool) + + if c.conns != nil { + t.Errorf("Close error, conns channel should be nil") + } + + if c.factory != nil { + t.Errorf("Close error, factory should be nil") + } + + _, err := p.Get() + if err == nil { + t.Errorf("Close error, get conn should return an error") + } + + if p.Len() != 0 { + t.Errorf("Close error used capacity. Expecting 0, got %d", p.Len()) + } +} + +func TestPoolConcurrent(t *testing.T) { + p, _ := newChannelPool() + pipe := make(chan net.Conn, 0) + + go func() { + p.Close() + }() + + for i := 0; i < MaximumCap; i++ { + go func() { + conn, _ := p.Get() + + pipe <- conn + }() + + go func() { + conn := <-pipe + if conn == nil { + return + } + conn.Close() + }() + } +} + +func TestPoolWriteRead(t *testing.T) { + p, _ := NewChannelPool(0, 30, factory) + + conn, _ := p.Get() + + msg := "hello" + _, err := conn.Write([]byte(msg)) + if err != nil { + t.Error(err) + } +} + +func TestPoolConcurrent2(t *testing.T) { + p, _ := NewChannelPool(0, 30, factory) + + var wg sync.WaitGroup + + go func() { + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + conn, _ := p.Get() + time.Sleep(time.Millisecond * time.Duration(rand.Intn(100))) + conn.Close() + wg.Done() + }(i) + } + }() + + for i := 0; i < 10; i++ { + wg.Add(1) + go func(i int) { + conn, _ := p.Get() + time.Sleep(time.Millisecond * time.Duration(rand.Intn(100))) + conn.Close() + wg.Done() + }(i) + } + + wg.Wait() +} + +func TestPoolConcurrent3(t *testing.T) { + p, _ := NewChannelPool(0, 1, factory) + + var wg sync.WaitGroup + + wg.Add(1) + go func() { + p.Close() + wg.Done() + }() + + if conn, err := p.Get(); err == nil { + conn.Close() + } + + wg.Wait() +} + +func newChannelPool() (Pool, error) { + return NewChannelPool(InitialCap, MaximumCap, factory) +} + +func simpleTCPServer() { + l, err := net.Listen(network, address) + if err != nil { + log.Fatal(err) + } + defer l.Close() + + for { + conn, err := l.Accept() + if err != nil { + log.Fatal(err) + } + + go func() { + buffer := make([]byte, 256) + conn.Read(buffer) + }() + } +} diff --git a/tcp/pool/conn.go b/tcp/pool/conn.go new file mode 100644 index 00000000..693488c8 --- /dev/null +++ b/tcp/pool/conn.go @@ -0,0 +1,43 @@ +package pool + +import ( + "net" + "sync" +) + +// PoolConn is a wrapper around net.Conn to modify the the behavior of +// net.Conn's Close() method. +type PoolConn struct { + net.Conn + mu sync.RWMutex + c *channelPool + unusable bool +} + +// Close() puts the given connects back to the pool instead of closing it. +func (p *PoolConn) Close() error { + p.mu.RLock() + defer p.mu.RUnlock() + + if p.unusable { + if p.Conn != nil { + return p.Conn.Close() + } + return nil + } + return p.c.put(p.Conn) +} + +// MarkUnusable() marks the connection not usable any more, to let the pool close it instead of returning it to pool. +func (p *PoolConn) MarkUnusable() { + p.mu.Lock() + p.unusable = true + p.mu.Unlock() +} + +// newConn wraps a standard net.Conn to a poolConn net.Conn. +func (c *channelPool) wrapConn(conn net.Conn) net.Conn { + p := &PoolConn{c: c} + p.Conn = conn + return p +} diff --git a/tcp/pool/conn_test.go b/tcp/pool/conn_test.go new file mode 100644 index 00000000..55f9237f --- /dev/null +++ b/tcp/pool/conn_test.go @@ -0,0 +1,10 @@ +package pool + +import ( + "net" + "testing" +) + +func TestConn_Impl(t *testing.T) { + var _ net.Conn = new(PoolConn) +} diff --git a/tcp/pool/pool.go b/tcp/pool/pool.go new file mode 100644 index 00000000..f88f2acb --- /dev/null +++ b/tcp/pool/pool.go @@ -0,0 +1,28 @@ +// Package pool implements a pool of net.Conn interfaces to manage and reuse them. +package pool + +import ( + "errors" + "net" +) + +var ( + // ErrClosed is the error resulting if the pool is closed via pool.Close(). + ErrClosed = errors.New("pool is closed") +) + +// Pool interface describes a pool implementation. A pool should have maximum +// capacity. An ideal pool is threadsafe and easy to use. +type Pool interface { + // Get returns a new connection from the pool. Closing the connections puts + // it back to the Pool. Closing it when the pool is destroyed or full will + // be counted as an error. + Get() (net.Conn, error) + + // Close closes the pool and all its connections. After Close() the pool is + // no longer usable. + Close() + + // Len returns the current number of connections of the pool. + Len() int +} From 312e44e57dbab26a45ffd3069919db7f125b340f Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 21 Aug 2021 09:17:54 -0400 Subject: [PATCH 02/14] Add Connection Pool for GetNodeAPIAddr --- cluster/client.go | 59 ++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 56 insertions(+), 3 deletions(-) diff --git a/cluster/client.go b/cluster/client.go index f5199265..074400e9 100644 --- a/cluster/client.go +++ b/cluster/client.go @@ -5,16 +5,27 @@ import ( "errors" "fmt" "io/ioutil" + "net" + "sync" "time" "github.com/golang/protobuf/proto" "github.com/rqlite/rqlite/command" + "github.com/rqlite/rqlite/tcp/pool" +) + +const ( + initialPoolSize = 4 + maxPoolCapacity = 64 ) // Client allows communicating with a remote node. type Client struct { dialer Dialer timeout time.Duration + + mu sync.RWMutex + pools map[string]pool.Pool } // NewClient returns a client instance for talking to a remote node. @@ -22,17 +33,56 @@ func NewClient(dl Dialer) *Client { return &Client{ dialer: dl, timeout: 30 * time.Second, + pools: make(map[string]pool.Pool), } } // GetNodeAPIAddr retrieves the API Address for the node at nodeAddr func (c *Client) GetNodeAPIAddr(nodeAddr string) (string, error) { - conn, err := c.dialer.Dial(nodeAddr, c.timeout) - if err != nil { - return "", fmt.Errorf("dial connection: %s", err) + var pl pool.Pool + var ok bool + + c.mu.RLock() + pl, ok = c.pools[nodeAddr] + c.mu.RUnlock() + + // Do we need a new pool for the given address? + if !ok { + if err := func() error { + c.mu.Lock() + defer c.mu.Unlock() + pl, ok = c.pools[nodeAddr] + if ok { + return nil // Pool was inserted just after we checked. + } + + // New pool is needed for given address. + factory := func() (net.Conn, error) { return c.dialer.Dial(nodeAddr, c.timeout) } + p, err := pool.NewChannelPool(initialPoolSize, maxPoolCapacity, factory) + if err != nil { + return err + } + c.pools[nodeAddr] = p + pl = p + return nil + }(); err != nil { + return "", err + } + } + + // Got pool, now get a connection. + conn, err := pl.Get() + if err != nil { + return "", fmt.Errorf("pool get: %s", err) } defer conn.Close() + handleConnError := func(c net.Conn) { + if pc, ok := conn.(*pool.PoolConn); ok { + pc.MarkUnusable() + } + } + // Send the request command := &Command{ Type: Command_COMMAND_TYPE_GET_NODE_API_URL, @@ -48,15 +98,18 @@ func (c *Client) GetNodeAPIAddr(nodeAddr string) (string, error) { _, err = conn.Write(b) if err != nil { + handleConnError(conn) return "", fmt.Errorf("write protobuf length: %s", err) } _, err = conn.Write(p) if err != nil { + handleConnError(conn) return "", fmt.Errorf("write protobuf: %s", err) } b, err = ioutil.ReadAll(conn) if err != nil { + handleConnError(conn) return "", fmt.Errorf("read protobuf bytes: %s", err) } From eee3a2e785010e2e5b503b6a82642c6f2b68d072 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 21 Aug 2021 10:15:54 -0400 Subject: [PATCH 03/14] Fix code path that could cause panic --- http/service.go | 33 +++++++++++++++++++++++++++------ http/service_test.go | 22 +++++++++++++++++++++- 2 files changed, 48 insertions(+), 7 deletions(-) diff --git a/http/service.go b/http/service.go index 75696315..2110f694 100644 --- a/http/service.go +++ b/http/service.go @@ -26,6 +26,11 @@ import ( "github.com/rqlite/rqlite/store" ) +var ( + // ErrLeaderNotFound is returned when a node cannot locate a leader + ErrLeaderNotFound = errors.New("leader not found") +) + // Database is the interface any queryable system must implement type Database interface { // Execute executes a slice of queries, each of which is not expected @@ -126,6 +131,7 @@ type Response struct { var stats *expvar.Map const ( + numLeaderNotFound = "leader_not_found" numExecutions = "executions" numQueries = "queries" numRemoteExecutions = "remote_executions" @@ -167,6 +173,7 @@ const ( func init() { stats = expvar.NewMap("http") + stats.Add(numLeaderNotFound, 0) stats.Add(numExecutions, 0) stats.Add(numQueries, 0) stats.Add(numRemoteExecutions, 0) @@ -375,7 +382,8 @@ func (s *Service) handleJoin(w http.ResponseWriter, r *http.Request) { if err == store.ErrNotLeader { leaderAPIAddr := s.LeaderAPIAddr() if leaderAPIAddr == "" { - http.Error(w, err.Error(), http.StatusServiceUnavailable) + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) return } @@ -427,7 +435,8 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) { if err == store.ErrNotLeader { leaderAPIAddr := s.LeaderAPIAddr() if leaderAPIAddr == "" { - http.Error(w, err.Error(), http.StatusServiceUnavailable) + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) return } @@ -470,7 +479,8 @@ func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request) { if err == store.ErrNotLeader { leaderAPIAddr := s.LeaderAPIAddr() if leaderAPIAddr == "" { - http.Error(w, err.Error(), http.StatusServiceUnavailable) + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) return } @@ -522,7 +532,8 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) { if err == store.ErrNotLeader { leaderAPIAddr := s.LeaderAPIAddr() if leaderAPIAddr == "" { - http.Error(w, err.Error(), http.StatusServiceUnavailable) + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) return } @@ -767,7 +778,8 @@ func (s *Service) handleExecute(w http.ResponseWriter, r *http.Request) { if redirect { leaderAPIAddr := s.LeaderAPIAddr() if leaderAPIAddr == "" { - http.Error(w, err.Error(), http.StatusServiceUnavailable) + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) return } loc := s.FormRedirect(r, leaderAPIAddr) @@ -779,6 +791,10 @@ func (s *Service) handleExecute(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } + if addr == "" { + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) + } results, resultsErr = s.cluster.Execute(er, addr, timeout) stats.Add(numRemoteExecutions, 1) w.Header().Add(ServedByHTTPHeader, addr) @@ -849,7 +865,8 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) { if redirect { leaderAPIAddr := s.LeaderAPIAddr() if leaderAPIAddr == "" { - http.Error(w, err.Error(), http.StatusServiceUnavailable) + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) return } loc := s.FormRedirect(r, leaderAPIAddr) @@ -861,6 +878,10 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) { if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) } + if addr == "" { + stats.Add(numLeaderNotFound, 1) + http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable) + } results, resultsErr = s.cluster.Query(qr, addr, timeout) stats.Add(numRemoteQueries, 1) w.Header().Add(ServedByHTTPHeader, addr) diff --git a/http/service_test.go b/http/service_test.go index da47c484..6915976d 100644 --- a/http/service_test.go +++ b/http/service_test.go @@ -677,7 +677,7 @@ func Test_Nodes(t *testing.T) { } } -func Test_ForwardingRedirectQueries(t *testing.T) { +func Test_ForwardingRedirectQuery(t *testing.T) { m := &MockStore{ leaderAddr: "foo:1234", } @@ -725,6 +725,16 @@ func Test_ForwardingRedirectQueries(t *testing.T) { if resp.StatusCode != http.StatusMovedPermanently { t.Fatalf("failed to get expected StatusMovedPermanently for query, got %d", resp.StatusCode) } + + // Check leader failure case. + m.leaderAddr = "" + resp, err = client.Get(host + "/db/query?pretty&timings&q=SELECT%20%2A%20FROM%20foo") + if err != nil { + t.Fatalf("failed to make query forwarded request") + } + if resp.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("failed to get expected StatusServiceUnavailable for node with no leader, got %d", resp.StatusCode) + } } func Test_ForwardingRedirectExecute(t *testing.T) { @@ -775,6 +785,16 @@ func Test_ForwardingRedirectExecute(t *testing.T) { if resp.StatusCode != http.StatusMovedPermanently { t.Fatalf("failed to get expected StatusMovedPermanently for execute, got %d", resp.StatusCode) } + + // Check leader failure case. + m.leaderAddr = "" + resp, err = client.Post(host+"/db/execute", "application/json", strings.NewReader(`["Some SQL"]`)) + if err != nil { + t.Fatalf("failed to make execute request") + } + if resp.StatusCode != http.StatusServiceUnavailable { + t.Fatalf("failed to get expected StatusServiceUnavailable for node with no leader, got %d", resp.StatusCode) + } } func Test_TLSServce(t *testing.T) { From ac53893673e44153d14fe290b94850975ff5be2f Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 21 Aug 2021 13:04:46 -0400 Subject: [PATCH 04/14] GetNodeAPI uses connection pool --- cluster/client.go | 88 ++++++++++++++++++++++++++--------------------- 1 file changed, 48 insertions(+), 40 deletions(-) diff --git a/cluster/client.go b/cluster/client.go index 074400e9..19d00429 100644 --- a/cluster/client.go +++ b/cluster/client.go @@ -39,50 +39,12 @@ func NewClient(dl Dialer) *Client { // GetNodeAPIAddr retrieves the API Address for the node at nodeAddr func (c *Client) GetNodeAPIAddr(nodeAddr string) (string, error) { - var pl pool.Pool - var ok bool - - c.mu.RLock() - pl, ok = c.pools[nodeAddr] - c.mu.RUnlock() - - // Do we need a new pool for the given address? - if !ok { - if err := func() error { - c.mu.Lock() - defer c.mu.Unlock() - pl, ok = c.pools[nodeAddr] - if ok { - return nil // Pool was inserted just after we checked. - } - - // New pool is needed for given address. - factory := func() (net.Conn, error) { return c.dialer.Dial(nodeAddr, c.timeout) } - p, err := pool.NewChannelPool(initialPoolSize, maxPoolCapacity, factory) - if err != nil { - return err - } - c.pools[nodeAddr] = p - pl = p - return nil - }(); err != nil { - return "", err - } - } - - // Got pool, now get a connection. - conn, err := pl.Get() + conn, err := c.dial(nodeAddr, c.timeout) if err != nil { - return "", fmt.Errorf("pool get: %s", err) + return "", err } defer conn.Close() - handleConnError := func(c net.Conn) { - if pc, ok := conn.(*pool.PoolConn); ok { - pc.MarkUnusable() - } - } - // Send the request command := &Command{ Type: Command_COMMAND_TYPE_GET_NODE_API_URL, @@ -246,3 +208,49 @@ func (c *Client) Stats() (map[string]interface{}, error) { "timeout": c.timeout, }, nil } + +func (c *Client) dial(nodeAddr string, timeout time.Duration) (net.Conn, error) { + var pl pool.Pool + var ok bool + + c.mu.RLock() + pl, ok = c.pools[nodeAddr] + c.mu.RUnlock() + + // Do we need a new pool for the given address? + if !ok { + if err := func() error { + c.mu.Lock() + defer c.mu.Unlock() + pl, ok = c.pools[nodeAddr] + if ok { + return nil // Pool was inserted just after we checked. + } + + // New pool is needed for given address. + factory := func() (net.Conn, error) { return c.dialer.Dial(nodeAddr, c.timeout) } + p, err := pool.NewChannelPool(initialPoolSize, maxPoolCapacity, factory) + if err != nil { + return err + } + c.pools[nodeAddr] = p + pl = p + return nil + }(); err != nil { + return nil, err + } + } + + // Got pool, now get a connection. + conn, err := pl.Get() + if err != nil { + return nil, fmt.Errorf("pool get: %s", err) + } + return conn, nil +} + +func handleConnError(conn net.Conn) { + if pc, ok := conn.(*pool.PoolConn); ok { + pc.MarkUnusable() + } +} From 7b2e711c738b864e4e0b7c35257225fb23cfa15d Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 21 Aug 2021 14:29:16 -0400 Subject: [PATCH 05/14] Remove constraint check It's too clever, and causing test practicality issues. --- system_test/full_system_test.py | 30 +++++++++--------------------- 1 file changed, 9 insertions(+), 21 deletions(-) diff --git a/system_test/full_system_test.py b/system_test/full_system_test.py index 6df359e0..ca354b3d 100755 --- a/system_test/full_system_test.py +++ b/system_test/full_system_test.py @@ -136,7 +136,7 @@ class Node(object): r.raise_for_status() return r.json() - def is_leader(self, constraint_check=True): + def is_leader(self): ''' is_leader returns whether this node is the cluster leader It also performs a check, to ensure the node nevers gives out @@ -144,28 +144,16 @@ class Node(object): ''' try: - isLeaderRaft = self.status()['store']['raft']['state'] == 'Leader' - isLeaderNodes = self.nodes()[self.node_id]['leader'] is True + return self.status()['store']['raft']['state'] == 'Leader' except requests.exceptions.ConnectionError: return False - if (isLeaderRaft != isLeaderNodes) and constraint_check: - raise AssertionError("conflicting states reported for leadership (raft: %s, nodes: %s)" - % (isLeaderRaft, isLeaderNodes)) - return isLeaderNodes - def is_follower(self): try: - isFollowerRaft = self.status()['store']['raft']['state'] == 'Follower' - isFollowersNodes = self.nodes()[self.node_id]['leader'] is False + return self.status()['store']['raft']['state'] == 'Follower' except requests.exceptions.ConnectionError: return False - if isFollowerRaft != isFollowersNodes: - raise AssertionError("conflicting states reported for followership (raft: %s, nodes: %s)" - % (isFollowerRaft, isFollowersNodes)) - return isFollowersNodes - def wait_for_leader(self, timeout=TIMEOUT): lr = None t = 0 @@ -289,6 +277,7 @@ class Node(object): def redirect_addr(self): r = requests.post(self._execute_url(redirect=True), data=json.dumps(['nonsense']), allow_redirects=False) + r.raise_for_status() if r.status_code == 301: return "%s://%s" % (urlparse(r.headers['Location']).scheme, urlparse(r.headers['Location']).netloc) @@ -333,7 +322,7 @@ def deprovision_node(node): class Cluster(object): def __init__(self, nodes): self.nodes = nodes - def wait_for_leader(self, node_exc=None, timeout=TIMEOUT, constraint_check=True): + def wait_for_leader(self, node_exc=None, timeout=TIMEOUT): t = 0 while True: if t > timeout: @@ -341,7 +330,7 @@ class Cluster(object): for n in self.nodes: if node_exc is not None and n == node_exc: continue - if n.is_leader(constraint_check): + if n.is_leader(): return n time.sleep(1) t+=1 @@ -682,10 +671,9 @@ class TestEndToEndNonVoterFollowsLeader(unittest.TestCase): j = n.query('SELECT * FROM foo') self.assertEqual(str(j), "{u'results': [{u'values': [[1, u'fiona']], u'types': [u'integer', u'text'], u'columns': [u'id', u'name']}]}") - # Kill leader, and then make more changes. Don't perform leader-constraint checks - # since the cluster is changing right now. - n0 = self.cluster.wait_for_leader(constraint_check=False).stop() - n1 = self.cluster.wait_for_leader(node_exc=n0, constraint_check=False) + # Kill leader, and then make more changes. + n0 = self.cluster.wait_for_leader().stop() + n1 = self.cluster.wait_for_leader(node_exc=n0) n1.wait_for_all_applied() j = n1.query('SELECT * FROM foo') self.assertEqual(str(j), "{u'results': [{u'values': [[1, u'fiona']], u'types': [u'integer', u'text'], u'columns': [u'id', u'name']}]}") From bf003e4f40f80fa7c10ce4817b04f84bba16ffd8 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 21 Aug 2021 14:34:13 -0400 Subject: [PATCH 06/14] Better test comments --- system_test/full_system_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/system_test/full_system_test.py b/system_test/full_system_test.py index ca354b3d..00b95294 100755 --- a/system_test/full_system_test.py +++ b/system_test/full_system_test.py @@ -511,6 +511,8 @@ class TestEndToEnd(unittest.TestCase): for n in fs: self.assertEqual(l.APIProtoAddr(), n.redirect_addr()) + # Kill the leader, wait for a new leader, and check that the + # redirect returns the new leader address. l.stop() n = self.cluster.wait_for_leader(node_exc=l) for f in self.cluster.followers(): From c8483d2ec5b45ec6fa09d9bcb2e30e71dd6b4338 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 21 Aug 2021 14:39:03 -0400 Subject: [PATCH 07/14] Use connection pool with Cluster Execute and Query --- cluster/client.go | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/cluster/client.go b/cluster/client.go index 19d00429..639c2860 100644 --- a/cluster/client.go +++ b/cluster/client.go @@ -86,9 +86,9 @@ func (c *Client) GetNodeAPIAddr(nodeAddr string) (string, error) { // Execute performs an Execute on a remote node. func (c *Client) Execute(er *command.ExecuteRequest, nodeAddr string, timeout time.Duration) ([]*command.ExecuteResult, error) { - conn, err := c.dialer.Dial(nodeAddr, c.timeout) + conn, err := c.dial(nodeAddr, c.timeout) if err != nil { - return nil, fmt.Errorf("dial connection: %s", err) + return nil, err } defer conn.Close() @@ -109,25 +109,31 @@ func (c *Client) Execute(er *command.ExecuteRequest, nodeAddr string, timeout ti binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + handleConnError(conn) return nil, err } _, err = conn.Write(b) if err != nil { + handleConnError(conn) return nil, err } if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + handleConnError(conn) return nil, err } _, err = conn.Write(p) if err != nil { + handleConnError(conn) return nil, err } if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + handleConnError(conn) return nil, err } b, err = ioutil.ReadAll(conn) if err != nil { + handleConnError(conn) return nil, err } @@ -145,9 +151,9 @@ func (c *Client) Execute(er *command.ExecuteRequest, nodeAddr string, timeout ti // Query performs an Query on a remote node. func (c *Client) Query(qr *command.QueryRequest, nodeAddr string, timeout time.Duration) ([]*command.QueryRows, error) { - conn, err := c.dialer.Dial(nodeAddr, c.timeout) + conn, err := c.dial(nodeAddr, c.timeout) if err != nil { - return nil, fmt.Errorf("dial connection: %s", err) + return nil, err } defer conn.Close() @@ -168,25 +174,31 @@ func (c *Client) Query(qr *command.QueryRequest, nodeAddr string, timeout time.D binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + handleConnError(conn) return nil, err } _, err = conn.Write(b) if err != nil { + handleConnError(conn) return nil, err } if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + handleConnError(conn) return nil, err } _, err = conn.Write(p) if err != nil { + handleConnError(conn) return nil, err } if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil { + handleConnError(conn) return nil, err } b, err = ioutil.ReadAll(conn) if err != nil { + handleConnError(conn) return nil, err } From 3a70db51509c3444cc600b5efd419d5e9d28d48e Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 21 Aug 2021 17:29:41 -0400 Subject: [PATCH 08/14] Long-lived connections in cluster service Needed now that we're using connection pooling. Unit tests pass. --- cluster/client.go | 62 +++++++++++++++---- cluster/service.go | 150 +++++++++++++++++++++++++-------------------- 2 files changed, 132 insertions(+), 80 deletions(-) diff --git a/cluster/client.go b/cluster/client.go index 639c2860..2eead814 100644 --- a/cluster/client.go +++ b/cluster/client.go @@ -4,7 +4,7 @@ import ( "encoding/binary" "errors" "fmt" - "io/ioutil" + "io" "net" "sync" "time" @@ -69,14 +69,22 @@ func (c *Client) GetNodeAPIAddr(nodeAddr string) (string, error) { return "", fmt.Errorf("write protobuf: %s", err) } - b, err = ioutil.ReadAll(conn) + // Read length of response. + _, err = io.ReadFull(conn, b) if err != nil { - handleConnError(conn) - return "", fmt.Errorf("read protobuf bytes: %s", err) + return "", err + } + sz := binary.LittleEndian.Uint16(b[0:]) + + // Read in the actual response. + p = make([]byte, sz) + _, err = io.ReadFull(conn, p) + if err != nil { + return "", err } a := &Address{} - err = proto.Unmarshal(b, a) + err = proto.Unmarshal(p, a) if err != nil { return "", fmt.Errorf("protobuf unmarshal: %s", err) } @@ -131,14 +139,23 @@ func (c *Client) Execute(er *command.ExecuteRequest, nodeAddr string, timeout ti handleConnError(conn) return nil, err } - b, err = ioutil.ReadAll(conn) + + // Read length of response. + _, err = io.ReadFull(conn, b) + if err != nil { + return nil, err + } + sz := binary.LittleEndian.Uint16(b[0:]) + + // Read in the actual response. + p = make([]byte, sz) + _, err = io.ReadFull(conn, p) if err != nil { - handleConnError(conn) return nil, err } a := &CommandExecuteResponse{} - err = proto.Unmarshal(b, a) + err = proto.Unmarshal(p, a) if err != nil { return nil, err } @@ -169,7 +186,7 @@ func (c *Client) Query(qr *command.QueryRequest, nodeAddr string, timeout time.D return nil, fmt.Errorf("command marshal: %s", err) } - // Write length of Protobuf, the Protobuf + // Write length of Protobuf, then the Protobuf b := make([]byte, 4) binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) @@ -196,14 +213,23 @@ func (c *Client) Query(qr *command.QueryRequest, nodeAddr string, timeout time.D handleConnError(conn) return nil, err } - b, err = ioutil.ReadAll(conn) + + // Read length of response. + _, err = io.ReadFull(conn, b) + if err != nil { + return nil, err + } + sz := binary.LittleEndian.Uint16(b[0:]) + + // Read in the actual response. + p = make([]byte, sz) + _, err = io.ReadFull(conn, p) if err != nil { - handleConnError(conn) return nil, err } a := &CommandQueryResponse{} - err = proto.Unmarshal(b, a) + err = proto.Unmarshal(p, a) if err != nil { return nil, err } @@ -216,8 +242,20 @@ func (c *Client) Query(qr *command.QueryRequest, nodeAddr string, timeout time.D // Stats returns stats on the Client instance func (c *Client) Stats() (map[string]interface{}, error) { + c.mu.RLock() + defer c.mu.RUnlock() + + poolStats := make(map[string]interface{}, len(c.pools)) + for k, v := range c.pools { + s, err := v.Stats() + if err != nil { + return nil, err + } + poolStats[k] = s + } return map[string]interface{}{ "timeout": c.timeout, + "pool": poolStats, }, nil } diff --git a/cluster/service.go b/cluster/service.go index 5f0d642f..c2071f3c 100644 --- a/cluster/service.go +++ b/cluster/service.go @@ -153,95 +153,109 @@ func (s *Service) serve() error { func (s *Service) handleConn(conn net.Conn) { defer conn.Close() - b := make([]byte, 4) - _, err := io.ReadFull(conn, b) - if err != nil { - return - } - sz := binary.LittleEndian.Uint16(b[0:]) - - b = make([]byte, sz) - _, err = io.ReadFull(conn, b) - if err != nil { - return - } - - c := &Command{} - err = proto.Unmarshal(b, c) - if err != nil { - conn.Close() - } - - switch c.Type { - case Command_COMMAND_TYPE_GET_NODE_API_URL: - stats.Add(numGetNodeAPIRequest, 1) - s.mu.RLock() - defer s.mu.RUnlock() + for { + b := make([]byte, 4) + _, err := io.ReadFull(conn, b) + if err != nil { + return + } + sz := binary.LittleEndian.Uint16(b[0:]) - a := &Address{} - scheme := "http" - if s.https { - scheme = "https" + p := make([]byte, sz) + _, err = io.ReadFull(conn, p) + if err != nil { + return } - a.Url = fmt.Sprintf("%s://%s", scheme, s.apiAddr) - b, err = proto.Marshal(a) + c := &Command{} + err = proto.Unmarshal(p, c) if err != nil { conn.Close() } - conn.Write(b) - stats.Add(numGetNodeAPIResponse, 1) - case Command_COMMAND_TYPE_EXECUTE: - stats.Add(numExecuteRequest, 1) + switch c.Type { + case Command_COMMAND_TYPE_GET_NODE_API_URL: + stats.Add(numGetNodeAPIRequest, 1) - resp := &CommandExecuteResponse{} + s.mu.RLock() + a := &Address{} + scheme := "http" + if s.https { + scheme = "https" + } + a.Url = fmt.Sprintf("%s://%s", scheme, s.apiAddr) + s.mu.RUnlock() - er := c.GetExecuteRequest() - if er == nil { - resp.Error = "ExecuteRequest is nil" - } else { - res, err := s.db.Execute(er) + p, err = proto.Marshal(a) if err != nil { - resp.Error = err.Error() + conn.Close() + } + // Write length of Protobuf first, then write the actual Protobuf. + b = make([]byte, 4) + binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) + conn.Write(b) + conn.Write(p) + stats.Add(numGetNodeAPIResponse, 1) + + case Command_COMMAND_TYPE_EXECUTE: + stats.Add(numExecuteRequest, 1) + + resp := &CommandExecuteResponse{} + + er := c.GetExecuteRequest() + if er == nil { + resp.Error = "ExecuteRequest is nil" } else { - resp.Results = make([]*command.ExecuteResult, len(res)) - for i := range res { - resp.Results[i] = res[i] + res, err := s.db.Execute(er) + if err != nil { + resp.Error = err.Error() + } else { + resp.Results = make([]*command.ExecuteResult, len(res)) + for i := range res { + resp.Results[i] = res[i] + } } } - } - b, err = proto.Marshal(resp) - if err != nil { - return - } - conn.Write(b) + p, err := proto.Marshal(resp) + if err != nil { + return + } + // Write length of Protobuf first, then write the actual Protobuf. + b = make([]byte, 4) + binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) + conn.Write(b) + conn.Write(p) - case Command_COMMAND_TYPE_QUERY: - stats.Add(numQueryRequest, 1) + case Command_COMMAND_TYPE_QUERY: + stats.Add(numQueryRequest, 1) - resp := &CommandQueryResponse{} + resp := &CommandQueryResponse{} - qr := c.GetQueryRequest() - if qr == nil { - resp.Error = "QueryRequest is nil" - } else { - res, err := s.db.Query(qr) - if err != nil { - resp.Error = err.Error() + qr := c.GetQueryRequest() + if qr == nil { + resp.Error = "QueryRequest is nil" } else { - resp.Rows = make([]*command.QueryRows, len(res)) - for i := range res { - resp.Rows[i] = res[i] + res, err := s.db.Query(qr) + if err != nil { + resp.Error = err.Error() + } else { + resp.Rows = make([]*command.QueryRows, len(res)) + for i := range res { + resp.Rows[i] = res[i] + } } } - } - b, err = proto.Marshal(resp) - if err != nil { - return + p, err = proto.Marshal(resp) + if err != nil { + return + } + // Write length of Protobuf first, then write the actual Protobuf. + b = make([]byte, 4) + binary.LittleEndian.PutUint16(b[0:], uint16(len(p))) + conn.Write(b) + conn.Write(p) } - conn.Write(b) } } From ac1166fdcaa6069ed862e04b60b3764ee8bdbbb6 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 21 Aug 2021 17:30:27 -0400 Subject: [PATCH 09/14] Add stats to connection pool --- tcp/pool/channel.go | 21 ++++++++++++++++++++- tcp/pool/pool.go | 3 +++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/tcp/pool/channel.go b/tcp/pool/channel.go index 36bb20d6..9087a567 100644 --- a/tcp/pool/channel.go +++ b/tcp/pool/channel.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "sync" + "sync/atomic" ) // channelPool implements the Pool interface based on buffered channels. @@ -14,7 +15,8 @@ type channelPool struct { conns chan net.Conn // net.Conn generator - factory Factory + factory Factory + nOpenConns int64 } // Factory is a function to create new connections. @@ -81,6 +83,7 @@ func (c *channelPool) Get() (net.Conn, error) { if err != nil { return nil, err } + atomic.AddInt64(&c.nOpenConns, 1) return c.wrapConn(conn), nil } @@ -98,6 +101,7 @@ func (c *channelPool) put(conn net.Conn) error { if c.conns == nil { // pool is closed, close passed connection + atomic.AddInt64(&c.nOpenConns, -1) return conn.Close() } @@ -108,10 +112,13 @@ func (c *channelPool) put(conn net.Conn) error { return nil default: // pool is full, close passed connection + atomic.AddInt64(&c.nOpenConns, -1) + return conn.Close() } } +// Close closes every connection in the pool. func (c *channelPool) Close() { c.mu.Lock() conns := c.conns @@ -127,9 +134,21 @@ func (c *channelPool) Close() { for conn := range conns { conn.Close() } + atomic.AddInt64(&c.nOpenConns, 0) } +// Len() returns the number of idle connections. func (c *channelPool) Len() int { conns, _ := c.getConnsAndFactory() return len(conns) } + +// Stats returns stats for the pool. +func (c *channelPool) Stats() (map[string]interface{}, error) { + conns, _ := c.getConnsAndFactory() + return map[string]interface{}{ + "idle": len(conns), + "open_connections": c.nOpenConns, + "max_open_connections": cap(conns), + }, nil +} diff --git a/tcp/pool/pool.go b/tcp/pool/pool.go index f88f2acb..7c66c370 100644 --- a/tcp/pool/pool.go +++ b/tcp/pool/pool.go @@ -25,4 +25,7 @@ type Pool interface { // Len returns the current number of connections of the pool. Len() int + + // Stats returns stats about the pool. + Stats() (map[string]interface{}, error) } From 5613edd887024fb2c5cda13066ae5a13754d5929 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 21 Aug 2021 17:55:09 -0400 Subject: [PATCH 10/14] Fix connection pool stats Also HTTP should read its own Raft address locally. --- cluster/client.go | 4 ++-- http/service.go | 20 +++++++++++++++----- tcp/pool/channel.go | 1 + 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/cluster/client.go b/cluster/client.go index 2eead814..a8dfbda8 100644 --- a/cluster/client.go +++ b/cluster/client.go @@ -254,8 +254,8 @@ func (c *Client) Stats() (map[string]interface{}, error) { poolStats[k] = s } return map[string]interface{}{ - "timeout": c.timeout, - "pool": poolStats, + "timeout": c.timeout, + "conn_pool_stats": poolStats, }, nil } diff --git a/http/service.go b/http/service.go index 2110f694..096bc1e6 100644 --- a/http/service.go +++ b/http/service.go @@ -1002,11 +1002,21 @@ func (s *Service) checkNodesAPIAddr(nodes []*store.Server, timeout time.Duration wg.Add(1) go func(id, raftAddr string) { defer wg.Done() - apiAddr, err := s.cluster.GetNodeAPIAddr(raftAddr) - if err == nil { - mu.Lock() - apiAddrs[id] = apiAddr - mu.Unlock() + + localRaftAddr, err := s.store.LeaderAddr() + if err != nil { + return + } + + if raftAddr == localRaftAddr { + apiAddrs[id] = localRaftAddr + } else { + apiAddr, err := s.cluster.GetNodeAPIAddr(raftAddr) + if err == nil { + mu.Lock() + apiAddrs[id] = apiAddr + mu.Unlock() + } } }(n.ID, n.Addr) } diff --git a/tcp/pool/channel.go b/tcp/pool/channel.go index 9087a567..fd525d1b 100644 --- a/tcp/pool/channel.go +++ b/tcp/pool/channel.go @@ -46,6 +46,7 @@ func NewChannelPool(initialCap, maxCap int, factory Factory) (Pool, error) { c.Close() return nil, fmt.Errorf("factory is not able to fill the pool: %s", err) } + atomic.AddInt64(&c.nOpenConns, 1) c.conns <- conn } From 132c1809feae48cc527c54d2916d71cae8e7d55f Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 21 Aug 2021 17:59:45 -0400 Subject: [PATCH 11/14] Fix race condition --- http/service.go | 2 ++ 1 file changed, 2 insertions(+) diff --git a/http/service.go b/http/service.go index 096bc1e6..1a4eb79f 100644 --- a/http/service.go +++ b/http/service.go @@ -1009,7 +1009,9 @@ func (s *Service) checkNodesAPIAddr(nodes []*store.Server, timeout time.Duration } if raftAddr == localRaftAddr { + mu.Lock() apiAddrs[id] = localRaftAddr + mu.Unlock() } else { apiAddr, err := s.cluster.GetNodeAPIAddr(raftAddr) if err == nil { From 88b80cffa621a62620dc1a5f57184336caf0462e Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sun, 22 Aug 2021 09:25:08 -0400 Subject: [PATCH 12/14] Clearer system-level test --- system_test/single_node_test.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/system_test/single_node_test.go b/system_test/single_node_test.go index bbd7a5ff..3ca849eb 100644 --- a/system_test/single_node_test.go +++ b/system_test/single_node_test.go @@ -434,8 +434,8 @@ func Test_SingleNodeNodes(t *testing.T) { if n.Addr != node.RaftAddr { t.Fatalf("node has wrong Raft address") } - if n.APIAddr != fmt.Sprintf("http://%s", node.APIAddr) { - t.Fatalf("node has wrong API address") + if got, exp := n.APIAddr, fmt.Sprintf("http://%s", node.APIAddr); exp != got { + t.Fatalf("node has wrong API address, exp: %s got: %s", exp, got) } if !n.Leader { t.Fatalf("node is not leader") From d90ea2e6b1082473b94055dbbd904a1186dd9fd2 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sun, 22 Aug 2021 10:12:22 -0400 Subject: [PATCH 13/14] Service Node API requests locally if possible --- cluster/client.go | 23 +++++++++++++++++++++++ cluster/service.go | 4 ++++ cluster/service_test.go | 35 +++++++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+) diff --git a/cluster/client.go b/cluster/client.go index a8dfbda8..a5d78f90 100644 --- a/cluster/client.go +++ b/cluster/client.go @@ -24,6 +24,10 @@ type Client struct { dialer Dialer timeout time.Duration + lMu sync.RWMutex + localNodeAddr string + localServ *Service + mu sync.RWMutex pools map[string]pool.Pool } @@ -37,8 +41,27 @@ func NewClient(dl Dialer) *Client { } } +// SetLocal informs the client instance of the node address for +// the node using this client. Along with the Service instance +// it allows this client to serve requests for this node locally +// without the network hop. +func (c *Client) SetLocal(nodeAddr string, serv *Service) { + c.lMu.Lock() + defer c.lMu.Unlock() + c.localNodeAddr = nodeAddr + c.localServ = serv +} + // GetNodeAPIAddr retrieves the API Address for the node at nodeAddr func (c *Client) GetNodeAPIAddr(nodeAddr string) (string, error) { + c.lMu.RLock() + defer c.lMu.RUnlock() + if c.localNodeAddr == nodeAddr && c.localServ != nil { + // Serve it locally! + stats.Add(numGetNodeAPIRequestLocal, 1) + return c.localServ.GetNodeAPIURL(), nil + } + conn, err := c.dial(nodeAddr, c.timeout) if err != nil { return "", err diff --git a/cluster/service.go b/cluster/service.go index 4ac6f82c..807f7b68 100644 --- a/cluster/service.go +++ b/cluster/service.go @@ -24,6 +24,9 @@ const ( numGetNodeAPIResponse = "num_get_node_api_resp" numExecuteRequest = "num_execute_req" numQueryRequest = "num_query_req" + + // Client stats for this package. + numGetNodeAPIRequestLocal = "num_get_node_api_req_local" ) const ( @@ -40,6 +43,7 @@ func init() { stats.Add(numGetNodeAPIResponse, 0) stats.Add(numExecuteRequest, 0) stats.Add(numQueryRequest, 0) + stats.Add(numGetNodeAPIRequestLocal, 0) } // Dialer is the interface dialers must implement. diff --git a/cluster/service_test.go b/cluster/service_test.go index 3ae2e1ee..e4bc5abe 100644 --- a/cluster/service_test.go +++ b/cluster/service_test.go @@ -95,6 +95,41 @@ func Test_NewServiceSetGetNodeAPIAddr(t *testing.T) { } } +func Test_NewServiceSetGetNodeAPIAddrLocal(t *testing.T) { + ml := mustNewMockTransport() + s := New(ml, mustNewMockDatabase()) + if s == nil { + t.Fatalf("failed to create cluster service") + } + + if err := s.Open(); err != nil { + t.Fatalf("failed to open cluster service") + } + + s.SetAPIAddr("foo") + + // Check stats to confirm no local request yet. + if stats.Get(numGetNodeAPIRequestLocal).String() != "0" { + t.Fatalf("failed to confirm request served locally") + } + + // Test by enabling local answering + c := NewClient(ml) + c.SetLocal(s.Addr(), s) + addr, err := c.GetNodeAPIAddr(s.Addr()) + if err != nil { + t.Fatalf("failed to get node API address locally: %s", err) + } + if addr != "http://foo" { + t.Fatalf("failed to get correct node API address locally, exp %s, got %s", "http://foo", addr) + } + + // Check stats to confirm local response. + if stats.Get(numGetNodeAPIRequestLocal).String() != "1" { + t.Fatalf("failed to confirm request served locally") + } +} + func Test_NewServiceSetGetNodeAPIAddrTLS(t *testing.T) { ml := mustNewMockTLSTransport() s := New(ml, mustNewMockDatabase()) From f1a7e7e8ecc127463346cae91e755ae83b4c0590 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sun, 22 Aug 2021 10:19:49 -0400 Subject: [PATCH 14/14] Use locally-enabled cluster client in HTTP service --- cmd/rqlited/main.go | 1 + http/service.go | 18 +++--------------- 2 files changed, 4 insertions(+), 15 deletions(-) diff --git a/cmd/rqlited/main.go b/cmd/rqlited/main.go index ed4bd552..c38ccc72 100644 --- a/cmd/rqlited/main.go +++ b/cmd/rqlited/main.go @@ -307,6 +307,7 @@ func main() { // Start the HTTP API server. clstrDialer := tcp.NewDialer(cluster.MuxClusterHeader, nodeEncrypt, noNodeVerify) clstrClient := cluster.NewClient(clstrDialer) + clstrClient.SetLocal(raftAdv, clstr) if err := startHTTPService(str, clstrClient); err != nil { log.Fatalf("failed to start HTTP server: %s", err.Error()) } diff --git a/http/service.go b/http/service.go index 1a4eb79f..2110f694 100644 --- a/http/service.go +++ b/http/service.go @@ -1002,23 +1002,11 @@ func (s *Service) checkNodesAPIAddr(nodes []*store.Server, timeout time.Duration wg.Add(1) go func(id, raftAddr string) { defer wg.Done() - - localRaftAddr, err := s.store.LeaderAddr() - if err != nil { - return - } - - if raftAddr == localRaftAddr { + apiAddr, err := s.cluster.GetNodeAPIAddr(raftAddr) + if err == nil { mu.Lock() - apiAddrs[id] = localRaftAddr + apiAddrs[id] = apiAddr mu.Unlock() - } else { - apiAddr, err := s.cluster.GetNodeAPIAddr(raftAddr) - if err == nil { - mu.Lock() - apiAddrs[id] = apiAddr - mu.Unlock() - } } }(n.ID, n.Addr) }