1
0
Fork 0
You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

161 lines
4.4 KiB
Go

package aws
import (
"context"
"fmt"
"io"
"net/http"
"strings"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/aws/aws-sdk-go/service/s3/s3manager"
)
var (
AWSS3IDKey = http.CanonicalHeaderKey("x-rqlite-auto-backup-id")
)
// S3Config is the subconfig for the S3 storage type
type S3Config struct {
Endpoint string `json:"endpoint,omitempty"`
Region string `json:"region"`
AccessKeyID string `json:"access_key_id"`
SecretAccessKey string `json:"secret_access_key"`
Bucket string `json:"bucket"`
Path string `json:"path"`
ForcePathStyle bool `json:"force_path_style"`
}
// S3Client is a client for uploading data to S3.
type S3Client struct {
endpoint string
region string
accessKey string
secretKey string
bucket string
key string
forcePathStyle bool
session *session.Session
s3 *s3.S3
// These fields are used for testing via dependency injection.
uploader uploader
downloader downloader
}
// NewS3Client returns an instance of an S3Client.
func NewS3Client(endpoint, region, accessKey, secretKey, bucket, key string, forcePathStyle bool) (*S3Client, error) {
cfg := aws.Config{
Endpoint: aws.String(endpoint),
Region: aws.String(region),
S3ForcePathStyle: aws.Bool(forcePathStyle),
}
// If credentials aren't provided by the user, the AWS SDK will use the default
// credential provider chain, which supports environment variables, shared credentials
// file, and EC2 instance roles.
if accessKey != "" && secretKey != "" {
cfg.Credentials = credentials.NewStaticCredentials(accessKey, secretKey, "")
}
sess, err := session.NewSession(&cfg)
if err != nil {
return nil, err
}
s3 := s3.New(sess)
return &S3Client{
endpoint: endpoint,
region: region,
accessKey: accessKey,
secretKey: secretKey,
bucket: bucket,
key: key,
forcePathStyle: forcePathStyle,
session: sess,
s3: s3,
uploader: s3manager.NewUploaderWithClient(s3),
downloader: s3manager.NewDownloaderWithClient(s3),
}, nil
}
// String returns a string representation of the S3Client.
func (s *S3Client) String() string {
if s.endpoint == "" || strings.HasSuffix(s.endpoint, "amazonaws.com") {
// Native Amazon S3, use AWS's S3 URL format
return fmt.Sprintf("s3://%s/%s", s.bucket, s.key)
} else if !s.forcePathStyle {
// Endpoint specified but not using path style (e.g. Wasabi)
return fmt.Sprintf("s3://%s.%s/%s", s.bucket, s.endpoint, s.key)
} else {
// Endpoint specified and using path style (e.g. MinIO)
return fmt.Sprintf("s3://%s/%s/%s", s.endpoint, s.bucket, s.key)
}
}
// Upload uploads data to S3.
func (s *S3Client) Upload(ctx context.Context, reader io.Reader, id string) error {
input := &s3manager.UploadInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s.key),
Body: reader,
}
if id != "" {
input.Metadata = map[string]*string{
AWSS3IDKey: aws.String(id),
}
}
_, err := s.uploader.UploadWithContext(ctx, input)
if err != nil {
return fmt.Errorf("failed to upload to %v: %w", s, err)
}
return nil
}
// CurrentID returns the last ID uploaded to S3.
func (s *S3Client) CurrentID(ctx context.Context) (string, error) {
input := &s3.HeadObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s.key),
}
result, err := s.s3.HeadObjectWithContext(ctx, input)
if err != nil {
return "", fmt.Errorf("failed to get object head for %v: %w", s, err)
}
id, ok := result.Metadata[AWSS3IDKey]
if !ok {
return "", fmt.Errorf("sum metadata not found for %v", s)
}
return *id, nil
}
// Download downloads data from S3.
func (s *S3Client) Download(ctx context.Context, writer io.WriterAt) error {
_, err := s.downloader.DownloadWithContext(ctx, writer, &s3.GetObjectInput{
Bucket: aws.String(s.bucket),
Key: aws.String(s.key),
})
if err != nil {
return fmt.Errorf("failed to download from %v: %w", s, err)
}
return nil
}
type uploader interface {
UploadWithContext(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
}
type downloader interface {
DownloadWithContext(ctx aws.Context, w io.WriterAt, input *s3.GetObjectInput, opts ...func(*s3manager.Downloader)) (n int64, err error)
}