1
0
Fork 0

Refactor some code out of main

master
Philip O'Toole 9 months ago
parent 6d3f0dafdf
commit 65ab53f23f

@ -3,6 +3,7 @@ package cluster
import (
"bytes"
"compress/gzip"
"crypto/tls"
"encoding/binary"
"errors"
"fmt"
@ -11,7 +12,10 @@ import (
"sync"
"time"
"github.com/rqlite/rqlite/v8/auth"
"github.com/rqlite/rqlite/v8/command"
"github.com/rqlite/rqlite/v8/rtls"
"github.com/rqlite/rqlite/v8/tcp"
"github.com/rqlite/rqlite/v8/tcp/pool"
"google.golang.org/protobuf/proto"
)
@ -24,6 +28,36 @@ const (
protoBufferLengthSize = 8
)
// CreateRaftDialer creates a dialer for connecting to other nodes' Raft service. If the cert and
// key arguments are not set, then the returned dialer will not use TLS.
func CreateRaftDialer(cert, key, caCert, serverName string, Insecure bool) (*tcp.Dialer, error) {
var dialerTLSConfig *tls.Config
var err error
if cert != "" || key != "" {
dialerTLSConfig, err = rtls.CreateClientConfig(cert, key, caCert, serverName, Insecure)
if err != nil {
return nil, fmt.Errorf("failed to create TLS config for Raft dialer: %s", err.Error())
}
}
return tcp.NewDialer(MuxRaftHeader, dialerTLSConfig), nil
}
// CredentialsFor returns a Credentials instance for the given username, or nil if
// the given CredentialsStore is nil, or the username is not found.
func CredentialsFor(credStr *auth.CredentialsStore, username string) *Credentials {
if credStr == nil {
return nil
}
pw, ok := credStr.Password(username)
if !ok {
return nil
}
return &Credentials{
Username: username,
Password: pw,
}
}
// Client allows communicating with a remote node.
type Client struct {
dialer Dialer

@ -95,7 +95,8 @@ func main() {
// Raft internode layer
raftLn := mux.Listen(cluster.MuxRaftHeader)
log.Printf("Raft TCP mux Listener registered with byte header %d", cluster.MuxRaftHeader)
raftDialer, err := createRaftDialer(cfg)
raftDialer, err := cluster.CreateRaftDialer(cfg.NodeX509Cert, cfg.NodeX509Key, cfg.NodeX509CACert,
cfg.NodeVerifyServerName, cfg.NoNodeVerify)
if err != nil {
log.Fatalf("failed to create Raft dialer: %s", err.Error())
}
@ -202,7 +203,7 @@ func main() {
if cfg.RaftClusterRemoveOnShutdown {
remover := cluster.NewRemover(clstrClient, 5*time.Second, str)
remover.SetCredentials(credentialsFor(credStr, cfg.JoinAs))
remover.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs))
log.Printf("initiating removal of this node from cluster before shutdown")
if err := remover.Do(cfg.NodeID, true); err != nil {
log.Fatalf("failed to remove this node from cluster before shutdown: %s", err.Error())
@ -459,22 +460,9 @@ func createClusterClient(cfg *Config, clstr *cluster.Service) (*cluster.Client,
return clstrClient, nil
}
func createRaftDialer(cfg *Config) (*tcp.Dialer, error) {
var dialerTLSConfig *tls.Config
var err error
if cfg.NodeX509Cert != "" || cfg.NodeX509CACert != "" {
dialerTLSConfig, err = rtls.CreateClientConfig(cfg.NodeX509Cert, cfg.NodeX509Key,
cfg.NodeX509CACert, cfg.NodeVerifyServerName, cfg.NoNodeVerify)
if err != nil {
return nil, fmt.Errorf("failed to create TLS config for Raft dialer: %s", err.Error())
}
}
return tcp.NewDialer(cluster.MuxRaftHeader, dialerTLSConfig), nil
}
func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *store.Store, httpServ *httpd.Service, credStr *auth.CredentialsStore) error {
joins := cfg.JoinAddresses()
if err := networkCheckJoinAddrs(cfg, joins); err != nil {
if err := networkCheckJoinAddrs(joins); err != nil {
return err
}
@ -498,7 +486,7 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor
}
joiner := cluster.NewJoiner(client, cfg.JoinAttempts, cfg.JoinInterval)
joiner.SetCredentials(credentialsFor(credStr, cfg.JoinAs))
joiner.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs))
if joins != nil && cfg.BootstrapExpect == 0 {
// Explicit join operation requested, so do it.
j, err := joiner.Do(joins, str.ID(), cfg.RaftAdv, !cfg.RaftNonVoter)
@ -512,7 +500,7 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor
if joins != nil && cfg.BootstrapExpect > 0 {
// Bootstrap with explicit join addresses requests.
bs := cluster.NewBootstrapper(cluster.NewAddressProviderString(joins), client)
bs.SetCredentials(credentialsFor(credStr, cfg.JoinAs))
bs.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs))
return bs.Boot(str.ID(), cfg.RaftAdv, isClustered, cfg.BootstrapExpectTimeout)
}
@ -555,7 +543,7 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor
}
bs := cluster.NewBootstrapper(provider, client)
bs.SetCredentials(credentialsFor(credStr, cfg.JoinAs))
bs.SetCredentials(cluster.CredentialsFor(credStr, cfg.JoinAs))
httpServ.RegisterStatus("disco", provider)
return bs.Boot(str.ID(), cfg.RaftAdv, isClustered, cfg.BootstrapExpectTimeout)
@ -610,29 +598,10 @@ func createCluster(cfg *Config, hasPeers bool, client *cluster.Client, str *stor
return nil
}
func networkCheckJoinAddrs(cfg *Config, joinAddrs []string) error {
if len(joinAddrs) == 0 {
return nil
}
func networkCheckJoinAddrs(joinAddrs []string) error {
log.Println("checking that join addresses don't serve HTTP(S)")
for _, addr := range joinAddrs {
if http.IsServingHTTP(addr) {
return fmt.Errorf("join address %s appears to be serving HTTP when it should be Raft", addr)
}
if http.AnyServingHTTP(joinAddrs) {
return fmt.Errorf("join address appears to be serving HTTP when it should be Raft")
}
return nil
}
func credentialsFor(credStr *auth.CredentialsStore, username string) *cluster.Credentials {
if credStr == nil {
return nil
}
pw, ok := credStr.Password(username)
if !ok {
return nil
}
return &cluster.Credentials{
Username: username,
Password: pw,
}
}

@ -6,6 +6,17 @@ import (
"time"
)
// AnyServingHTTP returns true if there appears to be a HTTP or HTTPS server
// running on any of the given addresses.
func AnyServingHTTP(addrs []string) bool {
for _, addr := range addrs {
if IsServingHTTP(addr) {
return true
}
}
return false
}
// IsServingHTTP returns true if there appears to be a HTTP or HTTPS server
// running on the given address.
func IsServingHTTP(addr string) bool {

@ -21,7 +21,10 @@ func Test_IsServingHTTP_HTTPServer(t *testing.T) {
addr := httpServer.Listener.Addr().String()
if !IsServingHTTP(addr) {
t.Errorf("Expected true for HTTP server running on %s", addr)
t.Fatalf("Expected true for HTTP server running on %s", addr)
}
if !AnyServingHTTP([]string{addr}) {
t.Fatalf("Expected true for HTTP server running on %s", addr)
}
}
@ -36,6 +39,9 @@ func Test_IsServingHTTP_HTTPSServer(t *testing.T) {
if !IsServingHTTP(addr) {
t.Error("Expected true for HTTPS server running")
}
if !AnyServingHTTP([]string{addr}) {
t.Fatalf("Expected true for HTTPS server running")
}
}
// Test_IsServingHTTP_NoServersRunning tests no servers running.
@ -44,6 +50,9 @@ func Test_IsServingHTTP_NoServersRunning(t *testing.T) {
if IsServingHTTP(addr) {
t.Error("Expected false for no servers running")
}
if AnyServingHTTP([]string{addr}) {
t.Error("Expected false for no servers running")
}
}
// Test_IsServingHTTP_InvalidAddress tests invalid address format.
@ -52,6 +61,9 @@ func Test_IsServingHTTP_InvalidAddress(t *testing.T) {
if IsServingHTTP(addr) {
t.Error("Expected false for invalid address")
}
if AnyServingHTTP([]string{addr}) {
t.Error("Expected false for invalid address")
}
}
// Test_IsServingHTTP_HTTPErrorStatusCode tests HTTP server returning error status code.
@ -65,6 +77,9 @@ func Test_IsServingHTTP_HTTPErrorStatusCode(t *testing.T) {
if !IsServingHTTP(addr) {
t.Error("Expected true for HTTP server running, even with error status code")
}
if !AnyServingHTTP([]string{addr}) {
t.Error("Expected true for HTTP server running, even with error status code")
}
}
// Test_IsServingHTTP_HTTPSSuccessStatusCode tests HTTPS server running with success status code.
@ -78,6 +93,9 @@ func Test_IsServingHTTP_HTTPSSuccessStatusCode(t *testing.T) {
if !IsServingHTTP(addr) {
t.Error("Expected true for HTTPS server running with success status code")
}
if !AnyServingHTTP([]string{addr}) {
t.Error("Expected true for HTTPS server running with success status code")
}
}
func Test_IsServingHTTP_OpenPort(t *testing.T) {
@ -92,6 +110,9 @@ func Test_IsServingHTTP_OpenPort(t *testing.T) {
if IsServingHTTP(addr) {
t.Error("Expected false for open port")
}
if AnyServingHTTP([]string{addr}) {
t.Error("Expected false for open port")
}
}
func Test_IsServingHTTP_OpenPortTLS(t *testing.T) {
@ -116,4 +137,7 @@ func Test_IsServingHTTP_OpenPortTLS(t *testing.T) {
if IsServingHTTP(addr) {
t.Error("Expected false for open TLS port")
}
if AnyServingHTTP([]string{addr}) {
t.Error("Expected false for open TLS port")
}
}

Loading…
Cancel
Save