diff --git a/CHANGELOG.md b/CHANGELOG.md index 456f502c..1c3276a5 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,7 @@ +## 8.13.4 (unreleased) +### Implementation changes and bug fixes +- [PR #1522](https://github.com/rqlite/rqlite/pull/1522): Minor refactoring of main module. + ## 8.13.4 (December 23rd 2023) This release makes sure the version information is correctly recorded in the released binaries. There are no functional changes. ### Implementation changes and bug fixes diff --git a/cluster/client.go b/cluster/client.go index 333e9395..04c2b621 100644 --- a/cluster/client.go +++ b/cluster/client.go @@ -3,6 +3,7 @@ package cluster import ( "bytes" "compress/gzip" + "crypto/tls" "encoding/binary" "errors" "fmt" @@ -11,7 +12,10 @@ import ( "sync" "time" + "github.com/rqlite/rqlite/v8/auth" "github.com/rqlite/rqlite/v8/command" + "github.com/rqlite/rqlite/v8/rtls" + "github.com/rqlite/rqlite/v8/tcp" "github.com/rqlite/rqlite/v8/tcp/pool" "google.golang.org/protobuf/proto" ) @@ -24,6 +28,36 @@ const ( protoBufferLengthSize = 8 ) +// CreateRaftDialer creates a dialer for connecting to other nodes' Raft service. If the cert and +// key arguments are not set, then the returned dialer will not use TLS. +func CreateRaftDialer(cert, key, caCert, serverName string, Insecure bool) (*tcp.Dialer, error) { + var dialerTLSConfig *tls.Config + var err error + if cert != "" || key != "" { + dialerTLSConfig, err = rtls.CreateClientConfig(cert, key, caCert, serverName, Insecure) + if err != nil { + return nil, fmt.Errorf("failed to create TLS config for Raft dialer: %s", err.Error()) + } + } + return tcp.NewDialer(MuxRaftHeader, dialerTLSConfig), nil +} + +// CredentialsFor returns a Credentials instance for the given username, or nil if +// the given CredentialsStore is nil, or the username is not found. +func CredentialsFor(credStr *auth.CredentialsStore, username string) *Credentials { + if credStr == nil { + return nil + } + pw, ok := credStr.Password(username) + if !ok { + return nil + } + return &Credentials{ + Username: username, + Password: pw, + } +} + // Client allows communicating with a remote node. type Client struct { dialer Dialer diff --git a/cmd/rqlited/main.go b/cmd/rqlited/main.go index 009b8a38..94335d0b 100644 --- a/cmd/rqlited/main.go +++ b/cmd/rqlited/main.go @@ -95,7 +95,8 @@ func main() { // Raft internode layer raftLn := mux.Listen(cluster.MuxRaftHeader) log.Printf("Raft TCP mux Listener registered with byte header %d", cluster.MuxRaftHeader) - raftDialer, err := createRaftDialer(cfg) + raftDialer, err := cluster.CreateRaftDialer(cfg.NodeX509Cert, cfg.NodeX509Key, cfg.NodeX509CACert, + cfg.NodeVerifyServerName, cfg.NoNodeVerify) if err != nil { log.Fatalf("failed to create Raft dialer: %s", err.Error()) } @@ -202,7 +203,7 @@ func main() { if cfg.RaftClusterRemoveOnShutdown { remover := cluster.NewRemover(clstrClient, 5*time.Second, str) - remover.SetCredentials(credentialsFor(credStr, cfg.JoinAs)) + remover.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs)) log.Printf("initiating removal of this node from cluster before shutdown") if err := remover.Do(cfg.NodeID, true); err != nil { log.Fatalf("failed to remove this node from cluster before shutdown: %s", err.Error()) @@ -459,22 +460,9 @@ func createClusterClient(cfg *Config, clstr *cluster.Service) (*cluster.Client, return clstrClient, nil } -func createRaftDialer(cfg *Config) (*tcp.Dialer, error) { - var dialerTLSConfig *tls.Config - var err error - if cfg.NodeX509Cert != "" || cfg.NodeX509CACert != "" { - dialerTLSConfig, err = rtls.CreateClientConfig(cfg.NodeX509Cert, cfg.NodeX509Key, - cfg.NodeX509CACert, cfg.NodeVerifyServerName, cfg.NoNodeVerify) - if err != nil { - return nil, fmt.Errorf("failed to create TLS config for Raft dialer: %s", err.Error()) - } - } - return tcp.NewDialer(cluster.MuxRaftHeader, dialerTLSConfig), nil -} - func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *store.Store, httpServ *httpd.Service, credStr *auth.CredentialsStore) error { joins := cfg.JoinAddresses() - if err := networkCheckJoinAddrs(cfg, joins); err != nil { + if err := networkCheckJoinAddrs(joins); err != nil { return err } @@ -498,7 +486,7 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor } joiner := cluster.NewJoiner(client, cfg.JoinAttempts, cfg.JoinInterval) - joiner.SetCredentials(credentialsFor(credStr, cfg.JoinAs)) + joiner.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs)) if joins != nil && cfg.BootstrapExpect == 0 { // Explicit join operation requested, so do it. j, err := joiner.Do(joins, str.ID(), cfg.RaftAdv, !cfg.RaftNonVoter) @@ -512,7 +500,7 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor if joins != nil && cfg.BootstrapExpect > 0 { // Bootstrap with explicit join addresses requests. bs := cluster.NewBootstrapper(cluster.NewAddressProviderString(joins), client) - bs.SetCredentials(credentialsFor(credStr, cfg.JoinAs)) + bs.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs)) return bs.Boot(str.ID(), cfg.RaftAdv, isClustered, cfg.BootstrapExpectTimeout) } @@ -555,7 +543,7 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor } bs := cluster.NewBootstrapper(provider, client) - bs.SetCredentials(credentialsFor(credStr, cfg.JoinAs)) + bs.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs)) httpServ.RegisterStatus("disco", provider) return bs.Boot(str.ID(), cfg.RaftAdv, isClustered, cfg.BootstrapExpectTimeout) @@ -610,29 +598,10 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor return nil } -func networkCheckJoinAddrs(cfg *Config, joinAddrs []string) error { - if len(joinAddrs) == 0 { - return nil - } +func networkCheckJoinAddrs(joinAddrs []string) error { log.Println("checking that join addresses don't serve HTTP(S)") - for _, addr := range joinAddrs { - if http.IsServingHTTP(addr) { - return fmt.Errorf("join address %s appears to be serving HTTP when it should be Raft", addr) - } + if addr, ok := http.AnyServingHTTP(joinAddrs); ok { + return fmt.Errorf("join address %s appears to be serving HTTP when it should be Raft", addr) } return nil } - -func credentialsFor(credStr *auth.CredentialsStore, username string) *cluster.Credentials { - if credStr == nil { - return nil - } - pw, ok := credStr.Password(username) - if !ok { - return nil - } - return &cluster.Credentials{ - Username: username, - Password: pw, - } -} diff --git a/http/preflight.go b/http/preflight.go index d4d41a1f..462efdca 100644 --- a/http/preflight.go +++ b/http/preflight.go @@ -6,6 +6,17 @@ import ( "time" ) +// AnyServingHTTP returns the first address in the list that appears to be +// serving HTTP or HTTPS, or false if none of them are. +func AnyServingHTTP(addrs []string) (string, bool) { + for _, addr := range addrs { + if IsServingHTTP(addr) { + return addr, true + } + } + return "", false +} + // IsServingHTTP returns true if there appears to be a HTTP or HTTPS server // running on the given address. func IsServingHTTP(addr string) bool { diff --git a/http/preflight_test.go b/http/preflight_test.go index 2a4baf71..4fa54675 100644 --- a/http/preflight_test.go +++ b/http/preflight_test.go @@ -21,7 +21,10 @@ func Test_IsServingHTTP_HTTPServer(t *testing.T) { addr := httpServer.Listener.Addr().String() if !IsServingHTTP(addr) { - t.Errorf("Expected true for HTTP server running on %s", addr) + t.Fatalf("Expected true for HTTP server running on %s", addr) + } + if a, ok := AnyServingHTTP([]string{addr}); !ok || a != addr { + t.Fatalf("Expected %s for AnyServingHTTP", addr) } } @@ -36,6 +39,9 @@ func Test_IsServingHTTP_HTTPSServer(t *testing.T) { if !IsServingHTTP(addr) { t.Error("Expected true for HTTPS server running") } + if a, ok := AnyServingHTTP([]string{addr}); !ok || a != addr { + t.Fatalf("Expected %s for AnyServingHTTP", addr) + } } // Test_IsServingHTTP_NoServersRunning tests no servers running. @@ -44,6 +50,9 @@ func Test_IsServingHTTP_NoServersRunning(t *testing.T) { if IsServingHTTP(addr) { t.Error("Expected false for no servers running") } + if _, ok := AnyServingHTTP([]string{addr}); ok { + t.Error("Expected false for no servers running") + } } // Test_IsServingHTTP_InvalidAddress tests invalid address format. @@ -52,6 +61,9 @@ func Test_IsServingHTTP_InvalidAddress(t *testing.T) { if IsServingHTTP(addr) { t.Error("Expected false for invalid address") } + if _, ok := AnyServingHTTP([]string{addr}); ok { + t.Error("Expected false for invalid address") + } } // Test_IsServingHTTP_HTTPErrorStatusCode tests HTTP server returning error status code. @@ -65,6 +77,9 @@ func Test_IsServingHTTP_HTTPErrorStatusCode(t *testing.T) { if !IsServingHTTP(addr) { t.Error("Expected true for HTTP server running, even with error status code") } + if a, ok := AnyServingHTTP([]string{addr}); !ok || a != addr { + t.Fatalf("Expected %s for AnyServingHTTP, even with error status code", addr) + } } // Test_IsServingHTTP_HTTPSSuccessStatusCode tests HTTPS server running with success status code. @@ -78,6 +93,9 @@ func Test_IsServingHTTP_HTTPSSuccessStatusCode(t *testing.T) { if !IsServingHTTP(addr) { t.Error("Expected true for HTTPS server running with success status code") } + if a, ok := AnyServingHTTP([]string{addr}); !ok || a != addr { + t.Fatalf("Expected %s for AnyServingHTTP with success status code", addr) + } } func Test_IsServingHTTP_OpenPort(t *testing.T) { @@ -92,6 +110,9 @@ func Test_IsServingHTTP_OpenPort(t *testing.T) { if IsServingHTTP(addr) { t.Error("Expected false for open port") } + if _, ok := AnyServingHTTP([]string{addr}); ok { + t.Error("Expected false for open port") + } } func Test_IsServingHTTP_OpenPortTLS(t *testing.T) { @@ -116,4 +137,26 @@ func Test_IsServingHTTP_OpenPortTLS(t *testing.T) { if IsServingHTTP(addr) { t.Error("Expected false for open TLS port") } + if _, ok := AnyServingHTTP([]string{addr}); ok { + t.Error("Expected false for open TLS port") + } +} + +func Test_IsServingHTTP_HTTPServerTCPPort(t *testing.T) { + httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + })) + defer httpServer.Close() + + ln, err := net.Listen("tcp", ":0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + + httpAddr := httpServer.Listener.Addr().String() + tcpAddr := ln.Addr().String() + if a, ok := AnyServingHTTP([]string{httpAddr, tcpAddr}); !ok || a != httpAddr { + t.Fatalf("Expected %s for AnyServingHTTP", httpAddr) + } }