1
0
Fork 0
You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

251 lines
6.0 KiB
Go

package restore
import (
"bytes"
"errors"
"os"
"reflect"
"testing"
"time"
"github.com/rqlite/rqlite/v8/auto"
"github.com/rqlite/rqlite/v8/aws"
)
func Test_ReadConfigFile(t *testing.T) {
t.Run("valid config file", func(t *testing.T) {
// Create a temporary config file
tempFile, err := os.CreateTemp("", "upload_config")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tempFile.Name())
content := []byte("key=value")
if _, err := tempFile.Write(content); err != nil {
t.Fatal(err)
}
tempFile.Close()
data, err := ReadConfigFile(tempFile.Name())
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
if !bytes.Equal(data, content) {
t.Fatalf("Expected %v, got %v", content, data)
}
})
t.Run("non-existent file", func(t *testing.T) {
_, err := ReadConfigFile("nonexistentfile")
if !errors.Is(err, os.ErrNotExist) {
t.Fatalf("Expected os.ErrNotExist, got %v", err)
}
})
t.Run("file with environment variables", func(t *testing.T) {
// Set an environment variable
if err := os.Setenv("TEST_VAR", "test_value"); err != nil {
t.Fatal(err)
}
defer os.Unsetenv("TEST_VAR")
// Create a temporary config file with an environment variable
tempFile, err := os.CreateTemp("", "upload_config")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tempFile.Name())
content := []byte("key=$TEST_VAR")
if _, err := tempFile.Write(content); err != nil {
t.Fatal(err)
}
tempFile.Close()
data, err := ReadConfigFile(tempFile.Name())
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
expectedContent := []byte("key=test_value")
if !bytes.Equal(data, expectedContent) {
t.Fatalf("Expected %v, got %v", expectedContent, data)
}
})
t.Run("longer file with environment variables", func(t *testing.T) {
// Set an environment variable
if err := os.Setenv("TEST_VAR1", "test_value"); err != nil {
t.Fatal(err)
}
defer os.Unsetenv("TEST_VAR1")
// Create a temporary config file with an environment variable
tempFile, err := os.CreateTemp("", "upload_config")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tempFile.Name())
content := []byte(`
key1=$TEST_VAR1
key2=TEST_VAR2`)
if _, err := tempFile.Write(content); err != nil {
t.Fatal(err)
}
tempFile.Close()
data, err := ReadConfigFile(tempFile.Name())
if err != nil {
t.Fatalf("Expected no error, got %v", err)
}
expectedContent := []byte(`
key1=test_value
key2=TEST_VAR2`)
if !bytes.Equal(data, expectedContent) {
t.Fatalf("Expected %v, got %v", expectedContent, data)
}
})
}
func TestUnmarshal(t *testing.T) {
testCases := []struct {
name string
input []byte
expectedCfg *Config
expectedS3 *aws.S3Config
expectedErr error
}{
{
name: "ValidS3Config",
input: []byte(`
{
"version": 1,
"type": "s3",
"timeout": "30s",
"sub": {
"access_key_id": "test_id",
"secret_access_key": "test_secret",
"region": "us-west-2",
"bucket": "test_bucket",
"path": "test/path"
}
}
`),
expectedCfg: &Config{
Version: 1,
Type: "s3",
Timeout: 30 * auto.Duration(time.Second),
ContinueOnFailure: false,
},
expectedS3: &aws.S3Config{
AccessKeyID: "test_id",
SecretAccessKey: "test_secret",
Region: "us-west-2",
Bucket: "test_bucket",
Path: "test/path",
},
expectedErr: nil,
},
{
name: "ValidS3ConfigNoOptionalFields",
input: []byte(`
{
"version": 1,
"type": "s3",
"continue_on_failure": true,
"sub": {
"access_key_id": "test_id",
"secret_access_key": "test_secret",
"region": "us-west-2",
"bucket": "test_bucket",
"path": "test/path"
}
}
`),
expectedCfg: &Config{
Version: 1,
Type: "s3",
Timeout: auto.Duration(30 * time.Second),
ContinueOnFailure: true,
},
expectedS3: &aws.S3Config{
AccessKeyID: "test_id",
SecretAccessKey: "test_secret",
Region: "us-west-2",
Bucket: "test_bucket",
Path: "test/path",
},
expectedErr: nil,
},
{
name: "InvalidVersion",
input: []byte(`
{
"version": 2,
"type": "s3",
"timeout": "5m",
"sub": {
"access_key_id": "test_id",
"secret_access_key": "test_secret",
"region": "us-west-2",
"bucket": "test_bucket",
"path": "test/path"
}
} `),
expectedCfg: nil,
expectedS3: nil,
expectedErr: auto.ErrInvalidVersion,
},
{
name: "UnsupportedType",
input: []byte(`
{
"version": 1,
"type": "unsupported",
"timeout": "24h",
"sub": {
"access_key_id": "test_id",
"secret_access_key": "test_secret",
"region": "us-west-2",
"bucket": "test_bucket",
"path": "test/path"
}
} `),
expectedCfg: nil,
expectedS3: nil,
expectedErr: auto.ErrUnsupportedStorageType,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
cfg, s3Cfg, err := Unmarshal(tc.input)
_ = s3Cfg
if !errors.Is(err, tc.expectedErr) {
t.Fatalf("Test case %s failed, expected error %v, got %v", tc.name, tc.expectedErr, err)
}
if !compareConfig(cfg, tc.expectedCfg) {
t.Fatalf("Test case %s failed, expected config %+v, got %+v", tc.name, tc.expectedCfg, cfg)
}
if tc.expectedS3 != nil {
if !reflect.DeepEqual(s3Cfg, tc.expectedS3) {
t.Fatalf("Test case %s failed, expected S3Config %+v, got %+v", tc.name, tc.expectedS3, s3Cfg)
}
}
})
}
}
func compareConfig(a, b *Config) bool {
if a == nil || b == nil {
return a == b
}
return a.Version == b.Version &&
a.Type == b.Type &&
a.Timeout == b.Timeout
}