diff --git a/tcp/dialer.go b/tcp/dialer.go new file mode 100644 index 00000000..d2ae53f4 --- /dev/null +++ b/tcp/dialer.go @@ -0,0 +1,50 @@ +package tcp + +import ( + "crypto/tls" + "net" + "time" +) + +// NewDialer returns an initialized Dialer +func NewDialer(header byte, remoteEncrypted, skipVerify bool) *Dialer { + return &Dialer{ + header: header, + remoteEncrypted: remoteEncrypted, + skipVerify: skipVerify, + } +} + +// Dialer supports dialing a cluster service. +type Dialer struct { + header byte + remoteEncrypted bool + skipVerify bool +} + +// Dial dials the cluster service at the given addr and returns a connection. +func (d *Dialer) Dial(addr string, timeout time.Duration) (net.Conn, error) { + dialer := &net.Dialer{Timeout: timeout} + + var err error + var conn net.Conn + if d.remoteEncrypted { + conf := &tls.Config{ + InsecureSkipVerify: d.skipVerify, + } + conn, err = tls.DialWithDialer(dialer, "tcp", addr, conf) + } else { + conn, err = dialer.Dial("tcp", addr) + } + if err != nil { + return nil, err + } + + // Write a marker byte to indicate message type. + _, err = conn.Write([]byte{d.header}) + if err != nil { + conn.Close() + return nil, err + } + return conn, err +} diff --git a/tcp/dialer_test.go b/tcp/dialer_test.go new file mode 100644 index 00000000..a703075c --- /dev/null +++ b/tcp/dialer_test.go @@ -0,0 +1,86 @@ +package tcp + +import ( + "net" + "testing" + "time" +) + +func Test_NewDialer(t *testing.T) { + d := NewDialer(1, false, false) + if d == nil { + t.Fatal("failed to create a dialer") + } +} + +func Test_DialerNoConnect(t *testing.T) { + d := NewDialer(87, false, false) + _, err := d.Dial("127.0.0.1:0", 5*time.Second) + if err == nil { + t.Fatalf("no error connecting to bad address") + } +} + +func Test_DialerHeader(t *testing.T) { + s := mustNewEchoServer() + defer s.Close() + + go s.MustStart() + + d := NewDialer(87, false, false) + conn, err := d.Dial(s.Addr(), 10*time.Second) + if err != nil { + t.Fatalf("failed to dial echo server: %s", err.Error()) + } + + buf := make([]byte, 1) + _, err = conn.Read(buf) + if err != nil { + t.Fatalf("failed to read from echo server: %s", err.Error()) + } + if exp, got := buf[0], byte(87); exp != got { + t.Fatalf("got wrong response from echo server, exp %d, got %d", exp, got) + } +} + +type echoSever struct { + ln net.Listener +} + +// Addr returns the address of the echo server. +func (e *echoSever) Addr() string { + return e.ln.Addr().String() +} + +// MustStart starts the echo server. +func (e *echoSever) MustStart() { + for { + conn, err := e.ln.Accept() + if err != nil { + return + } + go func(c net.Conn) { + buf := make([]byte, 1) + _, err := c.Read(buf) + if err != nil { + panic("failed to read byte") + } + _, err = c.Write(buf) + if err != nil { + panic("failed to echo received byte") + } + c.Close() + }(conn) + } +} + +// Close closes the echo server. +func (e *echoSever) Close() { + e.ln.Close() +} + +func mustNewEchoServer() *echoSever { + return &echoSever{ + ln: mustTCPListener("127.0.0.1:0"), + } +} diff --git a/tcp/mux.go b/tcp/mux.go index b2dfada2..08428a79 100644 --- a/tcp/mux.go +++ b/tcp/mux.go @@ -35,49 +35,6 @@ func init() { stats.Add(numUnregisteredHandlers, 0) } -// NewDialer returns an initialized Dialer -func NewDialer(header byte, remoteEncrypted, skipVerify bool) *Dialer { - return &Dialer{ - header: header, - remoteEncrypted: remoteEncrypted, - skipVerify: skipVerify, - } -} - -// Dialer supports dialing a cluster service. -type Dialer struct { - header byte - remoteEncrypted bool - skipVerify bool -} - -// Dial dials the cluster service a the given addr and returns a connection. -func (d *Dialer) Dial(addr string, timeout time.Duration) (net.Conn, error) { - dialer := &net.Dialer{Timeout: timeout} - - var err error - var conn net.Conn - if d.remoteEncrypted { - conf := &tls.Config{ - InsecureSkipVerify: d.skipVerify, - } - conn, err = tls.DialWithDialer(dialer, "tcp", addr, conf) - } else { - conn, err = dialer.Dial("tcp", addr) - } - if err != nil { - return nil, err - } - - // Write a marker byte to indicate message type. - _, err = conn.Write([]byte{d.header}) - if err != nil { - conn.Close() - return nil, err - } - return conn, err -} - // Layer represents the connection between nodes. It can be both used to // make connections to other nodes, and receive connections from other // nodes.