diff --git a/rtls/config.go b/rtls/config.go index 6af94d09..e2b3b9c4 100644 --- a/rtls/config.go +++ b/rtls/config.go @@ -9,14 +9,18 @@ import ( "os" ) +const ( + NoCACert = "" + NoServerName = "" +) + // CreateClientConfig creates a TLS configuration for use by a system that does both // client and server authentication using the same cert, key, and CA cert. If noverify -// is true, the client will not verify the server's certificate. If mutual is true, -// the server will verify the client's certificate. If tls1011 is true, the client will -// accept TLS 1.0 or 1.1. Otherwise, it will require TLS 1.2 or higher. -func CreateConfig(certFile, keyFile, caCertFile, serverName string, noverify, mutual bool) (*tls.Config, error) { +// is true, the client will not verify the server's certificate. If mutual is true, the +// server will verify the client's certificate. +func CreateConfig(certFile, keyFile, caCertFile string, noverify, mutual bool) (*tls.Config, error) { var err error - config := createBaseTLSConfig(serverName, noverify) + config := createBaseTLSConfig(NoServerName, noverify) // load the certificate and key if certFile != "" && keyFile != "" { @@ -54,9 +58,9 @@ func CreateConfig(certFile, keyFile, caCertFile, serverName string, noverify, mu // parameters are the paths to the client's certificate and key files, which will be used to // authenticate the client to the server if mutual TLS is active. The caCertFile parameter // is the path to the CA certificate file, which the client will use to verify any certificate -// presented by the server. If noverify is true, the client will not verify the server's certificate. -// If tls1011 is true, the client will accept TLS 1.0 or 1.1. Otherwise, it will require TLS 1.2 -// or higher. +// presented by the server. serverName can also be set, informing the client which hostname +// should appear in the returned certificate. If noverify is true, the client will not verify +// the server's certificate. func CreateClientConfig(certFile, keyFile, caCertFile, serverName string, noverify bool) (*tls.Config, error) { var err error @@ -86,13 +90,11 @@ func CreateClientConfig(certFile, keyFile, caCertFile, serverName string, noveri // parameters are the paths to the server's certificate and key files, which will be used to // authenticate the server to the client. The caCertFile parameter is the path to the CA // certificate file, which the server will use to verify any certificate presented by the -// client. If noverify is true, the server will not verify the client's certificate. If -// tls1011 is true, the server will accept TLS 1.0 or 1.1. Otherwise, it will require TLS 1.2 -// or higher. -func CreateServerConfig(certFile, keyFile, caCertFile, serverName string, noverify bool) (*tls.Config, error) { +// client. If noverify is true, the server will not verify the client's certificate. +func CreateServerConfig(certFile, keyFile, caCertFile string, noverify bool) (*tls.Config, error) { var err error - config := createBaseTLSConfig(serverName, false) + config := createBaseTLSConfig(NoServerName, false) config.Certificates = make([]tls.Certificate, 1) config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile) if err != nil { diff --git a/rtls/config_test.go b/rtls/config_test.go index 7a2a4c1a..6f9ca087 100644 --- a/rtls/config_test.go +++ b/rtls/config_test.go @@ -26,7 +26,7 @@ func Test_CreateConfig(t *testing.T) { caCertFile := mustWriteTempFile(t, caCertPEM) // create a config with no server or client verification - config, err := CreateConfig(certFile, keyFile, caCertFile, "", true, false) + config, err := CreateConfig(certFile, keyFile, caCertFile, true, false) if err != nil { t.Fatalf("failed to create config: %v", err) } @@ -69,7 +69,7 @@ func Test_CreateConfig(t *testing.T) { } // create a config with server cert verification only - config, err = CreateConfig(certFile, keyFile, caCertFile, "", false, false) + config, err = CreateConfig(certFile, keyFile, caCertFile, false, false) if err != nil { t.Fatalf("failed to create config: %v", err) } @@ -81,7 +81,7 @@ func Test_CreateConfig(t *testing.T) { } // create a config with both server and client verification - config, err = CreateConfig(certFile, keyFile, "", "", false, true) + config, err = CreateConfig(certFile, keyFile, "", false, true) if err != nil { t.Fatalf("failed to create config: %v", err) } @@ -103,7 +103,7 @@ func Test_CreateServerConfig(t *testing.T) { keyFile := mustWriteTempFile(t, keyPEM) // create a server config with no client verification - config, err := CreateServerConfig(certFile, keyFile, "", "", true) + config, err := CreateServerConfig(certFile, keyFile, NoCACert, true) if err != nil { t.Fatalf("failed to create server config: %v", err) } @@ -130,7 +130,7 @@ func Test_CreateServerConfig(t *testing.T) { } // create a server config with client verification - config, err = CreateServerConfig(certFile, keyFile, "", "", false) + config, err = CreateServerConfig(certFile, keyFile, NoCACert, false) if err != nil { t.Fatalf("failed to create server config: %v", err) } @@ -149,7 +149,7 @@ func Test_CreateClientConfig(t *testing.T) { keyFile := mustWriteTempFile(t, keyPEM) // create a client config with no server verification - config, err := CreateClientConfig(certFile, keyFile, "", "", true) + config, err := CreateClientConfig(certFile, keyFile, NoCACert, NoServerName, true) if err != nil { t.Fatalf("failed to create client config: %v", err) } @@ -176,13 +176,22 @@ func Test_CreateClientConfig(t *testing.T) { } // create a client config with server verification - config, err = CreateClientConfig(certFile, keyFile, "", "", false) + config, err = CreateClientConfig(certFile, keyFile, NoCACert, NoServerName, false) if err != nil { t.Fatalf("failed to create client config: %v", err) } if config.InsecureSkipVerify { t.Fatalf("expected InsecureSkipVerify to be false, got true") } + + // create a client config with Server Name + config, err = CreateClientConfig(certFile, keyFile, NoCACert, "expected", false) + if err != nil { + t.Fatalf("failed to create client config: %v", err) + } + if config.ServerName != "expected" { + t.Fatalf("expected ServerName to be 'expected', got %s", config.ServerName) + } } // mustWriteTempFile writes the given bytes to a temporary file, and returns the