1
0
Fork 0

Pass cluster client to joiner

master
Philip O'Toole 10 months ago
parent 312d27778b
commit 672adab487

@ -1,18 +1,12 @@
package cluster
import (
"bytes"
"crypto/tls"
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net"
"net/http"
"os"
"strings"
"time"
"github.com/rqlite/rqlite/command"
)
var (
@ -28,59 +22,26 @@ var (
// Joiner executes a node-join operation.
type Joiner struct {
srcIP string
numAttempts int
attemptInterval time.Duration
tlsConfig *tls.Config
client *http.Client
client *Client
logger *log.Logger
}
// NewJoiner returns an instantiated Joiner.
func NewJoiner(srcIP string, numAttempts int, attemptInterval time.Duration, tlsCfg *tls.Config) *Joiner {
if tlsCfg == nil {
tlsCfg = &tls.Config{InsecureSkipVerify: true}
}
// Source IP is optional
dialer := &net.Dialer{}
if srcIP != "" {
netAddr := &net.TCPAddr{
IP: net.ParseIP(srcIP),
Port: 0,
}
dialer = &net.Dialer{LocalAddr: netAddr}
}
joiner := &Joiner{
srcIP: srcIP,
func NewJoiner(client *Client, numAttempts int, attemptInterval time.Duration) *Joiner {
return &Joiner{
client: client,
numAttempts: numAttempts,
attemptInterval: attemptInterval,
tlsConfig: tlsCfg,
logger: log.New(os.Stderr, "[cluster-join] ", log.LstdFlags),
}
// Create and configure the client to connect to the other node.
tr := &http.Transport{
TLSClientConfig: joiner.tlsConfig,
Dial: dialer.Dial,
ForceAttemptHTTP2: true,
}
joiner.client = &http.Client{Transport: tr}
joiner.client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
return joiner
}
// Do makes the actual join request. If any of the join addresses do not contain a
// protocol, both http:// and https:// are tried for that address. If the join is successful
// with any address, the Join URL of the node that joined is returned. Otherwise, an error
// is returned.
func (j *Joiner) Do(joinAddrs []string, id, addr string, voter bool) (string, error) {
// Do makes the actual join request. If the join is successful with any address,
// that address is returned. Otherwise, an error is returned.
func (j *Joiner) Do(targetAddrs []string, id, addr string, voter bool) (string, error) {
if id == "" {
return "", ErrNodeIDRequired
}
@ -88,82 +49,34 @@ func (j *Joiner) Do(joinAddrs []string, id, addr string, voter bool) (string, er
var err error
var joinee string
for i := 0; i < j.numAttempts; i++ {
for _, a := range normalizeAddrs(joinAddrs) {
joinee, err = j.join(a, id, addr, voter)
for _, ta := range targetAddrs {
joinee, err = j.join(ta, id, addr, voter)
if err == nil {
// Success!
return joinee, nil
}
j.logger.Printf("failed to join via node at %s: %s", a, err)
j.logger.Printf("failed to join via node at %s: %s", ta, err)
}
if i+1 < j.numAttempts {
// This logic message only make sense if performing more than 1 join-attempt.
j.logger.Printf("failed to join cluster at %s, sleeping %s before retry", joinAddrs, j.attemptInterval)
j.logger.Printf("failed to join cluster at %s, sleeping %s before retry", targetAddrs, j.attemptInterval)
time.Sleep(j.attemptInterval)
}
}
j.logger.Printf("failed to join cluster at %s, after %d attempt(s)", joinAddrs, j.numAttempts)
j.logger.Printf("failed to join cluster at %s, after %d attempt(s)", targetAddrs, j.numAttempts)
return "", ErrJoinFailed
}
func (j *Joiner) join(joinAddr, id, addr string, voter bool) (string, error) {
fullAddr := fmt.Sprintf("%s/join", joinAddr)
reqBody, err := json.Marshal(map[string]interface{}{
"id": id,
"addr": addr,
"voter": voter,
})
if err != nil {
return "", err
func (j *Joiner) join(targetAddr, id, addr string, voter bool) (string, error) {
req := &command.JoinRequest{
Id: id,
Address: addr,
Voter: voter,
}
for {
// Attempt to join.
req, err := http.NewRequest("POST", fullAddr, bytes.NewReader(reqBody))
if err != nil {
return "", err
}
var resp *http.Response
var respB []byte
err = func() error {
req.Header.Add("Content-Type", "application/json")
resp, err = j.client.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
// Only significant in the event of an error response
// from the remote node.
respB, err = io.ReadAll(resp.Body)
if err != nil {
return err
}
return nil
}()
if err != nil {
return "", err
}
switch resp.StatusCode {
case http.StatusOK:
return fullAddr, nil
default:
return "", fmt.Errorf("%s: (%s)", resp.Status, string(respB))
}
}
}
func normalizeAddrs(addrs []string) []string {
var a []string
for _, addr := range addrs {
if strings.Contains(addr, "://") {
a = append(a, addr)
} else {
a = append(a, fmt.Sprintf("http://%s", addr))
a = append(a, fmt.Sprintf("https://%s", addr))
}
// Attempt to join.
if err := j.client.Join(req, targetAddr, time.Second); err != nil {
return "", err
}
return a
return targetAddr, nil
}

@ -1,22 +1,18 @@
package cluster
import (
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/rqlite/rqlite/rtls"
)
const numAttempts int = 3
const attemptInterval = 1 * time.Second
func Test_SingleJoinOK(t *testing.T) {
func Test_SingleJoinOK(t *testing.T) { XXXXXX THIS TEST NEEDS TO BE UPDATED
var body map[string]interface{}
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
@ -82,190 +78,118 @@ func Test_SingleJoinOK(t *testing.T) {
}
}
func Test_SingleJoinHTTPSOK(t *testing.T) {
var body map[string]interface{}
ts := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
t.Fatalf("Client did not use POST")
}
w.WriteHeader(http.StatusOK)
if r.Header["Content-Type"][0] != "application/json" {
t.Fatalf("incorrect Content-Type set")
}
b, err := io.ReadAll(r.Body)
if err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
if err := json.Unmarshal(b, &body); err != nil {
w.WriteHeader(http.StatusBadRequest)
return
}
}))
defer ts.Close()
ts.TLS = &tls.Config{NextProtos: []string{"h2", "http/1.1"}}
ts.StartTLS()
tlsConfig, err := rtls.CreateClientConfig("", "", "", true)
if err != nil {
t.Fatalf("failed to create TLS config: %s", err.Error())
}
joiner := NewJoiner("127.0.0.1", numAttempts, attemptInterval, tlsConfig)
// Ensure joining with protocol prefix works.
j, err := joiner.Do([]string{ts.URL}, "id0", "127.0.0.1:9090", false)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
if j != ts.URL+"/join" {
t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts.URL)
}
if got, exp := body["id"].(string), "id0"; got != exp {
t.Fatalf("wrong node ID supplied, exp %s, got %s", exp, got)
}
if got, exp := body["addr"].(string), "127.0.0.1:9090"; got != exp {
t.Fatalf("wrong address supplied, exp %s, got %s", exp, got)
}
if got, exp := body["voter"].(bool), false; got != exp {
t.Fatalf("wrong voter state supplied, exp %v, got %v", exp, got)
}
// Ensure joining without protocol prefix works.
j, err = joiner.Do([]string{ts.Listener.Addr().String()}, "id0", "127.0.0.1:9090", false)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
if j != ts.URL+"/join" {
t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts.URL)
}
if got, exp := body["id"].(string), "id0"; got != exp {
t.Fatalf("wrong node ID supplied, exp %s, got %s", exp, got)
}
if got, exp := body["addr"].(string), "127.0.0.1:9090"; got != exp {
t.Fatalf("wrong address supplied, exp %s, got %s", exp, got)
}
if got, exp := body["voter"].(bool), false; got != exp {
t.Fatalf("wrong voter state supplied, exp %v, got %v", exp, got)
}
}
func Test_SingleJoinZeroAttempts(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Fatalf("handler should not have been called")
}))
joiner := NewJoiner("127.0.0.1", 0, attemptInterval, nil)
_, err := joiner.Do([]string{ts.URL}, "id0", "127.0.0.1:9090", false)
if err != ErrJoinFailed {
t.Fatalf("Incorrect error returned when zero attempts specified")
}
}
func Test_SingleJoinFail(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer ts.Close()
joiner := NewJoiner("", 0, attemptInterval, nil)
_, err := joiner.Do([]string{ts.URL}, "id0", "127.0.0.1:9090", true)
if err == nil {
t.Fatalf("expected error when joining bad node")
}
}
func Test_DoubleJoinOK(t *testing.T) {
ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts1.Close()
ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts2.Close()
joiner := NewJoiner("127.0.0.1", numAttempts, attemptInterval, nil)
// Ensure joining with protocol prefix works.
j, err := joiner.Do([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", true)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
if j != ts1.URL+"/join" {
t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts1.URL)
}
// Ensure joining without protocol prefix works.
j, err = joiner.Do([]string{ts1.Listener.Addr().String(), ts2.Listener.Addr().String()}, "id0", "127.0.0.1:9090", true)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
if j != ts1.URL+"/join" {
t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts1.URL)
}
}
func Test_DoubleJoinOKSecondNode(t *testing.T) {
ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer ts1.Close()
ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts2.Close()
joiner := NewJoiner("", numAttempts, attemptInterval, nil)
// Ensure joining with protocol prefix works.
j, err := joiner.Do([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", true)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
if j != ts2.URL+"/join" {
t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts2.URL)
}
// Ensure joining without protocol prefix works.
j, err = joiner.Do([]string{ts1.Listener.Addr().String(), ts2.Listener.Addr().String()}, "id0", "127.0.0.1:9090", true)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
if j != ts2.URL+"/join" {
t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts2.URL)
}
}
func Test_DoubleJoinOKSecondNodeRedirect(t *testing.T) {
ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
}))
defer ts1.Close()
redirectAddr := fmt.Sprintf("%s%s", ts1.URL, "/join")
ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, redirectAddr, http.StatusMovedPermanently)
}))
defer ts2.Close()
joiner := NewJoiner("127.0.0.1", numAttempts, attemptInterval, nil)
// Ensure joining with protocol prefix works.
j, err := joiner.Do([]string{ts2.URL}, "id0", "127.0.0.1:9090", true)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
if j != redirectAddr {
t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", redirectAddr, j)
}
// Ensure joining without protocol prefix works.
j, err = joiner.Do([]string{ts2.Listener.Addr().String()}, "id0", "127.0.0.1:9090", true)
if err != nil {
t.Fatalf("failed to join a single node: %s", err.Error())
}
if j != redirectAddr {
t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", redirectAddr, j)
}
}
// func Test_SingleJoinZeroAttempts(t *testing.T) {
// ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// t.Fatalf("handler should not have been called")
// }))
// joiner := NewJoiner("127.0.0.1", 0, attemptInterval, nil)
// _, err := joiner.Do([]string{ts.URL}, "id0", "127.0.0.1:9090", false)
// if err != ErrJoinFailed {
// t.Fatalf("Incorrect error returned when zero attempts specified")
// }
// }
// func Test_SingleJoinFail(t *testing.T) {
// ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// w.WriteHeader(http.StatusBadRequest)
// }))
// defer ts.Close()
// joiner := NewJoiner("", 0, attemptInterval, nil)
// _, err := joiner.Do([]string{ts.URL}, "id0", "127.0.0.1:9090", true)
// if err == nil {
// t.Fatalf("expected error when joining bad node")
// }
// }
// func Test_DoubleJoinOK(t *testing.T) {
// ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// }))
// defer ts1.Close()
// ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// }))
// defer ts2.Close()
// joiner := NewJoiner("127.0.0.1", numAttempts, attemptInterval, nil)
// // Ensure joining with protocol prefix works.
// j, err := joiner.Do([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", true)
// if err != nil {
// t.Fatalf("failed to join a single node: %s", err.Error())
// }
// if j != ts1.URL+"/join" {
// t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts1.URL)
// }
// // Ensure joining without protocol prefix works.
// j, err = joiner.Do([]string{ts1.Listener.Addr().String(), ts2.Listener.Addr().String()}, "id0", "127.0.0.1:9090", true)
// if err != nil {
// t.Fatalf("failed to join a single node: %s", err.Error())
// }
// if j != ts1.URL+"/join" {
// t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts1.URL)
// }
// }
// func Test_DoubleJoinOKSecondNode(t *testing.T) {
// ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// w.WriteHeader(http.StatusBadRequest)
// }))
// defer ts1.Close()
// ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// }))
// defer ts2.Close()
// joiner := NewJoiner("", numAttempts, attemptInterval, nil)
// // Ensure joining with protocol prefix works.
// j, err := joiner.Do([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", true)
// if err != nil {
// t.Fatalf("failed to join a single node: %s", err.Error())
// }
// if j != ts2.URL+"/join" {
// t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts2.URL)
// }
// // Ensure joining without protocol prefix works.
// j, err = joiner.Do([]string{ts1.Listener.Addr().String(), ts2.Listener.Addr().String()}, "id0", "127.0.0.1:9090", true)
// if err != nil {
// t.Fatalf("failed to join a single node: %s", err.Error())
// }
// if j != ts2.URL+"/join" {
// t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts2.URL)
// }
// }
// func Test_DoubleJoinOKSecondNodeRedirect(t *testing.T) {
// ts1 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// }))
// defer ts1.Close()
// redirectAddr := fmt.Sprintf("%s%s", ts1.URL, "/join")
// ts2 := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// http.Redirect(w, r, redirectAddr, http.StatusMovedPermanently)
// }))
// defer ts2.Close()
// joiner := NewJoiner("127.0.0.1", numAttempts, attemptInterval, nil)
// // Ensure joining with protocol prefix works.
// j, err := joiner.Do([]string{ts2.URL}, "id0", "127.0.0.1:9090", true)
// if err != nil {
// t.Fatalf("failed to join a single node: %s", err.Error())
// }
// if j != redirectAddr {
// t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", redirectAddr, j)
// }
// // Ensure joining without protocol prefix works.
// j, err = joiner.Do([]string{ts2.Listener.Addr().String()}, "id0", "127.0.0.1:9090", true)
// if err != nil {
// t.Fatalf("failed to join a single node: %s", err.Error())
// }
// if j != redirectAddr {
// t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", redirectAddr, j)
// }
// }

@ -100,9 +100,6 @@ type Config struct {
// RaftAdv is the advertised Raft server address.
RaftAdv string
// JoinSrcIP sets the source IP address during Join request. May not be set.
JoinSrcIP string
// JoinAddrs is the list of Raft addresses to use for a join attempt.
JoinAddrs string
@ -300,9 +297,6 @@ func (c *Config) Validate() error {
}
}
}
if c.JoinSrcIP != "" && net.ParseIP(c.JoinSrcIP) == nil {
return fmt.Errorf("invalid join source IP address: %s", c.JoinSrcIP)
}
// Valid disco mode?
switch c.DiscoMode {
@ -424,7 +418,6 @@ func ParseFlags(name, desc string, build *BuildInfo) (*Config, error) {
flag.StringVar(&config.AutoRestoreFile, "auto-restore", "", "Path to automatic restore configuration file. If not set, not enabled")
flag.StringVar(&config.RaftAddr, RaftAddrFlag, "localhost:4002", "Raft communication bind address")
flag.StringVar(&config.RaftAdv, RaftAdvAddrFlag, "", "Advertised Raft communication address. If not set, same as Raft bind address")
flag.StringVar(&config.JoinSrcIP, "join-source-ip", "", "Set source IP address during HTTP Join request")
flag.StringVar(&config.JoinAddrs, "join", "", "Comma-delimited list of nodes, through which a cluster can be joined (proto://host:port)")
flag.IntVar(&config.JoinAttempts, "join-attempts", 5, "Number of join attempts to make")
flag.DurationVar(&config.JoinInterval, "join-interval", 3*time.Second, "Period between join attempts")

Loading…
Cancel
Save