1
0
Fork 0

Merge pull request #1229 from rqlite/auto-backup

Add support for automatic backup to S3
master
Philip O'Toole 1 year ago committed by GitHub
commit f1d3a1e893
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -170,6 +170,28 @@ jobs:
RQLITED_PATH: /home/circleci/go/bin/rqlited RQLITED_PATH: /home/circleci/go/bin/rqlited
resource_class: large resource_class: large
end_to_end_auto_backup:
docker:
- image: cimg/go:1.20.0
steps:
- checkout
- restore_cache:
keys:
- go-mod-v4-{{ checksum "go.sum" }}
- run: sudo apt-get update
- run: sudo apt-get install python3
- run: sudo apt install python3-pip
- run: python3 -m pip install requests boto3
- run: go version
- run: go get -t -d -v ./...
- run: go install -tags osusergo,netgo,sqlite_omit_load_extension
-ldflags="-extldflags=-static" ./...
- run:
command: python3 system_test/e2e/auto_backup.py
environment:
RQLITED_PATH: /home/circleci/go/bin/rqlited
resource_class: large
workflows: workflows:
version: 2 version: 2
build_and_test: build_and_test:
@ -183,3 +205,4 @@ workflows:
- end_to_end_multi_adv - end_to_end_multi_adv
- end_to_end_joining - end_to_end_joining
- end_to_end_autoclustering - end_to_end_autoclustering
- end_to_end_auto_backup

@ -1,3 +1,7 @@
## 7.15.0 (unreleased)
### New features
- [PR #11229](https://github.com/rqlite/rqlite/pull/1229): Add support for automatic backups to AWS S3.
## 7.14.3 (April 25th 2023) ## 7.14.3 (April 25th 2023)
### Implementation changes and bug fixes ### Implementation changes and bug fixes
- [PR #1218](https://github.com/rqlite/rqlite/pull/1218): Check for more possible errors in peers.json. Thanks @Tjstretchalot - [PR #1218](https://github.com/rqlite/rqlite/pull/1218): Check for more possible errors in peers.json. Thanks @Tjstretchalot

@ -30,7 +30,7 @@ rqlite uses [Raft](https://raft.github.io/) to achieve consensus across all the
- Choice of [read consistency levels](https://rqlite.io/docs/api/read-consistency/), and support for choosing [write performance over durability](https://rqlite.io/docs/api/queued-writes/). - Choice of [read consistency levels](https://rqlite.io/docs/api/read-consistency/), and support for choosing [write performance over durability](https://rqlite.io/docs/api/queued-writes/).
- Optional [read-only (non-voting) nodes](https://rqlite.io/docs/clustering/read-only-nodes/), which can add read scalability to the system. - Optional [read-only (non-voting) nodes](https://rqlite.io/docs/clustering/read-only-nodes/), which can add read scalability to the system.
- A form of transaction support. - A form of transaction support.
- Hot [backups](https://rqlite.io/docs/guides/backup/), as well as [load directly from SQLite](https://rqlite.io/docs/guides/restore/). - Hot [backups](https://rqlite.io/docs/guides/backup/), including automatic backups to [AWS S3](https://aws.amazon.com/s3/), as well as [restore directly from SQLite](https://rqlite.io/docs/guides/restore/).
## Quick Start ## Quick Start

@ -1,2 +0,0 @@
// Package aws provides functionality for accessing the AWS API.
package aws

@ -1,50 +0,0 @@
package aws
import (
"fmt"
"io/ioutil"
"net/http"
)
// MetadataClient is a client for fetching AWS EC2 instance metadata.
type MetadataClient struct {
client *http.Client
URL string
}
// NewMetadataClient returns an instance of a MetadataClient
func NewMetadataClient() *MetadataClient {
return &MetadataClient{
client: &http.Client{},
URL: `http://169.254.169.254/`,
}
}
// LocalIPv4 returns the private IPv4 address of the instance.
func (m *MetadataClient) LocalIPv4() (string, error) {
return m.get("/latest/meta-data/local-ipv4")
}
// PublicIPv4 returns the public IPv4 address of the instance.
func (m *MetadataClient) PublicIPv4() (string, error) {
return m.get("/latest/meta-data/public-ipv4")
}
func (m *MetadataClient) get(path string) (string, error) {
resp, err := m.client.Get(m.URL + path)
if err != nil {
return "", err
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return "", err
}
if resp.StatusCode != http.StatusOK {
return "", fmt.Errorf("failed to request %s, got: %s", path, resp.Status)
}
return string(b), nil
}

@ -1,89 +0,0 @@
package aws
import (
"fmt"
"net/http"
"net/http/httptest"
"testing"
)
func Test_NewMetadataClient(t *testing.T) {
c := NewMetadataClient()
if c == nil {
t.Fatalf("failed to create new Metadata client")
}
}
func Test_MetadataClient_LocalIPv4(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Fatalf("Client did not use GET")
}
if r.URL.String() != "/latest/meta-data/local-ipv4" {
t.Fatalf("Request URL is wrong, got: %s", r.URL.String())
}
fmt.Fprint(w, "172.31.34.179")
}))
defer ts.Close()
c := NewMetadataClient()
c.URL = ts.URL
addr, err := c.LocalIPv4()
if err != nil {
t.Fatalf("failed to get local IPv4 address: %s", err.Error())
}
if addr != "172.31.34.179" {
t.Fatalf("got incorrect local IPv4 address: %s", addr)
}
}
func Test_MetadataClient_LocalIPv4Fail(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer ts.Close()
c := NewMetadataClient()
c.URL = ts.URL
_, err := c.LocalIPv4()
if err == nil {
t.Fatalf("failed to get error when server returned 400")
}
}
func Test_MetadataClient_PublicIPv4(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if r.Method != "GET" {
t.Fatalf("Client did not use GET")
}
if r.URL.String() != "/latest/meta-data/public-ipv4" {
t.Fatalf("Request URL is wrong, got: %s", r.URL.String())
}
fmt.Fprint(w, "52.38.41.98")
}))
defer ts.Close()
c := NewMetadataClient()
c.URL = ts.URL
addr, err := c.PublicIPv4()
if err != nil {
t.Fatalf("failed to get local IPv4 address: %s", err.Error())
}
if addr != "52.38.41.98" {
t.Fatalf("got incorrect local IPv4 address: %s", addr)
}
}
func Test_MetadataClient_PublicIPv4Fail(t *testing.T) {
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusBadRequest)
}))
defer ts.Close()
c := NewMetadataClient()
c.URL = ts.URL
_, err := c.PublicIPv4()
if err == nil {
t.Fatalf("failed to get error when server returned 400")
}
}

@ -0,0 +1,106 @@
package aws
import (
"context"
"fmt"
"io"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
type uploader interface {
UploadWithContext(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
}
type downloader interface {
DownloadWithContext(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, opts ...func(*s3manager.Downloader)) (n int64, err error)
}
// S3Client is a client for uploading data to S3.
type S3Client struct {
region string
accessKey string
secretKey string
bucket string
key string
uploader uploader // for testing via dependency injection
}
// NewS3Client returns an instance of an S3Client.
func NewS3Client(region, accessKey, secretKey, bucket, key string) *S3Client {
return &S3Client{
region: region,
accessKey: accessKey,
secretKey: secretKey,
bucket: bucket,
key: key,
}
}
// String returns a string representation of the S3Client.
func (s *S3Client) String() string {
return fmt.Sprintf("s3://%s/%s", s.bucket, s.key)
}
// Upload uploads data to S3.
func (s *S3Client) Upload(ctx context.Context, reader io.Reader) error {
sess, err := s.createSession()
if err != nil {
return err
}
// If an uploader was not provided, use a real S3 uploader.
var uploader uploader
if s.uploader == nil {
uploader = s3manager.NewUploader(sess)
} else {
uploader = s.uploader
}
_, err = uploader.UploadWithContext(ctx, &s3manager.UploadInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s.key),
Body: reader,
})
if err != nil {
return fmt.Errorf("failed to upload to %v: %w", s, err)
}
return nil
}
// Download downloads data from S3.
func (s *S3Client) Download(ctx context.Context, writer io.WriterAt) error {
sess, err := s.createSession()
if err != nil {
return err
}
downloader := s3manager.NewDownloader(sess)
_, err = downloader.DownloadWithContext(ctx, writer, &s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s.key),
})
if err != nil {
return fmt.Errorf("failed to download from %v: %w", s, err)
}
return nil
}
func (s *S3Client) createSession() (*session.Session, error) {
sess, err := session.NewSession(&aws.Config{
Region: aws.String(s.region),
Credentials: credentials.NewStaticCredentials(s.accessKey, s.secretKey, ""),
})
if err != nil {
return nil, fmt.Errorf("failed to create S3 session: %w", err)
}
return sess, nil
}

@ -0,0 +1,118 @@
package aws
import (
"context"
"fmt"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
func Test_NewS3Client(t *testing.T) {
c := NewS3Client("region1", "access", "secret", "bucket2", "key3")
if c.region != "region1" {
t.Fatalf("expected region to be %q, got %q", "region1", c.region)
}
if c.accessKey != "access" {
t.Fatalf("expected accessKey to be %q, got %q", "access", c.accessKey)
}
if c.secretKey != "secret" {
t.Fatalf("expected secretKey to be %q, got %q", "secret", c.secretKey)
}
if c.bucket != "bucket2" {
t.Fatalf("expected bucket to be %q, got %q", "bucket2", c.bucket)
}
if c.key != "key3" {
t.Fatalf("expected key to be %q, got %q", "key3", c.key)
}
}
func Test_S3Client_String(t *testing.T) {
c := NewS3Client("region1", "access", "secret", "bucket2", "key3")
if c.String() != "s3://bucket2/key3" {
t.Fatalf("expected String() to be %q, got %q", "s3://bucket2/key3", c.String())
}
}
func TestS3ClientUploadOK(t *testing.T) {
region := "us-west-2"
accessKey := "your-access-key"
secretKey := "your-secret-key"
bucket := "your-bucket"
key := "your/key/path"
mockUploader := &mockUploader{
uploadFn: func(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) {
if *input.Bucket != bucket {
t.Errorf("expected bucket to be %q, got %q", bucket, *input.Bucket)
}
if *input.Key != key {
t.Errorf("expected key to be %q, got %q", key, *input.Key)
}
if input.Body == nil {
t.Errorf("expected body to be non-nil")
}
return &s3manager.UploadOutput{}, nil
},
}
client := &S3Client{
region: region,
accessKey: accessKey,
secretKey: secretKey,
bucket: bucket,
key: key,
uploader: mockUploader,
}
reader := strings.NewReader("test data")
err := client.Upload(context.Background(), reader)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
}
func TestS3ClientUploadFail(t *testing.T) {
region := "us-west-2"
accessKey := "your-access-key"
secretKey := "your-secret-key"
bucket := "your-bucket"
key := "your/key/path"
mockUploader := &mockUploader{
uploadFn: func(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) {
return &s3manager.UploadOutput{}, fmt.Errorf("some error related to S3")
},
}
client := &S3Client{
region: region,
accessKey: accessKey,
secretKey: secretKey,
bucket: bucket,
key: key,
uploader: mockUploader,
}
reader := strings.NewReader("test data")
err := client.Upload(context.Background(), reader)
if err == nil {
t.Fatal("Expected error, got nil")
}
if !strings.Contains(err.Error(), "some error related to S3") {
t.Fatalf("Expected error to contain %q, got %q", "some error related to S3", err.Error())
}
}
type mockUploader struct {
uploadFn func(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
}
func (m *mockUploader) UploadWithContext(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) {
if m.uploadFn != nil {
return m.uploadFn(ctx, input, opts...)
}
return &s3manager.UploadOutput{}, nil
}

@ -52,6 +52,9 @@ type Config struct {
// AuthFile is the path to the authentication file. May not be set. // AuthFile is the path to the authentication file. May not be set.
AuthFile string AuthFile string
// AutoBackupFile is the path to the auto-backup file. May not be set.
AutoBackupFile string
// HTTPx509CACert is the path to the CA certficate file for when this node verifies // HTTPx509CACert is the path to the CA certficate file for when this node verifies
// other certificates for any HTTP communications. May not be set. // other certificates for any HTTP communications. May not be set.
HTTPx509CACert string HTTPx509CACert string
@ -395,6 +398,7 @@ func ParseFlags(name, desc string, build *BuildInfo) (*Config, error) {
flag.BoolVar(&config.NoNodeVerify, "node-no-verify", false, "Skip verification of any node-node certificate") flag.BoolVar(&config.NoNodeVerify, "node-no-verify", false, "Skip verification of any node-node certificate")
flag.BoolVar(&config.NodeVerifyClient, "node-verify-client", false, "Enable mutual TLS for node-to-node communication") flag.BoolVar(&config.NodeVerifyClient, "node-verify-client", false, "Enable mutual TLS for node-to-node communication")
flag.StringVar(&config.AuthFile, "auth", "", "Path to authentication and authorization file. If not set, not enabled") flag.StringVar(&config.AuthFile, "auth", "", "Path to authentication and authorization file. If not set, not enabled")
flag.StringVar(&config.AutoBackupFile, "auto-backup", "", "Path to automatic backup configuration file. If not set, not enabled")
flag.StringVar(&config.RaftAddr, RaftAddrFlag, "localhost:4002", "Raft communication bind address") flag.StringVar(&config.RaftAddr, RaftAddrFlag, "localhost:4002", "Raft communication bind address")
flag.StringVar(&config.RaftAdv, RaftAdvAddrFlag, "", "Advertised Raft communication address. If not set, same as Raft bind") flag.StringVar(&config.RaftAdv, RaftAdvAddrFlag, "", "Advertised Raft communication address. If not set, same as Raft bind")
flag.StringVar(&config.JoinSrcIP, "join-source-ip", "", "Set source IP address during HTTP Join request") flag.StringVar(&config.JoinSrcIP, "join-source-ip", "", "Set source IP address during HTTP Join request")
@ -432,8 +436,8 @@ func ParseFlags(name, desc string, build *BuildInfo) (*Config, error) {
flag.IntVar(&config.WriteQueueBatchSz, "write-queue-batch-size", 128, "Write queue batch size") flag.IntVar(&config.WriteQueueBatchSz, "write-queue-batch-size", 128, "Write queue batch size")
flag.DurationVar(&config.WriteQueueTimeout, "write-queue-timeout", 50*time.Millisecond, "Write queue timeout") flag.DurationVar(&config.WriteQueueTimeout, "write-queue-timeout", 50*time.Millisecond, "Write queue timeout")
flag.BoolVar(&config.WriteQueueTx, "write-queue-tx", false, "Use a transaction when writing from queue") flag.BoolVar(&config.WriteQueueTx, "write-queue-tx", false, "Use a transaction when writing from queue")
flag.IntVar(&config.CompressionSize, "compression-size", 150, "Request query size for compression attempt") flag.IntVar(&config.CompressionSize, "compression-size", 150, "Request query size for Raft log compression attempt")
flag.IntVar(&config.CompressionBatch, "compression-batch", 5, "Request batch threshold for compression attempt") flag.IntVar(&config.CompressionBatch, "compression-batch", 5, "Request batch threshold for Raft log compression attempt")
flag.StringVar(&config.CPUProfile, "cpu-profile", "", "Path to file for CPU profiling information") flag.StringVar(&config.CPUProfile, "cpu-profile", "", "Path to file for CPU profiling information")
flag.StringVar(&config.MemProfile, "mem-profile", "", "Path to file for memory profiling information") flag.StringVar(&config.MemProfile, "mem-profile", "", "Path to file for memory profiling information")
flag.Usage = func() { flag.Usage = func() {

@ -2,6 +2,7 @@
package main package main
import ( import (
"context"
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"log" "log"
@ -19,6 +20,7 @@ import (
"github.com/rqlite/rqlite-disco-clients/dnssrv" "github.com/rqlite/rqlite-disco-clients/dnssrv"
etcd "github.com/rqlite/rqlite-disco-clients/etcd" etcd "github.com/rqlite/rqlite-disco-clients/etcd"
"github.com/rqlite/rqlite/auth" "github.com/rqlite/rqlite/auth"
"github.com/rqlite/rqlite/aws"
"github.com/rqlite/rqlite/cluster" "github.com/rqlite/rqlite/cluster"
"github.com/rqlite/rqlite/cmd" "github.com/rqlite/rqlite/cmd"
"github.com/rqlite/rqlite/db" "github.com/rqlite/rqlite/db"
@ -27,6 +29,7 @@ import (
"github.com/rqlite/rqlite/rtls" "github.com/rqlite/rqlite/rtls"
"github.com/rqlite/rqlite/store" "github.com/rqlite/rqlite/store"
"github.com/rqlite/rqlite/tcp" "github.com/rqlite/rqlite/tcp"
"github.com/rqlite/rqlite/upload"
) )
const logo = ` const logo = `
@ -43,7 +46,9 @@ const logo = `
const name = `rqlited` const name = `rqlited`
const desc = `rqlite is a lightweight, distributed relational database, which uses SQLite as its const desc = `rqlite is a lightweight, distributed relational database, which uses SQLite as its
storage engine. It provides an easy-to-use, fault-tolerant store for relational data.` storage engine. It provides an easy-to-use, fault-tolerant store for relational data.
Visit https://www.rqlite.io to learn more.`
func init() { func init() {
log.SetFlags(log.LstdFlags) log.SetFlags(log.LstdFlags)
@ -149,6 +154,16 @@ func main() {
h, p, _ := net.SplitHostPort(cfg.HTTPAdv) h, p, _ := net.SplitHostPort(cfg.HTTPAdv)
log.Printf("connect using the command-line tool via 'rqlite -H %s -p %s'", h, p) log.Printf("connect using the command-line tool via 'rqlite -H %s -p %s'", h, p)
// Start any requested auto-backups
ctx, backupSrvCancel := context.WithCancel(context.Background())
backupSrv, err := startAutoBackups(ctx, cfg, str)
if err != nil {
log.Fatalf("failed to start auto-backups: %s", err.Error())
}
if backupSrv != nil {
httpServ.RegisterStatus("auto_backups", backupSrv)
}
// Block until signalled. // Block until signalled.
terminate := make(chan os.Signal, 1) terminate := make(chan os.Signal, 1)
signal.Notify(terminate, syscall.SIGINT, syscall.SIGTERM, os.Interrupt) signal.Notify(terminate, syscall.SIGINT, syscall.SIGTERM, os.Interrupt)
@ -163,6 +178,7 @@ func main() {
str.Stepdown(true) str.Stepdown(true)
} }
backupSrvCancel()
httpServ.Close() httpServ.Close()
if err := str.Close(true); err != nil { if err := str.Close(true); err != nil {
log.Printf("failed to close store: %s", err.Error()) log.Printf("failed to close store: %s", err.Error())
@ -173,6 +189,26 @@ func main() {
log.Println("rqlite server stopped") log.Println("rqlite server stopped")
} }
func startAutoBackups(ctx context.Context, cfg *Config, str *store.Store) (*upload.Uploader, error) {
if cfg.AutoBackupFile == "" {
return nil, nil
}
b, err := upload.ReadConfigFile(cfg.AutoBackupFile)
if err != nil {
return nil, fmt.Errorf("failed to read auto-backup file: %s", err.Error())
}
uCfg, s3cfg, err := upload.Unmarshal(b)
if err != nil {
return nil, fmt.Errorf("failed to parse auto-backup file: %s", err.Error())
}
sc := aws.NewS3Client(s3cfg.Region, s3cfg.AccessKeyID, s3cfg.SecretAccessKey, s3cfg.Bucket, s3cfg.Path)
u := upload.NewUploader(sc, str, time.Duration(uCfg.Interval), !uCfg.NoCompress)
go u.Start(ctx, nil)
return u, nil
}
func createStore(cfg *Config, ln *tcp.Layer) (*store.Store, error) { func createStore(cfg *Config, ln *tcp.Layer) (*store.Store, error) {
dataPath, err := filepath.Abs(cfg.DataPath) dataPath, err := filepath.Abs(cfg.DataPath)
if err != nil { if err != nil {

@ -4,6 +4,7 @@ go 1.16
require ( require (
github.com/Bowery/prompt v0.0.0-20190916142128-fa8279994f75 github.com/Bowery/prompt v0.0.0-20190916142128-fa8279994f75
github.com/aws/aws-sdk-go v1.44.250
github.com/coreos/go-semver v0.3.1 // indirect github.com/coreos/go-semver v0.3.1 // indirect
github.com/coreos/go-systemd/v22 v22.5.0 // indirect github.com/coreos/go-systemd/v22 v22.5.0 // indirect
github.com/fatih/color v1.15.0 // indirect github.com/fatih/color v1.15.0 // indirect

@ -623,6 +623,8 @@ github.com/armon/go-metrics v0.4.1 h1:hR91U9KYmb6bLBYLQjyM+3j+rcd/UhE+G78SFnF8gJ
github.com/armon/go-metrics v0.4.1/go.mod h1:E6amYzXo6aW1tqzoZGT755KkbgrJsSdpwZ+3JqfkOG4= github.com/armon/go-metrics v0.4.1/go.mod h1:E6amYzXo6aW1tqzoZGT755KkbgrJsSdpwZ+3JqfkOG4=
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/armon/go-radix v1.0.0/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
github.com/aws/aws-sdk-go v1.44.250 h1:IuGUO2Hafv/b0yYKI5UPLQShYDx50BCIQhab/H1sX2M=
github.com/aws/aws-sdk-go v1.44.250/go.mod h1:aVsgQcEevwlmQ7qHE9I3h+dtQgpqhFB+i8Phjh7fkwI=
github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8= github.com/benbjohnson/clock v1.1.0 h1:Q92kusRqC1XV2MjkWETPvjJVqKetz1OzxZB7mHJLju8=
github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA= github.com/benbjohnson/clock v1.1.0/go.mod h1:J11/hYXuz8f4ySSvYwY0FKfm+ezbsZBKZxNJlLklBHA=
github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q= github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24TaqPxmxbtue+5NUziq4I4S80YR8gNf3Q=
@ -883,6 +885,10 @@ github.com/hashicorp/serf v0.10.1/go.mod h1:yL2t6BqATOLGc5HF7qbFkTfXoPIY0WZdWHfE
github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho= github.com/iancoleman/strcase v0.2.0/go.mod h1:iwCmte+B7n89clKwxIoIXy/HfoL7AsD47ZCWhYzw7ho=
github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20181102032728-5e5cf60278f6/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc= github.com/ianlancetaylor/demangle v0.0.0-20200824232613-28f6c0f3b639/go.mod h1:aSSvb/t6k1mPoxDqO4vJh6VOCGPwU4O0C2/Eqndh1Sc=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
github.com/jmespath/go-jmespath v0.4.0/go.mod h1:T8mJZnbsbmF+m6zOOFylbeCJqk5+pHWvzYPziyZiYoo=
github.com/jmespath/go-jmespath/internal/testify v1.5.1 h1:shLQSRRSCCPj3f2gpwzGwWFoC7ycTf1rcQZHOlsJ6N8=
github.com/jmespath/go-jmespath/internal/testify v1.5.1/go.mod h1:L3OGu8Wl2/fWfCI6z80xFu9LTZmf1ZRjMHUOPmWr69U=
github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4= github.com/jpillora/backoff v1.0.0/go.mod h1:J/6gKK9jxlEcS3zixgDgUAsiuZ7yrSoa/FX5e0EB2j4=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.9/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
@ -1223,6 +1229,7 @@ golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug
golang.org/x/net v0.0.0-20220909164309-bea034e7d591/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.0.0-20220909164309-bea034e7d591/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.0.0-20221012135044-0b7e1fb9d458/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.0.0-20221012135044-0b7e1fb9d458/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= golang.org/x/net v0.0.0-20221014081412-f15817d10f9b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY=
golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
@ -1365,6 +1372,7 @@ golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220829200755-d48e67d00261/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@ -1375,6 +1383,7 @@ golang.org/x/sys v0.7.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc=
golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
@ -1738,6 +1747,7 @@ gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.5/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

@ -28,6 +28,7 @@ import (
"github.com/rqlite/rqlite/command" "github.com/rqlite/rqlite/command"
sql "github.com/rqlite/rqlite/db" sql "github.com/rqlite/rqlite/db"
rlog "github.com/rqlite/rqlite/log" rlog "github.com/rqlite/rqlite/log"
"github.com/rqlite/rqlite/upload"
) )
var ( var (
@ -70,6 +71,7 @@ const (
const ( const (
numSnaphots = "num_snapshots" numSnaphots = "num_snapshots"
numProvides = "num_provides"
numBackups = "num_backups" numBackups = "num_backups"
numLoads = "num_loads" numLoads = "num_loads"
numRestores = "num_restores" numRestores = "num_restores"
@ -101,6 +103,7 @@ func init() {
// ResetStats resets the expvar stats for this module. Mostly for test purposes. // ResetStats resets the expvar stats for this module. Mostly for test purposes.
func ResetStats() { func ResetStats() {
stats.Add(numSnaphots, 0) stats.Add(numSnaphots, 0)
stats.Add(numProvides, 0)
stats.Add(numBackups, 0) stats.Add(numBackups, 0)
stats.Add(numRestores, 0) stats.Add(numRestores, 0)
stats.Add(numRecoveries, 0) stats.Add(numRecoveries, 0)
@ -903,7 +906,7 @@ func (s *Store) Backup(br *command.BackupRequest, dst io.Writer) (retErr error)
} }
if br.Format == command.BackupRequest_BACKUP_REQUEST_FORMAT_BINARY { if br.Format == command.BackupRequest_BACKUP_REQUEST_FORMAT_BINARY {
f, err := ioutil.TempFile("", "rqlilte-snap-") f, err := os.CreateTemp("", "rqlite-snap-")
if err != nil { if err != nil {
return err return err
} }
@ -930,6 +933,35 @@ func (s *Store) Backup(br *command.BackupRequest, dst io.Writer) (retErr error)
return ErrInvalidBackupFormat return ErrInvalidBackupFormat
} }
// Provide implements the uploader Provider interface, allowing the
// Store to be used as a DataProvider for an uploader. It returns
// a io.ReadCloser that can be used to read a copy of the entire database.
// When the ReadCloser is closed, the resources backing it are cleaned up.
func (s *Store) Provide() (io.ReadCloser, error) {
if !s.open {
return nil, ErrNotOpen
}
tempFile, err := os.CreateTemp("", "rqlite-upload-")
if err != nil {
return nil, err
}
if err := tempFile.Close(); err != nil {
return nil, err
}
if err := s.db.Backup(tempFile.Name()); err != nil {
return nil, err
}
fd, err := upload.NewAutoDeleteFile(tempFile.Name())
if err != nil {
return nil, err
}
stats.Add(numProvides, 1)
return fd, nil
}
// Loads an entire SQLite file into the database, sending the request // Loads an entire SQLite file into the database, sending the request
// through the Raft log. // through the Raft log.
func (s *Store) Load(lr *command.LoadRequest) error { func (s *Store) Load(lr *command.LoadRequest) error {

@ -3,6 +3,7 @@ package store
import ( import (
"bytes" "bytes"
"fmt" "fmt"
"io"
"io/ioutil" "io/ioutil"
"math/rand" "math/rand"
"net" "net"
@ -864,6 +865,120 @@ COMMIT;
} }
} }
func Test_SingleNodeProvide(t *testing.T) {
for _, inmem := range []bool{
true,
false,
} {
func() {
s0, ln := mustNewStore(t, inmem)
defer ln.Close()
if err := s0.Open(); err != nil {
t.Fatalf("failed to open single-node store: %s", err.Error())
}
if err := s0.Bootstrap(NewServer(s0.ID(), s0.Addr(), true)); err != nil {
t.Fatalf("failed to bootstrap single-node store: %s", err.Error())
}
defer s0.Close(true)
if _, err := s0.WaitForLeader(10 * time.Second); err != nil {
t.Fatalf("Error waiting for leader: %s", err)
}
er := executeRequestFromStrings([]string{
`CREATE TABLE foo (id INTEGER NOT NULL PRIMARY KEY, name TEXT)`,
`INSERT INTO foo(id, name) VALUES(1, "fiona")`,
}, false, false)
_, err := s0.Execute(er)
if err != nil {
t.Fatalf("failed to execute on single node: %s", err.Error())
}
qr := queryRequestFromString("SELECT * FROM foo", false, false)
qr.Level = command.QueryRequest_QUERY_REQUEST_LEVEL_NONE
r, err := s0.Query(qr)
if err != nil {
t.Fatalf("failed to query leader node: %s", err.Error())
}
if exp, got := `["id","name"]`, asJSON(r[0].Columns); exp != got {
t.Fatalf("unexpected results for query\nexp: %s\ngot: %s", exp, got)
}
if exp, got := `[[1,"fiona"]]`, asJSON(r[0].Values); exp != got {
t.Fatalf("unexpected results for query\nexp: %s\ngot: %s", exp, got)
}
rc, err := s0.Provide()
if err != nil {
t.Fatalf("store failed to provide: %s", err.Error())
}
f, err := os.CreateTemp("", "rqlite-store-test")
if err != nil {
t.Fatalf("failed to create temp file: %s", err.Error())
}
defer os.Remove(f.Name())
defer f.Close()
_, err = io.Copy(f, rc)
if err != nil {
t.Fatalf("failed to copy data from store: %s", err.Error())
}
// Load the provided data into a new store and check it.
s1, ln := mustNewStore(t, true)
defer ln.Close()
if err := s1.Open(); err != nil {
t.Fatalf("failed to open single-node store: %s", err.Error())
}
if err := s1.Bootstrap(NewServer(s1.ID(), s1.Addr(), true)); err != nil {
t.Fatalf("failed to bootstrap single-node store: %s", err.Error())
}
defer s1.Close(true)
if _, err := s1.WaitForLeader(10 * time.Second); err != nil {
t.Fatalf("Error waiting for leader: %s", err)
}
err = s1.Load(loadRequestFromFile(f.Name()))
if err != nil {
t.Fatalf("failed to load provided SQLite data: %s", err.Error())
}
qr = queryRequestFromString("SELECT * FROM foo", false, false)
qr.Level = command.QueryRequest_QUERY_REQUEST_LEVEL_STRONG
r, err = s1.Query(qr)
if err != nil {
t.Fatalf("failed to query leader node: %s", err.Error())
}
if exp, got := `["id","name"]`, asJSON(r[0].Columns); exp != got {
t.Fatalf("unexpected results for query\nexp: %s\ngot: %s", exp, got)
}
if exp, got := `[[1,"fiona"]]`, asJSON(r[0].Values); exp != got {
t.Fatalf("unexpected results for query\nexp: %s\ngot: %s", exp, got)
}
}()
}
}
func Test_SingleNodeInMemProvideNoData(t *testing.T) {
s, ln := mustNewStore(t, true)
defer ln.Close()
if err := s.Open(); err != nil {
t.Fatalf("failed to open single-node store: %s", err.Error())
}
if err := s.Bootstrap(NewServer(s.ID(), s.Addr(), true)); err != nil {
t.Fatalf("failed to bootstrap single-node store: %s", err.Error())
}
defer s.Close(true)
if _, err := s.WaitForLeader(10 * time.Second); err != nil {
t.Fatalf("Error waiting for leader: %s", err)
}
_, err := s.Provide()
if err != nil {
t.Fatalf("store failed to provide: %s", err.Error())
}
}
// Test_SingleNodeRecoverNoChange tests a node recovery that doesn't // Test_SingleNodeRecoverNoChange tests a node recovery that doesn't
// actually change anything. // actually change anything.
func Test_SingleNodeRecoverNoChange(t *testing.T) { func Test_SingleNodeRecoverNoChange(t *testing.T) {

@ -0,0 +1,140 @@
#!/usr/bin/env python
#
# End-to-end testing using actual rqlited binary.
#
# To run a specific test, execute
#
# python system_test/full_system_test.py Class.test
import os
import json
import unittest
import sqlite3
import time
from helpers import Node, deprovision_node, write_random_file, random_string, env_present, gunzip_file
from s3 import download_s3_object, delete_s3_object
S3_BUCKET = 'rqlite-testing-circleci'
S3_BUCKET_REGION = 'us-west-2'
RQLITED_PATH = os.environ['RQLITED_PATH']
class TestAutoBackupS3(unittest.TestCase):
@unittest.skipUnless(env_present('RQLITE_S3_ACCESS_KEY'), "S3 credentials not available")
def test_no_compress(self):
'''Test that automatic backups to AWS S3 work with compression off'''
node = None
cfg = None
path = None
backup_file = None
access_key_id = os.environ['RQLITE_S3_ACCESS_KEY']
secret_access_key_id = os.environ['RQLITE_S3_SECRET_ACCESS_KEY']
# Create the auto-backup config file
path = random_string(32)
auto_backup_cfg = {
"version": 1,
"type": "s3",
"interval": "1s",
"no_compress": True,
"sub" : {
"access_key_id": access_key_id,
"secret_access_key": secret_access_key_id,
"region": S3_BUCKET_REGION,
"bucket": S3_BUCKET,
"path": path
}
}
cfg = write_random_file(json.dumps(auto_backup_cfg))
# Create a node, enable automatic backups, and start it. Then
# create a table and insert a row. Wait for a backup to happen.
node = Node(RQLITED_PATH, '0', auto_backup=cfg)
node.start()
node.wait_for_leader()
node.execute('CREATE TABLE foo (id INTEGER NOT NULL PRIMARY KEY, name TEXT)')
node.execute('INSERT INTO foo(name) VALUES("fiona")')
node.wait_for_all_fsm()
time.sleep(5)
# Download the backup file from S3 and check it.
backup_data = download_s3_object(access_key_id, secret_access_key_id,
S3_BUCKET, path)
backup_file = write_random_file(backup_data, mode='wb')
conn = sqlite3.connect(backup_file)
c = conn.cursor()
c.execute('SELECT * FROM foo')
rows = c.fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0][1], 'fiona')
conn.close()
deprovision_node(node)
os.remove(cfg)
os.remove(backup_file)
delete_s3_object(access_key_id, secret_access_key_id,
S3_BUCKET, path)
@unittest.skipUnless(env_present('RQLITE_S3_ACCESS_KEY'), "S3 credentials not available")
def test_compress(self):
'''Test that automatic backups to AWS S3 work with compression on'''
node = None
cfg = None
path = None
compressed_backup_file = None
backup_file = None
access_key_id = os.environ['RQLITE_S3_ACCESS_KEY']
secret_access_key_id = os.environ['RQLITE_S3_SECRET_ACCESS_KEY']
# Create the auto-backup config file
path = random_string(32)
auto_backup_cfg = {
"version": 1,
"type": "s3",
"interval": "1s",
"sub" : {
"access_key_id": access_key_id,
"secret_access_key": secret_access_key_id,
"region": S3_BUCKET_REGION,
"bucket": S3_BUCKET,
"path": path
}
}
cfg = write_random_file(json.dumps(auto_backup_cfg))
# Create a node, enable automatic backups, and start it. Then
# create a table and insert a row. Wait for a backup to happen.
node = Node(RQLITED_PATH, '0', auto_backup=cfg)
node.start()
node.wait_for_leader()
node.execute('CREATE TABLE foo (id INTEGER NOT NULL PRIMARY KEY, name TEXT)')
node.execute('INSERT INTO foo(name) VALUES("fiona")')
node.wait_for_all_fsm()
time.sleep(5)
# Download the backup file from S3 and check it.
backup_data = download_s3_object(access_key_id, secret_access_key_id,
S3_BUCKET, path)
compressed_backup_file = write_random_file(backup_data, mode='wb')
backup_file = gunzip_file(compressed_backup_file)
conn = sqlite3.connect(backup_file)
c = conn.cursor()
c.execute('SELECT * FROM foo')
rows = c.fetchall()
self.assertEqual(len(rows), 1)
self.assertEqual(rows[0][1], 'fiona')
conn.close()
deprovision_node(node)
os.remove(cfg)
os.remove(compressed_backup_file)
os.remove(backup_file)
delete_s3_object(access_key_id, secret_access_key_id,
S3_BUCKET, path)
if __name__ == "__main__":
unittest.main(verbosity=2)

@ -2,6 +2,7 @@
import tempfile import tempfile
import ast import ast
import gzip
import subprocess import subprocess
import requests import requests
import json import json
@ -22,6 +23,14 @@ seqRe = re.compile("^{'results': \[\], 'sequence_number': \d+}$")
def d_(s): def d_(s):
return ast.literal_eval(s.replace("'", "\"")) return ast.literal_eval(s.replace("'", "\""))
def env_present(name):
return name in os.environ and os.environ[name] != ""
def gunzip_file(path):
with gzip.open(path, 'rb') as f:
file_content = f.read()
return write_random_file(file_content, mode='wb')
def is_sequence_number(r): def is_sequence_number(r):
return seqRe.match(r) return seqRe.match(r)
@ -29,8 +38,8 @@ def random_string(n):
letters = string.ascii_lowercase letters = string.ascii_lowercase
return ''.join(random.choice(letters) for i in range(n)) return ''.join(random.choice(letters) for i in range(n))
def write_random_file(data): def write_random_file(data, mode='w'):
f = tempfile.NamedTemporaryFile('w', delete=False) f = tempfile.NamedTemporaryFile(mode, delete=False)
f.write(data) f.write(data)
f.close() f.close()
return f.name return f.name
@ -57,7 +66,7 @@ class Node(object):
raft_snap_threshold=8192, raft_snap_int="1s", raft_snap_threshold=8192, raft_snap_int="1s",
http_cert=None, http_key=None, http_no_verify=False, http_cert=None, http_key=None, http_no_verify=False,
node_cert=None, node_key=None, node_no_verify=False, node_cert=None, node_key=None, node_no_verify=False,
auth=None, dir=None, on_disk=False): auth=None, auto_backup=None, dir=None, on_disk=False):
s_api = None s_api = None
s_raft = None s_raft = None
@ -100,6 +109,7 @@ class Node(object):
self.node_key = node_key self.node_key = node_key
self.node_no_verify = node_no_verify self.node_no_verify = node_no_verify
self.auth = auth self.auth = auth
self.auto_backup = auto_backup
self.disco_key = random_string(10) self.disco_key = random_string(10)
self.on_disk = on_disk self.on_disk = on_disk
self.process = None self.process = None
@ -166,6 +176,8 @@ class Node(object):
command += ['-on-disk'] command += ['-on-disk']
if self.auth is not None: if self.auth is not None:
command += ['-auth', self.auth] command += ['-auth', self.auth]
if self.auto_backup is not None:
command += ['-auto-backup', self.auto_backup]
if join is not None: if join is not None:
if join.startswith('http://') is False: if join.startswith('http://') is False:
join = 'http://' + join join = 'http://' + join

@ -0,0 +1,29 @@
#!/usr/bin/env python
import boto3
import os
def delete_s3_object(access_key_id, secret_access_key_id, bucket_name, object_key):
"""
Delete an object from an S3 bucket.
"""
os.environ['AWS_ACCESS_KEY_ID'] = access_key_id
os.environ['AWS_SECRET_ACCESS_KEY'] = secret_access_key_id
s3_client = boto3.client('s3')
s3_client.delete_object(Bucket=bucket_name, Key=object_key)
def download_s3_object(access_key_id, secret_access_key_id, bucket_name, object_key):
"""
Download an object from an S3 bucket.
Args:
bucket_name (str): The name of the S3 bucket.
object_key (str): The key of the object to download.
"""
os.environ['AWS_ACCESS_KEY_ID'] = access_key_id
os.environ['AWS_SECRET_ACCESS_KEY'] = secret_access_key_id
s3_client = boto3.client('s3')
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
return response['Body'].read()

@ -0,0 +1,28 @@
package upload
import (
"os"
)
// AutoDeleteFile is a wrapper around os.File that deletes the file when it is
// closed.
type AutoDeleteFile struct {
*os.File
}
// Close implements the io.Closer interface
func (f *AutoDeleteFile) Close() error {
if err := f.File.Close(); err != nil {
return err
}
return os.Remove(f.Name())
}
// NewAutoDeleteFile takes a filename and wraps it in an AutoDeleteFile
func NewAutoDeleteFile(path string) (*AutoDeleteFile, error) {
f, err := os.Open(path)
if err != nil {
return nil, err
}
return &AutoDeleteFile{f}, nil
}

@ -0,0 +1,57 @@
package upload
import (
"os"
"testing"
)
func Test_NewAutoDeleteTempFile(t *testing.T) {
adFile, err := NewAutoDeleteFile(mustCreateTempFilename())
if err != nil {
t.Fatalf("NewAutoDeleteFile() failed: %v", err)
}
defer adFile.Close()
if _, err := os.Stat(adFile.Name()); os.IsNotExist(err) {
t.Fatalf("Expected file to exist: %s", adFile.Name())
}
}
func Test_AutoDeleteFile_Name(t *testing.T) {
name := mustCreateTempFilename()
adFile, err := NewAutoDeleteFile(name)
if err != nil {
t.Fatalf("NewAutoDeleteFile() failed: %v", err)
}
defer adFile.Close()
if adFile.Name() != name {
t.Fatalf("Expected Name() to return %s, got %s", name, adFile.Name())
}
}
func Test_AutoDeleteFile_Close(t *testing.T) {
adFile, err := NewAutoDeleteFile(mustCreateTempFilename())
if err != nil {
t.Fatalf("NewAutoDeleteFile() failed: %v", err)
}
filename := adFile.Name()
err = adFile.Close()
if err != nil {
t.Fatalf("Close() failed: %v", err)
}
if _, err := os.Stat(filename); !os.IsNotExist(err) {
t.Fatalf("Expected file to be deleted after Close(): %s", filename)
}
}
func mustCreateTempFilename() string {
f, err := os.CreateTemp("", "autodeletefile_test")
if err != nil {
panic(err)
}
f.Close()
return f.Name()
}

@ -0,0 +1,129 @@
package upload
import (
"encoding/json"
"errors"
"io/ioutil"
"os"
"time"
)
const (
// version is the max version of the config file format supported
version = 1
)
var (
// ErrInvalidVersion is returned when the config file version is not supported.
ErrInvalidVersion = errors.New("invalid version")
// ErrUnsupportedStorageType is returned when the storage type is not supported.
ErrUnsupportedStorageType = errors.New("unsupported storage type")
)
// Duration is a wrapper around time.Duration that allows us to unmarshal
type Duration time.Duration
// MarshalJSON marshals the duration as a string
func (d Duration) MarshalJSON() ([]byte, error) {
return json.Marshal(time.Duration(d).String())
}
// UnmarshalJSON unmarshals the duration from a string or a float64
func (d *Duration) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch value := v.(type) {
case float64:
*d = Duration(time.Duration(value))
return nil
case string:
tmp, err := time.ParseDuration(value)
if err != nil {
return err
}
*d = Duration(tmp)
return nil
default:
return errors.New("invalid duration")
}
}
// StorageType is a wrapper around string that allows us to unmarshal
type StorageType string
// UnmarshalJSON unmarshals the storage type from a string and validates it
func (s *StorageType) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch value := v.(type) {
case string:
*s = StorageType(value)
if *s != "s3" {
return ErrUnsupportedStorageType
}
return nil
default:
return ErrUnsupportedStorageType
}
}
// Config is the config file format for the upload service
type Config struct {
Version int `json:"version"`
Type StorageType `json:"type"`
NoCompress bool `json:"no_compress,omitempty"`
Interval Duration `json:"interval"`
Sub json.RawMessage `json:"sub"`
}
// S3Config is the subconfig for the S3 storage type
type S3Config struct {
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Region string `json:"region"`
Bucket string `json:"bucket"`
Path string `json:"path"`
}
// Unmarshal unmarshals the config file and returns the config and subconfig
func Unmarshal(data []byte) (*Config, *S3Config, error) {
cfg := &Config{}
err := json.Unmarshal(data, cfg)
if err != nil {
return nil, nil, err
}
if cfg.Version > version {
return nil, nil, ErrInvalidVersion
}
s3cfg := &S3Config{}
err = json.Unmarshal(cfg.Sub, s3cfg)
if err != nil {
return nil, nil, err
}
return cfg, s3cfg, nil
}
// ReadConfigFile reads the config file and returns the data. It also expands
// any environment variables in the config file.
func ReadConfigFile(filename string) ([]byte, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
data, err := ioutil.ReadAll(f)
if err != nil {
return nil, err
}
data = []byte(os.ExpandEnv(string(data)))
return data, nil
}

@ -0,0 +1,252 @@
package upload
import (
"bytes"
"errors"
"io/ioutil"
"os"
"reflect"
"testing"
"time"
)
func Test_ReadConfigFile(t *testing.T) {
t.Run("valid config file", func(t *testing.T) {
// Create a temporary config file
tempFile, err := ioutil.TempFile("", "upload_config")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tempFile.Name())
content := []byte("key=value")
if _, err := tempFile.Write(content); err != nil {
t.Fatal(err)
}
tempFile.Close()
data, err := ReadConfigFile(tempFile.Name())
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if !bytes.Equal(data, content) {
t.Fatalf("Expected %v, got %v", content, data)
}
})
t.Run("non-existent file", func(t *testing.T) {
_, err := ReadConfigFile("nonexistentfile")
if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("Expected os.ErrNotExist, got %v", err)
}
})
t.Run("file with environment variables", func(t *testing.T) {
// Set an environment variable
if err := os.Setenv("TEST_VAR", "test_value"); err != nil {
t.Fatal(err)
}
defer os.Unsetenv("TEST_VAR")
// Create a temporary config file with an environment variable
tempFile, err := ioutil.TempFile("", "upload_config")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tempFile.Name())
content := []byte("key=$TEST_VAR")
if _, err := tempFile.Write(content); err != nil {
t.Fatal(err)
}
tempFile.Close()
data, err := ReadConfigFile(tempFile.Name())
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
expectedContent := []byte("key=test_value")
if !bytes.Equal(data, expectedContent) {
t.Fatalf("Expected %v, got %v", expectedContent, data)
}
})
t.Run("longer file with environment variables", func(t *testing.T) {
// Set an environment variable
if err := os.Setenv("TEST_VAR1", "test_value"); err != nil {
t.Fatal(err)
}
defer os.Unsetenv("TEST_VAR1")
// Create a temporary config file with an environment variable
tempFile, err := ioutil.TempFile("", "upload_config")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tempFile.Name())
content := []byte(`
key1=$TEST_VAR1
key2=TEST_VAR2`)
if _, err := tempFile.Write(content); err != nil {
t.Fatal(err)
}
tempFile.Close()
data, err := ReadConfigFile(tempFile.Name())
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
expectedContent := []byte(`
key1=test_value
key2=TEST_VAR2`)
if !bytes.Equal(data, expectedContent) {
t.Fatalf("Expected %v, got %v", expectedContent, data)
}
})
}
func TestUnmarshal(t *testing.T) {
testCases := []struct {
name string
input []byte
expectedCfg *Config
expectedS3 *S3Config
expectedErr error
}{
{
name: "ValidS3Config",
input: []byte(`
{
"version": 1,
"type": "s3",
"no_compress": true,
"interval": "24h",
"sub": {
"access_key_id": "test_id",
"secret_access_key": "test_secret",
"region": "us-west-2",
"bucket": "test_bucket",
"path": "test/path"
}
}
`),
expectedCfg: &Config{
Version: 1,
Type: "s3",
NoCompress: true,
Interval: 24 * Duration(time.Hour),
},
expectedS3: &S3Config{
AccessKeyID: "test_id",
SecretAccessKey: "test_secret",
Region: "us-west-2",
Bucket: "test_bucket",
Path: "test/path",
},
expectedErr: nil,
},
{
name: "ValidS3ConfigNoptionalFields",
input: []byte(`
{
"version": 1,
"type": "s3",
"interval": "24h",
"sub": {
"access_key_id": "test_id",
"secret_access_key": "test_secret",
"region": "us-west-2",
"bucket": "test_bucket",
"path": "test/path"
}
}
`),
expectedCfg: &Config{
Version: 1,
Type: "s3",
NoCompress: false,
Interval: 24 * Duration(time.Hour),
},
expectedS3: &S3Config{
AccessKeyID: "test_id",
SecretAccessKey: "test_secret",
Region: "us-west-2",
Bucket: "test_bucket",
Path: "test/path",
},
expectedErr: nil,
},
{
name: "InvalidVersion",
input: []byte(`
{
"version": 2,
"type": "s3",
"no_compress": false,
"interval": "24h",
"sub": {
"access_key_id": "test_id",
"secret_access_key": "test_secret",
"region": "us-west-2",
"bucket": "test_bucket",
"path": "test/path"
}
} `),
expectedCfg: nil,
expectedS3: nil,
expectedErr: ErrInvalidVersion,
},
{
name: "UnsupportedType",
input: []byte(`
{
"version": 1,
"type": "unsupported",
"no_compress": true,
"interval": "24h",
"sub": {
"access_key_id": "test_id",
"secret_access_key": "test_secret",
"region": "us-west-2",
"bucket": "test_bucket",
"path": "test/path"
}
} `),
expectedCfg: nil,
expectedS3: nil,
expectedErr: ErrUnsupportedStorageType,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cfg, s3Cfg, err := Unmarshal(tc.input)
_ = s3Cfg
if !errors.Is(err, tc.expectedErr) {
t.Fatalf("Test case %s failed, expected error %v, got %v", tc.name, tc.expectedErr, err)
}
if !compareConfig(cfg, tc.expectedCfg) {
t.Fatalf("Test case %s failed, expected config %+v, got %+v", tc.name, tc.expectedCfg, cfg)
}
if tc.expectedS3 != nil {
if !reflect.DeepEqual(s3Cfg, tc.expectedS3) {
t.Fatalf("Test case %s failed, expected S3Config %+v, got %+v", tc.name, tc.expectedS3, s3Cfg)
}
}
})
}
}
func compareConfig(a, b *Config) bool {
if a == nil || b == nil {
return a == b
}
return a.Version == b.Version &&
a.Type == b.Type &&
a.NoCompress == b.NoCompress &&
a.Interval == b.Interval
}

@ -0,0 +1,163 @@
package upload
import (
"bytes"
"compress/gzip"
"context"
"expvar"
"fmt"
"io"
"log"
"os"
"time"
)
// StorageClient is an interface for uploading data to a storage service.
type StorageClient interface {
Upload(ctx context.Context, reader io.Reader) error
fmt.Stringer
}
// DataProvider is an interface for providing data to be uploaded. The Uploader
// service will call Provide() to get a reader for the data to be uploaded. Once
// the upload completes the reader will be closed, regardless of whether the
// upload succeeded or failed.
type DataProvider interface {
Provide() (io.ReadCloser, error)
}
// stats captures stats for the Uploader service.
var stats *expvar.Map
const (
numUploadsOK = "num_uploads_ok"
numUploadsFail = "num_uploads_fail"
totalUploadBytes = "total_upload_bytes"
lastUploadBytes = "last_upload_bytes"
UploadCompress = true
UploadNoCompress = false
)
func init() {
stats = expvar.NewMap("uploader")
ResetStats()
}
// ResetStats resets the expvar stats for this module. Mostly for test purposes.
func ResetStats() {
stats.Init()
stats.Add(numUploadsOK, 0)
stats.Add(numUploadsFail, 0)
stats.Add(totalUploadBytes, 0)
stats.Add(lastUploadBytes, 0)
}
// Uploader is a service that periodically uploads data to a storage service.
type Uploader struct {
storageClient StorageClient
dataProvider DataProvider
interval time.Duration
compress bool
logger *log.Logger
lastUploadTime time.Time
lastUploadDuration time.Duration
}
// NewUploader creates a new Uploader service.
func NewUploader(storageClient StorageClient, dataProvider DataProvider, interval time.Duration, compress bool) *Uploader {
return &Uploader{
storageClient: storageClient,
dataProvider: dataProvider,
interval: interval,
compress: compress,
logger: log.New(os.Stderr, "[uploader] ", log.LstdFlags),
}
}
// Start starts the Uploader service.
func (u *Uploader) Start(ctx context.Context, isUploadEnabled func() bool) {
if isUploadEnabled == nil {
isUploadEnabled = func() bool { return true }
}
u.logger.Printf("starting upload to %s every %s", u.storageClient, u.interval)
ticker := time.NewTicker(u.interval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
u.logger.Println("upload service shutting down")
return
case <-ticker.C:
if !isUploadEnabled() {
continue
}
if err := u.upload(ctx); err != nil {
u.logger.Printf("failed to upload to %s: %v", u.storageClient, err)
}
}
}
}
// Stats returns the stats for the Uploader service.
func (u *Uploader) Stats() (map[string]interface{}, error) {
status := map[string]interface{}{
"upload_destination": u.storageClient.String(),
"upload_interval": u.interval.String(),
"compress": u.compress,
"last_upload_time": u.lastUploadTime.Format(time.RFC3339),
"last_upload_duration": u.lastUploadDuration.String(),
}
return status, nil
}
func (u *Uploader) upload(ctx context.Context) error {
rc, err := u.dataProvider.Provide()
if err != nil {
return err
}
defer rc.Close()
r := rc.(io.Reader)
if u.compress {
buffer := new(bytes.Buffer)
gw := gzip.NewWriter(buffer)
_, err = io.Copy(gw, rc)
if err != nil {
return err
}
err = gw.Close()
if err != nil {
return err
}
r = buffer
}
cr := &countingReader{reader: r}
startTime := time.Now()
err = u.storageClient.Upload(ctx, cr)
if err != nil {
stats.Add(numUploadsFail, 1)
} else {
stats.Add(numUploadsOK, 1)
stats.Add(totalUploadBytes, cr.count)
stats.Get(lastUploadBytes).(*expvar.Int).Set(cr.count)
u.lastUploadTime = time.Now()
u.lastUploadDuration = time.Since(startTime)
}
return err
}
type countingReader struct {
reader io.Reader
count int64
}
func (c *countingReader) Read(p []byte) (int, error) {
n, err := c.reader.Read(p)
c.count += int64(n)
return n, err
}

@ -0,0 +1,307 @@
package upload
import (
"compress/gzip"
"context"
"expvar"
"fmt"
"io"
"strings"
"sync"
"sync/atomic"
"testing"
"time"
)
func Test_NewUploader(t *testing.T) {
storageClient := &mockStorageClient{}
dataProvider := &mockDataProvider{}
interval := time.Second
uploader := NewUploader(storageClient, dataProvider, interval, UploadNoCompress)
if uploader.storageClient != storageClient {
t.Errorf("expected storageClient to be %v, got %v", storageClient, uploader.storageClient)
}
if uploader.dataProvider != dataProvider {
t.Errorf("expected dataProvider to be %v, got %v", dataProvider, uploader.dataProvider)
}
if uploader.interval != interval {
t.Errorf("expected interval to be %v, got %v", interval, uploader.interval)
}
}
func Test_UploaderSingleUpload(t *testing.T) {
ResetStats()
var uploadedData []byte
var err error
var wg sync.WaitGroup
wg.Add(1)
sc := &mockStorageClient{
uploadFn: func(ctx context.Context, reader io.Reader) error {
defer wg.Done()
uploadedData, err = io.ReadAll(reader)
return err
},
}
dp := &mockDataProvider{data: "my upload data"}
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
ctx, cancel := context.WithCancel(context.Background())
go uploader.Start(ctx, nil)
wg.Wait()
cancel()
<-ctx.Done()
if exp, got := "my upload data", string(uploadedData); exp != got {
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
}
}
func Test_UploaderSingleUploadCompress(t *testing.T) {
ResetStats()
var uploadedData []byte
var wg sync.WaitGroup
wg.Add(1)
sc := &mockStorageClient{
uploadFn: func(ctx context.Context, reader io.Reader) error {
defer wg.Done()
// Wrap a gzip reader about the reader.
gzReader, err := gzip.NewReader(reader)
if err != nil {
return err
}
defer gzReader.Close()
uploadedData, err = io.ReadAll(gzReader)
return err
},
}
dp := &mockDataProvider{data: "my upload data"}
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadCompress)
ctx, cancel := context.WithCancel(context.Background())
go uploader.Start(ctx, nil)
wg.Wait()
cancel()
<-ctx.Done()
if exp, got := "my upload data", string(uploadedData); exp != got {
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
}
}
func Test_UploaderDoubleUpload(t *testing.T) {
ResetStats()
var uploadedData []byte
var err error
var wg sync.WaitGroup
wg.Add(2)
sc := &mockStorageClient{
uploadFn: func(ctx context.Context, reader io.Reader) error {
defer wg.Done()
uploadedData = nil // Wipe out any previous state.
uploadedData, err = io.ReadAll(reader)
return err
},
}
dp := &mockDataProvider{data: "my upload data"}
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
ctx, cancel := context.WithCancel(context.Background())
go uploader.Start(ctx, nil)
wg.Wait()
cancel()
<-ctx.Done()
if exp, got := "my upload data", string(uploadedData); exp != got {
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
}
}
func Test_UploaderFailThenOK(t *testing.T) {
ResetStats()
var uploadedData []byte
uploadCount := 0
var err error
var wg sync.WaitGroup
wg.Add(2)
sc := &mockStorageClient{
uploadFn: func(ctx context.Context, reader io.Reader) error {
defer wg.Done()
if uploadCount == 0 {
uploadCount++
return fmt.Errorf("failed to upload")
}
uploadedData, err = io.ReadAll(reader)
return err
},
}
dp := &mockDataProvider{data: "my upload data"}
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
ctx, cancel := context.WithCancel(context.Background())
go uploader.Start(ctx, nil)
wg.Wait()
cancel()
<-ctx.Done()
if exp, got := "my upload data", string(uploadedData); exp != got {
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
}
}
func Test_UploaderOKThenFail(t *testing.T) {
var uploadedData []byte
uploadCount := 0
var err error
var wg sync.WaitGroup
wg.Add(2)
sc := &mockStorageClient{
uploadFn: func(ctx context.Context, reader io.Reader) error {
defer wg.Done()
if uploadCount == 1 {
return fmt.Errorf("failed to upload")
}
uploadCount++
uploadedData, err = io.ReadAll(reader)
return err
},
}
dp := &mockDataProvider{data: "my upload data"}
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
ctx, cancel := context.WithCancel(context.Background())
go uploader.Start(ctx, nil)
wg.Wait()
cancel()
<-ctx.Done()
if exp, got := "my upload data", string(uploadedData); exp != got {
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
}
}
func Test_UploaderContextCancellation(t *testing.T) {
var uploadCount int32
sc := &mockStorageClient{
uploadFn: func(ctx context.Context, reader io.Reader) error {
atomic.AddInt32(&uploadCount, 1)
return nil
},
}
dp := &mockDataProvider{data: "my upload data"}
uploader := NewUploader(sc, dp, time.Second, UploadNoCompress)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
go uploader.Start(ctx, nil)
<-ctx.Done()
cancel()
<-ctx.Done()
if exp, got := int32(0), atomic.LoadInt32(&uploadCount); exp != got {
t.Errorf("expected uploadCount to be %d, got %d", exp, got)
}
}
func Test_UploaderEnabledFalse(t *testing.T) {
ResetStats()
sc := &mockStorageClient{}
dp := &mockDataProvider{data: "my upload data"}
uploader := NewUploader(sc, dp, 100*time.Millisecond, false)
ctx, cancel := context.WithCancel(context.Background())
go uploader.Start(ctx, func() bool { return false })
time.Sleep(time.Second)
defer cancel()
if exp, got := int64(0), stats.Get(numUploadsOK).(*expvar.Int); exp != got.Value() {
t.Errorf("expected numUploadsOK to be %d, got %d", exp, got)
}
}
func Test_UploaderEnabledTrue(t *testing.T) {
var uploadedData []byte
var err error
ResetStats()
var wg sync.WaitGroup
wg.Add(1)
sc := &mockStorageClient{
uploadFn: func(ctx context.Context, reader io.Reader) error {
defer wg.Done()
uploadedData, err = io.ReadAll(reader)
return err
},
}
dp := &mockDataProvider{data: "my upload data"}
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
ctx, cancel := context.WithCancel(context.Background())
go uploader.Start(ctx, func() bool { return true })
defer cancel()
wg.Wait()
if exp, got := string(uploadedData), "my upload data"; exp != got {
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
}
}
func Test_UploaderStats(t *testing.T) {
sc := &mockStorageClient{}
dp := &mockDataProvider{data: "my upload data"}
interval := 100 * time.Millisecond
uploader := NewUploader(sc, dp, interval, UploadNoCompress)
stats, err := uploader.Stats()
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if exp, got := sc.String(), stats["upload_destination"]; exp != got {
t.Errorf("expected upload_destination to be %s, got %s", exp, got)
}
if exp, got := interval.String(), stats["upload_interval"]; exp != got {
t.Errorf("expected upload_interval to be %s, got %s", exp, got)
}
}
type mockStorageClient struct {
uploadFn func(ctx context.Context, reader io.Reader) error
}
func (mc *mockStorageClient) Upload(ctx context.Context, reader io.Reader) error {
if mc.uploadFn != nil {
return mc.uploadFn(ctx, reader)
}
return nil
}
func (mc *mockStorageClient) String() string {
return "mockStorageClient"
}
type mockDataProvider struct {
data string
err error
}
func (mp *mockDataProvider) Provide() (io.ReadCloser, error) {
if mp.err != nil {
return nil, mp.err
}
return io.NopCloser(strings.NewReader(mp.data)), nil
}
Loading…
Cancel
Save