1
0
Fork 0

Simplify rqlite implementation

This results in significant duplicated code, but is easier to follow.
The previous code was buggy when it came to redirection handling.
Longer term tool needs to be rebuilt to use a proper Go SQL-compliant
package (yet to be written).
master
Philip O'Toole 5 years ago
parent 0b7bbd85e2
commit 99ac2353b3

@ -2,6 +2,7 @@ package main
import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
@ -23,59 +24,92 @@ type executeResponse struct {
Time float64 `json:"time,omitempty"`
}
func makeExecuteRequest(line string) func(string) (*http.Request, error) {
requestData := strings.NewReader(makeJSONBody(line))
return func(urlStr string) (*http.Request, error) {
req, err := http.NewRequest("POST", urlStr, requestData)
if err != nil {
return nil, err
}
return req, nil
}
}
func execute(ctx *cli.Context, cmd, line string, timer bool, argv *argT) error {
func executeWithClient(ctx *cli.Context, client *http.Client, argv *argT, timer bool, stmt string) error {
queryStr := url.Values{}
if timer {
queryStr.Set("timings", "")
}
u := url.URL{
Scheme: argv.Protocol,
Host: fmt.Sprintf("%s:%d", argv.Host, argv.Port),
Path: fmt.Sprintf("%sdb/execute", argv.Prefix),
RawQuery: queryStr.Encode(),
Scheme: argv.Protocol,
Host: fmt.Sprintf("%s:%d", argv.Host, argv.Port),
Path: fmt.Sprintf("%sdb/execute", argv.Prefix),
}
urlStr := u.String()
response, err := sendRequest(ctx, makeExecuteRequest(line), u.String(), argv)
if err != nil {
return err
}
requestData := strings.NewReader(makeJSONBody(stmt))
ret := &executeResponse{}
if err := parseResponse(response, &ret); err != nil {
return err
}
if ret.Error != "" {
return fmt.Errorf(ret.Error)
}
if len(ret.Results) != 1 {
return fmt.Errorf("unexpected results length: %d", len(ret.Results))
}
nRedirect := 0
for {
req, err := http.NewRequest("POST", urlStr, requestData)
if err != nil {
return err
}
if argv.Credentials != "" {
creds := strings.Split(argv.Credentials, ":")
if len(creds) != 2 {
return fmt.Errorf("invalid Basic Auth credentials format")
}
req.SetBasicAuth(creds[0], creds[1])
}
result := ret.Results[0]
if result.Error != "" {
ctx.String("Error: %s\n", result.Error)
return nil
}
resp, err := client.Do(req)
if err != nil {
return err
}
response, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
resp.Body.Close()
rowString := "row"
if result.RowsAffected > 1 {
rowString = "rows"
}
if timer {
ctx.String("%d %s affected (%f sec)\n", result.RowsAffected, rowString, result.Time)
} else {
ctx.String("%d %s affected\n", result.RowsAffected, rowString)
if resp.StatusCode == http.StatusUnauthorized {
return fmt.Errorf("unauthorized")
}
if resp.StatusCode == http.StatusMovedPermanently {
nRedirect++
if nRedirect > maxRedirect {
return fmt.Errorf("maximum leader redirect limit exceeded")
}
urlStr = resp.Header["Location"][0]
continue
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("server responded with: %s", resp.Status)
}
// Parse response and write results
ret := &executeResponse{}
if err := parseResponse(&response, &ret); err != nil {
return err
}
if ret.Error != "" {
return fmt.Errorf(ret.Error)
}
if len(ret.Results) != 1 {
return fmt.Errorf("unexpected results length: %d", len(ret.Results))
}
result := ret.Results[0]
if result.Error != "" {
ctx.String("Error: %s\n", result.Error)
return nil
}
rowString := "row"
if result.RowsAffected > 1 {
rowString = "rows"
}
if timer {
ctx.String("%d %s affected (%f sec)\n", result.RowsAffected, rowString, result.Time)
} else {
ctx.String("%d %s affected\n", result.RowsAffected, rowString)
}
if timer {
fmt.Printf("Run Time: %f seconds\n", result.Time)
}
return nil
}
return nil
}

@ -48,6 +48,12 @@ func main() {
return nil
}
client, err := getHTTPClient(argv)
if err != nil {
ctx.String("%s %v\n", ctx.Color().Red("ERR!"), err)
return nil
}
timer := false
prefix := fmt.Sprintf("%s:%d>", argv.Host, argv.Port)
term, err := prompt.NewTerminal()
@ -81,11 +87,11 @@ func main() {
cmd = strings.ToUpper(cmd)
switch cmd {
case ".TABLES":
err = query(ctx, cmd, `SELECT name FROM sqlite_master WHERE type="table"`, timer, argv)
err = queryWithClient(ctx, client, argv, timer, `SELECT name FROM sqlite_master WHERE type="table"`)
case ".INDEXES":
err = query(ctx, cmd, `SELECT sql FROM sqlite_master WHERE type="index"`, timer, argv)
err = queryWithClient(ctx, client, argv, timer, `SELECT sql FROM sqlite_master WHERE type="index"`)
case ".SCHEMA":
err = query(ctx, cmd, "SELECT sql FROM sqlite_master", timer, argv)
err = queryWithClient(ctx, client, argv, timer, "SELECT sql FROM sqlite_master")
case ".TIMER":
err = toggleTimer(line[index+1:], &timer)
case ".STATUS":
@ -115,9 +121,9 @@ func main() {
case ".QUIT", "QUIT", "EXIT":
break FOR_READ
case "SELECT":
err = query(ctx, cmd, line, timer, argv)
err = queryWithClient(ctx, client, argv, timer, line)
default:
err = execute(ctx, cmd, line, timer, argv)
err = executeWithClient(ctx, client, argv, timer, line)
}
if err != nil {
ctx.String("%s %v\n", ctx.Color().Red("ERR!"), err)
@ -159,6 +165,36 @@ func expvar(ctx *cli.Context, cmd, line string, argv *argT) error {
return cliJSON(ctx, cmd, line, url, argv)
}
func getHTTPClient(argv *argT) (*http.Client, error) {
var rootCAs *x509.CertPool
if argv.CACert != "" {
pemCerts, err := ioutil.ReadFile(argv.CACert)
if err != nil {
return nil, err
}
rootCAs = x509.NewCertPool()
ok := rootCAs.AppendCertsFromPEM(pemCerts)
if !ok {
return nil, fmt.Errorf("failed to parse root CA certificate(s)")
}
}
client := http.Client{Transport: &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: argv.Insecure, RootCAs: rootCAs},
Proxy: http.ProxyFromEnvironment,
}}
// Explicitly handle redirects.
client.CheckRedirect = func(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
return &client, nil
}
func sendRequest(ctx *cli.Context, makeNewRequest func(string) (*http.Request, error), urlStr string, argv *argT) (*[]byte, error) {
url := urlStr
var rootCAs *x509.CertPool

@ -2,6 +2,7 @@ package main
import (
"fmt"
"io/ioutil"
"net/http"
"net/url"
"strings"
@ -82,19 +83,9 @@ type queryResponse struct {
Time float64 `json:"time"`
}
func makeQueryRequest(line string) func(string) (*http.Request, error) {
requestData := strings.NewReader(makeJSONBody(line))
return func(urlStr string) (*http.Request, error) {
req, err := http.NewRequest("POST", urlStr, requestData)
if err != nil {
return nil, err
}
return req, nil
}
}
func query(ctx *cli.Context, cmd, line string, timer bool, argv *argT) error {
func queryWithClient(ctx *cli.Context, client *http.Client, argv *argT, timer bool, query string) error {
queryStr := url.Values{}
queryStr.Set("q", query)
if timer {
queryStr.Set("timings", "")
}
@ -104,31 +95,70 @@ func query(ctx *cli.Context, cmd, line string, timer bool, argv *argT) error {
Path: fmt.Sprintf("%sdb/query", argv.Prefix),
RawQuery: queryStr.Encode(),
}
urlStr := u.String()
response, err := sendRequest(ctx, makeQueryRequest(line), u.String(), argv)
if err != nil {
return err
}
nRedirect := 0
for {
req, err := http.NewRequest("GET", urlStr, nil)
if err != nil {
return err
}
if argv.Credentials != "" {
creds := strings.Split(argv.Credentials, ":")
if len(creds) != 2 {
return fmt.Errorf("invalid Basic Auth credentials format")
}
req.SetBasicAuth(creds[0], creds[1])
}
ret := &queryResponse{}
if err := parseResponse(response, &ret); err != nil {
return err
}
if ret.Error != "" {
return fmt.Errorf(ret.Error)
}
if len(ret.Results) != 1 {
return fmt.Errorf("unexpected results length: %d", len(ret.Results))
}
resp, err := client.Do(req)
if err != nil {
return err
}
response, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
resp.Body.Close()
result := ret.Results[0]
if err := result.validate(); err != nil {
return err
}
textutil.WriteTable(ctx, result, headerRender)
if resp.StatusCode == http.StatusUnauthorized {
return fmt.Errorf("unauthorized")
}
if timer {
fmt.Printf("Run Time: %f seconds\n", result.Time)
if resp.StatusCode == http.StatusMovedPermanently {
nRedirect++
if nRedirect > maxRedirect {
return fmt.Errorf("maximum leader redirect limit exceeded")
}
urlStr = resp.Header["Location"][0]
continue
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("server responded with: %s", resp.Status)
}
// Parse response and write results
ret := &queryResponse{}
if err := parseResponse(&response, &ret); err != nil {
return err
}
if ret.Error != "" {
return fmt.Errorf(ret.Error)
}
if len(ret.Results) != 1 {
return fmt.Errorf("unexpected results length: %d", len(ret.Results))
}
result := ret.Results[0]
if err := result.validate(); err != nil {
return err
}
textutil.WriteTable(ctx, result, headerRender)
if timer {
fmt.Printf("Run Time: %f seconds\n", result.Time)
}
return nil
}
return nil
}

Loading…
Cancel
Save