1
0
Fork 0
You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

263 lines
8.0 KiB
Go

package aws
import (
"bytes"
"context"
"fmt"
"io"
"net/http"
"strings"
"testing"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
func Test_NewS3Client(t *testing.T) {
c, err := NewS3Client("endpoint1", "region1", "access", "secret", "bucket2", "key3", true)
if err != nil {
t.Fatalf("error while creating aws S3 client: %v", err)
}
if c.region != "region1" {
t.Fatalf("expected region to be %q, got %q", "region1", c.region)
}
if c.accessKey != "access" {
t.Fatalf("expected accessKey to be %q, got %q", "access", c.accessKey)
}
if c.secretKey != "secret" {
t.Fatalf("expected secretKey to be %q, got %q", "secret", c.secretKey)
}
if c.bucket != "bucket2" {
t.Fatalf("expected bucket to be %q, got %q", "bucket2", c.bucket)
}
if c.key != "key3" {
t.Fatalf("expected key to be %q, got %q", "key3", c.key)
}
if c.forcePathStyle != true {
t.Fatalf("expected forcePathStyle to be %v, got %v", true, c.forcePathStyle)
}
}
func Test_S3Client_String(t *testing.T) {
// Test native S3 with implicit endpoint
c, err := NewS3Client("", "region1", "access", "secret", "bucket2", "key3", false)
if err != nil {
t.Fatalf("error while creating aws S3 client: %v", err)
}
if c.String() != "s3://bucket2/key3" {
t.Fatalf("expected String() to be %q, got %q", "s3://bucket2/key3", c.String())
}
// Test native S3 with explicit endpoint
c, err = NewS3Client("s3.amazonaws.com", "region1", "access", "secret", "bucket2", "key3", false)
if err != nil {
t.Fatalf("error while creating aws S3 client: %v", err)
}
if c.String() != "s3://bucket2/key3" {
t.Fatalf("expected String() to be %q, got %q", "s3://bucket2/key3", c.String())
}
// Test non-native S3 (explicit endpoint) with non-path style (e.g. Wasabi)
c, err = NewS3Client("s3.ca-central-1.wasabisys.com", "region1", "access", "secret", "bucket2", "key3", false)
if err != nil {
t.Fatalf("error while creating aws S3 client: %v", err)
}
if c.String() != "s3://bucket2.s3.ca-central-1.wasabisys.com/key3" {
t.Fatalf("expected String() to be %q, got %q", "s3://bucket2.s3.ca-central-1.wasabisys.com/key3", c.String())
}
// Test non-native S3 (explicit endpoint) with forced path style (e.g. MinIO)
c, err = NewS3Client("s3.minio.example.com", "region1", "access", "secret", "bucket2", "key3", true)
if err != nil {
t.Fatalf("error while creating aws S3 client: %v", err)
}
if c.String() != "s3://s3.minio.example.com/bucket2/key3" {
t.Fatalf("expected String() to be %q, got %q", "s3://s3.minio.example.com/bucket2/key3", c.String())
}
}
func TestS3ClientUploadOK(t *testing.T) {
endpoint := "https://my-custom-s3-endpoint.com"
region := "us-west-2"
accessKey := "your-access-key"
secretKey := "your-secret-key"
bucket := "your-bucket"
key := "your/key/path"
expectedData := "test data"
uploadedData := new(bytes.Buffer)
mockUploader := &mockUploader{
uploadFn: func(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) {
if *input.Bucket != bucket {
t.Errorf("expected bucket to be %q, got %q", bucket, *input.Bucket)
}
if *input.Key != key {
t.Errorf("expected key to be %q, got %q", key, *input.Key)
}
if input.Body == nil {
t.Errorf("expected body to be non-nil")
}
_, err := uploadedData.ReadFrom(input.Body)
if err != nil {
t.Errorf("error reading from input body: %v", err)
}
if input.Metadata == nil {
t.Errorf("expected metadata to be non-nil")
}
exp, got := "some-id", *input.Metadata[http.CanonicalHeaderKey(AWSS3IDKey)]
if exp != got {
t.Errorf("expected metadata to contain %q, got %q", exp, got)
}
return &s3manager.UploadOutput{}, nil
},
}
client := &S3Client{
endpoint: endpoint,
region: region,
accessKey: accessKey,
secretKey: secretKey,
bucket: bucket,
key: key,
uploader: mockUploader,
}
reader := strings.NewReader("test data")
err := client.Upload(context.Background(), reader, "some-id")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if uploadedData.String() != expectedData {
t.Errorf("expected uploaded data to be %q, got %q", expectedData, uploadedData.String())
}
}
func TestS3ClientUploadFail(t *testing.T) {
region := "us-west-2"
accessKey := "your-access-key"
secretKey := "your-secret-key"
bucket := "your-bucket"
key := "your/key/path"
mockUploader := &mockUploader{
uploadFn: func(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) {
return &s3manager.UploadOutput{}, fmt.Errorf("some error related to S3")
},
}
client := &S3Client{
region: region,
accessKey: accessKey,
secretKey: secretKey,
bucket: bucket,
key: key,
uploader: mockUploader,
}
reader := strings.NewReader("test data")
err := client.Upload(context.Background(), reader, "")
if err == nil {
t.Fatal("Expected error, got nil")
}
if !strings.Contains(err.Error(), "some error related to S3") {
t.Fatalf("Expected error to contain %q, got %q", "some error related to S3", err.Error())
}
}
func TestS3ClientDownloadOK(t *testing.T) {
region := "us-west-2"
accessKey := "your-access-key"
secretKey := "your-secret-key"
bucket := "your-bucket"
key := "your/key/path"
expectedData := "test data"
mockDownloader := &mockDownloader{
downloadFn: func(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, opts ...func(*s3manager.Downloader)) (int64, error) {
if *input.Bucket != bucket {
t.Errorf("expected bucket to be %q, got %q", bucket, *input.Bucket)
}
if *input.Key != key {
t.Errorf("expected key to be %q, got %q", key, *input.Key)
}
n, err := w.WriteAt([]byte(expectedData), 0)
if err != nil {
t.Errorf("error writing to writer: %v", err)
}
return int64(n), nil
},
}
client := &S3Client{
region: region,
accessKey: accessKey,
secretKey: secretKey,
bucket: bucket,
key: key,
downloader: mockDownloader,
}
writer := aws.NewWriteAtBuffer(make([]byte, len(expectedData)))
err := client.Download(context.Background(), writer)
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if string(writer.Bytes()) != expectedData {
t.Errorf("expected downloaded data to be %q, got %q", expectedData, writer.Bytes())
}
}
func TestS3ClientDownloadFail(t *testing.T) {
endpoint := "https://my-custom-s3-endpoint.com"
region := "us-west-2"
accessKey := "your-access-key"
secretKey := "your-secret-key"
bucket := "your-bucket"
key := "your/key/path"
mockDownloader := &mockDownloader{
downloadFn: func(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, opts ...func(*s3manager.Downloader)) (n int64, err error) {
return 0, fmt.Errorf("some error related to S3")
},
}
client := &S3Client{
endpoint: endpoint,
region: region,
accessKey: accessKey,
secretKey: secretKey,
bucket: bucket,
key: key,
downloader: mockDownloader,
}
writer := aws.NewWriteAtBuffer(nil)
err := client.Download(context.Background(), writer)
if err == nil {
t.Fatal("Expected error, got nil")
}
if !strings.Contains(err.Error(), "some error related to S3") {
t.Fatalf("Expected error to contain %q, got %q", "some error related to S3", err.Error())
}
}
type mockDownloader struct {
downloadFn func(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, opts ...func(*s3manager.Downloader)) (n int64, err error)
}
func (m *mockDownloader) DownloadWithContext(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, opts ...func(*s3manager.Downloader)) (n int64, err error) {
if m.downloadFn != nil {
return m.downloadFn(ctx, w, input, opts...)
}
return 0, nil
}
type mockUploader struct {
uploadFn func(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
}
func (m *mockUploader) UploadWithContext(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) {
if m.uploadFn != nil {
return m.uploadFn(ctx, input, opts...)
}
return &s3manager.UploadOutput{}, nil
}