1
0
Fork 0

Refactor some code out of main

master
Philip O'Toole 9 months ago
parent 6d3f0dafdf
commit 65ab53f23f

@ -3,6 +3,7 @@ package cluster
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"crypto/tls"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -11,7 +12,10 @@ import (
"sync" "sync"
"time" "time"
"github.com/rqlite/rqlite/v8/auth"
"github.com/rqlite/rqlite/v8/command" "github.com/rqlite/rqlite/v8/command"
"github.com/rqlite/rqlite/v8/rtls"
"github.com/rqlite/rqlite/v8/tcp"
"github.com/rqlite/rqlite/v8/tcp/pool" "github.com/rqlite/rqlite/v8/tcp/pool"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
@ -24,6 +28,36 @@ const (
protoBufferLengthSize = 8 protoBufferLengthSize = 8
) )
// CreateRaftDialer creates a dialer for connecting to other nodes' Raft service. If the cert and
// key arguments are not set, then the returned dialer will not use TLS.
func CreateRaftDialer(cert, key, caCert, serverName string, Insecure bool) (*tcp.Dialer, error) {
var dialerTLSConfig *tls.Config
var err error
if cert != "" || key != "" {
dialerTLSConfig, err = rtls.CreateClientConfig(cert, key, caCert, serverName, Insecure)
if err != nil {
return nil, fmt.Errorf("failed to create TLS config for Raft dialer: %s", err.Error())
}
}
return tcp.NewDialer(MuxRaftHeader, dialerTLSConfig), nil
}
// CredentialsFor returns a Credentials instance for the given username, or nil if
// the given CredentialsStore is nil, or the username is not found.
func CredentialsFor(credStr *auth.CredentialsStore, username string) *Credentials {
if credStr == nil {
return nil
}
pw, ok := credStr.Password(username)
if !ok {
return nil
}
return &Credentials{
Username: username,
Password: pw,
}
}
// Client allows communicating with a remote node. // Client allows communicating with a remote node.
type Client struct { type Client struct {
dialer Dialer dialer Dialer

@ -95,7 +95,8 @@ func main() {
// Raft internode layer // Raft internode layer
raftLn := mux.Listen(cluster.MuxRaftHeader) raftLn := mux.Listen(cluster.MuxRaftHeader)
log.Printf("Raft TCP mux Listener registered with byte header %d", cluster.MuxRaftHeader) log.Printf("Raft TCP mux Listener registered with byte header %d", cluster.MuxRaftHeader)
raftDialer, err := createRaftDialer(cfg) raftDialer, err := cluster.CreateRaftDialer(cfg.NodeX509Cert, cfg.NodeX509Key, cfg.NodeX509CACert,
cfg.NodeVerifyServerName, cfg.NoNodeVerify)
if err != nil { if err != nil {
log.Fatalf("failed to create Raft dialer: %s", err.Error()) log.Fatalf("failed to create Raft dialer: %s", err.Error())
} }
@ -202,7 +203,7 @@ func main() {
if cfg.RaftClusterRemoveOnShutdown { if cfg.RaftClusterRemoveOnShutdown {
remover := cluster.NewRemover(clstrClient, 5*time.Second, str) remover := cluster.NewRemover(clstrClient, 5*time.Second, str)
remover.SetCredentials(credentialsFor(credStr, cfg.JoinAs)) remover.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs))
log.Printf("initiating removal of this node from cluster before shutdown") log.Printf("initiating removal of this node from cluster before shutdown")
if err := remover.Do(cfg.NodeID, true); err != nil { if err := remover.Do(cfg.NodeID, true); err != nil {
log.Fatalf("failed to remove this node from cluster before shutdown: %s", err.Error()) log.Fatalf("failed to remove this node from cluster before shutdown: %s", err.Error())
@ -459,22 +460,9 @@ func createClusterClient(cfg *Config, clstr *cluster.Service) (*cluster.Client,
return clstrClient, nil return clstrClient, nil
} }
func createRaftDialer(cfg *Config) (*tcp.Dialer, error) {
var dialerTLSConfig *tls.Config
var err error
if cfg.NodeX509Cert != "" || cfg.NodeX509CACert != "" {
dialerTLSConfig, err = rtls.CreateClientConfig(cfg.NodeX509Cert, cfg.NodeX509Key,
cfg.NodeX509CACert, cfg.NodeVerifyServerName, cfg.NoNodeVerify)
if err != nil {
return nil, fmt.Errorf("failed to create TLS config for Raft dialer: %s", err.Error())
}
}
return tcp.NewDialer(cluster.MuxRaftHeader, dialerTLSConfig), nil
}
func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *store.Store, httpServ *httpd.Service, credStr *auth.CredentialsStore) error { func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *store.Store, httpServ *httpd.Service, credStr *auth.CredentialsStore) error {
joins := cfg.JoinAddresses() joins := cfg.JoinAddresses()
if err := networkCheckJoinAddrs(cfg, joins); err != nil { if err := networkCheckJoinAddrs(joins); err != nil {
return err return err
} }
@ -498,7 +486,7 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor
} }
joiner := cluster.NewJoiner(client, cfg.JoinAttempts, cfg.JoinInterval) joiner := cluster.NewJoiner(client, cfg.JoinAttempts, cfg.JoinInterval)
joiner.SetCredentials(credentialsFor(credStr, cfg.JoinAs)) joiner.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs))
if joins != nil && cfg.BootstrapExpect == 0 { if joins != nil && cfg.BootstrapExpect == 0 {
// Explicit join operation requested, so do it. // Explicit join operation requested, so do it.
j, err := joiner.Do(joins, str.ID(), cfg.RaftAdv, !cfg.RaftNonVoter) j, err := joiner.Do(joins, str.ID(), cfg.RaftAdv, !cfg.RaftNonVoter)
@ -512,7 +500,7 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor
if joins != nil && cfg.BootstrapExpect > 0 { if joins != nil && cfg.BootstrapExpect > 0 {
// Bootstrap with explicit join addresses requests. // Bootstrap with explicit join addresses requests.
bs := cluster.NewBootstrapper(cluster.NewAddressProviderString(joins), client) bs := cluster.NewBootstrapper(cluster.NewAddressProviderString(joins), client)
bs.SetCredentials(credentialsFor(credStr, cfg.JoinAs)) bs.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs))
return bs.Boot(str.ID(), cfg.RaftAdv, isClustered, cfg.BootstrapExpectTimeout) return bs.Boot(str.ID(), cfg.RaftAdv, isClustered, cfg.BootstrapExpectTimeout)
} }
@ -555,7 +543,7 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor
} }
bs := cluster.NewBootstrapper(provider, client) bs := cluster.NewBootstrapper(provider, client)
bs.SetCredentials(credentialsFor(credStr, cfg.JoinAs)) bs.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs))
httpServ.RegisterStatus("disco", provider) httpServ.RegisterStatus("disco", provider)
return bs.Boot(str.ID(), cfg.RaftAdv, isClustered, cfg.BootstrapExpectTimeout) return bs.Boot(str.ID(), cfg.RaftAdv, isClustered, cfg.BootstrapExpectTimeout)
@ -610,29 +598,10 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor
return nil return nil
} }
func networkCheckJoinAddrs(cfg *Config, joinAddrs []string) error { func networkCheckJoinAddrs(joinAddrs []string) error {
if len(joinAddrs) == 0 {
return nil
}
log.Println("checking that join addresses don't serve HTTP(S)") log.Println("checking that join addresses don't serve HTTP(S)")
for _, addr := range joinAddrs { if http.AnyServingHTTP(joinAddrs) {
if http.IsServingHTTP(addr) { return fmt.Errorf("join address appears to be serving HTTP when it should be Raft")
return fmt.Errorf("join address %s appears to be serving HTTP when it should be Raft", addr)
}
} }
return nil return nil
} }
func credentialsFor(credStr *auth.CredentialsStore, username string) *cluster.Credentials {
if credStr == nil {
return nil
}
pw, ok := credStr.Password(username)
if !ok {
return nil
}
return &cluster.Credentials{
Username: username,
Password: pw,
}
}

@ -6,6 +6,17 @@ import (
"time" "time"
) )
// AnyServingHTTP returns true if there appears to be a HTTP or HTTPS server
// running on any of the given addresses.
func AnyServingHTTP(addrs []string) bool {
for _, addr := range addrs {
if IsServingHTTP(addr) {
return true
}
}
return false
}
// IsServingHTTP returns true if there appears to be a HTTP or HTTPS server // IsServingHTTP returns true if there appears to be a HTTP or HTTPS server
// running on the given address. // running on the given address.
func IsServingHTTP(addr string) bool { func IsServingHTTP(addr string) bool {

@ -21,7 +21,10 @@ func Test_IsServingHTTP_HTTPServer(t *testing.T) {
addr := httpServer.Listener.Addr().String() addr := httpServer.Listener.Addr().String()
if !IsServingHTTP(addr) { if !IsServingHTTP(addr) {
t.Errorf("Expected true for HTTP server running on %s", addr) t.Fatalf("Expected true for HTTP server running on %s", addr)
}
if !AnyServingHTTP([]string{addr}) {
t.Fatalf("Expected true for HTTP server running on %s", addr)
} }
} }
@ -36,6 +39,9 @@ func Test_IsServingHTTP_HTTPSServer(t *testing.T) {
if !IsServingHTTP(addr) { if !IsServingHTTP(addr) {
t.Error("Expected true for HTTPS server running") t.Error("Expected true for HTTPS server running")
} }
if !AnyServingHTTP([]string{addr}) {
t.Fatalf("Expected true for HTTPS server running")
}
} }
// Test_IsServingHTTP_NoServersRunning tests no servers running. // Test_IsServingHTTP_NoServersRunning tests no servers running.
@ -44,6 +50,9 @@ func Test_IsServingHTTP_NoServersRunning(t *testing.T) {
if IsServingHTTP(addr) { if IsServingHTTP(addr) {
t.Error("Expected false for no servers running") t.Error("Expected false for no servers running")
} }
if AnyServingHTTP([]string{addr}) {
t.Error("Expected false for no servers running")
}
} }
// Test_IsServingHTTP_InvalidAddress tests invalid address format. // Test_IsServingHTTP_InvalidAddress tests invalid address format.
@ -52,6 +61,9 @@ func Test_IsServingHTTP_InvalidAddress(t *testing.T) {
if IsServingHTTP(addr) { if IsServingHTTP(addr) {
t.Error("Expected false for invalid address") t.Error("Expected false for invalid address")
} }
if AnyServingHTTP([]string{addr}) {
t.Error("Expected false for invalid address")
}
} }
// Test_IsServingHTTP_HTTPErrorStatusCode tests HTTP server returning error status code. // Test_IsServingHTTP_HTTPErrorStatusCode tests HTTP server returning error status code.
@ -65,6 +77,9 @@ func Test_IsServingHTTP_HTTPErrorStatusCode(t *testing.T) {
if !IsServingHTTP(addr) { if !IsServingHTTP(addr) {
t.Error("Expected true for HTTP server running, even with error status code") t.Error("Expected true for HTTP server running, even with error status code")
} }
if !AnyServingHTTP([]string{addr}) {
t.Error("Expected true for HTTP server running, even with error status code")
}
} }
// Test_IsServingHTTP_HTTPSSuccessStatusCode tests HTTPS server running with success status code. // Test_IsServingHTTP_HTTPSSuccessStatusCode tests HTTPS server running with success status code.
@ -78,6 +93,9 @@ func Test_IsServingHTTP_HTTPSSuccessStatusCode(t *testing.T) {
if !IsServingHTTP(addr) { if !IsServingHTTP(addr) {
t.Error("Expected true for HTTPS server running with success status code") t.Error("Expected true for HTTPS server running with success status code")
} }
if !AnyServingHTTP([]string{addr}) {
t.Error("Expected true for HTTPS server running with success status code")
}
} }
func Test_IsServingHTTP_OpenPort(t *testing.T) { func Test_IsServingHTTP_OpenPort(t *testing.T) {
@ -92,6 +110,9 @@ func Test_IsServingHTTP_OpenPort(t *testing.T) {
if IsServingHTTP(addr) { if IsServingHTTP(addr) {
t.Error("Expected false for open port") t.Error("Expected false for open port")
} }
if AnyServingHTTP([]string{addr}) {
t.Error("Expected false for open port")
}
} }
func Test_IsServingHTTP_OpenPortTLS(t *testing.T) { func Test_IsServingHTTP_OpenPortTLS(t *testing.T) {
@ -116,4 +137,7 @@ func Test_IsServingHTTP_OpenPortTLS(t *testing.T) {
if IsServingHTTP(addr) { if IsServingHTTP(addr) {
t.Error("Expected false for open TLS port") t.Error("Expected false for open TLS port")
} }
if AnyServingHTTP([]string{addr}) {
t.Error("Expected false for open TLS port")
}
} }

Loading…
Cancel
Save