diff --git a/cmd/rqlite/execute.go b/cmd/rqlite/execute.go index 92668577..9dc6d40c 100644 --- a/cmd/rqlite/execute.go +++ b/cmd/rqlite/execute.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "io/ioutil" "net/http" "net/url" "strings" @@ -23,59 +24,92 @@ type executeResponse struct { Time float64 `json:"time,omitempty"` } -func makeExecuteRequest(line string) func(string) (*http.Request, error) { - requestData := strings.NewReader(makeJSONBody(line)) - return func(urlStr string) (*http.Request, error) { - req, err := http.NewRequest("POST", urlStr, requestData) - if err != nil { - return nil, err - } - return req, nil - } -} - -func execute(ctx *cli.Context, cmd, line string, timer bool, argv *argT) error { +func executeWithClient(ctx *cli.Context, client *http.Client, argv *argT, timer bool, stmt string) error { queryStr := url.Values{} if timer { queryStr.Set("timings", "") } u := url.URL{ - Scheme: argv.Protocol, - Host: fmt.Sprintf("%s:%d", argv.Host, argv.Port), - Path: fmt.Sprintf("%sdb/execute", argv.Prefix), - RawQuery: queryStr.Encode(), + Scheme: argv.Protocol, + Host: fmt.Sprintf("%s:%d", argv.Host, argv.Port), + Path: fmt.Sprintf("%sdb/execute", argv.Prefix), } + urlStr := u.String() - response, err := sendRequest(ctx, makeExecuteRequest(line), u.String(), argv) - if err != nil { - return err - } + requestData := strings.NewReader(makeJSONBody(stmt)) - ret := &executeResponse{} - if err := parseResponse(response, &ret); err != nil { - return err - } - if ret.Error != "" { - return fmt.Errorf(ret.Error) - } - if len(ret.Results) != 1 { - return fmt.Errorf("unexpected results length: %d", len(ret.Results)) - } + nRedirect := 0 + for { + req, err := http.NewRequest("POST", urlStr, requestData) + if err != nil { + return err + } + if argv.Credentials != "" { + creds := strings.Split(argv.Credentials, ":") + if len(creds) != 2 { + return fmt.Errorf("invalid Basic Auth credentials format") + } + req.SetBasicAuth(creds[0], creds[1]) + } - result := ret.Results[0] - if result.Error != "" { - ctx.String("Error: %s\n", result.Error) - return nil - } + resp, err := client.Do(req) + if err != nil { + return err + } + response, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + resp.Body.Close() - rowString := "row" - if result.RowsAffected > 1 { - rowString = "rows" - } - if timer { - ctx.String("%d %s affected (%f sec)\n", result.RowsAffected, rowString, result.Time) - } else { - ctx.String("%d %s affected\n", result.RowsAffected, rowString) + if resp.StatusCode == http.StatusUnauthorized { + return fmt.Errorf("unauthorized") + } + + if resp.StatusCode == http.StatusMovedPermanently { + nRedirect++ + if nRedirect > maxRedirect { + return fmt.Errorf("maximum leader redirect limit exceeded") + } + urlStr = resp.Header["Location"][0] + continue + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("server responded with: %s", resp.Status) + } + + // Parse response and write results + ret := &executeResponse{} + if err := parseResponse(&response, &ret); err != nil { + return err + } + if ret.Error != "" { + return fmt.Errorf(ret.Error) + } + if len(ret.Results) != 1 { + return fmt.Errorf("unexpected results length: %d", len(ret.Results)) + } + + result := ret.Results[0] + if result.Error != "" { + ctx.String("Error: %s\n", result.Error) + return nil + } + + rowString := "row" + if result.RowsAffected > 1 { + rowString = "rows" + } + if timer { + ctx.String("%d %s affected (%f sec)\n", result.RowsAffected, rowString, result.Time) + } else { + ctx.String("%d %s affected\n", result.RowsAffected, rowString) + } + + if timer { + fmt.Printf("Run Time: %f seconds\n", result.Time) + } + return nil } - return nil } diff --git a/cmd/rqlite/main.go b/cmd/rqlite/main.go index cd0f259e..9c87f842 100644 --- a/cmd/rqlite/main.go +++ b/cmd/rqlite/main.go @@ -48,6 +48,12 @@ func main() { return nil } + client, err := getHTTPClient(argv) + if err != nil { + ctx.String("%s %v\n", ctx.Color().Red("ERR!"), err) + return nil + } + timer := false prefix := fmt.Sprintf("%s:%d>", argv.Host, argv.Port) term, err := prompt.NewTerminal() @@ -81,11 +87,11 @@ func main() { cmd = strings.ToUpper(cmd) switch cmd { case ".TABLES": - err = query(ctx, cmd, `SELECT name FROM sqlite_master WHERE type="table"`, timer, argv) + err = queryWithClient(ctx, client, argv, timer, `SELECT name FROM sqlite_master WHERE type="table"`) case ".INDEXES": - err = query(ctx, cmd, `SELECT sql FROM sqlite_master WHERE type="index"`, timer, argv) + err = queryWithClient(ctx, client, argv, timer, `SELECT sql FROM sqlite_master WHERE type="index"`) case ".SCHEMA": - err = query(ctx, cmd, "SELECT sql FROM sqlite_master", timer, argv) + err = queryWithClient(ctx, client, argv, timer, "SELECT sql FROM sqlite_master") case ".TIMER": err = toggleTimer(line[index+1:], &timer) case ".STATUS": @@ -115,9 +121,9 @@ func main() { case ".QUIT", "QUIT", "EXIT": break FOR_READ case "SELECT": - err = query(ctx, cmd, line, timer, argv) + err = queryWithClient(ctx, client, argv, timer, line) default: - err = execute(ctx, cmd, line, timer, argv) + err = executeWithClient(ctx, client, argv, timer, line) } if err != nil { ctx.String("%s %v\n", ctx.Color().Red("ERR!"), err) @@ -159,6 +165,36 @@ func expvar(ctx *cli.Context, cmd, line string, argv *argT) error { return cliJSON(ctx, cmd, line, url, argv) } +func getHTTPClient(argv *argT) (*http.Client, error) { + var rootCAs *x509.CertPool + + if argv.CACert != "" { + pemCerts, err := ioutil.ReadFile(argv.CACert) + if err != nil { + return nil, err + } + + rootCAs = x509.NewCertPool() + + ok := rootCAs.AppendCertsFromPEM(pemCerts) + if !ok { + return nil, fmt.Errorf("failed to parse root CA certificate(s)") + } + } + + client := http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: argv.Insecure, RootCAs: rootCAs}, + Proxy: http.ProxyFromEnvironment, + }} + + // Explicitly handle redirects. + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + + return &client, nil +} + func sendRequest(ctx *cli.Context, makeNewRequest func(string) (*http.Request, error), urlStr string, argv *argT) (*[]byte, error) { url := urlStr var rootCAs *x509.CertPool diff --git a/cmd/rqlite/query.go b/cmd/rqlite/query.go index dd26826b..c8d52479 100644 --- a/cmd/rqlite/query.go +++ b/cmd/rqlite/query.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "io/ioutil" "net/http" "net/url" "strings" @@ -82,19 +83,9 @@ type queryResponse struct { Time float64 `json:"time"` } -func makeQueryRequest(line string) func(string) (*http.Request, error) { - requestData := strings.NewReader(makeJSONBody(line)) - return func(urlStr string) (*http.Request, error) { - req, err := http.NewRequest("POST", urlStr, requestData) - if err != nil { - return nil, err - } - return req, nil - } -} - -func query(ctx *cli.Context, cmd, line string, timer bool, argv *argT) error { +func queryWithClient(ctx *cli.Context, client *http.Client, argv *argT, timer bool, query string) error { queryStr := url.Values{} + queryStr.Set("q", query) if timer { queryStr.Set("timings", "") } @@ -104,31 +95,70 @@ func query(ctx *cli.Context, cmd, line string, timer bool, argv *argT) error { Path: fmt.Sprintf("%sdb/query", argv.Prefix), RawQuery: queryStr.Encode(), } + urlStr := u.String() - response, err := sendRequest(ctx, makeQueryRequest(line), u.String(), argv) - if err != nil { - return err - } + nRedirect := 0 + for { + req, err := http.NewRequest("GET", urlStr, nil) + if err != nil { + return err + } + if argv.Credentials != "" { + creds := strings.Split(argv.Credentials, ":") + if len(creds) != 2 { + return fmt.Errorf("invalid Basic Auth credentials format") + } + req.SetBasicAuth(creds[0], creds[1]) + } - ret := &queryResponse{} - if err := parseResponse(response, &ret); err != nil { - return err - } - if ret.Error != "" { - return fmt.Errorf(ret.Error) - } - if len(ret.Results) != 1 { - return fmt.Errorf("unexpected results length: %d", len(ret.Results)) - } + resp, err := client.Do(req) + if err != nil { + return err + } + response, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + resp.Body.Close() - result := ret.Results[0] - if err := result.validate(); err != nil { - return err - } - textutil.WriteTable(ctx, result, headerRender) + if resp.StatusCode == http.StatusUnauthorized { + return fmt.Errorf("unauthorized") + } - if timer { - fmt.Printf("Run Time: %f seconds\n", result.Time) + if resp.StatusCode == http.StatusMovedPermanently { + nRedirect++ + if nRedirect > maxRedirect { + return fmt.Errorf("maximum leader redirect limit exceeded") + } + urlStr = resp.Header["Location"][0] + continue + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("server responded with: %s", resp.Status) + } + + // Parse response and write results + ret := &queryResponse{} + if err := parseResponse(&response, &ret); err != nil { + return err + } + if ret.Error != "" { + return fmt.Errorf(ret.Error) + } + if len(ret.Results) != 1 { + return fmt.Errorf("unexpected results length: %d", len(ret.Results)) + } + + result := ret.Results[0] + if err := result.validate(); err != nil { + return err + } + textutil.WriteTable(ctx, result, headerRender) + + if timer { + fmt.Printf("Run Time: %f seconds\n", result.Time) + } + return nil } - return nil }