Add net impls

Also cleaned up error impls
next
Sayan Nandan 12 months ago
parent 5ba82a6cf0
commit be540a7ded
No known key found for this signature in database
GPG Key ID: 42EEDF4AE9D96B54

@ -147,14 +147,21 @@ pub struct ConfigEndpointTls {
tcp: ConfigEndpointTcp,
cert: String,
private_key: String,
pkey_pass: String,
}
impl ConfigEndpointTls {
pub fn new(tcp: ConfigEndpointTcp, cert: String, private_key: String) -> Self {
pub fn new(
tcp: ConfigEndpointTcp,
cert: String,
private_key: String,
pkey_pass: String,
) -> Self {
Self {
tcp,
cert,
private_key,
pkey_pass,
}
}
}
@ -264,6 +271,7 @@ pub struct DecodedEPSecureConfig {
port: u16,
cert: String,
private_key: String,
pkey_passphrase: String,
}
#[derive(Debug, PartialEq, Deserialize)]
@ -383,6 +391,7 @@ pub(super) trait ConfigurationSource {
const KEY_AUTH_ROOT_PASSWORD: &'static str;
const KEY_TLS_CERT: &'static str;
const KEY_TLS_KEY: &'static str;
const KEY_TLS_PKEY_PASS: &'static str;
const KEY_ENDPOINTS: &'static str;
const KEY_RUN_MODE: &'static str;
const KEY_SERVICE_WINDOW: &'static str;
@ -463,16 +472,19 @@ fn parse_endpoint(source: ConfigSource, s: &str) -> ConfigResult<(ConnectionProt
fn decode_tls_ep(
cert_path: &str,
key_path: &str,
pkey_pass: &str,
host: &str,
port: u16,
) -> ConfigResult<DecodedEPSecureConfig> {
let tls_key = fs::read_to_string(key_path)?;
let tls_cert = fs::read_to_string(cert_path)?;
let tls_priv_key_passphrase = fs::read_to_string(pkey_pass)?;
Ok(DecodedEPSecureConfig {
host: host.into(),
port,
cert: tls_cert,
private_key: tls_key,
pkey_passphrase: tls_priv_key_passphrase,
})
}
@ -484,22 +496,31 @@ fn arg_decode_tls_endpoint<CS: ConfigurationSource>(
) -> ConfigResult<DecodedEPSecureConfig> {
let _cert = args.remove(CS::KEY_TLS_CERT);
let _key = args.remove(CS::KEY_TLS_KEY);
let (tls_cert, tls_key) = match (_cert, _key) {
(Some(cert), Some(key)) => (cert, key),
let _passphrase = args.remove(CS::KEY_TLS_PKEY_PASS);
let (tls_cert, tls_key, tls_passphrase) = match (_cert, _key, _passphrase) {
(Some(cert), Some(key), Some(pass)) => (cert, key, pass),
_ => {
return Err(ConfigError::with_src(
ConfigSource::Cli,
ConfigErrorKind::ErrorString(format!(
"must supply values for both `{}` and `{}` when using TLS",
"must supply values for `{}`, `{}` and `{}` when using TLS",
CS::KEY_TLS_CERT,
CS::KEY_TLS_KEY
CS::KEY_TLS_KEY,
CS::KEY_TLS_PKEY_PASS,
)),
));
}
};
argck_duplicate_values::<CS>(&tls_cert, CS::KEY_TLS_CERT)?;
argck_duplicate_values::<CS>(&tls_key, CS::KEY_TLS_KEY)?;
Ok(decode_tls_ep(&tls_cert[0], &tls_key[0], host, port)?)
argck_duplicate_values::<CS>(&tls_passphrase, CS::KEY_TLS_PKEY_PASS)?;
Ok(decode_tls_ep(
&tls_cert[0],
&tls_key[0],
&tls_passphrase[0],
host,
port,
)?)
}
/*
@ -747,7 +768,7 @@ pub fn parse_cli_args<'a, T: 'a + AsRef<str>>(
/// Parse environment variables
pub fn parse_env_args() -> ConfigResult<Option<ParsedRawArgs>> {
const KEYS: [&str; 7] = [
const KEYS: [&str; 8] = [
CSEnvArgs::KEY_AUTH_DRIVER,
CSEnvArgs::KEY_AUTH_ROOT_PASSWORD,
CSEnvArgs::KEY_ENDPOINTS,
@ -755,6 +776,7 @@ pub fn parse_env_args() -> ConfigResult<Option<ParsedRawArgs>> {
CSEnvArgs::KEY_SERVICE_WINDOW,
CSEnvArgs::KEY_TLS_CERT,
CSEnvArgs::KEY_TLS_KEY,
CSEnvArgs::KEY_TLS_PKEY_PASS,
];
let mut ret = HashMap::new();
for key in KEYS {
@ -853,6 +875,7 @@ impl ConfigurationSource for CSCommandLine {
const KEY_AUTH_ROOT_PASSWORD: &'static str = "--auth-root-password";
const KEY_TLS_CERT: &'static str = "--tlscert";
const KEY_TLS_KEY: &'static str = "--tlskey";
const KEY_TLS_PKEY_PASS: &'static str = "--tls-passphrase";
const KEY_ENDPOINTS: &'static str = "--endpoint";
const KEY_RUN_MODE: &'static str = "--mode";
const KEY_SERVICE_WINDOW: &'static str = "--service-window";
@ -865,6 +888,7 @@ impl ConfigurationSource for CSEnvArgs {
const KEY_AUTH_ROOT_PASSWORD: &'static str = "SKYDB_AUTH_ROOT_PASSWORD";
const KEY_TLS_CERT: &'static str = "SKYDB_TLS_CERT";
const KEY_TLS_KEY: &'static str = "SKYDB_TLS_KEY";
const KEY_TLS_PKEY_PASS: &'static str = "SKYDB_TLS_PRIVATE_KEY_PASSWORD";
const KEY_ENDPOINTS: &'static str = "SKYDB_ENDPOINTS";
const KEY_RUN_MODE: &'static str = "SKYDB_RUN_MODE";
const KEY_SERVICE_WINDOW: &'static str = "SKYDB_SERVICE_WINDOW";
@ -877,6 +901,7 @@ impl ConfigurationSource for CSConfigFile {
const KEY_AUTH_ROOT_PASSWORD: &'static str = "auth.root_password";
const KEY_TLS_CERT: &'static str = "endpoints.secure.cert";
const KEY_TLS_KEY: &'static str = "endpoints.secure.key";
const KEY_TLS_PKEY_PASS: &'static str = "endpoints.secure.pkey_passphrase";
const KEY_ENDPOINTS: &'static str = "endpoints";
const KEY_RUN_MODE: &'static str = "system.mode";
const KEY_SERVICE_WINDOW: &'static str = "system.service_window";
@ -937,10 +962,11 @@ fn validate_configuration<CS: ConfigurationSource>(
let secure_ep = ConfigEndpointTls {
tcp: ConfigEndpointTcp {
host: secure.host,
port: secure.port
port: secure.port,
},
cert: secure.cert,
private_key: secure.private_key
private_key: secure.private_key,
pkey_pass: secure.pkey_passphrase,
};
match &config.endpoints {
ConfigEndpoint::Insecure(is) => if has_insecure {
@ -961,7 +987,7 @@ fn validate_configuration<CS: ConfigurationSource>(
CS::SOURCE,
ConfigErrorKind::ErrorString("invalid value for service window. must be nonzero".into()),
),
if config.auth.root_key.len() <= ROOT_PASSWORD_MIN_LEN => ConfigError::with_src(
if config.auth.root_key.len() < ROOT_PASSWORD_MIN_LEN => ConfigError::with_src(
CS::SOURCE,
ConfigErrorKind::ErrorString("the root password must have at least 16 characters".into()),
),
@ -1175,8 +1201,10 @@ fn check_config_file(
Some(secure_ep) => {
let cert = fs::read_to_string(&secure_ep.cert)?;
let private_key = fs::read_to_string(&secure_ep.private_key)?;
let private_key_passphrase = fs::read_to_string(&secure_ep.pkey_passphrase)?;
secure_ep.cert = cert;
secure_ep.private_key = private_key;
secure_ep.pkey_passphrase = private_key_passphrase;
}
None => {}
},

@ -24,8 +24,20 @@
*
*/
use super::{storage::v1::SDSSError, txn::TransactionError};
use {
super::{
storage::v1::{SDSSError, SDSSErrorKind},
txn::TransactionError,
},
crate::util::os::SysIOError,
std::fmt,
};
pub type QueryResult<T> = Result<T, Error>;
// stack
pub type CtxResult<T, E> = Result<T, CtxError<E>>;
pub type RuntimeResult<T> = CtxResult<T, RuntimeErrorKind>;
pub type RuntimeError = CtxError<RuntimeErrorKind>;
/// an enumeration of 'flat' errors that the server actually responds to the client with, since we do not want to send specific information
/// about anything (as that will be a security hole). The variants correspond with their actual response codes
@ -102,3 +114,119 @@ direct_from! {
TransactionError as TransactionalError,
}
}
/*
contextual errors
*/
/// An error context
pub enum CtxErrorDescription {
A(&'static str),
B(Box<str>),
}
impl CtxErrorDescription {
fn inner(&self) -> &str {
match self {
Self::A(a) => a,
Self::B(b) => &b,
}
}
}
impl PartialEq for CtxErrorDescription {
fn eq(&self, other: &Self) -> bool {
self.inner() == other.inner()
}
}
impl fmt::Display for CtxErrorDescription {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.inner())
}
}
impl fmt::Debug for CtxErrorDescription {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.inner())
}
}
direct_from! {
CtxErrorDescription => {
&'static str as A,
String as B,
Box<str> as B,
}
}
#[derive(Debug)]
#[cfg_attr(test, derive(PartialEq))]
/// A contextual error
pub struct CtxError<E> {
kind: E,
ctx: Option<CtxErrorDescription>,
}
impl<E> CtxError<E> {
fn _new(kind: E, ctx: Option<CtxErrorDescription>) -> Self {
Self { kind, ctx }
}
pub fn new(kind: E) -> Self {
Self::_new(kind, None)
}
pub fn with_ctx(kind: E, ctx: impl Into<CtxErrorDescription>) -> Self {
Self::_new(kind, Some(ctx.into()))
}
pub fn add_ctx(self, ctx: impl Into<CtxErrorDescription>) -> Self {
Self::with_ctx(self.kind, ctx)
}
pub fn into_result<T>(self) -> CtxResult<T, E> {
Err(self)
}
pub fn result<T, F>(result: Result<T, F>) -> CtxResult<T, E>
where
E: From<F>,
{
result.map_err(|e| CtxError::new(e.into()))
}
pub fn result_ctx<T, F>(
result: Result<T, F>,
ctx: impl Into<CtxErrorDescription>,
) -> CtxResult<T, E>
where
E: From<F>,
{
result.map_err(|e| CtxError::with_ctx(e.into(), ctx))
}
}
macro_rules! impl_from_hack {
($($ty:ty),*) => {
$(impl<E> From<E> for CtxError<$ty> where E: Into<$ty> {fn from(e: E) -> Self { CtxError::new(e.into()) }})*
}
}
/*
Contextual error impls
*/
impl_from_hack!(RuntimeErrorKind, SDSSErrorKind);
#[derive(Debug)]
pub enum RuntimeErrorKind {
StorageSubsytem(SDSSError),
IoError(SysIOError),
OSSLErrorMulti(openssl::error::ErrorStack),
OSSLError(openssl::ssl::Error),
}
direct_from! {
RuntimeErrorKind => {
SDSSError as StorageSubsytem,
std::io::Error as IoError,
SysIOError as IoError,
openssl::error::ErrorStack as OSSLErrorMulti,
openssl::ssl::Error as OSSLError,
}
}

@ -174,8 +174,8 @@ impl FractalMgr {
hp_receiver: UnboundedReceiver<Task<CriticalTask>>,
) -> FractalServiceHandles {
let fractal_mgr = global.get_state().fractal_mgr();
let global_1 = global.__global_clone();
let global_2 = global.__global_clone();
let global_1 = global.clone();
let global_2 = global.clone();
let hp_handle = tokio::spawn(async move {
FractalMgr::hp_executor_svc(fractal_mgr, global_1, hp_receiver).await
});

@ -53,8 +53,6 @@ pub use {
pub type ModelDrivers<Fs> = HashMap<ModelUniqueID, drivers::FractalModelDriver<Fs>>;
static mut GLOBAL: MaybeUninit<GlobalState> = MaybeUninit::uninit();
/*
global state init
*/
@ -89,10 +87,10 @@ pub unsafe fn enable_and_start_all(
mgr::FractalMgr::new(hp_sender, lp_sender, model_cnt_on_boot),
config,
);
GLOBAL = MaybeUninit::new(global_state);
*Global::__gref_raw() = MaybeUninit::new(global_state);
let token = Global::new();
GlobalStateStart {
global: token.__global_clone(),
global: token.clone(),
mgr_handles: mgr::FractalMgr::start_all(token, lp_recv, hp_recv),
}
}
@ -187,7 +185,7 @@ impl GlobalInstanceLike for Global {
}
}
#[derive(Debug)]
#[derive(Debug, Clone, Copy)]
/// A handle to the global state
pub struct Global(());
@ -195,18 +193,12 @@ impl Global {
unsafe fn new() -> Self {
Self(())
}
fn __global_clone(&self) -> Self {
unsafe {
// UNSAFE(@ohsayan): safe to call within this module
Self::new()
}
}
fn get_state(&self) -> &'static GlobalState {
unsafe { GLOBAL.assume_init_ref() }
unsafe { self.__gref() }
}
/// Returns a handle to the [`GlobalNS`]
fn _namespace(&self) -> &'static GlobalNS {
&unsafe { GLOBAL.assume_init_ref() }.gns
&unsafe { self.__gref() }.gns
}
/// Post an urgent task
fn _post_high_priority_task(&self, task: Task<CriticalTask>) {
@ -227,6 +219,16 @@ impl Global {
.get_rt_stat()
.per_mdl_delta_max_size()
}
unsafe fn __gref_raw() -> &'static mut MaybeUninit<GlobalState> {
static mut G: MaybeUninit<GlobalState> = MaybeUninit::uninit();
&mut G
}
unsafe fn __gref(&self) -> &'static GlobalState {
Self::__gref_raw().assume_init_ref()
}
pub unsafe fn unload_all(self) {
core::ptr::drop_in_place(Self::__gref_raw().as_mut_ptr())
}
}
/*

@ -43,3 +43,9 @@ mod txn;
// test
#[cfg(test)]
mod tests;
use error::RuntimeResult;
pub fn load_all() -> RuntimeResult<fractal::Global> {
todo!()
}

@ -24,13 +24,247 @@
*
*/
use tokio::io::{AsyncRead, AsyncWrite};
mod protocol;
use {
crate::engine::{
error::{RuntimeError, RuntimeResult},
fractal::Global,
},
bytes::BytesMut,
openssl::{
pkey::PKey,
ssl::Ssl,
ssl::{SslAcceptor, SslMethod},
x509::X509,
},
std::{cell::Cell, net::SocketAddr, pin::Pin, time::Duration},
tokio::{
io::{AsyncRead, AsyncWrite, BufWriter},
net::{TcpListener, TcpStream},
sync::{broadcast, mpsc, Semaphore},
},
tokio_openssl::SslStream,
};
pub trait Socket: AsyncWrite + AsyncRead + Unpin {}
pub type IoResult<T> = Result<T, std::io::Error>;
const BUF_WRITE_CAP: usize = 16384;
const BUF_READ_CAP: usize = 16384;
const CLIMIT: usize = 50000;
static CLIM: Semaphore = Semaphore::const_new(CLIMIT);
/*
socket definitions
*/
impl Socket for TcpStream {}
impl Socket for SslStream<TcpStream> {}
pub enum QLoopReturn {
Fin,
ConnectionRst,
}
struct NetBackoff {
at: Cell<u8>,
}
impl NetBackoff {
const BACKOFF_MAX: u8 = 64;
fn new() -> Self {
Self { at: Cell::new(1) }
}
async fn spin(&self) {
let current = self.at.get();
self.at.set(current << 1);
tokio::time::sleep(Duration::from_secs(current as _)).await
}
fn should_disconnect(&self) -> bool {
self.at.get() >= Self::BACKOFF_MAX
}
}
/*
listener
*/
/// Connection handler for a remote connection
pub struct ConnectionHandler<S> {
socket: BufWriter<S>,
buffer: BytesMut,
global: Global,
sig_terminate: broadcast::Receiver<()>,
_sig_inflight_complete: mpsc::Sender<()>,
}
impl<S: Socket> ConnectionHandler<S> {
pub fn new(
socket: S,
global: Global,
term_sig: broadcast::Receiver<()>,
_inflight_complete: mpsc::Sender<()>,
) -> Self {
Self {
socket: BufWriter::with_capacity(BUF_WRITE_CAP, socket),
buffer: BytesMut::with_capacity(BUF_READ_CAP),
global,
sig_terminate: term_sig,
_sig_inflight_complete: _inflight_complete,
}
}
pub async fn run(&mut self) -> IoResult<()> {
let Self {
socket,
buffer,
global,
..
} = self;
loop {
tokio::select! {
_ = protocol::query_loop(socket, buffer, global) => {},
_ = self.sig_terminate.recv() => {
return Ok(())
}
}
}
}
}
/// A TCP listener bound to a socket
pub struct Listener {
global: Global,
listener: TcpListener,
sig_shutdown: broadcast::Sender<()>,
sig_inflight: mpsc::Sender<()>,
sig_inflight_wait: mpsc::Receiver<()>,
}
impl Listener {
pub async fn new(
binaddr: &str,
global: Global,
sig_shutdown: broadcast::Sender<()>,
) -> RuntimeResult<Self> {
let (sig_inflight, sig_inflight_wait) = mpsc::channel(1);
let listener = RuntimeError::result_ctx(
TcpListener::bind(binaddr).await,
format!("failed to bind to port `{binaddr}`"),
)?;
Ok(Self {
global,
listener,
sig_shutdown,
sig_inflight,
sig_inflight_wait,
})
}
pub async fn terminate(self) {
let Self {
mut sig_inflight_wait,
sig_inflight,
sig_shutdown,
..
} = self;
drop(sig_shutdown);
drop(sig_inflight); // could be that we are the only ones holding this lol
let _ = sig_inflight_wait.recv().await; // wait
}
async fn accept(&mut self) -> IoResult<(TcpStream, SocketAddr)> {
let backoff = NetBackoff::new();
loop {
match self.listener.accept().await {
Ok(s) => return Ok(s),
Err(e) => {
if backoff.should_disconnect() {
// that's enough of your crappy connection dear sir
return Err(e.into());
}
}
}
backoff.spin().await;
}
}
async fn listen_tcp(&mut self) -> IoResult<()> {
loop {
// acquire a permit
let permit = CLIM.acquire().await.unwrap();
let (stream, _) = match self.accept().await {
Ok(s) => s,
Err(e) => {
/*
SECURITY: IGNORE THIS ERROR
*/
log::error!("failed to accept connection on TCP socket: `{e}`");
continue;
}
};
let mut handler = ConnectionHandler::new(
stream,
self.global,
self.sig_shutdown.subscribe(),
self.sig_inflight.clone(),
);
tokio::spawn(async move {
if let Err(e) = handler.run().await {
log::error!("error handling client connection: `{e}`");
}
});
// return the permit
drop(permit);
}
}
async fn listen_tls(
self: &mut Self,
tls_cert: String,
tls_priv_key: String,
tls_key_password: String,
) -> RuntimeResult<()> {
let build_acceptor = || {
let cert = X509::from_pem(tls_cert.as_bytes())?;
let priv_key = PKey::private_key_from_pem_passphrase(
tls_priv_key.as_bytes(),
tls_key_password.as_bytes(),
)?;
let mut builder = SslAcceptor::mozilla_intermediate_v5(SslMethod::tls())?;
builder.set_certificate(&cert)?;
builder.set_private_key(&priv_key)?;
builder.check_private_key()?;
Ok::<_, openssl::error::ErrorStack>(builder.build())
};
let acceptor =
RuntimeError::result_ctx(build_acceptor(), "failed to initialize TLS socket")?;
loop {
let stream = async {
let (stream, _) = self.accept().await?;
let ssl = Ssl::new(acceptor.context())?;
let mut stream = SslStream::new(ssl, stream)?;
Pin::new(&mut stream).accept().await?;
RuntimeResult::Ok(stream)
};
let stream = match stream.await {
Ok(s) => s,
Err(e) => {
/*
SECURITY: Once again, ignore this error
*/
log::error!("failed to accept connection on TLS socket: `{e:#?}`");
continue;
}
};
let mut handler = ConnectionHandler::new(
stream,
self.global,
self.sig_shutdown.subscribe(),
self.sig_inflight.clone(),
);
tokio::spawn(async move {
if let Err(e) = handler.run().await {
log::error!("error handling client TLS connection: `{e}`");
}
});
}
}
}

@ -32,7 +32,7 @@ mod tests;
use {
self::handshake::{CHandshake, HandshakeResult, HandshakeState},
super::{IoResult, QLoopReturn, Socket},
crate::engine::mem::BufferedScanner,
crate::engine::{fractal::Global, mem::BufferedScanner},
bytes::{Buf, BytesMut},
tokio::io::{AsyncReadExt, BufWriter},
};
@ -40,6 +40,7 @@ use {
pub async fn query_loop<S: Socket>(
con: &mut BufWriter<S>,
buf: &mut BytesMut,
_global: &Global,
) -> IoResult<QLoopReturn> {
// handshake
match do_handshake(con, buf).await? {
@ -56,6 +57,7 @@ pub async fn query_loop<S: Socket>(
}
}
#[inline(always)]
fn see_if_connection_terminates(read_many: usize, buf: &[u8]) -> Option<QLoopReturn> {
if read_many == 0 {
// that's a connection termination

@ -46,7 +46,7 @@ use {
storage::v1::{
inf::PersistTypeDscr,
rw::{RawFSInterface, SDSSFileIO, SDSSFileTrackedWriter},
SDSSError, SDSSResult,
SDSSErrorKind, SDSSResult,
},
},
util::EndianQW,
@ -76,7 +76,7 @@ impl<Fs: RawFSInterface> DataBatchPersistDriver<Fs> {
{
return Ok(());
} else {
return Err(SDSSError::DataBatchCloseError);
return Err(SDSSErrorKind::DataBatchCloseError.into());
}
}
pub fn write_new_batch(&mut self, model: &Model, observed_len: usize) -> SDSSResult<()> {
@ -154,7 +154,7 @@ impl<Fs: RawFSInterface> DataBatchPersistDriver<Fs> {
schema_version: DeltaVersion,
pk_tag: TagUnique,
col_cnt: usize,
) -> Result<(), SDSSError> {
) -> SDSSResult<()> {
self.f
.unfsynced_write(&[MARKER_ACTUAL_BATCH_EVENT, pk_tag.value_u8()])?;
let observed_len_bytes = observed_len.u64_bytes_le();
@ -169,7 +169,7 @@ impl<Fs: RawFSInterface> DataBatchPersistDriver<Fs> {
&mut self,
observed_len: usize,
inconsistent_reads: usize,
) -> Result<(), SDSSError> {
) -> SDSSResult<()> {
// [0xFD][actual_commit][checksum]
self.f.unfsynced_write(&[MARKER_END_OF_BATCH])?;
let actual_commit = (observed_len - inconsistent_reads).u64_bytes_le();
@ -189,7 +189,7 @@ impl<Fs: RawFSInterface> DataBatchPersistDriver<Fs> {
if f.fsynced_write(&[MARKER_RECOVERY_EVENT]).is_ok() {
return Ok(());
}
Err(SDSSError::DataBatchRecoveryFailStageOne)
Err(SDSSErrorKind::DataBatchRecoveryFailStageOne.into())
}
}
@ -288,7 +288,7 @@ impl<Fs: RawFSInterface> DataBatchPersistDriver<Fs> {
Ok(())
}
/// Write the change type and txnid
fn write_batch_item_common_row_data(&mut self, delta: &DataDelta) -> Result<(), SDSSError> {
fn write_batch_item_common_row_data(&mut self, delta: &DataDelta) -> SDSSResult<()> {
let change_type = [delta.change().value_u8()];
self.f.unfsynced_write(&change_type)?;
let txn_id = delta.data_version().value_u64().to_le_bytes();

@ -42,7 +42,7 @@ use {
storage::v1::{
inf::PersistTypeDscr,
rw::{RawFSInterface, SDSSFileIO, SDSSFileTrackedReader},
SDSSError, SDSSResult,
SDSSErrorKind, SDSSResult,
},
},
crossbeam_epoch::pin,
@ -187,7 +187,7 @@ impl<F: RawFSInterface> DataBatchRestoreDriver<F> {
}
}
// nope, this is a corrupted file
Err(SDSSError::DataBatchRestoreCorruptedBatchFile)
Err(SDSSErrorKind::DataBatchRestoreCorruptedBatchFile.into())
}
fn handle_reopen_is_actual_close(&mut self) -> SDSSResult<bool> {
if self.f.is_eof() {
@ -200,7 +200,7 @@ impl<F: RawFSInterface> DataBatchRestoreDriver<F> {
Ok(false)
} else {
// that's just a nice bug
Err(SDSSError::DataBatchRestoreCorruptedBatchFile)
Err(SDSSErrorKind::DataBatchRestoreCorruptedBatchFile.into())
}
}
}
@ -303,7 +303,7 @@ impl<F: RawFSInterface> DataBatchRestoreDriver<F> {
// we must read the batch termination signature
let b = self.f.read_byte()?;
if b != MARKER_END_OF_BATCH {
return Err(SDSSError::DataBatchRestoreCorruptedBatch);
return Err(SDSSErrorKind::DataBatchRestoreCorruptedBatch.into());
}
}
// read actual commit
@ -320,7 +320,7 @@ impl<F: RawFSInterface> DataBatchRestoreDriver<F> {
if actual_checksum == u64::from_le_bytes(hardcoded_checksum) {
Ok(actual_commit)
} else {
Err(SDSSError::DataBatchRestoreCorruptedBatch)
Err(SDSSErrorKind::DataBatchRestoreCorruptedBatch.into())
}
}
fn read_batch(&mut self) -> SDSSResult<Batch> {
@ -340,7 +340,7 @@ impl<F: RawFSInterface> DataBatchRestoreDriver<F> {
}
_ => {
// this is the only singular byte that is expected to be intact. If this isn't intact either, I'm sorry
return Err(SDSSError::DataBatchRestoreCorruptedBatch);
return Err(SDSSErrorKind::DataBatchRestoreCorruptedBatch.into());
}
}
// decode batch start block
@ -384,7 +384,7 @@ impl<F: RawFSInterface> DataBatchRestoreDriver<F> {
this_col_cnt -= 1;
}
if this_col_cnt != 0 {
return Err(SDSSError::DataBatchRestoreCorruptedEntry);
return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into());
}
if change_type == 1 {
this_batch.push(DecodedBatchEvent::new(
@ -402,7 +402,7 @@ impl<F: RawFSInterface> DataBatchRestoreDriver<F> {
processed_in_this_batch += 1;
}
_ => {
return Err(SDSSError::DataBatchRestoreCorruptedBatch);
return Err(SDSSErrorKind::DataBatchRestoreCorruptedBatch.into());
}
}
}
@ -417,7 +417,7 @@ impl<F: RawFSInterface> DataBatchRestoreDriver<F> {
if let Ok(MARKER_RECOVERY_EVENT) = self.f.inner_file().read_byte() {
return Ok(());
}
Err(SDSSError::DataBatchRestoreCorruptedBatch)
Err(SDSSErrorKind::DataBatchRestoreCorruptedBatch.into())
}
fn read_start_batch_block(&mut self) -> SDSSResult<BatchStartBlock> {
let pk_tag = self.f.read_byte()?;
@ -467,7 +467,7 @@ impl BatchStartBlock {
impl<F: RawFSInterface> DataBatchRestoreDriver<F> {
fn decode_primary_key(&mut self, pk_type: u8) -> SDSSResult<PrimaryIndexKey> {
let Some(pk_type) = TagUnique::try_from_raw(pk_type) else {
return Err(SDSSError::DataBatchRestoreCorruptedEntry);
return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into());
};
Ok(match pk_type {
TagUnique::SignedInt | TagUnique::UnsignedInt => {
@ -483,7 +483,7 @@ impl<F: RawFSInterface> DataBatchRestoreDriver<F> {
self.f.read_into_buffer(&mut data)?;
if pk_type == TagUnique::Str {
if core::str::from_utf8(&data).is_err() {
return Err(SDSSError::DataBatchRestoreCorruptedEntry);
return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into());
}
}
unsafe {
@ -501,14 +501,14 @@ impl<F: RawFSInterface> DataBatchRestoreDriver<F> {
fn decode_cell(&mut self) -> SDSSResult<Datacell> {
let cell_type_sig = self.f.read_byte()?;
let Some(cell_type) = PersistTypeDscr::try_from_raw(cell_type_sig) else {
return Err(SDSSError::DataBatchRestoreCorruptedEntry);
return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into());
};
Ok(match cell_type {
PersistTypeDscr::Null => Datacell::null(),
PersistTypeDscr::Bool => {
let bool = self.f.read_byte()?;
if bool > 1 {
return Err(SDSSError::DataBatchRestoreCorruptedEntry);
return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into());
}
Datacell::new_bool(bool == 1)
}
@ -528,7 +528,7 @@ impl<F: RawFSInterface> DataBatchRestoreDriver<F> {
// UNSAFE(@ohsayan): +tagck
if cell_type == PersistTypeDscr::Str {
if core::str::from_utf8(&data).is_err() {
return Err(SDSSError::DataBatchRestoreCorruptedEntry);
return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into());
}
Datacell::new_str(String::from_utf8_unchecked(data).into_boxed_str())
} else {
@ -543,13 +543,13 @@ impl<F: RawFSInterface> DataBatchRestoreDriver<F> {
list.push(self.decode_cell()?);
}
if len != list.len() as u64 {
return Err(SDSSError::DataBatchRestoreCorruptedEntry);
return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into());
}
Datacell::new_list(list)
}
PersistTypeDscr::Dict => {
// we don't support dicts just yet
return Err(SDSSError::DataBatchRestoreCorruptedEntry);
return Err(SDSSErrorKind::DataBatchRestoreCorruptedEntry.into());
}
})
}

@ -40,7 +40,7 @@ use {
},
idx::{IndexBaseSpec, IndexSTSeqCns, STIndex, STIndexSeq},
mem::BufferedScanner,
storage::v1::{inf, SDSSError, SDSSResult},
storage::v1::{inf, SDSSError, SDSSErrorKind, SDSSResult},
},
util::{copy_slice_to_array as memcpy, EndianQW},
},
@ -92,11 +92,12 @@ where
while M::pretest_entry_metadata(scanner) & (dict.st_len() != dict_size) {
let md = unsafe {
// UNSAFE(@ohsayan): +pretest
M::entry_md_dec(scanner)
.ok_or(SDSSError::InternalDecodeStructureCorruptedPayload)?
M::entry_md_dec(scanner).ok_or::<SDSSError>(
SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into(),
)?
};
if !M::pretest_entry_data(scanner, &md) {
return Err(SDSSError::InternalDecodeStructureCorruptedPayload);
return Err(SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into());
}
let key;
let val;
@ -107,7 +108,11 @@ where
key = _k;
val = _v;
}
None => return Err(SDSSError::InternalDecodeStructureCorruptedPayload),
None => {
return Err(
SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into()
)
}
}
} else {
let _k = M::dec_key(scanner, &md);
@ -117,18 +122,22 @@ where
key = _k;
val = _v;
}
_ => return Err(SDSSError::InternalDecodeStructureCorruptedPayload),
_ => {
return Err(
SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into()
)
}
}
}
}
if !dict.st_insert(key, val) {
return Err(SDSSError::InternalDecodeStructureIllegalData);
return Err(SDSSErrorKind::InternalDecodeStructureIllegalData.into());
}
}
if dict.st_len() == dict_size {
Ok(dict)
} else {
Err(SDSSError::InternalDecodeStructureIllegalData)
Err(SDSSErrorKind::InternalDecodeStructureIllegalData.into())
}
}
}

@ -42,7 +42,7 @@ use {
},
idx::{AsKey, AsValue},
mem::BufferedScanner,
storage::v1::{SDSSError, SDSSResult},
storage::v1::{SDSSErrorKind, SDSSResult},
},
std::mem,
};
@ -157,14 +157,14 @@ pub trait PersistObject {
/// Default routine to decode an object + its metadata (however, the metadata is used and not returned)
fn default_full_dec(scanner: &mut BufferedScanner) -> SDSSResult<Self::OutputType> {
if !Self::pretest_can_dec_metadata(scanner) {
return Err(SDSSError::InternalDecodeStructureCorrupted);
return Err(SDSSErrorKind::InternalDecodeStructureCorrupted.into());
}
let md = unsafe {
// UNSAFE(@ohsayan): +pretest
Self::meta_dec(scanner)?
};
if !Self::pretest_can_dec_object(scanner, &md) {
return Err(SDSSError::InternalDecodeStructureCorruptedPayload);
return Err(SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into());
}
unsafe {
// UNSAFE(@ohsayan): +obj pretest
@ -290,11 +290,11 @@ pub mod dec {
pub mod utils {
use crate::engine::{
mem::BufferedScanner,
storage::v1::{SDSSError, SDSSResult},
storage::v1::{SDSSErrorKind, SDSSResult},
};
pub unsafe fn decode_string(s: &mut BufferedScanner, len: usize) -> SDSSResult<String> {
String::from_utf8(s.next_chunk_variable(len).to_owned())
.map_err(|_| SDSSError::InternalDecodeStructureCorruptedPayload)
.map_err(|_| SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into())
}
}
}

@ -39,7 +39,7 @@ use {
DictGeneric,
},
mem::{BufferedScanner, VInline},
storage::v1::{inf, SDSSError, SDSSResult},
storage::v1::{inf, SDSSErrorKind, SDSSResult},
},
util::EndianQW,
},
@ -119,7 +119,7 @@ impl<'a> PersistObject for LayerRef<'a> {
fn obj_enc(_: &mut VecU8, _: Self::InputType) {}
unsafe fn obj_dec(_: &mut BufferedScanner, md: Self::Metadata) -> SDSSResult<Self::OutputType> {
if (md.type_selector > TagSelector::List.value_qword()) | (md.prop_set_arity != 0) {
return Err(SDSSError::InternalDecodeStructureCorruptedPayload);
return Err(SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into());
}
Ok(Layer::new_empty_props(
TagSelector::from_raw(md.type_selector as u8).into_full(),
@ -202,7 +202,7 @@ impl<'a> PersistObject for FieldRef<'a> {
if (field.layers().len() as u64 == md.layer_c) & (md.null <= 1) & (md.prop_c == 0) & fin {
Ok(field)
} else {
Err(SDSSError::InternalDecodeStructureCorrupted)
Err(SDSSErrorKind::InternalDecodeStructureCorrupted.into())
}
}
}
@ -281,7 +281,7 @@ impl<'a> PersistObject for ModelLayoutRef<'a> {
super::map::MapIndexSizeMD(md.field_c as usize),
)?;
let ptag = if md.p_key_tag > TagSelector::MAX as u64 {
return Err(SDSSError::InternalDecodeStructureCorruptedPayload);
return Err(SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into());
} else {
TagSelector::from_raw(md.p_key_tag as u8)
};

@ -44,7 +44,7 @@
use {
super::{
rw::{FileOpen, RawFSInterface, SDSSFileIO},
spec, SDSSError, SDSSResult,
spec, SDSSErrorKind, SDSSResult,
},
crate::util::{compiler, copy_a_into_b, copy_slice_to_array as memcpy},
std::marker::PhantomData,
@ -222,7 +222,7 @@ impl<TA: JournalAdapter, Fs: RawFSInterface> JournalReader<TA, Fs> {
}
match entry_metadata
.event_source_marker()
.ok_or(SDSSError::JournalLogEntryCorrupted)?
.ok_or(SDSSErrorKind::JournalLogEntryCorrupted)?
{
EventSourceMarker::ServerStandard => {}
EventSourceMarker::DriverClosed => {
@ -237,7 +237,7 @@ impl<TA: JournalAdapter, Fs: RawFSInterface> JournalReader<TA, Fs> {
EventSourceMarker::DriverReopened | EventSourceMarker::RecoveryReverseLastJournal => {
// these two are only taken in close and error paths (respectively) so we shouldn't see them here; this is bad
// two special directives in the middle of nowhere? incredible
return Err(SDSSError::JournalCorrupted);
return Err(SDSSErrorKind::JournalCorrupted.into());
}
}
// read payload
@ -270,10 +270,10 @@ impl<TA: JournalAdapter, Fs: RawFSInterface> JournalReader<TA, Fs> {
Ok(())
} else {
// FIXME(@ohsayan): tolerate loss in this directive too
Err(SDSSError::JournalCorrupted)
Err(SDSSErrorKind::JournalCorrupted.into())
}
} else {
Err(SDSSError::JournalCorrupted)
Err(SDSSErrorKind::JournalCorrupted.into())
}
}
#[cold] // FIXME(@ohsayan): how bad can prod systems be? (clue: pretty bad, so look for possible changes)
@ -286,7 +286,7 @@ impl<TA: JournalAdapter, Fs: RawFSInterface> JournalReader<TA, Fs> {
self.__record_read_bytes(JournalEntryMetadata::SIZE); // FIXME(@ohsayan): don't assume read length?
let mut entry_buf = [0u8; JournalEntryMetadata::SIZE];
if self.log_file.read_to_buffer(&mut entry_buf).is_err() {
return Err(SDSSError::JournalCorrupted);
return Err(SDSSErrorKind::JournalCorrupted.into());
}
let entry = JournalEntryMetadata::decode(entry_buf);
let okay = (entry.event_id == self.evid as u128)
@ -297,7 +297,7 @@ impl<TA: JournalAdapter, Fs: RawFSInterface> JournalReader<TA, Fs> {
if okay {
return Ok(());
} else {
Err(SDSSError::JournalCorrupted)
Err(SDSSErrorKind::JournalCorrupted.into())
}
}
/// Read and apply all events in the given log file to the global state, returning the (open file, last event ID)
@ -309,7 +309,7 @@ impl<TA: JournalAdapter, Fs: RawFSInterface> JournalReader<TA, Fs> {
if slf.closed {
Ok((slf.log_file, slf.evid))
} else {
Err(SDSSError::JournalCorrupted)
Err(SDSSErrorKind::JournalCorrupted.into())
}
}
}
@ -397,7 +397,7 @@ impl<Fs: RawFSInterface, TA> JournalWriter<Fs, TA> {
if self.log_file.fsynced_write(&entry.encoded()).is_ok() {
return Ok(());
}
Err(SDSSError::JournalWRecoveryStageOneFailCritical)
Err(SDSSErrorKind::JournalWRecoveryStageOneFailCritical.into())
}
pub fn append_journal_reopen(&mut self) -> SDSSResult<()> {
let id = self._incr_id() as u128;

@ -32,7 +32,7 @@ use crate::engine::{
batch_jrnl,
journal::{self, JournalWriter},
rw::{FileOpen, RawFSInterface},
spec, LocalFS, SDSSErrorContext, SDSSResult,
spec, LocalFS, SDSSResult,
},
txn::gns::{GNSAdapter, GNSTransactionDriverAnyFS},
};
@ -73,14 +73,11 @@ impl SEInitState {
for (model_name, model) in space.models().read().iter() {
let path =
Self::model_path(space_name, space_uuid, model_name, model.get_uuid());
let persist_driver = match batch_jrnl::reinit(&path, model) {
Ok(j) => j,
Err(e) => {
return Err(e.with_extra(format!(
"failed to restore model data from journal in `{path}`"
)))
}
};
let persist_driver = batch_jrnl::reinit(&path, model).map_err(|e| {
e.add_ctx(format!(
"failed to restore model data from journal in `{path}`"
))
})?;
let _ = model_drivers.insert(
ModelUniqueID::new(space_name, model_name, model.get_uuid()),
FractalModelDriver::init(persist_driver),

@ -33,7 +33,6 @@ pub mod spec;
mod sysdb;
// hl
pub mod inf;
mod start_stop;
// test
pub mod memfs;
#[cfg(test)]
@ -49,49 +48,25 @@ pub mod data_batch {
pub use super::batch_jrnl::{create, reinit, DataBatchPersistDriver, DataBatchRestoreDriver};
}
use crate::{engine::txn::TransactionError, util::os::SysIOError as IoError};
pub type SDSSResult<T> = Result<T, SDSSError>;
pub trait SDSSErrorContext {
type ExtraData;
fn with_extra(self, extra: Self::ExtraData) -> SDSSError;
}
impl SDSSErrorContext for IoError {
type ExtraData = &'static str;
fn with_extra(self, extra: Self::ExtraData) -> SDSSError {
SDSSError::IoErrorExtra(self, extra)
}
}
impl SDSSErrorContext for std::io::Error {
type ExtraData = &'static str;
fn with_extra(self, extra: Self::ExtraData) -> SDSSError {
SDSSError::IoErrorExtra(self.into(), extra)
}
}
use crate::{
engine::{
error::{CtxError, CtxResult},
txn::TransactionError,
},
util::os::SysIOError as IoError,
};
impl SDSSErrorContext for SDSSError {
type ExtraData = String;
fn with_extra(self, extra: Self::ExtraData) -> SDSSError {
SDSSError::Extra(Box::new(self), extra)
}
}
pub type SDSSResult<T> = CtxResult<T, SDSSErrorKind>;
pub type SDSSError = CtxError<SDSSErrorKind>;
#[derive(Debug)]
#[cfg_attr(test, derive(PartialEq))]
pub enum SDSSError {
pub enum SDSSErrorKind {
// IO errors
/// An IO err
IoError(IoError),
/// An IO err with extra ctx
IoErrorExtra(IoError, &'static str),
/// A corrupted file
CorruptedFile(&'static str),
// process errors
OtherError(&'static str),
CorruptedFile(&'static str),
// header
/// version mismatch
HeaderDecodeVersionMismatch,
@ -127,41 +102,18 @@ pub enum SDSSError {
DataBatchCloseError,
DataBatchRestoreCorruptedBatchFile,
JournalRestoreTxnError,
/// An error with more context
// TODO(@ohsayan): avoid the box; we'll clean this up soon
Extra(Box<Self>, String),
SysDBCorrupted,
}
impl From<TransactionError> for SDSSError {
impl From<TransactionError> for SDSSErrorKind {
fn from(_: TransactionError) -> Self {
Self::JournalRestoreTxnError
}
}
impl SDSSError {
pub const fn corrupted_file(fname: &'static str) -> Self {
Self::CorruptedFile(fname)
}
pub const fn ioerror_extra(error: IoError, extra: &'static str) -> Self {
Self::IoErrorExtra(error, extra)
}
pub fn with_ioerror_extra(self, extra: &'static str) -> Self {
match self {
Self::IoError(ioe) => Self::IoErrorExtra(ioe, extra),
x => x,
}
}
}
impl From<IoError> for SDSSError {
fn from(e: IoError) -> Self {
Self::IoError(e)
}
}
impl From<std::io::Error> for SDSSError {
fn from(e: std::io::Error) -> Self {
Self::IoError(e.into())
direct_from! {
SDSSErrorKind => {
std::io::Error as IoError,
IoError as IoError,
}
}

@ -29,10 +29,7 @@ use {
spec::{FileSpec, Header},
SDSSResult,
},
crate::{
engine::storage::{v1::SDSSError, SCrc},
util::os::SysIOError,
},
crate::{engine::storage::SCrc, util::os::SysIOError},
std::{
fs::{self, File},
io::{BufReader, BufWriter, Read, Seek, SeekFrom, Write},
@ -348,9 +345,7 @@ impl<Fs: RawFSInterface> SDSSFileTrackedReader<Fs> {
Err(e) => return Err(e),
}
} else {
Err(SDSSError::IoError(SysIOError::from(
std::io::ErrorKind::InvalidInput,
)))
Err(SysIOError::from(std::io::ErrorKind::InvalidInput).into())
}
}
pub fn read_byte(&mut self) -> SDSSResult<u8> {
@ -373,9 +368,7 @@ impl<Fs: RawFSInterface> SDSSFileTrackedReader<Fs> {
}
pub fn read_block<const N: usize>(&mut self) -> SDSSResult<[u8; N]> {
if !self.has_left(N as _) {
return Err(SDSSError::IoError(SysIOError::from(
std::io::ErrorKind::InvalidInput,
)));
return Err(SysIOError::from(std::io::ErrorKind::InvalidInput).into());
}
let mut buf = [0; N];
self.read_into_buffer(&mut buf)?;

@ -40,7 +40,7 @@ use {
crate::{
engine::storage::{
header::{HostArch, HostEndian, HostOS, HostPointerWidth},
v1::SDSSError,
v1::SDSSErrorKind,
versions::{self, DriverVersion, HeaderVersion, ServerVersion},
},
util::os,
@ -375,12 +375,12 @@ impl SDSSStaticHeaderV1Compact {
} else {
let version_okay = okay_header_version & okay_server_version & okay_driver_version;
let md = ManuallyDrop::new([
SDSSError::HeaderDecodeCorruptedHeader,
SDSSError::HeaderDecodeVersionMismatch,
SDSSErrorKind::HeaderDecodeCorruptedHeader,
SDSSErrorKind::HeaderDecodeVersionMismatch,
]);
Err(unsafe {
// UNSAFE(@ohsayan): while not needed, md for drop safety + correct index
md.as_ptr().add(!version_okay as usize).read()
md.as_ptr().add(!version_okay as usize).read().into()
})
}
}
@ -511,7 +511,7 @@ impl Header for SDSSStaticHeaderV1Compact {
{
Ok(())
} else {
Err(SDSSError::HeaderDecodeDataMismatch)
Err(SDSSErrorKind::HeaderDecodeDataMismatch.into())
}
}
}

@ -1,150 +0,0 @@
/*
* Created on Mon May 29 2023
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2023, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
use {
super::{SDSSError, SDSSErrorContext, SDSSResult},
crate::util::os,
std::{
fs::File,
io::{ErrorKind, Read, Write},
},
};
#[cfg(not(test))]
const START_FILE: &'static str = ".start";
#[cfg(test)]
const START_FILE: &'static str = ".start_testmode";
#[cfg(not(test))]
const STOP_FILE: &'static str = ".stop";
#[cfg(test)]
const STOP_FILE: &'static str = ".stop_testmode";
const EMSG_FAILED_WRITE_START_FILE: &str =
concat_str_to_str!("failed to write to `", START_FILE, "` file");
const EMSG_FAILED_WRITE_STOP_FILE: &str =
concat_str_to_str!("failed to write to `", STOP_FILE, "` file");
const EMSG_FAILED_OPEN_START_FILE: &str =
concat_str_to_str!("failed to open `", START_FILE, "` file");
const EMSG_FAILED_OPEN_STOP_FILE: &str =
concat_str_to_str!("failed to open `", STOP_FILE, "` file");
const EMSG_FAILED_VERIFY: &str = concat_str_to_str!(
"failed to verify `",
START_FILE,
concat_str_to_str!("` and `", STOP_FILE, "` timestamps")
);
#[derive(Debug)]
pub struct StartStop {
begin: u128,
stop_file: File,
}
#[derive(Debug)]
enum ReadNX {
Created(File),
Read(File, u128),
}
impl ReadNX {
const fn created(&self) -> bool {
matches!(self, Self::Created(_))
}
fn file_mut(&mut self) -> &mut File {
match self {
Self::Created(ref mut f) => f,
Self::Read(ref mut f, _) => f,
}
}
fn into_file(self) -> File {
match self {
Self::Created(f) => f,
Self::Read(f, _) => f,
}
}
}
impl StartStop {
fn read_time_file(f: &str, create_new_if_nx: bool) -> SDSSResult<ReadNX> {
let mut f = match File::options().write(true).read(true).open(f) {
Ok(f) => f,
Err(e) if e.kind() == ErrorKind::NotFound && create_new_if_nx => {
let f = File::create(f)?;
return Ok(ReadNX::Created(f));
}
Err(e) => return Err(e.into()),
};
let len = f.metadata().map(|m| m.len())?;
if len != sizeof!(u128) as u64 {
return Err(SDSSError::corrupted_file(START_FILE));
}
let mut buf = [0u8; sizeof!(u128)];
f.read_exact(&mut buf)?;
Ok(ReadNX::Read(f, u128::from_le_bytes(buf)))
}
pub fn terminate(mut self) -> SDSSResult<()> {
self.stop_file
.write_all(self.begin.to_le_bytes().as_ref())
.map_err(|e| e.with_extra(EMSG_FAILED_WRITE_STOP_FILE))
}
pub fn verify_and_start() -> SDSSResult<Self> {
// read start file
let mut start_file = Self::read_time_file(START_FILE, true)
.map_err(|e| e.with_ioerror_extra(EMSG_FAILED_OPEN_START_FILE))?;
// read stop file
let stop_file = Self::read_time_file(STOP_FILE, start_file.created())
.map_err(|e| e.with_ioerror_extra(EMSG_FAILED_OPEN_STOP_FILE))?;
// read current time
let ctime = os::get_epoch_time();
match (&start_file, &stop_file) {
(ReadNX::Read(_, time_start), ReadNX::Read(_, time_stop))
if time_start == time_stop => {}
(ReadNX::Created(_), ReadNX::Created(_)) => {}
_ => return Err(SDSSError::OtherError(EMSG_FAILED_VERIFY)),
}
start_file
.file_mut()
.write_all(&ctime.to_le_bytes())
.map_err(|e| e.with_extra(EMSG_FAILED_WRITE_START_FILE))?;
Ok(Self {
stop_file: stop_file.into_file(),
begin: ctime,
})
}
}
#[test]
fn verify_test() {
let x = || -> SDSSResult<()> {
let ss = StartStop::verify_and_start()?;
ss.terminate()?;
let ss = StartStop::verify_and_start()?;
ss.terminate()?;
std::fs::remove_file(START_FILE)?;
std::fs::remove_file(STOP_FILE)?;
Ok(())
};
x().unwrap();
}

@ -25,7 +25,7 @@
*/
use {
super::{rw::FileOpen, SDSSError},
super::{rw::FileOpen, SDSSErrorKind},
crate::engine::{
config::ConfigAuth,
data::{cell::Datacell, DictEntryGeneric, DictGeneric},
@ -175,7 +175,7 @@ fn rkey<T>(
) -> SDSSResult<T> {
match d.remove(key).map(transform) {
Some(Some(k)) => Ok(k),
_ => Err(SDSSError::SysDBCorrupted),
_ => Err(SDSSErrorKind::SysDBCorrupted.into()),
}
}
@ -201,14 +201,14 @@ pub fn decode_system_database<Fs: RawFSInterface>(mut f: SDSSFileIO<Fs>) -> SDSS
let mut userdata = userdata
.into_data()
.and_then(Datacell::into_list)
.ok_or(SDSSError::SysDBCorrupted)?;
.ok_or(SDSSErrorKind::SysDBCorrupted)?;
if userdata.len() != 1 {
return Err(SDSSError::SysDBCorrupted);
return Err(SDSSErrorKind::SysDBCorrupted.into());
}
let user_password = userdata
.remove(0)
.into_bin()
.ok_or(SDSSError::SysDBCorrupted)?;
.ok_or(SDSSErrorKind::SysDBCorrupted)?;
loaded_users.insert(username, SysAuthUser::new(user_password.into_boxed_slice()));
}
let sys_auth = SysAuth::new(root_key.into_boxed_slice(), loaded_users);
@ -220,7 +220,7 @@ pub fn decode_system_database<Fs: RawFSInterface>(mut f: SDSSFileIO<Fs>) -> SDSS
d.into_data()?.into_uint()
})?;
if !(sysdb_data.is_empty() & auth_store.is_empty() & sys_store.is_empty()) {
return Err(SDSSError::SysDBCorrupted);
return Err(SDSSErrorKind::SysDBCorrupted.into());
}
Ok(SysConfig::new(
RwLock::new(sys_auth),

@ -28,7 +28,7 @@ use {
crate::{
engine::storage::v1::{
journal::{self, JournalAdapter, JournalWriter},
spec, SDSSError, SDSSResult,
spec, SDSSError, SDSSErrorKind, SDSSResult,
},
util,
},
@ -115,7 +115,9 @@ impl JournalAdapter for DatabaseTxnAdapter {
fn decode_and_update_state(payload: &[u8], gs: &Self::GlobalState) -> Result<(), TxError> {
if payload.len() != 10 {
return Err(SDSSError::CorruptedFile("testtxn.log").into());
return Err(TxError::SDSS(
SDSSErrorKind::CorruptedFile("testtxn.log").into(),
));
}
let opcode = payload[0];
let index = u64::from_le_bytes(util::copy_slice_to_array(&payload[1..9]));
@ -123,7 +125,11 @@ impl JournalAdapter for DatabaseTxnAdapter {
match opcode {
0 if index == 0 && new_value == 0 => gs.reset(),
1 if index < 10 && index < isize::MAX as u64 => gs.set(index as usize, new_value),
_ => return Err(SDSSError::JournalLogEntryCorrupted.into()),
_ => {
return Err(TxError::SDSS(
SDSSErrorKind::JournalLogEntryCorrupted.into(),
))
}
}
Ok(())
}

@ -90,8 +90,12 @@ mod cfg {
#[test]
fn parse_validate_cli_args() {
with_files(
["__cli_args_test_private.key", "__cli_args_test_cert.pem"],
|[pkey, cert]| {
[
"__cli_args_test_private.key",
"__cli_args_test_cert.pem",
"__cli_args_test_passphrase.key",
],
|[pkey, cert, pass]| {
let payload = format!(
"skyd --mode=dev \
--endpoint tcp@127.0.0.1:2003 \
@ -99,6 +103,7 @@ mod cfg {
--service-window=600 \
--tlskey {pkey} \
--tlscert {cert} \
--tls-passphrase {pass} \
--auth-plugin pwd \
--auth-root-password password12345678
"
@ -115,6 +120,7 @@ mod cfg {
ConfigEndpointTls::new(
ConfigEndpointTcp::new("127.0.0.2".into(), 2004),
"".into(),
"".into(),
"".into()
)
),
@ -205,13 +211,18 @@ mod cfg {
#[test]
fn parse_validate_env_args() {
with_files(
["__env_args_test_cert.pem", "__env_args_test_private.key"],
|[cert, key]| {
[
"__env_args_test_cert.pem",
"__env_args_test_private.key",
"__env_args_test_private.passphrase.txt",
],
|[cert, key, pass]| {
let variables = [
format!("SKYDB_AUTH_PLUGIN=pwd"),
format!("SKYDB_AUTH_ROOT_PASSWORD=password12345678"),
format!("SKYDB_TLS_CERT={cert}"),
format!("SKYDB_TLS_KEY={key}"),
format!("SKYDB_TLS_PRIVATE_KEY_PASSWORD={pass}"),
format!("SKYDB_ENDPOINTS=tcp@localhost:8080,tls@localhost:8081"),
format!("SKYDB_RUN_MODE=dev"),
format!("SKYDB_SERVICE_WINDOW=600"),
@ -226,6 +237,7 @@ mod cfg {
ConfigEndpointTls::new(
ConfigEndpointTcp::new("localhost".into(), 8081),
"".into(),
"".into(),
"".into()
)
),
@ -252,6 +264,7 @@ endpoints:
port: 2004
cert: ._test_sample_cert.pem
private_key: ._test_sample_private.key
pkey_passphrase: ._test_sample_private.pass.txt
insecure:
host: 127.0.0.1
port: 2003
@ -259,7 +272,11 @@ endpoints:
#[test]
fn test_config_file() {
with_files(
["._test_sample_cert.pem", "._test_sample_private.key"],
[
"._test_sample_cert.pem",
"._test_sample_private.key",
"._test_sample_private.pass.txt",
],
|_| {
config::set_cli_src(vec!["skyd".into(), "--config=config.yml".into()]);
config::set_file_src(CONFIG_FILE);
@ -272,6 +289,7 @@ endpoints:
ConfigEndpointTls::new(
ConfigEndpointTcp::new("127.0.0.1".into(), 2004),
"".into(),
"".into(),
"".into()
)
),

@ -39,7 +39,7 @@ use {
ql::lex::Ident,
storage::v1::{
inf::{self, map, obj, PersistObject},
SDSSError, SDSSResult,
SDSSErrorKind, SDSSResult,
},
txn::TransactionError,
},
@ -498,7 +498,7 @@ impl<'a> PersistObject for AlterModelRemoveTxn<'a> {
removed_fields.push(inf::dec::utils::decode_string(s, len)?.into_boxed_str());
}
if removed_fields.len() as u64 != md.remove_field_c {
return Err(SDSSError::InternalDecodeStructureCorruptedPayload);
return Err(SDSSErrorKind::InternalDecodeStructureCorruptedPayload.into());
}
Ok(AlterModelRemoveTxnRestorePL {
model_id,

Loading…
Cancel
Save