1
0
Fork 0

Complete AWS S3 downloader

master
Philip O'Toole 1 year ago
parent 3a52259744
commit bf14cefb03

@ -28,7 +28,8 @@ type S3Client struct {
bucket string
key string
uploader uploader // for testing via dependency injection
uploader uploader // for testing via dependency injection
downloader downloader // for testing via dependency injection
}
// NewS3Client returns an instance of an S3Client.
@ -81,7 +82,13 @@ func (s *S3Client) Download(ctx context.Context, writer io.WriterAt) error {
return err
}
downloader := s3manager.NewDownloader(sess)
// If a downloader was not provided, use a real S3 downloader.
var downloader downloader
if s.downloader == nil {
downloader = s3manager.NewDownloader(sess)
} else {
downloader = s.downloader
}
_, err = downloader.DownloadWithContext(ctx, writer, &s3.GetObjectInput{
Bucket: aws.String(s.bucket),

@ -4,10 +4,12 @@ import (
"bytes"
"context"
"fmt"
"io"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
@ -116,6 +118,92 @@ func TestS3ClientUploadFail(t *testing.T) {
}
}
func TestS3ClientDownloadOK(t *testing.T) {
region := "us-west-2"
accessKey := "your-access-key"
secretKey := "your-secret-key"
bucket := "your-bucket"
key := "your/key/path"
expectedData := "test data"
mockDownloader := &mockDownloader{
downloadFn: func(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, opts ...func(*s3manager.Downloader)) (int64, error) {
if *input.Bucket != bucket {
t.Errorf("expected bucket to be %q, got %q", bucket, *input.Bucket)
}
if *input.Key != key {
t.Errorf("expected key to be %q, got %q", key, *input.Key)
}
n, err := w.WriteAt([]byte(expectedData), 0)
if err != nil {
t.Errorf("error writing to writer: %v", err)
}
return int64(n), nil
},
}
client := &S3Client{
region: region,
accessKey: accessKey,
secretKey: secretKey,
bucket: bucket,
key: key,
downloader: mockDownloader,
}
writer := aws.NewWriteAtBuffer(make([]byte, len(expectedData)))
err := client.Download(context.Background(), writer)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if string(writer.Bytes()) != expectedData {
t.Errorf("expected downloaded data to be %q, got %q", expectedData, writer.Bytes())
}
}
func TestS3ClientDownloadFail(t *testing.T) {
region := "us-west-2"
accessKey := "your-access-key"
secretKey := "your-secret-key"
bucket := "your-bucket"
key := "your/key/path"
mockDownloader := &mockDownloader{
downloadFn: func(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, opts ...func(*s3manager.Downloader)) (n int64, err error) {
return 0, fmt.Errorf("some error related to S3")
},
}
client := &S3Client{
region: region,
accessKey: accessKey,
secretKey: secretKey,
bucket: bucket,
key: key,
downloader: mockDownloader,
}
writer := aws.NewWriteAtBuffer(nil)
err := client.Download(context.Background(), writer)
if err == nil {
t.Fatal("Expected error, got nil")
}
if !strings.Contains(err.Error(), "some error related to S3") {
t.Fatalf("Expected error to contain %q, got %q", "some error related to S3", err.Error())
}
}
type mockDownloader struct {
downloadFn func(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, opts ...func(*s3manager.Downloader)) (n int64, err error)
}
func (m *mockDownloader) DownloadWithContext(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, opts ...func(*s3manager.Downloader)) (n int64, err error) {
if m.downloadFn != nil {
return m.downloadFn(ctx, w, input, opts...)
}
return 0, nil
}
type mockUploader struct {
uploadFn func(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
}

Loading…
Cancel
Save