Merge pull request #605 from rqlite/broadcast_meta
Broadcast Store meta via standard consensusmaster
commit
f9388f8344
@ -1,180 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
connectionTimeout = 10 * time.Second
|
||||
)
|
||||
|
||||
var respOKMarshalled []byte
|
||||
|
||||
func init() {
|
||||
var err error
|
||||
respOKMarshalled, err = json.Marshal(response{})
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("unable to JSON marshal OK response: %s", err.Error()))
|
||||
}
|
||||
}
|
||||
|
||||
type response struct {
|
||||
Code int `json:"code,omitempty"`
|
||||
Message string `json:"message,omitempty"`
|
||||
}
|
||||
|
||||
// Transport is the interface the network service must provide.
|
||||
type Transport interface {
|
||||
net.Listener
|
||||
|
||||
// Dial is used to create a new outgoing connection
|
||||
Dial(address string, timeout time.Duration) (net.Conn, error)
|
||||
}
|
||||
|
||||
// Store represents a store of information, managed via consensus.
|
||||
type Store interface {
|
||||
// Leader returns the address of the leader of the consensus system.
|
||||
LeaderAddr() string
|
||||
|
||||
// UpdateAPIPeers updates the API peers on the store.
|
||||
UpdateAPIPeers(peers map[string]string) error
|
||||
}
|
||||
|
||||
// Service allows access to the cluster and associated meta data,
|
||||
// via consensus.
|
||||
type Service struct {
|
||||
tn Transport
|
||||
store Store
|
||||
addr net.Addr
|
||||
|
||||
wg sync.WaitGroup
|
||||
|
||||
logger *log.Logger
|
||||
}
|
||||
|
||||
// NewService returns a new instance of the cluster service
|
||||
func NewService(tn Transport, store Store) *Service {
|
||||
return &Service{
|
||||
tn: tn,
|
||||
store: store,
|
||||
addr: tn.Addr(),
|
||||
logger: log.New(os.Stderr, "[cluster] ", log.LstdFlags),
|
||||
}
|
||||
}
|
||||
|
||||
// Open opens the Service.
|
||||
func (s *Service) Open() error {
|
||||
s.wg.Add(1)
|
||||
go s.serve()
|
||||
s.logger.Println("service listening on", s.tn.Addr())
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the service.
|
||||
func (s *Service) Close() error {
|
||||
s.tn.Close()
|
||||
s.wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Addr returns the address the service is listening on.
|
||||
func (s *Service) Addr() string {
|
||||
return s.addr.String()
|
||||
}
|
||||
|
||||
// SetPeer will set the mapping between raftAddr and apiAddr for the entire cluster.
|
||||
func (s *Service) SetPeer(raftAddr, apiAddr string) error {
|
||||
peer := map[string]string{
|
||||
raftAddr: apiAddr,
|
||||
}
|
||||
|
||||
// Try the local store. It might be the leader.
|
||||
err := s.store.UpdateAPIPeers(peer)
|
||||
if err == nil {
|
||||
// All done! Aren't we lucky?
|
||||
return nil
|
||||
}
|
||||
|
||||
// Try talking to the leader over the network.
|
||||
if leader := s.store.LeaderAddr(); leader == "" {
|
||||
return fmt.Errorf("no leader available")
|
||||
}
|
||||
conn, err := s.tn.Dial(s.store.LeaderAddr(), connectionTimeout)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
b, err := json.Marshal(peer)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := conn.Write(b); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Wait for the response and verify the operation went through.
|
||||
resp := response{}
|
||||
d := json.NewDecoder(conn)
|
||||
err = d.Decode(&resp)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if resp.Code != 0 {
|
||||
return fmt.Errorf(resp.Message)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) serve() error {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
conn, err := s.tn.Accept()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
go s.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) handleConn(conn net.Conn) {
|
||||
s.logger.Printf("received connection from %s", conn.RemoteAddr().String())
|
||||
|
||||
// Only handles peers updates for now.
|
||||
peers := make(map[string]string)
|
||||
d := json.NewDecoder(conn)
|
||||
err := d.Decode(&peers)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Update the peers.
|
||||
if err := s.store.UpdateAPIPeers(peers); err != nil {
|
||||
resp := response{1, err.Error()}
|
||||
b, err := json.Marshal(resp)
|
||||
if err != nil {
|
||||
conn.Close() // Only way left to signal.
|
||||
} else {
|
||||
if _, err := conn.Write(b); err != nil {
|
||||
conn.Close() // Only way left to signal.
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Let the remote node know everything went OK.
|
||||
if _, err := conn.Write(respOKMarshalled); err != nil {
|
||||
conn.Close() // Only way left to signal.
|
||||
}
|
||||
return
|
||||
}
|
@ -1,194 +0,0 @@
|
||||
package cluster
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_NewServiceOpenClose(t *testing.T) {
|
||||
ml := mustNewMockTransport()
|
||||
ms := &mockStore{}
|
||||
s := NewService(ml, ms)
|
||||
if s == nil {
|
||||
t.Fatalf("failed to create cluster service")
|
||||
}
|
||||
|
||||
if err := s.Open(); err != nil {
|
||||
t.Fatalf("failed to open cluster service")
|
||||
}
|
||||
if err := s.Close(); err != nil {
|
||||
t.Fatalf("failed to close cluster service")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SetAPIPeer(t *testing.T) {
|
||||
raftAddr, apiAddr := "localhost:4002", "localhost:4001"
|
||||
|
||||
s, _, ms := mustNewOpenService()
|
||||
defer s.Close()
|
||||
if err := s.SetPeer(raftAddr, apiAddr); err != nil {
|
||||
t.Fatalf("failed to set peer: %s", err.Error())
|
||||
}
|
||||
|
||||
if ms.peers[raftAddr] != apiAddr {
|
||||
t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr])
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SetAPIPeerNetwork(t *testing.T) {
|
||||
raftAddr, apiAddr := "localhost:4002", "localhost:4001"
|
||||
|
||||
s, _, ms := mustNewOpenService()
|
||||
defer s.Close()
|
||||
|
||||
raddr, err := net.ResolveTCPAddr("tcp", s.Addr())
|
||||
if err != nil {
|
||||
t.Fatalf("failed to resolve remote uster ervice address: %s", err.Error())
|
||||
}
|
||||
|
||||
conn, err := net.DialTCP("tcp4", nil, raddr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to remote cluster service: %s", err.Error())
|
||||
}
|
||||
if _, err := conn.Write([]byte(fmt.Sprintf(`{"%s": "%s"}`, raftAddr, apiAddr))); err != nil {
|
||||
t.Fatalf("failed to write to remote cluster service: %s", err.Error())
|
||||
}
|
||||
|
||||
resp := response{}
|
||||
d := json.NewDecoder(conn)
|
||||
err = d.Decode(&resp)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode response: %s", err.Error())
|
||||
}
|
||||
|
||||
if resp.Code != 0 {
|
||||
t.Fatalf("response code was non-zero")
|
||||
}
|
||||
|
||||
if ms.peers[raftAddr] != apiAddr {
|
||||
t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr])
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SetAPIPeerFailUpdate(t *testing.T) {
|
||||
raftAddr, apiAddr := "localhost:4002", "localhost:4001"
|
||||
|
||||
s, _, ms := mustNewOpenService()
|
||||
defer s.Close()
|
||||
ms.failUpdateAPIPeers = true
|
||||
|
||||
// Attempt to set peer without a leader
|
||||
if err := s.SetPeer(raftAddr, apiAddr); err == nil {
|
||||
t.Fatalf("no error returned by set peer when no leader")
|
||||
}
|
||||
|
||||
// Start a network server.
|
||||
tn, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to open test server: %s", err.Error())
|
||||
}
|
||||
ms.leader = tn.Addr().String()
|
||||
|
||||
c := make(chan map[string]string, 1)
|
||||
go func() {
|
||||
conn, err := tn.Accept()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to accept connection from cluster: %s", err.Error())
|
||||
}
|
||||
t.Logf("test server received connection from: %s", conn.RemoteAddr())
|
||||
|
||||
peers := make(map[string]string)
|
||||
d := json.NewDecoder(conn)
|
||||
err = d.Decode(&peers)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to decode message from cluster: %s", err.Error())
|
||||
}
|
||||
|
||||
// Response OK.
|
||||
// Let the remote node know everything went OK.
|
||||
if _, err := conn.Write(respOKMarshalled); err != nil {
|
||||
t.Fatalf("failed to respond to cluster: %s", err.Error())
|
||||
}
|
||||
|
||||
c <- peers
|
||||
}()
|
||||
|
||||
if err := s.SetPeer(raftAddr, apiAddr); err != nil {
|
||||
t.Fatalf("failed to set peer on cluster: %s", err.Error())
|
||||
}
|
||||
|
||||
peers := <-c
|
||||
if peers[raftAddr] != apiAddr {
|
||||
t.Fatalf("peer not set correctly, exp %s, got %s", apiAddr, ms.peers[raftAddr])
|
||||
}
|
||||
}
|
||||
|
||||
func mustNewOpenService() (*Service, *mockTransport, *mockStore) {
|
||||
ml := mustNewMockTransport()
|
||||
ms := newMockStore()
|
||||
s := NewService(ml, ms)
|
||||
if err := s.Open(); err != nil {
|
||||
panic("failed to open new service")
|
||||
}
|
||||
return s, ml, ms
|
||||
}
|
||||
|
||||
type mockTransport struct {
|
||||
tn net.Listener
|
||||
}
|
||||
|
||||
func mustNewMockTransport() *mockTransport {
|
||||
tn, err := net.Listen("tcp", "localhost:0")
|
||||
if err != nil {
|
||||
panic("failed to create mock listener")
|
||||
}
|
||||
return &mockTransport{
|
||||
tn: tn,
|
||||
}
|
||||
}
|
||||
|
||||
func (ml *mockTransport) Accept() (c net.Conn, err error) {
|
||||
return ml.tn.Accept()
|
||||
}
|
||||
|
||||
func (ml *mockTransport) Addr() net.Addr {
|
||||
return ml.tn.Addr()
|
||||
}
|
||||
|
||||
func (ml *mockTransport) Close() (err error) {
|
||||
return ml.tn.Close()
|
||||
}
|
||||
|
||||
func (ml *mockTransport) Dial(addr string, t time.Duration) (net.Conn, error) {
|
||||
return net.DialTimeout("tcp", addr, 5*time.Second)
|
||||
}
|
||||
|
||||
type mockStore struct {
|
||||
leader string
|
||||
peers map[string]string
|
||||
failUpdateAPIPeers bool
|
||||
}
|
||||
|
||||
func newMockStore() *mockStore {
|
||||
return &mockStore{
|
||||
peers: make(map[string]string),
|
||||
}
|
||||
}
|
||||
|
||||
func (ms *mockStore) LeaderAddr() string {
|
||||
return ms.leader
|
||||
}
|
||||
|
||||
func (ms *mockStore) UpdateAPIPeers(peers map[string]string) error {
|
||||
if ms.failUpdateAPIPeers {
|
||||
return fmt.Errorf("forced fail")
|
||||
}
|
||||
|
||||
for k, v := range peers {
|
||||
ms.peers[k] = v
|
||||
}
|
||||
return nil
|
||||
}
|
@ -0,0 +1,12 @@
|
||||
package store
|
||||
|
||||
// DBConfig represents the configuration of the underlying SQLite database.
|
||||
type DBConfig struct {
|
||||
DSN string // Any custom DSN
|
||||
Memory bool // Whether the database is in-memory only.
|
||||
}
|
||||
|
||||
// NewDBConfig returns a new DB config instance.
|
||||
func NewDBConfig(dsn string, memory bool) *DBConfig {
|
||||
return &DBConfig{DSN: dsn, Memory: memory}
|
||||
}
|
@ -0,0 +1,14 @@
|
||||
package store
|
||||
|
||||
// Server represents another node in the cluster.
|
||||
type Server struct {
|
||||
ID string `json:"id,omitempty"`
|
||||
Addr string `json:"addr,omitempty"`
|
||||
}
|
||||
|
||||
// Servers is a set of Servers.
|
||||
type Servers []*Server
|
||||
|
||||
func (s Servers) Less(i, j int) bool { return s[i].ID < s[j].ID }
|
||||
func (s Servers) Len() int { return len(s) }
|
||||
func (s Servers) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
|
@ -0,0 +1,11 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_NewTransport(t *testing.T) {
|
||||
if NewTransport(nil) == nil {
|
||||
t.Fatal("failed to create new Transport")
|
||||
}
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
/*
|
||||
Package tcp provides various TCP-related utilities. The TCP mux code provided by this package originated with InfluxDB.
|
||||
Package tcp provides the internode communication network layer.
|
||||
*/
|
||||
package tcp
|
||||
|
@ -1,301 +0,0 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// DefaultTimeout is the default length of time to wait for first byte.
|
||||
DefaultTimeout = 30 * time.Second
|
||||
)
|
||||
|
||||
// Layer represents the connection between nodes.
|
||||
type Layer struct {
|
||||
ln net.Listener
|
||||
header byte
|
||||
addr net.Addr
|
||||
|
||||
remoteEncrypted bool
|
||||
skipVerify bool
|
||||
nodeX509CACert string
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
// Addr returns the local address for the layer.
|
||||
func (l *Layer) Addr() net.Addr {
|
||||
return l.addr
|
||||
}
|
||||
|
||||
// Dial creates a new network connection.
|
||||
func (l *Layer) Dial(addr string, timeout time.Duration) (net.Conn, error) {
|
||||
dialer := &net.Dialer{Timeout: timeout}
|
||||
|
||||
var err error
|
||||
var conn net.Conn
|
||||
if l.remoteEncrypted {
|
||||
conn, err = tls.DialWithDialer(dialer, "tcp", addr, l.tlsConfig)
|
||||
} else {
|
||||
conn, err = dialer.Dial("tcp", addr)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Write a marker byte to indicate message type.
|
||||
_, err = conn.Write([]byte{l.header})
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return nil, err
|
||||
}
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// Accept waits for the next connection.
|
||||
func (l *Layer) Accept() (net.Conn, error) { return l.ln.Accept() }
|
||||
|
||||
// Close closes the layer.
|
||||
func (l *Layer) Close() error { return l.ln.Close() }
|
||||
|
||||
// Mux multiplexes a network connection.
|
||||
type Mux struct {
|
||||
ln net.Listener
|
||||
addr net.Addr
|
||||
m map[byte]*listener
|
||||
|
||||
wg sync.WaitGroup
|
||||
|
||||
remoteEncrypted bool
|
||||
|
||||
// The amount of time to wait for the first header byte.
|
||||
Timeout time.Duration
|
||||
|
||||
// Out-of-band error logger
|
||||
Logger *log.Logger
|
||||
|
||||
// Path to root X.509 certificate.
|
||||
nodeX509CACert string
|
||||
|
||||
// Path to X509 certificate
|
||||
nodeX509Cert string
|
||||
|
||||
// Path to X509 key.
|
||||
nodeX509Key string
|
||||
|
||||
// Whether to skip verification of other nodes' certificates.
|
||||
InsecureSkipVerify bool
|
||||
|
||||
tlsConfig *tls.Config
|
||||
}
|
||||
|
||||
// NewMux returns a new instance of Mux for ln. If adv is nil,
|
||||
// then the addr of ln is used.
|
||||
func NewMux(ln net.Listener, adv net.Addr) (*Mux, error) {
|
||||
addr := adv
|
||||
if addr == nil {
|
||||
addr = ln.Addr()
|
||||
}
|
||||
|
||||
return &Mux{
|
||||
ln: ln,
|
||||
addr: addr,
|
||||
m: make(map[byte]*listener),
|
||||
Timeout: DefaultTimeout,
|
||||
Logger: log.New(os.Stderr, "[mux] ", log.LstdFlags),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewTLSMux returns a new instance of Mux for ln, and encrypts all traffic
|
||||
// using TLS. If adv is nil, then the addr of ln is used.
|
||||
func NewTLSMux(ln net.Listener, adv net.Addr, cert, key, caCert string) (*Mux, error) {
|
||||
mux, err := NewMux(ln, adv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mux.tlsConfig, err = createTLSConfig(cert, key, caCert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mux.ln = tls.NewListener(ln, mux.tlsConfig)
|
||||
mux.remoteEncrypted = true
|
||||
mux.nodeX509CACert = caCert
|
||||
mux.nodeX509Cert = cert
|
||||
mux.nodeX509Key = key
|
||||
|
||||
return mux, nil
|
||||
}
|
||||
|
||||
// Serve handles connections from ln and multiplexes then across registered listener.
|
||||
func (mux *Mux) Serve() error {
|
||||
mux.Logger.Printf("mux serving on %s, advertising %s", mux.ln.Addr().String(), mux.addr)
|
||||
|
||||
for {
|
||||
// Wait for the next connection.
|
||||
// If it returns a temporary error then simply retry.
|
||||
// If it returns any other error then exit immediately.
|
||||
conn, err := mux.ln.Accept()
|
||||
if err, ok := err.(interface {
|
||||
Temporary() bool
|
||||
}); ok && err.Temporary() {
|
||||
continue
|
||||
}
|
||||
if err != nil {
|
||||
// Wait for all connections to be demuxed
|
||||
mux.wg.Wait()
|
||||
for _, ln := range mux.m {
|
||||
close(ln.c)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Demux in a goroutine to
|
||||
mux.wg.Add(1)
|
||||
go mux.handleConn(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// Stats returns status of the mux.
|
||||
func (mux *Mux) Stats() (interface{}, error) {
|
||||
s := map[string]string{
|
||||
"addr": mux.addr.String(),
|
||||
"timeout": mux.Timeout.String(),
|
||||
"encrypted": strconv.FormatBool(mux.remoteEncrypted),
|
||||
}
|
||||
|
||||
if mux.remoteEncrypted {
|
||||
s["certificate"] = mux.nodeX509Cert
|
||||
s["key"] = mux.nodeX509Key
|
||||
s["ca_certificate"] = mux.nodeX509CACert
|
||||
s["skip_verify"] = strconv.FormatBool(mux.InsecureSkipVerify)
|
||||
}
|
||||
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (mux *Mux) handleConn(conn net.Conn) {
|
||||
defer mux.wg.Done()
|
||||
// Set a read deadline so connections with no data don't timeout.
|
||||
if err := conn.SetReadDeadline(time.Now().Add(mux.Timeout)); err != nil {
|
||||
conn.Close()
|
||||
mux.Logger.Printf("tcp.Mux: cannot set read deadline: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Read first byte from connection to determine handler.
|
||||
var typ [1]byte
|
||||
if _, err := io.ReadFull(conn, typ[:]); err != nil {
|
||||
conn.Close()
|
||||
mux.Logger.Printf("tcp.Mux: cannot read header byte: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Reset read deadline and let the listener handle that.
|
||||
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
||||
conn.Close()
|
||||
mux.Logger.Printf("tcp.Mux: cannot reset set read deadline: %s", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Retrieve handler based on first byte.
|
||||
handler := mux.m[typ[0]]
|
||||
if handler == nil {
|
||||
conn.Close()
|
||||
mux.Logger.Printf("tcp.Mux: handler not registered: %d", typ[0])
|
||||
return
|
||||
}
|
||||
|
||||
// Send connection to handler. The handler is responsible for closing the connection.
|
||||
handler.c <- conn
|
||||
}
|
||||
|
||||
// Listen returns a listener identified by header.
|
||||
// Any connection accepted by mux is multiplexed based on the initial header byte.
|
||||
func (mux *Mux) Listen(header byte) *Layer {
|
||||
// Ensure two listeners are not created for the same header byte.
|
||||
if _, ok := mux.m[header]; ok {
|
||||
panic(fmt.Sprintf("listener already registered under header byte: %d", header))
|
||||
}
|
||||
|
||||
// Create a new listener and assign it.
|
||||
ln := &listener{
|
||||
c: make(chan net.Conn),
|
||||
}
|
||||
mux.m[header] = ln
|
||||
|
||||
layer := &Layer{
|
||||
ln: ln,
|
||||
header: header,
|
||||
addr: mux.addr,
|
||||
remoteEncrypted: mux.remoteEncrypted,
|
||||
skipVerify: mux.InsecureSkipVerify,
|
||||
nodeX509CACert: mux.nodeX509CACert,
|
||||
tlsConfig: mux.tlsConfig,
|
||||
}
|
||||
|
||||
return layer
|
||||
}
|
||||
|
||||
// listener is a receiver for connections received by Mux.
|
||||
type listener struct {
|
||||
c chan net.Conn
|
||||
}
|
||||
|
||||
// Accept waits for and returns the next connection to the listener.
|
||||
func (ln *listener) Accept() (c net.Conn, err error) {
|
||||
conn, ok := <-ln.c
|
||||
if !ok {
|
||||
return nil, errors.New("network connection closed")
|
||||
}
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// Close is a no-op. The mux's listener should be closed instead.
|
||||
func (ln *listener) Close() error { return nil }
|
||||
|
||||
// Addr always returns nil
|
||||
func (ln *listener) Addr() net.Addr { return nil }
|
||||
|
||||
// newTLSListener returns a net listener which encrypts the traffic using TLS.
|
||||
func newTLSListener(ln net.Listener, certFile, keyFile, caCertFile string) (net.Listener, error) {
|
||||
config, err := createTLSConfig(certFile, keyFile, caCertFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tls.NewListener(ln, config), nil
|
||||
}
|
||||
|
||||
// createTLSConfig returns a TLS config from the given cert and key.
|
||||
func createTLSConfig(certFile, keyFile, caCertFile string) (*tls.Config, error) {
|
||||
var err error
|
||||
config := &tls.Config{}
|
||||
config.Certificates = make([]tls.Certificate, 1)
|
||||
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if caCertFile != "" {
|
||||
asn1Data, err := ioutil.ReadFile(caCertFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
config.RootCAs = x509.NewCertPool()
|
||||
ok := config.RootCAs.AppendCertsFromPEM([]byte(asn1Data))
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("failed to parse root certificate(s) in %s", caCertFile)
|
||||
}
|
||||
}
|
||||
return config, nil
|
||||
}
|
@ -1,235 +0,0 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"net"
|
||||
"os"
|
||||
"strings"
|
||||
"sync"
|
||||
"testing"
|
||||
"testing/quick"
|
||||
"time"
|
||||
|
||||
"github.com/rqlite/rqlite/testdata/x509"
|
||||
)
|
||||
|
||||
// Ensure the muxer can split a listener's connections across multiple listeners.
|
||||
func TestMux(t *testing.T) {
|
||||
if err := quick.Check(func(n uint8, msg []byte) bool {
|
||||
if testing.Verbose() {
|
||||
if len(msg) == 0 {
|
||||
log.Printf("n=%d, <no message>", n)
|
||||
} else {
|
||||
log.Printf("n=%d, hdr=%d, len=%d", n, msg[0], len(msg))
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
// Open single listener on random port.
|
||||
tcpListener := mustTCPListener("127.0.0.1:0")
|
||||
defer tcpListener.Close()
|
||||
|
||||
// Setup muxer & listeners.
|
||||
mux, err := NewMux(tcpListener, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mux: %s", err.Error())
|
||||
}
|
||||
mux.Timeout = 200 * time.Millisecond
|
||||
if !testing.Verbose() {
|
||||
mux.Logger = log.New(ioutil.Discard, "", 0)
|
||||
}
|
||||
for i := uint8(0); i < n; i++ {
|
||||
ln := mux.Listen(byte(i))
|
||||
|
||||
wg.Add(1)
|
||||
go func(i uint8, ln net.Listener) {
|
||||
defer wg.Done()
|
||||
|
||||
// Wait for a connection for this listener.
|
||||
conn, err := ln.Accept()
|
||||
if conn != nil {
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
// If there is no message or the header byte
|
||||
// doesn't match then expect close.
|
||||
if len(msg) == 0 || msg[0] != byte(i) {
|
||||
if err == nil || err.Error() != "network connection closed" {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// If the header byte matches this listener
|
||||
// then expect a connection and read the message.
|
||||
var buf bytes.Buffer
|
||||
if _, err := io.CopyN(&buf, conn, int64(len(msg)-1)); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if !bytes.Equal(msg[1:], buf.Bytes()) {
|
||||
t.Fatalf("message mismatch:\n\nexp=%x\n\ngot=%x\n\n", msg[1:], buf.Bytes())
|
||||
}
|
||||
|
||||
// Write response.
|
||||
if _, err := conn.Write([]byte("OK")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}(i, ln)
|
||||
}
|
||||
|
||||
// Begin serving from the listener.
|
||||
go mux.Serve()
|
||||
|
||||
// Write message to TCP listener and read OK response.
|
||||
conn, err := net.Dial("tcp", tcpListener.Addr().String())
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
} else if _, err = conn.Write(msg); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Read the response into the buffer.
|
||||
var resp [2]byte
|
||||
_, err = io.ReadFull(conn, resp[:])
|
||||
|
||||
// If the message header is less than n then expect a response.
|
||||
// Otherwise we should get an EOF because the mux closed.
|
||||
if len(msg) > 0 && uint8(msg[0]) < n {
|
||||
if string(resp[:]) != `OK` {
|
||||
t.Fatalf("unexpected response: %s", resp[:])
|
||||
}
|
||||
} else {
|
||||
if err == nil || (err != io.EOF && !(strings.Contains(err.Error(), "connection reset by peer") ||
|
||||
strings.Contains(err.Error(), "closed by the remote host"))) {
|
||||
t.Fatalf("unexpected error: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Close connection.
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Close original TCP listener and wait for all goroutines to close.
|
||||
tcpListener.Close()
|
||||
wg.Wait()
|
||||
|
||||
return true
|
||||
}, nil); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMux_Advertise(t *testing.T) {
|
||||
// Setup muxer.
|
||||
tcpListener := mustTCPListener("127.0.0.1:0")
|
||||
defer tcpListener.Close()
|
||||
|
||||
addr := &mockAddr{
|
||||
Nwk: "tcp",
|
||||
Addr: "rqlite.com:8081",
|
||||
}
|
||||
|
||||
mux, err := NewMux(tcpListener, addr)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mux: %s", err.Error())
|
||||
}
|
||||
mux.Timeout = 200 * time.Millisecond
|
||||
if !testing.Verbose() {
|
||||
mux.Logger = log.New(ioutil.Discard, "", 0)
|
||||
}
|
||||
|
||||
layer := mux.Listen(1)
|
||||
if layer.Addr().String() != addr.Addr {
|
||||
t.Fatalf("layer advertise address not correct, exp %s, got %s",
|
||||
layer.Addr().String(), addr.Addr)
|
||||
}
|
||||
}
|
||||
|
||||
// Ensure two handlers cannot be registered for the same header byte.
|
||||
func TestMux_Listen_ErrAlreadyRegistered(t *testing.T) {
|
||||
defer func() {
|
||||
if r := recover(); r != `listener already registered under header byte: 5` {
|
||||
t.Fatalf("unexpected recover: %#v", r)
|
||||
}
|
||||
}()
|
||||
|
||||
// Register two listeners with the same header byte.
|
||||
tcpListener := mustTCPListener("127.0.0.1:0")
|
||||
mux, err := NewMux(tcpListener, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mux: %s", err.Error())
|
||||
}
|
||||
mux.Listen(5)
|
||||
mux.Listen(5)
|
||||
}
|
||||
|
||||
func TestTLSMux(t *testing.T) {
|
||||
tcpListener := mustTCPListener("127.0.0.1:0")
|
||||
defer tcpListener.Close()
|
||||
|
||||
cert := x509.CertFile()
|
||||
defer os.Remove(cert)
|
||||
key := x509.KeyFile()
|
||||
defer os.Remove(key)
|
||||
|
||||
mux, err := NewTLSMux(tcpListener, nil, cert, key, "")
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create mux: %s", err.Error())
|
||||
}
|
||||
go mux.Serve()
|
||||
|
||||
// Verify that the listener is secured.
|
||||
conn, err := tls.Dial("tcp", tcpListener.Addr().String(), &tls.Config{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
state := conn.ConnectionState()
|
||||
if !state.HandshakeComplete {
|
||||
t.Fatal("connection handshake failed to complete")
|
||||
}
|
||||
}
|
||||
|
||||
func TestTLSMux_Fail(t *testing.T) {
|
||||
tcpListener := mustTCPListener("127.0.0.1:0")
|
||||
defer tcpListener.Close()
|
||||
|
||||
cert := x509.CertFile()
|
||||
defer os.Remove(cert)
|
||||
key := x509.KeyFile()
|
||||
defer os.Remove(key)
|
||||
|
||||
_, err := NewTLSMux(tcpListener, nil, "xxxx", "yyyy", "")
|
||||
if err == nil {
|
||||
t.Fatalf("created mux unexpectedly with bad resources")
|
||||
}
|
||||
}
|
||||
|
||||
type mockAddr struct {
|
||||
Nwk string
|
||||
Addr string
|
||||
}
|
||||
|
||||
func (m *mockAddr) Network() string {
|
||||
return m.Nwk
|
||||
}
|
||||
|
||||
func (m *mockAddr) String() string {
|
||||
return m.Addr
|
||||
}
|
||||
|
||||
// mustTCPListener returns a listener on bind, or panics.
|
||||
func mustTCPListener(bind string) net.Listener {
|
||||
l, err := net.Listen("tcp", bind)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return l
|
||||
}
|
@ -0,0 +1,101 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Transport is the network layer for internode communications.
|
||||
type Transport struct {
|
||||
ln net.Listener
|
||||
|
||||
certFile string // Path to local X.509 cert.
|
||||
certKey string // Path to corresponding X.509 key.
|
||||
remoteEncrypted bool // Remote nodes use encrypted communication.
|
||||
skipVerify bool // Skip verification of remote node certs.
|
||||
}
|
||||
|
||||
// NewTransport returns an initialized unecrypted Transport.
|
||||
func NewTransport() *Transport {
|
||||
return &Transport{}
|
||||
}
|
||||
|
||||
// NewTransport returns an initialized TLS-ecrypted Transport.
|
||||
func NewTLSTransport(certFile, keyPath string, skipVerify bool) *Transport {
|
||||
return &Transport{
|
||||
certFile: certFile,
|
||||
certKey: keyPath,
|
||||
remoteEncrypted: true,
|
||||
skipVerify: skipVerify,
|
||||
}
|
||||
}
|
||||
|
||||
// Open opens the transport, binding to the supplied address.
|
||||
func (t *Transport) Open(addr string) error {
|
||||
ln, err := net.Listen("tcp", addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if t.certFile != "" {
|
||||
config, err := createTLSConfig(t.certFile, t.certKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ln = tls.NewListener(ln, config)
|
||||
}
|
||||
|
||||
t.ln = ln
|
||||
return nil
|
||||
}
|
||||
|
||||
// Dial opens a network connection.
|
||||
func (t *Transport) Dial(addr string, timeout time.Duration) (net.Conn, error) {
|
||||
dialer := &net.Dialer{Timeout: timeout}
|
||||
|
||||
var err error
|
||||
var conn net.Conn
|
||||
if t.remoteEncrypted {
|
||||
conf := &tls.Config{
|
||||
InsecureSkipVerify: t.skipVerify,
|
||||
}
|
||||
fmt.Println("doing a TLS dial")
|
||||
conn, err = tls.DialWithDialer(dialer, "tcp", addr, conf)
|
||||
} else {
|
||||
conn, err = dialer.Dial("tcp", addr)
|
||||
}
|
||||
|
||||
return conn, err
|
||||
}
|
||||
|
||||
// Accept waits for the next connection.
|
||||
func (t *Transport) Accept() (net.Conn, error) {
|
||||
c, err := t.ln.Accept()
|
||||
if err != nil {
|
||||
fmt.Println("error accepting: ", err.Error())
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
// Close closes the transport
|
||||
func (t *Transport) Close() error {
|
||||
return t.ln.Close()
|
||||
}
|
||||
|
||||
// Addr returns the binding address of the transport.
|
||||
func (t *Transport) Addr() net.Addr {
|
||||
return t.ln.Addr()
|
||||
}
|
||||
|
||||
// createTLSConfig returns a TLS config from the given cert and key.
|
||||
func createTLSConfig(certFile, keyFile string) (*tls.Config, error) {
|
||||
var err error
|
||||
config := &tls.Config{}
|
||||
config.Certificates = make([]tls.Certificate, 1)
|
||||
config.Certificates[0], err = tls.LoadX509KeyPair(certFile, keyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return config, nil
|
||||
}
|
@ -0,0 +1,70 @@
|
||||
package tcp
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/rqlite/rqlite/testdata/x509"
|
||||
)
|
||||
|
||||
func Test_NewTransport(t *testing.T) {
|
||||
if NewTransport() == nil {
|
||||
t.Fatal("failed to create new Transport")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TransportOpenClose(t *testing.T) {
|
||||
tn := NewTransport()
|
||||
if err := tn.Open("localhost:0"); err != nil {
|
||||
t.Fatalf("failed to open transport: %s", err.Error())
|
||||
}
|
||||
if tn.Addr().String() == "localhost:0" {
|
||||
t.Fatalf("transport address set incorrectly, got: %s", tn.Addr().String())
|
||||
}
|
||||
if err := tn.Close(); err != nil {
|
||||
t.Fatalf("failed to close transport: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TransportDial(t *testing.T) {
|
||||
tn1 := NewTransport()
|
||||
tn1.Open("localhost:0")
|
||||
go tn1.Accept()
|
||||
tn2 := NewTransport()
|
||||
|
||||
_, err := tn2.Dial(tn1.Addr().String(), time.Second)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to connect to first transport: %s", err.Error())
|
||||
}
|
||||
tn1.Close()
|
||||
}
|
||||
|
||||
func Test_NewTLSTransport(t *testing.T) {
|
||||
c := x509.CertFile()
|
||||
defer os.Remove(c)
|
||||
k := x509.KeyFile()
|
||||
defer os.Remove(k)
|
||||
|
||||
if NewTLSTransport(c, k, true) == nil {
|
||||
t.Fatal("failed to create new TLS Transport")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_TLSTransportOpenClose(t *testing.T) {
|
||||
c := x509.CertFile()
|
||||
defer os.Remove(c)
|
||||
k := x509.KeyFile()
|
||||
defer os.Remove(k)
|
||||
|
||||
tn := NewTLSTransport(c, k, true)
|
||||
if err := tn.Open("localhost:0"); err != nil {
|
||||
t.Fatalf("failed to open TLS transport: %s", err.Error())
|
||||
}
|
||||
if tn.Addr().String() == "localhost:0" {
|
||||
t.Fatalf("TLS transport address set incorrectly, got: %s", tn.Addr().String())
|
||||
}
|
||||
if err := tn.Close(); err != nil {
|
||||
t.Fatalf("failed to close TLS transport: %s", err.Error())
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue