From c4d51ac8e7146e9a8eb30a95810640d9ea677f27 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Sun, 24 Sep 2023 10:15:32 +0000 Subject: [PATCH] Ensure full auth config is read --- server/src/engine/config.rs | 260 +++++++++++++++++++++------------ server/src/engine/tests/mod.rs | 70 ++++++++- 2 files changed, 226 insertions(+), 104 deletions(-) diff --git a/server/src/engine/config.rs b/server/src/engine/config.rs index 75c0bdba..b6fd2606 100644 --- a/server/src/engine/config.rs +++ b/server/src/engine/config.rs @@ -27,10 +27,7 @@ use { crate::util::os::SysIOError, core::fmt, - serde::{ - de::{self, Deserializer, Visitor}, - Deserialize, - }, + serde::Deserialize, std::{collections::HashMap, fs}, }; @@ -39,6 +36,7 @@ use { */ pub type ParsedRawArgs = std::collections::HashMap>; +pub const ROOT_PASSWORD_MIN_LEN: usize = 16; #[derive(Debug, PartialEq)] pub struct ModifyGuard { @@ -85,14 +83,21 @@ pub struct Configuration { endpoints: ConfigEndpoint, mode: ConfigMode, system: ConfigSystem, + auth: Option, } impl Configuration { - pub fn new(endpoints: ConfigEndpoint, mode: ConfigMode, system: ConfigSystem) -> Self { + pub fn new( + endpoints: ConfigEndpoint, + mode: ConfigMode, + system: ConfigSystem, + auth: Option, + ) -> Self { Self { endpoints, mode, system, + auth, } } const DEFAULT_HOST: &'static str = "127.0.0.1"; @@ -107,8 +112,8 @@ impl Configuration { mode: ConfigMode::Dev, system: ConfigSystem { reliability_system_window: Self::DEFAULT_RELIABILITY_SVC_PING, - auth: false, }, + auth: None, } } } @@ -158,44 +163,17 @@ impl ConfigEndpointTls { config mode */ -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Deserialize)] /// The configuration mode pub enum ConfigMode { /// In [`ConfigMode::Dev`] we're allowed to be more relaxed with settings + #[serde(rename = "dev")] Dev, /// In [`ConfigMode::Prod`] we're more stringent with settings + #[serde(rename = "prod")] Prod, } -impl<'de> serde::Deserialize<'de> for ConfigMode { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct StringVisitor; - impl<'de> Visitor<'de> for StringVisitor { - type Value = ConfigMode; - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string 'dev' or 'prod'") - } - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - match value { - "dev" => Ok(ConfigMode::Dev), - "prod" => Ok(ConfigMode::Prod), - _ => Err(de::Error::custom(format!( - "expected 'dev' or 'prod', got {}", - value - ))), - } - } - } - deserializer.deserialize_str(StringVisitor) - } -} - /* config system */ @@ -205,19 +183,38 @@ impl<'de> serde::Deserialize<'de> for ConfigMode { pub struct ConfigSystem { /// time window in seconds for the reliability system to kick-in automatically reliability_system_window: u64, - /// if or not auth is enabled - auth: bool, } impl ConfigSystem { - pub fn new(reliability_system_window: u64, auth: bool) -> Self { + pub fn new(reliability_system_window: u64) -> Self { Self { reliability_system_window, - auth, } } } +/* + config auth +*/ + +#[derive(Debug, PartialEq, Deserialize)] +pub enum AuthDriver { + #[serde(rename = "pwd")] + Pwd, +} + +#[derive(Debug, PartialEq, Deserialize)] +pub struct ConfigAuth { + plugin: AuthDriver, + root_key: String, +} + +impl ConfigAuth { + pub fn new(plugin: AuthDriver, root_key: String) -> Self { + Self { plugin, root_key } + } +} + /** decoded configuration --- @@ -227,6 +224,7 @@ impl ConfigSystem { pub struct DecodedConfiguration { system: Option, endpoints: Option, + auth: Option, } impl Default for DecodedConfiguration { @@ -234,14 +232,20 @@ impl Default for DecodedConfiguration { Self { system: Default::default(), endpoints: Default::default(), + auth: None, } } } +#[derive(Debug, PartialEq, Deserialize)] +pub struct DecodedAuth { + plugin: AuthDriver, + root_pass: String, +} + #[derive(Debug, PartialEq, Deserialize)] /// Decoded system configuration pub struct DecodedSystemConfig { - auth_enabled: Option, mode: Option, rs_window: Option, } @@ -259,7 +263,7 @@ pub struct DecodedEPSecureConfig { host: String, port: u16, cert: String, - pass: String, + private_key: String, } #[derive(Debug, PartialEq, Deserialize)] @@ -375,9 +379,10 @@ impl From for ConfigError { /// A configuration source implementation pub(super) trait ConfigurationSource { + const KEY_AUTH_DRIVER: &'static str; + const KEY_AUTH_ROOT_PASSWORD: &'static str; const KEY_TLS_CERT: &'static str; const KEY_TLS_KEY: &'static str; - const KEY_AUTH: &'static str; const KEY_ENDPOINTS: &'static str; const KEY_RUN_MODE: &'static str; const KEY_SERVICE_WINDOW: &'static str; @@ -467,7 +472,7 @@ fn decode_tls_ep( host: host.into(), port, cert: tls_cert, - pass: tls_key, + private_key: tls_key, }) } @@ -501,25 +506,33 @@ fn arg_decode_tls_endpoint( decode options */ -/// Check the auth mode. We currently only allow `pwd` fn arg_decode_auth( - args: &[String], + src_args: &mut ParsedRawArgs, config: &mut ModifyGuard, ) -> ConfigResult<()> { - argck_duplicate_values::(&args, CS::KEY_AUTH)?; - match args[0].as_str() { - "pwd" => match config.system.as_mut() { - Some(cfg) => cfg.auth_enabled = Some(true), - _ => { - config.system = Some(DecodedSystemConfig { - auth_enabled: Some(true), - mode: None, - rs_window: None, - }) - } - }, - _ => return Err(CS::err_invalid_value_for(CS::KEY_AUTH)), - } + let (Some(auth_driver), Some(mut root_key)) = ( + src_args.remove(CS::KEY_AUTH_DRIVER), + src_args.remove(CS::KEY_AUTH_ROOT_PASSWORD), + ) else { + return Err(ConfigError::with_src( + CS::SOURCE, + ConfigErrorKind::ErrorString(format!( + "to enable auth, you must provide values for both {} and {}", + CS::KEY_AUTH_DRIVER, + CS::KEY_AUTH_ROOT_PASSWORD + )), + )); + }; + argck_duplicate_values::(&auth_driver, CS::KEY_AUTH_DRIVER)?; + argck_duplicate_values::(&root_key, CS::KEY_AUTH_DRIVER)?; + let auth_plugin = match auth_driver[0].as_str() { + "pwd" => AuthDriver::Pwd, + _ => return Err(CS::err_invalid_value_for(CS::KEY_AUTH_DRIVER)), + }; + config.auth = Some(DecodedAuth { + plugin: auth_plugin, + root_pass: root_key.remove(0), + }); Ok(()) } @@ -576,7 +589,6 @@ fn arg_decode_mode( Some(s) => s.mode = Some(mode), None => { config.system = Some(DecodedSystemConfig { - auth_enabled: None, mode: Some(mode), rs_window: None, }) @@ -596,7 +608,6 @@ fn arg_decode_rs_window( Some(sys) => sys.rs_window = Some(n), None => { config.system = Some(DecodedSystemConfig { - auth_enabled: None, mode: None, rs_window: Some(n), }) @@ -612,7 +623,7 @@ fn arg_decode_rs_window( */ /// CLI help message -pub(super) const CLI_HELP: &str ="\ +pub(super) const CLI_HELP: &str = "\ Usage: skyd [OPTION]... skyd is the Skytable database server daemon and can be used to serve database requests. @@ -622,20 +633,23 @@ Flags: -v, --version Display the version number and exit. Options: - --tlscert Specify the path to the TLS certificate. - --tlskey Define the path to the TLS private key. - --endpoint Designate an endpoint. Format: protocol@host:port. - This option can be repeated to define multiple endpoints. - --service-window Establish the time window for the background service in seconds. - --auth Identify the authentication plugin by name. - --mode Set the operational mode. Note: This option is mandatory. + --tlscert Specify the path to the TLS certificate. + --tlskey Define the path to the TLS private key. + --endpoint Designate an endpoint. Format: protocol@host:port. + This option can be repeated to define multiple endpoints. + --service-window Establish the time window for the background service in seconds. + --auth Identify the authentication plugin by name. + --mode Set the operational mode. Note: This option is mandatory. + --auth-plugin Set the auth plugin. `pwd` is a supported option + --auth-root-password Set the root password Examples: - skyd --mode=dev --endpoint=tcp@127.0.0.1:2003 + skyd --mode=dev --endpoint tcp@127.0.0.1:2003 Notes: - Ensure the 'mode' is always provided, as it is essential for the application's correct functioning. - When either of `--help` or `--version` is provided, all other options and flags are ignored. + - When no mode is provided, `--mode=dev` is defaulted to + - When either of `-h` or `-v` is provided, all other options and flags are ignored. + - When `--auth-plugin` is provided, you must provide a value for `--auth-root-password` For further assistance, refer to the official documentation here: https://docs.skytable.org "; @@ -733,8 +747,9 @@ pub fn parse_cli_args<'a, T: 'a + AsRef>( /// Parse environment variables pub fn parse_env_args() -> ConfigResult> { - const KEYS: [&str; 6] = [ - CSEnvArgs::KEY_AUTH, + const KEYS: [&str; 7] = [ + CSEnvArgs::KEY_AUTH_DRIVER, + CSEnvArgs::KEY_AUTH_ROOT_PASSWORD, CSEnvArgs::KEY_ENDPOINTS, CSEnvArgs::KEY_RUN_MODE, CSEnvArgs::KEY_SERVICE_WINDOW, @@ -743,7 +758,7 @@ pub fn parse_env_args() -> ConfigResult> { ]; let mut ret = HashMap::new(); for key in KEYS { - let var = match get_var(key) { + let var = match get_var_from_store(key) { Ok(v) => v, Err(e) => match e { std::env::VarError::NotPresent => continue, @@ -785,8 +800,7 @@ fn apply_config_changes( } let decode_tasks = [ // auth - DecodeKind::Simple { - key: CS::KEY_AUTH, + DecodeKind::Complex { f: arg_decode_auth::, }, // mode @@ -835,9 +849,10 @@ impl CSCommandLine { const ARG_CONFIG_FILE: &'static str = "--config"; } impl ConfigurationSource for CSCommandLine { + const KEY_AUTH_DRIVER: &'static str = "--auth-plugin"; + const KEY_AUTH_ROOT_PASSWORD: &'static str = "--auth-root-password"; const KEY_TLS_CERT: &'static str = "--tlscert"; const KEY_TLS_KEY: &'static str = "--tlskey"; - const KEY_AUTH: &'static str = "--auth"; const KEY_ENDPOINTS: &'static str = "--endpoint"; const KEY_RUN_MODE: &'static str = "--mode"; const KEY_SERVICE_WINDOW: &'static str = "--service-window"; @@ -846,9 +861,10 @@ impl ConfigurationSource for CSCommandLine { pub struct CSEnvArgs; impl ConfigurationSource for CSEnvArgs { + const KEY_AUTH_DRIVER: &'static str = "SKYDB_AUTH_PLUGIN"; + const KEY_AUTH_ROOT_PASSWORD: &'static str = "SKYDB_AUTH_ROOT_PASSWORD"; const KEY_TLS_CERT: &'static str = "SKYDB_TLS_CERT"; const KEY_TLS_KEY: &'static str = "SKYDB_TLS_KEY"; - const KEY_AUTH: &'static str = "SKYDB_AUTH"; const KEY_ENDPOINTS: &'static str = "SKYDB_ENDPOINTS"; const KEY_RUN_MODE: &'static str = "SKYDB_RUN_MODE"; const KEY_SERVICE_WINDOW: &'static str = "SKYDB_SERVICE_WINDOW"; @@ -857,9 +873,10 @@ impl ConfigurationSource for CSEnvArgs { pub struct CSConfigFile; impl ConfigurationSource for CSConfigFile { + const KEY_AUTH_DRIVER: &'static str = "auth.plugin"; + const KEY_AUTH_ROOT_PASSWORD: &'static str = "auth.root_password"; const KEY_TLS_CERT: &'static str = "endpoints.secure.cert"; const KEY_TLS_KEY: &'static str = "endpoints.secure.key"; - const KEY_AUTH: &'static str = "system.auth"; const KEY_ENDPOINTS: &'static str = "endpoints"; const KEY_RUN_MODE: &'static str = "system.mode"; const KEY_SERVICE_WINDOW: &'static str = "system.service_window"; @@ -886,14 +903,17 @@ macro_rules! err_if { /// Validate the configuration, and prepare the final configuration fn validate_configuration( - DecodedConfiguration { system, endpoints }: DecodedConfiguration, + DecodedConfiguration { + system, + endpoints, + auth, + }: DecodedConfiguration, ) -> ConfigResult { // initialize our default configuration let mut config = Configuration::default_dev_mode(); // mutate if_some!( system => |system: DecodedSystemConfig| { - if_some!(system.auth_enabled => |auth| config.system.auth = auth); if_some!(system.mode => |mode| config.mode = mode); if_some!(system.rs_window => |window| config.system.reliability_system_window = window); } @@ -911,7 +931,7 @@ fn validate_configuration( port: secure.port }, cert: secure.cert, - private_key: secure.pass + private_key: secure.private_key }; match &config.endpoints { ConfigEndpoint::Insecure(is) => if has_insecure { @@ -926,6 +946,20 @@ fn validate_configuration( }) } ); + if let Some(auth) = auth { + if auth.root_pass.len() < ROOT_PASSWORD_MIN_LEN { + return Err(ConfigError::with_src( + CS::SOURCE, + ConfigErrorKind::ErrorString(format!( + "root password must have atleast {ROOT_PASSWORD_MIN_LEN} characters" + )), + )); + } + config.auth = Some(ConfigAuth { + plugin: auth.plugin, + root_key: auth.root_pass, + }); + } // now check a few things err_if!( if config.system.reliability_system_window == 0 => ConfigError::with_src( @@ -982,6 +1016,7 @@ pub(super) fn apply_and_validate( thread_local! { static CLI_SRC: std::cell::RefCell>> = std::cell::RefCell::new(None); static ENV_SRC: std::cell::RefCell>> = std::cell::RefCell::new(None); + static FILE_SRC: std::cell::RefCell> = std::cell::RefCell::new(None); } #[cfg(test)] pub(super) fn set_cli_src(cli: Vec) { @@ -1002,7 +1037,26 @@ pub(super) fn set_env_src(variables: Vec) { *env.borrow_mut() = Some(variables); }) } -fn get_var(name: &str) -> Result { +#[cfg(test)] +pub(super) fn set_file_src(src: &str) { + FILE_SRC.with(|s| { + s.borrow_mut().replace(src.to_string()); + }) +} +fn get_file_from_store(filename: &str) -> ConfigResult { + let _f = filename; + let f; + #[cfg(test)] + { + f = Ok(FILE_SRC.with(|f| f.borrow().clone().unwrap())); + } + #[cfg(not(test))] + { + f = Ok(fs::read_to_string(filename)?); + } + f +} +fn get_var_from_store(name: &str) -> Result { let var; #[cfg(test)] { @@ -1025,7 +1079,7 @@ fn get_var(name: &str) -> Result { } var } -fn get_cli_src() -> Vec { +fn get_cli_from_store() -> Vec { let src; #[cfg(test)] { @@ -1048,7 +1102,7 @@ pub fn check_configuration() -> ConfigResult { // read in our environment variables let env_args = parse_env_args()?; // read in our CLI args (since that can tell us whether we need a configuration file) - let read_cli_args = parse_cli_args(get_cli_src().into_iter())?; + let read_cli_args = parse_cli_args(get_cli_from_store().into_iter())?; let cli_args = match read_cli_args { CLIConfigParseReturn::Default => { // no options were provided in the CLI @@ -1107,15 +1161,29 @@ fn check_config_file( // yes, we only have the config file argck_duplicate_values::(&cfg_file, CSCommandLine::ARG_CONFIG_FILE)?; // read the config file - let file = fs::read_to_string(&cfg_file[0])?; - let config_from_file: DecodedConfiguration = serde_yaml::from_str(&file).map_err(|e| { - ConfigError::with_src( - ConfigSource::File, - ConfigErrorKind::ErrorString(format!( - "failed to parse YAML config file with error: `{e}`" - )), - ) - })?; + let file = get_file_from_store(&cfg_file[0])?; + let mut config_from_file: DecodedConfiguration = + serde_yaml::from_str(&file).map_err(|e| { + ConfigError::with_src( + ConfigSource::File, + ConfigErrorKind::ErrorString(format!( + "failed to parse YAML config file with error: `{e}`" + )), + ) + })?; + // read in the TLS certs (if any) + match config_from_file.endpoints.as_mut() { + Some(ep) => match ep.secure.as_mut() { + Some(secure_ep) => { + let cert = fs::read_to_string(&secure_ep.cert)?; + let private_key = fs::read_to_string(&secure_ep.private_key)?; + secure_ep.cert = cert; + secure_ep.private_key = private_key; + } + None => {} + }, + None => {} + } // done here return validate_configuration::(config_from_file).map(ConfigReturn::Config); } else { diff --git a/server/src/engine/tests/mod.rs b/server/src/engine/tests/mod.rs index 96c89d9b..5e421df3 100644 --- a/server/src/engine/tests/mod.rs +++ b/server/src/engine/tests/mod.rs @@ -27,8 +27,9 @@ mod cfg { use crate::{ engine::config::{ - self, CLIConfigParseReturn, ConfigEndpoint, ConfigEndpointTcp, ConfigEndpointTls, - ConfigMode, ConfigReturn, ConfigSystem, Configuration, ParsedRawArgs, + self, AuthDriver, CLIConfigParseReturn, ConfigAuth, ConfigEndpoint, ConfigEndpointTcp, + ConfigEndpointTls, ConfigMode, ConfigReturn, ConfigSystem, Configuration, + ParsedRawArgs, }, util::test_utils::with_files, }; @@ -98,7 +99,9 @@ mod cfg { --service-window=600 \ --tlskey {pkey} \ --tlscert {cert} \ - --auth pwd" + --auth-plugin pwd \ + --auth-root-password password12345678 + " ); let cfg = extract_cli_args(&payload); let ret = config::apply_and_validate::(cfg) @@ -116,7 +119,8 @@ mod cfg { ) ), ConfigMode::Dev, - ConfigSystem::new(600, true) + ConfigSystem::new(600), + Some(ConfigAuth::new(AuthDriver::Pwd, "password12345678".into())) ) ) }, @@ -171,7 +175,8 @@ mod cfg { let variables = [ format!("SKYDB_TLS_CERT=/var/skytable/keys/cert.pem"), format!("SKYDB_TLS_KEY=/var/skytable/keys/private.key"), - format!("SKYDB_AUTH=pwd"), + format!("SKYDB_AUTH_PLUGIN=pwd"), + format!("SKYDB_AUTH_ROOT_PASSWORD=password12345678"), format!("SKYDB_ENDPOINTS=tcp@localhost:8080"), format!("SKYDB_RUN_MODE=dev"), format!("SKYDB_SERVICE_WINDOW=600"), @@ -186,7 +191,8 @@ mod cfg { let variables = [ format!("SKYDB_TLS_CERT=/var/skytable/keys/cert.pem"), format!("SKYDB_TLS_KEY=/var/skytable/keys/private.key"), - format!("SKYDB_AUTH=pwd"), + format!("SKYDB_AUTH_PLUGIN=pwd"), + format!("SKYDB_AUTH_ROOT_PASSWORD=password12345678"), format!("SKYDB_ENDPOINTS=tcp@localhost:8080,tls@localhost:8081"), format!("SKYDB_RUN_MODE=dev"), format!("SKYDB_SERVICE_WINDOW=600"), @@ -202,9 +208,10 @@ mod cfg { ["__env_args_test_cert.pem", "__env_args_test_private.key"], |[cert, key]| { let variables = [ + format!("SKYDB_AUTH_PLUGIN=pwd"), + format!("SKYDB_AUTH_ROOT_PASSWORD=password12345678"), format!("SKYDB_TLS_CERT={cert}"), format!("SKYDB_TLS_KEY={key}"), - format!("SKYDB_AUTH=pwd"), format!("SKYDB_ENDPOINTS=tcp@localhost:8080,tls@localhost:8081"), format!("SKYDB_RUN_MODE=dev"), format!("SKYDB_SERVICE_WINDOW=600"), @@ -223,10 +230,57 @@ mod cfg { ) ), ConfigMode::Dev, - ConfigSystem::new(600, true) + ConfigSystem::new(600), + Some(ConfigAuth::new(AuthDriver::Pwd, "password12345678".into())) ) ) }, ); } + const CONFIG_FILE: &str = "\ +system: + mode: dev + rs_window: 600 + +auth: + plugin: pwd + root_pass: password12345678 + +endpoints: + secure: + host: 127.0.0.1 + port: 2004 + cert: ._test_sample_cert.pem + private_key: ._test_sample_private.key + insecure: + host: 127.0.0.1 + port: 2003 + "; + #[test] + fn test_config_file() { + with_files( + ["._test_sample_cert.pem", "._test_sample_private.key"], + |_| { + config::set_cli_src(vec!["skyd".into(), "--config=config.yml".into()]); + config::set_file_src(CONFIG_FILE); + let cfg = config::check_configuration().unwrap().into_config(); + assert_eq!( + cfg, + Configuration::new( + ConfigEndpoint::Multi( + ConfigEndpointTcp::new("127.0.0.1".into(), 2003), + ConfigEndpointTls::new( + ConfigEndpointTcp::new("127.0.0.1".into(), 2004), + "".into(), + "".into() + ) + ), + ConfigMode::Dev, + ConfigSystem::new(600), + Some(ConfigAuth::new(AuthDriver::Pwd, "password12345678".into())) + ) + ) + }, + ) + } }