commit
f1d3a1e893
@ -1,2 +0,0 @@
|
||||
// Package aws provides functionality for accessing the AWS API.
|
||||
package aws
|
@ -1,50 +0,0 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// MetadataClient is a client for fetching AWS EC2 instance metadata.
|
||||
type MetadataClient struct {
|
||||
client *http.Client
|
||||
URL string
|
||||
}
|
||||
|
||||
// NewMetadataClient returns an instance of a MetadataClient
|
||||
func NewMetadataClient() *MetadataClient {
|
||||
return &MetadataClient{
|
||||
client: &http.Client{},
|
||||
URL: `http://169.254.169.254/`,
|
||||
}
|
||||
}
|
||||
|
||||
// LocalIPv4 returns the private IPv4 address of the instance.
|
||||
func (m *MetadataClient) LocalIPv4() (string, error) {
|
||||
return m.get("/latest/meta-data/local-ipv4")
|
||||
}
|
||||
|
||||
// PublicIPv4 returns the public IPv4 address of the instance.
|
||||
func (m *MetadataClient) PublicIPv4() (string, error) {
|
||||
return m.get("/latest/meta-data/public-ipv4")
|
||||
}
|
||||
|
||||
func (m *MetadataClient) get(path string) (string, error) {
|
||||
resp, err := m.client.Get(m.URL + path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
defer resp.Body.Close()
|
||||
|
||||
b, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return "", fmt.Errorf("failed to request %s, got: %s", path, resp.Status)
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
@ -1,89 +0,0 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_NewMetadataClient(t *testing.T) {
|
||||
c := NewMetadataClient()
|
||||
if c == nil {
|
||||
t.Fatalf("failed to create new Metadata client")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_MetadataClient_LocalIPv4(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
t.Fatalf("Client did not use GET")
|
||||
}
|
||||
if r.URL.String() != "/latest/meta-data/local-ipv4" {
|
||||
t.Fatalf("Request URL is wrong, got: %s", r.URL.String())
|
||||
}
|
||||
fmt.Fprint(w, "172.31.34.179")
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
c := NewMetadataClient()
|
||||
c.URL = ts.URL
|
||||
addr, err := c.LocalIPv4()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get local IPv4 address: %s", err.Error())
|
||||
}
|
||||
if addr != "172.31.34.179" {
|
||||
t.Fatalf("got incorrect local IPv4 address: %s", addr)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_MetadataClient_LocalIPv4Fail(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
c := NewMetadataClient()
|
||||
c.URL = ts.URL
|
||||
_, err := c.LocalIPv4()
|
||||
if err == nil {
|
||||
t.Fatalf("failed to get error when server returned 400")
|
||||
}
|
||||
}
|
||||
|
||||
func Test_MetadataClient_PublicIPv4(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "GET" {
|
||||
t.Fatalf("Client did not use GET")
|
||||
}
|
||||
if r.URL.String() != "/latest/meta-data/public-ipv4" {
|
||||
t.Fatalf("Request URL is wrong, got: %s", r.URL.String())
|
||||
}
|
||||
fmt.Fprint(w, "52.38.41.98")
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
c := NewMetadataClient()
|
||||
c.URL = ts.URL
|
||||
addr, err := c.PublicIPv4()
|
||||
if err != nil {
|
||||
t.Fatalf("failed to get local IPv4 address: %s", err.Error())
|
||||
}
|
||||
if addr != "52.38.41.98" {
|
||||
t.Fatalf("got incorrect local IPv4 address: %s", addr)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_MetadataClient_PublicIPv4Fail(t *testing.T) {
|
||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
}))
|
||||
defer ts.Close()
|
||||
|
||||
c := NewMetadataClient()
|
||||
c.URL = ts.URL
|
||||
_, err := c.PublicIPv4()
|
||||
if err == nil {
|
||||
t.Fatalf("failed to get error when server returned 400")
|
||||
}
|
||||
}
|
@ -0,0 +1,106 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||
"github.com/aws/aws-sdk-go/aws/session"
|
||||
"github.com/aws/aws-sdk-go/service/s3"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
)
|
||||
|
||||
type uploader interface {
|
||||
UploadWithContext(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
|
||||
}
|
||||
|
||||
type downloader interface {
|
||||
DownloadWithContext(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, opts ...func(*s3manager.Downloader)) (n int64, err error)
|
||||
}
|
||||
|
||||
// S3Client is a client for uploading data to S3.
|
||||
type S3Client struct {
|
||||
region string
|
||||
accessKey string
|
||||
secretKey string
|
||||
bucket string
|
||||
key string
|
||||
|
||||
uploader uploader // for testing via dependency injection
|
||||
}
|
||||
|
||||
// NewS3Client returns an instance of an S3Client.
|
||||
func NewS3Client(region, accessKey, secretKey, bucket, key string) *S3Client {
|
||||
return &S3Client{
|
||||
region: region,
|
||||
accessKey: accessKey,
|
||||
secretKey: secretKey,
|
||||
bucket: bucket,
|
||||
key: key,
|
||||
}
|
||||
}
|
||||
|
||||
// String returns a string representation of the S3Client.
|
||||
func (s *S3Client) String() string {
|
||||
return fmt.Sprintf("s3://%s/%s", s.bucket, s.key)
|
||||
}
|
||||
|
||||
// Upload uploads data to S3.
|
||||
func (s *S3Client) Upload(ctx context.Context, reader io.Reader) error {
|
||||
sess, err := s.createSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If an uploader was not provided, use a real S3 uploader.
|
||||
var uploader uploader
|
||||
if s.uploader == nil {
|
||||
uploader = s3manager.NewUploader(sess)
|
||||
} else {
|
||||
uploader = s.uploader
|
||||
}
|
||||
|
||||
_, err = uploader.UploadWithContext(ctx, &s3manager.UploadInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s.key),
|
||||
Body: reader,
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upload to %v: %w", s, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Download downloads data from S3.
|
||||
func (s *S3Client) Download(ctx context.Context, writer io.WriterAt) error {
|
||||
sess, err := s.createSession()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
downloader := s3manager.NewDownloader(sess)
|
||||
|
||||
_, err = downloader.DownloadWithContext(ctx, writer, &s3.GetObjectInput{
|
||||
Bucket: aws.String(s.bucket),
|
||||
Key: aws.String(s.key),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to download from %v: %w", s, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *S3Client) createSession() (*session.Session, error) {
|
||||
sess, err := session.NewSession(&aws.Config{
|
||||
Region: aws.String(s.region),
|
||||
Credentials: credentials.NewStaticCredentials(s.accessKey, s.secretKey, ""),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create S3 session: %w", err)
|
||||
}
|
||||
return sess, nil
|
||||
}
|
@ -0,0 +1,118 @@
|
||||
package aws
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/aws/aws-sdk-go/aws"
|
||||
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||
)
|
||||
|
||||
func Test_NewS3Client(t *testing.T) {
|
||||
c := NewS3Client("region1", "access", "secret", "bucket2", "key3")
|
||||
if c.region != "region1" {
|
||||
t.Fatalf("expected region to be %q, got %q", "region1", c.region)
|
||||
}
|
||||
if c.accessKey != "access" {
|
||||
t.Fatalf("expected accessKey to be %q, got %q", "access", c.accessKey)
|
||||
}
|
||||
if c.secretKey != "secret" {
|
||||
t.Fatalf("expected secretKey to be %q, got %q", "secret", c.secretKey)
|
||||
}
|
||||
if c.bucket != "bucket2" {
|
||||
t.Fatalf("expected bucket to be %q, got %q", "bucket2", c.bucket)
|
||||
}
|
||||
if c.key != "key3" {
|
||||
t.Fatalf("expected key to be %q, got %q", "key3", c.key)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_S3Client_String(t *testing.T) {
|
||||
c := NewS3Client("region1", "access", "secret", "bucket2", "key3")
|
||||
if c.String() != "s3://bucket2/key3" {
|
||||
t.Fatalf("expected String() to be %q, got %q", "s3://bucket2/key3", c.String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestS3ClientUploadOK(t *testing.T) {
|
||||
region := "us-west-2"
|
||||
accessKey := "your-access-key"
|
||||
secretKey := "your-secret-key"
|
||||
bucket := "your-bucket"
|
||||
key := "your/key/path"
|
||||
|
||||
mockUploader := &mockUploader{
|
||||
uploadFn: func(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) {
|
||||
if *input.Bucket != bucket {
|
||||
t.Errorf("expected bucket to be %q, got %q", bucket, *input.Bucket)
|
||||
}
|
||||
if *input.Key != key {
|
||||
t.Errorf("expected key to be %q, got %q", key, *input.Key)
|
||||
}
|
||||
if input.Body == nil {
|
||||
t.Errorf("expected body to be non-nil")
|
||||
}
|
||||
return &s3manager.UploadOutput{}, nil
|
||||
},
|
||||
}
|
||||
|
||||
client := &S3Client{
|
||||
region: region,
|
||||
accessKey: accessKey,
|
||||
secretKey: secretKey,
|
||||
bucket: bucket,
|
||||
key: key,
|
||||
uploader: mockUploader,
|
||||
}
|
||||
|
||||
reader := strings.NewReader("test data")
|
||||
err := client.Upload(context.Background(), reader)
|
||||
if err != nil {
|
||||
t.Fatalf("Unexpected error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestS3ClientUploadFail(t *testing.T) {
|
||||
region := "us-west-2"
|
||||
accessKey := "your-access-key"
|
||||
secretKey := "your-secret-key"
|
||||
bucket := "your-bucket"
|
||||
key := "your/key/path"
|
||||
|
||||
mockUploader := &mockUploader{
|
||||
uploadFn: func(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) {
|
||||
return &s3manager.UploadOutput{}, fmt.Errorf("some error related to S3")
|
||||
},
|
||||
}
|
||||
|
||||
client := &S3Client{
|
||||
region: region,
|
||||
accessKey: accessKey,
|
||||
secretKey: secretKey,
|
||||
bucket: bucket,
|
||||
key: key,
|
||||
uploader: mockUploader,
|
||||
}
|
||||
|
||||
reader := strings.NewReader("test data")
|
||||
err := client.Upload(context.Background(), reader)
|
||||
if err == nil {
|
||||
t.Fatal("Expected error, got nil")
|
||||
}
|
||||
if !strings.Contains(err.Error(), "some error related to S3") {
|
||||
t.Fatalf("Expected error to contain %q, got %q", "some error related to S3", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
type mockUploader struct {
|
||||
uploadFn func(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
|
||||
}
|
||||
|
||||
func (m *mockUploader) UploadWithContext(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) {
|
||||
if m.uploadFn != nil {
|
||||
return m.uploadFn(ctx, input, opts...)
|
||||
}
|
||||
return &s3manager.UploadOutput{}, nil
|
||||
}
|
@ -0,0 +1,140 @@
|
||||
#!/usr/bin/env python
|
||||
#
|
||||
# End-to-end testing using actual rqlited binary.
|
||||
#
|
||||
# To run a specific test, execute
|
||||
#
|
||||
# python system_test/full_system_test.py Class.test
|
||||
|
||||
import os
|
||||
import json
|
||||
import unittest
|
||||
import sqlite3
|
||||
import time
|
||||
|
||||
from helpers import Node, deprovision_node, write_random_file, random_string, env_present, gunzip_file
|
||||
from s3 import download_s3_object, delete_s3_object
|
||||
|
||||
S3_BUCKET = 'rqlite-testing-circleci'
|
||||
S3_BUCKET_REGION = 'us-west-2'
|
||||
|
||||
RQLITED_PATH = os.environ['RQLITED_PATH']
|
||||
|
||||
class TestAutoBackupS3(unittest.TestCase):
|
||||
@unittest.skipUnless(env_present('RQLITE_S3_ACCESS_KEY'), "S3 credentials not available")
|
||||
def test_no_compress(self):
|
||||
'''Test that automatic backups to AWS S3 work with compression off'''
|
||||
node = None
|
||||
cfg = None
|
||||
path = None
|
||||
backup_file = None
|
||||
|
||||
access_key_id = os.environ['RQLITE_S3_ACCESS_KEY']
|
||||
secret_access_key_id = os.environ['RQLITE_S3_SECRET_ACCESS_KEY']
|
||||
|
||||
# Create the auto-backup config file
|
||||
path = random_string(32)
|
||||
auto_backup_cfg = {
|
||||
"version": 1,
|
||||
"type": "s3",
|
||||
"interval": "1s",
|
||||
"no_compress": True,
|
||||
"sub" : {
|
||||
"access_key_id": access_key_id,
|
||||
"secret_access_key": secret_access_key_id,
|
||||
"region": S3_BUCKET_REGION,
|
||||
"bucket": S3_BUCKET,
|
||||
"path": path
|
||||
}
|
||||
}
|
||||
cfg = write_random_file(json.dumps(auto_backup_cfg))
|
||||
|
||||
# Create a node, enable automatic backups, and start it. Then
|
||||
# create a table and insert a row. Wait for a backup to happen.
|
||||
node = Node(RQLITED_PATH, '0', auto_backup=cfg)
|
||||
node.start()
|
||||
node.wait_for_leader()
|
||||
node.execute('CREATE TABLE foo (id INTEGER NOT NULL PRIMARY KEY, name TEXT)')
|
||||
node.execute('INSERT INTO foo(name) VALUES("fiona")')
|
||||
node.wait_for_all_fsm()
|
||||
time.sleep(5)
|
||||
|
||||
# Download the backup file from S3 and check it.
|
||||
backup_data = download_s3_object(access_key_id, secret_access_key_id,
|
||||
S3_BUCKET, path)
|
||||
backup_file = write_random_file(backup_data, mode='wb')
|
||||
conn = sqlite3.connect(backup_file)
|
||||
c = conn.cursor()
|
||||
c.execute('SELECT * FROM foo')
|
||||
rows = c.fetchall()
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(rows[0][1], 'fiona')
|
||||
conn.close()
|
||||
|
||||
deprovision_node(node)
|
||||
os.remove(cfg)
|
||||
os.remove(backup_file)
|
||||
delete_s3_object(access_key_id, secret_access_key_id,
|
||||
S3_BUCKET, path)
|
||||
|
||||
@unittest.skipUnless(env_present('RQLITE_S3_ACCESS_KEY'), "S3 credentials not available")
|
||||
def test_compress(self):
|
||||
'''Test that automatic backups to AWS S3 work with compression on'''
|
||||
node = None
|
||||
cfg = None
|
||||
path = None
|
||||
compressed_backup_file = None
|
||||
backup_file = None
|
||||
|
||||
access_key_id = os.environ['RQLITE_S3_ACCESS_KEY']
|
||||
secret_access_key_id = os.environ['RQLITE_S3_SECRET_ACCESS_KEY']
|
||||
|
||||
# Create the auto-backup config file
|
||||
path = random_string(32)
|
||||
auto_backup_cfg = {
|
||||
"version": 1,
|
||||
"type": "s3",
|
||||
"interval": "1s",
|
||||
"sub" : {
|
||||
"access_key_id": access_key_id,
|
||||
"secret_access_key": secret_access_key_id,
|
||||
"region": S3_BUCKET_REGION,
|
||||
"bucket": S3_BUCKET,
|
||||
"path": path
|
||||
}
|
||||
}
|
||||
cfg = write_random_file(json.dumps(auto_backup_cfg))
|
||||
|
||||
# Create a node, enable automatic backups, and start it. Then
|
||||
# create a table and insert a row. Wait for a backup to happen.
|
||||
node = Node(RQLITED_PATH, '0', auto_backup=cfg)
|
||||
node.start()
|
||||
node.wait_for_leader()
|
||||
node.execute('CREATE TABLE foo (id INTEGER NOT NULL PRIMARY KEY, name TEXT)')
|
||||
node.execute('INSERT INTO foo(name) VALUES("fiona")')
|
||||
node.wait_for_all_fsm()
|
||||
time.sleep(5)
|
||||
|
||||
# Download the backup file from S3 and check it.
|
||||
backup_data = download_s3_object(access_key_id, secret_access_key_id,
|
||||
S3_BUCKET, path)
|
||||
compressed_backup_file = write_random_file(backup_data, mode='wb')
|
||||
backup_file = gunzip_file(compressed_backup_file)
|
||||
conn = sqlite3.connect(backup_file)
|
||||
c = conn.cursor()
|
||||
c.execute('SELECT * FROM foo')
|
||||
rows = c.fetchall()
|
||||
self.assertEqual(len(rows), 1)
|
||||
self.assertEqual(rows[0][1], 'fiona')
|
||||
conn.close()
|
||||
|
||||
deprovision_node(node)
|
||||
os.remove(cfg)
|
||||
os.remove(compressed_backup_file)
|
||||
os.remove(backup_file)
|
||||
delete_s3_object(access_key_id, secret_access_key_id,
|
||||
S3_BUCKET, path)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main(verbosity=2)
|
@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import boto3
|
||||
import os
|
||||
|
||||
def delete_s3_object(access_key_id, secret_access_key_id, bucket_name, object_key):
|
||||
"""
|
||||
Delete an object from an S3 bucket.
|
||||
"""
|
||||
os.environ['AWS_ACCESS_KEY_ID'] = access_key_id
|
||||
os.environ['AWS_SECRET_ACCESS_KEY'] = secret_access_key_id
|
||||
|
||||
s3_client = boto3.client('s3')
|
||||
s3_client.delete_object(Bucket=bucket_name, Key=object_key)
|
||||
|
||||
def download_s3_object(access_key_id, secret_access_key_id, bucket_name, object_key):
|
||||
"""
|
||||
Download an object from an S3 bucket.
|
||||
|
||||
Args:
|
||||
bucket_name (str): The name of the S3 bucket.
|
||||
object_key (str): The key of the object to download.
|
||||
"""
|
||||
os.environ['AWS_ACCESS_KEY_ID'] = access_key_id
|
||||
os.environ['AWS_SECRET_ACCESS_KEY'] = secret_access_key_id
|
||||
|
||||
s3_client = boto3.client('s3')
|
||||
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
|
||||
return response['Body'].read()
|
@ -0,0 +1,28 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"os"
|
||||
)
|
||||
|
||||
// AutoDeleteFile is a wrapper around os.File that deletes the file when it is
|
||||
// closed.
|
||||
type AutoDeleteFile struct {
|
||||
*os.File
|
||||
}
|
||||
|
||||
// Close implements the io.Closer interface
|
||||
func (f *AutoDeleteFile) Close() error {
|
||||
if err := f.File.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
return os.Remove(f.Name())
|
||||
}
|
||||
|
||||
// NewAutoDeleteFile takes a filename and wraps it in an AutoDeleteFile
|
||||
func NewAutoDeleteFile(path string) (*AutoDeleteFile, error) {
|
||||
f, err := os.Open(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &AutoDeleteFile{f}, nil
|
||||
}
|
@ -0,0 +1,57 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func Test_NewAutoDeleteTempFile(t *testing.T) {
|
||||
adFile, err := NewAutoDeleteFile(mustCreateTempFilename())
|
||||
if err != nil {
|
||||
t.Fatalf("NewAutoDeleteFile() failed: %v", err)
|
||||
}
|
||||
defer adFile.Close()
|
||||
|
||||
if _, err := os.Stat(adFile.Name()); os.IsNotExist(err) {
|
||||
t.Fatalf("Expected file to exist: %s", adFile.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_AutoDeleteFile_Name(t *testing.T) {
|
||||
name := mustCreateTempFilename()
|
||||
adFile, err := NewAutoDeleteFile(name)
|
||||
if err != nil {
|
||||
t.Fatalf("NewAutoDeleteFile() failed: %v", err)
|
||||
}
|
||||
defer adFile.Close()
|
||||
|
||||
if adFile.Name() != name {
|
||||
t.Fatalf("Expected Name() to return %s, got %s", name, adFile.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func Test_AutoDeleteFile_Close(t *testing.T) {
|
||||
adFile, err := NewAutoDeleteFile(mustCreateTempFilename())
|
||||
if err != nil {
|
||||
t.Fatalf("NewAutoDeleteFile() failed: %v", err)
|
||||
}
|
||||
filename := adFile.Name()
|
||||
|
||||
err = adFile.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("Close() failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filename); !os.IsNotExist(err) {
|
||||
t.Fatalf("Expected file to be deleted after Close(): %s", filename)
|
||||
}
|
||||
}
|
||||
|
||||
func mustCreateTempFilename() string {
|
||||
f, err := os.CreateTemp("", "autodeletefile_test")
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
f.Close()
|
||||
return f.Name()
|
||||
}
|
@ -0,0 +1,129 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
// version is the max version of the config file format supported
|
||||
version = 1
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidVersion is returned when the config file version is not supported.
|
||||
ErrInvalidVersion = errors.New("invalid version")
|
||||
|
||||
// ErrUnsupportedStorageType is returned when the storage type is not supported.
|
||||
ErrUnsupportedStorageType = errors.New("unsupported storage type")
|
||||
)
|
||||
|
||||
// Duration is a wrapper around time.Duration that allows us to unmarshal
|
||||
type Duration time.Duration
|
||||
|
||||
// MarshalJSON marshals the duration as a string
|
||||
func (d Duration) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(time.Duration(d).String())
|
||||
}
|
||||
|
||||
// UnmarshalJSON unmarshals the duration from a string or a float64
|
||||
func (d *Duration) UnmarshalJSON(b []byte) error {
|
||||
var v interface{}
|
||||
if err := json.Unmarshal(b, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
switch value := v.(type) {
|
||||
case float64:
|
||||
*d = Duration(time.Duration(value))
|
||||
return nil
|
||||
case string:
|
||||
tmp, err := time.ParseDuration(value)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*d = Duration(tmp)
|
||||
return nil
|
||||
default:
|
||||
return errors.New("invalid duration")
|
||||
}
|
||||
}
|
||||
|
||||
// StorageType is a wrapper around string that allows us to unmarshal
|
||||
type StorageType string
|
||||
|
||||
// UnmarshalJSON unmarshals the storage type from a string and validates it
|
||||
func (s *StorageType) UnmarshalJSON(b []byte) error {
|
||||
var v interface{}
|
||||
if err := json.Unmarshal(b, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
switch value := v.(type) {
|
||||
case string:
|
||||
*s = StorageType(value)
|
||||
if *s != "s3" {
|
||||
return ErrUnsupportedStorageType
|
||||
}
|
||||
return nil
|
||||
default:
|
||||
return ErrUnsupportedStorageType
|
||||
}
|
||||
}
|
||||
|
||||
// Config is the config file format for the upload service
|
||||
type Config struct {
|
||||
Version int `json:"version"`
|
||||
Type StorageType `json:"type"`
|
||||
NoCompress bool `json:"no_compress,omitempty"`
|
||||
Interval Duration `json:"interval"`
|
||||
Sub json.RawMessage `json:"sub"`
|
||||
}
|
||||
|
||||
// S3Config is the subconfig for the S3 storage type
|
||||
type S3Config struct {
|
||||
AccessKeyID string `json:"access_key_id"`
|
||||
SecretAccessKey string `json:"secret_access_key"`
|
||||
Region string `json:"region"`
|
||||
Bucket string `json:"bucket"`
|
||||
Path string `json:"path"`
|
||||
}
|
||||
|
||||
// Unmarshal unmarshals the config file and returns the config and subconfig
|
||||
func Unmarshal(data []byte) (*Config, *S3Config, error) {
|
||||
cfg := &Config{}
|
||||
err := json.Unmarshal(data, cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if cfg.Version > version {
|
||||
return nil, nil, ErrInvalidVersion
|
||||
}
|
||||
|
||||
s3cfg := &S3Config{}
|
||||
err = json.Unmarshal(cfg.Sub, s3cfg)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return cfg, s3cfg, nil
|
||||
}
|
||||
|
||||
// ReadConfigFile reads the config file and returns the data. It also expands
|
||||
// any environment variables in the config file.
|
||||
func ReadConfigFile(filename string) ([]byte, error) {
|
||||
f, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
data, err := ioutil.ReadAll(f)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data = []byte(os.ExpandEnv(string(data)))
|
||||
return data, nil
|
||||
}
|
@ -0,0 +1,252 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_ReadConfigFile(t *testing.T) {
|
||||
t.Run("valid config file", func(t *testing.T) {
|
||||
// Create a temporary config file
|
||||
tempFile, err := ioutil.TempFile("", "upload_config")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
content := []byte("key=value")
|
||||
if _, err := tempFile.Write(content); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tempFile.Close()
|
||||
|
||||
data, err := ReadConfigFile(tempFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
if !bytes.Equal(data, content) {
|
||||
t.Fatalf("Expected %v, got %v", content, data)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("non-existent file", func(t *testing.T) {
|
||||
_, err := ReadConfigFile("nonexistentfile")
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
t.Fatalf("Expected os.ErrNotExist, got %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("file with environment variables", func(t *testing.T) {
|
||||
// Set an environment variable
|
||||
if err := os.Setenv("TEST_VAR", "test_value"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Unsetenv("TEST_VAR")
|
||||
|
||||
// Create a temporary config file with an environment variable
|
||||
tempFile, err := ioutil.TempFile("", "upload_config")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
content := []byte("key=$TEST_VAR")
|
||||
if _, err := tempFile.Write(content); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tempFile.Close()
|
||||
|
||||
data, err := ReadConfigFile(tempFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
expectedContent := []byte("key=test_value")
|
||||
if !bytes.Equal(data, expectedContent) {
|
||||
t.Fatalf("Expected %v, got %v", expectedContent, data)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("longer file with environment variables", func(t *testing.T) {
|
||||
// Set an environment variable
|
||||
if err := os.Setenv("TEST_VAR1", "test_value"); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Unsetenv("TEST_VAR1")
|
||||
|
||||
// Create a temporary config file with an environment variable
|
||||
tempFile, err := ioutil.TempFile("", "upload_config")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.Remove(tempFile.Name())
|
||||
|
||||
content := []byte(`
|
||||
key1=$TEST_VAR1
|
||||
key2=TEST_VAR2`)
|
||||
if _, err := tempFile.Write(content); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
tempFile.Close()
|
||||
|
||||
data, err := ReadConfigFile(tempFile.Name())
|
||||
if err != nil {
|
||||
t.Fatalf("Expected no error, got %v", err)
|
||||
}
|
||||
expectedContent := []byte(`
|
||||
key1=test_value
|
||||
key2=TEST_VAR2`)
|
||||
if !bytes.Equal(data, expectedContent) {
|
||||
t.Fatalf("Expected %v, got %v", expectedContent, data)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestUnmarshal(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expectedCfg *Config
|
||||
expectedS3 *S3Config
|
||||
expectedErr error
|
||||
}{
|
||||
{
|
||||
name: "ValidS3Config",
|
||||
input: []byte(`
|
||||
{
|
||||
"version": 1,
|
||||
"type": "s3",
|
||||
"no_compress": true,
|
||||
"interval": "24h",
|
||||
"sub": {
|
||||
"access_key_id": "test_id",
|
||||
"secret_access_key": "test_secret",
|
||||
"region": "us-west-2",
|
||||
"bucket": "test_bucket",
|
||||
"path": "test/path"
|
||||
}
|
||||
}
|
||||
`),
|
||||
expectedCfg: &Config{
|
||||
Version: 1,
|
||||
Type: "s3",
|
||||
NoCompress: true,
|
||||
Interval: 24 * Duration(time.Hour),
|
||||
},
|
||||
expectedS3: &S3Config{
|
||||
AccessKeyID: "test_id",
|
||||
SecretAccessKey: "test_secret",
|
||||
Region: "us-west-2",
|
||||
Bucket: "test_bucket",
|
||||
Path: "test/path",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "ValidS3ConfigNoptionalFields",
|
||||
input: []byte(`
|
||||
{
|
||||
"version": 1,
|
||||
"type": "s3",
|
||||
"interval": "24h",
|
||||
"sub": {
|
||||
"access_key_id": "test_id",
|
||||
"secret_access_key": "test_secret",
|
||||
"region": "us-west-2",
|
||||
"bucket": "test_bucket",
|
||||
"path": "test/path"
|
||||
}
|
||||
}
|
||||
`),
|
||||
expectedCfg: &Config{
|
||||
Version: 1,
|
||||
Type: "s3",
|
||||
NoCompress: false,
|
||||
Interval: 24 * Duration(time.Hour),
|
||||
},
|
||||
expectedS3: &S3Config{
|
||||
AccessKeyID: "test_id",
|
||||
SecretAccessKey: "test_secret",
|
||||
Region: "us-west-2",
|
||||
Bucket: "test_bucket",
|
||||
Path: "test/path",
|
||||
},
|
||||
expectedErr: nil,
|
||||
},
|
||||
{
|
||||
name: "InvalidVersion",
|
||||
input: []byte(`
|
||||
{
|
||||
"version": 2,
|
||||
"type": "s3",
|
||||
"no_compress": false,
|
||||
"interval": "24h",
|
||||
"sub": {
|
||||
"access_key_id": "test_id",
|
||||
"secret_access_key": "test_secret",
|
||||
"region": "us-west-2",
|
||||
"bucket": "test_bucket",
|
||||
"path": "test/path"
|
||||
}
|
||||
} `),
|
||||
expectedCfg: nil,
|
||||
expectedS3: nil,
|
||||
expectedErr: ErrInvalidVersion,
|
||||
},
|
||||
{
|
||||
name: "UnsupportedType",
|
||||
input: []byte(`
|
||||
{
|
||||
"version": 1,
|
||||
"type": "unsupported",
|
||||
"no_compress": true,
|
||||
"interval": "24h",
|
||||
"sub": {
|
||||
"access_key_id": "test_id",
|
||||
"secret_access_key": "test_secret",
|
||||
"region": "us-west-2",
|
||||
"bucket": "test_bucket",
|
||||
"path": "test/path"
|
||||
}
|
||||
} `),
|
||||
expectedCfg: nil,
|
||||
expectedS3: nil,
|
||||
expectedErr: ErrUnsupportedStorageType,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
cfg, s3Cfg, err := Unmarshal(tc.input)
|
||||
_ = s3Cfg
|
||||
|
||||
if !errors.Is(err, tc.expectedErr) {
|
||||
t.Fatalf("Test case %s failed, expected error %v, got %v", tc.name, tc.expectedErr, err)
|
||||
}
|
||||
|
||||
if !compareConfig(cfg, tc.expectedCfg) {
|
||||
t.Fatalf("Test case %s failed, expected config %+v, got %+v", tc.name, tc.expectedCfg, cfg)
|
||||
}
|
||||
|
||||
if tc.expectedS3 != nil {
|
||||
if !reflect.DeepEqual(s3Cfg, tc.expectedS3) {
|
||||
t.Fatalf("Test case %s failed, expected S3Config %+v, got %+v", tc.name, tc.expectedS3, s3Cfg)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func compareConfig(a, b *Config) bool {
|
||||
if a == nil || b == nil {
|
||||
return a == b
|
||||
}
|
||||
return a.Version == b.Version &&
|
||||
a.Type == b.Type &&
|
||||
a.NoCompress == b.NoCompress &&
|
||||
a.Interval == b.Interval
|
||||
}
|
@ -0,0 +1,163 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
)
|
||||
|
||||
// StorageClient is an interface for uploading data to a storage service.
|
||||
type StorageClient interface {
|
||||
Upload(ctx context.Context, reader io.Reader) error
|
||||
fmt.Stringer
|
||||
}
|
||||
|
||||
// DataProvider is an interface for providing data to be uploaded. The Uploader
|
||||
// service will call Provide() to get a reader for the data to be uploaded. Once
|
||||
// the upload completes the reader will be closed, regardless of whether the
|
||||
// upload succeeded or failed.
|
||||
type DataProvider interface {
|
||||
Provide() (io.ReadCloser, error)
|
||||
}
|
||||
|
||||
// stats captures stats for the Uploader service.
|
||||
var stats *expvar.Map
|
||||
|
||||
const (
|
||||
numUploadsOK = "num_uploads_ok"
|
||||
numUploadsFail = "num_uploads_fail"
|
||||
totalUploadBytes = "total_upload_bytes"
|
||||
lastUploadBytes = "last_upload_bytes"
|
||||
|
||||
UploadCompress = true
|
||||
UploadNoCompress = false
|
||||
)
|
||||
|
||||
func init() {
|
||||
stats = expvar.NewMap("uploader")
|
||||
ResetStats()
|
||||
}
|
||||
|
||||
// ResetStats resets the expvar stats for this module. Mostly for test purposes.
|
||||
func ResetStats() {
|
||||
stats.Init()
|
||||
stats.Add(numUploadsOK, 0)
|
||||
stats.Add(numUploadsFail, 0)
|
||||
stats.Add(totalUploadBytes, 0)
|
||||
stats.Add(lastUploadBytes, 0)
|
||||
}
|
||||
|
||||
// Uploader is a service that periodically uploads data to a storage service.
|
||||
type Uploader struct {
|
||||
storageClient StorageClient
|
||||
dataProvider DataProvider
|
||||
interval time.Duration
|
||||
compress bool
|
||||
|
||||
logger *log.Logger
|
||||
lastUploadTime time.Time
|
||||
lastUploadDuration time.Duration
|
||||
}
|
||||
|
||||
// NewUploader creates a new Uploader service.
|
||||
func NewUploader(storageClient StorageClient, dataProvider DataProvider, interval time.Duration, compress bool) *Uploader {
|
||||
return &Uploader{
|
||||
storageClient: storageClient,
|
||||
dataProvider: dataProvider,
|
||||
interval: interval,
|
||||
compress: compress,
|
||||
logger: log.New(os.Stderr, "[uploader] ", log.LstdFlags),
|
||||
}
|
||||
}
|
||||
|
||||
// Start starts the Uploader service.
|
||||
func (u *Uploader) Start(ctx context.Context, isUploadEnabled func() bool) {
|
||||
if isUploadEnabled == nil {
|
||||
isUploadEnabled = func() bool { return true }
|
||||
}
|
||||
|
||||
u.logger.Printf("starting upload to %s every %s", u.storageClient, u.interval)
|
||||
ticker := time.NewTicker(u.interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
u.logger.Println("upload service shutting down")
|
||||
return
|
||||
case <-ticker.C:
|
||||
if !isUploadEnabled() {
|
||||
continue
|
||||
}
|
||||
if err := u.upload(ctx); err != nil {
|
||||
u.logger.Printf("failed to upload to %s: %v", u.storageClient, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stats returns the stats for the Uploader service.
|
||||
func (u *Uploader) Stats() (map[string]interface{}, error) {
|
||||
status := map[string]interface{}{
|
||||
"upload_destination": u.storageClient.String(),
|
||||
"upload_interval": u.interval.String(),
|
||||
"compress": u.compress,
|
||||
"last_upload_time": u.lastUploadTime.Format(time.RFC3339),
|
||||
"last_upload_duration": u.lastUploadDuration.String(),
|
||||
}
|
||||
return status, nil
|
||||
}
|
||||
|
||||
func (u *Uploader) upload(ctx context.Context) error {
|
||||
rc, err := u.dataProvider.Provide()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rc.Close()
|
||||
|
||||
r := rc.(io.Reader)
|
||||
if u.compress {
|
||||
buffer := new(bytes.Buffer)
|
||||
gw := gzip.NewWriter(buffer)
|
||||
_, err = io.Copy(gw, rc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = gw.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
r = buffer
|
||||
}
|
||||
|
||||
cr := &countingReader{reader: r}
|
||||
startTime := time.Now()
|
||||
err = u.storageClient.Upload(ctx, cr)
|
||||
if err != nil {
|
||||
stats.Add(numUploadsFail, 1)
|
||||
} else {
|
||||
stats.Add(numUploadsOK, 1)
|
||||
stats.Add(totalUploadBytes, cr.count)
|
||||
stats.Get(lastUploadBytes).(*expvar.Int).Set(cr.count)
|
||||
u.lastUploadTime = time.Now()
|
||||
u.lastUploadDuration = time.Since(startTime)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type countingReader struct {
|
||||
reader io.Reader
|
||||
count int64
|
||||
}
|
||||
|
||||
func (c *countingReader) Read(p []byte) (int, error) {
|
||||
n, err := c.reader.Read(p)
|
||||
c.count += int64(n)
|
||||
return n, err
|
||||
}
|
@ -0,0 +1,307 @@
|
||||
package upload
|
||||
|
||||
import (
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"expvar"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func Test_NewUploader(t *testing.T) {
|
||||
storageClient := &mockStorageClient{}
|
||||
dataProvider := &mockDataProvider{}
|
||||
interval := time.Second
|
||||
uploader := NewUploader(storageClient, dataProvider, interval, UploadNoCompress)
|
||||
if uploader.storageClient != storageClient {
|
||||
t.Errorf("expected storageClient to be %v, got %v", storageClient, uploader.storageClient)
|
||||
}
|
||||
if uploader.dataProvider != dataProvider {
|
||||
t.Errorf("expected dataProvider to be %v, got %v", dataProvider, uploader.dataProvider)
|
||||
}
|
||||
if uploader.interval != interval {
|
||||
t.Errorf("expected interval to be %v, got %v", interval, uploader.interval)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UploaderSingleUpload(t *testing.T) {
|
||||
ResetStats()
|
||||
var uploadedData []byte
|
||||
var err error
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
sc := &mockStorageClient{
|
||||
uploadFn: func(ctx context.Context, reader io.Reader) error {
|
||||
defer wg.Done()
|
||||
uploadedData, err = io.ReadAll(reader)
|
||||
return err
|
||||
},
|
||||
}
|
||||
dp := &mockDataProvider{data: "my upload data"}
|
||||
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go uploader.Start(ctx, nil)
|
||||
wg.Wait()
|
||||
cancel()
|
||||
<-ctx.Done()
|
||||
|
||||
if exp, got := "my upload data", string(uploadedData); exp != got {
|
||||
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UploaderSingleUploadCompress(t *testing.T) {
|
||||
ResetStats()
|
||||
var uploadedData []byte
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
sc := &mockStorageClient{
|
||||
uploadFn: func(ctx context.Context, reader io.Reader) error {
|
||||
defer wg.Done()
|
||||
|
||||
// Wrap a gzip reader about the reader.
|
||||
gzReader, err := gzip.NewReader(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer gzReader.Close()
|
||||
|
||||
uploadedData, err = io.ReadAll(gzReader)
|
||||
return err
|
||||
},
|
||||
}
|
||||
dp := &mockDataProvider{data: "my upload data"}
|
||||
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadCompress)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go uploader.Start(ctx, nil)
|
||||
wg.Wait()
|
||||
cancel()
|
||||
<-ctx.Done()
|
||||
|
||||
if exp, got := "my upload data", string(uploadedData); exp != got {
|
||||
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UploaderDoubleUpload(t *testing.T) {
|
||||
ResetStats()
|
||||
|
||||
var uploadedData []byte
|
||||
var err error
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
sc := &mockStorageClient{
|
||||
uploadFn: func(ctx context.Context, reader io.Reader) error {
|
||||
defer wg.Done()
|
||||
uploadedData = nil // Wipe out any previous state.
|
||||
uploadedData, err = io.ReadAll(reader)
|
||||
return err
|
||||
},
|
||||
}
|
||||
dp := &mockDataProvider{data: "my upload data"}
|
||||
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go uploader.Start(ctx, nil)
|
||||
wg.Wait()
|
||||
cancel()
|
||||
<-ctx.Done()
|
||||
|
||||
if exp, got := "my upload data", string(uploadedData); exp != got {
|
||||
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UploaderFailThenOK(t *testing.T) {
|
||||
ResetStats()
|
||||
|
||||
var uploadedData []byte
|
||||
uploadCount := 0
|
||||
var err error
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
sc := &mockStorageClient{
|
||||
uploadFn: func(ctx context.Context, reader io.Reader) error {
|
||||
defer wg.Done()
|
||||
|
||||
if uploadCount == 0 {
|
||||
uploadCount++
|
||||
return fmt.Errorf("failed to upload")
|
||||
}
|
||||
|
||||
uploadedData, err = io.ReadAll(reader)
|
||||
return err
|
||||
},
|
||||
}
|
||||
dp := &mockDataProvider{data: "my upload data"}
|
||||
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go uploader.Start(ctx, nil)
|
||||
wg.Wait()
|
||||
cancel()
|
||||
<-ctx.Done()
|
||||
|
||||
if exp, got := "my upload data", string(uploadedData); exp != got {
|
||||
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UploaderOKThenFail(t *testing.T) {
|
||||
var uploadedData []byte
|
||||
uploadCount := 0
|
||||
var err error
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(2)
|
||||
sc := &mockStorageClient{
|
||||
uploadFn: func(ctx context.Context, reader io.Reader) error {
|
||||
defer wg.Done()
|
||||
|
||||
if uploadCount == 1 {
|
||||
return fmt.Errorf("failed to upload")
|
||||
}
|
||||
|
||||
uploadCount++
|
||||
uploadedData, err = io.ReadAll(reader)
|
||||
return err
|
||||
},
|
||||
}
|
||||
dp := &mockDataProvider{data: "my upload data"}
|
||||
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go uploader.Start(ctx, nil)
|
||||
wg.Wait()
|
||||
cancel()
|
||||
<-ctx.Done()
|
||||
|
||||
if exp, got := "my upload data", string(uploadedData); exp != got {
|
||||
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UploaderContextCancellation(t *testing.T) {
|
||||
var uploadCount int32
|
||||
|
||||
sc := &mockStorageClient{
|
||||
uploadFn: func(ctx context.Context, reader io.Reader) error {
|
||||
atomic.AddInt32(&uploadCount, 1)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
dp := &mockDataProvider{data: "my upload data"}
|
||||
uploader := NewUploader(sc, dp, time.Second, UploadNoCompress)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
|
||||
go uploader.Start(ctx, nil)
|
||||
<-ctx.Done()
|
||||
cancel()
|
||||
<-ctx.Done()
|
||||
|
||||
if exp, got := int32(0), atomic.LoadInt32(&uploadCount); exp != got {
|
||||
t.Errorf("expected uploadCount to be %d, got %d", exp, got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UploaderEnabledFalse(t *testing.T) {
|
||||
ResetStats()
|
||||
|
||||
sc := &mockStorageClient{}
|
||||
dp := &mockDataProvider{data: "my upload data"}
|
||||
uploader := NewUploader(sc, dp, 100*time.Millisecond, false)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go uploader.Start(ctx, func() bool { return false })
|
||||
time.Sleep(time.Second)
|
||||
defer cancel()
|
||||
|
||||
if exp, got := int64(0), stats.Get(numUploadsOK).(*expvar.Int); exp != got.Value() {
|
||||
t.Errorf("expected numUploadsOK to be %d, got %d", exp, got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UploaderEnabledTrue(t *testing.T) {
|
||||
var uploadedData []byte
|
||||
var err error
|
||||
ResetStats()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
sc := &mockStorageClient{
|
||||
uploadFn: func(ctx context.Context, reader io.Reader) error {
|
||||
defer wg.Done()
|
||||
uploadedData, err = io.ReadAll(reader)
|
||||
return err
|
||||
},
|
||||
}
|
||||
dp := &mockDataProvider{data: "my upload data"}
|
||||
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go uploader.Start(ctx, func() bool { return true })
|
||||
defer cancel()
|
||||
|
||||
wg.Wait()
|
||||
if exp, got := string(uploadedData), "my upload data"; exp != got {
|
||||
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
|
||||
}
|
||||
}
|
||||
|
||||
func Test_UploaderStats(t *testing.T) {
|
||||
sc := &mockStorageClient{}
|
||||
dp := &mockDataProvider{data: "my upload data"}
|
||||
interval := 100 * time.Millisecond
|
||||
uploader := NewUploader(sc, dp, interval, UploadNoCompress)
|
||||
|
||||
stats, err := uploader.Stats()
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
|
||||
if exp, got := sc.String(), stats["upload_destination"]; exp != got {
|
||||
t.Errorf("expected upload_destination to be %s, got %s", exp, got)
|
||||
}
|
||||
|
||||
if exp, got := interval.String(), stats["upload_interval"]; exp != got {
|
||||
t.Errorf("expected upload_interval to be %s, got %s", exp, got)
|
||||
}
|
||||
}
|
||||
|
||||
type mockStorageClient struct {
|
||||
uploadFn func(ctx context.Context, reader io.Reader) error
|
||||
}
|
||||
|
||||
func (mc *mockStorageClient) Upload(ctx context.Context, reader io.Reader) error {
|
||||
if mc.uploadFn != nil {
|
||||
return mc.uploadFn(ctx, reader)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mc *mockStorageClient) String() string {
|
||||
return "mockStorageClient"
|
||||
}
|
||||
|
||||
type mockDataProvider struct {
|
||||
data string
|
||||
err error
|
||||
}
|
||||
|
||||
func (mp *mockDataProvider) Provide() (io.ReadCloser, error) {
|
||||
if mp.err != nil {
|
||||
return nil, mp.err
|
||||
}
|
||||
return io.NopCloser(strings.NewReader(mp.data)), nil
|
||||
}
|
Loading…
Reference in New Issue