diff --git a/tcp/dialer.go b/tcp/dialer.go index d2ae53f4..91b9c901 100644 --- a/tcp/dialer.go +++ b/tcp/dialer.go @@ -2,6 +2,7 @@ package tcp import ( "crypto/tls" + "fmt" "net" "time" ) @@ -23,28 +24,32 @@ type Dialer struct { } // Dial dials the cluster service at the given addr and returns a connection. -func (d *Dialer) Dial(addr string, timeout time.Duration) (net.Conn, error) { +func (d *Dialer) Dial(addr string, timeout time.Duration) (conn net.Conn, retErr error) { dialer := &net.Dialer{Timeout: timeout} - var err error - var conn net.Conn if d.remoteEncrypted { conf := &tls.Config{ InsecureSkipVerify: d.skipVerify, } - conn, err = tls.DialWithDialer(dialer, "tcp", addr, conf) + conn, retErr = tls.DialWithDialer(dialer, "tcp", addr, conf) } else { - conn, err = dialer.Dial("tcp", addr) + conn, retErr = dialer.Dial("tcp", addr) } - if err != nil { - return nil, err + if retErr != nil { + return nil, retErr } + defer func() { + if retErr != nil { + conn.Close() + } + }() // Write a marker byte to indicate message type. - _, err = conn.Write([]byte{d.header}) - if err != nil { - conn.Close() + if err := conn.SetWriteDeadline(time.Now().Add(timeout)); err != nil { + return nil, fmt.Errorf("failed to set WriteDeadline for header: %s", err.Error()) + } + if _, err := conn.Write([]byte{d.header}); err != nil { return nil, err } - return conn, err + return conn, nil }