1
0
Fork 0

aws: reuse session and S3 manager

The AWS session and s3 manager are concurrent safe, it should be reused
whenever possible:

    Sessions are safe to use concurrently as long as the Session is not
    being modified. Sessions should be cached when possible, because
    creating a new Session will load all configuration values from the
    environment, and config files each time the Session is created.

See https://pkg.go.dev/github.com/aws/aws-sdk-go/aws/session

Currently, an aws session and s3 client/manager are created every time a
call to Upload, CurrentID or Download is made. I changed it so it creates
one session and S3 manager during app startup and reuse it afterwards.
master
Mauri de Souza Meneguzzo 8 months ago
parent b76e37eae0
commit 69933cbe35

@ -38,8 +38,11 @@ func DownloadFile(ctx context.Context, cfgPath string) (path string, errOK bool,
if err != nil { if err != nil {
return "", false, fmt.Errorf("failed to parse auto-restore file: %s", err.Error()) return "", false, fmt.Errorf("failed to parse auto-restore file: %s", err.Error())
} }
sc := aws.NewS3Client(s3cfg.Endpoint, s3cfg.Region, s3cfg.AccessKeyID, s3cfg.SecretAccessKey, sc, err := aws.NewS3Client(s3cfg.Endpoint, s3cfg.Region, s3cfg.AccessKeyID, s3cfg.SecretAccessKey,
s3cfg.Bucket, s3cfg.Path, s3cfg.ForcePathStyle) s3cfg.Bucket, s3cfg.Path, s3cfg.ForcePathStyle)
if err != nil {
return "", false, fmt.Errorf("failed to create aws S3 client: %s", err.Error())
}
d := NewDownloader(sc) d := NewDownloader(sc)
// Create a temporary file to download to. // Create a temporary file to download to.

@ -39,13 +39,34 @@ type S3Client struct {
key string key string
forcePathStyle bool forcePathStyle bool
session *session.Session
s3 *s3.S3
// These fields are used for testing via dependency injection. // These fields are used for testing via dependency injection.
uploader uploader uploader uploader
downloader downloader downloader downloader
} }
// NewS3Client returns an instance of an S3Client. // NewS3Client returns an instance of an S3Client.
func NewS3Client(endpoint, region, accessKey, secretKey, bucket, key string, forcePathStyle bool) *S3Client { func NewS3Client(endpoint, region, accessKey, secretKey, bucket, key string, forcePathStyle bool) (*S3Client, error) {
cfg := aws.Config{
Endpoint: aws.String(endpoint),
Region: aws.String(region),
S3ForcePathStyle: aws.Bool(forcePathStyle),
}
// If credentials aren't provided by the user, the AWS SDK will use the default
// credential provider chain, which supports environment variables, shared credentials
// file, and EC2 instance roles.
if accessKey != "" && secretKey != "" {
cfg.Credentials = credentials.NewStaticCredentials(accessKey, secretKey, "")
}
sess, err := session.NewSession(&cfg)
if err != nil {
return nil, err
}
s3 := s3.New(sess)
return &S3Client{ return &S3Client{
endpoint: endpoint, endpoint: endpoint,
region: region, region: region,
@ -54,7 +75,13 @@ func NewS3Client(endpoint, region, accessKey, secretKey, bucket, key string, for
bucket: bucket, bucket: bucket,
key: key, key: key,
forcePathStyle: forcePathStyle, forcePathStyle: forcePathStyle,
}
session: sess,
s3: s3,
uploader: s3manager.NewUploaderWithClient(s3),
downloader: s3manager.NewDownloaderWithClient(s3),
}, nil
} }
// String returns a string representation of the S3Client. // String returns a string representation of the S3Client.
@ -73,19 +100,6 @@ func (s *S3Client) String() string {
// Upload uploads data to S3. // Upload uploads data to S3.
func (s *S3Client) Upload(ctx context.Context, reader io.Reader, id string) error { func (s *S3Client) Upload(ctx context.Context, reader io.Reader, id string) error {
sess, err := s.createSession()
if err != nil {
return err
}
// If an uploader was not provided, use a real S3 uploader.
var uploader uploader
if s.uploader == nil {
uploader = s3manager.NewUploader(sess)
} else {
uploader = s.uploader
}
input := &s3manager.UploadInput{ input := &s3manager.UploadInput{
Bucket: aws.String(s.bucket), Bucket: aws.String(s.bucket),
Key: aws.String(s.key), Key: aws.String(s.key),
@ -97,7 +111,7 @@ func (s *S3Client) Upload(ctx context.Context, reader io.Reader, id string) erro
AWSS3IDKey: aws.String(id), AWSS3IDKey: aws.String(id),
} }
} }
_, err = uploader.UploadWithContext(ctx, input) _, err := s.uploader.UploadWithContext(ctx, input)
if err != nil { if err != nil {
return fmt.Errorf("failed to upload to %v: %w", s, err) return fmt.Errorf("failed to upload to %v: %w", s, err)
} }
@ -107,18 +121,12 @@ func (s *S3Client) Upload(ctx context.Context, reader io.Reader, id string) erro
// CurrentID returns the last ID uploaded to S3. // CurrentID returns the last ID uploaded to S3.
func (s *S3Client) CurrentID(ctx context.Context) (string, error) { func (s *S3Client) CurrentID(ctx context.Context) (string, error) {
sess, err := s.createSession()
if err != nil {
return "", err
}
svc := s3.New(sess)
input := &s3.HeadObjectInput{ input := &s3.HeadObjectInput{
Bucket: aws.String(s.bucket), Bucket: aws.String(s.bucket),
Key: aws.String(s.key), Key: aws.String(s.key),
} }
result, err := svc.HeadObjectWithContext(ctx, input) result, err := s.s3.HeadObjectWithContext(ctx, input)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to get object head for %v: %w", s, err) return "", fmt.Errorf("failed to get object head for %v: %w", s, err)
} }
@ -132,20 +140,7 @@ func (s *S3Client) CurrentID(ctx context.Context) (string, error) {
// Download downloads data from S3. // Download downloads data from S3.
func (s *S3Client) Download(ctx context.Context, writer io.WriterAt) error { func (s *S3Client) Download(ctx context.Context, writer io.WriterAt) error {
sess, err := s.createSession() _, err := s.downloader.DownloadWithContext(ctx, writer, &s3.GetObjectInput{
if err != nil {
return err
}
// If a downloader was not provided, use a real S3 downloader.
var downloader downloader
if s.downloader == nil {
downloader = s3manager.NewDownloader(sess)
} else {
downloader = s.downloader
}
_, err = downloader.DownloadWithContext(ctx, writer, &s3.GetObjectInput{
Bucket: aws.String(s.bucket), Bucket: aws.String(s.bucket),
Key: aws.String(s.key), Key: aws.String(s.key),
}) })
@ -156,25 +151,6 @@ func (s *S3Client) Download(ctx context.Context, writer io.WriterAt) error {
return nil return nil
} }
func (s *S3Client) createSession() (*session.Session, error) {
cfg := aws.Config{
Endpoint: aws.String(s.endpoint),
Region: aws.String(s.region),
S3ForcePathStyle: aws.Bool(s.forcePathStyle),
}
// If credentials aren't provided by the user, the AWS SDK will use the default
// credential provider chain, which supports environment variables, shared credentials
// file, and EC2 instance roles.
if s.accessKey != "" && s.secretKey != "" {
cfg.Credentials = credentials.NewStaticCredentials(s.accessKey, s.secretKey, "")
}
sess, err := session.NewSession(&cfg)
if err != nil {
return nil, fmt.Errorf("failed to create S3 session: %w", err)
}
return sess, nil
}
type uploader interface { type uploader interface {
UploadWithContext(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) UploadWithContext(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error)
} }

@ -15,7 +15,10 @@ import (
) )
func Test_NewS3Client(t *testing.T) { func Test_NewS3Client(t *testing.T) {
c := NewS3Client("endpoint1", "region1", "access", "secret", "bucket2", "key3", true) c, err := NewS3Client("endpoint1", "region1", "access", "secret", "bucket2", "key3", true)
if err != nil {
t.Fatalf("error while creating aws S3 client: %v", err)
}
if c.region != "region1" { if c.region != "region1" {
t.Fatalf("expected region to be %q, got %q", "region1", c.region) t.Fatalf("expected region to be %q, got %q", "region1", c.region)
} }
@ -38,22 +41,34 @@ func Test_NewS3Client(t *testing.T) {
func Test_S3Client_String(t *testing.T) { func Test_S3Client_String(t *testing.T) {
// Test native S3 with implicit endpoint // Test native S3 with implicit endpoint
c := NewS3Client("", "region1", "access", "secret", "bucket2", "key3", false) c, err := NewS3Client("", "region1", "access", "secret", "bucket2", "key3", false)
if err != nil {
t.Fatalf("error while creating aws S3 client: %v", err)
}
if c.String() != "s3://bucket2/key3" { if c.String() != "s3://bucket2/key3" {
t.Fatalf("expected String() to be %q, got %q", "s3://bucket2/key3", c.String()) t.Fatalf("expected String() to be %q, got %q", "s3://bucket2/key3", c.String())
} }
// Test native S3 with explicit endpoint // Test native S3 with explicit endpoint
c = NewS3Client("s3.amazonaws.com", "region1", "access", "secret", "bucket2", "key3", false) c, err = NewS3Client("s3.amazonaws.com", "region1", "access", "secret", "bucket2", "key3", false)
if err != nil {
t.Fatalf("error while creating aws S3 client: %v", err)
}
if c.String() != "s3://bucket2/key3" { if c.String() != "s3://bucket2/key3" {
t.Fatalf("expected String() to be %q, got %q", "s3://bucket2/key3", c.String()) t.Fatalf("expected String() to be %q, got %q", "s3://bucket2/key3", c.String())
} }
// Test non-native S3 (explicit endpoint) with non-path style (e.g. Wasabi) // Test non-native S3 (explicit endpoint) with non-path style (e.g. Wasabi)
c = NewS3Client("s3.ca-central-1.wasabisys.com", "region1", "access", "secret", "bucket2", "key3", false) c, err = NewS3Client("s3.ca-central-1.wasabisys.com", "region1", "access", "secret", "bucket2", "key3", false)
if err != nil {
t.Fatalf("error while creating aws S3 client: %v", err)
}
if c.String() != "s3://bucket2.s3.ca-central-1.wasabisys.com/key3" { if c.String() != "s3://bucket2.s3.ca-central-1.wasabisys.com/key3" {
t.Fatalf("expected String() to be %q, got %q", "s3://bucket2.s3.ca-central-1.wasabisys.com/key3", c.String()) t.Fatalf("expected String() to be %q, got %q", "s3://bucket2.s3.ca-central-1.wasabisys.com/key3", c.String())
} }
// Test non-native S3 (explicit endpoint) with forced path style (e.g. MinIO) // Test non-native S3 (explicit endpoint) with forced path style (e.g. MinIO)
c = NewS3Client("s3.minio.example.com", "region1", "access", "secret", "bucket2", "key3", true) c, err = NewS3Client("s3.minio.example.com", "region1", "access", "secret", "bucket2", "key3", true)
if err != nil {
t.Fatalf("error while creating aws S3 client: %v", err)
}
if c.String() != "s3://s3.minio.example.com/bucket2/key3" { if c.String() != "s3://s3.minio.example.com/bucket2/key3" {
t.Fatalf("expected String() to be %q, got %q", "s3://s3.minio.example.com/bucket2/key3", c.String()) t.Fatalf("expected String() to be %q, got %q", "s3://s3.minio.example.com/bucket2/key3", c.String())
} }

@ -250,8 +250,11 @@ func startAutoBackups(ctx context.Context, cfg *Config, str *store.Store) (*back
return nil, fmt.Errorf("failed to parse auto-backup file: %s", err.Error()) return nil, fmt.Errorf("failed to parse auto-backup file: %s", err.Error())
} }
provider := store.NewProvider(str, uCfg.Vacuum, !uCfg.NoCompress) provider := store.NewProvider(str, uCfg.Vacuum, !uCfg.NoCompress)
sc := aws.NewS3Client(s3cfg.Endpoint, s3cfg.Region, s3cfg.AccessKeyID, s3cfg.SecretAccessKey, sc, err := aws.NewS3Client(s3cfg.Endpoint, s3cfg.Region, s3cfg.AccessKeyID, s3cfg.SecretAccessKey,
s3cfg.Bucket, s3cfg.Path, s3cfg.ForcePathStyle) s3cfg.Bucket, s3cfg.Path, s3cfg.ForcePathStyle)
if err != nil {
return nil, fmt.Errorf("failed to create aws S3 client: %s", err.Error())
}
u := backup.NewUploader(sc, provider, time.Duration(uCfg.Interval), !uCfg.NoCompress) u := backup.NewUploader(sc, provider, time.Duration(uCfg.Interval), !uCfg.NoCompress)
u.Start(ctx, str.IsLeader) u.Start(ctx, str.IsLeader)
return u, nil return u, nil

Loading…
Cancel
Save