diff --git a/cmd/rqlited/main.go b/cmd/rqlited/main.go index bd8de5be..1863a4fd 100644 --- a/cmd/rqlited/main.go +++ b/cmd/rqlited/main.go @@ -21,7 +21,6 @@ import ( "os/signal" "path/filepath" "runtime/pprof" - "strings" "github.com/otoolep/rqlite/auth" sql "github.com/otoolep/rqlite/db" @@ -206,10 +205,7 @@ func join(joinAddr string, skipVerify bool, raftAddr string) error { } // Check for protocol scheme, and insert default if necessary. - fullAddr := fmt.Sprintf("%s/join", joinAddr) - if !strings.HasPrefix(joinAddr, "http://") && !strings.HasPrefix(joinAddr, "https://") { - fullAddr = fmt.Sprintf("http://%s", joinAddr) - } + fullAddr := httpd.NormalizeAddr(fmt.Sprintf("%s/join", joinAddr)) // Enable skipVerify as requested. tr := &http.Transport{ diff --git a/http/service.go b/http/service.go index e6618308..74bfe150 100644 --- a/http/service.go +++ b/http/service.go @@ -591,3 +591,12 @@ func prettyEnabled(e bool) string { } return "disabled" } + +// NormalizeAddr ensures that the given URL has a HTTP protocol prefix. +// If none is supplied, it prefixes the URL with "http://". +func NormalizeAddr(addr string) string { + if !strings.HasPrefix(addr, "http://") && !strings.HasPrefix(addr, "https://") { + return fmt.Sprintf("http://%s", addr) + } + return addr +} diff --git a/http/service_test.go b/http/service_test.go index 0d22df34..adb3d762 100644 --- a/http/service_test.go +++ b/http/service_test.go @@ -9,6 +9,44 @@ import ( "github.com/otoolep/rqlite/store" ) +func Test_NormalizeAddr(t *testing.T) { + tests := []struct { + orig string + norm string + }{ + { + orig: "http://localhost:4001", + norm: "http://localhost:4001", + }, + { + orig: "https://localhost:4001", + norm: "https://localhost:4001", + }, + { + orig: "https://localhost:4001/foo", + norm: "https://localhost:4001/foo", + }, + { + orig: "localhost:4001", + norm: "http://localhost:4001", + }, + { + orig: "localhost", + norm: "http://localhost", + }, + { + orig: ":4001", + norm: "http://:4001", + }, + } + + for _, tt := range tests { + if NormalizeAddr(tt.orig) != tt.norm { + t.Fatalf("%s not normalized correctly, got: %s", tt.orig, tt.norm) + } + } +} + func Test_NewService(t *testing.T) { m := &MockStore{} s := New("127.0.0.1:0", m, nil)