1
0
Fork 0

Downloader config

master
Philip O'Toole 1 year ago
parent a9151ec299
commit a299f2d552

@ -0,0 +1,121 @@
package download
import (
"encoding/json"
"errors"
"io/ioutil"
"os"
"time"
"github.com/rqlite/rqlite/aws"
)
const (
// version is the max version of the config file format supported
version = 1
)
var (
// ErrInvalidVersion is returned when the config file version is not supported.
ErrInvalidVersion = errors.New("invalid version")
// ErrUnsupportedStorageType is returned when the storage type is not supported.
ErrUnsupportedStorageType = errors.New("unsupported storage type")
)
// Duration is a wrapper around time.Duration that allows us to unmarshal
type Duration time.Duration
// MarshalJSON marshals the duration as a string
func (d Duration) MarshalJSON() ([]byte, error) {
return json.Marshal(time.Duration(d).String())
}
// UnmarshalJSON unmarshals the duration from a string or a float64
func (d *Duration) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch value := v.(type) {
case float64:
*d = Duration(time.Duration(value))
return nil
case string:
tmp, err := time.ParseDuration(value)
if err != nil {
return err
}
*d = Duration(tmp)
return nil
default:
return errors.New("invalid duration")
}
}
// StorageType is a wrapper around string that allows us to unmarshal
type StorageType string
// UnmarshalJSON unmarshals the storage type from a string and validates it
func (s *StorageType) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
switch value := v.(type) {
case string:
*s = StorageType(value)
if *s != "s3" {
return ErrUnsupportedStorageType
}
return nil
default:
return ErrUnsupportedStorageType
}
}
// Config is the config file format for the upload service
type Config struct {
Version int `json:"version"`
Type StorageType `json:"type"`
Timeout Duration `json:"timeout"`
Sub json.RawMessage `json:"sub"`
}
// Unmarshal unmarshals the config file and returns the config and subconfig
func Unmarshal(data []byte) (*Config, *aws.S3Config, error) {
cfg := &Config{}
err := json.Unmarshal(data, cfg)
if err != nil {
return nil, nil, err
}
if cfg.Version > version {
return nil, nil, ErrInvalidVersion
}
s3cfg := &aws.S3Config{}
err = json.Unmarshal(cfg.Sub, s3cfg)
if err != nil {
return nil, nil, err
}
return cfg, s3cfg, nil
}
// ReadConfigFile reads the config file and returns the data. It also expands
// any environment variables in the config file.
func ReadConfigFile(filename string) ([]byte, error) {
f, err := os.Open(filename)
if err != nil {
return nil, err
}
defer f.Close()
data, err := ioutil.ReadAll(f)
if err != nil {
return nil, err
}
data = []byte(os.ExpandEnv(string(data)))
return data, nil
}

@ -0,0 +1,248 @@
package download
import (
"bytes"
"errors"
"io/ioutil"
"os"
"reflect"
"testing"
"time"
"github.com/rqlite/rqlite/aws"
)
func Test_ReadConfigFile(t *testing.T) {
t.Run("valid config file", func(t *testing.T) {
// Create a temporary config file
tempFile, err := ioutil.TempFile("", "upload_config")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tempFile.Name())
content := []byte("key=value")
if _, err := tempFile.Write(content); err != nil {
t.Fatal(err)
}
tempFile.Close()
data, err := ReadConfigFile(tempFile.Name())
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if !bytes.Equal(data, content) {
t.Fatalf("Expected %v, got %v", content, data)
}
})
t.Run("non-existent file", func(t *testing.T) {
_, err := ReadConfigFile("nonexistentfile")
if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("Expected os.ErrNotExist, got %v", err)
}
})
t.Run("file with environment variables", func(t *testing.T) {
// Set an environment variable
if err := os.Setenv("TEST_VAR", "test_value"); err != nil {
t.Fatal(err)
}
defer os.Unsetenv("TEST_VAR")
// Create a temporary config file with an environment variable
tempFile, err := ioutil.TempFile("", "upload_config")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tempFile.Name())
content := []byte("key=$TEST_VAR")
if _, err := tempFile.Write(content); err != nil {
t.Fatal(err)
}
tempFile.Close()
data, err := ReadConfigFile(tempFile.Name())
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
expectedContent := []byte("key=test_value")
if !bytes.Equal(data, expectedContent) {
t.Fatalf("Expected %v, got %v", expectedContent, data)
}
})
t.Run("longer file with environment variables", func(t *testing.T) {
// Set an environment variable
if err := os.Setenv("TEST_VAR1", "test_value"); err != nil {
t.Fatal(err)
}
defer os.Unsetenv("TEST_VAR1")
// Create a temporary config file with an environment variable
tempFile, err := ioutil.TempFile("", "upload_config")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tempFile.Name())
content := []byte(`
key1=$TEST_VAR1
key2=TEST_VAR2`)
if _, err := tempFile.Write(content); err != nil {
t.Fatal(err)
}
tempFile.Close()
data, err := ReadConfigFile(tempFile.Name())
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
expectedContent := []byte(`
key1=test_value
key2=TEST_VAR2`)
if !bytes.Equal(data, expectedContent) {
t.Fatalf("Expected %v, got %v", expectedContent, data)
}
})
}
func TestUnmarshal(t *testing.T) {
testCases := []struct {
name string
input []byte
expectedCfg *Config
expectedS3 *aws.S3Config
expectedErr error
}{
{
name: "ValidS3Config",
input: []byte(`
{
"version": 1,
"type": "s3",
"timeout": "30s",
"sub": {
"access_key_id": "test_id",
"secret_access_key": "test_secret",
"region": "us-west-2",
"bucket": "test_bucket",
"path": "test/path"
}
}
`),
expectedCfg: &Config{
Version: 1,
Type: "s3",
Timeout: 30 * Duration(time.Second),
},
expectedS3: &aws.S3Config{
AccessKeyID: "test_id",
SecretAccessKey: "test_secret",
Region: "us-west-2",
Bucket: "test_bucket",
Path: "test/path",
},
expectedErr: nil,
},
{
name: "ValidS3ConfigNoptionalFields",
input: []byte(`
{
"version": 1,
"type": "s3",
"timeout": "1m",
"sub": {
"access_key_id": "test_id",
"secret_access_key": "test_secret",
"region": "us-west-2",
"bucket": "test_bucket",
"path": "test/path"
}
}
`),
expectedCfg: &Config{
Version: 1,
Type: "s3",
Timeout: Duration(time.Minute),
},
expectedS3: &aws.S3Config{
AccessKeyID: "test_id",
SecretAccessKey: "test_secret",
Region: "us-west-2",
Bucket: "test_bucket",
Path: "test/path",
},
expectedErr: nil,
},
{
name: "InvalidVersion",
input: []byte(`
{
"version": 2,
"type": "s3",
"timeout": "5m",
"sub": {
"access_key_id": "test_id",
"secret_access_key": "test_secret",
"region": "us-west-2",
"bucket": "test_bucket",
"path": "test/path"
}
} `),
expectedCfg: nil,
expectedS3: nil,
expectedErr: ErrInvalidVersion,
},
{
name: "UnsupportedType",
input: []byte(`
{
"version": 1,
"type": "unsupported",
"timeout": "24h",
"sub": {
"access_key_id": "test_id",
"secret_access_key": "test_secret",
"region": "us-west-2",
"bucket": "test_bucket",
"path": "test/path"
}
} `),
expectedCfg: nil,
expectedS3: nil,
expectedErr: ErrUnsupportedStorageType,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cfg, s3Cfg, err := Unmarshal(tc.input)
_ = s3Cfg
if !errors.Is(err, tc.expectedErr) {
t.Fatalf("Test case %s failed, expected error %v, got %v", tc.name, tc.expectedErr, err)
}
if !compareConfig(cfg, tc.expectedCfg) {
t.Fatalf("Test case %s failed, expected config %+v, got %+v", tc.name, tc.expectedCfg, cfg)
}
if tc.expectedS3 != nil {
if !reflect.DeepEqual(s3Cfg, tc.expectedS3) {
t.Fatalf("Test case %s failed, expected S3Config %+v, got %+v", tc.name, tc.expectedS3, s3Cfg)
}
}
})
}
}
func compareConfig(a, b *Config) bool {
if a == nil || b == nil {
return a == b
}
return a.Version == b.Version &&
a.Type == b.Type &&
a.Timeout == b.Timeout
}
Loading…
Cancel
Save