1
0
Fork 0

Refactor join code with Joiner type (#986)

Refactor join code with Joiner type
master
Philip O'Toole 3 years ago committed by GitHub
parent 61e66893b1
commit e1aeb9a664
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -12,6 +12,7 @@ This release introduces supported for [DNS-based](https://www.cloudflare.com/lea
- [PR #981](https://github.com/rqlite/rqlite/pull/981): Add curent time to node `/status` output.
- [PR #982](https://github.com/rqlite/rqlite/pull/982): `/readyz` can skip leader check via `noleader` query param.
- [PR #984](https://github.com/rqlite/rqlite/pull/984): Count number of `/status` and `/readyz` requests via expvar.
- [PR #986](https://github.com/rqlite/rqlite/pull/986): Refactor join code with new Joiner type.
## 7.1.0 (January 28th 2022)
This release introduces a new automatic clustering approach, known as _Bootstrapping_, which allows rqlite clusters to form without assistance from an external system such as Consul. This can be very useful for certain deployment scenarios. See the [documentation](https://github.com/rqlite/rqlite/blob/master/DOC/AUTO_CLUSTERING.md) for full details on using the new Bootstrapping mode. Special thanks to [Nathan Ferch](https://github.com/nferch) for his advice regarding the design and development of this feature.

@ -38,6 +38,11 @@ type Bootstrapper struct {
expect int
tlsConfig *tls.Config
joiner *Joiner
username string
password string
logger *log.Logger
Interval time.Duration
}
@ -48,6 +53,7 @@ func NewBootstrapper(p AddressProvider, expect int, tlsConfig *tls.Config) *Boot
provider: p,
expect: expect,
tlsConfig: &tls.Config{InsecureSkipVerify: true},
joiner: NewJoiner("", 1, 0, tlsConfig),
logger: log.New(os.Stderr, "[cluster-bootstrap] ", log.LstdFlags),
Interval: jitter(5 * time.Second),
}
@ -57,6 +63,11 @@ func NewBootstrapper(p AddressProvider, expect int, tlsConfig *tls.Config) *Boot
return bs
}
// SetBasicAuth sets Basic Auth credentials for any bootstrap attempt.
func (b *Bootstrapper) SetBasicAuth(username, password string) {
b.username, b.password = username, password
}
// Boot performs the bootstrapping process for this node. This means it will
// ensure this node becomes part of a cluster. It does this by either joining
// an existing cluster by explicitly joining it through one of these nodes,
@ -98,9 +109,9 @@ func (b *Bootstrapper) Boot(id, raftAddr string, done func() bool, timeout time.
}
// Try an explicit join.
if j, err := Join("", targets, id, raftAddr, true, 1, 0, b.tlsConfig); err == nil {
b.logger.Printf("succeeded directly joining cluster via node at %s",
httpd.RemoveBasicAuth(j))
b.joiner.SetBasicAuth(b.username, b.password)
if j, err := b.joiner.Do(targets, id, raftAddr, true); err == nil {
b.logger.Printf("succeeded directly joining cluster via node at %s", j)
return nil
}
@ -139,11 +150,18 @@ func (b *Bootstrapper) notify(targets []string, id, raftAddr string) error {
TargetLoop:
for {
resp, err := client.Post(fullTarget, "application/json", bytes.NewReader(buf))
req, err := http.NewRequest("POST", fullTarget, bytes.NewReader(buf))
if err != nil {
return err
}
if b.username != "" && b.password != "" {
req.SetBasicAuth(b.username, b.password)
}
req.Header.Add("Content-Type", "application/json")
resp, err := client.Do(req)
if err != nil {
return err
// time.Sleep(bs.joinInterval) // need to count loops....? Or this just does one loop?
// continue
}
resp.Body.Close()
switch resp.StatusCode {
@ -156,8 +174,7 @@ func (b *Bootstrapper) notify(targets []string, id, raftAddr string) error {
// record information about which protocol a registered node is actually using.
if strings.HasPrefix(fullTarget, "https://") {
// It's already HTTPS, give up.
return fmt.Errorf("failed to notify node at %s: %s",
httpd.RemoveBasicAuth(fullTarget), resp.Status)
return fmt.Errorf("failed to notify node at %s: %s", fullTarget, resp.Status)
}
fullTarget = httpd.EnsureHTTPS(fullTarget)
default:

@ -115,7 +115,48 @@ func Test_BootstrapperBootSingleNotify(t *testing.T) {
if got, exp := body["addr"], "192.168.1.1:1234"; got != exp {
t.Fatalf("wrong address supplied, exp %s, got %s", exp, got)
}
}
func Test_BootstrapperBootSingleNotifyAuth(t *testing.T) {
tsNotified := false
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
username, password, ok := r.BasicAuth()
if !ok {
t.Fatalf("request did not have Basic Auth credentials")
}
if username != "username1" || password != "password1" {
t.Fatalf("bad Basic Auth credentials received (%s, %s", username, password)
}
if r.URL.Path == "/join" {
w.WriteHeader(http.StatusServiceUnavailable)
return
}
tsNotified = true
}))
n := -1
done := func() bool {
n++
if n == 5 {
return true
}
return false
}
p := NewAddressProviderString([]string{ts.URL})
bs := NewBootstrapper(p, 1, nil)
bs.SetBasicAuth("username1", "password1")
bs.Interval = time.Second
err := bs.Boot("node1", "192.168.1.1:1234", done, 60*time.Second)
if err != nil {
t.Fatalf("failed to boot: %s", err)
}
if tsNotified != true {
t.Fatalf("notify target not contacted")
}
}
func Test_BootstrapperBootMultiNotify(t *testing.T) {

@ -18,6 +18,12 @@ import (
)
var (
// ErrInvalidRedirect is returned when a node returns an invalid HTTP redirect.
ErrInvalidRedirect = errors.New("invalid redirect received")
// ErrNodeIDRequired is returned a join request doesn't supply a node ID
ErrNodeIDRequired = errors.New("node required")
// ErrJoinFailed is returned when a node fails to join a cluster
ErrJoinFailed = errors.New("failed to join cluster")
@ -25,39 +31,26 @@ var (
ErrNotifyFailed = errors.New("failed to notify node")
)
// Join attempts to join the cluster at one of the addresses given in joinAddr.
// It walks through joinAddr in order, and sets the node ID and Raft address of
// the joining node as id addr respectively. It returns the endpoint successfully
// used to join the cluster.
func Join(srcIP string, joinAddr []string, id, addr string, voter bool, numAttempts int,
attemptInterval time.Duration, tlsConfig *tls.Config) (string, error) {
var err error
var j string
logger := log.New(os.Stderr, "[cluster-join] ", log.LstdFlags)
if tlsConfig == nil {
tlsConfig = &tls.Config{InsecureSkipVerify: true}
}
// Joiner executes a node-join operation.
type Joiner struct {
srcIP string
numAttempts int
attemptInterval time.Duration
tlsConfig *tls.Config
for i := 0; i < numAttempts; i++ {
for _, a := range joinAddr {
j, err = join(srcIP, a, id, addr, voter, tlsConfig, logger)
if err == nil {
// Success!
return j, nil
}
}
logger.Printf("failed to join cluster at %s: %s, sleeping %s before retry", joinAddr, err.Error(), attemptInterval)
time.Sleep(attemptInterval)
}
logger.Printf("failed to join cluster at %s, after %d attempts", joinAddr, numAttempts)
return "", ErrJoinFailed
}
username string
password string
func join(srcIP, joinAddr, id, addr string, voter bool, tlsConfig *tls.Config, logger *log.Logger) (string, error) {
if id == "" {
return "", fmt.Errorf("node ID not set")
client *http.Client
logger *log.Logger
}
// The specified source IP is optional
// NewJoiner returns an instantiated Joiner.
func NewJoiner(srcIP string, numAttempts int, attemptInterval time.Duration,
tlsCfg *tls.Config) *Joiner {
// Source IP is optional
var dialer *net.Dialer
dialer = &net.Dialer{}
if srcIP != "" {
@ -67,25 +60,69 @@ func join(srcIP, joinAddr, id, addr string, voter bool, tlsConfig *tls.Config, l
}
dialer = &net.Dialer{LocalAddr: netAddr}
}
// Join using IP address, as that is what Hashicorp Raft works in.
resv, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return "", err
}
// Check for protocol scheme, and insert default if necessary.
fullAddr := httpd.NormalizeAddr(fmt.Sprintf("%s/join", joinAddr))
joiner := &Joiner{
srcIP: srcIP,
numAttempts: numAttempts,
attemptInterval: attemptInterval,
tlsConfig: tlsCfg,
logger: log.New(os.Stderr, "[cluster-join] ", log.LstdFlags),
}
if joiner.tlsConfig == nil {
joiner.tlsConfig = &tls.Config{InsecureSkipVerify: true}
}
// Create and configure the client to connect to the other node.
tr := &http.Transport{
TLSClientConfig: tlsConfig,
TLSClientConfig: joiner.tlsConfig,
Dial: dialer.Dial,
}
client := &http.Client{Transport: tr}
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
joiner.client = &http.Client{Transport: tr}
joiner.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
return joiner
}
// SetBasicAuth sets Basic Auth credentials for any join attempt.
func (j *Joiner) SetBasicAuth(username, password string) {
j.username, j.password = username, password
}
// Do makes the actual join request.
func (j *Joiner) Do(joinAddrs []string, id, addr string, voter bool) (string, error) {
if id == "" {
return "", ErrNodeIDRequired
}
var err error
var joinee string
for i := 0; i < j.numAttempts; i++ {
for _, a := range joinAddrs {
joinee, err = j.join(a, id, addr, voter)
if err == nil {
// Success!
return joinee, nil
}
}
j.logger.Printf("failed to join cluster at %s: %s, sleeping %s before retry", joinAddrs, err.Error(), j.attemptInterval)
time.Sleep(j.attemptInterval)
}
j.logger.Printf("failed to join cluster at %s, after %d attempts", joinAddrs, j.numAttempts)
return "", ErrJoinFailed
}
func (j *Joiner) join(joinAddr, id, addr string, voter bool) (string, error) {
// Join using IP address, as that is what Hashicorp Raft works in.
resv, err := net.ResolveTCPAddr("tcp", addr)
if err != nil {
return "", err
}
// Check for protocol scheme, and insert default if necessary.
fullAddr := httpd.NormalizeAddr(fmt.Sprintf("%s/join", joinAddr))
for {
b, err := json.Marshal(map[string]interface{}{
"id": id,
@ -97,7 +134,15 @@ func join(srcIP, joinAddr, id, addr string, voter bool, tlsConfig *tls.Config, l
}
// Attempt to join.
resp, err := client.Post(fullAddr, "application/json", bytes.NewReader(b))
req, err := http.NewRequest("POST", fullAddr, bytes.NewReader(b))
if err != nil {
return "", err
}
if j.username != "" && j.password != "" {
req.SetBasicAuth(j.username, j.password)
}
req.Header.Add("Content-Type", "application/json")
resp, err := j.client.Do(req)
if err != nil {
return "", err
}
@ -114,7 +159,7 @@ func join(srcIP, joinAddr, id, addr string, voter bool, tlsConfig *tls.Config, l
case http.StatusMovedPermanently:
fullAddr = resp.Header.Get("location")
if fullAddr == "" {
return "", fmt.Errorf("failed to join, invalid redirect received")
return "", ErrInvalidRedirect
}
continue
case http.StatusBadRequest:
@ -127,7 +172,7 @@ func join(srcIP, joinAddr, id, addr string, voter bool, tlsConfig *tls.Config, l
return "", fmt.Errorf("failed to join, node returned: %s: (%s)", resp.Status, string(b))
}
logger.Print("join via HTTP failed, trying via HTTPS")
j.logger.Print("join via HTTP failed, trying via HTTPS")
fullAddr = httpd.EnsureHTTPS(fullAddr)
continue
default:

@ -11,7 +11,7 @@ import (
)
const numAttempts int = 3
const attemptInterval = 5 * time.Second
const attemptInterval = 1 * time.Second
func Test_SingleJoinOK(t *testing.T) {
var body map[string]interface{}
@ -21,6 +21,10 @@ func Test_SingleJoinOK(t *testing.T) {
}
w.WriteHeader(http.StatusOK)
if r.Header["Content-Type"][0] != "application/json" {
t.Fatalf("incorrect Content-Type set")
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
@ -32,11 +36,62 @@ func Test_SingleJoinOK(t *testing.T) {
return
}
}))
defer ts.Close()
joiner := NewJoiner("127.0.0.1", numAttempts, attemptInterval, nil)
j, err := joiner.Do([]string{ts.URL}, "id0", "127.0.0.1:9090", false)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
if j != ts.URL+"/join" {
t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts.URL)
}
if got, exp := body["id"].(string), "id0"; got != exp {
t.Fatalf("wrong node ID supplied, exp %s, got %s", exp, got)
}
if got, exp := body["addr"].(string), "127.0.0.1:9090"; got != exp {
t.Fatalf("wrong address supplied, exp %s, got %s", exp, got)
}
if got, exp := body["voter"].(bool), false; got != exp {
t.Fatalf("wrong voter state supplied, exp %v, got %v", exp, got)
}
}
func Test_SingleJoinOKBasicAuth(t *testing.T) {
var body map[string]interface{}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Fatalf("Client did not use POST")
}
w.WriteHeader(http.StatusOK)
username, password, ok := r.BasicAuth()
if !ok {
t.Fatalf("request did not have Basic Auth credentials")
}
if username != "user1" || password != "password1" {
t.Fatalf("bad Basic Auth credentials received (%s, %s", username, password)
}
b, err := ioutil.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
if err := json.Unmarshal(b, &body); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
}))
defer ts.Close()
j, err := Join("127.0.0.1", []string{ts.URL}, "id0", "127.0.0.1:9090", false,
numAttempts, attemptInterval, nil)
joiner := NewJoiner("127.0.0.1", numAttempts, attemptInterval, nil)
joiner.SetBasicAuth("user1", "password1")
j, err := joiner.Do([]string{ts.URL}, "id0", "127.0.0.1:9090", false)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
@ -60,7 +115,8 @@ func Test_SingleJoinZeroAttempts(t *testing.T) {
t.Fatalf("handler should not have been called")
}))
_, err := Join("127.0.0.1", []string{ts.URL}, "id0", "127.0.0.1:9090", false, 0, attemptInterval, nil)
joiner := NewJoiner("127.0.0.1", 0, attemptInterval, nil)
_, err := joiner.Do([]string{ts.URL}, "id0", "127.0.0.1:9090", false)
if err != ErrJoinFailed {
t.Fatalf("Incorrect error returned when zero attempts specified")
}
@ -72,8 +128,8 @@ func Test_SingleJoinFail(t *testing.T) {
}))
defer ts.Close()
_, err := Join("", []string{ts.URL}, "id0", "127.0.0.1:9090", true,
numAttempts, attemptInterval, nil)
joiner := NewJoiner("", 0, attemptInterval, nil)
_, err := joiner.Do([]string{ts.URL}, "id0", "127.0.0.1:9090", true)
if err == nil {
t.Fatalf("expected error when joining bad node")
}
@ -87,8 +143,9 @@ func Test_DoubleJoinOK(t *testing.T) {
}))
defer ts2.Close()
j, err := Join("127.0.0.1", []string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", true,
numAttempts, attemptInterval, nil)
joiner := NewJoiner("127.0.0.1", numAttempts, attemptInterval, nil)
j, err := joiner.Do([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", true)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
@ -106,8 +163,9 @@ func Test_DoubleJoinOKSecondNode(t *testing.T) {
}))
defer ts2.Close()
j, err := Join("", []string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", true,
numAttempts, attemptInterval, nil)
joiner := NewJoiner("", numAttempts, attemptInterval, nil)
j, err := joiner.Do([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", true)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
@ -127,8 +185,9 @@ func Test_DoubleJoinOKSecondNodeRedirect(t *testing.T) {
}))
defer ts2.Close()
j, err := Join("127.0.0.1", []string{ts2.URL}, "id0", "127.0.0.1:9090", true,
numAttempts, attemptInterval, nil)
joiner := NewJoiner("127.0.0.1", numAttempts, attemptInterval, nil)
j, err := joiner.Do([]string{ts2.URL}, "id0", "127.0.0.1:9090", true)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}

@ -331,37 +331,49 @@ func createCluster(cfg *Config, tlsConfig *tls.Config, hasPeers bool, str *store
return nil
}
if joins != nil {
if cfg.BootstrapExpect == 0 {
// Explicit join operation requested, so do it.
if err := addJoinCreds(joins, cfg.JoinAs, credStr); err != nil {
return fmt.Errorf("failed to add BasicAuth creds: %s", err.Error())
// Prepare the Joiner
joiner := cluster.NewJoiner(cfg.JoinSrcIP, cfg.JoinAttempts, cfg.JoinInterval, tlsConfig)
if cfg.JoinAs != "" {
pw, ok := credStr.Password(cfg.JoinAs)
if !ok {
return fmt.Errorf("user %s does not exist in credential store", cfg.JoinAs)
}
joiner.SetBasicAuth(cfg.JoinAs, pw)
}
// Prepare defintion of being part of a cluster.
isClustered := func() bool {
leader, _ := str.LeaderAddr()
return leader != ""
}
j, err := cluster.Join(cfg.JoinSrcIP, joins, str.ID(), cfg.RaftAdv, !cfg.RaftNonVoter,
cfg.JoinAttempts, cfg.JoinInterval, tlsConfig)
if joins != nil && cfg.BootstrapExpect == 0 {
// Explicit join operation requested, so do it.
j, err := joiner.Do(joins, str.ID(), cfg.RaftAdv, !cfg.RaftNonVoter)
if err != nil {
return fmt.Errorf("failed to join cluster: %s", err.Error())
}
log.Println("successfully joined cluster at", httpd.RemoveBasicAuth(j))
log.Println("successfully joined cluster at", j)
return nil
}
if joins != nil && cfg.BootstrapExpect > 0 {
// Bootstrap with explicit join addresses requests.
if hasPeers {
log.Println("preexisting node configuration detected, ignoring bootstrap request")
return nil
}
if err := addJoinCreds(joins, cfg.JoinAs, credStr); err != nil {
return fmt.Errorf("failed to add BasicAuth creds: %s", err.Error())
}
bs := cluster.NewBootstrapper(cluster.NewAddressProviderString(joins),
cfg.BootstrapExpect, tlsConfig)
done := func() bool {
leader, _ := str.LeaderAddr()
return leader != ""
if cfg.JoinAs != "" {
pw, ok := credStr.Password(cfg.JoinAs)
if !ok {
return fmt.Errorf("user %s does not exist in credential store", cfg.JoinAs)
}
return bs.Boot(str.ID(), cfg.RaftAdv, done, cfg.BootstrapExpectTimeout)
bs.SetBasicAuth(cfg.JoinAs, pw)
}
return bs.Boot(str.ID(), cfg.RaftAdv, isClustered, cfg.BootstrapExpectTimeout)
}
if cfg.DiscoMode == "" {
@ -369,10 +381,10 @@ func createCluster(cfg *Config, tlsConfig *tls.Config, hasPeers bool, str *store
// existing Raft state.
return nil
}
log.Printf("discovery mode: %s", cfg.DiscoMode)
// DNS-based discovery involves a few different options.
if cfg.DiscoMode == DiscoModeDNS || cfg.DiscoMode == DiscoModeDNSSRV {
log.Printf("discovery mode: %s", cfg.DiscoMode)
switch cfg.DiscoMode {
case DiscoModeDNS, DiscoModeDNSSRV:
if hasPeers {
log.Printf("preexisting node configuration detected, ignoring %s", cfg.DiscoMode)
return nil
@ -404,14 +416,17 @@ func createCluster(cfg *Config, tlsConfig *tls.Config, hasPeers bool, str *store
}
bs := cluster.NewBootstrapper(provider, cfg.BootstrapExpect, tlsConfig)
done := func() bool {
leader, _ := str.LeaderAddr()
return leader != ""
if cfg.JoinAs != "" {
pw, ok := credStr.Password(cfg.JoinAs)
if !ok {
return fmt.Errorf("user %s does not exist in credential store", cfg.JoinAs)
}
bs.SetBasicAuth(cfg.JoinAs, pw)
}
httpServ.RegisterStatus("disco", provider)
return bs.Boot(str.ID(), cfg.RaftAdv, isClustered, cfg.BootstrapExpectTimeout)
return bs.Boot(str.ID(), cfg.RaftAdv, done, cfg.BootstrapExpectTimeout)
} else {
case DiscoModeEtcdKV, DiscoModeConsulKV:
discoService, err := createDiscoService(cfg, str)
if err != nil {
return fmt.Errorf("failed to start discovery service: %s", err.Error())
@ -432,12 +447,7 @@ func createCluster(cfg *Config, tlsConfig *tls.Config, hasPeers bool, str *store
} else {
for {
log.Printf("discovery service returned %s as join address", addr)
if err := addJoinCreds([]string{addr}, cfg.JoinAs, credStr); err != nil {
return fmt.Errorf("failed too add auth creds: %s", err.Error())
}
if j, err := cluster.Join(cfg.JoinSrcIP, []string{addr}, str.ID(), cfg.RaftAdv, !cfg.RaftNonVoter,
cfg.JoinAttempts, cfg.JoinInterval, tlsConfig); err != nil {
if j, err := joiner.Do([]string{addr}, str.ID(), cfg.RaftAdv, !cfg.RaftNonVoter); err != nil {
log.Printf("failed to join cluster at %s: %s", addr, err.Error())
time.Sleep(time.Second)
@ -455,30 +465,11 @@ func createCluster(cfg *Config, tlsConfig *tls.Config, hasPeers bool, str *store
} else {
log.Println("preexisting node configuration detected, not registering with discovery service")
}
go discoService.StartReporting(cfg.NodeID, cfg.HTTPURL(), cfg.RaftAdv)
httpServ.RegisterStatus("disco", discoService)
}
return nil
}
// addJoinCreds adds credentials to any join addresses, if necessary.
func addJoinCreds(joins []string, joinAs string, credStr *auth.CredentialsStore) error {
if credStr == nil || joinAs == "" {
return nil
}
pw, ok := credStr.Password(joinAs)
if !ok {
return fmt.Errorf("user %s does not exist in credential store", joinAs)
}
var err error
for i := range joins {
joins[i], err = httpd.AddBasicAuth(joins[i], joinAs, pw)
if err != nil {
return fmt.Errorf("failed to use credential store join_as: %s", err.Error())
}
default:
return fmt.Errorf("invalid disco mode %s", cfg.DiscoMode)
}
return nil
}

Loading…
Cancel
Save