diff --git a/cluster/join.go b/cluster/join.go index f3f53ceb..ef233d57 100644 --- a/cluster/join.go +++ b/cluster/join.go @@ -13,8 +13,6 @@ import ( "os" "strings" "time" - - rurl "github.com/rqlite/rqlite/http/url" ) var ( @@ -47,8 +45,10 @@ type Joiner struct { } // NewJoiner returns an instantiated Joiner. -func NewJoiner(srcIP string, numAttempts int, attemptInterval time.Duration, - tlsCfg *tls.Config) *Joiner { +func NewJoiner(srcIP string, numAttempts int, attemptInterval time.Duration, tlsCfg *tls.Config) *Joiner { + if tlsCfg == nil { + tlsCfg = &tls.Config{InsecureSkipVerify: true} + } // Source IP is optional dialer := &net.Dialer{} @@ -67,9 +67,6 @@ func NewJoiner(srcIP string, numAttempts int, attemptInterval time.Duration, tlsConfig: tlsCfg, logger: log.New(os.Stderr, "[cluster-join] ", log.LstdFlags), } - if joiner.tlsConfig == nil { - joiner.tlsConfig = &tls.Config{InsecureSkipVerify: true} - } // Create and configure the client to connect to the other node. tr := &http.Transport{ @@ -90,7 +87,10 @@ func (j *Joiner) SetBasicAuth(username, password string) { j.username, j.password = username, password } -// Do makes the actual join request. +// Do makes the actual join request. If any of the join addresses do not contain a +// protocol, both http:// and https:// are tried for that address. If the join is successful +// with any address, the Join URL of the node that joined is returned. Otherwise, an error +// is returned. func (j *Joiner) Do(joinAddrs []string, id, addr string, voter bool) (string, error) { if id == "" { return "", ErrNodeIDRequired @@ -99,7 +99,7 @@ func (j *Joiner) Do(joinAddrs []string, id, addr string, voter bool) (string, er var err error var joinee string for i := 0; i < j.numAttempts; i++ { - for _, a := range joinAddrs { + for _, a := range normalizeAddrs(joinAddrs) { joinee, err = j.join(a, id, addr, voter) if err == nil { // Success! @@ -119,35 +119,44 @@ func (j *Joiner) Do(joinAddrs []string, id, addr string, voter bool) (string, er } func (j *Joiner) join(joinAddr, id, addr string, voter bool) (string, error) { - // Check for protocol scheme, and insert default if necessary. - fullAddr := rurl.NormalizeAddr(fmt.Sprintf("%s/join", joinAddr)) + fullAddr := fmt.Sprintf("%s/join", joinAddr) + reqBody, err := json.Marshal(map[string]interface{}{ + "id": id, + "addr": addr, + "voter": voter, + }) + if err != nil { + return "", err + } for { - b, err := json.Marshal(map[string]interface{}{ - "id": id, - "addr": addr, - "voter": voter, - }) - if err != nil { - return "", err - } - // Attempt to join. - req, err := http.NewRequest("POST", fullAddr, bytes.NewReader(b)) + req, err := http.NewRequest("POST", fullAddr, bytes.NewReader(reqBody)) if err != nil { return "", err } if j.username != "" && j.password != "" { req.SetBasicAuth(j.username, j.password) } - req.Header.Add("Content-Type", "application/json") - resp, err := j.client.Do(req) - if err != nil { - return "", err - } - defer resp.Body.Close() - b, err = io.ReadAll(resp.Body) + var resp *http.Response + var respB []byte + err = func() error { + req.Header.Add("Content-Type", "application/json") + resp, err = j.client.Do(req) + if err != nil { + return err + } + defer resp.Body.Close() + + // Only significant in the event of an error response + // from the remote node. + respB, err = io.ReadAll(resp.Body) + if err != nil { + return err + } + return nil + }() if err != nil { return "", err } @@ -161,21 +170,21 @@ func (j *Joiner) join(joinAddr, id, addr string, voter bool) (string, error) { return "", ErrInvalidRedirect } continue - case http.StatusBadRequest: - // One possible cause is that the target server is listening for HTTPS, but an HTTP - // attempt was made. Switch the protocol to HTTPS, and try again. This can happen - // when using the Disco service, since it doesn't record information about which - // protocol a registered node is actually using. - if strings.HasPrefix(fullAddr, "https://") { - // It's already HTTPS, give up. - return "", fmt.Errorf("%s: (%s)", resp.Status, string(b)) - } - - j.logger.Print("join via HTTP failed, trying via HTTPS") - fullAddr = rurl.EnsureHTTPS(fullAddr) - continue default: - return "", fmt.Errorf("%s: (%s)", resp.Status, string(b)) + return "", fmt.Errorf("%s: (%s)", resp.Status, string(respB)) + } + } +} + +func normalizeAddrs(addrs []string) []string { + var a []string + for _, addr := range addrs { + if strings.Contains(addr, "://") { + a = append(a, addr) + } else { + a = append(a, fmt.Sprintf("http://%s", addr)) + a = append(a, fmt.Sprintf("https://%s", addr)) } } + return a } diff --git a/cluster/join_test.go b/cluster/join_test.go index 3fe4f92a..cdb46c18 100644 --- a/cluster/join_test.go +++ b/cluster/join_test.go @@ -43,6 +43,7 @@ func Test_SingleJoinOK(t *testing.T) { joiner := NewJoiner("127.0.0.1", numAttempts, attemptInterval, nil) + // Ensure joining with protocol prefix works. j, err := joiner.Do([]string{ts.URL}, "id0", "127.0.0.1:9090", false) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) @@ -60,6 +61,25 @@ func Test_SingleJoinOK(t *testing.T) { if got, exp := body["voter"].(bool), false; got != exp { t.Fatalf("wrong voter state supplied, exp %v, got %v", exp, got) } + + // Ensure joining without protocol prefix works. + j, err = joiner.Do([]string{ts.Listener.Addr().String()}, "id0", "127.0.0.1:9090", false) + if err != nil { + t.Fatalf("failed to join a single node: %s", err.Error()) + } + if j != ts.URL+"/join" { + t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts.URL) + } + + if got, exp := body["id"].(string), "id0"; got != exp { + t.Fatalf("wrong node ID supplied, exp %s, got %s", exp, got) + } + if got, exp := body["addr"].(string), "127.0.0.1:9090"; got != exp { + t.Fatalf("wrong address supplied, exp %s, got %s", exp, got) + } + if got, exp := body["voter"].(bool), false; got != exp { + t.Fatalf("wrong voter state supplied, exp %v, got %v", exp, got) + } } func Test_SingleJoinHTTPSOK(t *testing.T) { @@ -95,6 +115,7 @@ func Test_SingleJoinHTTPSOK(t *testing.T) { } joiner := NewJoiner("127.0.0.1", numAttempts, attemptInterval, tlsConfig) + // Ensure joining with protocol prefix works. j, err := joiner.Do([]string{ts.URL}, "id0", "127.0.0.1:9090", false) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) @@ -112,6 +133,25 @@ func Test_SingleJoinHTTPSOK(t *testing.T) { if got, exp := body["voter"].(bool), false; got != exp { t.Fatalf("wrong voter state supplied, exp %v, got %v", exp, got) } + + // Ensure joining without protocol prefix works. + j, err = joiner.Do([]string{ts.Listener.Addr().String()}, "id0", "127.0.0.1:9090", false) + if err != nil { + t.Fatalf("failed to join a single node: %s", err.Error()) + } + if j != ts.URL+"/join" { + t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts.URL) + } + + if got, exp := body["id"].(string), "id0"; got != exp { + t.Fatalf("wrong node ID supplied, exp %s, got %s", exp, got) + } + if got, exp := body["addr"].(string), "127.0.0.1:9090"; got != exp { + t.Fatalf("wrong address supplied, exp %s, got %s", exp, got) + } + if got, exp := body["voter"].(bool), false; got != exp { + t.Fatalf("wrong voter state supplied, exp %v, got %v", exp, got) + } } func Test_SingleJoinOKBasicAuth(t *testing.T) { @@ -200,6 +240,7 @@ func Test_DoubleJoinOK(t *testing.T) { joiner := NewJoiner("127.0.0.1", numAttempts, attemptInterval, nil) + // Ensure joining with protocol prefix works. j, err := joiner.Do([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", true) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) @@ -207,6 +248,15 @@ func Test_DoubleJoinOK(t *testing.T) { if j != ts1.URL+"/join" { t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts1.URL) } + + // Ensure joining without protocol prefix works. + j, err = joiner.Do([]string{ts1.Listener.Addr().String(), ts2.Listener.Addr().String()}, "id0", "127.0.0.1:9090", true) + if err != nil { + t.Fatalf("failed to join a single node: %s", err.Error()) + } + if j != ts1.URL+"/join" { + t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts1.URL) + } } func Test_DoubleJoinOKSecondNode(t *testing.T) { @@ -220,6 +270,7 @@ func Test_DoubleJoinOKSecondNode(t *testing.T) { joiner := NewJoiner("", numAttempts, attemptInterval, nil) + // Ensure joining with protocol prefix works. j, err := joiner.Do([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", true) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) @@ -227,6 +278,15 @@ func Test_DoubleJoinOKSecondNode(t *testing.T) { if j != ts2.URL+"/join" { t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts2.URL) } + + // Ensure joining without protocol prefix works. + j, err = joiner.Do([]string{ts1.Listener.Addr().String(), ts2.Listener.Addr().String()}, "id0", "127.0.0.1:9090", true) + if err != nil { + t.Fatalf("failed to join a single node: %s", err.Error()) + } + if j != ts2.URL+"/join" { + t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", j, ts2.URL) + } } func Test_DoubleJoinOKSecondNodeRedirect(t *testing.T) { @@ -242,6 +302,7 @@ func Test_DoubleJoinOKSecondNodeRedirect(t *testing.T) { joiner := NewJoiner("127.0.0.1", numAttempts, attemptInterval, nil) + // Ensure joining with protocol prefix works. j, err := joiner.Do([]string{ts2.URL}, "id0", "127.0.0.1:9090", true) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) @@ -249,4 +310,13 @@ func Test_DoubleJoinOKSecondNodeRedirect(t *testing.T) { if j != redirectAddr { t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", redirectAddr, j) } + + // Ensure joining without protocol prefix works. + j, err = joiner.Do([]string{ts2.Listener.Addr().String()}, "id0", "127.0.0.1:9090", true) + if err != nil { + t.Fatalf("failed to join a single node: %s", err.Error()) + } + if j != redirectAddr { + t.Fatalf("node joined using wrong endpoint, exp: %s, got: %s", redirectAddr, j) + } }