From 3be97dd61e7c22402e5683bd95075e54eeeb1f9a Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 24 Apr 2021 21:14:57 -0400 Subject: [PATCH] Add mux code back to tcp package Tests pass, nothing else done yet. --- tcp/mux.go | 276 +++++++++++++++++++++++++++++++++++++++++++++++ tcp/mux_test.go | 235 ++++++++++++++++++++++++++++++++++++++++ tcp/transport.go | 3 +- 3 files changed, 513 insertions(+), 1 deletion(-) create mode 100644 tcp/mux.go create mode 100644 tcp/mux_test.go diff --git a/tcp/mux.go b/tcp/mux.go new file mode 100644 index 00000000..1dbf0769 --- /dev/null +++ b/tcp/mux.go @@ -0,0 +1,276 @@ +package tcp + +import ( + "crypto/tls" + "errors" + "fmt" + "io" + "log" + "net" + "os" + "strconv" + "sync" + "time" +) + +const ( + // DefaultTimeout is the default length of time to wait for first byte. + DefaultTimeout = 30 * time.Second +) + +// Layer represents the connection between nodes. +type Layer struct { + ln net.Listener + header byte + addr net.Addr + + remoteEncrypted bool + skipVerify bool + nodeX509CACert string + tlsConfig *tls.Config +} + +// Addr returns the local address for the layer. +func (l *Layer) Addr() net.Addr { + return l.addr +} + +// Dial creates a new network connection. +func (l *Layer) Dial(addr string, timeout time.Duration) (net.Conn, error) { + dialer := &net.Dialer{Timeout: timeout} + + var err error + var conn net.Conn + if l.remoteEncrypted { + conn, err = tls.DialWithDialer(dialer, "tcp", addr, l.tlsConfig) + } else { + conn, err = dialer.Dial("tcp", addr) + } + if err != nil { + return nil, err + } + + // Write a marker byte to indicate message type. + _, err = conn.Write([]byte{l.header}) + if err != nil { + conn.Close() + return nil, err + } + return conn, err +} + +// Accept waits for the next connection. +func (l *Layer) Accept() (net.Conn, error) { return l.ln.Accept() } + +// Close closes the layer. +func (l *Layer) Close() error { return l.ln.Close() } + +// Mux multiplexes a network connection. +type Mux struct { + ln net.Listener + addr net.Addr + m map[byte]*listener + + wg sync.WaitGroup + + remoteEncrypted bool + + // The amount of time to wait for the first header byte. + Timeout time.Duration + + // Out-of-band error logger + Logger *log.Logger + + // Path to root X.509 certificate. + nodeX509CACert string + + // Path to X509 certificate + nodeX509Cert string + + // Path to X509 key. + nodeX509Key string + + // Whether to skip verification of other nodes' certificates. + InsecureSkipVerify bool + + tlsConfig *tls.Config +} + +// NewMux returns a new instance of Mux for ln. If adv is nil, +// then the addr of ln is used. +func NewMux(ln net.Listener, adv net.Addr) (*Mux, error) { + addr := adv + if addr == nil { + addr = ln.Addr() + } + + return &Mux{ + ln: ln, + addr: addr, + m: make(map[byte]*listener), + Timeout: DefaultTimeout, + Logger: log.New(os.Stderr, "[mux] ", log.LstdFlags), + }, nil +} + +// NewTLSMux returns a new instance of Mux for ln, and encrypts all traffic +// using TLS. If adv is nil, then the addr of ln is used. +func NewTLSMux(ln net.Listener, adv net.Addr, cert, key, caCert string) (*Mux, error) { + mux, err := NewMux(ln, adv) + if err != nil { + return nil, err + } + + mux.tlsConfig, err = createTLSConfig(cert, key, caCert) + if err != nil { + return nil, err + } + + mux.ln = tls.NewListener(ln, mux.tlsConfig) + mux.remoteEncrypted = true + mux.nodeX509CACert = caCert + mux.nodeX509Cert = cert + mux.nodeX509Key = key + + return mux, nil +} + +// Serve handles connections from ln and multiplexes then across registered listener. +func (mux *Mux) Serve() error { + mux.Logger.Printf("mux serving on %s, advertising %s", mux.ln.Addr().String(), mux.addr) + + for { + // Wait for the next connection. + // If it returns a temporary error then simply retry. + // If it returns any other error then exit immediately. + conn, err := mux.ln.Accept() + if err, ok := err.(interface { + Temporary() bool + }); ok && err.Temporary() { + continue + } + if err != nil { + // Wait for all connections to be demuxed + mux.wg.Wait() + for _, ln := range mux.m { + close(ln.c) + } + return err + } + + // Demux in a goroutine to + mux.wg.Add(1) + go mux.handleConn(conn) + } +} + +// Stats returns status of the mux. +func (mux *Mux) Stats() (interface{}, error) { + s := map[string]string{ + "addr": mux.addr.String(), + "timeout": mux.Timeout.String(), + "encrypted": strconv.FormatBool(mux.remoteEncrypted), + } + + if mux.remoteEncrypted { + s["certificate"] = mux.nodeX509Cert + s["key"] = mux.nodeX509Key + s["ca_certificate"] = mux.nodeX509CACert + s["skip_verify"] = strconv.FormatBool(mux.InsecureSkipVerify) + } + + return s, nil +} + +func (mux *Mux) handleConn(conn net.Conn) { + defer mux.wg.Done() + // Set a read deadline so connections with no data don't timeout. + if err := conn.SetReadDeadline(time.Now().Add(mux.Timeout)); err != nil { + conn.Close() + mux.Logger.Printf("tcp.Mux: cannot set read deadline: %s", err) + return + } + + // Read first byte from connection to determine handler. + var typ [1]byte + if _, err := io.ReadFull(conn, typ[:]); err != nil { + conn.Close() + mux.Logger.Printf("tcp.Mux: cannot read header byte: %s", err) + return + } + + // Reset read deadline and let the listener handle that. + if err := conn.SetReadDeadline(time.Time{}); err != nil { + conn.Close() + mux.Logger.Printf("tcp.Mux: cannot reset set read deadline: %s", err) + return + } + + // Retrieve handler based on first byte. + handler := mux.m[typ[0]] + if handler == nil { + conn.Close() + mux.Logger.Printf("tcp.Mux: handler not registered: %d", typ[0]) + return + } + + // Send connection to handler. The handler is responsible for closing the connection. + handler.c <- conn +} + +// Listen returns a listener identified by header. +// Any connection accepted by mux is multiplexed based on the initial header byte. +func (mux *Mux) Listen(header byte) *Layer { + // Ensure two listeners are not created for the same header byte. + if _, ok := mux.m[header]; ok { + panic(fmt.Sprintf("listener already registered under header byte: %d", header)) + } + + // Create a new listener and assign it. + ln := &listener{ + c: make(chan net.Conn), + } + mux.m[header] = ln + + layer := &Layer{ + ln: ln, + header: header, + addr: mux.addr, + remoteEncrypted: mux.remoteEncrypted, + skipVerify: mux.InsecureSkipVerify, + nodeX509CACert: mux.nodeX509CACert, + tlsConfig: mux.tlsConfig, + } + + return layer +} + +// listener is a receiver for connections received by Mux. +type listener struct { + c chan net.Conn +} + +// Accept waits for and returns the next connection to the listener. +func (ln *listener) Accept() (c net.Conn, err error) { + conn, ok := <-ln.c + if !ok { + return nil, errors.New("network connection closed") + } + return conn, nil +} + +// Close is a no-op. The mux's listener should be closed instead. +func (ln *listener) Close() error { return nil } + +// Addr always returns nil +func (ln *listener) Addr() net.Addr { return nil } + +// newTLSListener returns a net listener which encrypts the traffic using TLS. +func newTLSListener(ln net.Listener, certFile, keyFile, caCertFile string) (net.Listener, error) { + config, err := createTLSConfig(certFile, keyFile, caCertFile) + if err != nil { + return nil, err + } + + return tls.NewListener(ln, config), nil +} diff --git a/tcp/mux_test.go b/tcp/mux_test.go new file mode 100644 index 00000000..88e85c0b --- /dev/null +++ b/tcp/mux_test.go @@ -0,0 +1,235 @@ +package tcp + +import ( + "bytes" + "crypto/tls" + "io" + "io/ioutil" + "log" + "net" + "os" + "strings" + "sync" + "testing" + "testing/quick" + "time" + + "github.com/rqlite/rqlite/testdata/x509" +) + +// Ensure the muxer can split a listener's connections across multiple listeners. +func TestMux(t *testing.T) { + if err := quick.Check(func(n uint8, msg []byte) bool { + if testing.Verbose() { + if len(msg) == 0 { + log.Printf("n=%d, ", n) + } else { + log.Printf("n=%d, hdr=%d, len=%d", n, msg[0], len(msg)) + } + } + + var wg sync.WaitGroup + + // Open single listener on random port. + tcpListener := mustTCPListener("127.0.0.1:0") + defer tcpListener.Close() + + // Setup muxer & listeners. + mux, err := NewMux(tcpListener, nil) + if err != nil { + t.Fatalf("failed to create mux: %s", err.Error()) + } + mux.Timeout = 200 * time.Millisecond + if !testing.Verbose() { + mux.Logger = log.New(ioutil.Discard, "", 0) + } + for i := uint8(0); i < n; i++ { + ln := mux.Listen(byte(i)) + + wg.Add(1) + go func(i uint8, ln net.Listener) { + defer wg.Done() + + // Wait for a connection for this listener. + conn, err := ln.Accept() + if conn != nil { + defer conn.Close() + } + + // If there is no message or the header byte + // doesn't match then expect close. + if len(msg) == 0 || msg[0] != byte(i) { + if err == nil || err.Error() != "network connection closed" { + t.Fatalf("unexpected error: %s", err) + } + return + } + + // If the header byte matches this listener + // then expect a connection and read the message. + var buf bytes.Buffer + if _, err := io.CopyN(&buf, conn, int64(len(msg)-1)); err != nil { + t.Fatal(err) + } else if !bytes.Equal(msg[1:], buf.Bytes()) { + t.Fatalf("message mismatch:\n\nexp=%x\n\ngot=%x\n\n", msg[1:], buf.Bytes()) + } + + // Write response. + if _, err := conn.Write([]byte("OK")); err != nil { + t.Fatal(err) + } + }(i, ln) + } + + // Begin serving from the listener. + go mux.Serve() + + // Write message to TCP listener and read OK response. + conn, err := net.Dial("tcp", tcpListener.Addr().String()) + if err != nil { + t.Fatal(err) + } else if _, err = conn.Write(msg); err != nil { + t.Fatal(err) + } + + // Read the response into the buffer. + var resp [2]byte + _, err = io.ReadFull(conn, resp[:]) + + // If the message header is less than n then expect a response. + // Otherwise we should get an EOF because the mux closed. + if len(msg) > 0 && uint8(msg[0]) < n { + if string(resp[:]) != `OK` { + t.Fatalf("unexpected response: %s", resp[:]) + } + } else { + if err == nil || (err != io.EOF && !(strings.Contains(err.Error(), "connection reset by peer") || + strings.Contains(err.Error(), "closed by the remote host"))) { + t.Fatalf("unexpected error: %s", err) + } + } + + // Close connection. + if err := conn.Close(); err != nil { + t.Fatal(err) + } + + // Close original TCP listener and wait for all goroutines to close. + tcpListener.Close() + wg.Wait() + + return true + }, nil); err != nil { + t.Error(err) + } +} + +func TestMux_Advertise(t *testing.T) { + // Setup muxer. + tcpListener := mustTCPListener("127.0.0.1:0") + defer tcpListener.Close() + + addr := &mockAddr{ + Nwk: "tcp", + Addr: "rqlite.com:8081", + } + + mux, err := NewMux(tcpListener, addr) + if err != nil { + t.Fatalf("failed to create mux: %s", err.Error()) + } + mux.Timeout = 200 * time.Millisecond + if !testing.Verbose() { + mux.Logger = log.New(ioutil.Discard, "", 0) + } + + layer := mux.Listen(1) + if layer.Addr().String() != addr.Addr { + t.Fatalf("layer advertise address not correct, exp %s, got %s", + layer.Addr().String(), addr.Addr) + } +} + +// Ensure two handlers cannot be registered for the same header byte. +func TestMux_Listen_ErrAlreadyRegistered(t *testing.T) { + defer func() { + if r := recover(); r != `listener already registered under header byte: 5` { + t.Fatalf("unexpected recover: %#v", r) + } + }() + + // Register two listeners with the same header byte. + tcpListener := mustTCPListener("127.0.0.1:0") + mux, err := NewMux(tcpListener, nil) + if err != nil { + t.Fatalf("failed to create mux: %s", err.Error()) + } + mux.Listen(5) + mux.Listen(5) +} + +func TestTLSMux(t *testing.T) { + tcpListener := mustTCPListener("127.0.0.1:0") + defer tcpListener.Close() + + cert := x509.CertFile("") + defer os.Remove(cert) + key := x509.KeyFile("") + defer os.Remove(key) + + mux, err := NewTLSMux(tcpListener, nil, cert, key, "") + if err != nil { + t.Fatalf("failed to create mux: %s", err.Error()) + } + go mux.Serve() + + // Verify that the listener is secured. + conn, err := tls.Dial("tcp", tcpListener.Addr().String(), &tls.Config{ + InsecureSkipVerify: true, + }) + if err != nil { + t.Fatal(err) + } + + state := conn.ConnectionState() + if !state.HandshakeComplete { + t.Fatal("connection handshake failed to complete") + } +} + +func TestTLSMux_Fail(t *testing.T) { + tcpListener := mustTCPListener("127.0.0.1:0") + defer tcpListener.Close() + + cert := x509.CertFile("") + defer os.Remove(cert) + key := x509.KeyFile("") + defer os.Remove(key) + + _, err := NewTLSMux(tcpListener, nil, "xxxx", "yyyy", "") + if err == nil { + t.Fatalf("created mux unexpectedly with bad resources") + } +} + +type mockAddr struct { + Nwk string + Addr string +} + +func (m *mockAddr) Network() string { + return m.Nwk +} + +func (m *mockAddr) String() string { + return m.Addr +} + +// mustTCPListener returns a listener on bind, or panics. +func mustTCPListener(bind string) net.Listener { + l, err := net.Listen("tcp", bind) + if err != nil { + panic(err) + } + return l +} diff --git a/tcp/transport.go b/tcp/transport.go index c2df991d..a4a0077c 100644 --- a/tcp/transport.go +++ b/tcp/transport.go @@ -104,7 +104,8 @@ func (t *Transport) Addr() net.Addr { return t.ln.Addr() } -// createTLSConfig returns a TLS config from the given cert and key. +// createTLSConfig returns a TLS config from the given cert, key and optionally +// Certificate Authority cert. func createTLSConfig(certFile, keyFile, caCertFile string) (*tls.Config, error) { var err error config := &tls.Config{}