diff --git a/cluster/join.go b/cluster/join.go index 45933e2e..898ad112 100644 --- a/cluster/join.go +++ b/cluster/join.go @@ -26,7 +26,7 @@ var ( // It walks through joinAddr in order, and sets the node ID and Raft address of // the joining node as id addr respectively. It returns the endpoint successfully // used to join the cluster. -func Join(joinAddr []string, id, addr string, voter bool, meta map[string]string, numAttempts int, +func Join(srcIP string, joinAddr []string, id, addr string, voter bool, meta map[string]string, numAttempts int, attemptInterval time.Duration, tlsConfig *tls.Config) (string, error) { var err error var j string @@ -37,7 +37,7 @@ func Join(joinAddr []string, id, addr string, voter bool, meta map[string]string for i := 0; i < numAttempts; i++ { for _, a := range joinAddr { - j, err = join(a, id, addr, voter, meta, tlsConfig, logger) + j, err = join(srcIP, a, id, addr, voter, meta, tlsConfig, logger) if err == nil { // Success! return j, nil @@ -50,11 +50,20 @@ func Join(joinAddr []string, id, addr string, voter bool, meta map[string]string return "", ErrJoinFailed } -func join(joinAddr, id, addr string, voter bool, meta map[string]string, tlsConfig *tls.Config, logger *log.Logger) (string, error) { +func join(srcIP, joinAddr, id, addr string, voter bool, meta map[string]string, tlsConfig *tls.Config, logger *log.Logger) (string, error) { if id == "" { return "", fmt.Errorf("node ID not set") } - + // The specified source IP is optional + var dialer *net.Dialer + dialer = &net.Dialer{} + if srcIP != "" { + netAddr := &net.TCPAddr{ + IP: net.ParseIP(srcIP), + Port: 0, + } + dialer = &net.Dialer{LocalAddr: netAddr} + } // Join using IP address, as that is what Hashicorp Raft works in. resv, err := net.ResolveTCPAddr("tcp", addr) if err != nil { @@ -67,6 +76,7 @@ func join(joinAddr, id, addr string, voter bool, meta map[string]string, tlsConf // Create and configure the client to connect to the other node. tr := &http.Transport{ TLSClientConfig: tlsConfig, + Dial: dialer.Dial, } client := &http.Client{Transport: tr} client.CheckRedirect = func(req *http.Request, via []*http.Request) error { diff --git a/cluster/join_test.go b/cluster/join_test.go index 82c45750..3465d5ad 100644 --- a/cluster/join_test.go +++ b/cluster/join_test.go @@ -35,7 +35,7 @@ func Test_SingleJoinOK(t *testing.T) { defer ts.Close() - j, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", false, nil, + j, err := Join("127.0.0.1", []string{ts.URL}, "id0", "127.0.0.1:9090", false, nil, numAttempts, attemptInterval, nil) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) @@ -60,7 +60,7 @@ func Test_SingleJoinZeroAttempts(t *testing.T) { t.Fatalf("handler should not have been called") })) - _, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", false, nil, 0, attemptInterval, nil) + _, err := Join("127.0.0.1", []string{ts.URL}, "id0", "127.0.0.1:9090", false, nil, 0, attemptInterval, nil) if err != ErrJoinFailed { t.Fatalf("Incorrect error returned when zero attempts specified") } @@ -90,7 +90,7 @@ func Test_SingleJoinMetaOK(t *testing.T) { nodeAddr := "127.0.0.1:9090" md := map[string]string{"foo": "bar"} - j, err := Join([]string{ts.URL}, "id0", nodeAddr, true, md, + j, err := Join("", []string{ts.URL}, "id0", nodeAddr, true, md, numAttempts, attemptInterval, nil) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) @@ -117,7 +117,7 @@ func Test_SingleJoinFail(t *testing.T) { })) defer ts.Close() - _, err := Join([]string{ts.URL}, "id0", "127.0.0.1:9090", true, nil, + _, err := Join("", []string{ts.URL}, "id0", "127.0.0.1:9090", true, nil, numAttempts, attemptInterval, nil) if err == nil { t.Fatalf("expected error when joining bad node") @@ -132,7 +132,7 @@ func Test_DoubleJoinOK(t *testing.T) { })) defer ts2.Close() - j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", true, nil, + j, err := Join("127.0.0.1", []string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", true, nil, numAttempts, attemptInterval, nil) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) @@ -151,7 +151,7 @@ func Test_DoubleJoinOKSecondNode(t *testing.T) { })) defer ts2.Close() - j, err := Join([]string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", true, nil, + j, err := Join("", []string{ts1.URL, ts2.URL}, "id0", "127.0.0.1:9090", true, nil, numAttempts, attemptInterval, nil) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) @@ -172,7 +172,7 @@ func Test_DoubleJoinOKSecondNodeRedirect(t *testing.T) { })) defer ts2.Close() - j, err := Join([]string{ts2.URL}, "id0", "127.0.0.1:9090", true, nil, + j, err := Join("127.0.0.1", []string{ts2.URL}, "id0", "127.0.0.1:9090", true, nil, numAttempts, attemptInterval, nil) if err != nil { t.Fatalf("failed to join a single node: %s", err.Error()) diff --git a/cmd/rqlited/main.go b/cmd/rqlited/main.go index c539bced..b3c52c74 100644 --- a/cmd/rqlited/main.go +++ b/cmd/rqlited/main.go @@ -85,6 +85,7 @@ var compressionBatch int var showVersion bool var cpuProfile string var memProfile string +var srcIP string const name = `rqlited` const desc = `rqlite is a lightweight, distributed relational database, which uses SQLite as its @@ -132,6 +133,8 @@ func init() { flag.IntVar(&compressionBatch, "compression-batch", 5, "Request batch threshold for compression attempt") flag.StringVar(&cpuProfile, "cpu-profile", "", "Path to file for CPU profiling information") flag.StringVar(&memProfile, "mem-profile", "", "Path to file for memory profiling information") + flag.StringVar(&srcIP, "src-ip", "", "Specify a source ip address, when your node has multiple ip address segments") + flag.Usage = func() { fmt.Fprintf(os.Stderr, "\n%s\n\n", desc) fmt.Fprintf(os.Stderr, "Usage: %s [flags] \n", name) @@ -303,7 +306,7 @@ func main() { } } - if j, err := cluster.Join(joins, str.ID(), advAddr, !raftNonVoter, meta, + if j, err := cluster.Join(srcIP, joins, str.ID(), advAddr, !raftNonVoter, meta, joinAttempts, joinDur, &tlsConfig); err != nil { log.Fatalf("failed to join cluster at %s: %s", joins, err.Error()) } else { diff --git a/tcp/transport.go b/tcp/transport.go index 1379ec8b..88a61691 100644 --- a/tcp/transport.go +++ b/tcp/transport.go @@ -15,6 +15,7 @@ type Transport struct { certKey string // Path to corresponding X.509 key. remoteEncrypted bool // Remote nodes use encrypted communication. skipVerify bool // Skip verification of remote node certs. + srcIP string // The specified source IP is optional } // NewTransport returns an initialized unencrypted Transport. @@ -52,7 +53,15 @@ func (t *Transport) Open(addr string) error { // Dial opens a network connection. func (t *Transport) Dial(addr string, timeout time.Duration) (net.Conn, error) { - dialer := &net.Dialer{Timeout: timeout} + var dialer *net.Dialer + dialer = &net.Dialer{Timeout: timeout} + if t.srcIP != "" { + netAddr := &net.TCPAddr{ + IP: net.ParseIP(t.srcIP), + Port: 0, + } + dialer = &net.Dialer{Timeout: timeout, LocalAddr: netAddr} + } var err error var conn net.Conn