commit
f1d3a1e893
@ -1,2 +0,0 @@
|
|||||||
// Package aws provides functionality for accessing the AWS API.
|
|
||||||
package aws
|
|
@ -1,50 +0,0 @@
|
|||||||
package aws
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"io/ioutil"
|
|
||||||
"net/http"
|
|
||||||
)
|
|
||||||
|
|
||||||
// MetadataClient is a client for fetching AWS EC2 instance metadata.
|
|
||||||
type MetadataClient struct {
|
|
||||||
client *http.Client
|
|
||||||
URL string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewMetadataClient returns an instance of a MetadataClient
|
|
||||||
func NewMetadataClient() *MetadataClient {
|
|
||||||
return &MetadataClient{
|
|
||||||
client: &http.Client{},
|
|
||||||
URL: `http://169.254.169.254/`,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// LocalIPv4 returns the private IPv4 address of the instance.
|
|
||||||
func (m *MetadataClient) LocalIPv4() (string, error) {
|
|
||||||
return m.get("/latest/meta-data/local-ipv4")
|
|
||||||
}
|
|
||||||
|
|
||||||
// PublicIPv4 returns the public IPv4 address of the instance.
|
|
||||||
func (m *MetadataClient) PublicIPv4() (string, error) {
|
|
||||||
return m.get("/latest/meta-data/public-ipv4")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *MetadataClient) get(path string) (string, error) {
|
|
||||||
resp, err := m.client.Get(m.URL + path)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
b, err := ioutil.ReadAll(resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
|
|
||||||
if resp.StatusCode != http.StatusOK {
|
|
||||||
return "", fmt.Errorf("failed to request %s, got: %s", path, resp.Status)
|
|
||||||
}
|
|
||||||
return string(b), nil
|
|
||||||
}
|
|
@ -1,89 +0,0 @@
|
|||||||
package aws
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"net/http"
|
|
||||||
"net/http/httptest"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_NewMetadataClient(t *testing.T) {
|
|
||||||
c := NewMetadataClient()
|
|
||||||
if c == nil {
|
|
||||||
t.Fatalf("failed to create new Metadata client")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_MetadataClient_LocalIPv4(t *testing.T) {
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method != "GET" {
|
|
||||||
t.Fatalf("Client did not use GET")
|
|
||||||
}
|
|
||||||
if r.URL.String() != "/latest/meta-data/local-ipv4" {
|
|
||||||
t.Fatalf("Request URL is wrong, got: %s", r.URL.String())
|
|
||||||
}
|
|
||||||
fmt.Fprint(w, "172.31.34.179")
|
|
||||||
}))
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
c := NewMetadataClient()
|
|
||||||
c.URL = ts.URL
|
|
||||||
addr, err := c.LocalIPv4()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get local IPv4 address: %s", err.Error())
|
|
||||||
}
|
|
||||||
if addr != "172.31.34.179" {
|
|
||||||
t.Fatalf("got incorrect local IPv4 address: %s", addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_MetadataClient_LocalIPv4Fail(t *testing.T) {
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
}))
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
c := NewMetadataClient()
|
|
||||||
c.URL = ts.URL
|
|
||||||
_, err := c.LocalIPv4()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("failed to get error when server returned 400")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_MetadataClient_PublicIPv4(t *testing.T) {
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
if r.Method != "GET" {
|
|
||||||
t.Fatalf("Client did not use GET")
|
|
||||||
}
|
|
||||||
if r.URL.String() != "/latest/meta-data/public-ipv4" {
|
|
||||||
t.Fatalf("Request URL is wrong, got: %s", r.URL.String())
|
|
||||||
}
|
|
||||||
fmt.Fprint(w, "52.38.41.98")
|
|
||||||
}))
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
c := NewMetadataClient()
|
|
||||||
c.URL = ts.URL
|
|
||||||
addr, err := c.PublicIPv4()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to get local IPv4 address: %s", err.Error())
|
|
||||||
}
|
|
||||||
if addr != "52.38.41.98" {
|
|
||||||
t.Fatalf("got incorrect local IPv4 address: %s", addr)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_MetadataClient_PublicIPv4Fail(t *testing.T) {
|
|
||||||
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
||||||
w.WriteHeader(http.StatusBadRequest)
|
|
||||||
}))
|
|
||||||
defer ts.Close()
|
|
||||||
|
|
||||||
c := NewMetadataClient()
|
|
||||||
c.URL = ts.URL
|
|
||||||
_, err := c.PublicIPv4()
|
|
||||||
if err == nil {
|
|
||||||
t.Fatalf("failed to get error when server returned 400")
|
|
||||||
}
|
|
||||||
}
|
|
@ -0,0 +1,106 @@
|
|||||||
|
package aws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/credentials"
|
||||||
|
"github.com/aws/aws-sdk-go/aws/session"
|
||||||
|
"github.com/aws/aws-sdk-go/service/s3"
|
||||||
|
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
type uploader interface {
|
||||||
|
UploadWithContext(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
type downloader interface {
|
||||||
|
DownloadWithContext(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, opts ...func(*s3manager.Downloader)) (n int64, err error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// S3Client is a client for uploading data to S3.
|
||||||
|
type S3Client struct {
|
||||||
|
region string
|
||||||
|
accessKey string
|
||||||
|
secretKey string
|
||||||
|
bucket string
|
||||||
|
key string
|
||||||
|
|
||||||
|
uploader uploader // for testing via dependency injection
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewS3Client returns an instance of an S3Client.
|
||||||
|
func NewS3Client(region, accessKey, secretKey, bucket, key string) *S3Client {
|
||||||
|
return &S3Client{
|
||||||
|
region: region,
|
||||||
|
accessKey: accessKey,
|
||||||
|
secretKey: secretKey,
|
||||||
|
bucket: bucket,
|
||||||
|
key: key,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns a string representation of the S3Client.
|
||||||
|
func (s *S3Client) String() string {
|
||||||
|
return fmt.Sprintf("s3://%s/%s", s.bucket, s.key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Upload uploads data to S3.
|
||||||
|
func (s *S3Client) Upload(ctx context.Context, reader io.Reader) error {
|
||||||
|
sess, err := s.createSession()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// If an uploader was not provided, use a real S3 uploader.
|
||||||
|
var uploader uploader
|
||||||
|
if s.uploader == nil {
|
||||||
|
uploader = s3manager.NewUploader(sess)
|
||||||
|
} else {
|
||||||
|
uploader = s.uploader
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = uploader.UploadWithContext(ctx, &s3manager.UploadInput{
|
||||||
|
Bucket: aws.String(s.bucket),
|
||||||
|
Key: aws.String(s.key),
|
||||||
|
Body: reader,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to upload to %v: %w", s, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Download downloads data from S3.
|
||||||
|
func (s *S3Client) Download(ctx context.Context, writer io.WriterAt) error {
|
||||||
|
sess, err := s.createSession()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
downloader := s3manager.NewDownloader(sess)
|
||||||
|
|
||||||
|
_, err = downloader.DownloadWithContext(ctx, writer, &s3.GetObjectInput{
|
||||||
|
Bucket: aws.String(s.bucket),
|
||||||
|
Key: aws.String(s.key),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to download from %v: %w", s, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *S3Client) createSession() (*session.Session, error) {
|
||||||
|
sess, err := session.NewSession(&aws.Config{
|
||||||
|
Region: aws.String(s.region),
|
||||||
|
Credentials: credentials.NewStaticCredentials(s.accessKey, s.secretKey, ""),
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create S3 session: %w", err)
|
||||||
|
}
|
||||||
|
return sess, nil
|
||||||
|
}
|
@ -0,0 +1,118 @@
|
|||||||
|
package aws
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/aws/aws-sdk-go/aws"
|
||||||
|
"github.com/aws/aws-sdk-go/service/s3/s3manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_NewS3Client(t *testing.T) {
|
||||||
|
c := NewS3Client("region1", "access", "secret", "bucket2", "key3")
|
||||||
|
if c.region != "region1" {
|
||||||
|
t.Fatalf("expected region to be %q, got %q", "region1", c.region)
|
||||||
|
}
|
||||||
|
if c.accessKey != "access" {
|
||||||
|
t.Fatalf("expected accessKey to be %q, got %q", "access", c.accessKey)
|
||||||
|
}
|
||||||
|
if c.secretKey != "secret" {
|
||||||
|
t.Fatalf("expected secretKey to be %q, got %q", "secret", c.secretKey)
|
||||||
|
}
|
||||||
|
if c.bucket != "bucket2" {
|
||||||
|
t.Fatalf("expected bucket to be %q, got %q", "bucket2", c.bucket)
|
||||||
|
}
|
||||||
|
if c.key != "key3" {
|
||||||
|
t.Fatalf("expected key to be %q, got %q", "key3", c.key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_S3Client_String(t *testing.T) {
|
||||||
|
c := NewS3Client("region1", "access", "secret", "bucket2", "key3")
|
||||||
|
if c.String() != "s3://bucket2/key3" {
|
||||||
|
t.Fatalf("expected String() to be %q, got %q", "s3://bucket2/key3", c.String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestS3ClientUploadOK(t *testing.T) {
|
||||||
|
region := "us-west-2"
|
||||||
|
accessKey := "your-access-key"
|
||||||
|
secretKey := "your-secret-key"
|
||||||
|
bucket := "your-bucket"
|
||||||
|
key := "your/key/path"
|
||||||
|
|
||||||
|
mockUploader := &mockUploader{
|
||||||
|
uploadFn: func(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) {
|
||||||
|
if *input.Bucket != bucket {
|
||||||
|
t.Errorf("expected bucket to be %q, got %q", bucket, *input.Bucket)
|
||||||
|
}
|
||||||
|
if *input.Key != key {
|
||||||
|
t.Errorf("expected key to be %q, got %q", key, *input.Key)
|
||||||
|
}
|
||||||
|
if input.Body == nil {
|
||||||
|
t.Errorf("expected body to be non-nil")
|
||||||
|
}
|
||||||
|
return &s3manager.UploadOutput{}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &S3Client{
|
||||||
|
region: region,
|
||||||
|
accessKey: accessKey,
|
||||||
|
secretKey: secretKey,
|
||||||
|
bucket: bucket,
|
||||||
|
key: key,
|
||||||
|
uploader: mockUploader,
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := strings.NewReader("test data")
|
||||||
|
err := client.Upload(context.Background(), reader)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestS3ClientUploadFail(t *testing.T) {
|
||||||
|
region := "us-west-2"
|
||||||
|
accessKey := "your-access-key"
|
||||||
|
secretKey := "your-secret-key"
|
||||||
|
bucket := "your-bucket"
|
||||||
|
key := "your/key/path"
|
||||||
|
|
||||||
|
mockUploader := &mockUploader{
|
||||||
|
uploadFn: func(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) {
|
||||||
|
return &s3manager.UploadOutput{}, fmt.Errorf("some error related to S3")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
client := &S3Client{
|
||||||
|
region: region,
|
||||||
|
accessKey: accessKey,
|
||||||
|
secretKey: secretKey,
|
||||||
|
bucket: bucket,
|
||||||
|
key: key,
|
||||||
|
uploader: mockUploader,
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := strings.NewReader("test data")
|
||||||
|
err := client.Upload(context.Background(), reader)
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal("Expected error, got nil")
|
||||||
|
}
|
||||||
|
if !strings.Contains(err.Error(), "some error related to S3") {
|
||||||
|
t.Fatalf("Expected error to contain %q, got %q", "some error related to S3", err.Error())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockUploader struct {
|
||||||
|
uploadFn func(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockUploader) UploadWithContext(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) {
|
||||||
|
if m.uploadFn != nil {
|
||||||
|
return m.uploadFn(ctx, input, opts...)
|
||||||
|
}
|
||||||
|
return &s3manager.UploadOutput{}, nil
|
||||||
|
}
|
@ -0,0 +1,140 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
#
|
||||||
|
# End-to-end testing using actual rqlited binary.
|
||||||
|
#
|
||||||
|
# To run a specific test, execute
|
||||||
|
#
|
||||||
|
# python system_test/full_system_test.py Class.test
|
||||||
|
|
||||||
|
import os
|
||||||
|
import json
|
||||||
|
import unittest
|
||||||
|
import sqlite3
|
||||||
|
import time
|
||||||
|
|
||||||
|
from helpers import Node, deprovision_node, write_random_file, random_string, env_present, gunzip_file
|
||||||
|
from s3 import download_s3_object, delete_s3_object
|
||||||
|
|
||||||
|
S3_BUCKET = 'rqlite-testing-circleci'
|
||||||
|
S3_BUCKET_REGION = 'us-west-2'
|
||||||
|
|
||||||
|
RQLITED_PATH = os.environ['RQLITED_PATH']
|
||||||
|
|
||||||
|
class TestAutoBackupS3(unittest.TestCase):
|
||||||
|
@unittest.skipUnless(env_present('RQLITE_S3_ACCESS_KEY'), "S3 credentials not available")
|
||||||
|
def test_no_compress(self):
|
||||||
|
'''Test that automatic backups to AWS S3 work with compression off'''
|
||||||
|
node = None
|
||||||
|
cfg = None
|
||||||
|
path = None
|
||||||
|
backup_file = None
|
||||||
|
|
||||||
|
access_key_id = os.environ['RQLITE_S3_ACCESS_KEY']
|
||||||
|
secret_access_key_id = os.environ['RQLITE_S3_SECRET_ACCESS_KEY']
|
||||||
|
|
||||||
|
# Create the auto-backup config file
|
||||||
|
path = random_string(32)
|
||||||
|
auto_backup_cfg = {
|
||||||
|
"version": 1,
|
||||||
|
"type": "s3",
|
||||||
|
"interval": "1s",
|
||||||
|
"no_compress": True,
|
||||||
|
"sub" : {
|
||||||
|
"access_key_id": access_key_id,
|
||||||
|
"secret_access_key": secret_access_key_id,
|
||||||
|
"region": S3_BUCKET_REGION,
|
||||||
|
"bucket": S3_BUCKET,
|
||||||
|
"path": path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfg = write_random_file(json.dumps(auto_backup_cfg))
|
||||||
|
|
||||||
|
# Create a node, enable automatic backups, and start it. Then
|
||||||
|
# create a table and insert a row. Wait for a backup to happen.
|
||||||
|
node = Node(RQLITED_PATH, '0', auto_backup=cfg)
|
||||||
|
node.start()
|
||||||
|
node.wait_for_leader()
|
||||||
|
node.execute('CREATE TABLE foo (id INTEGER NOT NULL PRIMARY KEY, name TEXT)')
|
||||||
|
node.execute('INSERT INTO foo(name) VALUES("fiona")')
|
||||||
|
node.wait_for_all_fsm()
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
# Download the backup file from S3 and check it.
|
||||||
|
backup_data = download_s3_object(access_key_id, secret_access_key_id,
|
||||||
|
S3_BUCKET, path)
|
||||||
|
backup_file = write_random_file(backup_data, mode='wb')
|
||||||
|
conn = sqlite3.connect(backup_file)
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute('SELECT * FROM foo')
|
||||||
|
rows = c.fetchall()
|
||||||
|
self.assertEqual(len(rows), 1)
|
||||||
|
self.assertEqual(rows[0][1], 'fiona')
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
deprovision_node(node)
|
||||||
|
os.remove(cfg)
|
||||||
|
os.remove(backup_file)
|
||||||
|
delete_s3_object(access_key_id, secret_access_key_id,
|
||||||
|
S3_BUCKET, path)
|
||||||
|
|
||||||
|
@unittest.skipUnless(env_present('RQLITE_S3_ACCESS_KEY'), "S3 credentials not available")
|
||||||
|
def test_compress(self):
|
||||||
|
'''Test that automatic backups to AWS S3 work with compression on'''
|
||||||
|
node = None
|
||||||
|
cfg = None
|
||||||
|
path = None
|
||||||
|
compressed_backup_file = None
|
||||||
|
backup_file = None
|
||||||
|
|
||||||
|
access_key_id = os.environ['RQLITE_S3_ACCESS_KEY']
|
||||||
|
secret_access_key_id = os.environ['RQLITE_S3_SECRET_ACCESS_KEY']
|
||||||
|
|
||||||
|
# Create the auto-backup config file
|
||||||
|
path = random_string(32)
|
||||||
|
auto_backup_cfg = {
|
||||||
|
"version": 1,
|
||||||
|
"type": "s3",
|
||||||
|
"interval": "1s",
|
||||||
|
"sub" : {
|
||||||
|
"access_key_id": access_key_id,
|
||||||
|
"secret_access_key": secret_access_key_id,
|
||||||
|
"region": S3_BUCKET_REGION,
|
||||||
|
"bucket": S3_BUCKET,
|
||||||
|
"path": path
|
||||||
|
}
|
||||||
|
}
|
||||||
|
cfg = write_random_file(json.dumps(auto_backup_cfg))
|
||||||
|
|
||||||
|
# Create a node, enable automatic backups, and start it. Then
|
||||||
|
# create a table and insert a row. Wait for a backup to happen.
|
||||||
|
node = Node(RQLITED_PATH, '0', auto_backup=cfg)
|
||||||
|
node.start()
|
||||||
|
node.wait_for_leader()
|
||||||
|
node.execute('CREATE TABLE foo (id INTEGER NOT NULL PRIMARY KEY, name TEXT)')
|
||||||
|
node.execute('INSERT INTO foo(name) VALUES("fiona")')
|
||||||
|
node.wait_for_all_fsm()
|
||||||
|
time.sleep(5)
|
||||||
|
|
||||||
|
# Download the backup file from S3 and check it.
|
||||||
|
backup_data = download_s3_object(access_key_id, secret_access_key_id,
|
||||||
|
S3_BUCKET, path)
|
||||||
|
compressed_backup_file = write_random_file(backup_data, mode='wb')
|
||||||
|
backup_file = gunzip_file(compressed_backup_file)
|
||||||
|
conn = sqlite3.connect(backup_file)
|
||||||
|
c = conn.cursor()
|
||||||
|
c.execute('SELECT * FROM foo')
|
||||||
|
rows = c.fetchall()
|
||||||
|
self.assertEqual(len(rows), 1)
|
||||||
|
self.assertEqual(rows[0][1], 'fiona')
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
deprovision_node(node)
|
||||||
|
os.remove(cfg)
|
||||||
|
os.remove(compressed_backup_file)
|
||||||
|
os.remove(backup_file)
|
||||||
|
delete_s3_object(access_key_id, secret_access_key_id,
|
||||||
|
S3_BUCKET, path)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main(verbosity=2)
|
@ -0,0 +1,29 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import os
|
||||||
|
|
||||||
|
def delete_s3_object(access_key_id, secret_access_key_id, bucket_name, object_key):
|
||||||
|
"""
|
||||||
|
Delete an object from an S3 bucket.
|
||||||
|
"""
|
||||||
|
os.environ['AWS_ACCESS_KEY_ID'] = access_key_id
|
||||||
|
os.environ['AWS_SECRET_ACCESS_KEY'] = secret_access_key_id
|
||||||
|
|
||||||
|
s3_client = boto3.client('s3')
|
||||||
|
s3_client.delete_object(Bucket=bucket_name, Key=object_key)
|
||||||
|
|
||||||
|
def download_s3_object(access_key_id, secret_access_key_id, bucket_name, object_key):
|
||||||
|
"""
|
||||||
|
Download an object from an S3 bucket.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
bucket_name (str): The name of the S3 bucket.
|
||||||
|
object_key (str): The key of the object to download.
|
||||||
|
"""
|
||||||
|
os.environ['AWS_ACCESS_KEY_ID'] = access_key_id
|
||||||
|
os.environ['AWS_SECRET_ACCESS_KEY'] = secret_access_key_id
|
||||||
|
|
||||||
|
s3_client = boto3.client('s3')
|
||||||
|
response = s3_client.get_object(Bucket=bucket_name, Key=object_key)
|
||||||
|
return response['Body'].read()
|
@ -0,0 +1,28 @@
|
|||||||
|
package upload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AutoDeleteFile is a wrapper around os.File that deletes the file when it is
|
||||||
|
// closed.
|
||||||
|
type AutoDeleteFile struct {
|
||||||
|
*os.File
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close implements the io.Closer interface
|
||||||
|
func (f *AutoDeleteFile) Close() error {
|
||||||
|
if err := f.File.Close(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return os.Remove(f.Name())
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAutoDeleteFile takes a filename and wraps it in an AutoDeleteFile
|
||||||
|
func NewAutoDeleteFile(path string) (*AutoDeleteFile, error) {
|
||||||
|
f, err := os.Open(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return &AutoDeleteFile{f}, nil
|
||||||
|
}
|
@ -0,0 +1,57 @@
|
|||||||
|
package upload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_NewAutoDeleteTempFile(t *testing.T) {
|
||||||
|
adFile, err := NewAutoDeleteFile(mustCreateTempFilename())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAutoDeleteFile() failed: %v", err)
|
||||||
|
}
|
||||||
|
defer adFile.Close()
|
||||||
|
|
||||||
|
if _, err := os.Stat(adFile.Name()); os.IsNotExist(err) {
|
||||||
|
t.Fatalf("Expected file to exist: %s", adFile.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_AutoDeleteFile_Name(t *testing.T) {
|
||||||
|
name := mustCreateTempFilename()
|
||||||
|
adFile, err := NewAutoDeleteFile(name)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAutoDeleteFile() failed: %v", err)
|
||||||
|
}
|
||||||
|
defer adFile.Close()
|
||||||
|
|
||||||
|
if adFile.Name() != name {
|
||||||
|
t.Fatalf("Expected Name() to return %s, got %s", name, adFile.Name())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_AutoDeleteFile_Close(t *testing.T) {
|
||||||
|
adFile, err := NewAutoDeleteFile(mustCreateTempFilename())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewAutoDeleteFile() failed: %v", err)
|
||||||
|
}
|
||||||
|
filename := adFile.Name()
|
||||||
|
|
||||||
|
err = adFile.Close()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Close() failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := os.Stat(filename); !os.IsNotExist(err) {
|
||||||
|
t.Fatalf("Expected file to be deleted after Close(): %s", filename)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustCreateTempFilename() string {
|
||||||
|
f, err := os.CreateTemp("", "autodeletefile_test")
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
f.Close()
|
||||||
|
return f.Name()
|
||||||
|
}
|
@ -0,0 +1,129 @@
|
|||||||
|
package upload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"errors"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// version is the max version of the config file format supported
|
||||||
|
version = 1
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrInvalidVersion is returned when the config file version is not supported.
|
||||||
|
ErrInvalidVersion = errors.New("invalid version")
|
||||||
|
|
||||||
|
// ErrUnsupportedStorageType is returned when the storage type is not supported.
|
||||||
|
ErrUnsupportedStorageType = errors.New("unsupported storage type")
|
||||||
|
)
|
||||||
|
|
||||||
|
// Duration is a wrapper around time.Duration that allows us to unmarshal
|
||||||
|
type Duration time.Duration
|
||||||
|
|
||||||
|
// MarshalJSON marshals the duration as a string
|
||||||
|
func (d Duration) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(time.Duration(d).String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON unmarshals the duration from a string or a float64
|
||||||
|
func (d *Duration) UnmarshalJSON(b []byte) error {
|
||||||
|
var v interface{}
|
||||||
|
if err := json.Unmarshal(b, &v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch value := v.(type) {
|
||||||
|
case float64:
|
||||||
|
*d = Duration(time.Duration(value))
|
||||||
|
return nil
|
||||||
|
case string:
|
||||||
|
tmp, err := time.ParseDuration(value)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
*d = Duration(tmp)
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return errors.New("invalid duration")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// StorageType is a wrapper around string that allows us to unmarshal
|
||||||
|
type StorageType string
|
||||||
|
|
||||||
|
// UnmarshalJSON unmarshals the storage type from a string and validates it
|
||||||
|
func (s *StorageType) UnmarshalJSON(b []byte) error {
|
||||||
|
var v interface{}
|
||||||
|
if err := json.Unmarshal(b, &v); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
switch value := v.(type) {
|
||||||
|
case string:
|
||||||
|
*s = StorageType(value)
|
||||||
|
if *s != "s3" {
|
||||||
|
return ErrUnsupportedStorageType
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
default:
|
||||||
|
return ErrUnsupportedStorageType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config is the config file format for the upload service
|
||||||
|
type Config struct {
|
||||||
|
Version int `json:"version"`
|
||||||
|
Type StorageType `json:"type"`
|
||||||
|
NoCompress bool `json:"no_compress,omitempty"`
|
||||||
|
Interval Duration `json:"interval"`
|
||||||
|
Sub json.RawMessage `json:"sub"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// S3Config is the subconfig for the S3 storage type
|
||||||
|
type S3Config struct {
|
||||||
|
AccessKeyID string `json:"access_key_id"`
|
||||||
|
SecretAccessKey string `json:"secret_access_key"`
|
||||||
|
Region string `json:"region"`
|
||||||
|
Bucket string `json:"bucket"`
|
||||||
|
Path string `json:"path"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// Unmarshal unmarshals the config file and returns the config and subconfig
|
||||||
|
func Unmarshal(data []byte) (*Config, *S3Config, error) {
|
||||||
|
cfg := &Config{}
|
||||||
|
err := json.Unmarshal(data, cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg.Version > version {
|
||||||
|
return nil, nil, ErrInvalidVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
s3cfg := &S3Config{}
|
||||||
|
err = json.Unmarshal(cfg.Sub, s3cfg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
return cfg, s3cfg, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ReadConfigFile reads the config file and returns the data. It also expands
|
||||||
|
// any environment variables in the config file.
|
||||||
|
func ReadConfigFile(filename string) ([]byte, error) {
|
||||||
|
f, err := os.Open(filename)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
defer f.Close()
|
||||||
|
|
||||||
|
data, err := ioutil.ReadAll(f)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
data = []byte(os.ExpandEnv(string(data)))
|
||||||
|
return data, nil
|
||||||
|
}
|
@ -0,0 +1,252 @@
|
|||||||
|
package upload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"io/ioutil"
|
||||||
|
"os"
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_ReadConfigFile(t *testing.T) {
|
||||||
|
t.Run("valid config file", func(t *testing.T) {
|
||||||
|
// Create a temporary config file
|
||||||
|
tempFile, err := ioutil.TempFile("", "upload_config")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tempFile.Name())
|
||||||
|
|
||||||
|
content := []byte("key=value")
|
||||||
|
if _, err := tempFile.Write(content); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tempFile.Close()
|
||||||
|
|
||||||
|
data, err := ReadConfigFile(tempFile.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
if !bytes.Equal(data, content) {
|
||||||
|
t.Fatalf("Expected %v, got %v", content, data)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("non-existent file", func(t *testing.T) {
|
||||||
|
_, err := ReadConfigFile("nonexistentfile")
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
t.Fatalf("Expected os.ErrNotExist, got %v", err)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("file with environment variables", func(t *testing.T) {
|
||||||
|
// Set an environment variable
|
||||||
|
if err := os.Setenv("TEST_VAR", "test_value"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.Unsetenv("TEST_VAR")
|
||||||
|
|
||||||
|
// Create a temporary config file with an environment variable
|
||||||
|
tempFile, err := ioutil.TempFile("", "upload_config")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tempFile.Name())
|
||||||
|
|
||||||
|
content := []byte("key=$TEST_VAR")
|
||||||
|
if _, err := tempFile.Write(content); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tempFile.Close()
|
||||||
|
|
||||||
|
data, err := ReadConfigFile(tempFile.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
expectedContent := []byte("key=test_value")
|
||||||
|
if !bytes.Equal(data, expectedContent) {
|
||||||
|
t.Fatalf("Expected %v, got %v", expectedContent, data)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("longer file with environment variables", func(t *testing.T) {
|
||||||
|
// Set an environment variable
|
||||||
|
if err := os.Setenv("TEST_VAR1", "test_value"); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.Unsetenv("TEST_VAR1")
|
||||||
|
|
||||||
|
// Create a temporary config file with an environment variable
|
||||||
|
tempFile, err := ioutil.TempFile("", "upload_config")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
defer os.Remove(tempFile.Name())
|
||||||
|
|
||||||
|
content := []byte(`
|
||||||
|
key1=$TEST_VAR1
|
||||||
|
key2=TEST_VAR2`)
|
||||||
|
if _, err := tempFile.Write(content); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
tempFile.Close()
|
||||||
|
|
||||||
|
data, err := ReadConfigFile(tempFile.Name())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Expected no error, got %v", err)
|
||||||
|
}
|
||||||
|
expectedContent := []byte(`
|
||||||
|
key1=test_value
|
||||||
|
key2=TEST_VAR2`)
|
||||||
|
if !bytes.Equal(data, expectedContent) {
|
||||||
|
t.Fatalf("Expected %v, got %v", expectedContent, data)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUnmarshal(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
input []byte
|
||||||
|
expectedCfg *Config
|
||||||
|
expectedS3 *S3Config
|
||||||
|
expectedErr error
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "ValidS3Config",
|
||||||
|
input: []byte(`
|
||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"type": "s3",
|
||||||
|
"no_compress": true,
|
||||||
|
"interval": "24h",
|
||||||
|
"sub": {
|
||||||
|
"access_key_id": "test_id",
|
||||||
|
"secret_access_key": "test_secret",
|
||||||
|
"region": "us-west-2",
|
||||||
|
"bucket": "test_bucket",
|
||||||
|
"path": "test/path"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`),
|
||||||
|
expectedCfg: &Config{
|
||||||
|
Version: 1,
|
||||||
|
Type: "s3",
|
||||||
|
NoCompress: true,
|
||||||
|
Interval: 24 * Duration(time.Hour),
|
||||||
|
},
|
||||||
|
expectedS3: &S3Config{
|
||||||
|
AccessKeyID: "test_id",
|
||||||
|
SecretAccessKey: "test_secret",
|
||||||
|
Region: "us-west-2",
|
||||||
|
Bucket: "test_bucket",
|
||||||
|
Path: "test/path",
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ValidS3ConfigNoptionalFields",
|
||||||
|
input: []byte(`
|
||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"type": "s3",
|
||||||
|
"interval": "24h",
|
||||||
|
"sub": {
|
||||||
|
"access_key_id": "test_id",
|
||||||
|
"secret_access_key": "test_secret",
|
||||||
|
"region": "us-west-2",
|
||||||
|
"bucket": "test_bucket",
|
||||||
|
"path": "test/path"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`),
|
||||||
|
expectedCfg: &Config{
|
||||||
|
Version: 1,
|
||||||
|
Type: "s3",
|
||||||
|
NoCompress: false,
|
||||||
|
Interval: 24 * Duration(time.Hour),
|
||||||
|
},
|
||||||
|
expectedS3: &S3Config{
|
||||||
|
AccessKeyID: "test_id",
|
||||||
|
SecretAccessKey: "test_secret",
|
||||||
|
Region: "us-west-2",
|
||||||
|
Bucket: "test_bucket",
|
||||||
|
Path: "test/path",
|
||||||
|
},
|
||||||
|
expectedErr: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "InvalidVersion",
|
||||||
|
input: []byte(`
|
||||||
|
{
|
||||||
|
"version": 2,
|
||||||
|
"type": "s3",
|
||||||
|
"no_compress": false,
|
||||||
|
"interval": "24h",
|
||||||
|
"sub": {
|
||||||
|
"access_key_id": "test_id",
|
||||||
|
"secret_access_key": "test_secret",
|
||||||
|
"region": "us-west-2",
|
||||||
|
"bucket": "test_bucket",
|
||||||
|
"path": "test/path"
|
||||||
|
}
|
||||||
|
} `),
|
||||||
|
expectedCfg: nil,
|
||||||
|
expectedS3: nil,
|
||||||
|
expectedErr: ErrInvalidVersion,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UnsupportedType",
|
||||||
|
input: []byte(`
|
||||||
|
{
|
||||||
|
"version": 1,
|
||||||
|
"type": "unsupported",
|
||||||
|
"no_compress": true,
|
||||||
|
"interval": "24h",
|
||||||
|
"sub": {
|
||||||
|
"access_key_id": "test_id",
|
||||||
|
"secret_access_key": "test_secret",
|
||||||
|
"region": "us-west-2",
|
||||||
|
"bucket": "test_bucket",
|
||||||
|
"path": "test/path"
|
||||||
|
}
|
||||||
|
} `),
|
||||||
|
expectedCfg: nil,
|
||||||
|
expectedS3: nil,
|
||||||
|
expectedErr: ErrUnsupportedStorageType,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
cfg, s3Cfg, err := Unmarshal(tc.input)
|
||||||
|
_ = s3Cfg
|
||||||
|
|
||||||
|
if !errors.Is(err, tc.expectedErr) {
|
||||||
|
t.Fatalf("Test case %s failed, expected error %v, got %v", tc.name, tc.expectedErr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !compareConfig(cfg, tc.expectedCfg) {
|
||||||
|
t.Fatalf("Test case %s failed, expected config %+v, got %+v", tc.name, tc.expectedCfg, cfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.expectedS3 != nil {
|
||||||
|
if !reflect.DeepEqual(s3Cfg, tc.expectedS3) {
|
||||||
|
t.Fatalf("Test case %s failed, expected S3Config %+v, got %+v", tc.name, tc.expectedS3, s3Cfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareConfig(a, b *Config) bool {
|
||||||
|
if a == nil || b == nil {
|
||||||
|
return a == b
|
||||||
|
}
|
||||||
|
return a.Version == b.Version &&
|
||||||
|
a.Type == b.Type &&
|
||||||
|
a.NoCompress == b.NoCompress &&
|
||||||
|
a.Interval == b.Interval
|
||||||
|
}
|
@ -0,0 +1,163 @@
|
|||||||
|
package upload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"compress/gzip"
|
||||||
|
"context"
|
||||||
|
"expvar"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
// StorageClient is an interface for uploading data to a storage service.
|
||||||
|
type StorageClient interface {
|
||||||
|
Upload(ctx context.Context, reader io.Reader) error
|
||||||
|
fmt.Stringer
|
||||||
|
}
|
||||||
|
|
||||||
|
// DataProvider is an interface for providing data to be uploaded. The Uploader
|
||||||
|
// service will call Provide() to get a reader for the data to be uploaded. Once
|
||||||
|
// the upload completes the reader will be closed, regardless of whether the
|
||||||
|
// upload succeeded or failed.
|
||||||
|
type DataProvider interface {
|
||||||
|
Provide() (io.ReadCloser, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stats captures stats for the Uploader service.
|
||||||
|
var stats *expvar.Map
|
||||||
|
|
||||||
|
const (
|
||||||
|
numUploadsOK = "num_uploads_ok"
|
||||||
|
numUploadsFail = "num_uploads_fail"
|
||||||
|
totalUploadBytes = "total_upload_bytes"
|
||||||
|
lastUploadBytes = "last_upload_bytes"
|
||||||
|
|
||||||
|
UploadCompress = true
|
||||||
|
UploadNoCompress = false
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
stats = expvar.NewMap("uploader")
|
||||||
|
ResetStats()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResetStats resets the expvar stats for this module. Mostly for test purposes.
|
||||||
|
func ResetStats() {
|
||||||
|
stats.Init()
|
||||||
|
stats.Add(numUploadsOK, 0)
|
||||||
|
stats.Add(numUploadsFail, 0)
|
||||||
|
stats.Add(totalUploadBytes, 0)
|
||||||
|
stats.Add(lastUploadBytes, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Uploader is a service that periodically uploads data to a storage service.
|
||||||
|
type Uploader struct {
|
||||||
|
storageClient StorageClient
|
||||||
|
dataProvider DataProvider
|
||||||
|
interval time.Duration
|
||||||
|
compress bool
|
||||||
|
|
||||||
|
logger *log.Logger
|
||||||
|
lastUploadTime time.Time
|
||||||
|
lastUploadDuration time.Duration
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUploader creates a new Uploader service.
|
||||||
|
func NewUploader(storageClient StorageClient, dataProvider DataProvider, interval time.Duration, compress bool) *Uploader {
|
||||||
|
return &Uploader{
|
||||||
|
storageClient: storageClient,
|
||||||
|
dataProvider: dataProvider,
|
||||||
|
interval: interval,
|
||||||
|
compress: compress,
|
||||||
|
logger: log.New(os.Stderr, "[uploader] ", log.LstdFlags),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start starts the Uploader service.
|
||||||
|
func (u *Uploader) Start(ctx context.Context, isUploadEnabled func() bool) {
|
||||||
|
if isUploadEnabled == nil {
|
||||||
|
isUploadEnabled = func() bool { return true }
|
||||||
|
}
|
||||||
|
|
||||||
|
u.logger.Printf("starting upload to %s every %s", u.storageClient, u.interval)
|
||||||
|
ticker := time.NewTicker(u.interval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
u.logger.Println("upload service shutting down")
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
if !isUploadEnabled() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := u.upload(ctx); err != nil {
|
||||||
|
u.logger.Printf("failed to upload to %s: %v", u.storageClient, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stats returns the stats for the Uploader service.
|
||||||
|
func (u *Uploader) Stats() (map[string]interface{}, error) {
|
||||||
|
status := map[string]interface{}{
|
||||||
|
"upload_destination": u.storageClient.String(),
|
||||||
|
"upload_interval": u.interval.String(),
|
||||||
|
"compress": u.compress,
|
||||||
|
"last_upload_time": u.lastUploadTime.Format(time.RFC3339),
|
||||||
|
"last_upload_duration": u.lastUploadDuration.String(),
|
||||||
|
}
|
||||||
|
return status, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *Uploader) upload(ctx context.Context) error {
|
||||||
|
rc, err := u.dataProvider.Provide()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer rc.Close()
|
||||||
|
|
||||||
|
r := rc.(io.Reader)
|
||||||
|
if u.compress {
|
||||||
|
buffer := new(bytes.Buffer)
|
||||||
|
gw := gzip.NewWriter(buffer)
|
||||||
|
_, err = io.Copy(gw, rc)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
err = gw.Close()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
r = buffer
|
||||||
|
}
|
||||||
|
|
||||||
|
cr := &countingReader{reader: r}
|
||||||
|
startTime := time.Now()
|
||||||
|
err = u.storageClient.Upload(ctx, cr)
|
||||||
|
if err != nil {
|
||||||
|
stats.Add(numUploadsFail, 1)
|
||||||
|
} else {
|
||||||
|
stats.Add(numUploadsOK, 1)
|
||||||
|
stats.Add(totalUploadBytes, cr.count)
|
||||||
|
stats.Get(lastUploadBytes).(*expvar.Int).Set(cr.count)
|
||||||
|
u.lastUploadTime = time.Now()
|
||||||
|
u.lastUploadDuration = time.Since(startTime)
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
type countingReader struct {
|
||||||
|
reader io.Reader
|
||||||
|
count int64
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *countingReader) Read(p []byte) (int, error) {
|
||||||
|
n, err := c.reader.Read(p)
|
||||||
|
c.count += int64(n)
|
||||||
|
return n, err
|
||||||
|
}
|
@ -0,0 +1,307 @@
|
|||||||
|
package upload
|
||||||
|
|
||||||
|
import (
|
||||||
|
"compress/gzip"
|
||||||
|
"context"
|
||||||
|
"expvar"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_NewUploader(t *testing.T) {
|
||||||
|
storageClient := &mockStorageClient{}
|
||||||
|
dataProvider := &mockDataProvider{}
|
||||||
|
interval := time.Second
|
||||||
|
uploader := NewUploader(storageClient, dataProvider, interval, UploadNoCompress)
|
||||||
|
if uploader.storageClient != storageClient {
|
||||||
|
t.Errorf("expected storageClient to be %v, got %v", storageClient, uploader.storageClient)
|
||||||
|
}
|
||||||
|
if uploader.dataProvider != dataProvider {
|
||||||
|
t.Errorf("expected dataProvider to be %v, got %v", dataProvider, uploader.dataProvider)
|
||||||
|
}
|
||||||
|
if uploader.interval != interval {
|
||||||
|
t.Errorf("expected interval to be %v, got %v", interval, uploader.interval)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_UploaderSingleUpload(t *testing.T) {
|
||||||
|
ResetStats()
|
||||||
|
var uploadedData []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
sc := &mockStorageClient{
|
||||||
|
uploadFn: func(ctx context.Context, reader io.Reader) error {
|
||||||
|
defer wg.Done()
|
||||||
|
uploadedData, err = io.ReadAll(reader)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dp := &mockDataProvider{data: "my upload data"}
|
||||||
|
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go uploader.Start(ctx, nil)
|
||||||
|
wg.Wait()
|
||||||
|
cancel()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if exp, got := "my upload data", string(uploadedData); exp != got {
|
||||||
|
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_UploaderSingleUploadCompress(t *testing.T) {
|
||||||
|
ResetStats()
|
||||||
|
var uploadedData []byte
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
sc := &mockStorageClient{
|
||||||
|
uploadFn: func(ctx context.Context, reader io.Reader) error {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
// Wrap a gzip reader about the reader.
|
||||||
|
gzReader, err := gzip.NewReader(reader)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer gzReader.Close()
|
||||||
|
|
||||||
|
uploadedData, err = io.ReadAll(gzReader)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dp := &mockDataProvider{data: "my upload data"}
|
||||||
|
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadCompress)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go uploader.Start(ctx, nil)
|
||||||
|
wg.Wait()
|
||||||
|
cancel()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if exp, got := "my upload data", string(uploadedData); exp != got {
|
||||||
|
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_UploaderDoubleUpload(t *testing.T) {
|
||||||
|
ResetStats()
|
||||||
|
|
||||||
|
var uploadedData []byte
|
||||||
|
var err error
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
sc := &mockStorageClient{
|
||||||
|
uploadFn: func(ctx context.Context, reader io.Reader) error {
|
||||||
|
defer wg.Done()
|
||||||
|
uploadedData = nil // Wipe out any previous state.
|
||||||
|
uploadedData, err = io.ReadAll(reader)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dp := &mockDataProvider{data: "my upload data"}
|
||||||
|
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go uploader.Start(ctx, nil)
|
||||||
|
wg.Wait()
|
||||||
|
cancel()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if exp, got := "my upload data", string(uploadedData); exp != got {
|
||||||
|
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_UploaderFailThenOK(t *testing.T) {
|
||||||
|
ResetStats()
|
||||||
|
|
||||||
|
var uploadedData []byte
|
||||||
|
uploadCount := 0
|
||||||
|
var err error
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
sc := &mockStorageClient{
|
||||||
|
uploadFn: func(ctx context.Context, reader io.Reader) error {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
if uploadCount == 0 {
|
||||||
|
uploadCount++
|
||||||
|
return fmt.Errorf("failed to upload")
|
||||||
|
}
|
||||||
|
|
||||||
|
uploadedData, err = io.ReadAll(reader)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dp := &mockDataProvider{data: "my upload data"}
|
||||||
|
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go uploader.Start(ctx, nil)
|
||||||
|
wg.Wait()
|
||||||
|
cancel()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if exp, got := "my upload data", string(uploadedData); exp != got {
|
||||||
|
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_UploaderOKThenFail(t *testing.T) {
|
||||||
|
var uploadedData []byte
|
||||||
|
uploadCount := 0
|
||||||
|
var err error
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
sc := &mockStorageClient{
|
||||||
|
uploadFn: func(ctx context.Context, reader io.Reader) error {
|
||||||
|
defer wg.Done()
|
||||||
|
|
||||||
|
if uploadCount == 1 {
|
||||||
|
return fmt.Errorf("failed to upload")
|
||||||
|
}
|
||||||
|
|
||||||
|
uploadCount++
|
||||||
|
uploadedData, err = io.ReadAll(reader)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dp := &mockDataProvider{data: "my upload data"}
|
||||||
|
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go uploader.Start(ctx, nil)
|
||||||
|
wg.Wait()
|
||||||
|
cancel()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if exp, got := "my upload data", string(uploadedData); exp != got {
|
||||||
|
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_UploaderContextCancellation(t *testing.T) {
|
||||||
|
var uploadCount int32
|
||||||
|
|
||||||
|
sc := &mockStorageClient{
|
||||||
|
uploadFn: func(ctx context.Context, reader io.Reader) error {
|
||||||
|
atomic.AddInt32(&uploadCount, 1)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dp := &mockDataProvider{data: "my upload data"}
|
||||||
|
uploader := NewUploader(sc, dp, time.Second, UploadNoCompress)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||||
|
|
||||||
|
go uploader.Start(ctx, nil)
|
||||||
|
<-ctx.Done()
|
||||||
|
cancel()
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
if exp, got := int32(0), atomic.LoadInt32(&uploadCount); exp != got {
|
||||||
|
t.Errorf("expected uploadCount to be %d, got %d", exp, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_UploaderEnabledFalse(t *testing.T) {
|
||||||
|
ResetStats()
|
||||||
|
|
||||||
|
sc := &mockStorageClient{}
|
||||||
|
dp := &mockDataProvider{data: "my upload data"}
|
||||||
|
uploader := NewUploader(sc, dp, 100*time.Millisecond, false)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go uploader.Start(ctx, func() bool { return false })
|
||||||
|
time.Sleep(time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
if exp, got := int64(0), stats.Get(numUploadsOK).(*expvar.Int); exp != got.Value() {
|
||||||
|
t.Errorf("expected numUploadsOK to be %d, got %d", exp, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_UploaderEnabledTrue(t *testing.T) {
|
||||||
|
var uploadedData []byte
|
||||||
|
var err error
|
||||||
|
ResetStats()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(1)
|
||||||
|
sc := &mockStorageClient{
|
||||||
|
uploadFn: func(ctx context.Context, reader io.Reader) error {
|
||||||
|
defer wg.Done()
|
||||||
|
uploadedData, err = io.ReadAll(reader)
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dp := &mockDataProvider{data: "my upload data"}
|
||||||
|
uploader := NewUploader(sc, dp, 100*time.Millisecond, UploadNoCompress)
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
go uploader.Start(ctx, func() bool { return true })
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
if exp, got := string(uploadedData), "my upload data"; exp != got {
|
||||||
|
t.Errorf("expected uploadedData to be %s, got %s", exp, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func Test_UploaderStats(t *testing.T) {
|
||||||
|
sc := &mockStorageClient{}
|
||||||
|
dp := &mockDataProvider{data: "my upload data"}
|
||||||
|
interval := 100 * time.Millisecond
|
||||||
|
uploader := NewUploader(sc, dp, interval, UploadNoCompress)
|
||||||
|
|
||||||
|
stats, err := uploader.Stats()
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if exp, got := sc.String(), stats["upload_destination"]; exp != got {
|
||||||
|
t.Errorf("expected upload_destination to be %s, got %s", exp, got)
|
||||||
|
}
|
||||||
|
|
||||||
|
if exp, got := interval.String(), stats["upload_interval"]; exp != got {
|
||||||
|
t.Errorf("expected upload_interval to be %s, got %s", exp, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockStorageClient struct {
|
||||||
|
uploadFn func(ctx context.Context, reader io.Reader) error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mockStorageClient) Upload(ctx context.Context, reader io.Reader) error {
|
||||||
|
if mc.uploadFn != nil {
|
||||||
|
return mc.uploadFn(ctx, reader)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mc *mockStorageClient) String() string {
|
||||||
|
return "mockStorageClient"
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockDataProvider struct {
|
||||||
|
data string
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (mp *mockDataProvider) Provide() (io.ReadCloser, error) {
|
||||||
|
if mp.err != nil {
|
||||||
|
return nil, mp.err
|
||||||
|
}
|
||||||
|
return io.NopCloser(strings.NewReader(mp.data)), nil
|
||||||
|
}
|
Loading…
Reference in New Issue