1
0
Fork 0

[Draft] Accept alternative servers to connect to in rqlite cli (#947)

Add -alternatives flag to fallback to when hosts are unavailable
master
Mehdi Cheracher 3 years ago committed by GitHub
parent 8f0f5f9ebc
commit 80881e7b8b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -9,6 +9,7 @@ import (
"strings"
"github.com/mkideal/cli"
cl "github.com/rqlite/rqlite/cmd/rqlite/http"
)
// Result represents execute result
@ -25,96 +26,75 @@ type executeResponse struct {
Time float64 `json:"time,omitempty"`
}
func executeWithClient(ctx *cli.Context, client *http.Client, argv *argT, timer bool, stmt string) error {
func executeWithClient(ctx *cli.Context, client *cl.Client, timer bool, stmt string) error {
queryStr := url.Values{}
if timer {
queryStr.Set("timings", "")
}
u := url.URL{
Scheme: argv.Protocol,
Host: fmt.Sprintf("%s:%d", argv.Host, argv.Port),
Path: fmt.Sprintf("%sdb/execute", argv.Prefix),
Path: fmt.Sprintf("%sdb/execute", client.Prefix),
}
urlStr := u.String()
requestData := strings.NewReader(makeJSONBody(stmt))
nRedirect := 0
for {
if _, err := requestData.Seek(0, io.SeekStart); err != nil {
return err
}
if _, err := requestData.Seek(0, io.SeekStart); err != nil {
return err
}
req, err := http.NewRequest("POST", urlStr, requestData)
if err != nil {
return err
}
if argv.Credentials != "" {
creds := strings.Split(argv.Credentials, ":")
if len(creds) != 2 {
return fmt.Errorf("invalid Basic Auth credentials format")
}
req.SetBasicAuth(creds[0], creds[1])
}
resp, err := client.Execute(u, requestData)
resp, err := client.Do(req)
if err != nil {
return err
}
response, err := ioutil.ReadAll(resp.Body)
if err != nil {
var hcr error
if err != nil {
// If the error is HostChangedError, it should be propagated back to the caller to handle
// accordingly (change prompt display), but we should still assume that the request succeeded on some
// host and not treat it as an error.
err, ok := err.(*cl.HostChangedError)
if !ok {
return err
}
resp.Body.Close()
if resp.StatusCode == http.StatusUnauthorized {
return fmt.Errorf("unauthorized")
}
hcr = err
}
if resp.StatusCode == http.StatusMovedPermanently {
nRedirect++
if nRedirect > maxRedirect {
return fmt.Errorf("maximum leader redirect limit exceeded")
}
urlStr = resp.Header["Location"][0]
continue
}
response, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("server responded with %s: %s", resp.Status, response)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("server responded with %s: %s", resp.Status, response)
}
// Parse response and write results
ret := &executeResponse{}
if err := parseResponse(&response, &ret); err != nil {
return err
}
if ret.Error != "" {
return fmt.Errorf(ret.Error)
}
if len(ret.Results) != 1 {
return fmt.Errorf("unexpected results length: %d", len(ret.Results))
}
// Parse response and write results
ret := &executeResponse{}
if err := parseResponse(&response, &ret); err != nil {
return err
}
if ret.Error != "" {
return fmt.Errorf(ret.Error)
}
if len(ret.Results) != 1 {
return fmt.Errorf("unexpected results length: %d", len(ret.Results))
}
result := ret.Results[0]
if result.Error != "" {
ctx.String("Error: %s\n", result.Error)
return nil
}
result := ret.Results[0]
if result.Error != "" {
ctx.String("Error: %s\n", result.Error)
return nil
}
rowString := "row"
if result.RowsAffected > 1 {
rowString = "rows"
}
if timer {
ctx.String("%d %s affected (%f sec)\n", result.RowsAffected, rowString, result.Time)
} else {
ctx.String("%d %s affected\n", result.RowsAffected, rowString)
}
rowString := "row"
if result.RowsAffected > 1 {
rowString = "rows"
}
if timer {
ctx.String("%d %s affected (%f sec)\n", result.RowsAffected, rowString, result.Time)
} else {
ctx.String("%d %s affected\n", result.RowsAffected, rowString)
}
if timer {
fmt.Printf("Run Time: %f seconds\n", result.Time)
}
return nil
if timer {
fmt.Printf("Run Time: %f seconds\n", result.Time)
}
return hcr
}

@ -0,0 +1,191 @@
package http
import (
"fmt"
"io"
"log"
"net/http"
"net/url"
"os"
"strings"
)
// ErrNoAvailableHost indicates that the client could not find an available host to send the request to.
var ErrNoAvailableHost = fmt.Errorf("no host available to perform the request")
// ErrTooManyRedirects indicates that the client exceeded the maximum number of redirects
var ErrTooManyRedirects = fmt.Errorf("maximum leader redirect limit exceeded")
// HostChangedError indicates that the underlying request was executed on a different host
// different from the caller anticipated
type HostChangedError struct {
NewHost string
}
func (he *HostChangedError) Error() string {
return fmt.Sprintf("HostChangedErr: new host is '%s'", he.NewHost)
}
type ConfigFunc func(*Client)
// Client is a wrapper around stock `http.Client` that adds "retry on another host" behaviour
// based on the supplied configuration.
//
// The client will fall back and try other nodes when the current node is unavailable, and would stop trying
// after exhausting the list of supplied hosts.
//
// Note:
//
// This type is not goroutine safe.
// A node is considered unavailable if the client is not reachable via the network.
// TODO: make the unavailability condition for the client more dynamic.
type Client struct {
*http.Client
scheme string
hosts []string
Prefix string
// creds stores the http basic authentication username and password
creds string
logger *log.Logger
// currentHost keeps track of the last available host
currentHost int
maxRedirect int
}
// NewClient creates a default client that sends `execute` and query `requests` against the
// rqlited nodes supplied via `hosts` argument.
func NewClient(client *http.Client, hosts []string, configFuncs ...ConfigFunc) *Client {
cl := &Client{
Client: client,
hosts: hosts,
scheme: "http",
maxRedirect: 21,
Prefix: "/",
logger: log.New(os.Stderr, "[client] ", log.LstdFlags),
}
for _, f := range configFuncs {
f(cl)
}
return cl
}
// WithScheme changes the default scheme used i.e "http".
func WithScheme(scheme string) ConfigFunc {
return func(client *Client) {
client.scheme = scheme
}
}
// WithPrefix sets the prefix to be used when issuing HTTP requests against one of
// the rqlited nodes.
func WithPrefix(prefix string) ConfigFunc {
return func(client *Client) {
client.Prefix = prefix
}
}
// WithLogger changes the default logger to the one provided.
func WithLogger(logger *log.Logger) ConfigFunc {
return func(client *Client) {
client.logger = logger
}
}
// WithBasicAuth adds basic authentication behaviour to the client's request.
func WithBasicAuth(creds string) ConfigFunc {
return func(client *Client) {
client.creds = creds
}
}
// Query sends GET requests to one of the hosts known to the client.
func (c *Client) Query(url url.URL) (*http.Response, error) {
return c.execRequest(http.MethodGet, url, nil)
}
// Execute sends POST requests to one of the hosts known to the client
func (c *Client) Execute(url url.URL, body io.Reader) (*http.Response, error) {
return c.execRequest(http.MethodPost, url, body)
}
func (c *Client) execRequest(method string, url url.URL, body io.Reader) (*http.Response, error) {
triedHosts := 0
for triedHosts < len(c.hosts) {
host := c.hosts[c.currentHost]
url.Scheme = c.scheme
url.Host = host
urlStr := url.String()
resp, err := c.requestFollowRedirect(method, urlStr, body)
// Found a responsive node
if err == nil {
if triedHosts > 0 {
return resp, &HostChangedError{NewHost: host}
}
return resp, nil
}
// If we did too many redirects, we will consider the host as unavailable as well,
// and we will retry the request from another host
if err == ErrTooManyRedirects {
c.logger.Printf("too many redirects from host: '%s'", host)
}
c.logger.Printf("host '%s' is unavailable, retrying with the next available host", host)
triedHosts++
c.nextHost()
}
c.logger.Printf("none of the available hosts are responsive")
return nil, ErrNoAvailableHost
}
func (c *Client) nextHost() {
c.currentHost = (c.currentHost + 1) % len(c.hosts)
}
func (c *Client) requestFollowRedirect(method string, urlStr string, body io.Reader) (*http.Response, error) {
nRedirects := 0
for {
req, err := http.NewRequest(method, urlStr, body)
if err != nil {
return nil, err
}
err = c.setBasicAuth(req)
if err != nil {
return nil, err
}
resp, err := c.Client.Do(req)
if err != nil {
return nil, err
}
if resp.StatusCode == http.StatusMovedPermanently {
nRedirects++
if nRedirects > c.maxRedirect {
return resp, ErrTooManyRedirects
}
urlStr = resp.Header["Location"][0]
continue
}
return resp, nil
}
}
func (c *Client) setBasicAuth(req *http.Request) error {
if c.creds == "" {
return nil
}
creds := strings.Split(c.creds, ":")
if len(creds) != 2 {
return fmt.Errorf("invalid Basic Auth credential format")
}
req.SetBasicAuth(creds[0], creds[1])
return nil
}

@ -0,0 +1,164 @@
package http
import (
"net/http"
"net/http/httptest"
"net/url"
"testing"
)
func TestClient_QueryWhenAllAvailable(t *testing.T) {
node1 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
}))
defer node1.Close()
node2 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
}))
defer node2.Close()
httpClient := http.DefaultClient
u1, _ := url.Parse(node1.URL)
u2, _ := url.Parse(node2.URL)
client := NewClient(httpClient, []string{u1.Host, u2.Host})
res, err := client.Query(url.URL{
Path: "/",
})
defer res.Body.Close()
if err != nil {
t.Errorf("unexpected error: %v", err)
}
if client.currentHost != 0 {
t.Errorf("expected to only forward requests to the first host")
}
if res.StatusCode != http.StatusOK {
t.Errorf("unexpected status code, expected '200' got '%d'", res.StatusCode)
}
}
func TestClient_QueryWhenSomeAreAvailable(t *testing.T) {
node1 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
}))
node2 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
}))
defer node2.Close()
httpClient := http.DefaultClient
// Shutting down one of the hosts making it unavailable
node1.Close()
u1, _ := url.Parse(node1.URL)
u2, _ := url.Parse(node2.URL)
client := NewClient(httpClient, []string{u1.Host, u2.Host})
res, err := client.Query(url.URL{
Path: "/",
})
defer res.Body.Close()
// If the request succeeds after changing hosts, it should be reflected in the returned error
// as HostChangedError
if err == nil {
t.Errorf("expected HostChangedError got nil instead")
}
hcer, ok := err.(*HostChangedError)
if !ok {
t.Errorf("unexpected error occurred: %v", err)
}
if hcer.NewHost != u2.Host {
t.Errorf("unexpected responding host")
}
if client.currentHost != 1 {
t.Errorf("expected to move on to the following host")
}
if res.StatusCode != http.StatusOK {
t.Errorf("unexpected status code, expected '200' got '%d'", res.StatusCode)
}
}
func TestClient_QueryWhenAllUnavailable(t *testing.T) {
node1 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
}))
node2 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
writer.WriteHeader(http.StatusOK)
}))
httpClient := http.DefaultClient
u1, _ := url.Parse(node1.URL)
u2, _ := url.Parse(node2.URL)
// Shutting down both nodes, both of them now are unavailable
node1.Close()
node2.Close()
client := NewClient(httpClient, []string{u1.Host, u2.Host})
_, err := client.Query(url.URL{
Path: "/",
})
if err != ErrNoAvailableHost {
t.Errorf("Expected %v, got: %v", ErrNoAvailableHost, err)
}
}
func TestClient_BasicAuthIsForwarded(t *testing.T) {
mockAuth := func(request *http.Request) bool {
user, pass, ok := request.BasicAuth()
if ok {
if user == "john" && pass == "doe" {
return true
}
}
return false
}
node1 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
if mockAuth(request) {
writer.WriteHeader(http.StatusOK)
return
}
writer.WriteHeader(http.StatusUnauthorized)
}))
defer node1.Close()
node2 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
if mockAuth(request) {
writer.WriteHeader(http.StatusOK)
return
}
writer.WriteHeader(http.StatusUnauthorized)
}))
defer node2.Close()
httpClient := http.DefaultClient
u1, _ := url.Parse(node1.URL)
u2, _ := url.Parse(node2.URL)
client := NewClient(httpClient, []string{u1.Host, u2.Host}, WithBasicAuth("john:wrongpassword"))
res, err := client.Query(url.URL{
Path: "/",
})
if err != nil {
t.Errorf("unexpected error")
}
if res.StatusCode != http.StatusUnauthorized {
t.Errorf("expected unauthorized status")
}
}

@ -18,20 +18,22 @@ import (
"github.com/Bowery/prompt"
"github.com/mkideal/cli"
"github.com/rqlite/rqlite/cmd"
httpcl "github.com/rqlite/rqlite/cmd/rqlite/http"
)
const maxRedirect = 21
type argT struct {
cli.Helper
Protocol string `cli:"s,scheme" usage:"protocol scheme (http or https)" dft:"http"`
Host string `cli:"H,host" usage:"rqlited host address" dft:"127.0.0.1"`
Port uint16 `cli:"p,port" usage:"rqlited host port" dft:"4001"`
Prefix string `cli:"P,prefix" usage:"rqlited HTTP URL prefix" dft:"/"`
Insecure bool `cli:"i,insecure" usage:"do not verify rqlited HTTPS certificate" dft:"false"`
CACert string `cli:"c,ca-cert" usage:"path to trusted X.509 root CA certificate"`
Credentials string `cli:"u,user" usage:"set basic auth credentials in form username:password"`
Version bool `cli:"v,version" usage:"display CLI version"`
Alternatives string `cli:"a,alternatives" usage:"comma separated list of 'host:port' pairs to use as fallback"`
Protocol string `cli:"s,scheme" usage:"protocol scheme (http or https)" dft:"http"`
Host string `cli:"H,host" usage:"rqlited host address" dft:"127.0.0.1"`
Port uint16 `cli:"p,port" usage:"rqlited host port" dft:"4001"`
Prefix string `cli:"P,prefix" usage:"rqlited HTTP URL prefix" dft:"/"`
Insecure bool `cli:"i,insecure" usage:"do not verify rqlited HTTPS certificate" dft:"false"`
CACert string `cli:"c,ca-cert" usage:"path to trusted X.509 root CA certificate"`
Credentials string `cli:"u,user" usage:"set basic auth credentials in form username:password"`
Version bool `cli:"v,version" usage:"display CLI version"`
}
var cliHelp = []string{
@ -67,13 +69,13 @@ func main() {
return nil
}
client, err := getHTTPClient(argv)
httpClient, err := getHTTPClient(argv)
if err != nil {
ctx.String("%s %v\n", ctx.Color().Red("ERR!"), err)
return nil
}
version, err := getVersionWithClient(client, argv)
version, err := getVersionWithClient(httpClient, argv)
if err != nil {
ctx.String("%s %v\n", ctx.Color().Red("ERR!"), err)
return nil
@ -93,6 +95,12 @@ func main() {
}
term.Close()
hosts := createHostList(argv)
client := httpcl.NewClient(httpClient, hosts,
httpcl.WithScheme(argv.Protocol),
httpcl.WithBasicAuth(argv.Credentials),
httpcl.WithPrefix(argv.Prefix))
FOR_READ:
for {
term.Reopen()
@ -122,23 +130,23 @@ func main() {
}
err = setConsistency(line[index+1:], &consistency)
case ".TABLES":
err = queryWithClient(ctx, client, argv, timer, consistency, `SELECT name FROM sqlite_master WHERE type="table"`)
err = queryWithClient(ctx, client, timer, consistency, `SELECT name FROM sqlite_master WHERE type="table"`)
case ".INDEXES":
err = queryWithClient(ctx, client, argv, timer, consistency, `SELECT sql FROM sqlite_master WHERE type="index"`)
err = queryWithClient(ctx, client, timer, consistency, `SELECT sql FROM sqlite_master WHERE type="index"`)
case ".SCHEMA":
err = queryWithClient(ctx, client, argv, timer, consistency, `SELECT sql FROM sqlite_master`)
err = queryWithClient(ctx, client, timer, consistency, `SELECT sql FROM sqlite_master`)
case ".TIMER":
err = toggleTimer(line[index+1:], &timer)
case ".STATUS":
err = status(ctx, cmd, line, argv)
case ".READY":
err = ready(ctx, client, argv)
err = ready(ctx, httpClient, argv)
case ".NODES":
err = nodes(ctx, cmd, line, argv)
case ".EXPVAR":
err = expvar(ctx, cmd, line, argv)
case ".REMOVE":
err = removeNode(client, line[index+1:], argv, timer)
err = removeNode(httpClient, line[index+1:], argv, timer)
case ".BACKUP":
if index == -1 || index == len(line)-1 {
err = fmt.Errorf("please specify an output file for the backup")
@ -168,12 +176,18 @@ func main() {
case ".QUIT", "QUIT", "EXIT":
break FOR_READ
case "SELECT", "PRAGMA":
err = queryWithClient(ctx, client, argv, timer, consistency, line)
err = queryWithClient(ctx, client, timer, consistency, line)
default:
err = executeWithClient(ctx, client, argv, timer, line)
err = executeWithClient(ctx, client, timer, line)
}
if err != nil {
ctx.String("%s %v\n", ctx.Color().Red("ERR!"), err)
// if a previous request was executed on a different host, make that change
// visible to the user.
if hcerr, ok := err.(*httpcl.HostChangedError); ok {
prefix = fmt.Sprintf("%s>", hcerr.NewHost)
} else {
ctx.String("%s %v\n", ctx.Color().Red("ERR!"), err)
}
}
}
ctx.String("bye~\n")
@ -554,3 +568,10 @@ func urlsToWriter(urls []string, w io.Writer, argv *argT) error {
return nil
}
func createHostList(argv *argT) []string {
var hosts = make([]string, 0)
hosts = append(hosts, fmt.Sprintf("%s:%d", argv.Host, argv.Port))
hosts = append(hosts, strings.Split(argv.Alternatives, ",")...)
return hosts
}

@ -5,10 +5,10 @@ import (
"io/ioutil"
"net/http"
"net/url"
"strings"
"github.com/mkideal/cli"
"github.com/mkideal/pkg/textutil"
cl "github.com/rqlite/rqlite/cmd/rqlite/http"
)
// Rows represents query result
@ -80,7 +80,7 @@ type queryResponse struct {
Time float64 `json:"time"`
}
func queryWithClient(ctx *cli.Context, client *http.Client, argv *argT, timer bool, consistency, query string) error {
func queryWithClient(ctx *cli.Context, client *cl.Client, timer bool, consistency, query string) error {
queryStr := url.Values{}
queryStr.Set("level", consistency)
queryStr.Set("q", query)
@ -88,75 +88,54 @@ func queryWithClient(ctx *cli.Context, client *http.Client, argv *argT, timer bo
queryStr.Set("timings", "")
}
u := url.URL{
Scheme: argv.Protocol,
Host: fmt.Sprintf("%s:%d", argv.Host, argv.Port),
Path: fmt.Sprintf("%sdb/query", argv.Prefix),
Path: fmt.Sprintf("%sdb/query", client.Prefix),
RawQuery: queryStr.Encode(),
}
urlStr := u.String()
nRedirect := 0
for {
req, err := http.NewRequest("GET", urlStr, nil)
if err != nil {
return err
}
if argv.Credentials != "" {
creds := strings.Split(argv.Credentials, ":")
if len(creds) != 2 {
return fmt.Errorf("invalid Basic Auth credentials format")
}
req.SetBasicAuth(creds[0], creds[1])
}
resp, err := client.Query(u)
resp, err := client.Do(req)
if err != nil {
var hcr error
if err != nil {
// If the error is HostChangedError, it should be propagated back to the caller to handle
// accordingly (change prompt display), but we should still assume that the request succeeded on some
// host and not treat it as an error.
err, ok := err.(*cl.HostChangedError)
if !ok {
return err
}
response, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
resp.Body.Close()
if resp.StatusCode == http.StatusUnauthorized {
return fmt.Errorf("unauthorized")
}
hcr = err
}
if resp.StatusCode == http.StatusMovedPermanently {
nRedirect++
if nRedirect > maxRedirect {
return fmt.Errorf("maximum leader redirect limit exceeded")
}
urlStr = resp.Header["Location"][0]
continue
}
response, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("server responded with %s: %s", resp.Status, response)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("server responded with %s: %s", resp.Status, response)
}
// Parse response and write results
ret := &queryResponse{}
if err := parseResponse(&response, &ret); err != nil {
return err
}
if ret.Error != "" {
return fmt.Errorf(ret.Error)
}
if len(ret.Results) != 1 {
return fmt.Errorf("unexpected results length: %d", len(ret.Results))
}
// Parse response and write results
ret := &queryResponse{}
if err := parseResponse(&response, &ret); err != nil {
return err
}
if ret.Error != "" {
return fmt.Errorf(ret.Error)
}
if len(ret.Results) != 1 {
return fmt.Errorf("unexpected results length: %d", len(ret.Results))
}
result := ret.Results[0]
if err := result.validate(); err != nil {
return err
}
textutil.WriteTable(ctx, result, headerRender)
result := ret.Results[0]
if err := result.validate(); err != nil {
return err
}
textutil.WriteTable(ctx, result, headerRender)
if timer {
fmt.Printf("Run Time: %f seconds\n", result.Time)
}
return nil
if timer {
fmt.Printf("Run Time: %f seconds\n", result.Time)
}
return hcr
}

Loading…
Cancel
Save