parent
4f77db64af
commit
3be97dd61e
@ -0,0 +1,276 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultTimeout is the default length of time to wait for first byte.
|
||||||
|
DefaultTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// Layer represents the connection between nodes.
|
||||||
|
type Layer struct {
|
||||||
|
ln net.Listener
|
||||||
|
header byte
|
||||||
|
addr net.Addr
|
||||||
|
|
||||||
|
remoteEncrypted bool
|
||||||
|
skipVerify bool
|
||||||
|
nodeX509CACert string
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// Addr returns the local address for the layer.
|
||||||
|
func (l *Layer) Addr() net.Addr {
|
||||||
|
return l.addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial creates a new network connection.
|
||||||
|
func (l *Layer) Dial(addr string, timeout time.Duration) (net.Conn, error) {
|
||||||
|
dialer := &net.Dialer{Timeout: timeout}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var conn net.Conn
|
||||||
|
if l.remoteEncrypted {
|
||||||
|
conn, err = tls.DialWithDialer(dialer, "tcp", addr, l.tlsConfig)
|
||||||
|
} else {
|
||||||
|
conn, err = dialer.Dial("tcp", addr)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write a marker byte to indicate message type.
|
||||||
|
_, err = conn.Write([]byte{l.header})
|
||||||
|
if err != nil {
|
||||||
|
conn.Close()
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return conn, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept waits for the next connection.
|
||||||
|
func (l *Layer) Accept() (net.Conn, error) { return l.ln.Accept() }
|
||||||
|
|
||||||
|
// Close closes the layer.
|
||||||
|
func (l *Layer) Close() error { return l.ln.Close() }
|
||||||
|
|
||||||
|
// Mux multiplexes a network connection.
|
||||||
|
type Mux struct {
|
||||||
|
ln net.Listener
|
||||||
|
addr net.Addr
|
||||||
|
m map[byte]*listener
|
||||||
|
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
|
remoteEncrypted bool
|
||||||
|
|
||||||
|
// The amount of time to wait for the first header byte.
|
||||||
|
Timeout time.Duration
|
||||||
|
|
||||||
|
// Out-of-band error logger
|
||||||
|
Logger *log.Logger
|
||||||
|
|
||||||
|
// Path to root X.509 certificate.
|
||||||
|
nodeX509CACert string
|
||||||
|
|
||||||
|
// Path to X509 certificate
|
||||||
|
nodeX509Cert string
|
||||||
|
|
||||||
|
// Path to X509 key.
|
||||||
|
nodeX509Key string
|
||||||
|
|
||||||
|
// Whether to skip verification of other nodes' certificates.
|
||||||
|
InsecureSkipVerify bool
|
||||||
|
|
||||||
|
tlsConfig *tls.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewMux returns a new instance of Mux for ln. If adv is nil,
|
||||||
|
// then the addr of ln is used.
|
||||||
|
func NewMux(ln net.Listener, adv net.Addr) (*Mux, error) {
|
||||||
|
addr := adv
|
||||||
|
if addr == nil {
|
||||||
|
addr = ln.Addr()
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Mux{
|
||||||
|
ln: ln,
|
||||||
|
addr: addr,
|
||||||
|
m: make(map[byte]*listener),
|
||||||
|
Timeout: DefaultTimeout,
|
||||||
|
Logger: log.New(os.Stderr, "[mux] ", log.LstdFlags),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTLSMux returns a new instance of Mux for ln, and encrypts all traffic
|
||||||
|
// using TLS. If adv is nil, then the addr of ln is used.
|
||||||
|
func NewTLSMux(ln net.Listener, adv net.Addr, cert, key, caCert string) (*Mux, error) {
|
||||||
|
mux, err := NewMux(ln, adv)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
mux.tlsConfig, err = createTLSConfig(cert, key, caCert)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
mux.ln = tls.NewListener(ln, mux.tlsConfig)
|
||||||
|
mux.remoteEncrypted = true
|
||||||
|
mux.nodeX509CACert = caCert
|
||||||
|
mux.nodeX509Cert = cert
|
||||||
|
mux.nodeX509Key = key
|
||||||
|
|
||||||
|
return mux, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serve handles connections from ln and multiplexes then across registered listener.
|
||||||
|
func (mux *Mux) Serve() error {
|
||||||
|
mux.Logger.Printf("mux serving on %s, advertising %s", mux.ln.Addr().String(), mux.addr)
|
||||||
|
|
||||||
|
for {
|
||||||
|
// Wait for the next connection.
|
||||||
|
// If it returns a temporary error then simply retry.
|
||||||
|
// If it returns any other error then exit immediately.
|
||||||
|
conn, err := mux.ln.Accept()
|
||||||
|
if err, ok := err.(interface {
|
||||||
|
Temporary() bool
|
||||||
|
}); ok && err.Temporary() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
// Wait for all connections to be demuxed
|
||||||
|
mux.wg.Wait()
|
||||||
|
for _, ln := range mux.m {
|
||||||
|
close(ln.c)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Demux in a goroutine to
|
||||||
|
mux.wg.Add(1)
|
||||||
|
go mux.handleConn(conn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns status of the mux.
|
||||||
|
func (mux *Mux) Stats() (interface{}, error) {
|
||||||
|
s := map[string]string{
|
||||||
|
"addr": mux.addr.String(),
|
||||||
|
"timeout": mux.Timeout.String(),
|
||||||
|
"encrypted": strconv.FormatBool(mux.remoteEncrypted),
|
||||||
|
}
|
||||||
|
|
||||||
|
if mux.remoteEncrypted {
|
||||||
|
s["certificate"] = mux.nodeX509Cert
|
||||||
|
s["key"] = mux.nodeX509Key
|
||||||
|
s["ca_certificate"] = mux.nodeX509CACert
|
||||||
|
s["skip_verify"] = strconv.FormatBool(mux.InsecureSkipVerify)
|
||||||
|
}
|
||||||
|
|
||||||
|
return s, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mux *Mux) handleConn(conn net.Conn) {
|
||||||
|
defer mux.wg.Done()
|
||||||
|
// Set a read deadline so connections with no data don't timeout.
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(mux.Timeout)); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
mux.Logger.Printf("tcp.Mux: cannot set read deadline: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read first byte from connection to determine handler.
|
||||||
|
var typ [1]byte
|
||||||
|
if _, err := io.ReadFull(conn, typ[:]); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
mux.Logger.Printf("tcp.Mux: cannot read header byte: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset read deadline and let the listener handle that.
|
||||||
|
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||||
|
conn.Close()
|
||||||
|
mux.Logger.Printf("tcp.Mux: cannot reset set read deadline: %s", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve handler based on first byte.
|
||||||
|
handler := mux.m[typ[0]]
|
||||||
|
if handler == nil {
|
||||||
|
conn.Close()
|
||||||
|
mux.Logger.Printf("tcp.Mux: handler not registered: %d", typ[0])
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send connection to handler. The handler is responsible for closing the connection.
|
||||||
|
handler.c <- conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen returns a listener identified by header.
|
||||||
|
// Any connection accepted by mux is multiplexed based on the initial header byte.
|
||||||
|
func (mux *Mux) Listen(header byte) *Layer {
|
||||||
|
// Ensure two listeners are not created for the same header byte.
|
||||||
|
if _, ok := mux.m[header]; ok {
|
||||||
|
panic(fmt.Sprintf("listener already registered under header byte: %d", header))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new listener and assign it.
|
||||||
|
ln := &listener{
|
||||||
|
c: make(chan net.Conn),
|
||||||
|
}
|
||||||
|
mux.m[header] = ln
|
||||||
|
|
||||||
|
layer := &Layer{
|
||||||
|
ln: ln,
|
||||||
|
header: header,
|
||||||
|
addr: mux.addr,
|
||||||
|
remoteEncrypted: mux.remoteEncrypted,
|
||||||
|
skipVerify: mux.InsecureSkipVerify,
|
||||||
|
nodeX509CACert: mux.nodeX509CACert,
|
||||||
|
tlsConfig: mux.tlsConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
return layer
|
||||||
|
}
|
||||||
|
|
||||||
|
// listener is a receiver for connections received by Mux.
|
||||||
|
type listener struct {
|
||||||
|
c chan net.Conn
|
||||||
|
}
|
||||||
|
|
||||||
|
// Accept waits for and returns the next connection to the listener.
|
||||||
|
func (ln *listener) Accept() (c net.Conn, err error) {
|
||||||
|
conn, ok := <-ln.c
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("network connection closed")
|
||||||
|
}
|
||||||
|
return conn, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close is a no-op. The mux's listener should be closed instead.
|
||||||
|
func (ln *listener) Close() error { return nil }
|
||||||
|
|
||||||
|
// Addr always returns nil
|
||||||
|
func (ln *listener) Addr() net.Addr { return nil }
|
||||||
|
|
||||||
|
// newTLSListener returns a net listener which encrypts the traffic using TLS.
|
||||||
|
func newTLSListener(ln net.Listener, certFile, keyFile, caCertFile string) (net.Listener, error) {
|
||||||
|
config, err := createTLSConfig(certFile, keyFile, caCertFile)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return tls.NewListener(ln, config), nil
|
||||||
|
}
|
@ -0,0 +1,235 @@
|
|||||||
|
package tcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/tls"
|
||||||
|
"io"
|
||||||
|
"io/ioutil"
|
||||||
|
"log"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
"testing/quick"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/rqlite/rqlite/testdata/x509"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Ensure the muxer can split a listener's connections across multiple listeners.
|
||||||
|
func TestMux(t *testing.T) {
|
||||||
|
if err := quick.Check(func(n uint8, msg []byte) bool {
|
||||||
|
if testing.Verbose() {
|
||||||
|
if len(msg) == 0 {
|
||||||
|
log.Printf("n=%d, <no message>", n)
|
||||||
|
} else {
|
||||||
|
log.Printf("n=%d, hdr=%d, len=%d", n, msg[0], len(msg))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
|
// Open single listener on random port.
|
||||||
|
tcpListener := mustTCPListener("127.0.0.1:0")
|
||||||
|
defer tcpListener.Close()
|
||||||
|
|
||||||
|
// Setup muxer & listeners.
|
||||||
|
mux, err := NewMux(tcpListener, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create mux: %s", err.Error())
|
||||||
|
}
|
||||||
|
mux.Timeout = 200 * time.Millisecond
|
||||||
|
if !testing.Verbose() {
|
||||||
|
mux.Logger = log.New(ioutil.Discard, "", 0)
|
||||||
|
}
|
||||||
|
for i := uint8(0); i < n; i++ {
|
||||||
|
ln := mux.Listen(byte(i))
|
||||||
|
|
||||||
|
wg.Add(1)
|
||||||
|
go func(i uint8, ln net.Listener) {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Wait for a connection for this listener.
|
||||||
|
conn, err := ln.Accept()
|
||||||
|
if conn != nil {
|
||||||
|
defer conn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is no message or the header byte
|
||||||
|
// doesn't match then expect close.
|
||||||
|
if len(msg) == 0 || msg[0] != byte(i) {
|
||||||
|
if err == nil || err.Error() != "network connection closed" {
|
||||||
|
t.Fatalf("unexpected error: %s", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the header byte matches this listener
|
||||||
|
// then expect a connection and read the message.
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if _, err := io.CopyN(&buf, conn, int64(len(msg)-1)); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if !bytes.Equal(msg[1:], buf.Bytes()) {
|
||||||
|
t.Fatalf("message mismatch:\n\nexp=%x\n\ngot=%x\n\n", msg[1:], buf.Bytes())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write response.
|
||||||
|
if _, err := conn.Write([]byte("OK")); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}(i, ln)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Begin serving from the listener.
|
||||||
|
go mux.Serve()
|
||||||
|
|
||||||
|
// Write message to TCP listener and read OK response.
|
||||||
|
conn, err := net.Dial("tcp", tcpListener.Addr().String())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
} else if _, err = conn.Write(msg); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Read the response into the buffer.
|
||||||
|
var resp [2]byte
|
||||||
|
_, err = io.ReadFull(conn, resp[:])
|
||||||
|
|
||||||
|
// If the message header is less than n then expect a response.
|
||||||
|
// Otherwise we should get an EOF because the mux closed.
|
||||||
|
if len(msg) > 0 && uint8(msg[0]) < n {
|
||||||
|
if string(resp[:]) != `OK` {
|
||||||
|
t.Fatalf("unexpected response: %s", resp[:])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if err == nil || (err != io.EOF && !(strings.Contains(err.Error(), "connection reset by peer") ||
|
||||||
|
strings.Contains(err.Error(), "closed by the remote host"))) {
|
||||||
|
t.Fatalf("unexpected error: %s", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close connection.
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close original TCP listener and wait for all goroutines to close.
|
||||||
|
tcpListener.Close()
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
return true
|
||||||
|
}, nil); err != nil {
|
||||||
|
t.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMux_Advertise(t *testing.T) {
|
||||||
|
// Setup muxer.
|
||||||
|
tcpListener := mustTCPListener("127.0.0.1:0")
|
||||||
|
defer tcpListener.Close()
|
||||||
|
|
||||||
|
addr := &mockAddr{
|
||||||
|
Nwk: "tcp",
|
||||||
|
Addr: "rqlite.com:8081",
|
||||||
|
}
|
||||||
|
|
||||||
|
mux, err := NewMux(tcpListener, addr)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create mux: %s", err.Error())
|
||||||
|
}
|
||||||
|
mux.Timeout = 200 * time.Millisecond
|
||||||
|
if !testing.Verbose() {
|
||||||
|
mux.Logger = log.New(ioutil.Discard, "", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
layer := mux.Listen(1)
|
||||||
|
if layer.Addr().String() != addr.Addr {
|
||||||
|
t.Fatalf("layer advertise address not correct, exp %s, got %s",
|
||||||
|
layer.Addr().String(), addr.Addr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure two handlers cannot be registered for the same header byte.
|
||||||
|
func TestMux_Listen_ErrAlreadyRegistered(t *testing.T) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != `listener already registered under header byte: 5` {
|
||||||
|
t.Fatalf("unexpected recover: %#v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Register two listeners with the same header byte.
|
||||||
|
tcpListener := mustTCPListener("127.0.0.1:0")
|
||||||
|
mux, err := NewMux(tcpListener, nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create mux: %s", err.Error())
|
||||||
|
}
|
||||||
|
mux.Listen(5)
|
||||||
|
mux.Listen(5)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSMux(t *testing.T) {
|
||||||
|
tcpListener := mustTCPListener("127.0.0.1:0")
|
||||||
|
defer tcpListener.Close()
|
||||||
|
|
||||||
|
cert := x509.CertFile("")
|
||||||
|
defer os.Remove(cert)
|
||||||
|
key := x509.KeyFile("")
|
||||||
|
defer os.Remove(key)
|
||||||
|
|
||||||
|
mux, err := NewTLSMux(tcpListener, nil, cert, key, "")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("failed to create mux: %s", err.Error())
|
||||||
|
}
|
||||||
|
go mux.Serve()
|
||||||
|
|
||||||
|
// Verify that the listener is secured.
|
||||||
|
conn, err := tls.Dial("tcp", tcpListener.Addr().String(), &tls.Config{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
state := conn.ConnectionState()
|
||||||
|
if !state.HandshakeComplete {
|
||||||
|
t.Fatal("connection handshake failed to complete")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTLSMux_Fail(t *testing.T) {
|
||||||
|
tcpListener := mustTCPListener("127.0.0.1:0")
|
||||||
|
defer tcpListener.Close()
|
||||||
|
|
||||||
|
cert := x509.CertFile("")
|
||||||
|
defer os.Remove(cert)
|
||||||
|
key := x509.KeyFile("")
|
||||||
|
defer os.Remove(key)
|
||||||
|
|
||||||
|
_, err := NewTLSMux(tcpListener, nil, "xxxx", "yyyy", "")
|
||||||
|
if err == nil {
|
||||||
|
t.Fatalf("created mux unexpectedly with bad resources")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockAddr struct {
|
||||||
|
Nwk string
|
||||||
|
Addr string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAddr) Network() string {
|
||||||
|
return m.Nwk
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockAddr) String() string {
|
||||||
|
return m.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
// mustTCPListener returns a listener on bind, or panics.
|
||||||
|
func mustTCPListener(bind string) net.Listener {
|
||||||
|
l, err := net.Listen("tcp", bind)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return l
|
||||||
|
}
|
Loading…
Reference in New Issue