diff --git a/cmd/rqlited/main.go b/cmd/rqlited/main.go index 1b5cc9b9..670252c6 100644 --- a/cmd/rqlited/main.go +++ b/cmd/rqlited/main.go @@ -68,6 +68,7 @@ var httpAdv string var authFile string var x509Cert string var x509Key string +var nodeEncrypt bool var nodeX509Cert string var nodeX509Key string var raftAddr string @@ -98,14 +99,15 @@ func init() { flag.StringVar(&httpAdv, "httpadv", "", "Advertised HTTP address. If not set, same as HTTP server") flag.StringVar(&x509Cert, "x509cert", "", "Path to X.509 certificate for HTTP endpoint") flag.StringVar(&x509Key, "x509key", "", "Path to X.509 private key for certificate HTTP endpoint") - flag.StringVar(&nodeX509Cert, "nodex509cert", "", "Path to X.509 certificate for inter-node communication") - flag.StringVar(&nodeX509Key, "nodex509key", "", "Path to X.509 private key for inter-node communication") + flag.BoolVar(&noVerify, "noverify", false, "Skip verification of remote HTTPS cert when joining cluster") + flag.BoolVar(&nodeEncrypt, "encrypt", false, "Enable node-to-node encryption") + flag.StringVar(&nodeX509Cert, "nodex509cert", "cert.pem", "Path to X.509 certificate for node-to-node encryption") + flag.StringVar(&nodeX509Key, "nodex509key", "key.pem", "Path to X.509 private key for node-to-node encryption") + flag.BoolVar(&noNodeVerify, "nonodeverify", false, "Skip verification of a remote node cert") flag.StringVar(&authFile, "auth", "", "Path to authentication and authorization file. If not set, not enabled") flag.StringVar(&raftAddr, "raft", "localhost:4002", "Raft communication bind address") flag.StringVar(&raftAdv, "raftadv", "", "Advertised Raft communication address. If not set, same as Raft bind") flag.StringVar(&joinAddr, "join", "", "Comma-delimited list of nodes, through which a cluster can be joined (proto://host:port)") - flag.BoolVar(&noVerify, "noverify", false, "Skip verification of remote HTTPS cert when joining cluster") - flag.BoolVar(&noNodeVerify, "nonodeverify", false, "Skip verification of a remote node cert") flag.StringVar(&discoURL, "disco", "http://discovery.rqlite.com", "Set Discovery Service URL") flag.StringVar(&discoID, "discoid", "", "Set Discovery ID. If not set, Discovery Service not used") flag.BoolVar(&expvar, "expvar", true, "Serve expvar data on HTTP server") @@ -156,7 +158,7 @@ func main() { // Start requested profiling. startProfile(cpuProfile, memProfile) - // Set up internode TCP communication. + // Set up node-to-node TCP communication. ln, err := net.Listen("tcp", raftAddr) if err != nil { log.Fatalf("failed to listen on %s: %s", raftAddr, err.Error()) @@ -169,18 +171,21 @@ func main() { } } - // Encypt internode TCP connection if requested. - if nodeX509Cert != "" && nodeX509Key != "" { - ln, err = tcp.NewTLSListener(ln, nodeX509Cert, nodeX509Key) - if err != nil { - log.Fatalf("failed to create encrypted inter-node communication: %s", err.Error()) - } + // Start up node-to-node network mux. + var mux *tcp.Mux + if nodeEncrypt { log.Printf("encrypting inter-node connection with cert %s, key %s", nodeX509Cert, nodeX509Key) + mux, err = tcp.NewTLSMux(ln, adv, nodeX509Cert, nodeX509Key) + } else { + mux, err = tcp.NewMux(ln, adv) } - - // Start up mux and get transports for cluster. - mux := tcp.NewMux(ln, adv) + if err != nil { + log.Fatalf("failed to create node-to-node mux: %s", err.Error()) + } + mux.InsecureSkipVerify = noNodeVerify go mux.Serve() + + // Get transport for Raft communications. raftTn := mux.Listen(muxRaftHeader) // Create and open the store. diff --git a/tcp/mux.go b/tcp/mux.go index d0126aa3..806ffc3a 100644 --- a/tcp/mux.go +++ b/tcp/mux.go @@ -22,6 +22,9 @@ type Layer struct { ln net.Listener header byte addr net.Addr + + remoteEncrypted bool + skipVerify bool } // Addr returns the local address for the layer. @@ -31,7 +34,18 @@ func (l *Layer) Addr() net.Addr { // Dial creates a new network connection. func (l *Layer) Dial(addr string, timeout time.Duration) (net.Conn, error) { - conn, err := net.DialTimeout("tcp", addr, timeout) + dialer := &net.Dialer{Timeout: timeout} + + var err error + var conn net.Conn + if l.remoteEncrypted { + conf := &tls.Config{ + InsecureSkipVerify: l.skipVerify, + } + conn, err = tls.DialWithDialer(dialer, "tcp", addr, conf) + } else { + conn, err = dialer.Dial("tcp", addr) + } if err != nil { return nil, err } @@ -59,16 +73,27 @@ type Mux struct { wg sync.WaitGroup + remoteEncrypted bool + // The amount of time to wait for the first header byte. Timeout time.Duration // Out-of-band error logger Logger *log.Logger + + // Path to X509 certificate + nodeX509Cert string + + // Path to X509 key. + nodeX509Key string + + // Whether to skip verification of other nodes' certificates. + InsecureSkipVerify bool } // NewMux returns a new instance of Mux for ln. If adv is nil, // then the addr of ln is used. -func NewMux(ln net.Listener, adv net.Addr) *Mux { +func NewMux(ln net.Listener, adv net.Addr) (*Mux, error) { addr := adv if addr == nil { addr = ln.Addr() @@ -80,7 +105,33 @@ func NewMux(ln net.Listener, adv net.Addr) *Mux { m: make(map[byte]*listener), Timeout: DefaultTimeout, Logger: log.New(os.Stderr, "[tcp] ", log.LstdFlags), + }, nil +} + +// NewTLSMux returns a new instance of Mux for ln, and encrypts all traffic +// using TLS. If adv is nil, then the addr of ln is used. +func NewTLSMux(ln net.Listener, adv net.Addr, cert, key string) (*Mux, error) { + addr := adv + if addr == nil { + addr = ln.Addr() + } + + var err error + ln, err = newTLSListener(ln, cert, key) + if err != nil { + return nil, err } + + return &Mux{ + ln: ln, + addr: addr, + m: make(map[byte]*listener), + remoteEncrypted: true, + Timeout: DefaultTimeout, + Logger: log.New(os.Stderr, "[tcp] ", log.LstdFlags), + nodeX509Cert: cert, + nodeX509Key: key, + }, nil } // Serve handles connections from ln and multiplexes then across registered listener. @@ -163,9 +214,11 @@ func (mux *Mux) Listen(header byte) *Layer { mux.m[header] = ln layer := &Layer{ - ln: ln, - header: header, - addr: mux.addr, + ln: ln, + header: header, + addr: mux.addr, + remoteEncrypted: mux.remoteEncrypted, + skipVerify: mux.InsecureSkipVerify, } return layer @@ -191,8 +244,8 @@ func (ln *listener) Close() error { return nil } // Addr always returns nil func (ln *listener) Addr() net.Addr { return nil } -// NewTLSListener returns a net listener which encrypts the traffic using TLS. -func NewTLSListener(ln net.Listener, certFile, keyFile string) (net.Listener, error) { +// newTLSListener returns a net listener which encrypts the traffic using TLS. +func newTLSListener(ln net.Listener, certFile, keyFile string) (net.Listener, error) { config, err := createTLSConfig(certFile, keyFile) if err != nil { return nil, err