diff --git a/cmd/rqlite/execute.go b/cmd/rqlite/execute.go index 522239c4..50e1fee8 100644 --- a/cmd/rqlite/execute.go +++ b/cmd/rqlite/execute.go @@ -9,6 +9,7 @@ import ( "strings" "github.com/mkideal/cli" + cl "github.com/rqlite/rqlite/cmd/rqlite/http" ) // Result represents execute result @@ -25,96 +26,75 @@ type executeResponse struct { Time float64 `json:"time,omitempty"` } -func executeWithClient(ctx *cli.Context, client *http.Client, argv *argT, timer bool, stmt string) error { +func executeWithClient(ctx *cli.Context, client *cl.Client, timer bool, stmt string) error { queryStr := url.Values{} if timer { queryStr.Set("timings", "") } u := url.URL{ - Scheme: argv.Protocol, - Host: fmt.Sprintf("%s:%d", argv.Host, argv.Port), - Path: fmt.Sprintf("%sdb/execute", argv.Prefix), + Path: fmt.Sprintf("%sdb/execute", client.Prefix), } - urlStr := u.String() requestData := strings.NewReader(makeJSONBody(stmt)) - nRedirect := 0 - for { - if _, err := requestData.Seek(0, io.SeekStart); err != nil { - return err - } + if _, err := requestData.Seek(0, io.SeekStart); err != nil { + return err + } - req, err := http.NewRequest("POST", urlStr, requestData) - if err != nil { - return err - } - if argv.Credentials != "" { - creds := strings.Split(argv.Credentials, ":") - if len(creds) != 2 { - return fmt.Errorf("invalid Basic Auth credentials format") - } - req.SetBasicAuth(creds[0], creds[1]) - } + resp, err := client.Execute(u, requestData) - resp, err := client.Do(req) - if err != nil { - return err - } - response, err := ioutil.ReadAll(resp.Body) - if err != nil { + var hcr error + if err != nil { + // If the error is HostChangedError, it should be propagated back to the caller to handle + // accordingly (change prompt display), but we should still assume that the request succeeded on some + // host and not treat it as an error. + err, ok := err.(*cl.HostChangedError) + if !ok { return err } - resp.Body.Close() - - if resp.StatusCode == http.StatusUnauthorized { - return fmt.Errorf("unauthorized") - } + hcr = err + } - if resp.StatusCode == http.StatusMovedPermanently { - nRedirect++ - if nRedirect > maxRedirect { - return fmt.Errorf("maximum leader redirect limit exceeded") - } - urlStr = resp.Header["Location"][0] - continue - } + response, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("server responded with %s: %s", resp.Status, response) - } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("server responded with %s: %s", resp.Status, response) + } - // Parse response and write results - ret := &executeResponse{} - if err := parseResponse(&response, &ret); err != nil { - return err - } - if ret.Error != "" { - return fmt.Errorf(ret.Error) - } - if len(ret.Results) != 1 { - return fmt.Errorf("unexpected results length: %d", len(ret.Results)) - } + // Parse response and write results + ret := &executeResponse{} + if err := parseResponse(&response, &ret); err != nil { + return err + } + if ret.Error != "" { + return fmt.Errorf(ret.Error) + } + if len(ret.Results) != 1 { + return fmt.Errorf("unexpected results length: %d", len(ret.Results)) + } - result := ret.Results[0] - if result.Error != "" { - ctx.String("Error: %s\n", result.Error) - return nil - } + result := ret.Results[0] + if result.Error != "" { + ctx.String("Error: %s\n", result.Error) + return nil + } - rowString := "row" - if result.RowsAffected > 1 { - rowString = "rows" - } - if timer { - ctx.String("%d %s affected (%f sec)\n", result.RowsAffected, rowString, result.Time) - } else { - ctx.String("%d %s affected\n", result.RowsAffected, rowString) - } + rowString := "row" + if result.RowsAffected > 1 { + rowString = "rows" + } + if timer { + ctx.String("%d %s affected (%f sec)\n", result.RowsAffected, rowString, result.Time) + } else { + ctx.String("%d %s affected\n", result.RowsAffected, rowString) + } - if timer { - fmt.Printf("Run Time: %f seconds\n", result.Time) - } - return nil + if timer { + fmt.Printf("Run Time: %f seconds\n", result.Time) } + return hcr } diff --git a/cmd/rqlite/http/client.go b/cmd/rqlite/http/client.go new file mode 100644 index 00000000..9050320d --- /dev/null +++ b/cmd/rqlite/http/client.go @@ -0,0 +1,191 @@ +package http + +import ( + "fmt" + "io" + "log" + "net/http" + "net/url" + "os" + "strings" +) + +// ErrNoAvailableHost indicates that the client could not find an available host to send the request to. +var ErrNoAvailableHost = fmt.Errorf("no host available to perform the request") + +// ErrTooManyRedirects indicates that the client exceeded the maximum number of redirects +var ErrTooManyRedirects = fmt.Errorf("maximum leader redirect limit exceeded") + +// HostChangedError indicates that the underlying request was executed on a different host +// different from the caller anticipated +type HostChangedError struct { + NewHost string +} + +func (he *HostChangedError) Error() string { + return fmt.Sprintf("HostChangedErr: new host is '%s'", he.NewHost) +} + +type ConfigFunc func(*Client) + +// Client is a wrapper around stock `http.Client` that adds "retry on another host" behaviour +// based on the supplied configuration. +// +// The client will fall back and try other nodes when the current node is unavailable, and would stop trying +// after exhausting the list of supplied hosts. +// +// Note: +// +// This type is not goroutine safe. +// A node is considered unavailable if the client is not reachable via the network. +// TODO: make the unavailability condition for the client more dynamic. +type Client struct { + *http.Client + scheme string + hosts []string + Prefix string + + // creds stores the http basic authentication username and password + creds string + logger *log.Logger + + // currentHost keeps track of the last available host + currentHost int + maxRedirect int +} + +// NewClient creates a default client that sends `execute` and query `requests` against the +// rqlited nodes supplied via `hosts` argument. +func NewClient(client *http.Client, hosts []string, configFuncs ...ConfigFunc) *Client { + cl := &Client{ + Client: client, + hosts: hosts, + scheme: "http", + maxRedirect: 21, + Prefix: "/", + logger: log.New(os.Stderr, "[client] ", log.LstdFlags), + } + + for _, f := range configFuncs { + f(cl) + } + + return cl +} + +// WithScheme changes the default scheme used i.e "http". +func WithScheme(scheme string) ConfigFunc { + return func(client *Client) { + client.scheme = scheme + } +} + +// WithPrefix sets the prefix to be used when issuing HTTP requests against one of +// the rqlited nodes. +func WithPrefix(prefix string) ConfigFunc { + return func(client *Client) { + client.Prefix = prefix + } +} + +// WithLogger changes the default logger to the one provided. +func WithLogger(logger *log.Logger) ConfigFunc { + return func(client *Client) { + client.logger = logger + } +} + +// WithBasicAuth adds basic authentication behaviour to the client's request. +func WithBasicAuth(creds string) ConfigFunc { + return func(client *Client) { + client.creds = creds + } +} + +// Query sends GET requests to one of the hosts known to the client. +func (c *Client) Query(url url.URL) (*http.Response, error) { + return c.execRequest(http.MethodGet, url, nil) +} + +// Execute sends POST requests to one of the hosts known to the client +func (c *Client) Execute(url url.URL, body io.Reader) (*http.Response, error) { + return c.execRequest(http.MethodPost, url, body) +} + +func (c *Client) execRequest(method string, url url.URL, body io.Reader) (*http.Response, error) { + triedHosts := 0 + for triedHosts < len(c.hosts) { + host := c.hosts[c.currentHost] + url.Scheme = c.scheme + url.Host = host + urlStr := url.String() + resp, err := c.requestFollowRedirect(method, urlStr, body) + + // Found a responsive node + if err == nil { + if triedHosts > 0 { + return resp, &HostChangedError{NewHost: host} + } + return resp, nil + } + + // If we did too many redirects, we will consider the host as unavailable as well, + // and we will retry the request from another host + if err == ErrTooManyRedirects { + c.logger.Printf("too many redirects from host: '%s'", host) + } + + c.logger.Printf("host '%s' is unavailable, retrying with the next available host", host) + triedHosts++ + c.nextHost() + } + + c.logger.Printf("none of the available hosts are responsive") + return nil, ErrNoAvailableHost +} + +func (c *Client) nextHost() { + c.currentHost = (c.currentHost + 1) % len(c.hosts) +} + +func (c *Client) requestFollowRedirect(method string, urlStr string, body io.Reader) (*http.Response, error) { + nRedirects := 0 + for { + req, err := http.NewRequest(method, urlStr, body) + if err != nil { + return nil, err + } + err = c.setBasicAuth(req) + if err != nil { + return nil, err + } + + resp, err := c.Client.Do(req) + if err != nil { + return nil, err + } + + if resp.StatusCode == http.StatusMovedPermanently { + nRedirects++ + if nRedirects > c.maxRedirect { + return resp, ErrTooManyRedirects + } + urlStr = resp.Header["Location"][0] + continue + } + + return resp, nil + } +} + +func (c *Client) setBasicAuth(req *http.Request) error { + if c.creds == "" { + return nil + } + creds := strings.Split(c.creds, ":") + if len(creds) != 2 { + return fmt.Errorf("invalid Basic Auth credential format") + } + req.SetBasicAuth(creds[0], creds[1]) + return nil +} diff --git a/cmd/rqlite/http/client_test.go b/cmd/rqlite/http/client_test.go new file mode 100644 index 00000000..b06a2065 --- /dev/null +++ b/cmd/rqlite/http/client_test.go @@ -0,0 +1,164 @@ +package http + +import ( + "net/http" + "net/http/httptest" + "net/url" + "testing" +) + +func TestClient_QueryWhenAllAvailable(t *testing.T) { + node1 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusOK) + })) + defer node1.Close() + + node2 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusOK) + })) + defer node2.Close() + + httpClient := http.DefaultClient + + u1, _ := url.Parse(node1.URL) + u2, _ := url.Parse(node2.URL) + client := NewClient(httpClient, []string{u1.Host, u2.Host}) + res, err := client.Query(url.URL{ + Path: "/", + }) + defer res.Body.Close() + + if err != nil { + t.Errorf("unexpected error: %v", err) + } + + if client.currentHost != 0 { + t.Errorf("expected to only forward requests to the first host") + } + + if res.StatusCode != http.StatusOK { + t.Errorf("unexpected status code, expected '200' got '%d'", res.StatusCode) + } +} + +func TestClient_QueryWhenSomeAreAvailable(t *testing.T) { + node1 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusOK) + })) + + node2 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusOK) + })) + defer node2.Close() + + httpClient := http.DefaultClient + + // Shutting down one of the hosts making it unavailable + node1.Close() + + u1, _ := url.Parse(node1.URL) + u2, _ := url.Parse(node2.URL) + client := NewClient(httpClient, []string{u1.Host, u2.Host}) + res, err := client.Query(url.URL{ + Path: "/", + }) + defer res.Body.Close() + + // If the request succeeds after changing hosts, it should be reflected in the returned error + // as HostChangedError + if err == nil { + t.Errorf("expected HostChangedError got nil instead") + } + + hcer, ok := err.(*HostChangedError) + + if !ok { + t.Errorf("unexpected error occurred: %v", err) + } + + if hcer.NewHost != u2.Host { + t.Errorf("unexpected responding host") + } + + if client.currentHost != 1 { + t.Errorf("expected to move on to the following host") + } + + if res.StatusCode != http.StatusOK { + t.Errorf("unexpected status code, expected '200' got '%d'", res.StatusCode) + } +} + +func TestClient_QueryWhenAllUnavailable(t *testing.T) { + node1 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusOK) + })) + + node2 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + writer.WriteHeader(http.StatusOK) + })) + + httpClient := http.DefaultClient + + u1, _ := url.Parse(node1.URL) + u2, _ := url.Parse(node2.URL) + + // Shutting down both nodes, both of them now are unavailable + node1.Close() + node2.Close() + client := NewClient(httpClient, []string{u1.Host, u2.Host}) + _, err := client.Query(url.URL{ + Path: "/", + }) + + if err != ErrNoAvailableHost { + t.Errorf("Expected %v, got: %v", ErrNoAvailableHost, err) + } +} + +func TestClient_BasicAuthIsForwarded(t *testing.T) { + mockAuth := func(request *http.Request) bool { + user, pass, ok := request.BasicAuth() + if ok { + if user == "john" && pass == "doe" { + return true + } + } + return false + } + node1 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + if mockAuth(request) { + writer.WriteHeader(http.StatusOK) + return + } + writer.WriteHeader(http.StatusUnauthorized) + })) + defer node1.Close() + + node2 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { + if mockAuth(request) { + writer.WriteHeader(http.StatusOK) + return + } + writer.WriteHeader(http.StatusUnauthorized) + })) + defer node2.Close() + + httpClient := http.DefaultClient + + u1, _ := url.Parse(node1.URL) + u2, _ := url.Parse(node2.URL) + client := NewClient(httpClient, []string{u1.Host, u2.Host}, WithBasicAuth("john:wrongpassword")) + + res, err := client.Query(url.URL{ + Path: "/", + }) + + if err != nil { + t.Errorf("unexpected error") + } + + if res.StatusCode != http.StatusUnauthorized { + t.Errorf("expected unauthorized status") + } +} diff --git a/cmd/rqlite/main.go b/cmd/rqlite/main.go index fbacd771..00c910b4 100644 --- a/cmd/rqlite/main.go +++ b/cmd/rqlite/main.go @@ -18,20 +18,22 @@ import ( "github.com/Bowery/prompt" "github.com/mkideal/cli" "github.com/rqlite/rqlite/cmd" + httpcl "github.com/rqlite/rqlite/cmd/rqlite/http" ) const maxRedirect = 21 type argT struct { cli.Helper - Protocol string `cli:"s,scheme" usage:"protocol scheme (http or https)" dft:"http"` - Host string `cli:"H,host" usage:"rqlited host address" dft:"127.0.0.1"` - Port uint16 `cli:"p,port" usage:"rqlited host port" dft:"4001"` - Prefix string `cli:"P,prefix" usage:"rqlited HTTP URL prefix" dft:"/"` - Insecure bool `cli:"i,insecure" usage:"do not verify rqlited HTTPS certificate" dft:"false"` - CACert string `cli:"c,ca-cert" usage:"path to trusted X.509 root CA certificate"` - Credentials string `cli:"u,user" usage:"set basic auth credentials in form username:password"` - Version bool `cli:"v,version" usage:"display CLI version"` + Alternatives string `cli:"a,alternatives" usage:"comma separated list of 'host:port' pairs to use as fallback"` + Protocol string `cli:"s,scheme" usage:"protocol scheme (http or https)" dft:"http"` + Host string `cli:"H,host" usage:"rqlited host address" dft:"127.0.0.1"` + Port uint16 `cli:"p,port" usage:"rqlited host port" dft:"4001"` + Prefix string `cli:"P,prefix" usage:"rqlited HTTP URL prefix" dft:"/"` + Insecure bool `cli:"i,insecure" usage:"do not verify rqlited HTTPS certificate" dft:"false"` + CACert string `cli:"c,ca-cert" usage:"path to trusted X.509 root CA certificate"` + Credentials string `cli:"u,user" usage:"set basic auth credentials in form username:password"` + Version bool `cli:"v,version" usage:"display CLI version"` } var cliHelp = []string{ @@ -67,13 +69,13 @@ func main() { return nil } - client, err := getHTTPClient(argv) + httpClient, err := getHTTPClient(argv) if err != nil { ctx.String("%s %v\n", ctx.Color().Red("ERR!"), err) return nil } - version, err := getVersionWithClient(client, argv) + version, err := getVersionWithClient(httpClient, argv) if err != nil { ctx.String("%s %v\n", ctx.Color().Red("ERR!"), err) return nil @@ -93,6 +95,12 @@ func main() { } term.Close() + hosts := createHostList(argv) + client := httpcl.NewClient(httpClient, hosts, + httpcl.WithScheme(argv.Protocol), + httpcl.WithBasicAuth(argv.Credentials), + httpcl.WithPrefix(argv.Prefix)) + FOR_READ: for { term.Reopen() @@ -122,23 +130,23 @@ func main() { } err = setConsistency(line[index+1:], &consistency) case ".TABLES": - err = queryWithClient(ctx, client, argv, timer, consistency, `SELECT name FROM sqlite_master WHERE type="table"`) + err = queryWithClient(ctx, client, timer, consistency, `SELECT name FROM sqlite_master WHERE type="table"`) case ".INDEXES": - err = queryWithClient(ctx, client, argv, timer, consistency, `SELECT sql FROM sqlite_master WHERE type="index"`) + err = queryWithClient(ctx, client, timer, consistency, `SELECT sql FROM sqlite_master WHERE type="index"`) case ".SCHEMA": - err = queryWithClient(ctx, client, argv, timer, consistency, `SELECT sql FROM sqlite_master`) + err = queryWithClient(ctx, client, timer, consistency, `SELECT sql FROM sqlite_master`) case ".TIMER": err = toggleTimer(line[index+1:], &timer) case ".STATUS": err = status(ctx, cmd, line, argv) case ".READY": - err = ready(ctx, client, argv) + err = ready(ctx, httpClient, argv) case ".NODES": err = nodes(ctx, cmd, line, argv) case ".EXPVAR": err = expvar(ctx, cmd, line, argv) case ".REMOVE": - err = removeNode(client, line[index+1:], argv, timer) + err = removeNode(httpClient, line[index+1:], argv, timer) case ".BACKUP": if index == -1 || index == len(line)-1 { err = fmt.Errorf("please specify an output file for the backup") @@ -168,12 +176,18 @@ func main() { case ".QUIT", "QUIT", "EXIT": break FOR_READ case "SELECT", "PRAGMA": - err = queryWithClient(ctx, client, argv, timer, consistency, line) + err = queryWithClient(ctx, client, timer, consistency, line) default: - err = executeWithClient(ctx, client, argv, timer, line) + err = executeWithClient(ctx, client, timer, line) } if err != nil { - ctx.String("%s %v\n", ctx.Color().Red("ERR!"), err) + // if a previous request was executed on a different host, make that change + // visible to the user. + if hcerr, ok := err.(*httpcl.HostChangedError); ok { + prefix = fmt.Sprintf("%s>", hcerr.NewHost) + } else { + ctx.String("%s %v\n", ctx.Color().Red("ERR!"), err) + } } } ctx.String("bye~\n") @@ -554,3 +568,10 @@ func urlsToWriter(urls []string, w io.Writer, argv *argT) error { return nil } + +func createHostList(argv *argT) []string { + var hosts = make([]string, 0) + hosts = append(hosts, fmt.Sprintf("%s:%d", argv.Host, argv.Port)) + hosts = append(hosts, strings.Split(argv.Alternatives, ",")...) + return hosts +} diff --git a/cmd/rqlite/query.go b/cmd/rqlite/query.go index d20a0045..1151c788 100644 --- a/cmd/rqlite/query.go +++ b/cmd/rqlite/query.go @@ -5,10 +5,10 @@ import ( "io/ioutil" "net/http" "net/url" - "strings" "github.com/mkideal/cli" "github.com/mkideal/pkg/textutil" + cl "github.com/rqlite/rqlite/cmd/rqlite/http" ) // Rows represents query result @@ -80,7 +80,7 @@ type queryResponse struct { Time float64 `json:"time"` } -func queryWithClient(ctx *cli.Context, client *http.Client, argv *argT, timer bool, consistency, query string) error { +func queryWithClient(ctx *cli.Context, client *cl.Client, timer bool, consistency, query string) error { queryStr := url.Values{} queryStr.Set("level", consistency) queryStr.Set("q", query) @@ -88,75 +88,54 @@ func queryWithClient(ctx *cli.Context, client *http.Client, argv *argT, timer bo queryStr.Set("timings", "") } u := url.URL{ - Scheme: argv.Protocol, - Host: fmt.Sprintf("%s:%d", argv.Host, argv.Port), - Path: fmt.Sprintf("%sdb/query", argv.Prefix), + Path: fmt.Sprintf("%sdb/query", client.Prefix), RawQuery: queryStr.Encode(), } - urlStr := u.String() - nRedirect := 0 - for { - req, err := http.NewRequest("GET", urlStr, nil) - if err != nil { - return err - } - if argv.Credentials != "" { - creds := strings.Split(argv.Credentials, ":") - if len(creds) != 2 { - return fmt.Errorf("invalid Basic Auth credentials format") - } - req.SetBasicAuth(creds[0], creds[1]) - } + resp, err := client.Query(u) - resp, err := client.Do(req) - if err != nil { + var hcr error + if err != nil { + // If the error is HostChangedError, it should be propagated back to the caller to handle + // accordingly (change prompt display), but we should still assume that the request succeeded on some + // host and not treat it as an error. + err, ok := err.(*cl.HostChangedError) + if !ok { return err } - response, err := ioutil.ReadAll(resp.Body) - if err != nil { - return err - } - resp.Body.Close() - - if resp.StatusCode == http.StatusUnauthorized { - return fmt.Errorf("unauthorized") - } + hcr = err + } - if resp.StatusCode == http.StatusMovedPermanently { - nRedirect++ - if nRedirect > maxRedirect { - return fmt.Errorf("maximum leader redirect limit exceeded") - } - urlStr = resp.Header["Location"][0] - continue - } + response, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + resp.Body.Close() - if resp.StatusCode != http.StatusOK { - return fmt.Errorf("server responded with %s: %s", resp.Status, response) - } + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("server responded with %s: %s", resp.Status, response) + } - // Parse response and write results - ret := &queryResponse{} - if err := parseResponse(&response, &ret); err != nil { - return err - } - if ret.Error != "" { - return fmt.Errorf(ret.Error) - } - if len(ret.Results) != 1 { - return fmt.Errorf("unexpected results length: %d", len(ret.Results)) - } + // Parse response and write results + ret := &queryResponse{} + if err := parseResponse(&response, &ret); err != nil { + return err + } + if ret.Error != "" { + return fmt.Errorf(ret.Error) + } + if len(ret.Results) != 1 { + return fmt.Errorf("unexpected results length: %d", len(ret.Results)) + } - result := ret.Results[0] - if err := result.validate(); err != nil { - return err - } - textutil.WriteTable(ctx, result, headerRender) + result := ret.Results[0] + if err := result.validate(); err != nil { + return err + } + textutil.WriteTable(ctx, result, headerRender) - if timer { - fmt.Printf("Run Time: %f seconds\n", result.Time) - } - return nil + if timer { + fmt.Printf("Run Time: %f seconds\n", result.Time) } + return hcr }