diff --git a/auto/restore/downloader.go b/auto/restore/downloader.go index a554cf72..31fb9088 100644 --- a/auto/restore/downloader.go +++ b/auto/restore/downloader.go @@ -38,8 +38,11 @@ func DownloadFile(ctx context.Context, cfgPath string) (path string, errOK bool, if err != nil { return "", false, fmt.Errorf("failed to parse auto-restore file: %s", err.Error()) } - sc := aws.NewS3Client(s3cfg.Endpoint, s3cfg.Region, s3cfg.AccessKeyID, s3cfg.SecretAccessKey, + sc, err := aws.NewS3Client(s3cfg.Endpoint, s3cfg.Region, s3cfg.AccessKeyID, s3cfg.SecretAccessKey, s3cfg.Bucket, s3cfg.Path, s3cfg.ForcePathStyle) + if err != nil { + return "", false, fmt.Errorf("failed to create aws S3 client: %s", err.Error()) + } d := NewDownloader(sc) // Create a temporary file to download to. diff --git a/aws/s3.go b/aws/s3.go index b632dbeb..390666ab 100644 --- a/aws/s3.go +++ b/aws/s3.go @@ -39,13 +39,34 @@ type S3Client struct { key string forcePathStyle bool + session *session.Session + s3 *s3.S3 + // These fields are used for testing via dependency injection. uploader uploader downloader downloader } // NewS3Client returns an instance of an S3Client. -func NewS3Client(endpoint, region, accessKey, secretKey, bucket, key string, forcePathStyle bool) *S3Client { +func NewS3Client(endpoint, region, accessKey, secretKey, bucket, key string, forcePathStyle bool) (*S3Client, error) { + cfg := aws.Config{ + Endpoint: aws.String(endpoint), + Region: aws.String(region), + S3ForcePathStyle: aws.Bool(forcePathStyle), + } + // If credentials aren't provided by the user, the AWS SDK will use the default + // credential provider chain, which supports environment variables, shared credentials + // file, and EC2 instance roles. + if accessKey != "" && secretKey != "" { + cfg.Credentials = credentials.NewStaticCredentials(accessKey, secretKey, "") + } + sess, err := session.NewSession(&cfg) + if err != nil { + return nil, err + } + + s3 := s3.New(sess) + return &S3Client{ endpoint: endpoint, region: region, @@ -54,7 +75,13 @@ func NewS3Client(endpoint, region, accessKey, secretKey, bucket, key string, for bucket: bucket, key: key, forcePathStyle: forcePathStyle, - } + + session: sess, + s3: s3, + + uploader: s3manager.NewUploaderWithClient(s3), + downloader: s3manager.NewDownloaderWithClient(s3), + }, nil } // String returns a string representation of the S3Client. @@ -73,19 +100,6 @@ func (s *S3Client) String() string { // Upload uploads data to S3. func (s *S3Client) Upload(ctx context.Context, reader io.Reader, id string) error { - sess, err := s.createSession() - if err != nil { - return err - } - - // If an uploader was not provided, use a real S3 uploader. - var uploader uploader - if s.uploader == nil { - uploader = s3manager.NewUploader(sess) - } else { - uploader = s.uploader - } - input := &s3manager.UploadInput{ Bucket: aws.String(s.bucket), Key: aws.String(s.key), @@ -97,7 +111,7 @@ func (s *S3Client) Upload(ctx context.Context, reader io.Reader, id string) erro AWSS3IDKey: aws.String(id), } } - _, err = uploader.UploadWithContext(ctx, input) + _, err := s.uploader.UploadWithContext(ctx, input) if err != nil { return fmt.Errorf("failed to upload to %v: %w", s, err) } @@ -107,18 +121,12 @@ func (s *S3Client) Upload(ctx context.Context, reader io.Reader, id string) erro // CurrentID returns the last ID uploaded to S3. func (s *S3Client) CurrentID(ctx context.Context) (string, error) { - sess, err := s.createSession() - if err != nil { - return "", err - } - - svc := s3.New(sess) input := &s3.HeadObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(s.key), } - result, err := svc.HeadObjectWithContext(ctx, input) + result, err := s.s3.HeadObjectWithContext(ctx, input) if err != nil { return "", fmt.Errorf("failed to get object head for %v: %w", s, err) } @@ -132,20 +140,7 @@ func (s *S3Client) CurrentID(ctx context.Context) (string, error) { // Download downloads data from S3. func (s *S3Client) Download(ctx context.Context, writer io.WriterAt) error { - sess, err := s.createSession() - if err != nil { - return err - } - - // If a downloader was not provided, use a real S3 downloader. - var downloader downloader - if s.downloader == nil { - downloader = s3manager.NewDownloader(sess) - } else { - downloader = s.downloader - } - - _, err = downloader.DownloadWithContext(ctx, writer, &s3.GetObjectInput{ + _, err := s.downloader.DownloadWithContext(ctx, writer, &s3.GetObjectInput{ Bucket: aws.String(s.bucket), Key: aws.String(s.key), }) @@ -156,25 +151,6 @@ func (s *S3Client) Download(ctx context.Context, writer io.WriterAt) error { return nil } -func (s *S3Client) createSession() (*session.Session, error) { - cfg := aws.Config{ - Endpoint: aws.String(s.endpoint), - Region: aws.String(s.region), - S3ForcePathStyle: aws.Bool(s.forcePathStyle), - } - // If credentials aren't provided by the user, the AWS SDK will use the default - // credential provider chain, which supports environment variables, shared credentials - // file, and EC2 instance roles. - if s.accessKey != "" && s.secretKey != "" { - cfg.Credentials = credentials.NewStaticCredentials(s.accessKey, s.secretKey, "") - } - sess, err := session.NewSession(&cfg) - if err != nil { - return nil, fmt.Errorf("failed to create S3 session: %w", err) - } - return sess, nil -} - type uploader interface { UploadWithContext(ctx aws.Context, input *s3manager.UploadInput, opts ...func(*s3manager.Uploader)) (*s3manager.UploadOutput, error) } diff --git a/aws/s3_test.go b/aws/s3_test.go index 95b6350c..783424c3 100644 --- a/aws/s3_test.go +++ b/aws/s3_test.go @@ -15,7 +15,10 @@ import ( ) func Test_NewS3Client(t *testing.T) { - c := NewS3Client("endpoint1", "region1", "access", "secret", "bucket2", "key3", true) + c, err := NewS3Client("endpoint1", "region1", "access", "secret", "bucket2", "key3", true) + if err != nil { + t.Fatalf("error while creating aws S3 client: %v", err) + } if c.region != "region1" { t.Fatalf("expected region to be %q, got %q", "region1", c.region) } @@ -38,22 +41,34 @@ func Test_NewS3Client(t *testing.T) { func Test_S3Client_String(t *testing.T) { // Test native S3 with implicit endpoint - c := NewS3Client("", "region1", "access", "secret", "bucket2", "key3", false) + c, err := NewS3Client("", "region1", "access", "secret", "bucket2", "key3", false) + if err != nil { + t.Fatalf("error while creating aws S3 client: %v", err) + } if c.String() != "s3://bucket2/key3" { t.Fatalf("expected String() to be %q, got %q", "s3://bucket2/key3", c.String()) } // Test native S3 with explicit endpoint - c = NewS3Client("s3.amazonaws.com", "region1", "access", "secret", "bucket2", "key3", false) + c, err = NewS3Client("s3.amazonaws.com", "region1", "access", "secret", "bucket2", "key3", false) + if err != nil { + t.Fatalf("error while creating aws S3 client: %v", err) + } if c.String() != "s3://bucket2/key3" { t.Fatalf("expected String() to be %q, got %q", "s3://bucket2/key3", c.String()) } // Test non-native S3 (explicit endpoint) with non-path style (e.g. Wasabi) - c = NewS3Client("s3.ca-central-1.wasabisys.com", "region1", "access", "secret", "bucket2", "key3", false) + c, err = NewS3Client("s3.ca-central-1.wasabisys.com", "region1", "access", "secret", "bucket2", "key3", false) + if err != nil { + t.Fatalf("error while creating aws S3 client: %v", err) + } if c.String() != "s3://bucket2.s3.ca-central-1.wasabisys.com/key3" { t.Fatalf("expected String() to be %q, got %q", "s3://bucket2.s3.ca-central-1.wasabisys.com/key3", c.String()) } // Test non-native S3 (explicit endpoint) with forced path style (e.g. MinIO) - c = NewS3Client("s3.minio.example.com", "region1", "access", "secret", "bucket2", "key3", true) + c, err = NewS3Client("s3.minio.example.com", "region1", "access", "secret", "bucket2", "key3", true) + if err != nil { + t.Fatalf("error while creating aws S3 client: %v", err) + } if c.String() != "s3://s3.minio.example.com/bucket2/key3" { t.Fatalf("expected String() to be %q, got %q", "s3://s3.minio.example.com/bucket2/key3", c.String()) } diff --git a/cmd/rqlited/main.go b/cmd/rqlited/main.go index 7d514fe9..506e1f83 100644 --- a/cmd/rqlited/main.go +++ b/cmd/rqlited/main.go @@ -250,8 +250,11 @@ func startAutoBackups(ctx context.Context, cfg *Config, str *store.Store) (*back return nil, fmt.Errorf("failed to parse auto-backup file: %s", err.Error()) } provider := store.NewProvider(str, uCfg.Vacuum, !uCfg.NoCompress) - sc := aws.NewS3Client(s3cfg.Endpoint, s3cfg.Region, s3cfg.AccessKeyID, s3cfg.SecretAccessKey, + sc, err := aws.NewS3Client(s3cfg.Endpoint, s3cfg.Region, s3cfg.AccessKeyID, s3cfg.SecretAccessKey, s3cfg.Bucket, s3cfg.Path, s3cfg.ForcePathStyle) + if err != nil { + return nil, fmt.Errorf("failed to create aws S3 client: %s", err.Error()) + } u := backup.NewUploader(sc, provider, time.Duration(uCfg.Interval), !uCfg.NoCompress) u.Start(ctx, str.IsLeader) return u, nil