From b6457a46479aee2242f21d840fc4017060c11860 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 23 Dec 2023 10:24:10 -0500 Subject: [PATCH] Remove unused joint TLS config creation This was confusing and too clever. --- http/preflight_test.go | 2 +- http/service.go | 6 ++- rtls/config.go | 53 +++++-------------------- rtls/config_test.go | 88 +----------------------------------------- tcp/dialer_test.go | 2 +- tcp/mux.go | 17 +++++++- 6 files changed, 34 insertions(+), 134 deletions(-) diff --git a/http/preflight_test.go b/http/preflight_test.go index c6ae5f6c..2a4baf71 100644 --- a/http/preflight_test.go +++ b/http/preflight_test.go @@ -101,7 +101,7 @@ func Test_IsServingHTTP_OpenPortTLS(t *testing.T) { } certFile := mustWriteTempFile(t, cert) keyFile := mustWriteTempFile(t, key) - tlsConfig, err := rtls.CreateServerConfig(certFile, keyFile, rtls.NoCACert, false) + tlsConfig, err := rtls.CreateServerConfig(certFile, keyFile, rtls.NoCACert, rtls.MTLSStateEnabled) if err != nil { t.Fatalf("failed to create TLS config: %s", err) } diff --git a/http/service.go b/http/service.go index 754db5f2..4a5d4161 100644 --- a/http/service.go +++ b/http/service.go @@ -370,7 +370,11 @@ func (s *Service) Start() error { return err } } else { - s.tlsConfig, err = rtls.CreateServerConfig(s.CertFile, s.KeyFile, s.CACertFile, !s.ClientVerify) + mTLSState := rtls.MTLSStateDisabled + if s.ClientVerify { + mTLSState = rtls.MTLSStateEnabled + } + s.tlsConfig, err = rtls.CreateServerConfig(s.CertFile, s.KeyFile, s.CACertFile, mTLSState) if err != nil { return err } diff --git a/rtls/config.go b/rtls/config.go index e2b3b9c4..108a1677 100644 --- a/rtls/config.go +++ b/rtls/config.go @@ -14,45 +14,13 @@ const ( NoServerName = "" ) -// CreateClientConfig creates a TLS configuration for use by a system that does both -// client and server authentication using the same cert, key, and CA cert. If noverify -// is true, the client will not verify the server's certificate. If mutual is true, the -// server will verify the client's certificate. -func CreateConfig(certFile, keyFile, caCertFile string, noverify, mutual bool) (*tls.Config, error) { - var err error - config := createBaseTLSConfig(NoServerName, noverify) +// MTLSState indicates whether mutual TLS is enabled or disabled. +type MTLSState tls.ClientAuthType - // load the certificate and key - if certFile != "" && keyFile != "" { - config.Certificates = make([]tls.Certificate, 1) - config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return nil, err - } - } - - // load the CA certificate file, if provided, as the root CA and client CA - if caCertFile != "" { - asn1Data, err := os.ReadFile(caCertFile) - if err != nil { - return nil, err - } - config.RootCAs = x509.NewCertPool() - ok := config.RootCAs.AppendCertsFromPEM(asn1Data) - if !ok { - return nil, fmt.Errorf("failed to load CA certificate(s) for server verification in %q", caCertFile) - } - config.ClientCAs = x509.NewCertPool() - ok = config.ClientCAs.AppendCertsFromPEM(asn1Data) - if !ok { - return nil, fmt.Errorf("failed to load CA certificate(s) for client verification in %q", caCertFile) - } - } - if mutual { - config.ClientAuth = tls.RequireAndVerifyClientCert - } - return config, nil -} +const ( + MTLSStateDisabled MTLSState = MTLSState(tls.NoClientCert) + MTLSStateEnabled MTLSState = MTLSState(tls.RequireAndVerifyClientCert) +) // CreateClientConfig creates a new tls.Config for use by a client. The certFile and keyFile // parameters are the paths to the client's certificate and key files, which will be used to @@ -90,8 +58,9 @@ func CreateClientConfig(certFile, keyFile, caCertFile, serverName string, noveri // parameters are the paths to the server's certificate and key files, which will be used to // authenticate the server to the client. The caCertFile parameter is the path to the CA // certificate file, which the server will use to verify any certificate presented by the -// client. If noverify is true, the server will not verify the client's certificate. -func CreateServerConfig(certFile, keyFile, caCertFile string, noverify bool) (*tls.Config, error) { +// client. If mtls is MTLSStateEnabled, the server will require the client to present a +// valid certificate. +func CreateServerConfig(certFile, keyFile, caCertFile string, mtls MTLSState) (*tls.Config, error) { var err error config := createBaseTLSConfig(NoServerName, false) @@ -111,9 +80,7 @@ func CreateServerConfig(certFile, keyFile, caCertFile string, noverify bool) (*t return nil, fmt.Errorf("failed to load CA certificate(s) for client verification in %q", caCertFile) } } - if !noverify { - config.ClientAuth = tls.RequireAndVerifyClientCert - } + config.ClientAuth = tls.ClientAuthType(mtls) return config, nil } diff --git a/rtls/config_test.go b/rtls/config_test.go index 6f9ca087..aadfbfa7 100644 --- a/rtls/config_test.go +++ b/rtls/config_test.go @@ -9,90 +9,6 @@ import ( "time" ) -func Test_CreateConfig(t *testing.T) { - // generate a cert and key pair, and write both to a temporary file - certPEM, keyPEM, err := GenerateCert(pkix.Name{CommonName: "rqlite"}, 365*24*time.Hour, 2048, nil, nil) - if err != nil { - t.Fatalf("failed to generate cert: %v", err) - } - certFile := mustWriteTempFile(t, certPEM) - keyFile := mustWriteTempFile(t, keyPEM) - - // generate a CA cert, and write it to a temporary file - caCertPEM, _, err := GenerateCACert(pkix.Name{CommonName: "rqlite CA"}, 365*24*time.Hour, 2048) - if err != nil { - t.Fatalf("failed to generate cert: %v", err) - } - caCertFile := mustWriteTempFile(t, caCertPEM) - - // create a config with no server or client verification - config, err := CreateConfig(certFile, keyFile, caCertFile, true, false) - if err != nil { - t.Fatalf("failed to create config: %v", err) - } - if config.ClientAuth != tls.NoClientCert { - t.Fatalf("expected ClientAuth to be NoClientCert, got %v", config.ClientAuth) - } - if !config.InsecureSkipVerify { - t.Fatalf("expected InsecureSkipVerify to be true, got false") - } - - // Check that the certificate is loaded correctly - if len(config.Certificates) != 1 { - t.Fatalf("expected 1 certificate, got %d", len(config.Certificates)) - } - // parse the certificate in the tls config - parsedCert, err := x509.ParseCertificate(config.Certificates[0].Certificate[0]) - if err != nil { - t.Fatalf("failed to parse certificate: %v", err) - } - if parsedCert.Subject.CommonName != "rqlite" { - t.Fatalf("expected certificate subject to be 'rqlite', got %s", parsedCert.Subject.CommonName) - } - - // Check that the root and client CAs are loaded with the correct certificate - caCertPool := x509.NewCertPool() - caCertPool.AppendCertsFromPEM(caCertPEM) - - if config.RootCAs == nil { - t.Fatalf("got nil root CA") - } - if !config.RootCAs.Equal(caCertPool) { - t.Fatalf("expected root CA to be %v, got %v", caCertPool, config.RootCAs) - } - - if config.ClientCAs == nil { - t.Fatalf("got nil client CA") - } - if !config.ClientCAs.Equal(caCertPool) { - t.Fatalf("expected client CA to be %v, got %v", caCertPool, config.ClientCAs) - } - - // create a config with server cert verification only - config, err = CreateConfig(certFile, keyFile, caCertFile, false, false) - if err != nil { - t.Fatalf("failed to create config: %v", err) - } - if config.ClientAuth != tls.NoClientCert { - t.Fatalf("expected ClientAuth to be NoClientCert, got %v", config.ClientAuth) - } - if config.InsecureSkipVerify { - t.Fatalf("expected InsecureSkipVerify to be false, got true") - } - - // create a config with both server and client verification - config, err = CreateConfig(certFile, keyFile, "", false, true) - if err != nil { - t.Fatalf("failed to create config: %v", err) - } - if config.ClientAuth != tls.RequireAndVerifyClientCert { - t.Fatalf("expected ClientAuth to be RequireAndVerifyClientCert, got %v", config.ClientAuth) - } - if config.InsecureSkipVerify { - t.Fatalf("expected InsecureSkipVerify to be false, got true") - } -} - func Test_CreateServerConfig(t *testing.T) { // generate a cert and key pair, and write both to a temporary file certPEM, keyPEM, err := GenerateCert(pkix.Name{CommonName: "rqlite"}, 365*24*time.Hour, 2048, nil, nil) @@ -103,7 +19,7 @@ func Test_CreateServerConfig(t *testing.T) { keyFile := mustWriteTempFile(t, keyPEM) // create a server config with no client verification - config, err := CreateServerConfig(certFile, keyFile, NoCACert, true) + config, err := CreateServerConfig(certFile, keyFile, NoCACert, MTLSStateDisabled) if err != nil { t.Fatalf("failed to create server config: %v", err) } @@ -130,7 +46,7 @@ func Test_CreateServerConfig(t *testing.T) { } // create a server config with client verification - config, err = CreateServerConfig(certFile, keyFile, NoCACert, false) + config, err = CreateServerConfig(certFile, keyFile, NoCACert, MTLSStateEnabled) if err != nil { t.Fatalf("failed to create server config: %v", err) } diff --git a/tcp/dialer_test.go b/tcp/dialer_test.go index 5cc6697b..efec1c49 100644 --- a/tcp/dialer_test.go +++ b/tcp/dialer_test.go @@ -185,7 +185,7 @@ func mustNewEchoServerTLS_ExampleDotCom() (*echoServer, string, string) { cert := x509.CertExampleDotComFile("") key := x509.KeyExampleDotComFile("") - tlsConfig, err := rtls.CreateServerConfig(cert, key, rtls.NoCACert, true) + tlsConfig, err := rtls.CreateServerConfig(cert, key, rtls.NoCACert, rtls.MTLSStateDisabled) if err != nil { panic(fmt.Sprintf("failed to create TLS config: %s", err.Error())) } diff --git a/tcp/mux.go b/tcp/mux.go index a2baa957..8717c36d 100644 --- a/tcp/mux.go +++ b/tcp/mux.go @@ -107,18 +107,31 @@ func NewMux(ln net.Listener, adv net.Addr) (*Mux, error) { // then the server will not verify the client's certificate. If mutual is true, // then the server will require the client to present a trusted certificate. func NewTLSMux(ln net.Listener, adv net.Addr, cert, key, caCert string, insecure, mutual bool) (*Mux, error) { + return newTLSMux(ln, adv, cert, key, caCert, false) +} + +// NewMutualTLSMux returns a new instance of Mux for ln, and encrypts all traffic +// using TLS. The server will also verify the client's certificate. +func NewMutualTLSMux(ln net.Listener, adv net.Addr, cert, key, caCert string) (*Mux, error) { + return newTLSMux(ln, adv, cert, key, caCert, true) +} + +func newTLSMux(ln net.Listener, adv net.Addr, cert, key, caCert string, mutual bool) (*Mux, error) { mux, err := NewMux(ln, adv) if err != nil { return nil, err } - mux.tlsConfig, err = rtls.CreateConfig(cert, key, caCert, insecure, mutual) + mtlsState := rtls.MTLSStateDisabled + if mutual { + mtlsState = rtls.MTLSStateEnabled + } + mux.tlsConfig, err = rtls.CreateServerConfig(cert, key, caCert, mtlsState) if err != nil { return nil, fmt.Errorf("cannot create TLS config: %s", err) } mux.ln = tls.NewListener(ln, mux.tlsConfig) - return mux, nil }