From a299f2d552ef367af93c128f712b37df09c833ce Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Tue, 2 May 2023 22:57:30 -0400 Subject: [PATCH] Downloader config --- download/config.go | 121 ++++++++++++++++++++ download/config_test.go | 248 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 369 insertions(+) create mode 100644 download/config.go create mode 100644 download/config_test.go diff --git a/download/config.go b/download/config.go new file mode 100644 index 00000000..905212a1 --- /dev/null +++ b/download/config.go @@ -0,0 +1,121 @@ +package download + +import ( + "encoding/json" + "errors" + "io/ioutil" + "os" + "time" + + "github.com/rqlite/rqlite/aws" +) + +const ( + // version is the max version of the config file format supported + version = 1 +) + +var ( + // ErrInvalidVersion is returned when the config file version is not supported. + ErrInvalidVersion = errors.New("invalid version") + + // ErrUnsupportedStorageType is returned when the storage type is not supported. + ErrUnsupportedStorageType = errors.New("unsupported storage type") +) + +// Duration is a wrapper around time.Duration that allows us to unmarshal +type Duration time.Duration + +// MarshalJSON marshals the duration as a string +func (d Duration) MarshalJSON() ([]byte, error) { + return json.Marshal(time.Duration(d).String()) +} + +// UnmarshalJSON unmarshals the duration from a string or a float64 +func (d *Duration) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + switch value := v.(type) { + case float64: + *d = Duration(time.Duration(value)) + return nil + case string: + tmp, err := time.ParseDuration(value) + if err != nil { + return err + } + *d = Duration(tmp) + return nil + default: + return errors.New("invalid duration") + } +} + +// StorageType is a wrapper around string that allows us to unmarshal +type StorageType string + +// UnmarshalJSON unmarshals the storage type from a string and validates it +func (s *StorageType) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + switch value := v.(type) { + case string: + *s = StorageType(value) + if *s != "s3" { + return ErrUnsupportedStorageType + } + return nil + default: + return ErrUnsupportedStorageType + } +} + +// Config is the config file format for the upload service +type Config struct { + Version int `json:"version"` + Type StorageType `json:"type"` + Timeout Duration `json:"timeout"` + Sub json.RawMessage `json:"sub"` +} + +// Unmarshal unmarshals the config file and returns the config and subconfig +func Unmarshal(data []byte) (*Config, *aws.S3Config, error) { + cfg := &Config{} + err := json.Unmarshal(data, cfg) + if err != nil { + return nil, nil, err + } + + if cfg.Version > version { + return nil, nil, ErrInvalidVersion + } + + s3cfg := &aws.S3Config{} + err = json.Unmarshal(cfg.Sub, s3cfg) + if err != nil { + return nil, nil, err + } + return cfg, s3cfg, nil +} + +// ReadConfigFile reads the config file and returns the data. It also expands +// any environment variables in the config file. +func ReadConfigFile(filename string) ([]byte, error) { + f, err := os.Open(filename) + if err != nil { + return nil, err + } + defer f.Close() + + data, err := ioutil.ReadAll(f) + if err != nil { + return nil, err + } + + data = []byte(os.ExpandEnv(string(data))) + return data, nil +} diff --git a/download/config_test.go b/download/config_test.go new file mode 100644 index 00000000..2bfc5488 --- /dev/null +++ b/download/config_test.go @@ -0,0 +1,248 @@ +package download + +import ( + "bytes" + "errors" + "io/ioutil" + "os" + "reflect" + "testing" + "time" + + "github.com/rqlite/rqlite/aws" +) + +func Test_ReadConfigFile(t *testing.T) { + t.Run("valid config file", func(t *testing.T) { + // Create a temporary config file + tempFile, err := ioutil.TempFile("", "upload_config") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tempFile.Name()) + + content := []byte("key=value") + if _, err := tempFile.Write(content); err != nil { + t.Fatal(err) + } + tempFile.Close() + + data, err := ReadConfigFile(tempFile.Name()) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + if !bytes.Equal(data, content) { + t.Fatalf("Expected %v, got %v", content, data) + } + }) + + t.Run("non-existent file", func(t *testing.T) { + _, err := ReadConfigFile("nonexistentfile") + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("Expected os.ErrNotExist, got %v", err) + } + }) + + t.Run("file with environment variables", func(t *testing.T) { + // Set an environment variable + if err := os.Setenv("TEST_VAR", "test_value"); err != nil { + t.Fatal(err) + } + defer os.Unsetenv("TEST_VAR") + + // Create a temporary config file with an environment variable + tempFile, err := ioutil.TempFile("", "upload_config") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tempFile.Name()) + + content := []byte("key=$TEST_VAR") + if _, err := tempFile.Write(content); err != nil { + t.Fatal(err) + } + tempFile.Close() + + data, err := ReadConfigFile(tempFile.Name()) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + expectedContent := []byte("key=test_value") + if !bytes.Equal(data, expectedContent) { + t.Fatalf("Expected %v, got %v", expectedContent, data) + } + }) + + t.Run("longer file with environment variables", func(t *testing.T) { + // Set an environment variable + if err := os.Setenv("TEST_VAR1", "test_value"); err != nil { + t.Fatal(err) + } + defer os.Unsetenv("TEST_VAR1") + + // Create a temporary config file with an environment variable + tempFile, err := ioutil.TempFile("", "upload_config") + if err != nil { + t.Fatal(err) + } + defer os.Remove(tempFile.Name()) + + content := []byte(` +key1=$TEST_VAR1 +key2=TEST_VAR2`) + if _, err := tempFile.Write(content); err != nil { + t.Fatal(err) + } + tempFile.Close() + + data, err := ReadConfigFile(tempFile.Name()) + if err != nil { + t.Fatalf("Expected no error, got %v", err) + } + expectedContent := []byte(` +key1=test_value +key2=TEST_VAR2`) + if !bytes.Equal(data, expectedContent) { + t.Fatalf("Expected %v, got %v", expectedContent, data) + } + }) +} + +func TestUnmarshal(t *testing.T) { + testCases := []struct { + name string + input []byte + expectedCfg *Config + expectedS3 *aws.S3Config + expectedErr error + }{ + { + name: "ValidS3Config", + input: []byte(` + { + "version": 1, + "type": "s3", + "timeout": "30s", + "sub": { + "access_key_id": "test_id", + "secret_access_key": "test_secret", + "region": "us-west-2", + "bucket": "test_bucket", + "path": "test/path" + } + } + `), + expectedCfg: &Config{ + Version: 1, + Type: "s3", + Timeout: 30 * Duration(time.Second), + }, + expectedS3: &aws.S3Config{ + AccessKeyID: "test_id", + SecretAccessKey: "test_secret", + Region: "us-west-2", + Bucket: "test_bucket", + Path: "test/path", + }, + expectedErr: nil, + }, + { + name: "ValidS3ConfigNoptionalFields", + input: []byte(` + { + "version": 1, + "type": "s3", + "timeout": "1m", + "sub": { + "access_key_id": "test_id", + "secret_access_key": "test_secret", + "region": "us-west-2", + "bucket": "test_bucket", + "path": "test/path" + } + } + `), + expectedCfg: &Config{ + Version: 1, + Type: "s3", + Timeout: Duration(time.Minute), + }, + expectedS3: &aws.S3Config{ + AccessKeyID: "test_id", + SecretAccessKey: "test_secret", + Region: "us-west-2", + Bucket: "test_bucket", + Path: "test/path", + }, + expectedErr: nil, + }, + { + name: "InvalidVersion", + input: []byte(` + { + "version": 2, + "type": "s3", + "timeout": "5m", + "sub": { + "access_key_id": "test_id", + "secret_access_key": "test_secret", + "region": "us-west-2", + "bucket": "test_bucket", + "path": "test/path" + } + } `), + expectedCfg: nil, + expectedS3: nil, + expectedErr: ErrInvalidVersion, + }, + { + name: "UnsupportedType", + input: []byte(` + { + "version": 1, + "type": "unsupported", + "timeout": "24h", + "sub": { + "access_key_id": "test_id", + "secret_access_key": "test_secret", + "region": "us-west-2", + "bucket": "test_bucket", + "path": "test/path" + } + } `), + expectedCfg: nil, + expectedS3: nil, + expectedErr: ErrUnsupportedStorageType, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + cfg, s3Cfg, err := Unmarshal(tc.input) + _ = s3Cfg + + if !errors.Is(err, tc.expectedErr) { + t.Fatalf("Test case %s failed, expected error %v, got %v", tc.name, tc.expectedErr, err) + } + + if !compareConfig(cfg, tc.expectedCfg) { + t.Fatalf("Test case %s failed, expected config %+v, got %+v", tc.name, tc.expectedCfg, cfg) + } + + if tc.expectedS3 != nil { + if !reflect.DeepEqual(s3Cfg, tc.expectedS3) { + t.Fatalf("Test case %s failed, expected S3Config %+v, got %+v", tc.name, tc.expectedS3, s3Cfg) + } + } + }) + } +} + +func compareConfig(a, b *Config) bool { + if a == nil || b == nil { + return a == b + } + return a.Version == b.Version && + a.Type == b.Type && + a.Timeout == b.Timeout +}