From 954503dc660c42808fbd02e31fa04bd58eb10e2a Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 23 Dec 2023 09:23:52 -0500 Subject: [PATCH] Move Store to Layer and create NewLayer in tcp --- store/store.go | 12 ++++++------ store/store_test.go | 22 +++++++++++----------- store/transport.go | 19 ++++++++++--------- tcp/mux.go | 30 +++++++++++++++++------------- tcp/mux_test.go | 8 ++++---- 5 files changed, 48 insertions(+), 43 deletions(-) diff --git a/store/store.go b/store/store.go index 248a32e8..25082d5a 100644 --- a/store/store.go +++ b/store/store.go @@ -208,7 +208,7 @@ type Store struct { restoreDoneCh chan struct{} raft *raft.Raft // The consensus mechanism. - ln Listener + ly Layer raftTn *NodeTransport raftID string // Node ID. dbConf *DBConfig // SQLite database config. @@ -309,7 +309,7 @@ type Config struct { } // New returns a new Store. -func New(ln Listener, c *Config) *Store { +func New(ly Layer, c *Config) *Store { logger := c.Logger if logger == nil { logger = log.New(os.Stderr, "[store] ", log.LstdFlags) @@ -321,7 +321,7 @@ func New(ln Listener, c *Config) *Store { } return &Store{ - ln: ln, + ly: ly, raftDir: c.Dir, peersPath: filepath.Join(c.Dir, peersPath), peersInfoPath: filepath.Join(c.Dir, peersInfoPath), @@ -376,7 +376,7 @@ func (s *Store) Open() (retErr error) { } s.openT = time.Now() - s.logger.Printf("opening store with node ID %s, listening on %s", s.raftID, s.ln.Addr().String()) + s.logger.Printf("opening store with node ID %s, listening on %s", s.raftID, s.ly.Addr().String()) // Create all the required Raft directories. s.logger.Printf("ensuring data directory exists at %s", s.raftDir) @@ -403,7 +403,7 @@ func (s *Store) Open() (retErr error) { } // Create Raft-compatible network layer. - nt := raft.NewNetworkTransport(NewTransport(s.ln), connectionPoolCount, connectionTimeout, nil) + nt := raft.NewNetworkTransport(NewTransport(s.ly), connectionPoolCount, connectionTimeout, nil) s.raftTn = NewNodeTransport(nt) // Don't allow control over trailing logs directly, just implement a policy. @@ -583,7 +583,7 @@ func (s *Store) Ready() bool { func (s *Store) Close(wait bool) (retErr error) { defer func() { if retErr == nil { - s.logger.Printf("store closed with node ID %s, listening on %s", s.raftID, s.ln.Addr().String()) + s.logger.Printf("store closed with node ID %s, listening on %s", s.raftID, s.ly.Addr().String()) s.open = false } }() diff --git a/store/store_test.go b/store/store_test.go index f71128c5..67504c07 100644 --- a/store/store_test.go +++ b/store/store_test.go @@ -2061,7 +2061,7 @@ func Test_MultiNodeStoreAutoRestoreBootstrap(t *testing.T) { // Trigger a bootstrap. s0.BootstrapExpect = 3 for _, s := range []*Store{s0, s1, s2} { - if err := s0.Notify(notifyRequest(s.ID(), s.ln.Addr().String())); err != nil { + if err := s0.Notify(notifyRequest(s.ID(), s.ly.Addr().String())); err != nil { t.Fatalf("failed to notify store: %s", err.Error()) } } @@ -2797,8 +2797,8 @@ func mustNewStoreAtPathsLn(id, dataPath, sqlitePath string, fk bool) (*Store, ne cfg.FKConstraints = fk cfg.OnDiskPath = sqlitePath - ln := mustMockLister("localhost:0") - s := New(ln, &Config{ + ly := mustMockLayer("localhost:0") + s := New(ly, &Config{ DBConf: cfg, Dir: dataPath, ID: id, @@ -2806,7 +2806,7 @@ func mustNewStoreAtPathsLn(id, dataPath, sqlitePath string, fk bool) (*Store, ne if s == nil { panic("failed to create new store") } - return s, ln + return s, ly } func mustNewStore(t *testing.T) (*Store, net.Listener) { @@ -2837,27 +2837,27 @@ func (m *mockSnapshotSink) Cancel() error { return nil } -type mockListener struct { +type mockLayer struct { ln net.Listener } -func mustMockLister(addr string) Listener { +func mustMockLayer(addr string) Layer { ln, err := net.Listen("tcp", addr) if err != nil { panic("failed to create new listner") } - return &mockListener{ln} + return &mockLayer{ln} } -func (m *mockListener) Dial(addr string, timeout time.Duration) (net.Conn, error) { +func (m *mockLayer) Dial(addr string, timeout time.Duration) (net.Conn, error) { return net.DialTimeout("tcp", addr, timeout) } -func (m *mockListener) Accept() (net.Conn, error) { return m.ln.Accept() } +func (m *mockLayer) Accept() (net.Conn, error) { return m.ln.Accept() } -func (m *mockListener) Close() error { return m.ln.Close() } +func (m *mockLayer) Close() error { return m.ln.Close() } -func (m *mockListener) Addr() net.Addr { return m.ln.Addr() } +func (m *mockLayer) Addr() net.Addr { return m.ln.Addr() } func mustCreateTempFile() string { f, err := os.CreateTemp("", "rqlite-temp") diff --git a/store/transport.go b/store/transport.go index 6ddde994..6400d3af 100644 --- a/store/transport.go +++ b/store/transport.go @@ -9,42 +9,43 @@ import ( "github.com/rqlite/rqlite/v8/store/gzip" ) -// Listener is the interface expected by the Store for Transports. -type Listener interface { +// Layer is the interface expected by the Store for network communication +// between nodes, which is used for Raft distributed consensus. +type Layer interface { net.Listener Dial(address string, timeout time.Duration) (net.Conn, error) } // Transport is the network service provided to Raft, and wraps a Listener. type Transport struct { - ln Listener + ly Layer } // NewTransport returns an initialized Transport. -func NewTransport(ln Listener) *Transport { +func NewTransport(ly Layer) *Transport { return &Transport{ - ln: ln, + ly: ly, } } // Dial creates a new network connection. func (t *Transport) Dial(addr raft.ServerAddress, timeout time.Duration) (net.Conn, error) { - return t.ln.Dial(string(addr), timeout) + return t.ly.Dial(string(addr), timeout) } // Accept waits for the next connection. func (t *Transport) Accept() (net.Conn, error) { - return t.ln.Accept() + return t.ly.Accept() } // Close closes the transport func (t *Transport) Close() error { - return t.ln.Close() + return t.ly.Close() } // Addr returns the binding address of the transport. func (t *Transport) Addr() net.Addr { - return t.ln.Addr() + return t.ly.Addr() } // NodeTransport is a wrapper around the Raft NetworkTransport, which allows diff --git a/tcp/mux.go b/tcp/mux.go index 4f8f3839..a2baa957 100644 --- a/tcp/mux.go +++ b/tcp/mux.go @@ -43,6 +43,15 @@ type Layer struct { dialer *Dialer } +// NewLayer returns a new instance of Layer. +func NewLayer(ln net.Listener, dialer *Dialer) *Layer { + return &Layer{ + ln: ln, + addr: ln.Addr(), + dialer: dialer, + } +} + // Dial creates a new network connection. func (l *Layer) Dial(addr string, timeout time.Duration) (net.Conn, error) { return l.dialer.Dial(addr, timeout) @@ -161,9 +170,9 @@ func (mux *Mux) Stats() (interface{}, error) { return s, nil } -// Listen returns a Layer associated with the given header. Any connection +// Listen returns a Listener associated with the given header. Any connection // accepted by mux is multiplexed based on the initial header byte. -func (mux *Mux) Listen(header byte) *Layer { +func (mux *Mux) Listen(header byte) net.Listener { // Ensure two listeners are not created for the same header byte. if _, ok := mux.m[header]; ok { panic(fmt.Sprintf("listener already registered under header byte: %d", header)) @@ -171,17 +180,11 @@ func (mux *Mux) Listen(header byte) *Layer { // Create a new listener and assign it. ln := &listener{ - c: make(chan net.Conn), - } - mux.m[header] = ln - - layer := &Layer{ - ln: ln, + c: make(chan net.Conn), addr: mux.addr, } - layer.dialer = NewDialer(header, mux.tlsConfig) - - return layer + mux.m[header] = ln + return ln } func (mux *Mux) handleConn(conn net.Conn) { @@ -226,7 +229,8 @@ func (mux *Mux) handleConn(conn net.Conn) { // listener is a receiver for connections received by Mux. type listener struct { - c chan net.Conn + c chan net.Conn + addr net.Addr } // Accept waits for and returns the next connection to the listener. @@ -242,4 +246,4 @@ func (ln *listener) Accept() (c net.Conn, err error) { func (ln *listener) Close() error { return nil } // Addr always returns nil -func (ln *listener) Addr() net.Addr { return nil } +func (ln *listener) Addr() net.Addr { return ln.addr } diff --git a/tcp/mux_test.go b/tcp/mux_test.go index 8ef17a1c..9e53b1c1 100644 --- a/tcp/mux_test.go +++ b/tcp/mux_test.go @@ -142,10 +142,10 @@ func TestMux_Advertise(t *testing.T) { mux.Logger = log.New(io.Discard, "", 0) } - layer := mux.Listen(1) - if layer.Addr().String() != addr.Addr { - t.Fatalf("layer advertise address not correct, exp %s, got %s", - layer.Addr().String(), addr.Addr) + ln := mux.Listen(1) + if ln.Addr().String() != addr.Addr { + t.Fatalf("listener advertise address not correct, exp %s, got %s", + ln.Addr().String(), addr.Addr) } }