From be540a7ded484741e3b0891bf678f3b172bed9b5 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Fri, 29 Sep 2023 09:34:22 +0000 Subject: [PATCH] Add net impls Also cleaned up error impls --- server/src/engine/config.rs | 48 +++- server/src/engine/error.rs | 130 +++++++++- server/src/engine/fractal/mgr.rs | 4 +- server/src/engine/fractal/mod.rs | 28 ++- server/src/engine/mod.rs | 6 + server/src/engine/net/mod.rs | 236 +++++++++++++++++- server/src/engine/net/protocol/mod.rs | 4 +- .../engine/storage/v1/batch_jrnl/persist.rs | 12 +- .../engine/storage/v1/batch_jrnl/restore.rs | 32 +-- server/src/engine/storage/v1/inf/map.rs | 25 +- server/src/engine/storage/v1/inf/mod.rs | 10 +- server/src/engine/storage/v1/inf/obj.rs | 8 +- server/src/engine/storage/v1/journal.rs | 18 +- server/src/engine/storage/v1/loader.rs | 15 +- server/src/engine/storage/v1/mod.rs | 80 ++---- server/src/engine/storage/v1/rw.rs | 13 +- server/src/engine/storage/v1/spec.rs | 10 +- server/src/engine/storage/v1/start_stop.rs | 150 ----------- server/src/engine/storage/v1/sysdb.rs | 12 +- server/src/engine/storage/v1/tests/tx.rs | 12 +- server/src/engine/tests/mod.rs | 28 ++- server/src/engine/txn/gns/model.rs | 4 +- 22 files changed, 555 insertions(+), 330 deletions(-) delete mode 100644 server/src/engine/storage/v1/start_stop.rs diff --git a/server/src/engine/config.rs b/server/src/engine/config.rs index 19397155..4c75f009 100644 --- a/server/src/engine/config.rs +++ b/server/src/engine/config.rs @@ -147,14 +147,21 @@ pub struct ConfigEndpointTls { tcp: ConfigEndpointTcp, cert: String, private_key: String, + pkey_pass: String, } impl ConfigEndpointTls { - pub fn new(tcp: ConfigEndpointTcp, cert: String, private_key: String) -> Self { + pub fn new( + tcp: ConfigEndpointTcp, + cert: String, + private_key: String, + pkey_pass: String, + ) -> Self { Self { tcp, cert, private_key, + pkey_pass, } } } @@ -264,6 +271,7 @@ pub struct DecodedEPSecureConfig { port: u16, cert: String, private_key: String, + pkey_passphrase: String, } #[derive(Debug, PartialEq, Deserialize)] @@ -383,6 +391,7 @@ pub(super) trait ConfigurationSource { const KEY_AUTH_ROOT_PASSWORD: &'static str; const KEY_TLS_CERT: &'static str; const KEY_TLS_KEY: &'static str; + const KEY_TLS_PKEY_PASS: &'static str; const KEY_ENDPOINTS: &'static str; const KEY_RUN_MODE: &'static str; const KEY_SERVICE_WINDOW: &'static str; @@ -463,16 +472,19 @@ fn parse_endpoint(source: ConfigSource, s: &str) -> ConfigResult<(ConnectionProt fn decode_tls_ep( cert_path: &str, key_path: &str, + pkey_pass: &str, host: &str, port: u16, ) -> ConfigResult { let tls_key = fs::read_to_string(key_path)?; let tls_cert = fs::read_to_string(cert_path)?; + let tls_priv_key_passphrase = fs::read_to_string(pkey_pass)?; Ok(DecodedEPSecureConfig { host: host.into(), port, cert: tls_cert, private_key: tls_key, + pkey_passphrase: tls_priv_key_passphrase, }) } @@ -484,22 +496,31 @@ fn arg_decode_tls_endpoint( ) -> ConfigResult { let _cert = args.remove(CS::KEY_TLS_CERT); let _key = args.remove(CS::KEY_TLS_KEY); - let (tls_cert, tls_key) = match (_cert, _key) { - (Some(cert), Some(key)) => (cert, key), + let _passphrase = args.remove(CS::KEY_TLS_PKEY_PASS); + let (tls_cert, tls_key, tls_passphrase) = match (_cert, _key, _passphrase) { + (Some(cert), Some(key), Some(pass)) => (cert, key, pass), _ => { return Err(ConfigError::with_src( ConfigSource::Cli, ConfigErrorKind::ErrorString(format!( - "must supply values for both `{}` and `{}` when using TLS", + "must supply values for `{}`, `{}` and `{}` when using TLS", CS::KEY_TLS_CERT, - CS::KEY_TLS_KEY + CS::KEY_TLS_KEY, + CS::KEY_TLS_PKEY_PASS, )), )); } }; argck_duplicate_values::(&tls_cert, CS::KEY_TLS_CERT)?; argck_duplicate_values::(&tls_key, CS::KEY_TLS_KEY)?; - Ok(decode_tls_ep(&tls_cert[0], &tls_key[0], host, port)?) + argck_duplicate_values::(&tls_passphrase, CS::KEY_TLS_PKEY_PASS)?; + Ok(decode_tls_ep( + &tls_cert[0], + &tls_key[0], + &tls_passphrase[0], + host, + port, + )?) } /* @@ -747,7 +768,7 @@ pub fn parse_cli_args<'a, T: 'a + AsRef>( /// Parse environment variables pub fn parse_env_args() -> ConfigResult> { - const KEYS: [&str; 7] = [ + const KEYS: [&str; 8] = [ CSEnvArgs::KEY_AUTH_DRIVER, CSEnvArgs::KEY_AUTH_ROOT_PASSWORD, CSEnvArgs::KEY_ENDPOINTS, @@ -755,6 +776,7 @@ pub fn parse_env_args() -> ConfigResult> { CSEnvArgs::KEY_SERVICE_WINDOW, CSEnvArgs::KEY_TLS_CERT, CSEnvArgs::KEY_TLS_KEY, + CSEnvArgs::KEY_TLS_PKEY_PASS, ]; let mut ret = HashMap::new(); for key in KEYS { @@ -853,6 +875,7 @@ impl ConfigurationSource for CSCommandLine { const KEY_AUTH_ROOT_PASSWORD: &'static str = "--auth-root-password"; const KEY_TLS_CERT: &'static str = "--tlscert"; const KEY_TLS_KEY: &'static str = "--tlskey"; + const KEY_TLS_PKEY_PASS: &'static str = "--tls-passphrase"; const KEY_ENDPOINTS: &'static str = "--endpoint"; const KEY_RUN_MODE: &'static str = "--mode"; const KEY_SERVICE_WINDOW: &'static str = "--service-window"; @@ -865,6 +888,7 @@ impl ConfigurationSource for CSEnvArgs { const KEY_AUTH_ROOT_PASSWORD: &'static str = "SKYDB_AUTH_ROOT_PASSWORD"; const KEY_TLS_CERT: &'static str = "SKYDB_TLS_CERT"; const KEY_TLS_KEY: &'static str = "SKYDB_TLS_KEY"; + const KEY_TLS_PKEY_PASS: &'static str = "SKYDB_TLS_PRIVATE_KEY_PASSWORD"; const KEY_ENDPOINTS: &'static str = "SKYDB_ENDPOINTS"; const KEY_RUN_MODE: &'static str = "SKYDB_RUN_MODE"; const KEY_SERVICE_WINDOW: &'static str = "SKYDB_SERVICE_WINDOW"; @@ -877,6 +901,7 @@ impl ConfigurationSource for CSConfigFile { const KEY_AUTH_ROOT_PASSWORD: &'static str = "auth.root_password"; const KEY_TLS_CERT: &'static str = "endpoints.secure.cert"; const KEY_TLS_KEY: &'static str = "endpoints.secure.key"; + const KEY_TLS_PKEY_PASS: &'static str = "endpoints.secure.pkey_passphrase"; const KEY_ENDPOINTS: &'static str = "endpoints"; const KEY_RUN_MODE: &'static str = "system.mode"; const KEY_SERVICE_WINDOW: &'static str = "system.service_window"; @@ -937,10 +962,11 @@ fn validate_configuration( let secure_ep = ConfigEndpointTls { tcp: ConfigEndpointTcp { host: secure.host, - port: secure.port + port: secure.port, }, cert: secure.cert, - private_key: secure.private_key + private_key: secure.private_key, + pkey_pass: secure.pkey_passphrase, }; match &config.endpoints { ConfigEndpoint::Insecure(is) => if has_insecure { @@ -961,7 +987,7 @@ fn validate_configuration( CS::SOURCE, ConfigErrorKind::ErrorString("invalid value for service window. must be nonzero".into()), ), - if config.auth.root_key.len() <= ROOT_PASSWORD_MIN_LEN => ConfigError::with_src( + if config.auth.root_key.len() < ROOT_PASSWORD_MIN_LEN => ConfigError::with_src( CS::SOURCE, ConfigErrorKind::ErrorString("the root password must have at least 16 characters".into()), ), @@ -1175,8 +1201,10 @@ fn check_config_file( Some(secure_ep) => { let cert = fs::read_to_string(&secure_ep.cert)?; let private_key = fs::read_to_string(&secure_ep.private_key)?; + let private_key_passphrase = fs::read_to_string(&secure_ep.pkey_passphrase)?; secure_ep.cert = cert; secure_ep.private_key = private_key; + secure_ep.pkey_passphrase = private_key_passphrase; } None => {} }, diff --git a/server/src/engine/error.rs b/server/src/engine/error.rs index ecf1c60e..72ebe463 100644 --- a/server/src/engine/error.rs +++ b/server/src/engine/error.rs @@ -24,8 +24,20 @@ * */ -use super::{storage::v1::SDSSError, txn::TransactionError}; +use { + super::{ + storage::v1::{SDSSError, SDSSErrorKind}, + txn::TransactionError, + }, + crate::util::os::SysIOError, + std::fmt, +}; + pub type QueryResult = Result; +// stack +pub type CtxResult = Result>; +pub type RuntimeResult = CtxResult; +pub type RuntimeError = CtxError; /// an enumeration of 'flat' errors that the server actually responds to the client with, since we do not want to send specific information /// about anything (as that will be a security hole). The variants correspond with their actual response codes @@ -102,3 +114,119 @@ direct_from! { TransactionError as TransactionalError, } } + +/* + contextual errors +*/ + +/// An error context +pub enum CtxErrorDescription { + A(&'static str), + B(Box), +} + +impl CtxErrorDescription { + fn inner(&self) -> &str { + match self { + Self::A(a) => a, + Self::B(b) => &b, + } + } +} + +impl PartialEq for CtxErrorDescription { + fn eq(&self, other: &Self) -> bool { + self.inner() == other.inner() + } +} + +impl fmt::Display for CtxErrorDescription { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.inner()) + } +} + +impl fmt::Debug for CtxErrorDescription { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.inner()) + } +} + +direct_from! { + CtxErrorDescription => { + &'static str as A, + String as B, + Box as B, + } +} + +#[derive(Debug)] +#[cfg_attr(test, derive(PartialEq))] +/// A contextual error +pub struct CtxError { + kind: E, + ctx: Option, +} + +impl CtxError { + fn _new(kind: E, ctx: Option) -> Self { + Self { kind, ctx } + } + pub fn new(kind: E) -> Self { + Self::_new(kind, None) + } + pub fn with_ctx(kind: E, ctx: impl Into) -> Self { + Self::_new(kind, Some(ctx.into())) + } + pub fn add_ctx(self, ctx: impl Into) -> Self { + Self::with_ctx(self.kind, ctx) + } + pub fn into_result(self) -> CtxResult { + Err(self) + } + pub fn result(result: Result) -> CtxResult + where + E: From, + { + result.map_err(|e| CtxError::new(e.into())) + } + pub fn result_ctx( + result: Result, + ctx: impl Into, + ) -> CtxResult + where + E: From, + { + result.map_err(|e| CtxError::with_ctx(e.into(), ctx)) + } +} + +macro_rules! impl_from_hack { + ($($ty:ty),*) => { + $(impl From for CtxError<$ty> where E: Into<$ty> {fn from(e: E) -> Self { CtxError::new(e.into()) }})* + } +} + +/* + Contextual error impls +*/ + +impl_from_hack!(RuntimeErrorKind, SDSSErrorKind); + +#[derive(Debug)] +pub enum RuntimeErrorKind { + StorageSubsytem(SDSSError), + IoError(SysIOError), + OSSLErrorMulti(openssl::error::ErrorStack), + OSSLError(openssl::ssl::Error), +} + +direct_from! { + RuntimeErrorKind => { + SDSSError as StorageSubsytem, + std::io::Error as IoError, + SysIOError as IoError, + openssl::error::ErrorStack as OSSLErrorMulti, + openssl::ssl::Error as OSSLError, + } +} diff --git a/server/src/engine/fractal/mgr.rs b/server/src/engine/fractal/mgr.rs index 72f46d29..72ea4638 100644 --- a/server/src/engine/fractal/mgr.rs +++ b/server/src/engine/fractal/mgr.rs @@ -174,8 +174,8 @@ impl FractalMgr { hp_receiver: UnboundedReceiver>, ) -> FractalServiceHandles { let fractal_mgr = global.get_state().fractal_mgr(); - let global_1 = global.__global_clone(); - let global_2 = global.__global_clone(); + let global_1 = global.clone(); + let global_2 = global.clone(); let hp_handle = tokio::spawn(async move { FractalMgr::hp_executor_svc(fractal_mgr, global_1, hp_receiver).await }); diff --git a/server/src/engine/fractal/mod.rs b/server/src/engine/fractal/mod.rs index 46159de8..25fa5866 100644 --- a/server/src/engine/fractal/mod.rs +++ b/server/src/engine/fractal/mod.rs @@ -53,8 +53,6 @@ pub use { pub type ModelDrivers = HashMap>; -static mut GLOBAL: MaybeUninit = MaybeUninit::uninit(); - /* global state init */ @@ -89,10 +87,10 @@ pub unsafe fn enable_and_start_all( mgr::FractalMgr::new(hp_sender, lp_sender, model_cnt_on_boot), config, ); - GLOBAL = MaybeUninit::new(global_state); + *Global::__gref_raw() = MaybeUninit::new(global_state); let token = Global::new(); GlobalStateStart { - global: token.__global_clone(), + global: token.clone(), mgr_handles: mgr::FractalMgr::start_all(token, lp_recv, hp_recv), } } @@ -187,7 +185,7 @@ impl GlobalInstanceLike for Global { } } -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] /// A handle to the global state pub struct Global(()); @@ -195,18 +193,12 @@ impl Global { unsafe fn new() -> Self { Self(()) } - fn __global_clone(&self) -> Self { - unsafe { - // UNSAFE(@ohsayan): safe to call within this module - Self::new() - } - } fn get_state(&self) -> &'static GlobalState { - unsafe { GLOBAL.assume_init_ref() } + unsafe { self.__gref() } } /// Returns a handle to the [`GlobalNS`] fn _namespace(&self) -> &'static GlobalNS { - &unsafe { GLOBAL.assume_init_ref() }.gns + &unsafe { self.__gref() }.gns } /// Post an urgent task fn _post_high_priority_task(&self, task: Task) { @@ -227,6 +219,16 @@ impl Global { .get_rt_stat() .per_mdl_delta_max_size() } + unsafe fn __gref_raw() -> &'static mut MaybeUninit { + static mut G: MaybeUninit = MaybeUninit::uninit(); + &mut G + } + unsafe fn __gref(&self) -> &'static GlobalState { + Self::__gref_raw().assume_init_ref() + } + pub unsafe fn unload_all(self) { + core::ptr::drop_in_place(Self::__gref_raw().as_mut_ptr()) + } } /* diff --git a/server/src/engine/mod.rs b/server/src/engine/mod.rs index db600b35..a7f43885 100644 --- a/server/src/engine/mod.rs +++ b/server/src/engine/mod.rs @@ -43,3 +43,9 @@ mod txn; // test #[cfg(test)] mod tests; + +use error::RuntimeResult; + +pub fn load_all() -> RuntimeResult { + todo!() +} diff --git a/server/src/engine/net/mod.rs b/server/src/engine/net/mod.rs index 213135b9..c90db058 100644 --- a/server/src/engine/net/mod.rs +++ b/server/src/engine/net/mod.rs @@ -24,13 +24,247 @@ * */ -use tokio::io::{AsyncRead, AsyncWrite}; mod protocol; +use { + crate::engine::{ + error::{RuntimeError, RuntimeResult}, + fractal::Global, + }, + bytes::BytesMut, + openssl::{ + pkey::PKey, + ssl::Ssl, + ssl::{SslAcceptor, SslMethod}, + x509::X509, + }, + std::{cell::Cell, net::SocketAddr, pin::Pin, time::Duration}, + tokio::{ + io::{AsyncRead, AsyncWrite, BufWriter}, + net::{TcpListener, TcpStream}, + sync::{broadcast, mpsc, Semaphore}, + }, + tokio_openssl::SslStream, +}; + pub trait Socket: AsyncWrite + AsyncRead + Unpin {} pub type IoResult = Result; +const BUF_WRITE_CAP: usize = 16384; +const BUF_READ_CAP: usize = 16384; +const CLIMIT: usize = 50000; + +static CLIM: Semaphore = Semaphore::const_new(CLIMIT); + +/* + socket definitions +*/ + +impl Socket for TcpStream {} +impl Socket for SslStream {} + pub enum QLoopReturn { Fin, ConnectionRst, } + +struct NetBackoff { + at: Cell, +} + +impl NetBackoff { + const BACKOFF_MAX: u8 = 64; + fn new() -> Self { + Self { at: Cell::new(1) } + } + async fn spin(&self) { + let current = self.at.get(); + self.at.set(current << 1); + tokio::time::sleep(Duration::from_secs(current as _)).await + } + fn should_disconnect(&self) -> bool { + self.at.get() >= Self::BACKOFF_MAX + } +} + +/* + listener +*/ + +/// Connection handler for a remote connection +pub struct ConnectionHandler { + socket: BufWriter, + buffer: BytesMut, + global: Global, + sig_terminate: broadcast::Receiver<()>, + _sig_inflight_complete: mpsc::Sender<()>, +} + +impl ConnectionHandler { + pub fn new( + socket: S, + global: Global, + term_sig: broadcast::Receiver<()>, + _inflight_complete: mpsc::Sender<()>, + ) -> Self { + Self { + socket: BufWriter::with_capacity(BUF_WRITE_CAP, socket), + buffer: BytesMut::with_capacity(BUF_READ_CAP), + global, + sig_terminate: term_sig, + _sig_inflight_complete: _inflight_complete, + } + } + pub async fn run(&mut self) -> IoResult<()> { + let Self { + socket, + buffer, + global, + .. + } = self; + loop { + tokio::select! { + _ = protocol::query_loop(socket, buffer, global) => {}, + _ = self.sig_terminate.recv() => { + return Ok(()) + } + } + } + } +} + +/// A TCP listener bound to a socket +pub struct Listener { + global: Global, + listener: TcpListener, + sig_shutdown: broadcast::Sender<()>, + sig_inflight: mpsc::Sender<()>, + sig_inflight_wait: mpsc::Receiver<()>, +} + +impl Listener { + pub async fn new( + binaddr: &str, + global: Global, + sig_shutdown: broadcast::Sender<()>, + ) -> RuntimeResult { + let (sig_inflight, sig_inflight_wait) = mpsc::channel(1); + let listener = RuntimeError::result_ctx( + TcpListener::bind(binaddr).await, + format!("failed to bind to port `{binaddr}`"), + )?; + Ok(Self { + global, + listener, + sig_shutdown, + sig_inflight, + sig_inflight_wait, + }) + } + pub async fn terminate(self) { + let Self { + mut sig_inflight_wait, + sig_inflight, + sig_shutdown, + .. + } = self; + drop(sig_shutdown); + drop(sig_inflight); // could be that we are the only ones holding this lol + let _ = sig_inflight_wait.recv().await; // wait + } + async fn accept(&mut self) -> IoResult<(TcpStream, SocketAddr)> { + let backoff = NetBackoff::new(); + loop { + match self.listener.accept().await { + Ok(s) => return Ok(s), + Err(e) => { + if backoff.should_disconnect() { + // that's enough of your crappy connection dear sir + return Err(e.into()); + } + } + } + backoff.spin().await; + } + } + async fn listen_tcp(&mut self) -> IoResult<()> { + loop { + // acquire a permit + let permit = CLIM.acquire().await.unwrap(); + let (stream, _) = match self.accept().await { + Ok(s) => s, + Err(e) => { + /* + SECURITY: IGNORE THIS ERROR + */ + log::error!("failed to accept connection on TCP socket: `{e}`"); + continue; + } + }; + let mut handler = ConnectionHandler::new( + stream, + self.global, + self.sig_shutdown.subscribe(), + self.sig_inflight.clone(), + ); + tokio::spawn(async move { + if let Err(e) = handler.run().await { + log::error!("error handling client connection: `{e}`"); + } + }); + // return the permit + drop(permit); + } + } + async fn listen_tls( + self: &mut Self, + tls_cert: String, + tls_priv_key: String, + tls_key_password: String, + ) -> RuntimeResult<()> { + let build_acceptor = || { + let cert = X509::from_pem(tls_cert.as_bytes())?; + let priv_key = PKey::private_key_from_pem_passphrase( + tls_priv_key.as_bytes(), + tls_key_password.as_bytes(), + )?; + let mut builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())?; + builder.set_certificate(&cert)?; + builder.set_private_key(&priv_key)?; + builder.check_private_key()?; + Ok::<_, openssl::error::ErrorStack>(builder.build()) + }; + let acceptor = + RuntimeError::result_ctx(build_acceptor(), "failed to initialize TLS socket")?; + loop { + let stream = async { + let (stream, _) = self.accept().await?; + let ssl = Ssl::new(acceptor.context())?; + let mut stream = SslStream::new(ssl, stream)?; + Pin::new(&mut stream).accept().await?; + RuntimeResult::Ok(stream) + }; + let stream = match stream.await { + Ok(s) => s, + Err(e) => { + /* + SECURITY: Once again, ignore this error + */ + log::error!("failed to accept connection on TLS socket: `{e:#?}`"); + continue; + } + }; + let mut handler = ConnectionHandler::new( + stream, + self.global, + self.sig_shutdown.subscribe(), + self.sig_inflight.clone(), + ); + tokio::spawn(async move { + if let Err(e) = handler.run().await { + log::error!("error handling client TLS connection: `{e}`"); + } + }); + } + } +} diff --git a/server/src/engine/net/protocol/mod.rs b/server/src/engine/net/protocol/mod.rs index 81bc3f2f..538ab007 100644 --- a/server/src/engine/net/protocol/mod.rs +++ b/server/src/engine/net/protocol/mod.rs @@ -32,7 +32,7 @@ mod tests; use { self::handshake::{CHandshake, HandshakeResult, HandshakeState}, super::{IoResult, QLoopReturn, Socket}, - crate::engine::mem::BufferedScanner, + crate::engine::{fractal::Global, mem::BufferedScanner}, bytes::{Buf, BytesMut}, tokio::io::{AsyncReadExt, BufWriter}, }; @@ -40,6 +40,7 @@ use { pub async fn query_loop( con: &mut BufWriter, buf: &mut BytesMut, + _global: &Global, ) -> IoResult { // handshake match do_handshake(con, buf).await? { @@ -56,6 +57,7 @@ pub async fn query_loop( } } +#[inline(always)] fn see_if_connection_terminates(read_many: usize, buf: &[u8]) -> Option { if read_many == 0 { // that's a connection termination diff --git a/server/src/engine/storage/v1/batch_jrnl/persist.rs b/server/src/engine/storage/v1/batch_jrnl/persist.rs index f47ee07f..06ba5ad8 100644 --- a/server/src/engine/storage/v1/batch_jrnl/persist.rs +++ b/server/src/engine/storage/v1/batch_jrnl/persist.rs @@ -46,7 +46,7 @@ use { storage::v1::{ inf::PersistTypeDscr, rw::{RawFSInterface, SDSSFileIO, SDSSFileTrackedWriter}, - SDSSError, SDSSResult, + SDSSErrorKind, SDSSResult, }, }, util::EndianQW, @@ -76,7 +76,7 @@ impl DataBatchPersistDriver { { return Ok(()); } else { - return Err(SDSSError::DataBatchCloseError); + return Err(SDSSErrorKind::DataBatchCloseError.into()); } } pub fn write_new_batch(&mut self, model: &Model, observed_len: usize) -> SDSSResult<()> { @@ -154,7 +154,7 @@ impl DataBatchPersistDriver { schema_version: DeltaVersion, pk_tag: TagUnique, col_cnt: usize, - ) -> Result<(), SDSSError> { + ) -> SDSSResult<()> { self.f .unfsynced_write(&[MARKER_ACTUAL_BATCH_EVENT, pk_tag.value_u8()])?; let observed_len_bytes = observed_len.u64_bytes_le(); @@ -169,7 +169,7 @@ impl DataBatchPersistDriver { &mut self, observed_len: usize, inconsistent_reads: usize, - ) -> Result<(), SDSSError> { + ) -> SDSSResult<()> { // [0xFD][actual_commit][checksum] self.f.unfsynced_write(&[MARKER_END_OF_BATCH])?; let actual_commit = (observed_len - inconsistent_reads).u64_bytes_le(); @@ -189,7 +189,7 @@ impl DataBatchPersistDriver { if f.fsynced_write(&[MARKER_RECOVERY_EVENT]).is_ok() { return Ok(()); } - Err(SDSSError::DataBatchRecoveryFailStageOne) + Err(SDSSErrorKind::DataBatchRecoveryFailStageOne.into()) } } @@ -288,7 +288,7 @@ impl DataBatchPersistDriver { Ok(()) } /// Write the change type and txnid - fn write_batch_item_common_row_data(&mut self, delta: &DataDelta) -> Result<(), SDSSError> { + fn write_batch_item_common_row_data(&mut self, delta: &DataDelta) -> SDSSResult<()> { let change_type = [delta.change().value_u8()]; self.f.unfsynced_write(&change_type)?; let txn_id = delta.data_version().value_u64().to_le_bytes(); diff --git a/server/src/engine/storage/v1/batch_jrnl/restore.rs b/server/src/engine/storage/v1/batch_jrnl/restore.rs index 3475cd0b..d6bafbf6 100644 --- a/server/src/engine/storage/v1/batch_jrnl/restore.rs +++ b/server/src/engine/storage/v1/batch_jrnl/restore.rs @@ -42,7 +42,7 @@ use { storage::v1::{ inf::PersistTypeDscr, rw::{RawFSInterface, SDSSFileIO, SDSSFileTrackedReader}, - SDSSError, SDSSResult, + SDSSErrorKind, SDSSResult, }, }, crossbeam_epoch::pin, @@ -187,7 +187,7 @@ impl DataBatchRestoreDriver { } } // nope, this is a corrupted file - Err(SDSSError::DataBatchRestoreCorruptedBatchFile) + Err(SDSSErrorKind::DataBatchRestoreCorruptedBatchFile.into()) } fn handle_reopen_is_actual_close(&mut self) -> SDSSResult { if self.f.is_eof() { @@ -200,7 +200,7 @@ impl DataBatchRestoreDriver { Ok(false) } else { // that's just a nice bug - Err(SDSSError::DataBatchRestoreCorruptedBatchFile) + Err(SDSSErrorKind::DataBatchRestoreCorruptedBatchFile.into()) } } } @@ -303,7 +303,7 @@ impl DataBatchRestoreDriver { // we must read the batch termination signature let b = self.f.read_byte()?; if b != MARKER_END_OF_BATCH { - return Err(SDSSError::DataBatchRestoreCorruptedBatch); + return Err(SDSSErrorKind::DataBatchRestoreCorruptedBatch.into()); } } // read actual commit @@ -320,7 +320,7 @@ impl DataBatchRestoreDriver { if actual_checksum == u64::from_le_bytes(hardcoded_checksum) { Ok(actual_commit) } else { - Err(SDSSError::DataBatchRestoreCorruptedBatch) + Err(SDSSErrorKind::DataBatchRestoreCorruptedBatch.into()) } } fn read_batch(&mut self) -> SDSSResult { @@ -340,7 +340,7 @@ impl DataBatchRestoreDriver { } _ => { // this is the only singular byte that is expected to be intact. If this isn't intact either, I'm sorry - return Err(SDSSError::DataBatchRestoreCorruptedBatch); + return Err(SDSSErrorKind::DataBatchRestoreCorruptedBatch.into()); } } // decode batch start block @@ -384,7 +384,7 @@ impl DataBatchRestoreDriver { this_col_cnt -= 1; } if this_col_cnt != 0 { - return Err(SDSSError::DataBatchRestoreCorruptedEntry); + return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into()); } if change_type == 1 { this_batch.push(DecodedBatchEvent::new( @@ -402,7 +402,7 @@ impl DataBatchRestoreDriver { processed_in_this_batch += 1; } _ => { - return Err(SDSSError::DataBatchRestoreCorruptedBatch); + return Err(SDSSErrorKind::DataBatchRestoreCorruptedBatch.into()); } } } @@ -417,7 +417,7 @@ impl DataBatchRestoreDriver { if let Ok(MARKER_RECOVERY_EVENT) = self.f.inner_file().read_byte() { return Ok(()); } - Err(SDSSError::DataBatchRestoreCorruptedBatch) + Err(SDSSErrorKind::DataBatchRestoreCorruptedBatch.into()) } fn read_start_batch_block(&mut self) -> SDSSResult { let pk_tag = self.f.read_byte()?; @@ -467,7 +467,7 @@ impl BatchStartBlock { impl DataBatchRestoreDriver { fn decode_primary_key(&mut self, pk_type: u8) -> SDSSResult { let Some(pk_type) = TagUnique::try_from_raw(pk_type) else { - return Err(SDSSError::DataBatchRestoreCorruptedEntry); + return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into()); }; Ok(match pk_type { TagUnique::SignedInt | TagUnique::UnsignedInt => { @@ -483,7 +483,7 @@ impl DataBatchRestoreDriver { self.f.read_into_buffer(&mut data)?; if pk_type == TagUnique::Str { if core::str::from_utf8(&data).is_err() { - return Err(SDSSError::DataBatchRestoreCorruptedEntry); + return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into()); } } unsafe { @@ -501,14 +501,14 @@ impl DataBatchRestoreDriver { fn decode_cell(&mut self) -> SDSSResult { let cell_type_sig = self.f.read_byte()?; let Some(cell_type) = PersistTypeDscr::try_from_raw(cell_type_sig) else { - return Err(SDSSError::DataBatchRestoreCorruptedEntry); + return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into()); }; Ok(match cell_type { PersistTypeDscr::Null => Datacell::null(), PersistTypeDscr::Bool => { let bool = self.f.read_byte()?; if bool > 1 { - return Err(SDSSError::DataBatchRestoreCorruptedEntry); + return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into()); } Datacell::new_bool(bool == 1) } @@ -528,7 +528,7 @@ impl DataBatchRestoreDriver { // UNSAFE(@ohsayan): +tagck if cell_type == PersistTypeDscr::Str { if core::str::from_utf8(&data).is_err() { - return Err(SDSSError::DataBatchRestoreCorruptedEntry); + return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into()); } Datacell::new_str(String::from_utf8_unchecked(data).into_boxed_str()) } else { @@ -543,13 +543,13 @@ impl DataBatchRestoreDriver { list.push(self.decode_cell()?); } if len != list.len() as u64 { - return Err(SDSSError::DataBatchRestoreCorruptedEntry); + return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into()); } Datacell::new_list(list) } PersistTypeDscr::Dict => { // we don't support dicts just yet - return Err(SDSSError::DataBatchRestoreCorruptedEntry); + return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into()); } }) } diff --git a/server/src/engine/storage/v1/inf/map.rs b/server/src/engine/storage/v1/inf/map.rs index 380d5ca2..57a2699d 100644 --- a/server/src/engine/storage/v1/inf/map.rs +++ b/server/src/engine/storage/v1/inf/map.rs @@ -40,7 +40,7 @@ use { }, idx::{IndexBaseSpec, IndexSTSeqCns, STIndex, STIndexSeq}, mem::BufferedScanner, - storage::v1::{inf, SDSSError, SDSSResult}, + storage::v1::{inf, SDSSError, SDSSErrorKind, SDSSResult}, }, util::{copy_slice_to_array as memcpy, EndianQW}, }, @@ -92,11 +92,12 @@ where while M::pretest_entry_metadata(scanner) & (dict.st_len() != dict_size) { let md = unsafe { // UNSAFE(@ohsayan): +pretest - M::entry_md_dec(scanner) - .ok_or(SDSSError::InternalDecodeStructureCorruptedPayload)? + M::entry_md_dec(scanner).ok_or::( + SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into(), + )? }; if !M::pretest_entry_data(scanner, &md) { - return Err(SDSSError::InternalDecodeStructureCorruptedPayload); + return Err(SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into()); } let key; let val; @@ -107,7 +108,11 @@ where key = _k; val = _v; } - None => return Err(SDSSError::InternalDecodeStructureCorruptedPayload), + None => { + return Err( + SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into() + ) + } } } else { let _k = M::dec_key(scanner, &md); @@ -117,18 +122,22 @@ where key = _k; val = _v; } - _ => return Err(SDSSError::InternalDecodeStructureCorruptedPayload), + _ => { + return Err( + SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into() + ) + } } } } if !dict.st_insert(key, val) { - return Err(SDSSError::InternalDecodeStructureIllegalData); + return Err(SDSSErrorKind::InternalDecodeStructureIllegalData.into()); } } if dict.st_len() == dict_size { Ok(dict) } else { - Err(SDSSError::InternalDecodeStructureIllegalData) + Err(SDSSErrorKind::InternalDecodeStructureIllegalData.into()) } } } diff --git a/server/src/engine/storage/v1/inf/mod.rs b/server/src/engine/storage/v1/inf/mod.rs index afa17003..75db030e 100644 --- a/server/src/engine/storage/v1/inf/mod.rs +++ b/server/src/engine/storage/v1/inf/mod.rs @@ -42,7 +42,7 @@ use { }, idx::{AsKey, AsValue}, mem::BufferedScanner, - storage::v1::{SDSSError, SDSSResult}, + storage::v1::{SDSSErrorKind, SDSSResult}, }, std::mem, }; @@ -157,14 +157,14 @@ pub trait PersistObject { /// Default routine to decode an object + its metadata (however, the metadata is used and not returned) fn default_full_dec(scanner: &mut BufferedScanner) -> SDSSResult { if !Self::pretest_can_dec_metadata(scanner) { - return Err(SDSSError::InternalDecodeStructureCorrupted); + return Err(SDSSErrorKind::InternalDecodeStructureCorrupted.into()); } let md = unsafe { // UNSAFE(@ohsayan): +pretest Self::meta_dec(scanner)? }; if !Self::pretest_can_dec_object(scanner, &md) { - return Err(SDSSError::InternalDecodeStructureCorruptedPayload); + return Err(SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into()); } unsafe { // UNSAFE(@ohsayan): +obj pretest @@ -290,11 +290,11 @@ pub mod dec { pub mod utils { use crate::engine::{ mem::BufferedScanner, - storage::v1::{SDSSError, SDSSResult}, + storage::v1::{SDSSErrorKind, SDSSResult}, }; pub unsafe fn decode_string(s: &mut BufferedScanner, len: usize) -> SDSSResult { String::from_utf8(s.next_chunk_variable(len).to_owned()) - .map_err(|_| SDSSError::InternalDecodeStructureCorruptedPayload) + .map_err(|_| SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into()) } } } diff --git a/server/src/engine/storage/v1/inf/obj.rs b/server/src/engine/storage/v1/inf/obj.rs index 01afe7fd..51a12e5e 100644 --- a/server/src/engine/storage/v1/inf/obj.rs +++ b/server/src/engine/storage/v1/inf/obj.rs @@ -39,7 +39,7 @@ use { DictGeneric, }, mem::{BufferedScanner, VInline}, - storage::v1::{inf, SDSSError, SDSSResult}, + storage::v1::{inf, SDSSErrorKind, SDSSResult}, }, util::EndianQW, }, @@ -119,7 +119,7 @@ impl<'a> PersistObject for LayerRef<'a> { fn obj_enc(_: &mut VecU8, _: Self::InputType) {} unsafe fn obj_dec(_: &mut BufferedScanner, md: Self::Metadata) -> SDSSResult { if (md.type_selector > TagSelector::List.value_qword()) | (md.prop_set_arity != 0) { - return Err(SDSSError::InternalDecodeStructureCorruptedPayload); + return Err(SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into()); } Ok(Layer::new_empty_props( TagSelector::from_raw(md.type_selector as u8).into_full(), @@ -202,7 +202,7 @@ impl<'a> PersistObject for FieldRef<'a> { if (field.layers().len() as u64 == md.layer_c) & (md.null <= 1) & (md.prop_c == 0) & fin { Ok(field) } else { - Err(SDSSError::InternalDecodeStructureCorrupted) + Err(SDSSErrorKind::InternalDecodeStructureCorrupted.into()) } } } @@ -281,7 +281,7 @@ impl<'a> PersistObject for ModelLayoutRef<'a> { super::map::MapIndexSizeMD(md.field_c as usize), )?; let ptag = if md.p_key_tag > TagSelector::MAX as u64 { - return Err(SDSSError::InternalDecodeStructureCorruptedPayload); + return Err(SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into()); } else { TagSelector::from_raw(md.p_key_tag as u8) }; diff --git a/server/src/engine/storage/v1/journal.rs b/server/src/engine/storage/v1/journal.rs index 4dbaaadb..6cdfcebf 100644 --- a/server/src/engine/storage/v1/journal.rs +++ b/server/src/engine/storage/v1/journal.rs @@ -44,7 +44,7 @@ use { super::{ rw::{FileOpen, RawFSInterface, SDSSFileIO}, - spec, SDSSError, SDSSResult, + spec, SDSSErrorKind, SDSSResult, }, crate::util::{compiler, copy_a_into_b, copy_slice_to_array as memcpy}, std::marker::PhantomData, @@ -222,7 +222,7 @@ impl JournalReader { } match entry_metadata .event_source_marker() - .ok_or(SDSSError::JournalLogEntryCorrupted)? + .ok_or(SDSSErrorKind::JournalLogEntryCorrupted)? { EventSourceMarker::ServerStandard => {} EventSourceMarker::DriverClosed => { @@ -237,7 +237,7 @@ impl JournalReader { EventSourceMarker::DriverReopened | EventSourceMarker::RecoveryReverseLastJournal => { // these two are only taken in close and error paths (respectively) so we shouldn't see them here; this is bad // two special directives in the middle of nowhere? incredible - return Err(SDSSError::JournalCorrupted); + return Err(SDSSErrorKind::JournalCorrupted.into()); } } // read payload @@ -270,10 +270,10 @@ impl JournalReader { Ok(()) } else { // FIXME(@ohsayan): tolerate loss in this directive too - Err(SDSSError::JournalCorrupted) + Err(SDSSErrorKind::JournalCorrupted.into()) } } else { - Err(SDSSError::JournalCorrupted) + Err(SDSSErrorKind::JournalCorrupted.into()) } } #[cold] // FIXME(@ohsayan): how bad can prod systems be? (clue: pretty bad, so look for possible changes) @@ -286,7 +286,7 @@ impl JournalReader { self.__record_read_bytes(JournalEntryMetadata::SIZE); // FIXME(@ohsayan): don't assume read length? let mut entry_buf = [0u8; JournalEntryMetadata::SIZE]; if self.log_file.read_to_buffer(&mut entry_buf).is_err() { - return Err(SDSSError::JournalCorrupted); + return Err(SDSSErrorKind::JournalCorrupted.into()); } let entry = JournalEntryMetadata::decode(entry_buf); let okay = (entry.event_id == self.evid as u128) @@ -297,7 +297,7 @@ impl JournalReader { if okay { return Ok(()); } else { - Err(SDSSError::JournalCorrupted) + Err(SDSSErrorKind::JournalCorrupted.into()) } } /// Read and apply all events in the given log file to the global state, returning the (open file, last event ID) @@ -309,7 +309,7 @@ impl JournalReader { if slf.closed { Ok((slf.log_file, slf.evid)) } else { - Err(SDSSError::JournalCorrupted) + Err(SDSSErrorKind::JournalCorrupted.into()) } } } @@ -397,7 +397,7 @@ impl JournalWriter { if self.log_file.fsynced_write(&entry.encoded()).is_ok() { return Ok(()); } - Err(SDSSError::JournalWRecoveryStageOneFailCritical) + Err(SDSSErrorKind::JournalWRecoveryStageOneFailCritical.into()) } pub fn append_journal_reopen(&mut self) -> SDSSResult<()> { let id = self._incr_id() as u128; diff --git a/server/src/engine/storage/v1/loader.rs b/server/src/engine/storage/v1/loader.rs index acf7f248..5cebafdb 100644 --- a/server/src/engine/storage/v1/loader.rs +++ b/server/src/engine/storage/v1/loader.rs @@ -32,7 +32,7 @@ use crate::engine::{ batch_jrnl, journal::{self, JournalWriter}, rw::{FileOpen, RawFSInterface}, - spec, LocalFS, SDSSErrorContext, SDSSResult, + spec, LocalFS, SDSSResult, }, txn::gns::{GNSAdapter, GNSTransactionDriverAnyFS}, }; @@ -73,14 +73,11 @@ impl SEInitState { for (model_name, model) in space.models().read().iter() { let path = Self::model_path(space_name, space_uuid, model_name, model.get_uuid()); - let persist_driver = match batch_jrnl::reinit(&path, model) { - Ok(j) => j, - Err(e) => { - return Err(e.with_extra(format!( - "failed to restore model data from journal in `{path}`" - ))) - } - }; + let persist_driver = batch_jrnl::reinit(&path, model).map_err(|e| { + e.add_ctx(format!( + "failed to restore model data from journal in `{path}`" + )) + })?; let _ = model_drivers.insert( ModelUniqueID::new(space_name, model_name, model.get_uuid()), FractalModelDriver::init(persist_driver), diff --git a/server/src/engine/storage/v1/mod.rs b/server/src/engine/storage/v1/mod.rs index 9f11804d..6132a515 100644 --- a/server/src/engine/storage/v1/mod.rs +++ b/server/src/engine/storage/v1/mod.rs @@ -33,7 +33,6 @@ pub mod spec; mod sysdb; // hl pub mod inf; -mod start_stop; // test pub mod memfs; #[cfg(test)] @@ -49,49 +48,25 @@ pub mod data_batch { pub use super::batch_jrnl::{create, reinit, DataBatchPersistDriver, DataBatchRestoreDriver}; } -use crate::{engine::txn::TransactionError, util::os::SysIOError as IoError}; - -pub type SDSSResult = Result; - -pub trait SDSSErrorContext { - type ExtraData; - fn with_extra(self, extra: Self::ExtraData) -> SDSSError; -} - -impl SDSSErrorContext for IoError { - type ExtraData = &'static str; - fn with_extra(self, extra: Self::ExtraData) -> SDSSError { - SDSSError::IoErrorExtra(self, extra) - } -} - -impl SDSSErrorContext for std::io::Error { - type ExtraData = &'static str; - fn with_extra(self, extra: Self::ExtraData) -> SDSSError { - SDSSError::IoErrorExtra(self.into(), extra) - } -} +use crate::{ + engine::{ + error::{CtxError, CtxResult}, + txn::TransactionError, + }, + util::os::SysIOError as IoError, +}; -impl SDSSErrorContext for SDSSError { - type ExtraData = String; - - fn with_extra(self, extra: Self::ExtraData) -> SDSSError { - SDSSError::Extra(Box::new(self), extra) - } -} +pub type SDSSResult = CtxResult; +pub type SDSSError = CtxError; #[derive(Debug)] #[cfg_attr(test, derive(PartialEq))] -pub enum SDSSError { +pub enum SDSSErrorKind { // IO errors /// An IO err IoError(IoError), - /// An IO err with extra ctx - IoErrorExtra(IoError, &'static str), - /// A corrupted file - CorruptedFile(&'static str), - // process errors OtherError(&'static str), + CorruptedFile(&'static str), // header /// version mismatch HeaderDecodeVersionMismatch, @@ -127,41 +102,18 @@ pub enum SDSSError { DataBatchCloseError, DataBatchRestoreCorruptedBatchFile, JournalRestoreTxnError, - /// An error with more context - // TODO(@ohsayan): avoid the box; we'll clean this up soon - Extra(Box, String), SysDBCorrupted, } -impl From for SDSSError { +impl From for SDSSErrorKind { fn from(_: TransactionError) -> Self { Self::JournalRestoreTxnError } } -impl SDSSError { - pub const fn corrupted_file(fname: &'static str) -> Self { - Self::CorruptedFile(fname) - } - pub const fn ioerror_extra(error: IoError, extra: &'static str) -> Self { - Self::IoErrorExtra(error, extra) - } - pub fn with_ioerror_extra(self, extra: &'static str) -> Self { - match self { - Self::IoError(ioe) => Self::IoErrorExtra(ioe, extra), - x => x, - } - } -} - -impl From for SDSSError { - fn from(e: IoError) -> Self { - Self::IoError(e) - } -} - -impl From for SDSSError { - fn from(e: std::io::Error) -> Self { - Self::IoError(e.into()) +direct_from! { + SDSSErrorKind => { + std::io::Error as IoError, + IoError as IoError, } } diff --git a/server/src/engine/storage/v1/rw.rs b/server/src/engine/storage/v1/rw.rs index be4233d8..266b5371 100644 --- a/server/src/engine/storage/v1/rw.rs +++ b/server/src/engine/storage/v1/rw.rs @@ -29,10 +29,7 @@ use { spec::{FileSpec, Header}, SDSSResult, }, - crate::{ - engine::storage::{v1::SDSSError, SCrc}, - util::os::SysIOError, - }, + crate::{engine::storage::SCrc, util::os::SysIOError}, std::{ fs::{self, File}, io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write}, @@ -348,9 +345,7 @@ impl SDSSFileTrackedReader { Err(e) => return Err(e), } } else { - Err(SDSSError::IoError(SysIOError::from( - std::io::ErrorKind::InvalidInput, - ))) + Err(SysIOError::from(std::io::ErrorKind::InvalidInput).into()) } } pub fn read_byte(&mut self) -> SDSSResult { @@ -373,9 +368,7 @@ impl SDSSFileTrackedReader { } pub fn read_block(&mut self) -> SDSSResult<[u8; N]> { if !self.has_left(N as _) { - return Err(SDSSError::IoError(SysIOError::from( - std::io::ErrorKind::InvalidInput, - ))); + return Err(SysIOError::from(std::io::ErrorKind::InvalidInput).into()); } let mut buf = [0; N]; self.read_into_buffer(&mut buf)?; diff --git a/server/src/engine/storage/v1/spec.rs b/server/src/engine/storage/v1/spec.rs index 6a448f3b..c470badc 100644 --- a/server/src/engine/storage/v1/spec.rs +++ b/server/src/engine/storage/v1/spec.rs @@ -40,7 +40,7 @@ use { crate::{ engine::storage::{ header::{HostArch, HostEndian, HostOS, HostPointerWidth}, - v1::SDSSError, + v1::SDSSErrorKind, versions::{self, DriverVersion, HeaderVersion, ServerVersion}, }, util::os, @@ -375,12 +375,12 @@ impl SDSSStaticHeaderV1Compact { } else { let version_okay = okay_header_version & okay_server_version & okay_driver_version; let md = ManuallyDrop::new([ - SDSSError::HeaderDecodeCorruptedHeader, - SDSSError::HeaderDecodeVersionMismatch, + SDSSErrorKind::HeaderDecodeCorruptedHeader, + SDSSErrorKind::HeaderDecodeVersionMismatch, ]); Err(unsafe { // UNSAFE(@ohsayan): while not needed, md for drop safety + correct index - md.as_ptr().add(!version_okay as usize).read() + md.as_ptr().add(!version_okay as usize).read().into() }) } } @@ -511,7 +511,7 @@ impl Header for SDSSStaticHeaderV1Compact { { Ok(()) } else { - Err(SDSSError::HeaderDecodeDataMismatch) + Err(SDSSErrorKind::HeaderDecodeDataMismatch.into()) } } } diff --git a/server/src/engine/storage/v1/start_stop.rs b/server/src/engine/storage/v1/start_stop.rs deleted file mode 100644 index 11121523..00000000 --- a/server/src/engine/storage/v1/start_stop.rs +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Created on Mon May 29 2023 - * - * This file is a part of Skytable - * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source - * NoSQL database written by Sayan Nandan ("the Author") with the - * vision to provide flexibility in data modelling without compromising - * on performance, queryability or scalability. - * - * Copyright (c) 2023, Sayan Nandan - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - * -*/ - -use { - super::{SDSSError, SDSSErrorContext, SDSSResult}, - crate::util::os, - std::{ - fs::File, - io::{ErrorKind, Read, Write}, - }, -}; - -#[cfg(not(test))] -const START_FILE: &'static str = ".start"; -#[cfg(test)] -const START_FILE: &'static str = ".start_testmode"; -#[cfg(not(test))] -const STOP_FILE: &'static str = ".stop"; -#[cfg(test)] -const STOP_FILE: &'static str = ".stop_testmode"; - -const EMSG_FAILED_WRITE_START_FILE: &str = - concat_str_to_str!("failed to write to `", START_FILE, "` file"); -const EMSG_FAILED_WRITE_STOP_FILE: &str = - concat_str_to_str!("failed to write to `", STOP_FILE, "` file"); -const EMSG_FAILED_OPEN_START_FILE: &str = - concat_str_to_str!("failed to open `", START_FILE, "` file"); -const EMSG_FAILED_OPEN_STOP_FILE: &str = - concat_str_to_str!("failed to open `", STOP_FILE, "` file"); -const EMSG_FAILED_VERIFY: &str = concat_str_to_str!( - "failed to verify `", - START_FILE, - concat_str_to_str!("` and `", STOP_FILE, "` timestamps") -); - -#[derive(Debug)] -pub struct StartStop { - begin: u128, - stop_file: File, -} - -#[derive(Debug)] -enum ReadNX { - Created(File), - Read(File, u128), -} - -impl ReadNX { - const fn created(&self) -> bool { - matches!(self, Self::Created(_)) - } - fn file_mut(&mut self) -> &mut File { - match self { - Self::Created(ref mut f) => f, - Self::Read(ref mut f, _) => f, - } - } - fn into_file(self) -> File { - match self { - Self::Created(f) => f, - Self::Read(f, _) => f, - } - } -} - -impl StartStop { - fn read_time_file(f: &str, create_new_if_nx: bool) -> SDSSResult { - let mut f = match File::options().write(true).read(true).open(f) { - Ok(f) => f, - Err(e) if e.kind() == ErrorKind::NotFound && create_new_if_nx => { - let f = File::create(f)?; - return Ok(ReadNX::Created(f)); - } - Err(e) => return Err(e.into()), - }; - let len = f.metadata().map(|m| m.len())?; - if len != sizeof!(u128) as u64 { - return Err(SDSSError::corrupted_file(START_FILE)); - } - let mut buf = [0u8; sizeof!(u128)]; - f.read_exact(&mut buf)?; - Ok(ReadNX::Read(f, u128::from_le_bytes(buf))) - } - pub fn terminate(mut self) -> SDSSResult<()> { - self.stop_file - .write_all(self.begin.to_le_bytes().as_ref()) - .map_err(|e| e.with_extra(EMSG_FAILED_WRITE_STOP_FILE)) - } - pub fn verify_and_start() -> SDSSResult { - // read start file - let mut start_file = Self::read_time_file(START_FILE, true) - .map_err(|e| e.with_ioerror_extra(EMSG_FAILED_OPEN_START_FILE))?; - // read stop file - let stop_file = Self::read_time_file(STOP_FILE, start_file.created()) - .map_err(|e| e.with_ioerror_extra(EMSG_FAILED_OPEN_STOP_FILE))?; - // read current time - let ctime = os::get_epoch_time(); - match (&start_file, &stop_file) { - (ReadNX::Read(_, time_start), ReadNX::Read(_, time_stop)) - if time_start == time_stop => {} - (ReadNX::Created(_), ReadNX::Created(_)) => {} - _ => return Err(SDSSError::OtherError(EMSG_FAILED_VERIFY)), - } - start_file - .file_mut() - .write_all(&ctime.to_le_bytes()) - .map_err(|e| e.with_extra(EMSG_FAILED_WRITE_START_FILE))?; - Ok(Self { - stop_file: stop_file.into_file(), - begin: ctime, - }) - } -} - -#[test] -fn verify_test() { - let x = || -> SDSSResult<()> { - let ss = StartStop::verify_and_start()?; - ss.terminate()?; - let ss = StartStop::verify_and_start()?; - ss.terminate()?; - std::fs::remove_file(START_FILE)?; - std::fs::remove_file(STOP_FILE)?; - Ok(()) - }; - x().unwrap(); -} diff --git a/server/src/engine/storage/v1/sysdb.rs b/server/src/engine/storage/v1/sysdb.rs index 7bb70c55..b76e0218 100644 --- a/server/src/engine/storage/v1/sysdb.rs +++ b/server/src/engine/storage/v1/sysdb.rs @@ -25,7 +25,7 @@ */ use { - super::{rw::FileOpen, SDSSError}, + super::{rw::FileOpen, SDSSErrorKind}, crate::engine::{ config::ConfigAuth, data::{cell::Datacell, DictEntryGeneric, DictGeneric}, @@ -175,7 +175,7 @@ fn rkey( ) -> SDSSResult { match d.remove(key).map(transform) { Some(Some(k)) => Ok(k), - _ => Err(SDSSError::SysDBCorrupted), + _ => Err(SDSSErrorKind::SysDBCorrupted.into()), } } @@ -201,14 +201,14 @@ pub fn decode_system_database(mut f: SDSSFileIO) -> SDSS let mut userdata = userdata .into_data() .and_then(Datacell::into_list) - .ok_or(SDSSError::SysDBCorrupted)?; + .ok_or(SDSSErrorKind::SysDBCorrupted)?; if userdata.len() != 1 { - return Err(SDSSError::SysDBCorrupted); + return Err(SDSSErrorKind::SysDBCorrupted.into()); } let user_password = userdata .remove(0) .into_bin() - .ok_or(SDSSError::SysDBCorrupted)?; + .ok_or(SDSSErrorKind::SysDBCorrupted)?; loaded_users.insert(username, SysAuthUser::new(user_password.into_boxed_slice())); } let sys_auth = SysAuth::new(root_key.into_boxed_slice(), loaded_users); @@ -220,7 +220,7 @@ pub fn decode_system_database(mut f: SDSSFileIO) -> SDSS d.into_data()?.into_uint() })?; if !(sysdb_data.is_empty() & auth_store.is_empty() & sys_store.is_empty()) { - return Err(SDSSError::SysDBCorrupted); + return Err(SDSSErrorKind::SysDBCorrupted.into()); } Ok(SysConfig::new( RwLock::new(sys_auth), diff --git a/server/src/engine/storage/v1/tests/tx.rs b/server/src/engine/storage/v1/tests/tx.rs index 45140c4d..5e39b737 100644 --- a/server/src/engine/storage/v1/tests/tx.rs +++ b/server/src/engine/storage/v1/tests/tx.rs @@ -28,7 +28,7 @@ use { crate::{ engine::storage::v1::{ journal::{self, JournalAdapter, JournalWriter}, - spec, SDSSError, SDSSResult, + spec, SDSSError, SDSSErrorKind, SDSSResult, }, util, }, @@ -115,7 +115,9 @@ impl JournalAdapter for DatabaseTxnAdapter { fn decode_and_update_state(payload: &[u8], gs: &Self::GlobalState) -> Result<(), TxError> { if payload.len() != 10 { - return Err(SDSSError::CorruptedFile("testtxn.log").into()); + return Err(TxError::SDSS( + SDSSErrorKind::CorruptedFile("testtxn.log").into(), + )); } let opcode = payload[0]; let index = u64::from_le_bytes(util::copy_slice_to_array(&payload[1..9])); @@ -123,7 +125,11 @@ impl JournalAdapter for DatabaseTxnAdapter { match opcode { 0 if index == 0 && new_value == 0 => gs.reset(), 1 if index < 10 && index < isize::MAX as u64 => gs.set(index as usize, new_value), - _ => return Err(SDSSError::JournalLogEntryCorrupted.into()), + _ => { + return Err(TxError::SDSS( + SDSSErrorKind::JournalLogEntryCorrupted.into(), + )) + } } Ok(()) } diff --git a/server/src/engine/tests/mod.rs b/server/src/engine/tests/mod.rs index 092acf3a..1ac07cb3 100644 --- a/server/src/engine/tests/mod.rs +++ b/server/src/engine/tests/mod.rs @@ -90,8 +90,12 @@ mod cfg { #[test] fn parse_validate_cli_args() { with_files( - ["__cli_args_test_private.key", "__cli_args_test_cert.pem"], - |[pkey, cert]| { + [ + "__cli_args_test_private.key", + "__cli_args_test_cert.pem", + "__cli_args_test_passphrase.key", + ], + |[pkey, cert, pass]| { let payload = format!( "skyd --mode=dev \ --endpoint tcp@127.0.0.1:2003 \ @@ -99,6 +103,7 @@ mod cfg { --service-window=600 \ --tlskey {pkey} \ --tlscert {cert} \ + --tls-passphrase {pass} \ --auth-plugin pwd \ --auth-root-password password12345678 " @@ -115,6 +120,7 @@ mod cfg { ConfigEndpointTls::new( ConfigEndpointTcp::new("127.0.0.2".into(), 2004), "".into(), + "".into(), "".into() ) ), @@ -205,13 +211,18 @@ mod cfg { #[test] fn parse_validate_env_args() { with_files( - ["__env_args_test_cert.pem", "__env_args_test_private.key"], - |[cert, key]| { + [ + "__env_args_test_cert.pem", + "__env_args_test_private.key", + "__env_args_test_private.passphrase.txt", + ], + |[cert, key, pass]| { let variables = [ format!("SKYDB_AUTH_PLUGIN=pwd"), format!("SKYDB_AUTH_ROOT_PASSWORD=password12345678"), format!("SKYDB_TLS_CERT={cert}"), format!("SKYDB_TLS_KEY={key}"), + format!("SKYDB_TLS_PRIVATE_KEY_PASSWORD={pass}"), format!("SKYDB_ENDPOINTS=tcp@localhost:8080,tls@localhost:8081"), format!("SKYDB_RUN_MODE=dev"), format!("SKYDB_SERVICE_WINDOW=600"), @@ -226,6 +237,7 @@ mod cfg { ConfigEndpointTls::new( ConfigEndpointTcp::new("localhost".into(), 8081), "".into(), + "".into(), "".into() ) ), @@ -252,6 +264,7 @@ endpoints: port: 2004 cert: ._test_sample_cert.pem private_key: ._test_sample_private.key + pkey_passphrase: ._test_sample_private.pass.txt insecure: host: 127.0.0.1 port: 2003 @@ -259,7 +272,11 @@ endpoints: #[test] fn test_config_file() { with_files( - ["._test_sample_cert.pem", "._test_sample_private.key"], + [ + "._test_sample_cert.pem", + "._test_sample_private.key", + "._test_sample_private.pass.txt", + ], |_| { config::set_cli_src(vec!["skyd".into(), "--config=config.yml".into()]); config::set_file_src(CONFIG_FILE); @@ -272,6 +289,7 @@ endpoints: ConfigEndpointTls::new( ConfigEndpointTcp::new("127.0.0.1".into(), 2004), "".into(), + "".into(), "".into() ) ), diff --git a/server/src/engine/txn/gns/model.rs b/server/src/engine/txn/gns/model.rs index a31b28ac..f8097f15 100644 --- a/server/src/engine/txn/gns/model.rs +++ b/server/src/engine/txn/gns/model.rs @@ -39,7 +39,7 @@ use { ql::lex::Ident, storage::v1::{ inf::{self, map, obj, PersistObject}, - SDSSError, SDSSResult, + SDSSErrorKind, SDSSResult, }, txn::TransactionError, }, @@ -498,7 +498,7 @@ impl<'a> PersistObject for AlterModelRemoveTxn<'a> { removed_fields.push(inf::dec::utils::decode_string(s, len)?.into_boxed_str()); } if removed_fields.len() as u64 != md.remove_field_c { - return Err(SDSSError::InternalDecodeStructureCorruptedPayload); + return Err(SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into()); } Ok(AlterModelRemoveTxnRestorePL { model_id,