1
0
Fork 0
You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

316 lines
7.4 KiB
Go

package cluster
import (
"errors"
"net"
"reflect"
"sync/atomic"
"testing"
"time"
"github.com/rqlite/rqlite/v8/cluster/proto"
"github.com/rqlite/rqlite/v8/cluster/servicetest"
command "github.com/rqlite/rqlite/v8/command/proto"
pb "google.golang.org/protobuf/proto"
)
func Test_AddressProviderString(t *testing.T) {
a := []string{"a", "b", "c"}
p := NewAddressProviderString(a)
b, err := p.Lookup()
if err != nil {
t.Fatalf("failed to lookup addresses: %s", err.Error())
}
if !reflect.DeepEqual(a, b) {
t.Fatalf("failed to get correct addresses")
}
}
func Test_NewBootstrapper(t *testing.T) {
bs := NewBootstrapper(nil, nil)
if bs == nil {
t.Fatalf("failed to create a simple Bootstrapper")
}
if exp, got := BootUnknown, bs.Status(); exp != got {
t.Fatalf("wrong status, exp %s, got %s", exp, got)
}
}
func Test_BootstrapperBootDoneImmediately(t *testing.T) {
srv := servicetest.NewService()
srv.Handler = func(conn net.Conn) {
t.Fatalf("client made request")
}
srv.Start()
defer srv.Close()
done := func() bool {
return true
}
p := NewAddressProviderString([]string{srv.Addr()})
bs := NewBootstrapper(p, nil)
if err := bs.Boot("node1", "192.168.1.1:1234", Voter, done, 10*time.Second); err != nil {
t.Fatalf("failed to boot: %s", err)
}
if exp, got := BootDone, bs.Status(); exp != got {
t.Fatalf("wrong status, exp %s, got %s", exp, got)
}
}
func Test_BootstrapperBootTimeout(t *testing.T) {
srv := servicetest.NewService()
srv.Handler = func(conn net.Conn) {
}
srv.Start()
defer srv.Close()
done := func() bool {
return false
}
p := NewAddressProviderString([]string{srv.Addr()})
bs := NewBootstrapper(p, NewClient(&simpleDialer{}, 0))
bs.Interval = time.Second
err := bs.Boot("node1", "192.168.1.1:1234", Voter, done, 5*time.Second)
if err == nil {
t.Fatalf("no error returned from timed-out boot")
}
if !errors.Is(err, ErrBootTimeout) {
t.Fatalf("wrong error returned")
}
if exp, got := BootTimeout, bs.Status(); exp != got {
t.Fatalf("wrong status, exp %s, got %s", exp, got)
}
}
func Test_BootstrapperBootSingleJoin(t *testing.T) {
srv := servicetest.NewService()
srv.Handler = func(conn net.Conn) {
var p []byte
var err error
c := readCommand(conn)
if c == nil {
// Connection error handling
return
}
if c.Type != proto.Command_COMMAND_TYPE_JOIN {
t.Fatalf("unexpected command type: %d", c.Type)
}
jnr := c.GetJoinRequest()
if jnr == nil {
t.Fatal("expected join node request, got nil")
}
if jnr.Address != "192.168.1.1:1234" {
t.Fatalf("unexpected node address, got %s", jnr.Address)
}
p, err = pb.Marshal(&proto.CommandJoinResponse{})
if err != nil {
conn.Close()
return
}
writeBytesWithLength(conn, p)
}
srv.Start()
defer srv.Close()
done := func() bool {
return false
}
p := NewAddressProviderString([]string{srv.Addr()})
bs := NewBootstrapper(p, NewClient(&simpleDialer{}, 0))
bs.Interval = time.Second
err := bs.Boot("node1", "192.168.1.1:1234", Voter, done, 5*time.Second)
if err != nil {
t.Fatalf("failed to boot: %s", err)
}
if exp, got := BootJoin, bs.Status(); exp != got {
t.Fatalf("wrong status, exp %s, got %s", exp, got)
}
}
// Test_BootstrapperBootNonVoter tests that a non-voter just attempts
// to join the cluster, and does not send a notify request.
func Test_BootstrapperBootNonVoter(t *testing.T) {
srv := servicetest.NewService()
srv.Handler = func(conn net.Conn) {
c := readCommand(conn)
if c == nil {
// Connection error handling
return
}
if c.Type != proto.Command_COMMAND_TYPE_JOIN {
t.Fatalf("unexpected command type: %d", c.Type)
}
jnr := c.GetJoinRequest()
if jnr == nil {
t.Fatal("expected join node request, got nil")
}
if jnr.Address != "192.168.1.1:1234" {
t.Fatalf("unexpected node address, got %s", jnr.Address)
}
// Just return, which will cause the bootstrapper to timeout.
}
srv.Start()
defer srv.Close()
done := func() bool {
return false
}
p := NewAddressProviderString([]string{srv.Addr()})
bs := NewBootstrapper(p, NewClient(&simpleDialer{}, 0))
bs.Interval = time.Second
err := bs.Boot("node1", "192.168.1.1:1234", NonVoter, done, 3*time.Second)
if err == nil {
t.Fatalf("expected error, got none")
}
if exp, got := BootTimeout, bs.Status(); exp != got {
t.Fatalf("wrong status, exp %s, got %s", exp, got)
}
}
func Test_BootstrapperBootSingleNotify(t *testing.T) {
var gotNR *command.NotifyRequest
srv := servicetest.NewService()
srv.Handler = func(conn net.Conn) {
var p []byte
var err error
c := readCommand(conn)
if c == nil {
// Connection error handling
return
}
if c.Type != proto.Command_COMMAND_TYPE_NOTIFY {
return
}
gotNR = c.GetNotifyRequest()
p, err = pb.Marshal(&proto.CommandNotifyResponse{})
if err != nil {
conn.Close()
return
}
writeBytesWithLength(conn, p)
}
srv.Start()
defer srv.Close()
n := -1
done := func() bool {
n++
return n == 5
}
p := NewAddressProviderString([]string{srv.Addr()})
bs := NewBootstrapper(p, NewClient(&simpleDialer{}, 0))
bs.Interval = time.Second
err := bs.Boot("node1", "192.168.1.1:1234", Voter, done, 60*time.Second)
if err != nil {
t.Fatalf("failed to boot: %s", err)
}
if got, exp := gotNR.Id, "node1"; got != exp {
t.Fatalf("wrong node ID supplied, exp %s, got %s", exp, got)
}
if got, exp := gotNR.Address, "192.168.1.1:1234"; got != exp {
t.Fatalf("wrong address supplied, exp %s, got %s", exp, got)
}
if exp, got := BootDone, bs.Status(); exp != got {
t.Fatalf("wrong status, exp %s, got %s", exp, got)
}
}
func Test_BootstrapperBootMultiJoinNotify(t *testing.T) {
var srv1JoinC int32
var srv1NotifiedC int32
srv1 := servicetest.NewService()
srv1.Handler = func(conn net.Conn) {
var p []byte
var err error
c := readCommand(conn)
if c == nil {
// Connection error handling
return
}
if c.Type == proto.Command_COMMAND_TYPE_JOIN {
atomic.AddInt32(&srv1JoinC, 1)
}
if c.Type != proto.Command_COMMAND_TYPE_NOTIFY {
return
}
atomic.AddInt32(&srv1NotifiedC, 1)
p, err = pb.Marshal(&proto.CommandNotifyResponse{})
if err != nil {
conn.Close()
return
}
writeBytesWithLength(conn, p)
}
srv1.Start()
defer srv1.Close()
var srv2JoinC int32
var srv2NotifiedC int32
srv2 := servicetest.NewService()
srv2.Handler = func(conn net.Conn) {
var p []byte
var err error
c := readCommand(conn)
if c == nil {
// Connection error handling
return
}
if c.Type == proto.Command_COMMAND_TYPE_JOIN {
atomic.AddInt32(&srv2JoinC, 1)
}
if c.Type != proto.Command_COMMAND_TYPE_NOTIFY {
return
}
atomic.AddInt32(&srv2NotifiedC, 1)
p, err = pb.Marshal(&proto.CommandNotifyResponse{})
if err != nil {
conn.Close()
return
}
writeBytesWithLength(conn, p)
}
srv2.Start()
defer srv2.Close()
n := -1
done := func() bool {
n++
return n == 5
}
p := NewAddressProviderString([]string{srv1.Addr(), srv2.Addr()})
bs := NewBootstrapper(p, NewClient(&simpleDialer{}, 0))
bs.Interval = time.Second
err := bs.Boot("node1", "192.168.1.1:1234", Voter, done, 60*time.Second)
if err != nil {
t.Fatalf("failed to boot: %s", err)
}
if atomic.LoadInt32(&srv1JoinC) < 1 || atomic.LoadInt32(&srv2JoinC) < 1 {
t.Fatalf("all join targets not contacted")
}
if atomic.LoadInt32(&srv2JoinC) < 1 || atomic.LoadInt32(&srv2NotifiedC) < 1 {
t.Fatalf("all notify targets not contacted")
}
if exp, got := BootDone, bs.Status(); exp != got {
t.Fatalf("wrong status, exp %s, got %s", exp, got)
}
}