Ensure full auth config is read

next
Sayan Nandan 1 year ago
parent c35e35b9c8
commit c4d51ac8e7
No known key found for this signature in database
GPG Key ID: 42EEDF4AE9D96B54

@ -27,10 +27,7 @@
use {
crate::util::os::SysIOError,
core::fmt,
serde::{
de::{self, Deserializer, Visitor},
Deserialize,
},
serde::Deserialize,
std::{collections::HashMap, fs},
};
@ -39,6 +36,7 @@ use {
*/
pub type ParsedRawArgs = std::collections::HashMap<String, Vec<String>>;
pub const ROOT_PASSWORD_MIN_LEN: usize = 16;
#[derive(Debug, PartialEq)]
pub struct ModifyGuard<T> {
@ -85,14 +83,21 @@ pub struct Configuration {
endpoints: ConfigEndpoint,
mode: ConfigMode,
system: ConfigSystem,
auth: Option<ConfigAuth>,
}
impl Configuration {
pub fn new(endpoints: ConfigEndpoint, mode: ConfigMode, system: ConfigSystem) -> Self {
pub fn new(
endpoints: ConfigEndpoint,
mode: ConfigMode,
system: ConfigSystem,
auth: Option<ConfigAuth>,
) -> Self {
Self {
endpoints,
mode,
system,
auth,
}
}
const DEFAULT_HOST: &'static str = "127.0.0.1";
@ -107,8 +112,8 @@ impl Configuration {
mode: ConfigMode::Dev,
system: ConfigSystem {
reliability_system_window: Self::DEFAULT_RELIABILITY_SVC_PING,
auth: false,
},
auth: None,
}
}
}
@ -158,44 +163,17 @@ impl ConfigEndpointTls {
config mode
*/
#[derive(Debug, PartialEq)]
#[derive(Debug, PartialEq, Deserialize)]
/// The configuration mode
pub enum ConfigMode {
/// In [`ConfigMode::Dev`] we're allowed to be more relaxed with settings
#[serde(rename = "dev")]
Dev,
/// In [`ConfigMode::Prod`] we're more stringent with settings
#[serde(rename = "prod")]
Prod,
}
impl<'de> serde::Deserialize<'de> for ConfigMode {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct StringVisitor;
impl<'de> Visitor<'de> for StringVisitor {
type Value = ConfigMode;
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
formatter.write_str("a string 'dev' or 'prod'")
}
fn visit_str<E>(self, value: &str) -> Result<ConfigMode, E>
where
E: de::Error,
{
match value {
"dev" => Ok(ConfigMode::Dev),
"prod" => Ok(ConfigMode::Prod),
_ => Err(de::Error::custom(format!(
"expected 'dev' or 'prod', got {}",
value
))),
}
}
}
deserializer.deserialize_str(StringVisitor)
}
}
/*
config system
*/
@ -205,19 +183,38 @@ impl<'de> serde::Deserialize<'de> for ConfigMode {
pub struct ConfigSystem {
/// time window in seconds for the reliability system to kick-in automatically
reliability_system_window: u64,
/// if or not auth is enabled
auth: bool,
}
impl ConfigSystem {
pub fn new(reliability_system_window: u64, auth: bool) -> Self {
pub fn new(reliability_system_window: u64) -> Self {
Self {
reliability_system_window,
auth,
}
}
}
/*
config auth
*/
#[derive(Debug, PartialEq, Deserialize)]
pub enum AuthDriver {
#[serde(rename = "pwd")]
Pwd,
}
#[derive(Debug, PartialEq, Deserialize)]
pub struct ConfigAuth {
plugin: AuthDriver,
root_key: String,
}
impl ConfigAuth {
pub fn new(plugin: AuthDriver, root_key: String) -> Self {
Self { plugin, root_key }
}
}
/**
decoded configuration
---
@ -227,6 +224,7 @@ impl ConfigSystem {
pub struct DecodedConfiguration {
system: Option<DecodedSystemConfig>,
endpoints: Option<DecodedEPConfig>,
auth: Option<DecodedAuth>,
}
impl Default for DecodedConfiguration {
@ -234,14 +232,20 @@ impl Default for DecodedConfiguration {
Self {
system: Default::default(),
endpoints: Default::default(),
auth: None,
}
}
}
#[derive(Debug, PartialEq, Deserialize)]
pub struct DecodedAuth {
plugin: AuthDriver,
root_pass: String,
}
#[derive(Debug, PartialEq, Deserialize)]
/// Decoded system configuration
pub struct DecodedSystemConfig {
auth_enabled: Option<bool>,
mode: Option<ConfigMode>,
rs_window: Option<u64>,
}
@ -259,7 +263,7 @@ pub struct DecodedEPSecureConfig {
host: String,
port: u16,
cert: String,
pass: String,
private_key: String,
}
#[derive(Debug, PartialEq, Deserialize)]
@ -375,9 +379,10 @@ impl From<std::io::Error> for ConfigError {
/// A configuration source implementation
pub(super) trait ConfigurationSource {
const KEY_AUTH_DRIVER: &'static str;
const KEY_AUTH_ROOT_PASSWORD: &'static str;
const KEY_TLS_CERT: &'static str;
const KEY_TLS_KEY: &'static str;
const KEY_AUTH: &'static str;
const KEY_ENDPOINTS: &'static str;
const KEY_RUN_MODE: &'static str;
const KEY_SERVICE_WINDOW: &'static str;
@ -467,7 +472,7 @@ fn decode_tls_ep(
host: host.into(),
port,
cert: tls_cert,
pass: tls_key,
private_key: tls_key,
})
}
@ -501,25 +506,33 @@ fn arg_decode_tls_endpoint<CS: ConfigurationSource>(
decode options
*/
/// Check the auth mode. We currently only allow `pwd`
fn arg_decode_auth<CS: ConfigurationSource>(
args: &[String],
src_args: &mut ParsedRawArgs,
config: &mut ModifyGuard<DecodedConfiguration>,
) -> ConfigResult<()> {
argck_duplicate_values::<CS>(&args, CS::KEY_AUTH)?;
match args[0].as_str() {
"pwd" => match config.system.as_mut() {
Some(cfg) => cfg.auth_enabled = Some(true),
_ => {
config.system = Some(DecodedSystemConfig {
auth_enabled: Some(true),
mode: None,
rs_window: None,
})
}
},
_ => return Err(CS::err_invalid_value_for(CS::KEY_AUTH)),
}
let (Some(auth_driver), Some(mut root_key)) = (
src_args.remove(CS::KEY_AUTH_DRIVER),
src_args.remove(CS::KEY_AUTH_ROOT_PASSWORD),
) else {
return Err(ConfigError::with_src(
CS::SOURCE,
ConfigErrorKind::ErrorString(format!(
"to enable auth, you must provide values for both {} and {}",
CS::KEY_AUTH_DRIVER,
CS::KEY_AUTH_ROOT_PASSWORD
)),
));
};
argck_duplicate_values::<CS>(&auth_driver, CS::KEY_AUTH_DRIVER)?;
argck_duplicate_values::<CS>(&root_key, CS::KEY_AUTH_DRIVER)?;
let auth_plugin = match auth_driver[0].as_str() {
"pwd" => AuthDriver::Pwd,
_ => return Err(CS::err_invalid_value_for(CS::KEY_AUTH_DRIVER)),
};
config.auth = Some(DecodedAuth {
plugin: auth_plugin,
root_pass: root_key.remove(0),
});
Ok(())
}
@ -576,7 +589,6 @@ fn arg_decode_mode<CS: ConfigurationSource>(
Some(s) => s.mode = Some(mode),
None => {
config.system = Some(DecodedSystemConfig {
auth_enabled: None,
mode: Some(mode),
rs_window: None,
})
@ -596,7 +608,6 @@ fn arg_decode_rs_window<CS: ConfigurationSource>(
Some(sys) => sys.rs_window = Some(n),
None => {
config.system = Some(DecodedSystemConfig {
auth_enabled: None,
mode: None,
rs_window: Some(n),
})
@ -612,7 +623,7 @@ fn arg_decode_rs_window<CS: ConfigurationSource>(
*/
/// CLI help message
pub(super) const CLI_HELP: &str ="\
pub(super) const CLI_HELP: &str = "\
Usage: skyd [OPTION]...
skyd is the Skytable database server daemon and can be used to serve database requests.
@ -622,20 +633,23 @@ Flags:
-v, --version Display the version number and exit.
Options:
--tlscert <path> Specify the path to the TLS certificate.
--tlskey <path> Define the path to the TLS private key.
--endpoint <definition> Designate an endpoint. Format: protocol@host:port.
This option can be repeated to define multiple endpoints.
--service-window <seconds> Establish the time window for the background service in seconds.
--auth <plugin_name> Identify the authentication plugin by name.
--mode <dev/prod> Set the operational mode. Note: This option is mandatory.
--tlscert <path> Specify the path to the TLS certificate.
--tlskey <path> Define the path to the TLS private key.
--endpoint <definition> Designate an endpoint. Format: protocol@host:port.
This option can be repeated to define multiple endpoints.
--service-window <seconds> Establish the time window for the background service in seconds.
--auth <plugin_name> Identify the authentication plugin by name.
--mode <dev/prod> Set the operational mode. Note: This option is mandatory.
--auth-plugin <plugin> Set the auth plugin. `pwd` is a supported option
--auth-root-password <pass> Set the root password
Examples:
skyd --mode=dev --endpoint=tcp@127.0.0.1:2003
skyd --mode=dev --endpoint tcp@127.0.0.1:2003
Notes:
Ensure the 'mode' is always provided, as it is essential for the application's correct functioning.
When either of `--help` or `--version` is provided, all other options and flags are ignored.
- When no mode is provided, `--mode=dev` is defaulted to
- When either of `-h` or `-v` is provided, all other options and flags are ignored.
- When `--auth-plugin` is provided, you must provide a value for `--auth-root-password`
For further assistance, refer to the official documentation here: https://docs.skytable.org
";
@ -733,8 +747,9 @@ pub fn parse_cli_args<'a, T: 'a + AsRef<str>>(
/// Parse environment variables
pub fn parse_env_args() -> ConfigResult<Option<ParsedRawArgs>> {
const KEYS: [&str; 6] = [
CSEnvArgs::KEY_AUTH,
const KEYS: [&str; 7] = [
CSEnvArgs::KEY_AUTH_DRIVER,
CSEnvArgs::KEY_AUTH_ROOT_PASSWORD,
CSEnvArgs::KEY_ENDPOINTS,
CSEnvArgs::KEY_RUN_MODE,
CSEnvArgs::KEY_SERVICE_WINDOW,
@ -743,7 +758,7 @@ pub fn parse_env_args() -> ConfigResult<Option<ParsedRawArgs>> {
];
let mut ret = HashMap::new();
for key in KEYS {
let var = match get_var(key) {
let var = match get_var_from_store(key) {
Ok(v) => v,
Err(e) => match e {
std::env::VarError::NotPresent => continue,
@ -785,8 +800,7 @@ fn apply_config_changes<CS: ConfigurationSource>(
}
let decode_tasks = [
// auth
DecodeKind::Simple {
key: CS::KEY_AUTH,
DecodeKind::Complex {
f: arg_decode_auth::<CS>,
},
// mode
@ -835,9 +849,10 @@ impl CSCommandLine {
const ARG_CONFIG_FILE: &'static str = "--config";
}
impl ConfigurationSource for CSCommandLine {
const KEY_AUTH_DRIVER: &'static str = "--auth-plugin";
const KEY_AUTH_ROOT_PASSWORD: &'static str = "--auth-root-password";
const KEY_TLS_CERT: &'static str = "--tlscert";
const KEY_TLS_KEY: &'static str = "--tlskey";
const KEY_AUTH: &'static str = "--auth";
const KEY_ENDPOINTS: &'static str = "--endpoint";
const KEY_RUN_MODE: &'static str = "--mode";
const KEY_SERVICE_WINDOW: &'static str = "--service-window";
@ -846,9 +861,10 @@ impl ConfigurationSource for CSCommandLine {
pub struct CSEnvArgs;
impl ConfigurationSource for CSEnvArgs {
const KEY_AUTH_DRIVER: &'static str = "SKYDB_AUTH_PLUGIN";
const KEY_AUTH_ROOT_PASSWORD: &'static str = "SKYDB_AUTH_ROOT_PASSWORD";
const KEY_TLS_CERT: &'static str = "SKYDB_TLS_CERT";
const KEY_TLS_KEY: &'static str = "SKYDB_TLS_KEY";
const KEY_AUTH: &'static str = "SKYDB_AUTH";
const KEY_ENDPOINTS: &'static str = "SKYDB_ENDPOINTS";
const KEY_RUN_MODE: &'static str = "SKYDB_RUN_MODE";
const KEY_SERVICE_WINDOW: &'static str = "SKYDB_SERVICE_WINDOW";
@ -857,9 +873,10 @@ impl ConfigurationSource for CSEnvArgs {
pub struct CSConfigFile;
impl ConfigurationSource for CSConfigFile {
const KEY_AUTH_DRIVER: &'static str = "auth.plugin";
const KEY_AUTH_ROOT_PASSWORD: &'static str = "auth.root_password";
const KEY_TLS_CERT: &'static str = "endpoints.secure.cert";
const KEY_TLS_KEY: &'static str = "endpoints.secure.key";
const KEY_AUTH: &'static str = "system.auth";
const KEY_ENDPOINTS: &'static str = "endpoints";
const KEY_RUN_MODE: &'static str = "system.mode";
const KEY_SERVICE_WINDOW: &'static str = "system.service_window";
@ -886,14 +903,17 @@ macro_rules! err_if {
/// Validate the configuration, and prepare the final configuration
fn validate_configuration<CS: ConfigurationSource>(
DecodedConfiguration { system, endpoints }: DecodedConfiguration,
DecodedConfiguration {
system,
endpoints,
auth,
}: DecodedConfiguration,
) -> ConfigResult<Configuration> {
// initialize our default configuration
let mut config = Configuration::default_dev_mode();
// mutate
if_some!(
system => |system: DecodedSystemConfig| {
if_some!(system.auth_enabled => |auth| config.system.auth = auth);
if_some!(system.mode => |mode| config.mode = mode);
if_some!(system.rs_window => |window| config.system.reliability_system_window = window);
}
@ -911,7 +931,7 @@ fn validate_configuration<CS: ConfigurationSource>(
port: secure.port
},
cert: secure.cert,
private_key: secure.pass
private_key: secure.private_key
};
match &config.endpoints {
ConfigEndpoint::Insecure(is) => if has_insecure {
@ -926,6 +946,20 @@ fn validate_configuration<CS: ConfigurationSource>(
})
}
);
if let Some(auth) = auth {
if auth.root_pass.len() < ROOT_PASSWORD_MIN_LEN {
return Err(ConfigError::with_src(
CS::SOURCE,
ConfigErrorKind::ErrorString(format!(
"root password must have atleast {ROOT_PASSWORD_MIN_LEN} characters"
)),
));
}
config.auth = Some(ConfigAuth {
plugin: auth.plugin,
root_key: auth.root_pass,
});
}
// now check a few things
err_if!(
if config.system.reliability_system_window == 0 => ConfigError::with_src(
@ -982,6 +1016,7 @@ pub(super) fn apply_and_validate<CS: ConfigurationSource>(
thread_local! {
static CLI_SRC: std::cell::RefCell<Option<Vec<String>>> = std::cell::RefCell::new(None);
static ENV_SRC: std::cell::RefCell<Option<HashMap<String, String>>> = std::cell::RefCell::new(None);
static FILE_SRC: std::cell::RefCell<Option<String>> = std::cell::RefCell::new(None);
}
#[cfg(test)]
pub(super) fn set_cli_src(cli: Vec<String>) {
@ -1002,7 +1037,26 @@ pub(super) fn set_env_src(variables: Vec<String>) {
*env.borrow_mut() = Some(variables);
})
}
fn get_var(name: &str) -> Result<String, std::env::VarError> {
#[cfg(test)]
pub(super) fn set_file_src(src: &str) {
FILE_SRC.with(|s| {
s.borrow_mut().replace(src.to_string());
})
}
fn get_file_from_store(filename: &str) -> ConfigResult<String> {
let _f = filename;
let f;
#[cfg(test)]
{
f = Ok(FILE_SRC.with(|f| f.borrow().clone().unwrap()));
}
#[cfg(not(test))]
{
f = Ok(fs::read_to_string(filename)?);
}
f
}
fn get_var_from_store(name: &str) -> Result<String, std::env::VarError> {
let var;
#[cfg(test)]
{
@ -1025,7 +1079,7 @@ fn get_var(name: &str) -> Result<String, std::env::VarError> {
}
var
}
fn get_cli_src() -> Vec<String> {
fn get_cli_from_store() -> Vec<String> {
let src;
#[cfg(test)]
{
@ -1048,7 +1102,7 @@ pub fn check_configuration() -> ConfigResult<ConfigReturn> {
// read in our environment variables
let env_args = parse_env_args()?;
// read in our CLI args (since that can tell us whether we need a configuration file)
let read_cli_args = parse_cli_args(get_cli_src().into_iter())?;
let read_cli_args = parse_cli_args(get_cli_from_store().into_iter())?;
let cli_args = match read_cli_args {
CLIConfigParseReturn::Default => {
// no options were provided in the CLI
@ -1107,15 +1161,29 @@ fn check_config_file(
// yes, we only have the config file
argck_duplicate_values::<CSCommandLine>(&cfg_file, CSCommandLine::ARG_CONFIG_FILE)?;
// read the config file
let file = fs::read_to_string(&cfg_file[0])?;
let config_from_file: DecodedConfiguration = serde_yaml::from_str(&file).map_err(|e| {
ConfigError::with_src(
ConfigSource::File,
ConfigErrorKind::ErrorString(format!(
"failed to parse YAML config file with error: `{e}`"
)),
)
})?;
let file = get_file_from_store(&cfg_file[0])?;
let mut config_from_file: DecodedConfiguration =
serde_yaml::from_str(&file).map_err(|e| {
ConfigError::with_src(
ConfigSource::File,
ConfigErrorKind::ErrorString(format!(
"failed to parse YAML config file with error: `{e}`"
)),
)
})?;
// read in the TLS certs (if any)
match config_from_file.endpoints.as_mut() {
Some(ep) => match ep.secure.as_mut() {
Some(secure_ep) => {
let cert = fs::read_to_string(&secure_ep.cert)?;
let private_key = fs::read_to_string(&secure_ep.private_key)?;
secure_ep.cert = cert;
secure_ep.private_key = private_key;
}
None => {}
},
None => {}
}
// done here
return validate_configuration::<CSConfigFile>(config_from_file).map(ConfigReturn::Config);
} else {

@ -27,8 +27,9 @@
mod cfg {
use crate::{
engine::config::{
self, CLIConfigParseReturn, ConfigEndpoint, ConfigEndpointTcp, ConfigEndpointTls,
ConfigMode, ConfigReturn, ConfigSystem, Configuration, ParsedRawArgs,
self, AuthDriver, CLIConfigParseReturn, ConfigAuth, ConfigEndpoint, ConfigEndpointTcp,
ConfigEndpointTls, ConfigMode, ConfigReturn, ConfigSystem, Configuration,
ParsedRawArgs,
},
util::test_utils::with_files,
};
@ -98,7 +99,9 @@ mod cfg {
--service-window=600 \
--tlskey {pkey} \
--tlscert {cert} \
--auth pwd"
--auth-plugin pwd \
--auth-root-password password12345678
"
);
let cfg = extract_cli_args(&payload);
let ret = config::apply_and_validate::<config::CSCommandLine>(cfg)
@ -116,7 +119,8 @@ mod cfg {
)
),
ConfigMode::Dev,
ConfigSystem::new(600, true)
ConfigSystem::new(600),
Some(ConfigAuth::new(AuthDriver::Pwd, "password12345678".into()))
)
)
},
@ -171,7 +175,8 @@ mod cfg {
let variables = [
format!("SKYDB_TLS_CERT=/var/skytable/keys/cert.pem"),
format!("SKYDB_TLS_KEY=/var/skytable/keys/private.key"),
format!("SKYDB_AUTH=pwd"),
format!("SKYDB_AUTH_PLUGIN=pwd"),
format!("SKYDB_AUTH_ROOT_PASSWORD=password12345678"),
format!("SKYDB_ENDPOINTS=tcp@localhost:8080"),
format!("SKYDB_RUN_MODE=dev"),
format!("SKYDB_SERVICE_WINDOW=600"),
@ -186,7 +191,8 @@ mod cfg {
let variables = [
format!("SKYDB_TLS_CERT=/var/skytable/keys/cert.pem"),
format!("SKYDB_TLS_KEY=/var/skytable/keys/private.key"),
format!("SKYDB_AUTH=pwd"),
format!("SKYDB_AUTH_PLUGIN=pwd"),
format!("SKYDB_AUTH_ROOT_PASSWORD=password12345678"),
format!("SKYDB_ENDPOINTS=tcp@localhost:8080,tls@localhost:8081"),
format!("SKYDB_RUN_MODE=dev"),
format!("SKYDB_SERVICE_WINDOW=600"),
@ -202,9 +208,10 @@ mod cfg {
["__env_args_test_cert.pem", "__env_args_test_private.key"],
|[cert, key]| {
let variables = [
format!("SKYDB_AUTH_PLUGIN=pwd"),
format!("SKYDB_AUTH_ROOT_PASSWORD=password12345678"),
format!("SKYDB_TLS_CERT={cert}"),
format!("SKYDB_TLS_KEY={key}"),
format!("SKYDB_AUTH=pwd"),
format!("SKYDB_ENDPOINTS=tcp@localhost:8080,tls@localhost:8081"),
format!("SKYDB_RUN_MODE=dev"),
format!("SKYDB_SERVICE_WINDOW=600"),
@ -223,10 +230,57 @@ mod cfg {
)
),
ConfigMode::Dev,
ConfigSystem::new(600, true)
ConfigSystem::new(600),
Some(ConfigAuth::new(AuthDriver::Pwd, "password12345678".into()))
)
)
},
);
}
const CONFIG_FILE: &str = "\
system:
mode: dev
rs_window: 600
auth:
plugin: pwd
root_pass: password12345678
endpoints:
secure:
host: 127.0.0.1
port: 2004
cert: ._test_sample_cert.pem
private_key: ._test_sample_private.key
insecure:
host: 127.0.0.1
port: 2003
";
#[test]
fn test_config_file() {
with_files(
["._test_sample_cert.pem", "._test_sample_private.key"],
|_| {
config::set_cli_src(vec!["skyd".into(), "--config=config.yml".into()]);
config::set_file_src(CONFIG_FILE);
let cfg = config::check_configuration().unwrap().into_config();
assert_eq!(
cfg,
Configuration::new(
ConfigEndpoint::Multi(
ConfigEndpointTcp::new("127.0.0.1".into(), 2003),
ConfigEndpointTls::new(
ConfigEndpointTcp::new("127.0.0.1".into(), 2004),
"".into(),
"".into()
)
),
ConfigMode::Dev,
ConfigSystem::new(600),
Some(ConfigAuth::new(AuthDriver::Pwd, "password12345678".into()))
)
)
},
)
}
}

Loading…
Cancel
Save