From 7655725d5da287a41135754bb61ed846d11b2cf5 Mon Sep 17 00:00:00 2001 From: Philip O'Toole Date: Sun, 30 Apr 2023 12:18:25 -0400 Subject: [PATCH] Try to simplify main --- auth/credential_store.go | 13 +++++++++ auth/credential_store_test.go | 53 +++++++++++++++++++++++++++++++++++ cmd/rqlited/flags.go | 7 +++++ cmd/rqlited/main.go | 29 ++++--------------- 4 files changed, 79 insertions(+), 23 deletions(-) diff --git a/auth/credential_store.go b/auth/credential_store.go index b2494f7f..0d1de7f8 100644 --- a/auth/credential_store.go +++ b/auth/credential_store.go @@ -5,6 +5,7 @@ package auth import ( "encoding/json" "io" + "os" "sync" "golang.org/x/crypto/bcrypt" @@ -105,6 +106,18 @@ func NewCredentialsStore() *CredentialsStore { } } +// NewCredentialsStoreFromFile returns a new instance of a CredentialStore loaded from a file. +func NewCredentialsStoreFromFile(path string) (*CredentialsStore, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + c := NewCredentialsStore() + return c, c.Load(f) +} + // Load loads credential information from a reader. func (c *CredentialsStore) Load(r io.Reader) error { dec := json.NewDecoder(r) diff --git a/auth/credential_store_test.go b/auth/credential_store_test.go index ba1b143b..f86a4945 100644 --- a/auth/credential_store_test.go +++ b/auth/credential_store_test.go @@ -1,6 +1,7 @@ package auth import ( + "os" "strings" "testing" ) @@ -433,6 +434,46 @@ func Test_AuthPermsRequestLoadSingle(t *testing.T) { t.Fatalf("single credential not loaded correctly") } + b1 := &testBasicAuther{ + username: "username1", + password: "password1", + ok: true, + } + if perm := store.HasPermRequest(b1, "foo"); !perm { + t.Fatalf("username1 does not has perm foo via request") + } + b2 := &testBasicAuther{ + username: "username2", + password: "password1", + ok: true, + } + if perm := store.HasPermRequest(b2, "foo"); perm { + t.Fatalf("username1 does have perm foo via request") + } +} + +func Test_AuthPermsRequestLoadSingleFromFile(t *testing.T) { + const jsonStream = ` + [ + { + "username": "username1", + "password": "password1", + "perms": ["foo", "bar"] + } + ] + ` + path := mustWriteTempFile(t, jsonStream) + defer os.Remove(path) + + store, err := NewCredentialsStoreFromFile(path) + if err != nil { + t.Fatalf("failed to load credential store from file: %s", err.Error()) + } + + if check := store.Check("username1", "password1"); !check { + t.Fatalf("single credential not loaded correctly") + } + b1 := &testBasicAuther{ username: "username1", password: "password1", @@ -553,3 +594,15 @@ func Test_AuthPermsAllUsers(t *testing.T) { t.Fatalf("username1 should have abc perm via *") } } + +func mustWriteTempFile(t *testing.T, s string) string { + f, err := os.CreateTemp(t.TempDir(), "rqlite-test") + if err != nil { + panic("failed to create temp file") + } + defer f.Close() + if _, err := f.WriteString(s); err != nil { + panic("failed to write to temp file") + } + return f.Name() +} diff --git a/cmd/rqlited/flags.go b/cmd/rqlited/flags.go index aedd323b..91b34fd5 100644 --- a/cmd/rqlited/flags.go +++ b/cmd/rqlited/flags.go @@ -9,6 +9,7 @@ import ( "net" "net/url" "os" + "path/filepath" "runtime" "strings" "time" @@ -233,6 +234,12 @@ func (c *Config) Validate() error { return errors.New("-on-disk-path is set, but -on-disk is not") } + dataPath, err := filepath.Abs(c.DataPath) + if err != nil { + return fmt.Errorf("failed to determine absolute data path: %s", err.Error()) + } + c.DataPath = dataPath + if !bothUnsetSet(c.HTTPx509Cert, c.HTTPx509Key) { return fmt.Errorf("either both -%s and -%s must be set, or neither", HTTPx509CertFlag, HTTPx509KeyFlag) } diff --git a/cmd/rqlited/main.go b/cmd/rqlited/main.go index b44822e6..7d4abbeb 100644 --- a/cmd/rqlited/main.go +++ b/cmd/rqlited/main.go @@ -9,7 +9,6 @@ import ( "net" "os" "os/signal" - "path/filepath" "runtime" "strings" "syscall" @@ -210,10 +209,6 @@ func startAutoBackups(ctx context.Context, cfg *Config, str *store.Store) (*uplo } func createStore(cfg *Config, ln *tcp.Layer) (*store.Store, error) { - dataPath, err := filepath.Abs(cfg.DataPath) - if err != nil { - return nil, fmt.Errorf("failed to determine absolute data path: %s", err.Error()) - } dbConf := store.NewDBConfig(!cfg.OnDisk) dbConf.OnDiskPath = cfg.OnDiskPath dbConf.FKConstraints = cfg.FKConstraints @@ -240,11 +235,10 @@ func createStore(cfg *Config, ln *tcp.Layer) (*store.Store, error) { str.ReapTimeout = cfg.RaftReapNodeTimeout str.ReapReadOnlyTimeout = cfg.RaftReapReadOnlyNodeTimeout - isNew := store.IsNewNode(dataPath) - if isNew { - log.Printf("no preexisting node state detected in %s, node may be bootstrapping", dataPath) + if store.IsNewNode(cfg.DataPath) { + log.Printf("no preexisting node state detected in %s, node may be bootstrapping", cfg.DataPath) } else { - log.Printf("preexisting node state detected in %s", dataPath) + log.Printf("preexisting node state detected in %s", cfg.DataPath) } return str, nil @@ -353,17 +347,7 @@ func credentialStore(cfg *Config) (*auth.CredentialsStore, error) { if cfg.AuthFile == "" { return nil, nil } - - f, err := os.Open(cfg.AuthFile) - if err != nil { - return nil, fmt.Errorf("failed to open authentication file %s: %s", cfg.AuthFile, err.Error()) - } - - cs := auth.NewCredentialsStore() - if err := cs.Load(f); err != nil { - return nil, err - } - return cs, nil + return auth.NewCredentialsStoreFromFile(cfg.AuthFile) } func createJoiner(cfg *Config, credStr *auth.CredentialsStore) (*cluster.Joiner, error) { @@ -386,7 +370,6 @@ func clusterService(cfg *Config, tn cluster.Transport, db cluster.Database, mgr c := cluster.New(tn, db, mgr, credStr) c.SetAPIAddr(cfg.HTTPAdv) c.EnableHTTPS(cfg.HTTPx509Cert != "" && cfg.HTTPx509Key != "") // Conditions met for an HTTPS API - if err := c.Open(); err != nil { return nil, err } @@ -400,13 +383,13 @@ func createClusterClient(cfg *Config, clstr *cluster.Service) (*cluster.Client, dialerTLSConfig, err = rtls.CreateClientConfig(cfg.NodeX509Cert, cfg.NodeX509Key, cfg.NodeX509CACert, cfg.NoNodeVerify, cfg.TLS1011) if err != nil { - log.Fatalf("failed to create TLS config for cluster dialer: %s", err.Error()) + return nil, fmt.Errorf("failed to create TLS config for cluster dialer: %s", err.Error()) } } clstrDialer := tcp.NewDialer(cluster.MuxClusterHeader, dialerTLSConfig) clstrClient := cluster.NewClient(clstrDialer, cfg.ClusterConnectTimeout) if err := clstrClient.SetLocal(cfg.RaftAdv, clstr); err != nil { - log.Fatalf("failed to set cluster client local parameters: %s", err.Error()) + return nil, fmt.Errorf("failed to set cluster client local parameters: %s", err.Error()) } return clstrClient, nil }