@ -2,9 +2,11 @@ package tcp
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"os"
@ -26,6 +28,8 @@ type Layer struct {
remoteEncrypted bool
skipVerify bool
nodeX509CACert string
tlsConfig * tls . Config
}
// Addr returns the local address for the layer.
@ -40,10 +44,7 @@ func (l *Layer) Dial(addr string, timeout time.Duration) (net.Conn, error) {
var err error
var conn net . Conn
if l . remoteEncrypted {
conf := & tls . Config {
InsecureSkipVerify : l . skipVerify ,
}
conn , err = tls . DialWithDialer ( dialer , "tcp" , addr , conf )
conn , err = tls . DialWithDialer ( dialer , "tcp" , addr , l . tlsConfig )
} else {
conn , err = dialer . Dial ( "tcp" , addr )
}
@ -82,6 +83,9 @@ type Mux struct {
// Out-of-band error logger
Logger * log . Logger
// Path to root X.509 certificate.
nodeX509CACert string
// Path to X509 certificate
nodeX509Cert string
@ -90,6 +94,8 @@ type Mux struct {
// Whether to skip verification of other nodes' certificates.
InsecureSkipVerify bool
tlsConfig * tls . Config
}
// NewMux returns a new instance of Mux for ln. If adv is nil,
@ -111,17 +117,20 @@ func NewMux(ln net.Listener, adv net.Addr) (*Mux, error) {
// NewTLSMux returns a new instance of Mux for ln, and encrypts all traffic
// using TLS. If adv is nil, then the addr of ln is used.
func NewTLSMux ( ln net . Listener , adv net . Addr , cert , key string ) ( * Mux , error ) {
func NewTLSMux ( ln net . Listener , adv net . Addr , cert , key , caCert string ) ( * Mux , error ) {
mux , err := NewMux ( ln , adv )
if err != nil {
return nil , err
}
mux . ln, err = newTLSListener ( mux . ln , cert , key )
mux . tlsConfig, err = createTLSConfig ( cert , key , caCert )
if err != nil {
return nil , err
}
mux . ln = tls . NewListener ( ln , mux . tlsConfig )
mux . remoteEncrypted = true
mux . nodeX509CACert = caCert
mux . nodeX509Cert = cert
mux . nodeX509Key = key
@ -168,6 +177,7 @@ func (mux *Mux) Stats() (interface{}, error) {
if mux . remoteEncrypted {
s [ "certificate" ] = mux . nodeX509Cert
s [ "key" ] = mux . nodeX509Key
s [ "ca_certificate" ] = mux . nodeX509CACert
s [ "skip_verify" ] = strconv . FormatBool ( mux . InsecureSkipVerify )
}
@ -230,6 +240,8 @@ func (mux *Mux) Listen(header byte) *Layer {
addr : mux . addr ,
remoteEncrypted : mux . remoteEncrypted ,
skipVerify : mux . InsecureSkipVerify ,
nodeX509CACert : mux . nodeX509CACert ,
tlsConfig : mux . tlsConfig ,
}
return layer
@ -256,8 +268,8 @@ func (ln *listener) Close() error { return nil }
func ( ln * listener ) Addr ( ) net . Addr { return nil }
// newTLSListener returns a net listener which encrypts the traffic using TLS.
func newTLSListener ( ln net . Listener , certFile , keyFile string ) ( net . Listener , error ) {
config , err := createTLSConfig ( certFile , keyFile )
func newTLSListener ( ln net . Listener , certFile , keyFile , caCertFile string ) ( net . Listener , error ) {
config , err := createTLSConfig ( certFile , keyFile , caCertFile )
if err != nil {
return nil , err
}
@ -266,7 +278,7 @@ func newTLSListener(ln net.Listener, certFile, keyFile string) (net.Listener, er
}
// createTLSConfig returns a TLS config from the given cert and key.
func createTLSConfig ( certFile , keyFile string ) ( * tls . Config , error ) {
func createTLSConfig ( certFile , keyFile , caCertFile string ) ( * tls . Config , error ) {
var err error
config := & tls . Config { }
config . Certificates = make ( [ ] tls . Certificate , 1 )
@ -274,5 +286,16 @@ func createTLSConfig(certFile, keyFile string) (*tls.Config, error) {
if err != nil {
return nil , err
}
if caCertFile != "" {
asn1Data , err := ioutil . ReadFile ( caCertFile )
if err != nil {
return nil , err
}
config . RootCAs = x509 . NewCertPool ( )
ok := config . RootCAs . AppendCertsFromPEM ( [ ] byte ( asn1Data ) )
if ! ok {
return nil , fmt . Errorf ( "failed to parse root certificate(s) in %s" , caCertFile )
}
}
return config , nil
}