From 99ac2353b337a0d6cc42f7ab369d64e3b8732c4e Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sat, 28 Dec 2019 11:05:03 -0500 Subject: [PATCH] Simplify rqlite implementation This results in significant duplicated code, but is easier to follow. The previous code was buggy when it came to redirection handling. Longer term tool needs to be rebuilt to use a proper Go SQL-compliant package (yet to be written). --- cmd/rqlite/execute.go | 122 +++++++++++++++++++++++++++--------------- cmd/rqlite/main.go | 46 ++++++++++++++-- cmd/rqlite/query.go | 98 +++++++++++++++++++++------------ 3 files changed, 183 insertions(+), 83 deletions(-) diff --git a/cmd/rqlite/execute.go b/cmd/rqlite/execute.go index 92668577..9dc6d40c 100644 --- a/cmd/rqlite/execute.go +++ b/cmd/rqlite/execute.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "io/ioutil" "net/http" "net/url" "strings" @@ -23,59 +24,92 @@ type executeResponse struct { Time float64 `json:"time,omitempty"` } -func makeExecuteRequest(line string) func(string) (*http.Request, error) { - requestData := strings.NewReader(makeJSONBody(line)) - return func(urlStr string) (*http.Request, error) { - req, err := http.NewRequest("POST", urlStr, requestData) - if err != nil { - return nil, err - } - return req, nil - } -} - -func execute(ctx *cli.Context, cmd, line string, timer bool, argv *argT) error { +func executeWithClient(ctx *cli.Context, client *http.Client, argv *argT, timer bool, stmt string) error { queryStr := url.Values{} if timer { queryStr.Set("timings", "") } u := url.URL{ - Scheme: argv.Protocol, - Host: fmt.Sprintf("%s:%d", argv.Host, argv.Port), - Path: fmt.Sprintf("%sdb/execute", argv.Prefix), - RawQuery: queryStr.Encode(), + Scheme: argv.Protocol, + Host: fmt.Sprintf("%s:%d", argv.Host, argv.Port), + Path: fmt.Sprintf("%sdb/execute", argv.Prefix), } + urlStr := u.String() - response, err := sendRequest(ctx, makeExecuteRequest(line), u.String(), argv) - if err != nil { - return err - } + requestData := strings.NewReader(makeJSONBody(stmt)) - ret := &executeResponse{} - if err := parseResponse(response, &ret); err != nil { - return err - } - if ret.Error != "" { - return fmt.Errorf(ret.Error) - } - if len(ret.Results) != 1 { - return fmt.Errorf("unexpected results length: %d", len(ret.Results)) - } + nRedirect := 0 + for { + req, err := http.NewRequest("POST", urlStr, requestData) + if err != nil { + return err + } + if argv.Credentials != "" { + creds := strings.Split(argv.Credentials, ":") + if len(creds) != 2 { + return fmt.Errorf("invalid Basic Auth credentials format") + } + req.SetBasicAuth(creds[0], creds[1]) + } - result := ret.Results[0] - if result.Error != "" { - ctx.String("Error: %s\n", result.Error) - return nil - } + resp, err := client.Do(req) + if err != nil { + return err + } + response, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + resp.Body.Close() - rowString := "row" - if result.RowsAffected > 1 { - rowString = "rows" - } - if timer { - ctx.String("%d %s affected (%f sec)\n", result.RowsAffected, rowString, result.Time) - } else { - ctx.String("%d %s affected\n", result.RowsAffected, rowString) + if resp.StatusCode == http.StatusUnauthorized { + return fmt.Errorf("unauthorized") + } + + if resp.StatusCode == http.StatusMovedPermanently { + nRedirect++ + if nRedirect > maxRedirect { + return fmt.Errorf("maximum leader redirect limit exceeded") + } + urlStr = resp.Header["Location"][0] + continue + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("server responded with: %s", resp.Status) + } + + // Parse response and write results + ret := &executeResponse{} + if err := parseResponse(&response, &ret); err != nil { + return err + } + if ret.Error != "" { + return fmt.Errorf(ret.Error) + } + if len(ret.Results) != 1 { + return fmt.Errorf("unexpected results length: %d", len(ret.Results)) + } + + result := ret.Results[0] + if result.Error != "" { + ctx.String("Error: %s\n", result.Error) + return nil + } + + rowString := "row" + if result.RowsAffected > 1 { + rowString = "rows" + } + if timer { + ctx.String("%d %s affected (%f sec)\n", result.RowsAffected, rowString, result.Time) + } else { + ctx.String("%d %s affected\n", result.RowsAffected, rowString) + } + + if timer { + fmt.Printf("Run Time: %f seconds\n", result.Time) + } + return nil } - return nil } diff --git a/cmd/rqlite/main.go b/cmd/rqlite/main.go index cd0f259e..9c87f842 100644 --- a/cmd/rqlite/main.go +++ b/cmd/rqlite/main.go @@ -48,6 +48,12 @@ func main() { return nil } + client, err := getHTTPClient(argv) + if err != nil { + ctx.String("%s %v\n", ctx.Color().Red("ERR!"), err) + return nil + } + timer := false prefix := fmt.Sprintf("%s:%d>", argv.Host, argv.Port) term, err := prompt.NewTerminal() @@ -81,11 +87,11 @@ func main() { cmd = strings.ToUpper(cmd) switch cmd { case ".TABLES": - err = query(ctx, cmd, `SELECT name FROM sqlite_master WHERE type="table"`, timer, argv) + err = queryWithClient(ctx, client, argv, timer, `SELECT name FROM sqlite_master WHERE type="table"`) case ".INDEXES": - err = query(ctx, cmd, `SELECT sql FROM sqlite_master WHERE type="index"`, timer, argv) + err = queryWithClient(ctx, client, argv, timer, `SELECT sql FROM sqlite_master WHERE type="index"`) case ".SCHEMA": - err = query(ctx, cmd, "SELECT sql FROM sqlite_master", timer, argv) + err = queryWithClient(ctx, client, argv, timer, "SELECT sql FROM sqlite_master") case ".TIMER": err = toggleTimer(line[index+1:], &timer) case ".STATUS": @@ -115,9 +121,9 @@ func main() { case ".QUIT", "QUIT", "EXIT": break FOR_READ case "SELECT": - err = query(ctx, cmd, line, timer, argv) + err = queryWithClient(ctx, client, argv, timer, line) default: - err = execute(ctx, cmd, line, timer, argv) + err = executeWithClient(ctx, client, argv, timer, line) } if err != nil { ctx.String("%s %v\n", ctx.Color().Red("ERR!"), err) @@ -159,6 +165,36 @@ func expvar(ctx *cli.Context, cmd, line string, argv *argT) error { return cliJSON(ctx, cmd, line, url, argv) } +func getHTTPClient(argv *argT) (*http.Client, error) { + var rootCAs *x509.CertPool + + if argv.CACert != "" { + pemCerts, err := ioutil.ReadFile(argv.CACert) + if err != nil { + return nil, err + } + + rootCAs = x509.NewCertPool() + + ok := rootCAs.AppendCertsFromPEM(pemCerts) + if !ok { + return nil, fmt.Errorf("failed to parse root CA certificate(s)") + } + } + + client := http.Client{Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: argv.Insecure, RootCAs: rootCAs}, + Proxy: http.ProxyFromEnvironment, + }} + + // Explicitly handle redirects. + client.CheckRedirect = func(req *http.Request, via []*http.Request) error { + return http.ErrUseLastResponse + } + + return &client, nil +} + func sendRequest(ctx *cli.Context, makeNewRequest func(string) (*http.Request, error), urlStr string, argv *argT) (*[]byte, error) { url := urlStr var rootCAs *x509.CertPool diff --git a/cmd/rqlite/query.go b/cmd/rqlite/query.go index dd26826b..c8d52479 100644 --- a/cmd/rqlite/query.go +++ b/cmd/rqlite/query.go @@ -2,6 +2,7 @@ package main import ( "fmt" + "io/ioutil" "net/http" "net/url" "strings" @@ -82,19 +83,9 @@ type queryResponse struct { Time float64 `json:"time"` } -func makeQueryRequest(line string) func(string) (*http.Request, error) { - requestData := strings.NewReader(makeJSONBody(line)) - return func(urlStr string) (*http.Request, error) { - req, err := http.NewRequest("POST", urlStr, requestData) - if err != nil { - return nil, err - } - return req, nil - } -} - -func query(ctx *cli.Context, cmd, line string, timer bool, argv *argT) error { +func queryWithClient(ctx *cli.Context, client *http.Client, argv *argT, timer bool, query string) error { queryStr := url.Values{} + queryStr.Set("q", query) if timer { queryStr.Set("timings", "") } @@ -104,31 +95,70 @@ func query(ctx *cli.Context, cmd, line string, timer bool, argv *argT) error { Path: fmt.Sprintf("%sdb/query", argv.Prefix), RawQuery: queryStr.Encode(), } + urlStr := u.String() - response, err := sendRequest(ctx, makeQueryRequest(line), u.String(), argv) - if err != nil { - return err - } + nRedirect := 0 + for { + req, err := http.NewRequest("GET", urlStr, nil) + if err != nil { + return err + } + if argv.Credentials != "" { + creds := strings.Split(argv.Credentials, ":") + if len(creds) != 2 { + return fmt.Errorf("invalid Basic Auth credentials format") + } + req.SetBasicAuth(creds[0], creds[1]) + } - ret := &queryResponse{} - if err := parseResponse(response, &ret); err != nil { - return err - } - if ret.Error != "" { - return fmt.Errorf(ret.Error) - } - if len(ret.Results) != 1 { - return fmt.Errorf("unexpected results length: %d", len(ret.Results)) - } + resp, err := client.Do(req) + if err != nil { + return err + } + response, err := ioutil.ReadAll(resp.Body) + if err != nil { + return err + } + resp.Body.Close() - result := ret.Results[0] - if err := result.validate(); err != nil { - return err - } - textutil.WriteTable(ctx, result, headerRender) + if resp.StatusCode == http.StatusUnauthorized { + return fmt.Errorf("unauthorized") + } - if timer { - fmt.Printf("Run Time: %f seconds\n", result.Time) + if resp.StatusCode == http.StatusMovedPermanently { + nRedirect++ + if nRedirect > maxRedirect { + return fmt.Errorf("maximum leader redirect limit exceeded") + } + urlStr = resp.Header["Location"][0] + continue + } + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("server responded with: %s", resp.Status) + } + + // Parse response and write results + ret := &queryResponse{} + if err := parseResponse(&response, &ret); err != nil { + return err + } + if ret.Error != "" { + return fmt.Errorf(ret.Error) + } + if len(ret.Results) != 1 { + return fmt.Errorf("unexpected results length: %d", len(ret.Results)) + } + + result := ret.Results[0] + if err := result.validate(); err != nil { + return err + } + textutil.WriteTable(ctx, result, headerRender) + + if timer { + fmt.Printf("Run Time: %f seconds\n", result.Time) + } + return nil } - return nil }