package pool import ( "log" "math/rand" "net" "sync" "testing" "time" ) var ( InitialCap = 5 MaximumCap = 30 network = "tcp" address = "127.0.0.1:7777" factory = func() (net.Conn, error) { return net.Dial(network, address) } ) func init() { // used for factory function go simpleTCPServer() time.Sleep(time.Millisecond * 300) // wait until tcp server has been settled rand.Seed(time.Now().UTC().UnixNano()) } func TestNew(t *testing.T) { _, err := newChannelPool() if err != nil { t.Errorf("New error: %s", err) } } func TestPool_Get_Impl(t *testing.T) { p, _ := newChannelPool() defer p.Close() conn, err := p.Get() if err != nil { t.Errorf("Get error: %s", err) } _, ok := conn.(*PoolConn) if !ok { t.Errorf("Conn is not of type poolConn") } } func TestPool_Get(t *testing.T) { p, _ := newChannelPool() defer p.Close() _, err := p.Get() if err != nil { t.Errorf("Get error: %s", err) } // after one get, current capacity should be lowered by one. if p.Len() != (InitialCap - 1) { t.Errorf("Get error. Expecting %d, got %d", (InitialCap - 1), p.Len()) } // get them all var wg sync.WaitGroup for i := 0; i < (InitialCap - 1); i++ { wg.Add(1) go func() { defer wg.Done() _, err := p.Get() if err != nil { t.Errorf("Get error: %s", err) } }() } wg.Wait() if p.Len() != 0 { t.Errorf("Get error. Expecting %d, got %d", (InitialCap - 1), p.Len()) } _, err = p.Get() if err != nil { t.Errorf("Get error: %s", err) } } func TestPool_Put(t *testing.T) { p, err := NewChannelPool(0, 30, factory) if err != nil { t.Fatal(err) } defer p.Close() // get/create from the pool conns := make([]net.Conn, MaximumCap) for i := 0; i < MaximumCap; i++ { conn, _ := p.Get() conns[i] = conn } // now put them all back for _, conn := range conns { conn.Close() } if p.Len() != MaximumCap { t.Errorf("Put error len. Expecting %d, got %d", 1, p.Len()) } conn, _ := p.Get() p.Close() // close pool conn.Close() // try to put into a full pool if p.Len() != 0 { t.Errorf("Put error. Closed pool shouldn't allow to put connections.") } } func TestPool_PutUnusableConn(t *testing.T) { p, _ := newChannelPool() defer p.Close() // ensure pool is not empty conn, _ := p.Get() conn.Close() poolSize := p.Len() conn, _ = p.Get() conn.Close() if p.Len() != poolSize { t.Errorf("Pool size is expected to be equal to initial size") } conn, _ = p.Get() if pc, ok := conn.(*PoolConn); !ok { t.Errorf("impossible") } else { pc.MarkUnusable() } conn.Close() if p.Len() != poolSize-1 { t.Errorf("Pool size is expected to be initial_size - 1, %d, %d", p.Len(), poolSize-1) } } func TestPool_UsedCapacity(t *testing.T) { p, _ := newChannelPool() defer p.Close() if p.Len() != InitialCap { t.Errorf("InitialCap error. Expecting %d, got %d", InitialCap, p.Len()) } } func TestPool_Close(t *testing.T) { p, _ := newChannelPool() // now close it and test all cases we are expecting. p.Close() c := p.(*channelPool) if c.conns != nil { t.Errorf("Close error, conns channel should be nil") } if c.factory != nil { t.Errorf("Close error, factory should be nil") } _, err := p.Get() if err == nil { t.Errorf("Close error, get conn should return an error") } if p.Len() != 0 { t.Errorf("Close error used capacity. Expecting 0, got %d", p.Len()) } } func TestPoolConcurrent(t *testing.T) { p, _ := newChannelPool() pipe := make(chan net.Conn, 0) go func() { p.Close() }() for i := 0; i < MaximumCap; i++ { go func() { conn, _ := p.Get() pipe <- conn }() go func() { conn := <-pipe if conn == nil { return } conn.Close() }() } } func TestPoolWriteRead(t *testing.T) { p, _ := NewChannelPool(0, 30, factory) conn, _ := p.Get() msg := "hello" _, err := conn.Write([]byte(msg)) if err != nil { t.Error(err) } } func TestPoolConcurrent2(t *testing.T) { p, _ := NewChannelPool(0, 30, factory) var wg sync.WaitGroup wg.Add(10) go func() { for i := 0; i < 10; i++ { go func(i int) { conn, _ := p.Get() time.Sleep(time.Millisecond * time.Duration(rand.Intn(100))) conn.Close() wg.Done() }(i) } }() wg.Add(10) for i := 0; i < 10; i++ { go func(i int) { conn, _ := p.Get() time.Sleep(time.Millisecond * time.Duration(rand.Intn(100))) conn.Close() wg.Done() }(i) } wg.Wait() } func TestPoolConcurrent3(t *testing.T) { p, _ := NewChannelPool(0, 1, factory) var wg sync.WaitGroup wg.Add(1) go func() { p.Close() wg.Done() }() if conn, err := p.Get(); err == nil { conn.Close() } wg.Wait() } func newChannelPool() (Pool, error) { return NewChannelPool(InitialCap, MaximumCap, factory) } func simpleTCPServer() { l, err := net.Listen(network, address) if err != nil { log.Fatal(err) } defer l.Close() for { conn, err := l.Accept() if err != nil { log.Fatal(err) } go func() { buffer := make([]byte, 256) conn.Read(buffer) }() } }