1
0
Fork 0

Merge pull request #864 from rqlite/conn-pool-2

Use a connection pool for internode communications
master
Philip O'Toole 3 years ago committed by GitHub
commit 9a8b463838
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,17 +4,32 @@ import (
"encoding/binary"
"errors"
"fmt"
"io/ioutil"
"io"
"net"
"sync"
"time"
"github.com/golang/protobuf/proto"
"github.com/rqlite/rqlite/command"
"github.com/rqlite/rqlite/tcp/pool"
)
const (
initialPoolSize = 4
maxPoolCapacity = 64
)
// Client allows communicating with a remote node.
type Client struct {
dialer Dialer
timeout time.Duration
lMu sync.RWMutex
localNodeAddr string
localServ *Service
mu sync.RWMutex
pools map[string]pool.Pool
}
// NewClient returns a client instance for talking to a remote node.
@ -22,14 +37,34 @@ func NewClient(dl Dialer) *Client {
return &Client{
dialer: dl,
timeout: 30 * time.Second,
pools: make(map[string]pool.Pool),
}
}
// SetLocal informs the client instance of the node address for
// the node using this client. Along with the Service instance
// it allows this client to serve requests for this node locally
// without the network hop.
func (c *Client) SetLocal(nodeAddr string, serv *Service) {
c.lMu.Lock()
defer c.lMu.Unlock()
c.localNodeAddr = nodeAddr
c.localServ = serv
}
// GetNodeAPIAddr retrieves the API Address for the node at nodeAddr
func (c *Client) GetNodeAPIAddr(nodeAddr string) (string, error) {
conn, err := c.dialer.Dial(nodeAddr, c.timeout)
c.lMu.RLock()
defer c.lMu.RUnlock()
if c.localNodeAddr == nodeAddr && c.localServ != nil {
// Serve it locally!
stats.Add(numGetNodeAPIRequestLocal, 1)
return c.localServ.GetNodeAPIURL(), nil
}
conn, err := c.dial(nodeAddr, c.timeout)
if err != nil {
return "", fmt.Errorf("dial connection: %s", err)
return "", err
}
defer conn.Close()
@ -48,20 +83,31 @@ func (c *Client) GetNodeAPIAddr(nodeAddr string) (string, error) {
_, err = conn.Write(b)
if err != nil {
handleConnError(conn)
return "", fmt.Errorf("write protobuf length: %s", err)
}
_, err = conn.Write(p)
if err != nil {
handleConnError(conn)
return "", fmt.Errorf("write protobuf: %s", err)
}
b, err = ioutil.ReadAll(conn)
// Read length of response.
_, err = io.ReadFull(conn, b)
if err != nil {
return "", err
}
sz := binary.LittleEndian.Uint16(b[0:])
// Read in the actual response.
p = make([]byte, sz)
_, err = io.ReadFull(conn, p)
if err != nil {
return "", fmt.Errorf("read protobuf bytes: %s", err)
return "", err
}
a := &Address{}
err = proto.Unmarshal(b, a)
err = proto.Unmarshal(p, a)
if err != nil {
return "", fmt.Errorf("protobuf unmarshal: %s", err)
}
@ -71,9 +117,9 @@ func (c *Client) GetNodeAPIAddr(nodeAddr string) (string, error) {
// Execute performs an Execute on a remote node.
func (c *Client) Execute(er *command.ExecuteRequest, nodeAddr string, timeout time.Duration) ([]*command.ExecuteResult, error) {
conn, err := c.dialer.Dial(nodeAddr, c.timeout)
conn, err := c.dial(nodeAddr, c.timeout)
if err != nil {
return nil, fmt.Errorf("dial connection: %s", err)
return nil, err
}
defer conn.Close()
@ -94,30 +140,45 @@ func (c *Client) Execute(er *command.ExecuteRequest, nodeAddr string, timeout ti
binary.LittleEndian.PutUint16(b[0:], uint16(len(p)))
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
handleConnError(conn)
return nil, err
}
_, err = conn.Write(b)
if err != nil {
handleConnError(conn)
return nil, err
}
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
handleConnError(conn)
return nil, err
}
_, err = conn.Write(p)
if err != nil {
handleConnError(conn)
return nil, err
}
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
handleConnError(conn)
return nil, err
}
b, err = ioutil.ReadAll(conn)
// Read length of response.
_, err = io.ReadFull(conn, b)
if err != nil {
return nil, err
}
sz := binary.LittleEndian.Uint16(b[0:])
// Read in the actual response.
p = make([]byte, sz)
_, err = io.ReadFull(conn, p)
if err != nil {
return nil, err
}
a := &CommandExecuteResponse{}
err = proto.Unmarshal(b, a)
err = proto.Unmarshal(p, a)
if err != nil {
return nil, err
}
@ -130,9 +191,9 @@ func (c *Client) Execute(er *command.ExecuteRequest, nodeAddr string, timeout ti
// Query performs an Query on a remote node.
func (c *Client) Query(qr *command.QueryRequest, nodeAddr string, timeout time.Duration) ([]*command.QueryRows, error) {
conn, err := c.dialer.Dial(nodeAddr, c.timeout)
conn, err := c.dial(nodeAddr, c.timeout)
if err != nil {
return nil, fmt.Errorf("dial connection: %s", err)
return nil, err
}
defer conn.Close()
@ -148,35 +209,50 @@ func (c *Client) Query(qr *command.QueryRequest, nodeAddr string, timeout time.D
return nil, fmt.Errorf("command marshal: %s", err)
}
// Write length of Protobuf, the Protobuf
// Write length of Protobuf, then the Protobuf
b := make([]byte, 4)
binary.LittleEndian.PutUint16(b[0:], uint16(len(p)))
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
handleConnError(conn)
return nil, err
}
_, err = conn.Write(b)
if err != nil {
handleConnError(conn)
return nil, err
}
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
handleConnError(conn)
return nil, err
}
_, err = conn.Write(p)
if err != nil {
handleConnError(conn)
return nil, err
}
if err := conn.SetDeadline(time.Now().Add(timeout)); err != nil {
handleConnError(conn)
return nil, err
}
b, err = ioutil.ReadAll(conn)
// Read length of response.
_, err = io.ReadFull(conn, b)
if err != nil {
return nil, err
}
sz := binary.LittleEndian.Uint16(b[0:])
// Read in the actual response.
p = make([]byte, sz)
_, err = io.ReadFull(conn, p)
if err != nil {
return nil, err
}
a := &CommandQueryResponse{}
err = proto.Unmarshal(b, a)
err = proto.Unmarshal(p, a)
if err != nil {
return nil, err
}
@ -189,7 +265,65 @@ func (c *Client) Query(qr *command.QueryRequest, nodeAddr string, timeout time.D
// Stats returns stats on the Client instance
func (c *Client) Stats() (map[string]interface{}, error) {
c.mu.RLock()
defer c.mu.RUnlock()
poolStats := make(map[string]interface{}, len(c.pools))
for k, v := range c.pools {
s, err := v.Stats()
if err != nil {
return nil, err
}
poolStats[k] = s
}
return map[string]interface{}{
"timeout": c.timeout,
"conn_pool_stats": poolStats,
}, nil
}
func (c *Client) dial(nodeAddr string, timeout time.Duration) (net.Conn, error) {
var pl pool.Pool
var ok bool
c.mu.RLock()
pl, ok = c.pools[nodeAddr]
c.mu.RUnlock()
// Do we need a new pool for the given address?
if !ok {
if err := func() error {
c.mu.Lock()
defer c.mu.Unlock()
pl, ok = c.pools[nodeAddr]
if ok {
return nil // Pool was inserted just after we checked.
}
// New pool is needed for given address.
factory := func() (net.Conn, error) { return c.dialer.Dial(nodeAddr, c.timeout) }
p, err := pool.NewChannelPool(initialPoolSize, maxPoolCapacity, factory)
if err != nil {
return err
}
c.pools[nodeAddr] = p
pl = p
return nil
}(); err != nil {
return nil, err
}
}
// Got pool, now get a connection.
conn, err := pl.Get()
if err != nil {
return nil, fmt.Errorf("pool get: %s", err)
}
return conn, nil
}
func handleConnError(conn net.Conn) {
if pc, ok := conn.(*pool.PoolConn); ok {
pc.MarkUnusable()
}
}

@ -24,6 +24,9 @@ const (
numGetNodeAPIResponse = "num_get_node_api_resp"
numExecuteRequest = "num_execute_req"
numQueryRequest = "num_query_req"
// Client stats for this package.
numGetNodeAPIRequestLocal = "num_get_node_api_req_local"
)
const (
@ -40,6 +43,7 @@ func init() {
stats.Add(numGetNodeAPIResponse, 0)
stats.Add(numExecuteRequest, 0)
stats.Add(numQueryRequest, 0)
stats.Add(numGetNodeAPIRequestLocal, 0)
}
// Dialer is the interface dialers must implement.
@ -166,6 +170,7 @@ func (s *Service) serve() error {
func (s *Service) handleConn(conn net.Conn) {
defer conn.Close()
for {
b := make([]byte, 4)
_, err := io.ReadFull(conn, b)
if err != nil {
@ -173,14 +178,14 @@ func (s *Service) handleConn(conn net.Conn) {
}
sz := binary.LittleEndian.Uint16(b[0:])
b = make([]byte, sz)
_, err = io.ReadFull(conn, b)
p := make([]byte, sz)
_, err = io.ReadFull(conn, p)
if err != nil {
return
}
c := &Command{}
err = proto.Unmarshal(b, c)
err = proto.Unmarshal(p, c)
if err != nil {
conn.Close()
}
@ -188,13 +193,18 @@ func (s *Service) handleConn(conn net.Conn) {
switch c.Type {
case Command_COMMAND_TYPE_GET_NODE_API_URL:
stats.Add(numGetNodeAPIRequest, 1)
b, err = proto.Marshal(&Address{
p, err = proto.Marshal(&Address{
Url: s.GetNodeAPIURL(),
})
if err != nil {
conn.Close()
}
// Write length of Protobuf first, then write the actual Protobuf.
b = make([]byte, 4)
binary.LittleEndian.PutUint16(b[0:], uint16(len(p)))
conn.Write(b)
conn.Write(p)
stats.Add(numGetNodeAPIResponse, 1)
case Command_COMMAND_TYPE_EXECUTE:
@ -217,11 +227,15 @@ func (s *Service) handleConn(conn net.Conn) {
}
}
b, err = proto.Marshal(resp)
p, err := proto.Marshal(resp)
if err != nil {
return
}
// Write length of Protobuf first, then write the actual Protobuf.
b = make([]byte, 4)
binary.LittleEndian.PutUint16(b[0:], uint16(len(p)))
conn.Write(b)
conn.Write(p)
case Command_COMMAND_TYPE_QUERY:
stats.Add(numQueryRequest, 1)
@ -243,10 +257,15 @@ func (s *Service) handleConn(conn net.Conn) {
}
}
b, err = proto.Marshal(resp)
p, err = proto.Marshal(resp)
if err != nil {
return
}
// Write length of Protobuf first, then write the actual Protobuf.
b = make([]byte, 4)
binary.LittleEndian.PutUint16(b[0:], uint16(len(p)))
conn.Write(b)
conn.Write(p)
}
}
}

@ -95,6 +95,41 @@ func Test_NewServiceSetGetNodeAPIAddr(t *testing.T) {
}
}
func Test_NewServiceSetGetNodeAPIAddrLocal(t *testing.T) {
ml := mustNewMockTransport()
s := New(ml, mustNewMockDatabase())
if s == nil {
t.Fatalf("failed to create cluster service")
}
if err := s.Open(); err != nil {
t.Fatalf("failed to open cluster service")
}
s.SetAPIAddr("foo")
// Check stats to confirm no local request yet.
if stats.Get(numGetNodeAPIRequestLocal).String() != "0" {
t.Fatalf("failed to confirm request served locally")
}
// Test by enabling local answering
c := NewClient(ml)
c.SetLocal(s.Addr(), s)
addr, err := c.GetNodeAPIAddr(s.Addr())
if err != nil {
t.Fatalf("failed to get node API address locally: %s", err)
}
if addr != "http://foo" {
t.Fatalf("failed to get correct node API address locally, exp %s, got %s", "http://foo", addr)
}
// Check stats to confirm local response.
if stats.Get(numGetNodeAPIRequestLocal).String() != "1" {
t.Fatalf("failed to confirm request served locally")
}
}
func Test_NewServiceSetGetNodeAPIAddrTLS(t *testing.T) {
ml := mustNewMockTLSTransport()
s := New(ml, mustNewMockDatabase())

@ -307,6 +307,7 @@ func main() {
// Start the HTTP API server.
clstrDialer := tcp.NewDialer(cluster.MuxClusterHeader, nodeEncrypt, noNodeVerify)
clstrClient := cluster.NewClient(clstrDialer)
clstrClient.SetLocal(raftAdv, clstr)
if err := startHTTPService(str, clstrClient); err != nil {
log.Fatalf("failed to start HTTP server: %s", err.Error())
}

@ -26,6 +26,11 @@ import (
"github.com/rqlite/rqlite/store"
)
var (
// ErrLeaderNotFound is returned when a node cannot locate a leader
ErrLeaderNotFound = errors.New("leader not found")
)
// Database is the interface any queryable system must implement
type Database interface {
// Execute executes a slice of queries, each of which is not expected
@ -126,6 +131,7 @@ type Response struct {
var stats *expvar.Map
const (
numLeaderNotFound = "leader_not_found"
numExecutions = "executions"
numQueries = "queries"
numRemoteExecutions = "remote_executions"
@ -167,6 +173,7 @@ const (
func init() {
stats = expvar.NewMap("http")
stats.Add(numLeaderNotFound, 0)
stats.Add(numExecutions, 0)
stats.Add(numQueries, 0)
stats.Add(numRemoteExecutions, 0)
@ -375,7 +382,8 @@ func (s *Service) handleJoin(w http.ResponseWriter, r *http.Request) {
if err == store.ErrNotLeader {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
@ -427,7 +435,8 @@ func (s *Service) handleRemove(w http.ResponseWriter, r *http.Request) {
if err == store.ErrNotLeader {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
@ -470,7 +479,8 @@ func (s *Service) handleBackup(w http.ResponseWriter, r *http.Request) {
if err == store.ErrNotLeader {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
@ -522,7 +532,8 @@ func (s *Service) handleLoad(w http.ResponseWriter, r *http.Request) {
if err == store.ErrNotLeader {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
@ -767,7 +778,8 @@ func (s *Service) handleExecute(w http.ResponseWriter, r *http.Request) {
if redirect {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
loc := s.FormRedirect(r, leaderAPIAddr)
@ -779,6 +791,10 @@ func (s *Service) handleExecute(w http.ResponseWriter, r *http.Request) {
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
if addr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
}
results, resultsErr = s.cluster.Execute(er, addr, timeout)
stats.Add(numRemoteExecutions, 1)
w.Header().Add(ServedByHTTPHeader, addr)
@ -849,7 +865,8 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
if redirect {
leaderAPIAddr := s.LeaderAPIAddr()
if leaderAPIAddr == "" {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
return
}
loc := s.FormRedirect(r, leaderAPIAddr)
@ -861,6 +878,10 @@ func (s *Service) handleQuery(w http.ResponseWriter, r *http.Request) {
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
if addr == "" {
stats.Add(numLeaderNotFound, 1)
http.Error(w, ErrLeaderNotFound.Error(), http.StatusServiceUnavailable)
}
results, resultsErr = s.cluster.Query(qr, addr, timeout)
stats.Add(numRemoteQueries, 1)
w.Header().Add(ServedByHTTPHeader, addr)

@ -677,7 +677,7 @@ func Test_Nodes(t *testing.T) {
}
}
func Test_ForwardingRedirectQueries(t *testing.T) {
func Test_ForwardingRedirectQuery(t *testing.T) {
m := &MockStore{
leaderAddr: "foo:1234",
}
@ -725,6 +725,16 @@ func Test_ForwardingRedirectQueries(t *testing.T) {
if resp.StatusCode != http.StatusMovedPermanently {
t.Fatalf("failed to get expected StatusMovedPermanently for query, got %d", resp.StatusCode)
}
// Check leader failure case.
m.leaderAddr = ""
resp, err = client.Get(host + "/db/query?pretty&timings&q=SELECT%20%2A%20FROM%20foo")
if err != nil {
t.Fatalf("failed to make query forwarded request")
}
if resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("failed to get expected StatusServiceUnavailable for node with no leader, got %d", resp.StatusCode)
}
}
func Test_ForwardingRedirectExecute(t *testing.T) {
@ -775,6 +785,16 @@ func Test_ForwardingRedirectExecute(t *testing.T) {
if resp.StatusCode != http.StatusMovedPermanently {
t.Fatalf("failed to get expected StatusMovedPermanently for execute, got %d", resp.StatusCode)
}
// Check leader failure case.
m.leaderAddr = ""
resp, err = client.Post(host+"/db/execute", "application/json", strings.NewReader(`["Some SQL"]`))
if err != nil {
t.Fatalf("failed to make execute request")
}
if resp.StatusCode != http.StatusServiceUnavailable {
t.Fatalf("failed to get expected StatusServiceUnavailable for node with no leader, got %d", resp.StatusCode)
}
}
func Test_TLSServce(t *testing.T) {

@ -136,7 +136,7 @@ class Node(object):
r.raise_for_status()
return r.json()
def is_leader(self, constraint_check=True):
def is_leader(self):
'''
is_leader returns whether this node is the cluster leader
It also performs a check, to ensure the node nevers gives out
@ -144,28 +144,16 @@ class Node(object):
'''
try:
isLeaderRaft = self.status()['store']['raft']['state'] == 'Leader'
isLeaderNodes = self.nodes()[self.node_id]['leader'] is True
return self.status()['store']['raft']['state'] == 'Leader'
except requests.exceptions.ConnectionError:
return False
if (isLeaderRaft != isLeaderNodes) and constraint_check:
raise AssertionError("conflicting states reported for leadership (raft: %s, nodes: %s)"
% (isLeaderRaft, isLeaderNodes))
return isLeaderNodes
def is_follower(self):
try:
isFollowerRaft = self.status()['store']['raft']['state'] == 'Follower'
isFollowersNodes = self.nodes()[self.node_id]['leader'] is False
return self.status()['store']['raft']['state'] == 'Follower'
except requests.exceptions.ConnectionError:
return False
if isFollowerRaft != isFollowersNodes:
raise AssertionError("conflicting states reported for followership (raft: %s, nodes: %s)"
% (isFollowerRaft, isFollowersNodes))
return isFollowersNodes
def wait_for_leader(self, timeout=TIMEOUT):
lr = None
t = 0
@ -289,6 +277,7 @@ class Node(object):
def redirect_addr(self):
r = requests.post(self._execute_url(redirect=True), data=json.dumps(['nonsense']), allow_redirects=False)
r.raise_for_status()
if r.status_code == 301:
return "%s://%s" % (urlparse(r.headers['Location']).scheme, urlparse(r.headers['Location']).netloc)
@ -333,7 +322,7 @@ def deprovision_node(node):
class Cluster(object):
def __init__(self, nodes):
self.nodes = nodes
def wait_for_leader(self, node_exc=None, timeout=TIMEOUT, constraint_check=True):
def wait_for_leader(self, node_exc=None, timeout=TIMEOUT):
t = 0
while True:
if t > timeout:
@ -341,7 +330,7 @@ class Cluster(object):
for n in self.nodes:
if node_exc is not None and n == node_exc:
continue
if n.is_leader(constraint_check):
if n.is_leader():
return n
time.sleep(1)
t+=1
@ -522,6 +511,8 @@ class TestEndToEnd(unittest.TestCase):
for n in fs:
self.assertEqual(l.APIProtoAddr(), n.redirect_addr())
# Kill the leader, wait for a new leader, and check that the
# redirect returns the new leader address.
l.stop()
n = self.cluster.wait_for_leader(node_exc=l)
for f in self.cluster.followers():
@ -682,10 +673,9 @@ class TestEndToEndNonVoterFollowsLeader(unittest.TestCase):
j = n.query('SELECT * FROM foo')
self.assertEqual(str(j), "{u'results': [{u'values': [[1, u'fiona']], u'types': [u'integer', u'text'], u'columns': [u'id', u'name']}]}")
# Kill leader, and then make more changes. Don't perform leader-constraint checks
# since the cluster is changing right now.
n0 = self.cluster.wait_for_leader(constraint_check=False).stop()
n1 = self.cluster.wait_for_leader(node_exc=n0, constraint_check=False)
# Kill leader, and then make more changes.
n0 = self.cluster.wait_for_leader().stop()
n1 = self.cluster.wait_for_leader(node_exc=n0)
n1.wait_for_all_applied()
j = n1.query('SELECT * FROM foo')
self.assertEqual(str(j), "{u'results': [{u'values': [[1, u'fiona']], u'types': [u'integer', u'text'], u'columns': [u'id', u'name']}]}")

@ -434,8 +434,8 @@ func Test_SingleNodeNodes(t *testing.T) {
if n.Addr != node.RaftAddr {
t.Fatalf("node has wrong Raft address")
}
if n.APIAddr != fmt.Sprintf("http://%s", node.APIAddr) {
t.Fatalf("node has wrong API address")
if got, exp := n.APIAddr, fmt.Sprintf("http://%s", node.APIAddr); exp != got {
t.Fatalf("node has wrong API address, exp: %s got: %s", exp, got)
}
if !n.Leader {
t.Fatalf("node is not leader")

@ -0,0 +1,20 @@
The MIT License (MIT)
Copyright (c) 2013 Fatih Arslan
Permission is hereby granted, free of charge, to any person obtaining a copy of
this software and associated documentation files (the "Software"), to deal in
the Software without restriction, including without limitation the rights to
use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
the Software, and to permit persons to whom the Software is furnished to do so,
subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.

@ -0,0 +1,155 @@
package pool
import (
"errors"
"fmt"
"net"
"sync"
"sync/atomic"
)
// channelPool implements the Pool interface based on buffered channels.
type channelPool struct {
// storage for our net.Conn connections
mu sync.RWMutex
conns chan net.Conn
// net.Conn generator
factory Factory
nOpenConns int64
}
// Factory is a function to create new connections.
type Factory func() (net.Conn, error)
// NewChannelPool returns a new pool based on buffered channels with an initial
// capacity and maximum capacity. Factory is used when initial capacity is
// greater than zero to fill the pool. A zero initialCap doesn't fill the Pool
// until a new Get() is called. During a Get(), If there is no new connection
// available in the pool, a new connection will be created via the Factory()
// method.
func NewChannelPool(initialCap, maxCap int, factory Factory) (Pool, error) {
if initialCap < 0 || maxCap <= 0 || initialCap > maxCap {
return nil, errors.New("invalid capacity settings")
}
c := &channelPool{
conns: make(chan net.Conn, maxCap),
factory: factory,
}
// create initial connections, if something goes wrong,
// just close the pool error out.
for i := 0; i < initialCap; i++ {
conn, err := factory()
if err != nil {
c.Close()
return nil, fmt.Errorf("factory is not able to fill the pool: %s", err)
}
atomic.AddInt64(&c.nOpenConns, 1)
c.conns <- conn
}
return c, nil
}
func (c *channelPool) getConnsAndFactory() (chan net.Conn, Factory) {
c.mu.RLock()
conns := c.conns
factory := c.factory
c.mu.RUnlock()
return conns, factory
}
// Get implements the Pool interfaces Get() method. If there is no new
// connection available in the pool, a new connection will be created via the
// Factory() method.
func (c *channelPool) Get() (net.Conn, error) {
conns, factory := c.getConnsAndFactory()
if conns == nil {
return nil, ErrClosed
}
// wrap our connections with out custom net.Conn implementation (wrapConn
// method) that puts the connection back to the pool if it's closed.
select {
case conn := <-conns:
if conn == nil {
return nil, ErrClosed
}
return c.wrapConn(conn), nil
default:
conn, err := factory()
if err != nil {
return nil, err
}
atomic.AddInt64(&c.nOpenConns, 1)
return c.wrapConn(conn), nil
}
}
// put puts the connection back to the pool. If the pool is full or closed,
// conn is simply closed. A nil conn will be rejected.
func (c *channelPool) put(conn net.Conn) error {
if conn == nil {
return errors.New("connection is nil. rejecting")
}
c.mu.RLock()
defer c.mu.RUnlock()
if c.conns == nil {
// pool is closed, close passed connection
atomic.AddInt64(&c.nOpenConns, -1)
return conn.Close()
}
// put the resource back into the pool. If the pool is full, this will
// block and the default case will be executed.
select {
case c.conns <- conn:
return nil
default:
// pool is full, close passed connection
atomic.AddInt64(&c.nOpenConns, -1)
return conn.Close()
}
}
// Close closes every connection in the pool.
func (c *channelPool) Close() {
c.mu.Lock()
conns := c.conns
c.conns = nil
c.factory = nil
c.mu.Unlock()
if conns == nil {
return
}
close(conns)
for conn := range conns {
conn.Close()
}
atomic.AddInt64(&c.nOpenConns, 0)
}
// Len() returns the number of idle connections.
func (c *channelPool) Len() int {
conns, _ := c.getConnsAndFactory()
return len(conns)
}
// Stats returns stats for the pool.
func (c *channelPool) Stats() (map[string]interface{}, error) {
conns, _ := c.getConnsAndFactory()
return map[string]interface{}{
"idle": len(conns),
"open_connections": c.nOpenConns,
"max_open_connections": cap(conns),
}, nil
}

@ -0,0 +1,292 @@
package pool
import (
"log"
"math/rand"
"net"
"sync"
"testing"
"time"
)
var (
InitialCap = 5
MaximumCap = 30
network = "tcp"
address = "127.0.0.1:7777"
factory = func() (net.Conn, error) { return net.Dial(network, address) }
)
func init() {
// used for factory function
go simpleTCPServer()
time.Sleep(time.Millisecond * 300) // wait until tcp server has been settled
rand.Seed(time.Now().UTC().UnixNano())
}
func TestNew(t *testing.T) {
_, err := newChannelPool()
if err != nil {
t.Errorf("New error: %s", err)
}
}
func TestPool_Get_Impl(t *testing.T) {
p, _ := newChannelPool()
defer p.Close()
conn, err := p.Get()
if err != nil {
t.Errorf("Get error: %s", err)
}
_, ok := conn.(*PoolConn)
if !ok {
t.Errorf("Conn is not of type poolConn")
}
}
func TestPool_Get(t *testing.T) {
p, _ := newChannelPool()
defer p.Close()
_, err := p.Get()
if err != nil {
t.Errorf("Get error: %s", err)
}
// after one get, current capacity should be lowered by one.
if p.Len() != (InitialCap - 1) {
t.Errorf("Get error. Expecting %d, got %d",
(InitialCap - 1), p.Len())
}
// get them all
var wg sync.WaitGroup
for i := 0; i < (InitialCap - 1); i++ {
wg.Add(1)
go func() {
defer wg.Done()
_, err := p.Get()
if err != nil {
t.Errorf("Get error: %s", err)
}
}()
}
wg.Wait()
if p.Len() != 0 {
t.Errorf("Get error. Expecting %d, got %d",
(InitialCap - 1), p.Len())
}
_, err = p.Get()
if err != nil {
t.Errorf("Get error: %s", err)
}
}
func TestPool_Put(t *testing.T) {
p, err := NewChannelPool(0, 30, factory)
if err != nil {
t.Fatal(err)
}
defer p.Close()
// get/create from the pool
conns := make([]net.Conn, MaximumCap)
for i := 0; i < MaximumCap; i++ {
conn, _ := p.Get()
conns[i] = conn
}
// now put them all back
for _, conn := range conns {
conn.Close()
}
if p.Len() != MaximumCap {
t.Errorf("Put error len. Expecting %d, got %d",
1, p.Len())
}
conn, _ := p.Get()
p.Close() // close pool
conn.Close() // try to put into a full pool
if p.Len() != 0 {
t.Errorf("Put error. Closed pool shouldn't allow to put connections.")
}
}
func TestPool_PutUnusableConn(t *testing.T) {
p, _ := newChannelPool()
defer p.Close()
// ensure pool is not empty
conn, _ := p.Get()
conn.Close()
poolSize := p.Len()
conn, _ = p.Get()
conn.Close()
if p.Len() != poolSize {
t.Errorf("Pool size is expected to be equal to initial size")
}
conn, _ = p.Get()
if pc, ok := conn.(*PoolConn); !ok {
t.Errorf("impossible")
} else {
pc.MarkUnusable()
}
conn.Close()
if p.Len() != poolSize-1 {
t.Errorf("Pool size is expected to be initial_size - 1, %d, %d", p.Len(), poolSize-1)
}
}
func TestPool_UsedCapacity(t *testing.T) {
p, _ := newChannelPool()
defer p.Close()
if p.Len() != InitialCap {
t.Errorf("InitialCap error. Expecting %d, got %d",
InitialCap, p.Len())
}
}
func TestPool_Close(t *testing.T) {
p, _ := newChannelPool()
// now close it and test all cases we are expecting.
p.Close()
c := p.(*channelPool)
if c.conns != nil {
t.Errorf("Close error, conns channel should be nil")
}
if c.factory != nil {
t.Errorf("Close error, factory should be nil")
}
_, err := p.Get()
if err == nil {
t.Errorf("Close error, get conn should return an error")
}
if p.Len() != 0 {
t.Errorf("Close error used capacity. Expecting 0, got %d", p.Len())
}
}
func TestPoolConcurrent(t *testing.T) {
p, _ := newChannelPool()
pipe := make(chan net.Conn, 0)
go func() {
p.Close()
}()
for i := 0; i < MaximumCap; i++ {
go func() {
conn, _ := p.Get()
pipe <- conn
}()
go func() {
conn := <-pipe
if conn == nil {
return
}
conn.Close()
}()
}
}
func TestPoolWriteRead(t *testing.T) {
p, _ := NewChannelPool(0, 30, factory)
conn, _ := p.Get()
msg := "hello"
_, err := conn.Write([]byte(msg))
if err != nil {
t.Error(err)
}
}
func TestPoolConcurrent2(t *testing.T) {
p, _ := NewChannelPool(0, 30, factory)
var wg sync.WaitGroup
go func() {
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
conn, _ := p.Get()
time.Sleep(time.Millisecond * time.Duration(rand.Intn(100)))
conn.Close()
wg.Done()
}(i)
}
}()
for i := 0; i < 10; i++ {
wg.Add(1)
go func(i int) {
conn, _ := p.Get()
time.Sleep(time.Millisecond * time.Duration(rand.Intn(100)))
conn.Close()
wg.Done()
}(i)
}
wg.Wait()
}
func TestPoolConcurrent3(t *testing.T) {
p, _ := NewChannelPool(0, 1, factory)
var wg sync.WaitGroup
wg.Add(1)
go func() {
p.Close()
wg.Done()
}()
if conn, err := p.Get(); err == nil {
conn.Close()
}
wg.Wait()
}
func newChannelPool() (Pool, error) {
return NewChannelPool(InitialCap, MaximumCap, factory)
}
func simpleTCPServer() {
l, err := net.Listen(network, address)
if err != nil {
log.Fatal(err)
}
defer l.Close()
for {
conn, err := l.Accept()
if err != nil {
log.Fatal(err)
}
go func() {
buffer := make([]byte, 256)
conn.Read(buffer)
}()
}
}

@ -0,0 +1,43 @@
package pool
import (
"net"
"sync"
)
// PoolConn is a wrapper around net.Conn to modify the the behavior of
// net.Conn's Close() method.
type PoolConn struct {
net.Conn
mu sync.RWMutex
c *channelPool
unusable bool
}
// Close() puts the given connects back to the pool instead of closing it.
func (p *PoolConn) Close() error {
p.mu.RLock()
defer p.mu.RUnlock()
if p.unusable {
if p.Conn != nil {
return p.Conn.Close()
}
return nil
}
return p.c.put(p.Conn)
}
// MarkUnusable() marks the connection not usable any more, to let the pool close it instead of returning it to pool.
func (p *PoolConn) MarkUnusable() {
p.mu.Lock()
p.unusable = true
p.mu.Unlock()
}
// newConn wraps a standard net.Conn to a poolConn net.Conn.
func (c *channelPool) wrapConn(conn net.Conn) net.Conn {
p := &PoolConn{c: c}
p.Conn = conn
return p
}

@ -0,0 +1,10 @@
package pool
import (
"net"
"testing"
)
func TestConn_Impl(t *testing.T) {
var _ net.Conn = new(PoolConn)
}

@ -0,0 +1,31 @@
// Package pool implements a pool of net.Conn interfaces to manage and reuse them.
package pool
import (
"errors"
"net"
)
var (
// ErrClosed is the error resulting if the pool is closed via pool.Close().
ErrClosed = errors.New("pool is closed")
)
// Pool interface describes a pool implementation. A pool should have maximum
// capacity. An ideal pool is threadsafe and easy to use.
type Pool interface {
// Get returns a new connection from the pool. Closing the connections puts
// it back to the Pool. Closing it when the pool is destroyed or full will
// be counted as an error.
Get() (net.Conn, error)
// Close closes the pool and all its connections. After Close() the pool is
// no longer usable.
Close()
// Len returns the current number of connections of the pool.
Len() int
// Stats returns stats about the pool.
Stats() (map[string]interface{}, error)
}
Loading…
Cancel
Save