[Draft] Accept alternative servers to connect to in rqlite cli (#947)
Add -alternatives flag to fallback to when hosts are unavailablemaster
parent
8f0f5f9ebc
commit
80881e7b8b
@ -0,0 +1,191 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ErrNoAvailableHost indicates that the client could not find an available host to send the request to.
|
||||
var ErrNoAvailableHost = fmt.Errorf("no host available to perform the request")
|
||||
|
||||
// ErrTooManyRedirects indicates that the client exceeded the maximum number of redirects
|
||||
var ErrTooManyRedirects = fmt.Errorf("maximum leader redirect limit exceeded")
|
||||
|
||||
// HostChangedError indicates that the underlying request was executed on a different host
|
||||
// different from the caller anticipated
|
||||
type HostChangedError struct {
|
||||
NewHost string
|
||||
}
|
||||
|
||||
func (he *HostChangedError) Error() string {
|
||||
return fmt.Sprintf("HostChangedErr: new host is '%s'", he.NewHost)
|
||||
}
|
||||
|
||||
type ConfigFunc func(*Client)
|
||||
|
||||
// Client is a wrapper around stock `http.Client` that adds "retry on another host" behaviour
|
||||
// based on the supplied configuration.
|
||||
//
|
||||
// The client will fall back and try other nodes when the current node is unavailable, and would stop trying
|
||||
// after exhausting the list of supplied hosts.
|
||||
//
|
||||
// Note:
|
||||
//
|
||||
// This type is not goroutine safe.
|
||||
// A node is considered unavailable if the client is not reachable via the network.
|
||||
// TODO: make the unavailability condition for the client more dynamic.
|
||||
type Client struct {
|
||||
*http.Client
|
||||
scheme string
|
||||
hosts []string
|
||||
Prefix string
|
||||
|
||||
// creds stores the http basic authentication username and password
|
||||
creds string
|
||||
logger *log.Logger
|
||||
|
||||
// currentHost keeps track of the last available host
|
||||
currentHost int
|
||||
maxRedirect int
|
||||
}
|
||||
|
||||
// NewClient creates a default client that sends `execute` and query `requests` against the
|
||||
// rqlited nodes supplied via `hosts` argument.
|
||||
func NewClient(client *http.Client, hosts []string, configFuncs ...ConfigFunc) *Client {
|
||||
cl := &Client{
|
||||
Client: client,
|
||||
hosts: hosts,
|
||||
scheme: "http",
|
||||
maxRedirect: 21,
|
||||
Prefix: "/",
|
||||
logger: log.New(os.Stderr, "[client] ", log.LstdFlags),
|
||||
}
|
||||
|
||||
for _, f := range configFuncs {
|
||||
f(cl)
|
||||
}
|
||||
|
||||
return cl
|
||||
}
|
||||
|
||||
// WithScheme changes the default scheme used i.e "http".
|
||||
func WithScheme(scheme string) ConfigFunc {
|
||||
return func(client *Client) {
|
||||
client.scheme = scheme
|
||||
}
|
||||
}
|
||||
|
||||
// WithPrefix sets the prefix to be used when issuing HTTP requests against one of
|
||||
// the rqlited nodes.
|
||||
func WithPrefix(prefix string) ConfigFunc {
|
||||
return func(client *Client) {
|
||||
client.Prefix = prefix
|
||||
}
|
||||
}
|
||||
|
||||
// WithLogger changes the default logger to the one provided.
|
||||
func WithLogger(logger *log.Logger) ConfigFunc {
|
||||
return func(client *Client) {
|
||||
client.logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
// WithBasicAuth adds basic authentication behaviour to the client's request.
|
||||
func WithBasicAuth(creds string) ConfigFunc {
|
||||
return func(client *Client) {
|
||||
client.creds = creds
|
||||
}
|
||||
}
|
||||
|
||||
// Query sends GET requests to one of the hosts known to the client.
|
||||
func (c *Client) Query(url url.URL) (*http.Response, error) {
|
||||
return c.execRequest(http.MethodGet, url, nil)
|
||||
}
|
||||
|
||||
// Execute sends POST requests to one of the hosts known to the client
|
||||
func (c *Client) Execute(url url.URL, body io.Reader) (*http.Response, error) {
|
||||
return c.execRequest(http.MethodPost, url, body)
|
||||
}
|
||||
|
||||
func (c *Client) execRequest(method string, url url.URL, body io.Reader) (*http.Response, error) {
|
||||
triedHosts := 0
|
||||
for triedHosts < len(c.hosts) {
|
||||
host := c.hosts[c.currentHost]
|
||||
url.Scheme = c.scheme
|
||||
url.Host = host
|
||||
urlStr := url.String()
|
||||
resp, err := c.requestFollowRedirect(method, urlStr, body)
|
||||
|
||||
// Found a responsive node
|
||||
if err == nil {
|
||||
if triedHosts > 0 {
|
||||
return resp, &HostChangedError{NewHost: host}
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// If we did too many redirects, we will consider the host as unavailable as well,
|
||||
// and we will retry the request from another host
|
||||
if err == ErrTooManyRedirects {
|
||||
c.logger.Printf("too many redirects from host: '%s'", host)
|
||||
}
|
||||
|
||||
c.logger.Printf("host '%s' is unavailable, retrying with the next available host", host)
|
||||
triedHosts++
|
||||
c.nextHost()
|
||||
}
|
||||
|
||||
c.logger.Printf("none of the available hosts are responsive")
|
||||
return nil, ErrNoAvailableHost
|
||||
}
|
||||
|
||||
func (c *Client) nextHost() {
|
||||
c.currentHost = (c.currentHost + 1) % len(c.hosts)
|
||||
}
|
||||
|
||||
func (c *Client) requestFollowRedirect(method string, urlStr string, body io.Reader) (*http.Response, error) {
|
||||
nRedirects := 0
|
||||
for {
|
||||
req, err := http.NewRequest(method, urlStr, body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = c.setBasicAuth(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resp, err := c.Client.Do(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusMovedPermanently {
|
||||
nRedirects++
|
||||
if nRedirects > c.maxRedirect {
|
||||
return resp, ErrTooManyRedirects
|
||||
}
|
||||
urlStr = resp.Header["Location"][0]
|
||||
continue
|
||||
}
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) setBasicAuth(req *http.Request) error {
|
||||
if c.creds == "" {
|
||||
return nil
|
||||
}
|
||||
creds := strings.Split(c.creds, ":")
|
||||
if len(creds) != 2 {
|
||||
return fmt.Errorf("invalid Basic Auth credential format")
|
||||
}
|
||||
req.SetBasicAuth(creds[0], creds[1])
|
||||
return nil
|
||||
}
|
@ -0,0 +1,164 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestClient_QueryWhenAllAvailable(t *testing.T) {
|
||||
node1 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer node1.Close()
|
||||
|
||||
node2 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer node2.Close()
|
||||
|
||||
httpClient := http.DefaultClient
|
||||
|
||||
u1, _ := url.Parse(node1.URL)
|
||||
u2, _ := url.Parse(node2.URL)
|
||||
client := NewClient(httpClient, []string{u1.Host, u2.Host})
|
||||
res, err := client.Query(url.URL{
|
||||
Path: "/",
|
||||
})
|
||||
defer res.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if client.currentHost != 0 {
|
||||
t.Errorf("expected to only forward requests to the first host")
|
||||
}
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("unexpected status code, expected '200' got '%d'", res.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_QueryWhenSomeAreAvailable(t *testing.T) {
|
||||
node1 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
node2 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
defer node2.Close()
|
||||
|
||||
httpClient := http.DefaultClient
|
||||
|
||||
// Shutting down one of the hosts making it unavailable
|
||||
node1.Close()
|
||||
|
||||
u1, _ := url.Parse(node1.URL)
|
||||
u2, _ := url.Parse(node2.URL)
|
||||
client := NewClient(httpClient, []string{u1.Host, u2.Host})
|
||||
res, err := client.Query(url.URL{
|
||||
Path: "/",
|
||||
})
|
||||
defer res.Body.Close()
|
||||
|
||||
// If the request succeeds after changing hosts, it should be reflected in the returned error
|
||||
// as HostChangedError
|
||||
if err == nil {
|
||||
t.Errorf("expected HostChangedError got nil instead")
|
||||
}
|
||||
|
||||
hcer, ok := err.(*HostChangedError)
|
||||
|
||||
if !ok {
|
||||
t.Errorf("unexpected error occurred: %v", err)
|
||||
}
|
||||
|
||||
if hcer.NewHost != u2.Host {
|
||||
t.Errorf("unexpected responding host")
|
||||
}
|
||||
|
||||
if client.currentHost != 1 {
|
||||
t.Errorf("expected to move on to the following host")
|
||||
}
|
||||
|
||||
if res.StatusCode != http.StatusOK {
|
||||
t.Errorf("unexpected status code, expected '200' got '%d'", res.StatusCode)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_QueryWhenAllUnavailable(t *testing.T) {
|
||||
node1 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
node2 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
}))
|
||||
|
||||
httpClient := http.DefaultClient
|
||||
|
||||
u1, _ := url.Parse(node1.URL)
|
||||
u2, _ := url.Parse(node2.URL)
|
||||
|
||||
// Shutting down both nodes, both of them now are unavailable
|
||||
node1.Close()
|
||||
node2.Close()
|
||||
client := NewClient(httpClient, []string{u1.Host, u2.Host})
|
||||
_, err := client.Query(url.URL{
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
if err != ErrNoAvailableHost {
|
||||
t.Errorf("Expected %v, got: %v", ErrNoAvailableHost, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClient_BasicAuthIsForwarded(t *testing.T) {
|
||||
mockAuth := func(request *http.Request) bool {
|
||||
user, pass, ok := request.BasicAuth()
|
||||
if ok {
|
||||
if user == "john" && pass == "doe" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
node1 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
if mockAuth(request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
defer node1.Close()
|
||||
|
||||
node2 := httptest.NewServer(http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
if mockAuth(request) {
|
||||
writer.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
writer.WriteHeader(http.StatusUnauthorized)
|
||||
}))
|
||||
defer node2.Close()
|
||||
|
||||
httpClient := http.DefaultClient
|
||||
|
||||
u1, _ := url.Parse(node1.URL)
|
||||
u2, _ := url.Parse(node2.URL)
|
||||
client := NewClient(httpClient, []string{u1.Host, u2.Host}, WithBasicAuth("john:wrongpassword"))
|
||||
|
||||
res, err := client.Query(url.URL{
|
||||
Path: "/",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("unexpected error")
|
||||
}
|
||||
|
||||
if res.StatusCode != http.StatusUnauthorized {
|
||||
t.Errorf("expected unauthorized status")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue