1
0
Fork 0

Merge pull request #1522 from rqlite/main-refactor-1

Refactor some code out of main
master
Philip O'Toole 9 months ago committed by GitHub
commit 859f2bd535
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,3 +1,7 @@
## 8.13.4 (unreleased)
### Implementation changes and bug fixes
- [PR #1522](https://github.com/rqlite/rqlite/pull/1522): Minor refactoring of main module.
## 8.13.4 (December 23rd 2023) ## 8.13.4 (December 23rd 2023)
This release makes sure the version information is correctly recorded in the released binaries. There are no functional changes. This release makes sure the version information is correctly recorded in the released binaries. There are no functional changes.
### Implementation changes and bug fixes ### Implementation changes and bug fixes

@ -3,6 +3,7 @@ package cluster
import ( import (
"bytes" "bytes"
"compress/gzip" "compress/gzip"
"crypto/tls"
"encoding/binary" "encoding/binary"
"errors" "errors"
"fmt" "fmt"
@ -11,7 +12,10 @@ import (
"sync" "sync"
"time" "time"
"github.com/rqlite/rqlite/v8/auth"
"github.com/rqlite/rqlite/v8/command" "github.com/rqlite/rqlite/v8/command"
"github.com/rqlite/rqlite/v8/rtls"
"github.com/rqlite/rqlite/v8/tcp"
"github.com/rqlite/rqlite/v8/tcp/pool" "github.com/rqlite/rqlite/v8/tcp/pool"
"google.golang.org/protobuf/proto" "google.golang.org/protobuf/proto"
) )
@ -24,6 +28,36 @@ const (
protoBufferLengthSize = 8 protoBufferLengthSize = 8
) )
// CreateRaftDialer creates a dialer for connecting to other nodes' Raft service. If the cert and
// key arguments are not set, then the returned dialer will not use TLS.
func CreateRaftDialer(cert, key, caCert, serverName string, Insecure bool) (*tcp.Dialer, error) {
var dialerTLSConfig *tls.Config
var err error
if cert != "" || key != "" {
dialerTLSConfig, err = rtls.CreateClientConfig(cert, key, caCert, serverName, Insecure)
if err != nil {
return nil, fmt.Errorf("failed to create TLS config for Raft dialer: %s", err.Error())
}
}
return tcp.NewDialer(MuxRaftHeader, dialerTLSConfig), nil
}
// CredentialsFor returns a Credentials instance for the given username, or nil if
// the given CredentialsStore is nil, or the username is not found.
func CredentialsFor(credStr *auth.CredentialsStore, username string) *Credentials {
if credStr == nil {
return nil
}
pw, ok := credStr.Password(username)
if !ok {
return nil
}
return &Credentials{
Username: username,
Password: pw,
}
}
// Client allows communicating with a remote node. // Client allows communicating with a remote node.
type Client struct { type Client struct {
dialer Dialer dialer Dialer

@ -95,7 +95,8 @@ func main() {
// Raft internode layer // Raft internode layer
raftLn := mux.Listen(cluster.MuxRaftHeader) raftLn := mux.Listen(cluster.MuxRaftHeader)
log.Printf("Raft TCP mux Listener registered with byte header %d", cluster.MuxRaftHeader) log.Printf("Raft TCP mux Listener registered with byte header %d", cluster.MuxRaftHeader)
raftDialer, err := createRaftDialer(cfg) raftDialer, err := cluster.CreateRaftDialer(cfg.NodeX509Cert, cfg.NodeX509Key, cfg.NodeX509CACert,
cfg.NodeVerifyServerName, cfg.NoNodeVerify)
if err != nil { if err != nil {
log.Fatalf("failed to create Raft dialer: %s", err.Error()) log.Fatalf("failed to create Raft dialer: %s", err.Error())
} }
@ -202,7 +203,7 @@ func main() {
if cfg.RaftClusterRemoveOnShutdown { if cfg.RaftClusterRemoveOnShutdown {
remover := cluster.NewRemover(clstrClient, 5*time.Second, str) remover := cluster.NewRemover(clstrClient, 5*time.Second, str)
remover.SetCredentials(credentialsFor(credStr, cfg.JoinAs)) remover.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs))
log.Printf("initiating removal of this node from cluster before shutdown") log.Printf("initiating removal of this node from cluster before shutdown")
if err := remover.Do(cfg.NodeID, true); err != nil { if err := remover.Do(cfg.NodeID, true); err != nil {
log.Fatalf("failed to remove this node from cluster before shutdown: %s", err.Error()) log.Fatalf("failed to remove this node from cluster before shutdown: %s", err.Error())
@ -459,22 +460,9 @@ func createClusterClient(cfg *Config, clstr *cluster.Service) (*cluster.Client,
return clstrClient, nil return clstrClient, nil
} }
func createRaftDialer(cfg *Config) (*tcp.Dialer, error) {
var dialerTLSConfig *tls.Config
var err error
if cfg.NodeX509Cert != "" || cfg.NodeX509CACert != "" {
dialerTLSConfig, err = rtls.CreateClientConfig(cfg.NodeX509Cert, cfg.NodeX509Key,
cfg.NodeX509CACert, cfg.NodeVerifyServerName, cfg.NoNodeVerify)
if err != nil {
return nil, fmt.Errorf("failed to create TLS config for Raft dialer: %s", err.Error())
}
}
return tcp.NewDialer(cluster.MuxRaftHeader, dialerTLSConfig), nil
}
func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *store.Store, httpServ *httpd.Service, credStr *auth.CredentialsStore) error { func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *store.Store, httpServ *httpd.Service, credStr *auth.CredentialsStore) error {
joins := cfg.JoinAddresses() joins := cfg.JoinAddresses()
if err := networkCheckJoinAddrs(cfg, joins); err != nil { if err := networkCheckJoinAddrs(joins); err != nil {
return err return err
} }
@ -498,7 +486,7 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor
} }
joiner := cluster.NewJoiner(client, cfg.JoinAttempts, cfg.JoinInterval) joiner := cluster.NewJoiner(client, cfg.JoinAttempts, cfg.JoinInterval)
joiner.SetCredentials(credentialsFor(credStr, cfg.JoinAs)) joiner.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs))
if joins != nil && cfg.BootstrapExpect == 0 { if joins != nil && cfg.BootstrapExpect == 0 {
// Explicit join operation requested, so do it. // Explicit join operation requested, so do it.
j, err := joiner.Do(joins, str.ID(), cfg.RaftAdv, !cfg.RaftNonVoter) j, err := joiner.Do(joins, str.ID(), cfg.RaftAdv, !cfg.RaftNonVoter)
@ -512,7 +500,7 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor
if joins != nil && cfg.BootstrapExpect > 0 { if joins != nil && cfg.BootstrapExpect > 0 {
// Bootstrap with explicit join addresses requests. // Bootstrap with explicit join addresses requests.
bs := cluster.NewBootstrapper(cluster.NewAddressProviderString(joins), client) bs := cluster.NewBootstrapper(cluster.NewAddressProviderString(joins), client)
bs.SetCredentials(credentialsFor(credStr, cfg.JoinAs)) bs.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs))
return bs.Boot(str.ID(), cfg.RaftAdv, isClustered, cfg.BootstrapExpectTimeout) return bs.Boot(str.ID(), cfg.RaftAdv, isClustered, cfg.BootstrapExpectTimeout)
} }
@ -555,7 +543,7 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor
} }
bs := cluster.NewBootstrapper(provider, client) bs := cluster.NewBootstrapper(provider, client)
bs.SetCredentials(credentialsFor(credStr, cfg.JoinAs)) bs.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs))
httpServ.RegisterStatus("disco", provider) httpServ.RegisterStatus("disco", provider)
return bs.Boot(str.ID(), cfg.RaftAdv, isClustered, cfg.BootstrapExpectTimeout) return bs.Boot(str.ID(), cfg.RaftAdv, isClustered, cfg.BootstrapExpectTimeout)
@ -610,29 +598,10 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor
return nil return nil
} }
func networkCheckJoinAddrs(cfg *Config, joinAddrs []string) error { func networkCheckJoinAddrs(joinAddrs []string) error {
if len(joinAddrs) == 0 {
return nil
}
log.Println("checking that join addresses don't serve HTTP(S)") log.Println("checking that join addresses don't serve HTTP(S)")
for _, addr := range joinAddrs { if addr, ok := http.AnyServingHTTP(joinAddrs); ok {
if http.IsServingHTTP(addr) { return fmt.Errorf("join address %s appears to be serving HTTP when it should be Raft", addr)
return fmt.Errorf("join address %s appears to be serving HTTP when it should be Raft", addr)
}
} }
return nil return nil
} }
func credentialsFor(credStr *auth.CredentialsStore, username string) *cluster.Credentials {
if credStr == nil {
return nil
}
pw, ok := credStr.Password(username)
if !ok {
return nil
}
return &cluster.Credentials{
Username: username,
Password: pw,
}
}

@ -6,6 +6,17 @@ import (
"time" "time"
) )
// AnyServingHTTP returns the first address in the list that appears to be
// serving HTTP or HTTPS, or false if none of them are.
func AnyServingHTTP(addrs []string) (string, bool) {
for _, addr := range addrs {
if IsServingHTTP(addr) {
return addr, true
}
}
return "", false
}
// IsServingHTTP returns true if there appears to be a HTTP or HTTPS server // IsServingHTTP returns true if there appears to be a HTTP or HTTPS server
// running on the given address. // running on the given address.
func IsServingHTTP(addr string) bool { func IsServingHTTP(addr string) bool {

@ -21,7 +21,10 @@ func Test_IsServingHTTP_HTTPServer(t *testing.T) {
addr := httpServer.Listener.Addr().String() addr := httpServer.Listener.Addr().String()
if !IsServingHTTP(addr) { if !IsServingHTTP(addr) {
t.Errorf("Expected true for HTTP server running on %s", addr) t.Fatalf("Expected true for HTTP server running on %s", addr)
}
if a, ok := AnyServingHTTP([]string{addr}); !ok || a != addr {
t.Fatalf("Expected %s for AnyServingHTTP", addr)
} }
} }
@ -36,6 +39,9 @@ func Test_IsServingHTTP_HTTPSServer(t *testing.T) {
if !IsServingHTTP(addr) { if !IsServingHTTP(addr) {
t.Error("Expected true for HTTPS server running") t.Error("Expected true for HTTPS server running")
} }
if a, ok := AnyServingHTTP([]string{addr}); !ok || a != addr {
t.Fatalf("Expected %s for AnyServingHTTP", addr)
}
} }
// Test_IsServingHTTP_NoServersRunning tests no servers running. // Test_IsServingHTTP_NoServersRunning tests no servers running.
@ -44,6 +50,9 @@ func Test_IsServingHTTP_NoServersRunning(t *testing.T) {
if IsServingHTTP(addr) { if IsServingHTTP(addr) {
t.Error("Expected false for no servers running") t.Error("Expected false for no servers running")
} }
if _, ok := AnyServingHTTP([]string{addr}); ok {
t.Error("Expected false for no servers running")
}
} }
// Test_IsServingHTTP_InvalidAddress tests invalid address format. // Test_IsServingHTTP_InvalidAddress tests invalid address format.
@ -52,6 +61,9 @@ func Test_IsServingHTTP_InvalidAddress(t *testing.T) {
if IsServingHTTP(addr) { if IsServingHTTP(addr) {
t.Error("Expected false for invalid address") t.Error("Expected false for invalid address")
} }
if _, ok := AnyServingHTTP([]string{addr}); ok {
t.Error("Expected false for invalid address")
}
} }
// Test_IsServingHTTP_HTTPErrorStatusCode tests HTTP server returning error status code. // Test_IsServingHTTP_HTTPErrorStatusCode tests HTTP server returning error status code.
@ -65,6 +77,9 @@ func Test_IsServingHTTP_HTTPErrorStatusCode(t *testing.T) {
if !IsServingHTTP(addr) { if !IsServingHTTP(addr) {
t.Error("Expected true for HTTP server running, even with error status code") t.Error("Expected true for HTTP server running, even with error status code")
} }
if a, ok := AnyServingHTTP([]string{addr}); !ok || a != addr {
t.Fatalf("Expected %s for AnyServingHTTP, even with error status code", addr)
}
} }
// Test_IsServingHTTP_HTTPSSuccessStatusCode tests HTTPS server running with success status code. // Test_IsServingHTTP_HTTPSSuccessStatusCode tests HTTPS server running with success status code.
@ -78,6 +93,9 @@ func Test_IsServingHTTP_HTTPSSuccessStatusCode(t *testing.T) {
if !IsServingHTTP(addr) { if !IsServingHTTP(addr) {
t.Error("Expected true for HTTPS server running with success status code") t.Error("Expected true for HTTPS server running with success status code")
} }
if a, ok := AnyServingHTTP([]string{addr}); !ok || a != addr {
t.Fatalf("Expected %s for AnyServingHTTP with success status code", addr)
}
} }
func Test_IsServingHTTP_OpenPort(t *testing.T) { func Test_IsServingHTTP_OpenPort(t *testing.T) {
@ -92,6 +110,9 @@ func Test_IsServingHTTP_OpenPort(t *testing.T) {
if IsServingHTTP(addr) { if IsServingHTTP(addr) {
t.Error("Expected false for open port") t.Error("Expected false for open port")
} }
if _, ok := AnyServingHTTP([]string{addr}); ok {
t.Error("Expected false for open port")
}
} }
func Test_IsServingHTTP_OpenPortTLS(t *testing.T) { func Test_IsServingHTTP_OpenPortTLS(t *testing.T) {
@ -116,4 +137,26 @@ func Test_IsServingHTTP_OpenPortTLS(t *testing.T) {
if IsServingHTTP(addr) { if IsServingHTTP(addr) {
t.Error("Expected false for open TLS port") t.Error("Expected false for open TLS port")
} }
if _, ok := AnyServingHTTP([]string{addr}); ok {
t.Error("Expected false for open TLS port")
}
}
func Test_IsServingHTTP_HTTPServerTCPPort(t *testing.T) {
httpServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}))
defer httpServer.Close()
ln, err := net.Listen("tcp", ":0")
if err != nil {
t.Fatal(err)
}
defer ln.Close()
httpAddr := httpServer.Listener.Addr().String()
tcpAddr := ln.Addr().String()
if a, ok := AnyServingHTTP([]string{httpAddr, tcpAddr}); !ok || a != httpAddr {
t.Fatalf("Expected %s for AnyServingHTTP", httpAddr)
}
} }

Loading…
Cancel
Save