From b2e130626f1375177c7fe35afbfbc6c13c2debcc Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Mon, 25 Apr 2022 10:44:57 -0700 Subject: [PATCH 01/13] Simplify types --- server/src/dbnet/connection.rs | 42 ++++++++++++++++------------------ server/src/dbnet/mod.rs | 2 +- 2 files changed, 21 insertions(+), 23 deletions(-) diff --git a/server/src/dbnet/connection.rs b/server/src/dbnet/connection.rs index a249ca60..b1af6c3a 100644 --- a/server/src/dbnet/connection.rs +++ b/server/src/dbnet/connection.rs @@ -51,10 +51,8 @@ use crate::{ }; use bytes::{Buf, BytesMut}; use std::{ - future::Future, io::{Error as IoError, ErrorKind}, marker::PhantomData, - pin::Pin, sync::Arc, }; use tokio::{ @@ -109,16 +107,20 @@ pub mod prelude { //! //! This module is hollow itself, it only re-exports from `dbnet::con` and `tokio::io` pub use super::{AuthProviderHandle, ClientConnection, ProtocolConnectionExt, Stream}; - pub use crate::actions::{ensure_boolean_or_aerr, ensure_cond_or_err, ensure_length}; - pub use crate::corestore::{ - table::{KVEBlob, KVEList}, - Corestore, + pub use crate::{ + actions::{ensure_boolean_or_aerr, ensure_cond_or_err, ensure_length}, + aerr, conwrite, + corestore::{ + table::{KVEBlob, KVEList}, + Corestore, + }, + get_tbl, handle_entity, is_lowbit_set, + protocol::responses::{self, groups}, + queryengine::ActionIter, + registry, + resp::StringWrapper, + util::{self, FutureResult, UnwrapActionError, Unwrappable}, }; - pub use crate::protocol::responses::{self, groups}; - pub use crate::queryengine::ActionIter; - pub use crate::resp::StringWrapper; - pub use crate::util::{self, FutureResult, UnwrapActionError, Unwrappable}; - pub use crate::{aerr, conwrite, get_tbl, handle_entity, is_lowbit_set, registry}; pub use tokio::io::{AsyncReadExt, AsyncWriteExt}; } @@ -134,7 +136,7 @@ pub mod prelude { /// good trouble. pub trait ProtocolConnectionExt: ProtocolConnection + Send where - Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, + Strm: Stream, { /// Try to parse a query from the buffered data fn try_query(&self) -> Result { @@ -144,9 +146,7 @@ where /// /// This function asynchronously waits until all the data required /// for parsing the query is available - fn read_query<'r, 's>( - &'r mut self, - ) -> Pin> + Send + 's>> + fn read_query<'r, 's>(&'r mut self) -> FutureResult<'s, Result> where 'r: 's, Self: Sync + Send + 's, @@ -183,7 +183,7 @@ where fn write_response<'r, 's>( &'r mut self, streamer: impl Writable + 's + Send + Sync, - ) -> Pin> + Sync + Send + 's>> + ) -> FutureResult<'s, IoResult<()>> where 'r: 's, Self: Send + 's, @@ -199,10 +199,8 @@ where ret }) } - /// Write the simple query header `*1\n` to the stream - fn write_simple_query_header<'r, 's>( - &'r mut self, - ) -> Pin> + Send + Sync + 's>> + /// Write the simple query header `*` to the stream + fn write_simple_query_header<'r, 's>(&'r mut self) -> FutureResult<'s, IoResult<()>> where 'r: 's, Self: Send + Sync + 's, @@ -353,7 +351,7 @@ pub trait ProtocolConnection { impl ProtocolConnectionExt for T where T: ProtocolConnection + Send, - Strm: Sync + Send + Unpin + AsyncWriteExt + AsyncReadExt, + Strm: Stream, { } @@ -400,7 +398,7 @@ pub struct ConnectionHandler { impl ConnectionHandler where T: ProtocolConnectionExt + Send + Sync, - Strm: Sync + Send + Unpin + AsyncWriteExt + AsyncReadExt, + Strm: Stream, { pub fn new( db: Corestore, diff --git a/server/src/dbnet/mod.rs b/server/src/dbnet/mod.rs index 109b5570..37c1eb2c 100644 --- a/server/src/dbnet/mod.rs +++ b/server/src/dbnet/mod.rs @@ -261,6 +261,6 @@ pub async fn connect( MultiListener::new_multi(secure_listener, insecure_listener, ssl).await? } }; - log::info!("Server started on {}", description); + log::info!("Server started on {description}"); Ok(server) } From 02a3e9b4e9494b3aa286acf9d022e4135c6dda9b Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Mon, 25 Apr 2022 23:17:00 -0700 Subject: [PATCH 02/13] Make connections generic over protocols --- server/src/auth/mod.rs | 14 +++---- server/src/dbnet/connection.rs | 75 ++++++++++++++++++++++------------ server/src/dbnet/tcp.rs | 12 ++++-- server/src/dbnet/tls.rs | 17 ++++---- server/src/protocol/mod.rs | 24 ++++++----- server/src/queryengine/mod.rs | 19 +++++---- server/src/resp/writer.rs | 40 ++++++++++-------- server/src/util/macros.rs | 27 +++++++++--- 8 files changed, 144 insertions(+), 84 deletions(-) diff --git a/server/src/auth/mod.rs b/server/src/auth/mod.rs index 7fa772e0..f0ddc5b5 100644 --- a/server/src/auth/mod.rs +++ b/server/src/auth/mod.rs @@ -61,7 +61,7 @@ action! { /// Handle auth. Should have passed the `auth` token fn auth( con: &mut T, - auth: &mut AuthProviderHandle<'_, T, Strm>, + auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: ActionIter<'_> ) { let mut iter = iter; @@ -94,12 +94,12 @@ action! { _ => util::err(groups::UNKNOWN_ACTION), } } - fn auth_whoami(con: &mut T, auth: &mut AuthProviderHandle<'_, T, Strm>, iter: &mut ActionIter<'_>) { + fn auth_whoami(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { ensure_boolean_or_aerr(ActionIter::is_empty(iter))?; con.write_response(StringWrapper(auth.provider().whoami()?)).await?; Ok(()) } - fn auth_listuser(con: &mut T, auth: &mut AuthProviderHandle<'_, T, Strm>, iter: &mut ActionIter<'_>) { + fn auth_listuser(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { ensure_boolean_or_aerr(ActionIter::is_empty(iter))?; let usernames = auth.provider().collect_usernames()?; let mut array_writer = unsafe { @@ -111,7 +111,7 @@ action! { } Ok(()) } - fn auth_restore(con: &mut T, auth: &mut AuthProviderHandle<'_, T, Strm>, iter: &mut ActionIter<'_>) { + fn auth_restore(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { let newkey = match iter.len() { 1 => { // so this fella thinks they're root @@ -128,7 +128,7 @@ action! { con.write_response(StringWrapper(newkey)).await?; Ok(()) } - fn _auth_claim(con: &mut T, auth: &mut AuthProviderHandle<'_, T, Strm>, iter: &mut ActionIter<'_>) { + fn _auth_claim(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { ensure_boolean_or_aerr(iter.len() == 1)?; // just the origin key let origin_key = unsafe { iter.next_unchecked() }; let key = auth.provider_mut().claim_root(origin_key)?; @@ -139,7 +139,7 @@ action! { /// Handle a login operation only. The **`login` token is expected to be present** fn auth_login_only( con: &mut T, - auth: &mut AuthProviderHandle<'_, T, Strm>, + auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: ActionIter<'_> ) { let mut iter = iter; @@ -151,7 +151,7 @@ action! { _ => util::err(errors::AUTH_CODE_PERMS), } } - fn _auth_login(con: &mut T, auth: &mut AuthProviderHandle<'_, T, Strm>, iter: &mut ActionIter<'_>) { + fn _auth_login(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { // sweet, where's our username and password ensure_boolean_or_aerr(iter.len() == 2)?; // just the uname and pass let (username, password) = unsafe { (iter.next_unchecked(), iter.next_unchecked()) }; diff --git a/server/src/dbnet/connection.rs b/server/src/dbnet/connection.rs index b1af6c3a..8dfc6583 100644 --- a/server/src/dbnet/connection.rs +++ b/server/src/dbnet/connection.rs @@ -44,7 +44,7 @@ use crate::{ tcp::{BufferedSocketStream, Connection}, Terminator, }, - protocol::{self, responses, ParseError, Query}, + protocol::{responses, ParseError, Query, Skyhash2}, queryengine, resp::Writable, IoResult, @@ -61,7 +61,19 @@ use tokio::{ }; pub const SIMPLE_QUERY_HEADER: [u8; 1] = [b'*']; -type QueryWithAdvance = (Query, usize); +pub(super) type QueryWithAdvance = (Query, usize); + +/// The [`ProtocolSpec`] trait implementation enables extremely easy switching between +/// protocols by being generic for the same base connection types +pub trait ProtocolSpec: Send + Sync { + fn parse(buf: &[u8]) -> Result; +} + +impl ProtocolSpec for Skyhash2 { + fn parse(buf: &[u8]) -> Result { + Skyhash2::parse(buf) + } +} pub enum QueryResult { Q(QueryWithAdvance), @@ -70,18 +82,19 @@ pub enum QueryResult { Disconnected, } -pub struct AuthProviderHandle<'a, T, Strm> { +pub struct AuthProviderHandle<'a, P: ProtocolSpec, T, Strm> { provider: &'a mut AuthProvider, - executor: &'a mut ExecutorFn, + executor: &'a mut ExecutorFn, _phantom: PhantomData<(T, Strm)>, } -impl<'a, T, Strm> AuthProviderHandle<'a, T, Strm> +impl<'a, P, T, Strm> AuthProviderHandle<'a, P, T, Strm> where - T: ClientConnection, + T: ClientConnection, Strm: Stream, + P: ProtocolSpec, { - pub fn new(provider: &'a mut AuthProvider, executor: &'a mut ExecutorFn) -> Self { + pub fn new(provider: &'a mut AuthProvider, executor: &'a mut ExecutorFn) -> Self { Self { provider, executor, @@ -106,7 +119,9 @@ pub mod prelude { //! A 'prelude' for callers that would like to use the `ProtocolConnection` and `ProtocolConnectionExt` traits //! //! This module is hollow itself, it only re-exports from `dbnet::con` and `tokio::io` - pub use super::{AuthProviderHandle, ClientConnection, ProtocolConnectionExt, Stream}; + pub use super::{ + AuthProviderHandle, ClientConnection, ProtocolConnectionExt, ProtocolSpec, Stream, + }; pub use crate::{ actions::{ensure_boolean_or_aerr, ensure_cond_or_err, ensure_length}, aerr, conwrite, @@ -134,13 +149,14 @@ pub mod prelude { /// The fact that this is a trait enables great flexibility in terms of visibility, but **DO NOT EVER CALL any function other than /// `read_query`, `close_conn_with_error` or `write_response`**. If you mess with functions like `read_again`, you're likely to pull yourself into some /// good trouble. -pub trait ProtocolConnectionExt: ProtocolConnection + Send +pub trait ProtocolConnectionExt: ProtocolConnection + Send where Strm: Stream, + P: ProtocolSpec, { /// Try to parse a query from the buffered data fn try_query(&self) -> Result { - protocol::Parser::parse(self.get_buffer()) + P::parse(self.get_buffer()) } /// Read a query from the remote end /// @@ -193,7 +209,7 @@ where let mv_self = self; let streamer = streamer; let ret: IoResult<()> = { - streamer.write(&mut mv_self.get_mut_stream()).await?; + streamer.write(mv_self.get_mut_stream()).await?; Ok(()) }; ret @@ -323,7 +339,7 @@ where /// ``` /// /// `Strm` should be a stream, i.e something like an SSL connection/TCP connection. -pub trait ProtocolConnection { +pub trait ProtocolConnection { /// Returns an **immutable** reference to the underlying read buffer fn get_buffer(&self) -> &BytesMut; /// Returns an **immutable** reference to the underlying stream @@ -348,16 +364,18 @@ pub trait ProtocolConnection { // Give ProtocolConnection implementors a free ProtocolConnectionExt impl -impl ProtocolConnectionExt for T +impl ProtocolConnectionExt for T where - T: ProtocolConnection + Send, + T: ProtocolConnection + Send, Strm: Stream, + P: ProtocolSpec, { } -impl ProtocolConnection for Connection +impl ProtocolConnection for Connection where T: BufferedSocketStream, + P: ProtocolSpec, { fn get_buffer(&self) -> &BytesMut { &self.buffer @@ -376,35 +394,36 @@ where } } -pub(super) type ExecutorFn = - for<'s> fn(&'s mut ConnectionHandler, Query) -> FutureResult<'s, ActionResult<()>>; +pub(super) type ExecutorFn = + for<'s> fn(&'s mut ConnectionHandler, Query) -> FutureResult<'s, ActionResult<()>>; /// # A generic connection handler /// /// A [`ConnectionHandler`] object is a generic connection handler for any object that implements the [`ProtocolConnection`] trait (or /// the [`ProtocolConnectionExt`] trait). This function will accept such a type `T`, possibly a listener object and then use it to read /// a query, parse it and return an appropriate response through [`corestore::Corestore::execute_query`] -pub struct ConnectionHandler { +pub struct ConnectionHandler { db: Corestore, con: T, climit: Arc, auth: AuthProvider, - executor: ExecutorFn, + executor: ExecutorFn, terminator: Terminator, _term_sig_tx: mpsc::Sender<()>, _marker: PhantomData, } -impl ConnectionHandler +impl ConnectionHandler where - T: ProtocolConnectionExt + Send + Sync, + T: ProtocolConnectionExt + Send + Sync, Strm: Stream, + P: ProtocolSpec, { pub fn new( db: Corestore, con: T, auth: AuthProvider, - executor: ExecutorFn, + executor: ExecutorFn, climit: Arc, terminator: Terminator, _term_sig_tx: mpsc::Sender<()>, @@ -513,7 +532,7 @@ where } } -impl Drop for ConnectionHandler { +impl Drop for ConnectionHandler { fn drop(&mut self) { // Make sure that the permit is returned to the semaphore // in the case that there is a panic inside @@ -526,10 +545,14 @@ pub trait Stream: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync {} impl Stream for T where T: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync {} /// A simple _shorthand trait_ for the insanely long definition of the connection generic type -pub trait ClientConnection: ProtocolConnectionExt + Send + Sync {} -impl ClientConnection for T +pub trait ClientConnection: + ProtocolConnectionExt + Send + Sync +{ +} +impl ClientConnection for T where - T: ProtocolConnectionExt + Send + Sync, + T: ProtocolConnectionExt + Send + Sync, Strm: Stream, + P: ProtocolSpec, { } diff --git a/server/src/dbnet/tcp.rs b/server/src/dbnet/tcp.rs index a146ebf5..3e603b15 100644 --- a/server/src/dbnet/tcp.rs +++ b/server/src/dbnet/tcp.rs @@ -24,6 +24,8 @@ * */ +use crate::dbnet::connection::ProtocolSpec; +use crate::protocol::Skyhash2; use crate::{ dbnet::{ connection::{ConnectionHandler, ExecutorFn}, @@ -45,7 +47,7 @@ pub trait BufferedSocketStream: AsyncWrite {} impl BufferedSocketStream for TcpStream {} -type TcpExecutorFn = ExecutorFn, TcpStream>; +type TcpExecutorFn

= ExecutorFn, TcpStream>; /// A TCP/SSL connection wrapper pub struct Connection @@ -94,13 +96,15 @@ impl TcpBackoff { } } +pub type Listener = RawListener; + /// A listener -pub struct Listener { +pub struct RawListener

{ pub base: BaseListener, - executor_fn: TcpExecutorFn, + executor_fn: TcpExecutorFn

, } -impl Listener { +impl RawListener

{ pub fn new(base: BaseListener) -> Self { Self { executor_fn: if base.auth.is_enabled() { diff --git a/server/src/dbnet/tls.rs b/server/src/dbnet/tls.rs index 916245e0..fcdaf2bb 100644 --- a/server/src/dbnet/tls.rs +++ b/server/src/dbnet/tls.rs @@ -26,10 +26,11 @@ use crate::{ dbnet::{ - connection::{ConnectionHandler, ExecutorFn}, + connection::{ConnectionHandler, ExecutorFn, ProtocolSpec}, tcp::{BufferedSocketStream, Connection, TcpBackoff}, BaseListener, Terminator, }, + protocol::Skyhash2, util::error::{Error, SkyResult}, IoResult, }; @@ -43,21 +44,23 @@ use tokio::net::TcpStream; use tokio_openssl::SslStream; impl BufferedSocketStream for SslStream {} -type SslExecutorFn = ExecutorFn>, SslStream>; +type SslExecutorFn

= ExecutorFn>, SslStream>; -pub struct SslListener { +pub type SslListener = SslListenerRaw; + +pub struct SslListenerRaw

{ pub base: BaseListener, acceptor: SslAcceptor, - executor_fn: SslExecutorFn, + executor_fn: SslExecutorFn

, } -impl SslListener { +impl SslListenerRaw

{ pub fn new_pem_based_ssl_connection( key_file: String, chain_file: String, base: BaseListener, tls_passfile: Option, - ) -> SkyResult { + ) -> SkyResult> { let mut acceptor_builder = SslAcceptor::mozilla_intermediate(SslMethod::tls())?; // cert is the same for both acceptor_builder.set_certificate_chain_file(chain_file)?; @@ -77,7 +80,7 @@ impl SslListener { // no passphrase, needs interactive acceptor_builder.set_private_key_file(key_file, SslFiletype::PEM)?; } - Ok(SslListener { + Ok(Self { acceptor: acceptor_builder.build(), executor_fn: if base.auth.is_enabled() { ConnectionHandler::execute_unauth diff --git a/server/src/protocol/mod.rs b/server/src/protocol/mod.rs index 21a23ca8..9c329221 100644 --- a/server/src/protocol/mod.rs +++ b/server/src/protocol/mod.rs @@ -25,7 +25,7 @@ */ use crate::corestore::heap_array::HeapArray; -use core::{fmt, marker::PhantomData, mem::transmute, slice}; +use core::{fmt, mem::transmute, slice}; #[cfg(feature = "nightly")] mod benches; #[cfg(test)] @@ -39,6 +39,7 @@ pub mod responses; pub const PROTOCOL_VERSION: f32 = 2.0; /// The Skyhash protocol version string (Skyhash-x.y) pub const PROTOCOL_VERSIONSTRING: &str = "Skyhash-2.0"; +pub type Skyhash2 = Parser; #[derive(PartialEq)] /// As its name says, an [`UnsafeSlice`] is a terribly unsafe slice. It's guarantess are @@ -165,27 +166,28 @@ struct OwnedPipelinedQuery { } /// A parser for Skyhash 2.0 -pub struct Parser<'a> { +pub struct Parser { end: *const u8, cursor: *const u8, - _lt: PhantomData<&'a ()>, } -impl<'a> Parser<'a> { +unsafe impl Sync for Parser {} +unsafe impl Send for Parser {} + +impl Parser { /// Initialize a new parser - pub fn new(slice: &[u8]) -> Self { + fn new(slice: &[u8]) -> Self { unsafe { Self { end: slice.as_ptr().add(slice.len()), cursor: slice.as_ptr(), - _lt: PhantomData, } } } } // basic methods -impl<'a> Parser<'a> { +impl Parser { /// Returns a ptr one byte past the allocation of the buffer const fn data_end_ptr(&self) -> *const u8 { self.end @@ -220,7 +222,7 @@ impl<'a> Parser<'a> { } // mutable refs -impl<'a> Parser<'a> { +impl Parser { /// Increment the cursor by `by` positions unsafe fn incr_cursor_by(&mut self, by: usize) { self.cursor = self.cursor.add(by); @@ -232,7 +234,7 @@ impl<'a> Parser<'a> { } // higher level abstractions -impl<'a> Parser<'a> { +impl Parser { /// Attempt to read `len` bytes fn read_until(&mut self, len: usize) -> ParseResult { if self.has_remaining(len) { @@ -308,7 +310,7 @@ impl<'a> Parser<'a> { } // query impls -impl<'a> Parser<'a> { +impl Parser { /// Parse the next simple query. This should have passed the `*` tsymbol /// /// Simple query structure (tokenized line-by-line): @@ -412,6 +414,8 @@ impl<'a> Parser<'a> { Err(ParseError::NotEnough) } } + // only expose this. don't expose Self::new since that'll be _relatively easier_ to + // invalidate invariants for pub fn parse(buf: &[u8]) -> ParseResult<(Query, usize)> { let mut slf = Self::new(buf); let body = slf._parse()?; diff --git a/server/src/queryengine/mod.rs b/server/src/queryengine/mod.rs index 78389afd..c98f18b0 100644 --- a/server/src/queryengine/mod.rs +++ b/server/src/queryengine/mod.rs @@ -78,7 +78,7 @@ action! { fn execute_simple_noauth( _db: &mut Corestore, con: &mut T, - auth: &mut AuthProviderHandle<'_, T, Strm>, + auth: &mut AuthProviderHandle<'_, P, T, Strm>, buf: SimpleQuery ) { let bufref = buf.as_slice(); @@ -96,17 +96,17 @@ action! { fn execute_simple( db: &mut Corestore, con: &mut T, - auth: &mut AuthProviderHandle<'_, T, Strm>, + auth: &mut AuthProviderHandle<'_, P, T, Strm>, buf: SimpleQuery ) { self::execute_stage(db, con, auth, buf.as_slice()).await } } -async fn execute_stage<'a, T: 'a + ClientConnection, Strm: Stream>( +async fn execute_stage<'a, P: ProtocolSpec, T: 'a + ClientConnection, Strm: Stream>( db: &mut Corestore, con: &'a mut T, - auth: &mut AuthProviderHandle<'_, T, Strm>, + auth: &mut AuthProviderHandle<'_, P, T, Strm>, buf: &[UnsafeSlice], ) -> ActionResult<()> { let mut iter = unsafe { @@ -171,10 +171,15 @@ action! { /// Execute a stage **completely**. This means that action errors are never propagated /// over the try operator -async fn execute_stage_pedantic<'a, T: ClientConnection + 'a, Strm: Stream + 'a>( +async fn execute_stage_pedantic< + 'a, + P: ProtocolSpec, + T: ClientConnection + 'a, + Strm: Stream + 'a, +>( handle: &mut Corestore, con: &mut T, - auth: &mut AuthProviderHandle<'_, T, Strm>, + auth: &mut AuthProviderHandle<'_, P, T, Strm>, stage: &[UnsafeSlice], ) -> crate::IoResult<()> { let ret = async { @@ -193,7 +198,7 @@ action! { fn execute_pipeline( handle: &mut Corestore, con: &mut T, - auth: &mut AuthProviderHandle<'_, T, Strm>, + auth: &mut AuthProviderHandle<'_, P, T, Strm>, pipeline: PipelinedQuery ) { for stage in pipeline.into_inner().iter() { diff --git a/server/src/resp/writer.rs b/server/src/resp/writer.rs index c2b54321..6b5d2302 100644 --- a/server/src/resp/writer.rs +++ b/server/src/resp/writer.rs @@ -26,7 +26,7 @@ use crate::corestore::buffers::Integer64; use crate::corestore::Data; -use crate::dbnet::connection::ProtocolConnectionExt; +use crate::dbnet::connection::{ProtocolConnectionExt, ProtocolSpec}; use crate::protocol::responses::groups; use crate::IoResult; use core::marker::PhantomData; @@ -34,13 +34,14 @@ use tokio::io::AsyncReadExt; use tokio::io::AsyncWriteExt; /// Write a raw mono group with a custom tsymbol -pub async unsafe fn write_raw_mono( +pub async unsafe fn write_raw_mono( con: &mut T, tsymbol: u8, payload: &Data, ) -> IoResult<()> where - T: ProtocolConnectionExt, + P: ProtocolSpec, + T: ProtocolConnectionExt, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, { let raw_stream = con.raw_stream(); @@ -54,16 +55,17 @@ where #[derive(Debug)] /// A writer for a flat array, which is a multi-typed non-recursive array -pub struct FlatArrayWriter<'a, T, Strm> { +pub struct FlatArrayWriter<'a, P, T, Strm> { tsymbol: u8, con: &'a mut T, - _owned: PhantomData, + _owned: PhantomData<(P, Strm)>, } #[allow(dead_code)] // TODO(@ohsayan): Remove this once we start using the flat array writer -impl<'a, T, Strm> FlatArrayWriter<'a, T, Strm> +impl<'a, P, T, Strm> FlatArrayWriter<'a, P, T, Strm> where - T: ProtocolConnectionExt, + P: ProtocolSpec, + T: ProtocolConnectionExt, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, { /// Intialize a new flat array writer. This will write out the tsymbol @@ -72,7 +74,7 @@ where con: &'a mut T, tsymbol: u8, len: usize, - ) -> IoResult> { + ) -> IoResult> { { let stream = con.raw_stream(); // first write _ @@ -121,14 +123,15 @@ where #[derive(Debug)] /// A writer for a typed array, which is a singly-typed array which either /// has a typed element or a `NULL` -pub struct TypedArrayWriter<'a, T, Strm> { +pub struct TypedArrayWriter<'a, P, T, Strm> { con: &'a mut T, - _owned: PhantomData, + _owned: PhantomData<(P, Strm)>, } -impl<'a, T, Strm> TypedArrayWriter<'a, T, Strm> +impl<'a, P, T, Strm> TypedArrayWriter<'a, P, T, Strm> where - T: ProtocolConnectionExt, + P: ProtocolSpec, + T: ProtocolConnectionExt, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, { /// Create a new `typedarraywriter`. This will write the tsymbol and @@ -137,7 +140,7 @@ where con: &'a mut T, tsymbol: u8, len: usize, - ) -> IoResult> { + ) -> IoResult> { { let stream = con.raw_stream(); // first write @ @@ -177,14 +180,15 @@ where #[derive(Debug)] /// A writer for a non-null typed array, which is a singly-typed array which either /// has a typed element or a `NULL` -pub struct NonNullArrayWriter<'a, T, Strm> { +pub struct NonNullArrayWriter<'a, P, T, Strm> { con: &'a mut T, - _owned: PhantomData, + _owned: PhantomData<(P, Strm)>, } -impl<'a, T, Strm> NonNullArrayWriter<'a, T, Strm> +impl<'a, P, T, Strm> NonNullArrayWriter<'a, P, T, Strm> where - T: ProtocolConnectionExt, + P: ProtocolSpec, + T: ProtocolConnectionExt, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, { /// Create a new `typedarraywriter`. This will write the tsymbol and @@ -193,7 +197,7 @@ where con: &'a mut T, tsymbol: u8, len: usize, - ) -> IoResult> { + ) -> IoResult> { { let stream = con.raw_stream(); // first write @ diff --git a/server/src/util/macros.rs b/server/src/util/macros.rs index b298af71..54a9d289 100644 --- a/server/src/util/macros.rs +++ b/server/src/util/macros.rs @@ -120,26 +120,43 @@ macro_rules! action { $block:block)* ) => { $($(#[$attr])* - pub async fn $fname<'a, T: 'a + ClientConnection, Strm:Stream>( + pub async fn $fname<'a, T: 'a + ClientConnection, Strm:Stream, P: crate::dbnet::connection::ProtocolSpec>( $($argname: $argty,)* ) -> $crate::actions::ActionResult<()> $block)* }; ( $($(#[$attr:meta])* - fn $fname:ident($argone:ident: $argonety:ty, + fn $fname:ident( + $argone:ident: $argonety:ty, $argtwo:ident: $argtwoty:ty, - mut $argthree:ident: $argthreety:ty) - $block:block)* + mut $argthree:ident: $argthreety:ty + ) $block:block)* ) => { $($(#[$attr])* - pub async fn $fname<'a, T: 'a + ClientConnection, Strm:Stream>( + pub async fn $fname<'a, T: 'a + ClientConnection, Strm:Stream, P: crate::dbnet::connection::ProtocolSpec>( $argone: $argonety, $argtwo: $argtwoty, mut $argthree: $argthreety ) -> $crate::actions::ActionResult<()> $block)* }; + ( + $($(#[$attr:meta])* + fn $fname:ident( + $argone:ident: $argonety:ty, + $argtwo:ident: $argtwoty:ty, + $argthree:ident: $argthreety:ty + ) $block:block)* + ) => { + $($(#[$attr])* + pub async fn $fname<'a, T: 'a + ClientConnection, Strm:Stream, P: crate::dbnet::connection::ProtocolSpec>( + $argone: $argonety, + $argtwo: $argtwoty, + $argthree: $argthreety + ) -> $crate::actions::ActionResult<()> + $block)* + }; } #[macro_export] From d31fc5855dc516bf6d57662ca666dac055135b1f Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Mon, 25 Apr 2022 23:37:38 -0700 Subject: [PATCH 03/13] Add memory safety assertions --- server/src/dbnet/connection.rs | 47 ++++++++++++++++++++++++++-------- 1 file changed, 37 insertions(+), 10 deletions(-) diff --git a/server/src/dbnet/connection.rs b/server/src/dbnet/connection.rs index 8dfc6583..4596f5fe 100644 --- a/server/src/dbnet/connection.rs +++ b/server/src/dbnet/connection.rs @@ -451,18 +451,45 @@ where Ok(QueryResult::Q((query, advance_by))) => { // the mutable reference to self ensures that the buffer is not modified // hence ensuring that the pointers will remain valid - match self.execute_query(query).await { - Ok(()) => {} - Err(ActionError::ActionError(e)) => { - self.con.close_conn_with_error(e).await?; - } - Err(ActionError::IoError(e)) => { - return Err(e); + #[cfg(debug_assertions)] + let len_at_start = self.con.get_buffer().len(); + #[cfg(debug_assertions)] + let sptr_at_start = self.con.get_buffer().as_ptr() as usize; + #[cfg(debug_assertions)] + let eptr_at_start = sptr_at_start + len_at_start; + { + match self.execute_query(query).await { + Ok(()) => {} + Err(ActionError::ActionError(e)) => { + self.con.close_conn_with_error(e).await?; + } + Err(ActionError::IoError(e)) => { + return Err(e); + } } } - // this is only when we clear the buffer. since execute_query is not called - // at this point, it's totally fine (so invalidating ptrs is totally cool) - self.con.advance_buffer(advance_by); + { + // do these assertions to ensure memory safety (this is just for sanity sake) + #[cfg(debug_assertions)] + // len should be unchanged. no functions should **ever** touch the buffer + debug_assert_eq!(self.con.get_buffer().len(), len_at_start); + #[cfg(debug_assertions)] + // start of allocation should be unchanged + debug_assert_eq!(self.con.get_buffer().as_ptr() as usize, sptr_at_start); + #[cfg(debug_assertions)] + // end of allocation should be unchanged. else we're entirely violating + // memory safety guarantees + debug_assert_eq!( + unsafe { + // UNSAFE(@ohsayan): THis is always okay + self.con.get_buffer().as_ptr().add(len_at_start) + } as usize, + eptr_at_start + ); + // this is only when we clear the buffer. since execute_query is not called + // at this point, it's totally fine (so invalidating ptrs is totally cool) + self.con.advance_buffer(advance_by); + } } Ok(QueryResult::E(r)) => self.con.close_conn_with_error(r).await?, Ok(QueryResult::Wrongtype) => { From 2bb7555e4ea69e161a082d3794e3411f579b7092 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Fri, 29 Apr 2022 02:54:55 -0700 Subject: [PATCH 04/13] Add `ProtocolWrite` trait for writing data according to `ProtocolSpec` --- Cargo.lock | 1 + server/Cargo.toml | 1 + server/src/dbnet/connection.rs | 268 +++----------------------- server/src/dbnet/tcp.rs | 11 +- server/src/dbnet/tls.rs | 12 +- server/src/protocol/interface/mod.rs | 275 +++++++++++++++++++++++++++ server/src/protocol/mod.rs | 3 + server/src/protocol/v2/mod.rs | 113 +++++++++++ server/src/resp/mod.rs | 2 +- server/src/resp/writer.rs | 14 +- server/src/util/macros.rs | 21 +- 11 files changed, 467 insertions(+), 254 deletions(-) create mode 100644 server/src/protocol/interface/mod.rs create mode 100644 server/src/protocol/v2/mod.rs diff --git a/Cargo.lock b/Cargo.lock index bb8637bb..e2828fdc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1293,6 +1293,7 @@ name = "skyd" version = "0.7.5" dependencies = [ "ahash", + "async-trait", "base64", "bincode", "bytes", diff --git a/server/Cargo.toml b/server/Cargo.toml index 4bc4a7a9..7f1c207b 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -10,6 +10,7 @@ version = "0.7.5" libsky = { path = "../libsky" } sky_macros = { path = "../sky-macros" } # external deps +async-trait = "0.1.53" ahash = "0.7.6" bytes = "1.1.0" chrono = "0.4.19" diff --git a/server/src/dbnet/connection.rs b/server/src/dbnet/connection.rs index 4596f5fe..3605971e 100644 --- a/server/src/dbnet/connection.rs +++ b/server/src/dbnet/connection.rs @@ -25,55 +25,39 @@ */ //! # Generic connection traits -//! The `con` module defines the generic connection traits `ProtocolConnection` and `ProtocolConnectionExt`. +//! The `con` module defines the generic connection traits `RawConnection` and `ProtocolRead`. //! These two traits can be used to interface with sockets that are used for communication through the Skyhash //! protocol. //! -//! The `ProtocolConnection` trait provides a basic set of methods that are required by prospective connection +//! The `RawConnection` trait provides a basic set of methods that are required by prospective connection //! objects to be eligible for higher level protocol interactions (such as interactions with high-level query objects). -//! Once a type implements this trait, it automatically gets a free `ProtocolConnectionExt` implementation. This immediately +//! Once a type implements this trait, it automatically gets a free `ProtocolRead` implementation. This immediately //! enables this connection object/type to use methods like read_query enabling it to read and interact with queries and write //! respones in compliance with the Skyhash protocol. use crate::{ actions::{ActionError, ActionResult}, auth::{self, AuthProvider}, - corestore::{buffers::Integer64, Corestore}, + corestore::Corestore, dbnet::{ connection::prelude::FutureResult, tcp::{BufferedSocketStream, Connection}, Terminator, }, - protocol::{responses, ParseError, Query, Skyhash2}, - queryengine, - resp::Writable, - IoResult, + protocol::{ + interface::{ProtocolRead, ProtocolSpec}, + responses, Query, + }, + queryengine, IoResult, }; use bytes::{Buf, BytesMut}; -use std::{ - io::{Error as IoError, ErrorKind}, - marker::PhantomData, - sync::Arc, -}; +use std::{marker::PhantomData, sync::Arc}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt, BufWriter}, sync::{mpsc, Semaphore}, }; -pub const SIMPLE_QUERY_HEADER: [u8; 1] = [b'*']; -pub(super) type QueryWithAdvance = (Query, usize); - -/// The [`ProtocolSpec`] trait implementation enables extremely easy switching between -/// protocols by being generic for the same base connection types -pub trait ProtocolSpec: Send + Sync { - fn parse(buf: &[u8]) -> Result; -} - -impl ProtocolSpec for Skyhash2 { - fn parse(buf: &[u8]) -> Result { - Skyhash2::parse(buf) - } -} +pub type QueryWithAdvance = (Query, usize); pub enum QueryResult { Q(QueryWithAdvance), @@ -116,12 +100,10 @@ where } pub mod prelude { - //! A 'prelude' for callers that would like to use the `ProtocolConnection` and `ProtocolConnectionExt` traits + //! A 'prelude' for callers that would like to use the `RawConnection` and `ProtocolRead` traits //! //! This module is hollow itself, it only re-exports from `dbnet::con` and `tokio::io` - pub use super::{ - AuthProviderHandle, ClientConnection, ProtocolConnectionExt, ProtocolSpec, Stream, - }; + pub use super::{AuthProviderHandle, ClientConnection, Stream}; pub use crate::{ actions::{ensure_boolean_or_aerr, ensure_cond_or_err, ensure_length}, aerr, conwrite, @@ -130,7 +112,10 @@ pub mod prelude { Corestore, }, get_tbl, handle_entity, is_lowbit_set, - protocol::responses::{self, groups}, + protocol::{ + interface::{ProtocolRead, ProtocolSpec}, + responses::{self, groups}, + }, queryengine::ActionIter, registry, resp::StringWrapper, @@ -139,197 +124,14 @@ pub mod prelude { pub use tokio::io::{AsyncReadExt, AsyncWriteExt}; } -/// # The `ProtocolConnectionExt` trait +/// # The `RawConnection` trait /// -/// The `ProtocolConnectionExt` trait has default implementations and doesn't ever require explicit definitions, unless -/// there's some black magic that you want to do. All [`ProtocolConnection`] objects will get a free implementation for this trait. -/// Hence implementing [`ProtocolConnection`] alone is enough for you to get high-level methods to interface with the protocol. -/// -/// ## DO NOT -/// The fact that this is a trait enables great flexibility in terms of visibility, but **DO NOT EVER CALL any function other than -/// `read_query`, `close_conn_with_error` or `write_response`**. If you mess with functions like `read_again`, you're likely to pull yourself into some -/// good trouble. -pub trait ProtocolConnectionExt: ProtocolConnection + Send -where - Strm: Stream, - P: ProtocolSpec, -{ - /// Try to parse a query from the buffered data - fn try_query(&self) -> Result { - P::parse(self.get_buffer()) - } - /// Read a query from the remote end - /// - /// This function asynchronously waits until all the data required - /// for parsing the query is available - fn read_query<'r, 's>(&'r mut self) -> FutureResult<'s, Result> - where - 'r: 's, - Self: Sync + Send + 's, - { - Box::pin(async move { - let mv_self = self; - loop { - let (buffer, stream) = mv_self.get_mut_both(); - match stream.read_buf(buffer).await { - Ok(0) => { - if buffer.is_empty() { - return Ok(QueryResult::Disconnected); - } else { - return Err(IoError::from(ErrorKind::ConnectionReset)); - } - } - Ok(_) => {} - Err(e) => return Err(e), - } - match mv_self.try_query() { - Ok(query_with_advance) => { - return Ok(QueryResult::Q(query_with_advance)); - } - Err(ParseError::NotEnough) => (), - Err(ParseError::DatatypeParseFailure) => return Ok(QueryResult::Wrongtype), - Err(ParseError::UnexpectedByte) | Err(ParseError::BadPacket) => { - return Ok(QueryResult::E(responses::full_responses::R_PACKET_ERR)); - } - } - } - }) - } - /// Write a response to the stream - fn write_response<'r, 's>( - &'r mut self, - streamer: impl Writable + 's + Send + Sync, - ) -> FutureResult<'s, IoResult<()>> - where - 'r: 's, - Self: Send + 's, - Self: Sync, - { - Box::pin(async move { - let mv_self = self; - let streamer = streamer; - let ret: IoResult<()> = { - streamer.write(mv_self.get_mut_stream()).await?; - Ok(()) - }; - ret - }) - } - /// Write the simple query header `*` to the stream - fn write_simple_query_header<'r, 's>(&'r mut self) -> FutureResult<'s, IoResult<()>> - where - 'r: 's, - Self: Send + Sync + 's, - { - Box::pin(async move { - let mv_self = self; - let ret: IoResult<()> = { - mv_self.write_response(SIMPLE_QUERY_HEADER).await?; - Ok(()) - }; - ret - }) - } - /// Write the length of the pipeline query (*) - fn write_pipeline_query_header<'r, 's>( - &'r mut self, - len: usize, - ) -> FutureResult<'s, IoResult<()>> - where - 'r: 's, - Self: Send + Sync + 's, - { - Box::pin(async move { - let slf = self; - slf.write_response([b'$']).await?; - slf.get_mut_stream() - .write_all(&Integer64::init(len as u64)) - .await?; - slf.write_response([b'\n']).await?; - Ok(()) - }) - } - /// Write the flat array length (`_\n`) - fn write_flat_array_length<'r, 's>(&'r mut self, len: usize) -> FutureResult<'s, IoResult<()>> - where - 'r: 's, - Self: Send + Sync + 's, - { - Box::pin(async move { - let mv_self = self; - let ret: IoResult<()> = { - mv_self.write_response([b'_']).await?; - mv_self.write_response(len.to_string().into_bytes()).await?; - mv_self.write_response([b'\n']).await?; - Ok(()) - }; - ret - }) - } - /// Write the array length (`&\n`) - fn write_array_length<'r, 's>(&'r mut self, len: usize) -> FutureResult<'s, IoResult<()>> - where - 'r: 's, - Self: Send + Sync + 's, - { - Box::pin(async move { - let mv_self = self; - let ret: IoResult<()> = { - mv_self.write_response([b'&']).await?; - mv_self.write_response(len.to_string().into_bytes()).await?; - mv_self.write_response([b'\n']).await?; - Ok(()) - }; - ret - }) - } - /// Wraps around the `write_response` used to differentiate between a - /// success response and an error response - fn close_conn_with_error<'r, 's>( - &'r mut self, - resp: impl Writable + 's + Send + Sync, - ) -> FutureResult<'s, IoResult<()>> - where - 'r: 's, - Self: Send + Sync + 's, - { - Box::pin(async move { - let mv_self = self; - let ret: IoResult<()> = { - mv_self.write_response(resp).await?; - mv_self.flush_stream().await?; - Ok(()) - }; - ret - }) - } - fn flush_stream<'r, 's>(&'r mut self) -> FutureResult<'s, IoResult<()>> - where - 'r: 's, - Self: Sync + Send + 's, - { - Box::pin(async move { - let mv_self = self; - let ret: IoResult<()> = { - mv_self.get_mut_stream().flush().await?; - Ok(()) - }; - ret - }) - } - unsafe fn raw_stream(&mut self) -> &mut BufWriter { - self.get_mut_stream() - } -} - -/// # The `ProtocolConnection` trait -/// -/// The `ProtocolConnection` trait has low-level methods that can be used to interface with raw sockets. Any type -/// that successfully implements this trait will get an implementation for `ProtocolConnectionExt` which augments and +/// The `RawConnection` trait has low-level methods that can be used to interface with raw sockets. Any type +/// that successfully implements this trait will get an implementation for `ProtocolRead` which augments and /// builds on these fundamental methods to provide high-level interfacing with queries. /// -/// ## Example of a `ProtocolConnection` object -/// Ideally a `ProtocolConnection` object should look like (the generic parameter just exists for doc-tests, just think that +/// ## Example of a `RawConnection` object +/// Ideally a `RawConnection` object should look like (the generic parameter just exists for doc-tests, just think that /// there is a type `Strm`): /// ```no_run /// struct Connection { @@ -339,7 +141,7 @@ where /// ``` /// /// `Strm` should be a stream, i.e something like an SSL connection/TCP connection. -pub trait ProtocolConnection { +pub trait RawConnection: Send + Sync { /// Returns an **immutable** reference to the underlying read buffer fn get_buffer(&self) -> &BytesMut; /// Returns an **immutable** reference to the underlying stream @@ -362,19 +164,9 @@ pub trait ProtocolConnection { } } -// Give ProtocolConnection implementors a free ProtocolConnectionExt impl - -impl ProtocolConnectionExt for T -where - T: ProtocolConnection + Send, - Strm: Stream, - P: ProtocolSpec, -{ -} - -impl ProtocolConnection for Connection +impl RawConnection for Connection where - T: BufferedSocketStream, + T: BufferedSocketStream + Sync + Send, P: ProtocolSpec, { fn get_buffer(&self) -> &BytesMut { @@ -399,8 +191,8 @@ pub(super) type ExecutorFn = /// # A generic connection handler /// -/// A [`ConnectionHandler`] object is a generic connection handler for any object that implements the [`ProtocolConnection`] trait (or -/// the [`ProtocolConnectionExt`] trait). This function will accept such a type `T`, possibly a listener object and then use it to read +/// A [`ConnectionHandler`] object is a generic connection handler for any object that implements the [`RawConnection`] trait (or +/// the [`ProtocolRead`] trait). This function will accept such a type `T`, possibly a listener object and then use it to read /// a query, parse it and return an appropriate response through [`corestore::Corestore::execute_query`] pub struct ConnectionHandler { db: Corestore, @@ -415,7 +207,7 @@ pub struct ConnectionHandler { impl ConnectionHandler where - T: ProtocolConnectionExt + Send + Sync, + T: ProtocolRead + Send + Sync, Strm: Stream, P: ProtocolSpec, { @@ -573,12 +365,12 @@ impl Stream for T where T: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync /// A simple _shorthand trait_ for the insanely long definition of the connection generic type pub trait ClientConnection: - ProtocolConnectionExt + Send + Sync + ProtocolRead + Send + Sync { } impl ClientConnection for T where - T: ProtocolConnectionExt + Send + Sync, + T: ProtocolRead + Send + Sync, Strm: Stream, P: ProtocolSpec, { diff --git a/server/src/dbnet/tcp.rs b/server/src/dbnet/tcp.rs index 3e603b15..0c9df616 100644 --- a/server/src/dbnet/tcp.rs +++ b/server/src/dbnet/tcp.rs @@ -24,8 +24,10 @@ * */ -use crate::dbnet::connection::ProtocolSpec; -use crate::protocol::Skyhash2; +use crate::protocol::{ + interface::{ProtocolRead, ProtocolSpec}, + Skyhash2, +}; use crate::{ dbnet::{ connection::{ConnectionHandler, ExecutorFn}, @@ -104,7 +106,10 @@ pub struct RawListener

{ executor_fn: TcpExecutorFn

, } -impl RawListener

{ +impl RawListener

+where + Connection: ProtocolRead, +{ pub fn new(base: BaseListener) -> Self { Self { executor_fn: if base.auth.is_enabled() { diff --git a/server/src/dbnet/tls.rs b/server/src/dbnet/tls.rs index fcdaf2bb..3308f2d0 100644 --- a/server/src/dbnet/tls.rs +++ b/server/src/dbnet/tls.rs @@ -26,11 +26,14 @@ use crate::{ dbnet::{ - connection::{ConnectionHandler, ExecutorFn, ProtocolSpec}, + connection::{ConnectionHandler, ExecutorFn}, tcp::{BufferedSocketStream, Connection, TcpBackoff}, BaseListener, Terminator, }, - protocol::Skyhash2, + protocol::{ + interface::{ProtocolRead, ProtocolSpec}, + Skyhash2, + }, util::error::{Error, SkyResult}, IoResult, }; @@ -54,7 +57,10 @@ pub struct SslListenerRaw

{ executor_fn: SslExecutorFn

, } -impl SslListenerRaw

{ +impl SslListenerRaw

+where + Connection>: ProtocolRead>, +{ pub fn new_pem_based_ssl_connection( key_file: String, chain_file: String, diff --git a/server/src/protocol/interface/mod.rs b/server/src/protocol/interface/mod.rs new file mode 100644 index 00000000..35066cb3 --- /dev/null +++ b/server/src/protocol/interface/mod.rs @@ -0,0 +1,275 @@ +/* + * Created on Tue Apr 26 2022 + * + * This file is a part of Skytable + * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source + * NoSQL database written by Sayan Nandan ("the Author") with the + * vision to provide flexibility in data modelling without compromising + * on performance, queryability or scalability. + * + * Copyright (c) 2022, Sayan Nandan + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * +*/ + +use super::{responses, ParseError}; +use crate::{ + corestore::buffers::Integer64, + dbnet::connection::{QueryResult, QueryWithAdvance, RawConnection, Stream}, + resp::Writable, + util::FutureResult, + IoResult, +}; +use std::io::{Error as IoError, ErrorKind}; +use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}; + +pub const SIMPLE_QUERY_HEADER: [u8; 1] = [b'*']; + +pub trait ProtocolCharset { + const TSYMBOL_STRING: u8; + const TSYMBOL_BINARY: u8; + const TSYMBOL_FLOAT: u8; + const TSYMBOL_INT64: u8; + const TSYMBOL_TYPED_ARRAY: u8; + const TSYMBOL_TYPED_NON_NULL_ARRAY: u8; + const TSYMBOL_ARRAY: u8; + const TSYMBOL_FLAT_ARRAY: u8; + const LF: u8 = b'\n'; + const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8]; +} + +/// The [`ProtocolSpec`] trait implementation enables extremely easy switching between +/// protocols by being generic for the same base connection types +pub trait ProtocolSpec: Send + Sync + Sized + ProtocolCharset { + fn parse(buf: &[u8]) -> Result; +} + +/// # The `ProtocolRead` trait +/// +/// The `ProtocolRead` trait enables read operations using the protocol for a given stream `Strm` and protocol +/// `P`. Both the stream and protocol must implement the appropriate traits for you to be able to use these +/// traits +/// +/// ## DO NOT +/// The fact that this is a trait enables great flexibility in terms of visibility, but **DO NOT EVER CALL any +/// function other than `read_query`, `close_conn_with_error` or `write_response`**. If you mess with functions +/// like `read_again`, you're likely to pull yourself into some good trouble. +#[async_trait::async_trait] +pub trait ProtocolRead: RawConnection +where + Strm: Stream, + P: ProtocolSpec, +{ + /// Try to parse a query from the buffered data + fn try_query(&self) -> Result { + P::parse(self.get_buffer()) + } + /// Read a query from the remote end + /// + /// This function asynchronously waits until all the data required + /// for parsing the query is available + fn read_query<'s, 'r: 's>(&'r mut self) -> FutureResult<'s, Result> { + Box::pin(async move { + let mv_self = self; + loop { + let (buffer, stream) = mv_self.get_mut_both(); + match stream.read_buf(buffer).await { + Ok(0) => { + if buffer.is_empty() { + return Ok(QueryResult::Disconnected); + } else { + return Err(IoError::from(ErrorKind::ConnectionReset)); + } + } + Ok(_) => {} + Err(e) => return Err(e), + } + match mv_self.try_query() { + Ok(query_with_advance) => { + return Ok(QueryResult::Q(query_with_advance)); + } + Err(ParseError::NotEnough) => (), + Err(ParseError::DatatypeParseFailure) => return Ok(QueryResult::Wrongtype), + Err(ParseError::UnexpectedByte | ParseError::BadPacket) => { + return Ok(QueryResult::E(responses::full_responses::R_PACKET_ERR)); + } + } + } + }) + } + /// Write a response to the stream + fn write_response<'s, 'r: 's>( + &'r mut self, + streamer: impl Writable + 's + Send + Sync, + ) -> FutureResult<'s, IoResult<()>> { + Box::pin(async move { + let mv_self = self; + let streamer = streamer; + let ret: IoResult<()> = { + streamer.write(mv_self.get_mut_stream()).await?; + Ok(()) + }; + ret + }) + } + /// Write the simple query header `*` to the stream + fn write_simple_query_header<'s, 'r: 's>(&'r mut self) -> FutureResult<'s, IoResult<()>> { + Box::pin(async move { + let mv_self = self; + let ret: IoResult<()> = { + mv_self.write_response(SIMPLE_QUERY_HEADER).await?; + Ok(()) + }; + ret + }) + } + /// Write the length of the pipeline query (*) + fn write_pipeline_query_header<'s, 'r: 's>( + &'r mut self, + len: usize, + ) -> FutureResult<'s, IoResult<()>> { + Box::pin(async move { + let slf = self; + slf.write_response([b'$']).await?; + slf.get_mut_stream() + .write_all(&Integer64::init(len as u64)) + .await?; + slf.write_response([b'\n']).await?; + Ok(()) + }) + } + /// Write the flat array length (`_\n`) + fn write_flat_array_length<'s, 'r: 's>( + &'r mut self, + len: usize, + ) -> FutureResult<'s, IoResult<()>> { + Box::pin(async move { + let mv_self = self; + let ret: IoResult<()> = { + mv_self.write_response([b'_']).await?; + mv_self.write_response(len.to_string().into_bytes()).await?; + mv_self.write_response([b'\n']).await?; + Ok(()) + }; + ret + }) + } + /// Write the array length (`&\n`) + fn write_array_length<'s, 'r: 's>(&'r mut self, len: usize) -> FutureResult<'s, IoResult<()>> { + Box::pin(async move { + let mv_self = self; + let ret: IoResult<()> = { + mv_self.write_response([b'&']).await?; + mv_self.write_response(len.to_string().into_bytes()).await?; + mv_self.write_response([b'\n']).await?; + Ok(()) + }; + ret + }) + } + /// Wraps around the `write_response` used to differentiate between a + /// success response and an error response + fn close_conn_with_error<'s, 'r: 's>( + &'r mut self, + resp: impl Writable + 's + Send + Sync, + ) -> FutureResult<'s, IoResult<()>> { + Box::pin(async move { + let mv_self = self; + let ret: IoResult<()> = { + mv_self.write_response(resp).await?; + mv_self.flush_stream().await?; + Ok(()) + }; + ret + }) + } + fn flush_stream<'s, 'r: 's>(&'r mut self) -> FutureResult<'s, IoResult<()>> { + Box::pin(async move { + let mv_self = self; + let ret: IoResult<()> = { + mv_self.get_mut_stream().flush().await?; + Ok(()) + }; + ret + }) + } + unsafe fn raw_stream(&mut self) -> &mut BufWriter { + self.get_mut_stream() + } +} + +impl ProtocolRead for T +where + T: RawConnection + Send + Sync, + Strm: Stream, + P: ProtocolSpec, +{ +} + +#[async_trait::async_trait] +pub trait ProtocolWrite: RawConnection +where + Strm: Stream, + P: ProtocolSpec, +{ + fn _get_raw_stream(&mut self) -> &mut BufWriter { + self.get_mut_stream() + } + + // monoelements + /// serialize and write an `&str` to the stream + async fn write_string(&mut self, string: &str) -> IoResult<()>; + /// serialize and write an `&[u8]` to the stream + async fn write_binary(&mut self, binary: &[u8]) -> IoResult<()>; + /// serialize and write an `usize` to the stream + async fn write_usize(&mut self, size: usize) -> IoResult<()>; + /// serialize and write an `f32` to the stream + async fn write_float(&mut self, float: f32) -> IoResult<()>; + + // typed array + async fn write_typed_array_header(&mut self, len: usize, tsymbol: u8) -> IoResult<()> { + // \n + self.get_mut_stream() + .write_all(&[P::TSYMBOL_TYPED_ARRAY, tsymbol]) + .await?; + self.get_mut_stream() + .write_all(&Integer64::from(len)) + .await?; + self.get_mut_stream().write_all(&[P::LF]).await?; + Ok(()) + } + async fn write_typed_array_element_null(&mut self) -> IoResult<()> { + self.get_mut_stream() + .write_all(P::TYPE_TYPED_ARRAY_ELEMENT_NULL) + .await + } + async fn write_typed_array_element(&mut self, element: &[u8]) -> IoResult<()>; + + // typed non-null array + async fn write_typed_non_null_array_header(&mut self, len: usize, tsymbol: u8) -> IoResult<()> { + // \n + self.get_mut_stream() + .write_all(&[P::TSYMBOL_TYPED_NON_NULL_ARRAY, tsymbol]) + .await?; + self.get_mut_stream() + .write_all(&Integer64::from(len)) + .await?; + self.get_mut_stream().write_all(&[P::LF]).await?; + Ok(()) + } + async fn write_typed_non_null_array_element(&mut self, element: &[u8]) -> IoResult<()> { + self.write_typed_array_element(element).await + } +} diff --git a/server/src/protocol/mod.rs b/server/src/protocol/mod.rs index 9c329221..bd70dff5 100644 --- a/server/src/protocol/mod.rs +++ b/server/src/protocol/mod.rs @@ -31,8 +31,11 @@ mod benches; #[cfg(test)] mod tests; // pub mods +pub mod interface; pub mod iter; pub mod responses; +// versions +mod v2; // endof pub mods /// The Skyhash protocol version diff --git a/server/src/protocol/v2/mod.rs b/server/src/protocol/v2/mod.rs new file mode 100644 index 00000000..d3c27abf --- /dev/null +++ b/server/src/protocol/v2/mod.rs @@ -0,0 +1,113 @@ +/* + * Created on Fri Apr 29 2022 + * + * This file is a part of Skytable + * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source + * NoSQL database written by Sayan Nandan ("the Author") with the + * vision to provide flexibility in data modelling without compromising + * on performance, queryability or scalability. + * + * Copyright (c) 2022, Sayan Nandan + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * +*/ + +use super::{ + interface::{ProtocolCharset, ProtocolSpec, ProtocolWrite}, + ParseError, Skyhash2, +}; +use crate::{ + corestore::buffers::Integer64, + dbnet::connection::{QueryWithAdvance, RawConnection, Stream}, + IoResult, +}; +use tokio::io::AsyncWriteExt; + +impl ProtocolCharset for Skyhash2 { + const TSYMBOL_STRING: u8 = b'+'; + const TSYMBOL_BINARY: u8 = b'?'; + const TSYMBOL_FLOAT: u8 = b'%'; + const TSYMBOL_INT64: u8 = b':'; + const TSYMBOL_TYPED_ARRAY: u8 = b'@'; + const TSYMBOL_TYPED_NON_NULL_ARRAY: u8 = b'^'; + const TSYMBOL_ARRAY: u8 = b'&'; + const TSYMBOL_FLAT_ARRAY: u8 = b'_'; + const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8] = b"\0"; +} + +impl ProtocolSpec for Skyhash2 { + fn parse(buf: &[u8]) -> Result { + Skyhash2::parse(buf) + } +} + +#[async_trait::async_trait] +impl ProtocolWrite for T +where + T: RawConnection + Send + Sync, + Strm: Stream, +{ + async fn write_string(&mut self, string: &str) -> IoResult<()> { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash2::TSYMBOL_STRING]).await?; + // length + let len_bytes = Integer64::from(string.len()); + stream.write_all(&len_bytes).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await?; + // payload + stream.write_all(string.as_bytes()).await + } + async fn write_binary(&mut self, binary: &[u8]) -> IoResult<()> { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash2::TSYMBOL_BINARY]).await?; + // length + let len_bytes = Integer64::from(binary.len()); + stream.write_all(&len_bytes).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await?; + // payload + stream.write_all(binary).await + } + async fn write_usize(&mut self, size: usize) -> IoResult<()> { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash2::TSYMBOL_INT64]).await?; + // body + stream.write_all(&Integer64::from(size)).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await + } + async fn write_float(&mut self, float: f32) -> IoResult<()> { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash2::TSYMBOL_FLOAT]).await?; + // body + stream.write_all(float.to_string().as_bytes()).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await + } + async fn write_typed_array_element(&mut self, element: &[u8]) -> IoResult<()> { + let stream = self.get_mut_stream(); + // len + stream.write_all(&Integer64::from(element.len())).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await?; + // body + stream.write_all(element).await + } +} diff --git a/server/src/resp/mod.rs b/server/src/resp/mod.rs index 3c6abc85..5dcd3f2e 100644 --- a/server/src/resp/mod.rs +++ b/server/src/resp/mod.rs @@ -60,7 +60,7 @@ pub trait Writable { fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s>; } -pub trait IsConnection: std::marker::Sync + std::marker::Send { +pub trait IsConnection: Sync + Send { fn write_lowlevel<'s>(&'s mut self, bytes: &'s [u8]) -> FutureIoResult<'s>; } diff --git a/server/src/resp/writer.rs b/server/src/resp/writer.rs index 6b5d2302..4590390f 100644 --- a/server/src/resp/writer.rs +++ b/server/src/resp/writer.rs @@ -26,8 +26,10 @@ use crate::corestore::buffers::Integer64; use crate::corestore::Data; -use crate::dbnet::connection::{ProtocolConnectionExt, ProtocolSpec}; -use crate::protocol::responses::groups; +use crate::protocol::{ + interface::{ProtocolRead, ProtocolSpec}, + responses::groups, +}; use crate::IoResult; use core::marker::PhantomData; use tokio::io::AsyncReadExt; @@ -41,7 +43,7 @@ pub async unsafe fn write_raw_mono( ) -> IoResult<()> where P: ProtocolSpec, - T: ProtocolConnectionExt, + T: ProtocolRead, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, { let raw_stream = con.raw_stream(); @@ -65,7 +67,7 @@ pub struct FlatArrayWriter<'a, P, T, Strm> { impl<'a, P, T, Strm> FlatArrayWriter<'a, P, T, Strm> where P: ProtocolSpec, - T: ProtocolConnectionExt, + T: ProtocolRead, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, { /// Intialize a new flat array writer. This will write out the tsymbol @@ -131,7 +133,7 @@ pub struct TypedArrayWriter<'a, P, T, Strm> { impl<'a, P, T, Strm> TypedArrayWriter<'a, P, T, Strm> where P: ProtocolSpec, - T: ProtocolConnectionExt, + T: ProtocolRead, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, { /// Create a new `typedarraywriter`. This will write the tsymbol and @@ -188,7 +190,7 @@ pub struct NonNullArrayWriter<'a, P, T, Strm> { impl<'a, P, T, Strm> NonNullArrayWriter<'a, P, T, Strm> where P: ProtocolSpec, - T: ProtocolConnectionExt, + T: ProtocolRead, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, { /// Create a new `typedarraywriter`. This will write the tsymbol and diff --git a/server/src/util/macros.rs b/server/src/util/macros.rs index 54a9d289..66feaac7 100644 --- a/server/src/util/macros.rs +++ b/server/src/util/macros.rs @@ -120,7 +120,12 @@ macro_rules! action { $block:block)* ) => { $($(#[$attr])* - pub async fn $fname<'a, T: 'a + ClientConnection, Strm:Stream, P: crate::dbnet::connection::ProtocolSpec>( + pub async fn $fname< + 'a, + T: 'a + $crate::dbnet::connection::ClientConnection, + Strm: $crate::dbnet::connection::Stream, + P: $crate::protocol::interface::ProtocolSpec + > ( $($argname: $argty,)* ) -> $crate::actions::ActionResult<()> $block)* @@ -134,7 +139,12 @@ macro_rules! action { ) $block:block)* ) => { $($(#[$attr])* - pub async fn $fname<'a, T: 'a + ClientConnection, Strm:Stream, P: crate::dbnet::connection::ProtocolSpec>( + pub async fn $fname< + 'a, + T: 'a + $crate::dbnet::connection::ClientConnection, + Strm: $crate::dbnet::connection::Stream, + P: $crate::protocol::interface::ProtocolSpec + >( $argone: $argonety, $argtwo: $argtwoty, mut $argthree: $argthreety @@ -150,7 +160,12 @@ macro_rules! action { ) $block:block)* ) => { $($(#[$attr])* - pub async fn $fname<'a, T: 'a + ClientConnection, Strm:Stream, P: crate::dbnet::connection::ProtocolSpec>( + pub async fn $fname< + 'a, + T: 'a + $crate::dbnet::connection::ClientConnection, + Strm: $crate::dbnet::connection::Stream, + P: $crate::protocol::interface::ProtocolSpec + >( $argone: $argonety, $argtwo: $argtwoty, $argthree: $argthreety From 9c15e100c8c37123f33bcec04f6ec46a8c85dfc3 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Fri, 29 Apr 2022 03:06:21 -0700 Subject: [PATCH 05/13] Add metaframe methods to `ProtocolWrite` --- server/src/protocol/interface/mod.rs | 19 +++++++++++++++++++ server/src/protocol/v2/mod.rs | 2 ++ 2 files changed, 21 insertions(+) diff --git a/server/src/protocol/interface/mod.rs b/server/src/protocol/interface/mod.rs index 35066cb3..8ec82079 100644 --- a/server/src/protocol/interface/mod.rs +++ b/server/src/protocol/interface/mod.rs @@ -47,6 +47,8 @@ pub trait ProtocolCharset { const TSYMBOL_ARRAY: u8; const TSYMBOL_FLAT_ARRAY: u8; const LF: u8 = b'\n'; + const SIMPLE_QUERY_HEADER: &'static [u8]; + const PIPELINED_QUERY_FIRST_BYTE: u8; const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8]; } @@ -224,10 +226,27 @@ where Strm: Stream, P: ProtocolSpec, { + // utility (intentionally underscored to avoid direct access) fn _get_raw_stream(&mut self) -> &mut BufWriter { self.get_mut_stream() } + // metaframe methods + async fn write_simple_query_header(&mut self) -> IoResult<()> { + self.get_mut_stream() + .write_all(P::SIMPLE_QUERY_HEADER) + .await + } + async fn write_pipelined_query_header(&mut self, qcount: usize) -> IoResult<()> { + self.get_mut_stream() + .write_all(&[P::PIPELINED_QUERY_FIRST_BYTE]) + .await?; + self.get_mut_stream() + .write_all(&Integer64::from(qcount)) + .await?; + self.get_mut_stream().write_all(&[P::LF]).await + } + // monoelements /// serialize and write an `&str` to the stream async fn write_string(&mut self, string: &str) -> IoResult<()>; diff --git a/server/src/protocol/v2/mod.rs b/server/src/protocol/v2/mod.rs index d3c27abf..6ae004ec 100644 --- a/server/src/protocol/v2/mod.rs +++ b/server/src/protocol/v2/mod.rs @@ -45,6 +45,8 @@ impl ProtocolCharset for Skyhash2 { const TSYMBOL_ARRAY: u8 = b'&'; const TSYMBOL_FLAT_ARRAY: u8 = b'_'; const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8] = b"\0"; + const SIMPLE_QUERY_HEADER: &'static [u8] = b"*"; + const PIPELINED_QUERY_FIRST_BYTE: u8 = b'$'; } impl ProtocolSpec for Skyhash2 { From b047845cc58c8f70301ad9278f21e5b8293f57b6 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Fri, 29 Apr 2022 11:58:33 -0700 Subject: [PATCH 06/13] Upgrade all interfaces to be generic over protocol --- Cargo.lock | 1 - server/Cargo.toml | 1 - server/src/actions/dbsize.rs | 5 +- server/src/actions/del.rs | 8 +- server/src/actions/exists.rs | 6 +- server/src/actions/flushdb.rs | 4 +- server/src/actions/get.rs | 10 +- server/src/actions/keylen.rs | 4 +- server/src/actions/lists/lget.rs | 65 +++-- server/src/actions/lists/lmod.rs | 36 +-- server/src/actions/lists/macros.rs | 7 +- server/src/actions/lists/mod.rs | 5 +- server/src/actions/lskeys.rs | 9 +- server/src/actions/macros.rs | 18 +- server/src/actions/mget.rs | 14 +- server/src/actions/mod.rs | 8 +- server/src/actions/mpop.rs | 17 +- server/src/actions/mset.rs | 6 +- server/src/actions/mupdate.rs | 6 +- server/src/actions/pop.rs | 16 +- server/src/actions/set.rs | 4 +- server/src/actions/strong/sdel.rs | 10 +- server/src/actions/strong/sset.rs | 10 +- server/src/actions/strong/supdate.rs | 10 +- server/src/actions/update.rs | 4 +- server/src/actions/uset.rs | 6 +- server/src/actions/whereami.rs | 11 +- server/src/admin/mksnap.rs | 24 +- server/src/admin/sys.rs | 12 +- server/src/auth/mod.rs | 22 +- server/src/dbnet/connection.rs | 22 +- server/src/dbnet/tcp.rs | 4 +- server/src/dbnet/tls.rs | 5 +- server/src/main.rs | 1 - server/src/protocol/interface/mod.rs | 368 +++++++++++++++------------ server/src/protocol/v2/mod.rs | 174 +++++++++---- server/src/queryengine/ddl.rs | 28 +- server/src/queryengine/inspect.rs | 36 ++- server/src/queryengine/mod.rs | 8 +- server/src/resp/mod.rs | 226 ---------------- server/src/resp/writer.rs | 231 ----------------- server/src/tests/inspect_tests.rs | 4 +- server/src/tests/kvengine.rs | 12 +- server/src/tests/kvengine_list.rs | 2 +- server/src/tests/macros.rs | 8 +- server/src/tests/persist/mod.rs | 8 +- 46 files changed, 545 insertions(+), 951 deletions(-) delete mode 100644 server/src/resp/mod.rs delete mode 100644 server/src/resp/writer.rs diff --git a/Cargo.lock b/Cargo.lock index e2828fdc..bb8637bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1293,7 +1293,6 @@ name = "skyd" version = "0.7.5" dependencies = [ "ahash", - "async-trait", "base64", "bincode", "bytes", diff --git a/server/Cargo.toml b/server/Cargo.toml index 7f1c207b..4bc4a7a9 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -10,7 +10,6 @@ version = "0.7.5" libsky = { path = "../libsky" } sky_macros = { path = "../sky-macros" } # external deps -async-trait = "0.1.53" ahash = "0.7.6" bytes = "1.1.0" chrono = "0.4.19" diff --git a/server/src/actions/dbsize.rs b/server/src/actions/dbsize.rs index 8bbbb843..a52cab79 100644 --- a/server/src/actions/dbsize.rs +++ b/server/src/actions/dbsize.rs @@ -32,11 +32,12 @@ action!( ensure_length(act.len(), |len| len < 2)?; if act.is_empty() { let len = get_tbl_ref!(handle, con).count(); - con.write_response(len).await?; + con.write_usize(len).await?; } else { let raw_entity = unsafe { act.next().unsafe_unwrap() }; let entity = handle_entity!(con, raw_entity); - conwrite!(con, get_tbl!(entity, handle, con).count())?; + con.write_usize(get_tbl!(entity, handle, con).count()) + .await?; } Ok(()) } diff --git a/server/src/actions/del.rs b/server/src/actions/del.rs index 0e4c548d..b56ec2b1 100644 --- a/server/src/actions/del.rs +++ b/server/src/actions/del.rs @@ -57,12 +57,12 @@ action!( } } if let Some(done_howmany) = done_howmany { - con.write_response(done_howmany).await?; + con.write_usize(done_howmany).await?; } else { - con.write_response(responses::groups::SERVER_ERR).await?; + con._write_raw(groups::SERVER_ERR).await?; } } else { - compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?; + return util::err(groups::ENCODING_ERROR); } }}; } @@ -74,7 +74,7 @@ action!( remove!(kvlmap) } #[allow(unreachable_patterns)] - _ => conwrite!(con, groups::WRONG_MODEL)?, + _ => return util::err(groups::WRONG_MODEL), } Ok(()) } diff --git a/server/src/actions/exists.rs b/server/src/actions/exists.rs index 4040278e..f2d69828 100644 --- a/server/src/actions/exists.rs +++ b/server/src/actions/exists.rs @@ -45,9 +45,9 @@ action!( act.for_each(|key| { how_many_of_them_exist += $engine.exists_unchecked(key) as usize; }); - conwrite!(con, how_many_of_them_exist)?; + con.write_usize(how_many_of_them_exist).await?; } else { - compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?; + return util::err(groups::ENCODING_ERROR); } }}; } @@ -56,7 +56,7 @@ action!( DataModel::KV(kve) => exists!(kve), DataModel::KVExtListmap(kve) => exists!(kve), #[allow(unreachable_patterns)] - _ => conwrite!(con, groups::WRONG_MODEL)?, + _ => return util::err(groups::WRONG_MODEL), } Ok(()) } diff --git a/server/src/actions/flushdb.rs b/server/src/actions/flushdb.rs index 6cbe97ce..d7dc7335 100644 --- a/server/src/actions/flushdb.rs +++ b/server/src/actions/flushdb.rs @@ -41,9 +41,9 @@ action!( let entity = handle_entity!(con, raw_entity); get_tbl!(entity, handle, con).truncate_table(); } - conwrite!(con, responses::groups::OKAY)?; + con._write_raw(groups::OKAY).await?; } else { - conwrite!(con, responses::groups::SERVER_ERR)?; + con._write_raw(groups::SERVER_ERR).await?; } Ok(()) } diff --git a/server/src/actions/get.rs b/server/src/actions/get.rs index 1d512563..d9739939 100644 --- a/server/src/actions/get.rs +++ b/server/src/actions/get.rs @@ -28,7 +28,6 @@ //! This module provides functions to work with `GET` queries use crate::dbnet::connection::prelude::*; -use crate::resp::writer; use crate::util::compiler; action!( @@ -38,9 +37,12 @@ action!( let kve = handle.get_table_with::()?; unsafe { match kve.get_cloned(act.next_unchecked()) { - Ok(Some(val)) => writer::write_raw_mono(con, kve.get_value_tsymbol(), &val).await?, - Err(_) => compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?, - Ok(_) => conwrite!(con, groups::NIL)?, + Ok(Some(val)) => { + con.write_mono_length_prefixed_with_tsymbol(&val, kve.get_value_tsymbol()) + .await? + } + Err(_) => compiler::cold_err(con._write_raw(groups::ENCODING_ERROR)).await?, + Ok(_) => con._write_raw(groups::NIL).await?, } } Ok(()) diff --git a/server/src/actions/keylen.rs b/server/src/actions/keylen.rs index 5b875c35..cf1cad8e 100644 --- a/server/src/actions/keylen.rs +++ b/server/src/actions/keylen.rs @@ -45,10 +45,10 @@ action!( }; if let Some(value) = res { // Good, we got the key's length, write it off to the stream - con.write_response(value).await?; + con.write_usize(value).await?; } else { // Ah, couldn't find that key - con.write_response(responses::groups::NIL).await?; + con._write_raw(groups::NIL).await?; } Ok(()) } diff --git a/server/src/actions/lists/lget.rs b/server/src/actions/lists/lget.rs index ecfd3149..e5900c30 100644 --- a/server/src/actions/lists/lget.rs +++ b/server/src/actions/lists/lget.rs @@ -26,9 +26,6 @@ use crate::corestore::Data; use crate::dbnet::connection::prelude::*; -use crate::resp::writer; -use crate::resp::writer::TypedArrayWriter; - const LEN: &[u8] = "LEN".as_bytes(); const LIMIT: &[u8] = "LIMIT".as_bytes(); const VALUEAT: &[u8] = "VALUEAT".as_bytes(); @@ -84,8 +81,8 @@ action! { // just return everything in the list let items = match listmap.list_cloned_full(listname) { Ok(Some(list)) => list, - Ok(None) => return conwrite!(con, groups::NIL), - Err(()) => return conwrite!(con, groups::ENCODING_ERROR), + Ok(None) => return Err(groups::NIL.into()), + Err(()) => return Err(groups::ENCODING_ERROR.into()), }; writelist!(con, listmap, items); } @@ -94,9 +91,9 @@ action! { LEN => { ensure_length(act.len(), |len| len == 0)?; match listmap.list_len(listname) { - Ok(Some(len)) => conwrite!(con, len)?, - Ok(None) => conwrite!(con, groups::NIL)?, - Err(()) => conwrite!(con, groups::ENCODING_ERROR)?, + Ok(Some(len)) => con.write_usize(len).await?, + Ok(None) => return Err(groups::NIL.into()), + Err(()) => return Err(groups::ENCODING_ERROR.into()), } } LIMIT => { @@ -104,8 +101,8 @@ action! { let count = get_numeric_count!(); match listmap.list_cloned(listname, count) { Ok(Some(items)) => writelist!(con, listmap, items), - Ok(None) => conwrite!(con, groups::NIL)?, - Err(()) => conwrite!(con, groups::ENCODING_ERROR)?, + Ok(None) => return Err(groups::NIL.into()), + Err(()) => return Err(groups::ENCODING_ERROR.into()), } } VALUEAT => { @@ -117,22 +114,20 @@ action! { match maybe_value { Ok(v) => match v { Some(Some(value)) => { - unsafe { - // tsymbol is verified - writer::write_raw_mono(con, listmap.get_value_tsymbol(), &value) - .await?; - } + con.write_mono_length_prefixed_with_tsymbol( + &value, listmap.get_value_tsymbol() + ).await?; } Some(None) => { // bad index - conwrite!(con, groups::LISTMAP_BAD_INDEX)?; + return Err(groups::LISTMAP_BAD_INDEX.into()); } None => { // not found - conwrite!(con, groups::NIL)?; + return Err(groups::NIL.into()); } } - Err(()) => conwrite!(con, groups::ENCODING_ERROR)?, + Err(()) => return Err(groups::ENCODING_ERROR.into()), } } LAST => { @@ -143,14 +138,14 @@ action! { match maybe_value { Ok(v) => match v { Some(Some(value)) => { - unsafe { - writer::write_raw_mono(con, listmap.get_value_tsymbol(), &value).await?; - } + con.write_mono_length_prefixed_with_tsymbol( + &value, listmap.get_value_tsymbol() + ).await?; }, - Some(None) => conwrite!(con, groups::LISTMAP_LIST_IS_EMPTY)?, - None => conwrite!(con, groups::NIL)?, + Some(None) => return Err(groups::LISTMAP_LIST_IS_EMPTY.into()), + None => return Err(groups::NIL.into()), } - Err(()) => conwrite!(con, groups::ENCODING_ERROR)?, + Err(()) => return Err(groups::ENCODING_ERROR.into()), } } FIRST => { @@ -161,14 +156,14 @@ action! { match maybe_value { Ok(v) => match v { Some(Some(value)) => { - unsafe { - writer::write_raw_mono(con, listmap.get_value_tsymbol(), &value).await?; - } + con.write_mono_length_prefixed_with_tsymbol( + &value, listmap.get_value_tsymbol() + ).await?; }, - Some(None) => conwrite!(con, groups::LISTMAP_LIST_IS_EMPTY)?, - None => conwrite!(con, groups::NIL)?, + Some(None) => return Err(groups::LISTMAP_LIST_IS_EMPTY.into()), + None => return Err(groups::NIL.into()), } - Err(()) => conwrite!(con, groups::ENCODING_ERROR)?, + Err(()) => return Err(groups::ENCODING_ERROR.into()), } } RANGE => { @@ -193,17 +188,17 @@ action! { Some(ret) => { writelist!(con, listmap, ret); }, - None => conwrite!(con, groups::LISTMAP_BAD_INDEX)?, + None => return Err(groups::LISTMAP_BAD_INDEX.into()), } } - Ok(None) => conwrite!(con, groups::NIL)?, - Err(()) => conwrite!(con, groups::ENCODING_ERROR)?, + Ok(None) => return Err(groups::NIL.into()), + Err(()) => return Err(groups::ENCODING_ERROR.into()), } } - None => aerr!(con), + None => return Err(groups::ACTION_ERR.into()), } } - _ => conwrite!(con, groups::UNKNOWN_ACTION)?, + _ => return Err(groups::UNKNOWN_ACTION.into()), } } } diff --git a/server/src/actions/lists/lmod.rs b/server/src/actions/lists/lmod.rs index cfd024bf..b5b090f2 100644 --- a/server/src/actions/lists/lmod.rs +++ b/server/src/actions/lists/lmod.rs @@ -24,7 +24,7 @@ * */ -use super::{writer, OKAY_BADIDX_NIL_NLUT}; +use super::OKAY_BADIDX_NIL_NLUT; use crate::corestore::Data; use crate::dbnet::connection::prelude::*; use crate::util::compiler; @@ -52,7 +52,7 @@ action! { () => { match unsafe { String::from_utf8_lossy(act.next_unchecked()) }.parse::() { Ok(int) => int, - Err(_) => return conwrite!(con, groups::WRONGTYPE_ERR), + Err(_) => return Err(groups::WRONGTYPE_ERR.into()), } }; } @@ -62,7 +62,7 @@ action! { ensure_length(act.len(), |len| len == 0)?; let list = match listmap.get_inner_ref().get(listname) { Some(l) => l, - _ => return conwrite!(con, groups::NIL), + _ => return Err(groups::NIL.into()), }; let okay = if registry::state_okay() { list.write().clear(); @@ -70,13 +70,13 @@ action! { } else { groups::SERVER_ERR }; - conwrite!(con, okay)?; + con._write_raw(okay).await? } PUSH => { ensure_boolean_or_aerr(!act.is_empty())?; let list = match listmap.get_inner_ref().get(listname) { Some(l) => l, - _ => return conwrite!(con, groups::NIL), + _ => return Err(groups::NIL.into()), }; let venc_ok = listmap.get_val_encoder(); let ret = if compiler::likely(act.as_ref().all(venc_ok)) { @@ -89,7 +89,7 @@ action! { } else { groups::ENCODING_ERROR }; - conwrite!(con, ret)?; + con._write_raw(ret).await? } REMOVE => { ensure_length(act.len(), |len| len == 1)?; @@ -104,9 +104,9 @@ action! { false } }); - conwrite!(con, OKAY_BADIDX_NIL_NLUT[maybe_value])?; + con._write_raw(OKAY_BADIDX_NIL_NLUT[maybe_value]).await? } else { - conwrite!(con, groups::SERVER_ERR)?; + return Err(groups::SERVER_ERR.into()); } } INSERT => { @@ -128,7 +128,7 @@ action! { false } }), - Err(()) => return conwrite!(con, groups::ENCODING_ERROR), + Err(()) => return Err(groups::ENCODING_ERROR.into()), }; OKAY_BADIDX_NIL_NLUT[maybe_insert] } else { @@ -139,7 +139,7 @@ action! { // encoding failed, uh groups::ENCODING_ERROR }; - conwrite!(con, ret)?; + con._write_raw(ret).await? } POP => { ensure_length(act.len(), |len| len < 2)?; @@ -165,24 +165,24 @@ action! { wlock.pop() } }), - Err(()) => return conwrite!(con, groups::ENCODING_ERROR), + Err(()) => return Err(groups::ENCODING_ERROR.into()), }; match maybe_pop { Some(Some(val)) => { - unsafe { - writer::write_raw_mono(con, listmap.get_value_tsymbol(), &val).await?; - } + con.write_mono_length_prefixed_with_tsymbol( + &val, listmap.get_value_tsymbol() + ).await?; } Some(None) => { - conwrite!(con, groups::LISTMAP_BAD_INDEX)?; + con._write_raw(groups::LISTMAP_BAD_INDEX).await?; } - None => conwrite!(con, groups::NIL)?, + None => con._write_raw(groups::NIL).await?, } } else { - conwrite!(con, groups::SERVER_ERR)?; + con._write_raw(groups::SERVER_ERR).await? } } - _ => conwrite!(con, groups::UNKNOWN_ACTION)?, + _ => con._write_raw(groups::UNKNOWN_ACTION).await?, } Ok(()) } diff --git a/server/src/actions/lists/macros.rs b/server/src/actions/lists/macros.rs index c2291d3b..fa428330 100644 --- a/server/src/actions/lists/macros.rs +++ b/server/src/actions/lists/macros.rs @@ -26,11 +26,10 @@ macro_rules! writelist { ($con:expr, $listmap:expr, $items:expr) => {{ - let mut typed_array_writer = - unsafe { TypedArrayWriter::new($con, $listmap.get_value_tsymbol(), $items.len()) } - .await?; + $con.write_typed_non_null_array_header($items.len(), $listmap.get_value_tsymbol()) + .await?; for item in $items { - typed_array_writer.write_element(item).await?; + $con.write_typed_non_null_array_element(&item).await?; } }}; } diff --git a/server/src/actions/lists/mod.rs b/server/src/actions/lists/mod.rs index 9f838d05..304a3799 100644 --- a/server/src/actions/lists/mod.rs +++ b/server/src/actions/lists/mod.rs @@ -35,7 +35,6 @@ use crate::corestore::booltable::BytesNicheLUT; use crate::corestore::Data; use crate::dbnet::connection::prelude::*; use crate::kvengine::LockedVec; -use crate::resp::writer; const OKAY_OVW_BLUT: BytesBoolTable = BytesBoolTable::new(groups::OKAY, groups::OVERWRITE_ERR); const OKAY_BADIDX_NIL_NLUT: BytesNicheLUT = @@ -57,9 +56,9 @@ action! { } else { false }; - conwrite!(con, OKAY_OVW_BLUT[did])?; + con._write_raw(OKAY_OVW_BLUT[did]).await? } else { - conwrite!(con, groups::SERVER_ERR)?; + con._write_raw(groups::SERVER_ERR).await? } Ok(()) } diff --git a/server/src/actions/lskeys.rs b/server/src/actions/lskeys.rs index 19aa032f..82ccc564 100644 --- a/server/src/actions/lskeys.rs +++ b/server/src/actions/lskeys.rs @@ -27,7 +27,6 @@ use crate::corestore::table::DataModel; use crate::corestore::Data; use crate::dbnet::connection::prelude::*; -use crate::resp::writer::TypedArrayWriter; const DEFAULT_COUNT: usize = 10; @@ -73,13 +72,9 @@ action!( DataModel::KV(kv) => kv.get_inner_ref().get_keys(count), DataModel::KVExtListmap(kv) => kv.get_inner_ref().get_keys(count), }; - let mut writer = unsafe { - // SAFETY: We have checked kty ourselves - TypedArrayWriter::new(con, tsymbol, items.len()) - } - .await?; + con.write_typed_non_null_array_header(items.len(), tsymbol).await?; for key in items { - writer.write_element(key).await?; + con.write_typed_non_null_array_element(&key).await?; } Ok(()) } diff --git a/server/src/actions/macros.rs b/server/src/actions/macros.rs index 5a69c6be..40be267a 100644 --- a/server/src/actions/macros.rs +++ b/server/src/actions/macros.rs @@ -46,22 +46,6 @@ macro_rules! is_lowbit_unset { }; } -#[macro_export] -macro_rules! conwrite { - ($con:expr, $what:expr) => { - $con.write_response($what) - .await - .map_err(|e| $crate::actions::ActionError::IoError(e)) - }; -} - -#[macro_export] -macro_rules! aerr { - ($con:expr) => { - return conwrite!($con, $crate::protocol::responses::groups::ACTION_ERR) - }; -} - #[macro_export] macro_rules! get_tbl { ($entity:expr, $store:expr, $con:expr) => {{ @@ -90,7 +74,7 @@ macro_rules! handle_entity { ($con:expr, $ident:expr) => {{ match $crate::queryengine::parser::Entity::from_slice(&$ident) { Ok(e) => e, - Err(e) => return conwrite!($con, e), + Err(e) => return Err(e.into()), } }}; } diff --git a/server/src/actions/mget.rs b/server/src/actions/mget.rs index 60c1ae17..a4789902 100644 --- a/server/src/actions/mget.rs +++ b/server/src/actions/mget.rs @@ -27,7 +27,6 @@ use crate::dbnet::connection::prelude::*; use crate::kvengine::encoding::ENCODING_LUT_ITER; use crate::queryengine::ActionIter; -use crate::resp::writer::TypedArrayWriter; use crate::util::compiler; action!( @@ -38,19 +37,16 @@ action!( let kve = handle.get_table_with::()?; let encoding_is_okay = ENCODING_LUT_ITER[kve.is_key_encoded()](act.as_ref()); if compiler::likely(encoding_is_okay) { - let mut writer = unsafe { - // SAFETY: We are getting the value type ourselves - TypedArrayWriter::new(con, kve.get_value_tsymbol(), act.len()) - } - .await?; + con.write_typed_array_header(act.len(), kve.get_value_tsymbol()) + .await?; for key in act { match kve.get_cloned_unchecked(key) { - Some(v) => writer.write_element(&v).await?, - None => writer.write_null().await?, + Some(v) => con.write_typed_array_element(&v).await?, + None => con.write_typed_array_element_null().await?, } } } else { - compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?; + return util::err(groups::ENCODING_ERROR); } Ok(()) } diff --git a/server/src/actions/mod.rs b/server/src/actions/mod.rs index 8fd8687f..22d99e88 100644 --- a/server/src/actions/mod.rs +++ b/server/src/actions/mod.rs @@ -121,16 +121,16 @@ pub fn ensure_cond_or_err(cond: bool, err: &'static [u8]) -> ActionResult<()> { pub mod heya { //! Respond to `HEYA` queries use crate::dbnet::connection::prelude::*; - use crate::resp::BytesWrapper; action!( /// Returns a `HEY!` `Response` fn heya(_handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { ensure_length(act.len(), |len| len == 0 || len == 1)?; if act.len() == 1 { - let raw_byte = unsafe { act.next_unchecked_bytes() }; - con.write_response(BytesWrapper(raw_byte)).await?; + let raw_byte = unsafe { act.next_unchecked() }; + con.write_mono_length_prefixed_with_tsymbol(raw_byte, b'+') + .await?; } else { - con.write_response(responses::groups::HEYA).await?; + return util::err(groups::HEYA); } Ok(()) } diff --git a/server/src/actions/mpop.rs b/server/src/actions/mpop.rs index 535e45c2..4a856252 100644 --- a/server/src/actions/mpop.rs +++ b/server/src/actions/mpop.rs @@ -27,9 +27,7 @@ use crate::corestore; use crate::dbnet::connection::prelude::*; use crate::kvengine::encoding::ENCODING_LUT_ITER; -use crate::protocol::responses; use crate::queryengine::ActionIter; -use crate::resp::writer::TypedArrayWriter; use crate::util::compiler; action!( @@ -40,23 +38,20 @@ action!( let kve = handle.get_table_with::()?; let encoding_is_okay = ENCODING_LUT_ITER[kve.is_key_encoded()](act.as_ref()); if compiler::likely(encoding_is_okay) { - let mut writer = unsafe { - // SAFETY: We have verified the tsymbol ourselves - TypedArrayWriter::new(con, kve.get_value_tsymbol(), act.len()) - } - .await?; + con.write_typed_array_header(act.len(), kve.get_value_tsymbol()) + .await?; for key in act { match kve.pop_unchecked(key) { - Some(val) => writer.write_element(val).await?, - None => writer.write_null().await?, + Some(val) => con.write_typed_array_element(&val).await?, + None => con.write_typed_array_element_null().await?, } } } else { - compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?; + return util::err(groups::ENCODING_ERROR); } } else { // don't begin the operation at all if the database is poisoned - con.write_response(responses::groups::SERVER_ERR).await?; + return util::err(groups::SERVER_ERR); } Ok(()) } diff --git a/server/src/actions/mset.rs b/server/src/actions/mset.rs index 33857073..2093244d 100644 --- a/server/src/actions/mset.rs +++ b/server/src/actions/mset.rs @@ -49,12 +49,12 @@ action!( None }; if let Some(done_howmany) = done_howmany { - con.write_response(done_howmany as usize).await?; + con.write_usize(done_howmany).await?; } else { - con.write_response(responses::groups::SERVER_ERR).await?; + return util::err(groups::SERVER_ERR); } } else { - compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?; + return util::err(groups::ENCODING_ERROR); } Ok(()) } diff --git a/server/src/actions/mupdate.rs b/server/src/actions/mupdate.rs index 3f3c4b0d..8c802418 100644 --- a/server/src/actions/mupdate.rs +++ b/server/src/actions/mupdate.rs @@ -50,12 +50,12 @@ action!( done_howmany = None; } if let Some(done_howmany) = done_howmany { - con.write_response(done_howmany as usize).await?; + con.write_usize(done_howmany).await?; } else { - con.write_response(responses::groups::SERVER_ERR).await?; + return util::err(groups::SERVER_ERR); } } else { - compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?; + return util::err(groups::ENCODING_ERROR); } Ok(()) } diff --git a/server/src/actions/pop.rs b/server/src/actions/pop.rs index e7a43836..00524ec4 100644 --- a/server/src/actions/pop.rs +++ b/server/src/actions/pop.rs @@ -25,8 +25,6 @@ */ use crate::dbnet::connection::prelude::*; -use crate::resp::writer; -use crate::util::compiler; action! { fn pop(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { @@ -37,17 +35,15 @@ action! { }; if registry::state_okay() { let kve = handle.get_table_with::()?; - let tsymbol = kve.get_value_tsymbol(); match kve.pop(key) { - Ok(Some(val)) => unsafe { - // SAFETY: We have verified the tsymbol ourselves - writer::write_raw_mono(con, tsymbol, &val).await? - }, - Ok(None) => conwrite!(con, groups::NIL)?, - Err(()) => compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?, + Ok(Some(val)) => con.write_mono_length_prefixed_with_tsymbol( + &val, kve.get_value_tsymbol() + ).await?, + Ok(None) => return util::err(groups::NIL), + Err(()) => return util::err(groups::ENCODING_ERROR), } } else { - conwrite!(con, groups::SERVER_ERR)?; + return util::err(groups::SERVER_ERR); } Ok(()) } diff --git a/server/src/actions/set.rs b/server/src/actions/set.rs index ca2fb45a..d5a0e799 100644 --- a/server/src/actions/set.rs +++ b/server/src/actions/set.rs @@ -56,9 +56,9 @@ action!( Err(()) => None, } }; - conwrite!(con, SET_NLUT[did_we])?; + con._write_raw(SET_NLUT[did_we]).await?; } else { - conwrite!(con, groups::SERVER_ERR)?; + con._write_raw(groups::SERVER_ERR).await?; } Ok(()) } diff --git a/server/src/actions/strong/sdel.rs b/server/src/actions/strong/sdel.rs index ce710c53..bb096788 100644 --- a/server/src/actions/strong/sdel.rs +++ b/server/src/actions/strong/sdel.rs @@ -48,15 +48,15 @@ action! { self::snapshot_and_del(kve, key_encoder, act.into_inner()) }; match outcome { - StrongActionResult::Okay => conwrite!(con, groups::OKAY)?, + StrongActionResult::Okay => con._write_raw(groups::OKAY).await?, StrongActionResult::Nil => { // good, it failed because some key didn't exist - conwrite!(con, groups::NIL)?; + return util::err(groups::NIL); }, - StrongActionResult::ServerError => conwrite!(con, groups::SERVER_ERR)?, + StrongActionResult::ServerError => return util::err(groups::SERVER_ERR), StrongActionResult::EncodingError => { // error we love to hate: encoding error, ugh - compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))? + return util::err(groups::ENCODING_ERROR); }, StrongActionResult::OverwriteError => unsafe { // SAFETY check: never the case @@ -64,7 +64,7 @@ action! { } } } else { - conwrite!(con, groups::SERVER_ERR)?; + return util::err(groups::SERVER_ERR); } Ok(()) } diff --git a/server/src/actions/strong/sset.rs b/server/src/actions/strong/sset.rs index bdf632a5..384e7a4c 100644 --- a/server/src/actions/strong/sset.rs +++ b/server/src/actions/strong/sset.rs @@ -50,12 +50,12 @@ action! { self::snapshot_and_insert(kve, encoder, act.into_inner()) }; match outcome { - StrongActionResult::Okay => conwrite!(con, groups::OKAY)?, - StrongActionResult::OverwriteError => conwrite!(con, groups::OVERWRITE_ERR)?, - StrongActionResult::ServerError => conwrite!(con, groups::SERVER_ERR)?, + StrongActionResult::Okay => con._write_raw(groups::OKAY).await?, + StrongActionResult::OverwriteError => return util::err(groups::OVERWRITE_ERR), + StrongActionResult::ServerError => return util::err(groups::SERVER_ERR), StrongActionResult::EncodingError => { // error we love to hate: encoding error, ugh - compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))? + return util::err(groups::ENCODING_ERROR); }, StrongActionResult::Nil => unsafe { // SAFETY check: never the case @@ -63,7 +63,7 @@ action! { } } } else { - conwrite!(con, groups::SERVER_ERR)?; + return util::err(groups::SERVER_ERR); } Ok(()) } diff --git a/server/src/actions/strong/supdate.rs b/server/src/actions/strong/supdate.rs index e217ebe8..c30c49d0 100644 --- a/server/src/actions/strong/supdate.rs +++ b/server/src/actions/strong/supdate.rs @@ -49,15 +49,15 @@ action! { self::snapshot_and_update(kve, encoder, act.into_inner()) }; match outcome { - StrongActionResult::Okay => conwrite!(con, groups::OKAY)?, + StrongActionResult::Okay => con._write_raw(groups::OKAY).await?, StrongActionResult::Nil => { // good, it failed because some key didn't exist - conwrite!(con, groups::NIL)?; + return util::err(groups::NIL); }, - StrongActionResult::ServerError => conwrite!(con, groups::SERVER_ERR)?, + StrongActionResult::ServerError => return util::err(groups::SERVER_ERR), StrongActionResult::EncodingError => { // error we love to hate: encoding error, ugh - compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))? + return util::err(groups::ENCODING_ERROR); }, StrongActionResult::OverwriteError => unsafe { // SAFETY check: never the case @@ -65,7 +65,7 @@ action! { } } } else { - conwrite!(con, groups::SERVER_ERR)?; + return util::err(groups::SERVER_ERR); } Ok(()) } diff --git a/server/src/actions/update.rs b/server/src/actions/update.rs index 6280e93d..e991d08b 100644 --- a/server/src/actions/update.rs +++ b/server/src/actions/update.rs @@ -55,9 +55,9 @@ action!( Err(()) => None, } }; - conwrite!(con, UPDATE_NLUT[did_we])?; + con._write_raw(UPDATE_NLUT[did_we]).await?; } else { - conwrite!(con, groups::SERVER_ERR)?; + return util::err(groups::SERVER_ERR); } Ok(()) } diff --git a/server/src/actions/uset.rs b/server/src/actions/uset.rs index ad3d58e9..29ca1dad 100644 --- a/server/src/actions/uset.rs +++ b/server/src/actions/uset.rs @@ -44,12 +44,12 @@ action!( while let (Some(key), Some(val)) = (act.next(), act.next()) { kve.upsert_unchecked(Data::copy_from_slice(key), Data::copy_from_slice(val)); } - conwrite!(con, howmany / 2)?; + con.write_usize(howmany / 2).await?; } else { - conwrite!(con, groups::SERVER_ERR)?; + return util::err(groups::SERVER_ERR); } } else { - compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?; + return util::err(groups::ENCODING_ERROR); } Ok(()) } diff --git a/server/src/actions/whereami.rs b/server/src/actions/whereami.rs index 70c94eaa..03053bb3 100644 --- a/server/src/actions/whereami.rs +++ b/server/src/actions/whereami.rs @@ -25,20 +25,19 @@ */ use crate::dbnet::connection::prelude::*; -use crate::resp::writer::NonNullArrayWriter; action! { fn whereami(store: &Corestore, con: &mut T, act: ActionIter<'a>) { ensure_length(act.len(), |len| len == 0)?; match store.get_ids() { (Some(ks), Some(tbl)) => { - let mut writer = unsafe { NonNullArrayWriter::new(con, b'+', 2).await? }; - writer.write_element(ks).await?; - writer.write_element(tbl).await?; + con.write_typed_non_null_array_header(2, b'+').await?; + con.write_typed_non_null_array_element(ks).await?; + con.write_typed_non_null_array_element(tbl).await?; }, (Some(ks), None) => { - let mut writer = unsafe { NonNullArrayWriter::new(con, b'+', 1).await? }; - writer.write_element(ks).await?; + con.write_typed_non_null_array_header(1, b'+').await?; + con.write_typed_non_null_array_element(ks).await?; }, _ => unsafe { impossible!() } } diff --git a/server/src/admin/mksnap.rs b/server/src/admin/mksnap.rs index 9e48be41..8a6c3d1b 100644 --- a/server/src/admin/mksnap.rs +++ b/server/src/admin/mksnap.rs @@ -38,10 +38,10 @@ action!( if act.is_empty() { // traditional mksnap match engine.mksnap(handle.clone_store()).await { - SnapshotActionResult::Ok => conwrite!(con, groups::OKAY)?, - SnapshotActionResult::Failure => conwrite!(con, groups::SERVER_ERR)?, - SnapshotActionResult::Disabled => conwrite!(con, groups::SNAPSHOT_DISABLED)?, - SnapshotActionResult::Busy => conwrite!(con, groups::SNAPSHOT_BUSY)?, + SnapshotActionResult::Ok => con._write_raw(groups::OKAY).await?, + SnapshotActionResult::Failure => return util::err(groups::SERVER_ERR), + SnapshotActionResult::Disabled => return util::err(groups::SNAPSHOT_DISABLED), + SnapshotActionResult::Busy => return util::err(groups::SNAPSHOT_BUSY), _ => unsafe { impossible!() }, } } else if act.len() == 1 { @@ -51,7 +51,7 @@ action!( act.next_unchecked_bytes() }; if !encoding::is_utf8(&name) { - return conwrite!(con, groups::ENCODING_ERROR); + return util::err(groups::ENCODING_ERROR); } // SECURITY: Check for directory traversal syntax @@ -72,19 +72,21 @@ action!( .count() != 0; if illegal_snapshot { - return conwrite!(con, groups::SNAPSHOT_ILLEGAL_NAME); + return util::err(groups::SNAPSHOT_ILLEGAL_NAME); } // now make the snapshot match engine.mkrsnap(name, handle.clone_store()).await { - SnapshotActionResult::Ok => conwrite!(con, groups::OKAY)?, - SnapshotActionResult::Failure => conwrite!(con, groups::SERVER_ERR)?, - SnapshotActionResult::Busy => conwrite!(con, groups::SNAPSHOT_BUSY)?, - SnapshotActionResult::AlreadyExists => conwrite!(con, groups::SNAPSHOT_DUPLICATE)?, + SnapshotActionResult::Ok => con._write_raw(groups::OKAY).await?, + SnapshotActionResult::Failure => return util::err(groups::SERVER_ERR), + SnapshotActionResult::Busy => return util::err(groups::SNAPSHOT_BUSY), + SnapshotActionResult::AlreadyExists => { + return util::err(groups::SNAPSHOT_DUPLICATE) + } _ => unsafe { impossible!() }, } } else { - conwrite!(con, groups::ACTION_ERR)?; + return util::err(groups::ACTION_ERR); } Ok(()) } diff --git a/server/src/admin/sys.rs b/server/src/admin/sys.rs index d432c1c7..8644c79e 100644 --- a/server/src/admin/sys.rs +++ b/server/src/admin/sys.rs @@ -56,9 +56,9 @@ action! { } fn sys_info(con: &mut T, iter: &mut ActionIter<'_>) { match unsafe { iter.next_lowercase_unchecked() }.as_ref() { - INFO_PROTOCOL => con.write_response(PROTOCOL_VERSIONSTRING).await?, - INFO_PROTOVER => con.write_response(PROTOCOL_VERSION).await?, - INFO_VERSION => con.write_response(VERSION).await?, + INFO_PROTOCOL => con.write_string(PROTOCOL_VERSIONSTRING).await?, + INFO_PROTOVER => con.write_float(PROTOCOL_VERSION).await?, + INFO_VERSION => con.write_string(VERSION).await?, _ => return util::err(ERR_UNKNOWN_PROPERTY), } Ok(()) @@ -66,14 +66,14 @@ action! { fn sys_metric(con: &mut T, iter: &mut ActionIter<'_>) { match unsafe { iter.next_lowercase_unchecked() }.as_ref() { METRIC_HEALTH => { - con.write_response(HEALTH_TABLE[registry::state_okay()]).await? + con.write_string(HEALTH_TABLE[registry::state_okay()]).await? } METRIC_STORAGE_USAGE => { match util::os::dirsize(DIR_ROOT) { - Ok(size) => con.write_response(size).await?, + Ok(size) => con.write_int64(size).await?, Err(e) => { log::error!("Failed to get storage usage with: {e}"); - con.write_response(groups::SERVER_ERR).await? + return util::err(groups::SERVER_ERR); }, } } diff --git a/server/src/auth/mod.rs b/server/src/auth/mod.rs index f0ddc5b5..6f234a92 100644 --- a/server/src/auth/mod.rs +++ b/server/src/auth/mod.rs @@ -38,7 +38,6 @@ mod keys; pub mod provider; -use crate::resp::{writer::NonNullArrayWriter, TSYMBOL_UNICODE_STRING}; pub use provider::{AuthProvider, AuthResult, Authmap}; pub mod errors; pub use errors::AuthError; @@ -72,20 +71,20 @@ action! { ensure_boolean_or_aerr(iter.len() == 1)?; // just the username let username = unsafe { iter.next_unchecked() }; let key = auth.provider_mut().claim_user(username)?; - con.write_response(StringWrapper(key)).await?; + con.write_string(&key).await?; Ok(()) } AUTH_LOGOUT => { ensure_boolean_or_aerr(iter.is_empty())?; // nothing else auth.provider_mut().logout()?; auth.swap_executor_to_anonymous(); - con.write_response(groups::OKAY).await?; + con._write_raw(groups::OKAY).await?; Ok(()) } AUTH_DELUSER => { ensure_boolean_or_aerr(iter.len() == 1)?; // just the username auth.provider_mut().delete_user(unsafe { iter.next_unchecked() })?; - con.write_response(groups::OKAY).await?; + con._write_raw(groups::OKAY).await?; Ok(()) } AUTH_RESTORE => self::auth_restore(con, auth, &mut iter).await, @@ -96,18 +95,15 @@ action! { } fn auth_whoami(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { ensure_boolean_or_aerr(ActionIter::is_empty(iter))?; - con.write_response(StringWrapper(auth.provider().whoami()?)).await?; + con.write_string(&auth.provider().whoami()?).await?; Ok(()) } fn auth_listuser(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { ensure_boolean_or_aerr(ActionIter::is_empty(iter))?; let usernames = auth.provider().collect_usernames()?; - let mut array_writer = unsafe { - // The symbol is definitely correct, obvious from this context - NonNullArrayWriter::new(con, TSYMBOL_UNICODE_STRING, usernames.len()) - }.await?; + con.write_typed_non_null_array_header(usernames.len(), b'+').await?; for username in usernames { - array_writer.write_element(username).await?; + con.write_typed_non_null_array_element(username.as_bytes()).await?; } Ok(()) } @@ -125,7 +121,7 @@ action! { } _ => return util::err(groups::ACTION_ERR), }; - con.write_response(StringWrapper(newkey)).await?; + con.write_string(&newkey).await?; Ok(()) } fn _auth_claim(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { @@ -133,7 +129,7 @@ action! { let origin_key = unsafe { iter.next_unchecked() }; let key = auth.provider_mut().claim_root(origin_key)?; auth.swap_executor_to_authenticated(); - con.write_response(StringWrapper(key)).await?; + con.write_string(&key).await?; Ok(()) } /// Handle a login operation only. The **`login` token is expected to be present** @@ -157,7 +153,7 @@ action! { let (username, password) = unsafe { (iter.next_unchecked(), iter.next_unchecked()) }; auth.provider_mut().login(username, password)?; auth.swap_executor_to_authenticated(); - con.write_response(groups::OKAY).await?; + con._write_raw(groups::OKAY).await?; Ok(()) } } diff --git a/server/src/dbnet/connection.rs b/server/src/dbnet/connection.rs index 3605971e..b7a8e8f2 100644 --- a/server/src/dbnet/connection.rs +++ b/server/src/dbnet/connection.rs @@ -45,12 +45,14 @@ use crate::{ Terminator, }, protocol::{ - interface::{ProtocolRead, ProtocolSpec}, + interface::{ProtocolRead, ProtocolSpec, ProtocolWrite}, responses, Query, }, queryengine, IoResult, }; use bytes::{Buf, BytesMut}; +#[cfg(windows)] +use std::io::ErrorKind; use std::{marker::PhantomData, sync::Arc}; use tokio::{ io::{AsyncReadExt, AsyncWriteExt, BufWriter}, @@ -66,7 +68,7 @@ pub enum QueryResult { Disconnected, } -pub struct AuthProviderHandle<'a, P: ProtocolSpec, T, Strm> { +pub struct AuthProviderHandle<'a, P, T, Strm> { provider: &'a mut AuthProvider, executor: &'a mut ExecutorFn, _phantom: PhantomData<(T, Strm)>, @@ -106,7 +108,6 @@ pub mod prelude { pub use super::{AuthProviderHandle, ClientConnection, Stream}; pub use crate::{ actions::{ensure_boolean_or_aerr, ensure_cond_or_err, ensure_length}, - aerr, conwrite, corestore::{ table::{KVEBlob, KVEList}, Corestore, @@ -118,7 +119,6 @@ pub mod prelude { }, queryengine::ActionIter, registry, - resp::StringWrapper, util::{self, FutureResult, UnwrapActionError, Unwrappable}, }; pub use tokio::io::{AsyncReadExt, AsyncWriteExt}; @@ -207,7 +207,7 @@ pub struct ConnectionHandler { impl ConnectionHandler where - T: ProtocolRead + Send + Sync, + T: ProtocolRead + ProtocolWrite + Send + Sync, Strm: Stream, P: ProtocolSpec, { @@ -286,7 +286,7 @@ where Ok(QueryResult::E(r)) => self.con.close_conn_with_error(r).await?, Ok(QueryResult::Wrongtype) => { self.con - .close_conn_with_error(responses::groups::WRONGTYPE_ERR.to_owned()) + .close_conn_with_error(responses::groups::WRONGTYPE_ERR) .await? } Ok(QueryResult::Disconnected) => return Ok(()), @@ -315,7 +315,7 @@ where } Query::Pipelined(_) => { con.write_simple_query_header().await?; - con.write_response(auth::errors::AUTH_CODE_BAD_CREDENTIALS) + con._write_raw(auth::errors::AUTH_CODE_BAD_CREDENTIALS) .await?; } } @@ -335,7 +335,7 @@ where queryengine::execute_simple(db, con, &mut auth_provider, q).await?; } Query::Pipelined(pipeline) => { - con.write_pipeline_query_header(pipeline.len()).await?; + con.write_pipelined_query_header(pipeline.len()).await?; queryengine::execute_pipeline(db, con, &mut auth_provider, pipeline).await?; } } @@ -346,7 +346,7 @@ where /// Execute a query that has already been validated by `Connection::read_query` async fn execute_query(&mut self, query: Query) -> ActionResult<()> { (self.executor)(self, query).await?; - self.con.flush_stream().await?; + self.con._flush_stream().await?; Ok(()) } } @@ -365,12 +365,12 @@ impl Stream for T where T: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync /// A simple _shorthand trait_ for the insanely long definition of the connection generic type pub trait ClientConnection: - ProtocolRead + Send + Sync + ProtocolWrite + ProtocolRead + Send + Sync { } impl ClientConnection for T where - T: ProtocolRead + Send + Sync, + T: ProtocolWrite + ProtocolRead + Send + Sync, Strm: Stream, P: ProtocolSpec, { diff --git a/server/src/dbnet/tcp.rs b/server/src/dbnet/tcp.rs index 0c9df616..64f7c9ad 100644 --- a/server/src/dbnet/tcp.rs +++ b/server/src/dbnet/tcp.rs @@ -25,7 +25,7 @@ */ use crate::protocol::{ - interface::{ProtocolRead, ProtocolSpec}, + interface::{ProtocolRead, ProtocolSpec, ProtocolWrite}, Skyhash2, }; use crate::{ @@ -108,7 +108,7 @@ pub struct RawListener

{ impl RawListener

where - Connection: ProtocolRead, + Connection: ProtocolRead + ProtocolWrite, { pub fn new(base: BaseListener) -> Self { Self { diff --git a/server/src/dbnet/tls.rs b/server/src/dbnet/tls.rs index 3308f2d0..d258c257 100644 --- a/server/src/dbnet/tls.rs +++ b/server/src/dbnet/tls.rs @@ -31,7 +31,7 @@ use crate::{ BaseListener, Terminator, }, protocol::{ - interface::{ProtocolRead, ProtocolSpec}, + interface::{ProtocolRead, ProtocolSpec, ProtocolWrite}, Skyhash2, }, util::error::{Error, SkyResult}, @@ -59,7 +59,8 @@ pub struct SslListenerRaw

{ impl SslListenerRaw

where - Connection>: ProtocolRead>, + Connection>: + ProtocolRead> + ProtocolWrite>, { pub fn new_pem_based_ssl_connection( key_file: String, diff --git a/server/src/main.rs b/server/src/main.rs index be9abc35..34d62a2c 100644 --- a/server/src/main.rs +++ b/server/src/main.rs @@ -56,7 +56,6 @@ mod kvengine; mod protocol; mod queryengine; pub mod registry; -mod resp; mod services; mod storage; #[cfg(test)] diff --git a/server/src/protocol/interface/mod.rs b/server/src/protocol/interface/mod.rs index 8ec82079..a4feb8e8 100644 --- a/server/src/protocol/interface/mod.rs +++ b/server/src/protocol/interface/mod.rs @@ -28,16 +28,13 @@ use super::{responses, ParseError}; use crate::{ corestore::buffers::Integer64, dbnet::connection::{QueryResult, QueryWithAdvance, RawConnection, Stream}, - resp::Writable, util::FutureResult, IoResult, }; use std::io::{Error as IoError, ErrorKind}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}; -pub const SIMPLE_QUERY_HEADER: [u8; 1] = [b'*']; - -pub trait ProtocolCharset { +pub trait ProtocolSpec { const TSYMBOL_STRING: u8; const TSYMBOL_BINARY: u8; const TSYMBOL_FLOAT: u8; @@ -52,12 +49,6 @@ pub trait ProtocolCharset { const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8]; } -/// The [`ProtocolSpec`] trait implementation enables extremely easy switching between -/// protocols by being generic for the same base connection types -pub trait ProtocolSpec: Send + Sync + Sized + ProtocolCharset { - fn parse(buf: &[u8]) -> Result; -} - /// # The `ProtocolRead` trait /// /// The `ProtocolRead` trait enables read operations using the protocol for a given stream `Strm` and protocol @@ -68,16 +59,13 @@ pub trait ProtocolSpec: Send + Sync + Sized + ProtocolCharset { /// The fact that this is a trait enables great flexibility in terms of visibility, but **DO NOT EVER CALL any /// function other than `read_query`, `close_conn_with_error` or `write_response`**. If you mess with functions /// like `read_again`, you're likely to pull yourself into some good trouble. -#[async_trait::async_trait] pub trait ProtocolRead: RawConnection where Strm: Stream, P: ProtocolSpec, { /// Try to parse a query from the buffered data - fn try_query(&self) -> Result { - P::parse(self.get_buffer()) - } + fn try_query(&self) -> Result; /// Read a query from the remote end /// /// This function asynchronously waits until all the data required @@ -111,184 +99,230 @@ where } }) } - /// Write a response to the stream - fn write_response<'s, 'r: 's>( - &'r mut self, - streamer: impl Writable + 's + Send + Sync, - ) -> FutureResult<'s, IoResult<()>> { - Box::pin(async move { - let mv_self = self; - let streamer = streamer; - let ret: IoResult<()> = { - streamer.write(mv_self.get_mut_stream()).await?; - Ok(()) - }; - ret - }) +} + +pub trait ProtocolWrite: RawConnection +where + Strm: Stream, + P: ProtocolSpec, +{ + // utility + fn _get_raw_stream(&mut self) -> &mut BufWriter { + self.get_mut_stream() } - /// Write the simple query header `*` to the stream - fn write_simple_query_header<'s, 'r: 's>(&'r mut self) -> FutureResult<'s, IoResult<()>> { - Box::pin(async move { - let mv_self = self; - let ret: IoResult<()> = { - mv_self.write_response(SIMPLE_QUERY_HEADER).await?; - Ok(()) - }; - ret - }) + fn _flush_stream<'life0, 'ret_life>(&'life0 mut self) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: Send + 'ret_life, + { + Box::pin(async move { self.get_mut_stream().flush().await }) } - /// Write the length of the pipeline query (*) - fn write_pipeline_query_header<'s, 'r: 's>( - &'r mut self, - len: usize, - ) -> FutureResult<'s, IoResult<()>> { - Box::pin(async move { - let slf = self; - slf.write_response([b'$']).await?; - slf.get_mut_stream() - .write_all(&Integer64::init(len as u64)) - .await?; - slf.write_response([b'\n']).await?; - Ok(()) - }) + fn _write_raw<'life0, 'life1, 'ret_life>( + &'life0 mut self, + data: &'life1 [u8], + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: Send + 'ret_life, + { + Box::pin(async move { self.get_mut_stream().write_all(data).await }) } - /// Write the flat array length (`_\n`) - fn write_flat_array_length<'s, 'r: 's>( - &'r mut self, - len: usize, - ) -> FutureResult<'s, IoResult<()>> { + fn _write_raw_flushed<'life0, 'life1, 'ret_life>( + &'life0 mut self, + data: &'life1 [u8], + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: Send + 'ret_life, + { Box::pin(async move { - let mv_self = self; - let ret: IoResult<()> = { - mv_self.write_response([b'_']).await?; - mv_self.write_response(len.to_string().into_bytes()).await?; - mv_self.write_response([b'\n']).await?; - Ok(()) - }; - ret + self._write_raw(data).await?; + self._flush_stream().await }) } - /// Write the array length (`&\n`) - fn write_array_length<'s, 'r: 's>(&'r mut self, len: usize) -> FutureResult<'s, IoResult<()>> { - Box::pin(async move { - let mv_self = self; - let ret: IoResult<()> = { - mv_self.write_response([b'&']).await?; - mv_self.write_response(len.to_string().into_bytes()).await?; - mv_self.write_response([b'\n']).await?; - Ok(()) - }; - ret - }) + fn close_conn_with_error<'life0, 'life1, 'ret_life>( + &'life0 mut self, + resp: &'life1 [u8], + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: Send + 'ret_life, + { + Box::pin(async move { self._write_raw_flushed(resp).await }) } - /// Wraps around the `write_response` used to differentiate between a - /// success response and an error response - fn close_conn_with_error<'s, 'r: 's>( - &'r mut self, - resp: impl Writable + 's + Send + Sync, - ) -> FutureResult<'s, IoResult<()>> { + + // metaframe + fn write_simple_query_header<'life0, 'ret_life>( + &'life0 mut self, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: Send + 'ret_life, + { Box::pin(async move { - let mv_self = self; - let ret: IoResult<()> = { - mv_self.write_response(resp).await?; - mv_self.flush_stream().await?; - Ok(()) - }; - ret + self.get_mut_stream() + .write_all(P::SIMPLE_QUERY_HEADER) + .await }) } - fn flush_stream<'s, 'r: 's>(&'r mut self) -> FutureResult<'s, IoResult<()>> { + fn write_pipelined_query_header<'life0, 'ret_life>( + &'life0 mut self, + qcount: usize, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: Send + 'ret_life, + { Box::pin(async move { - let mv_self = self; - let ret: IoResult<()> = { - mv_self.get_mut_stream().flush().await?; - Ok(()) - }; - ret + self.get_mut_stream() + .write_all(&[P::PIPELINED_QUERY_FIRST_BYTE]) + .await?; + self.get_mut_stream() + .write_all(&Integer64::from(qcount)) + .await?; + self.get_mut_stream().write_all(&[P::LF]).await }) } - unsafe fn raw_stream(&mut self) -> &mut BufWriter { - self.get_mut_stream() - } -} - -impl ProtocolRead for T -where - T: RawConnection + Send + Sync, - Strm: Stream, - P: ProtocolSpec, -{ -} - -#[async_trait::async_trait] -pub trait ProtocolWrite: RawConnection -where - Strm: Stream, - P: ProtocolSpec, -{ - // utility (intentionally underscored to avoid direct access) - fn _get_raw_stream(&mut self) -> &mut BufWriter { - self.get_mut_stream() - } - // metaframe methods - async fn write_simple_query_header(&mut self) -> IoResult<()> { - self.get_mut_stream() - .write_all(P::SIMPLE_QUERY_HEADER) - .await - } - async fn write_pipelined_query_header(&mut self, qcount: usize) -> IoResult<()> { - self.get_mut_stream() - .write_all(&[P::PIPELINED_QUERY_FIRST_BYTE]) - .await?; - self.get_mut_stream() - .write_all(&Integer64::from(qcount)) - .await?; - self.get_mut_stream().write_all(&[P::LF]).await + // monoelement + fn write_mono_length_prefixed_with_tsymbol<'life0, 'life1, 'ret_life>( + &'life0 mut self, + data: &'life1 [u8], + tsymbol: u8, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: Send + 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // + stream.write_all(&[tsymbol]).await?; + stream.write_all(&Integer64::from(data.len())).await?; + stream.write_all(&[P::LF]).await?; + stream.write_all(data).await + }) } - - // monoelements /// serialize and write an `&str` to the stream - async fn write_string(&mut self, string: &str) -> IoResult<()>; + fn write_string<'life0, 'life1, 'ret_life>( + &'life0 mut self, + string: &'life1 str, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: 'ret_life; /// serialize and write an `&[u8]` to the stream - async fn write_binary(&mut self, binary: &[u8]) -> IoResult<()>; + fn write_binary<'life0, 'life1, 'ret_life>( + &'life0 mut self, + binary: &'life1 [u8], + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: 'ret_life; /// serialize and write an `usize` to the stream - async fn write_usize(&mut self, size: usize) -> IoResult<()>; + fn write_usize<'life0, 'ret_life>( + &'life0 mut self, + size: usize, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: 'ret_life; + /// serialize and write an `u64` to the stream + fn write_int64<'life0, 'ret_life>( + &'life0 mut self, + int: u64, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: 'ret_life; /// serialize and write an `f32` to the stream - async fn write_float(&mut self, float: f32) -> IoResult<()>; + fn write_float<'life0, 'ret_life>( + &'life0 mut self, + float: f32, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: 'ret_life; // typed array - async fn write_typed_array_header(&mut self, len: usize, tsymbol: u8) -> IoResult<()> { - // \n - self.get_mut_stream() - .write_all(&[P::TSYMBOL_TYPED_ARRAY, tsymbol]) - .await?; - self.get_mut_stream() - .write_all(&Integer64::from(len)) - .await?; - self.get_mut_stream().write_all(&[P::LF]).await?; - Ok(()) + fn write_typed_array_header<'life0, 'ret_life>( + &'life0 mut self, + len: usize, + tsymbol: u8, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: Send + 'ret_life, + { + Box::pin(async move { + self.get_mut_stream() + .write_all(&[P::TSYMBOL_TYPED_ARRAY, tsymbol]) + .await?; + self.get_mut_stream() + .write_all(&Integer64::from(len)) + .await?; + self.get_mut_stream().write_all(&[P::LF]).await?; + Ok(()) + }) } - async fn write_typed_array_element_null(&mut self) -> IoResult<()> { - self.get_mut_stream() - .write_all(P::TYPE_TYPED_ARRAY_ELEMENT_NULL) - .await + fn write_typed_array_element_null<'life0, 'ret_life>( + &'life0 mut self, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: Send + 'ret_life, + { + Box::pin(async move { + self.get_mut_stream() + .write_all(P::TYPE_TYPED_ARRAY_ELEMENT_NULL) + .await + }) } - async fn write_typed_array_element(&mut self, element: &[u8]) -> IoResult<()>; + fn write_typed_array_element<'life0, 'life1, 'ret_life>( + &'life0 mut self, + element: &'life1 [u8], + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: 'ret_life; // typed non-null array - async fn write_typed_non_null_array_header(&mut self, len: usize, tsymbol: u8) -> IoResult<()> { - // \n - self.get_mut_stream() - .write_all(&[P::TSYMBOL_TYPED_NON_NULL_ARRAY, tsymbol]) - .await?; - self.get_mut_stream() - .write_all(&Integer64::from(len)) - .await?; - self.get_mut_stream().write_all(&[P::LF]).await?; - Ok(()) + fn write_typed_non_null_array_header<'life0, 'ret_life>( + &'life0 mut self, + len: usize, + tsymbol: u8, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: Send + 'ret_life, + { + Box::pin(async move { + self.get_mut_stream() + .write_all(&[P::TSYMBOL_TYPED_NON_NULL_ARRAY, tsymbol]) + .await?; + self.get_mut_stream() + .write_all(&Integer64::from(len)) + .await?; + self.get_mut_stream().write_all(&[P::LF]).await?; + Ok(()) + }) } - async fn write_typed_non_null_array_element(&mut self, element: &[u8]) -> IoResult<()> { - self.write_typed_array_element(element).await + fn write_typed_non_null_array_element<'life0, 'life1, 'ret_life>( + &'life0 mut self, + element: &'life1 [u8], + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: Send + 'ret_life, + { + Box::pin(async move { self.write_typed_array_element(element).await }) } } diff --git a/server/src/protocol/v2/mod.rs b/server/src/protocol/v2/mod.rs index 6ae004ec..7a66de08 100644 --- a/server/src/protocol/v2/mod.rs +++ b/server/src/protocol/v2/mod.rs @@ -25,17 +25,18 @@ */ use super::{ - interface::{ProtocolCharset, ProtocolSpec, ProtocolWrite}, + interface::{ProtocolRead, ProtocolSpec, ProtocolWrite}, ParseError, Skyhash2, }; use crate::{ corestore::buffers::Integer64, dbnet::connection::{QueryWithAdvance, RawConnection, Stream}, + util::FutureResult, IoResult, }; use tokio::io::AsyncWriteExt; -impl ProtocolCharset for Skyhash2 { +impl ProtocolSpec for Skyhash2 { const TSYMBOL_STRING: u8 = b'+'; const TSYMBOL_BINARY: u8 = b'?'; const TSYMBOL_FLOAT: u8 = b'%'; @@ -49,67 +50,136 @@ impl ProtocolCharset for Skyhash2 { const PIPELINED_QUERY_FIRST_BYTE: u8 = b'$'; } -impl ProtocolSpec for Skyhash2 { - fn parse(buf: &[u8]) -> Result { - Skyhash2::parse(buf) +impl ProtocolRead for T +where + T: RawConnection + Send + Sync, + Strm: Stream, +{ + fn try_query(&self) -> Result { + Skyhash2::parse(self.get_buffer()) } } -#[async_trait::async_trait] impl ProtocolWrite for T where T: RawConnection + Send + Sync, Strm: Stream, { - async fn write_string(&mut self, string: &str) -> IoResult<()> { - let stream = self.get_mut_stream(); - // tsymbol - stream.write_all(&[Skyhash2::TSYMBOL_STRING]).await?; - // length - let len_bytes = Integer64::from(string.len()); - stream.write_all(&len_bytes).await?; - // LF - stream.write_all(&[Skyhash2::LF]).await?; - // payload - stream.write_all(string.as_bytes()).await + fn write_string<'life0, 'life1, 'ret_life>( + &'life0 mut self, + string: &'life1 str, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash2::TSYMBOL_STRING]).await?; + // length + let len_bytes = Integer64::from(string.len()); + stream.write_all(&len_bytes).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await?; + // payload + stream.write_all(string.as_bytes()).await + }) + } + fn write_binary<'life0, 'life1, 'ret_life>( + &'life0 mut self, + binary: &'life1 [u8], + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash2::TSYMBOL_BINARY]).await?; + // length + let len_bytes = Integer64::from(binary.len()); + stream.write_all(&len_bytes).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await?; + // payload + stream.write_all(binary).await + }) } - async fn write_binary(&mut self, binary: &[u8]) -> IoResult<()> { - let stream = self.get_mut_stream(); - // tsymbol - stream.write_all(&[Skyhash2::TSYMBOL_BINARY]).await?; - // length - let len_bytes = Integer64::from(binary.len()); - stream.write_all(&len_bytes).await?; - // LF - stream.write_all(&[Skyhash2::LF]).await?; - // payload - stream.write_all(binary).await + fn write_usize<'life0, 'ret_life>( + &'life0 mut self, + size: usize, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash2::TSYMBOL_INT64]).await?; + // body + stream.write_all(&Integer64::from(size)).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await + }) } - async fn write_usize(&mut self, size: usize) -> IoResult<()> { - let stream = self.get_mut_stream(); - // tsymbol - stream.write_all(&[Skyhash2::TSYMBOL_INT64]).await?; - // body - stream.write_all(&Integer64::from(size)).await?; - // LF - stream.write_all(&[Skyhash2::LF]).await + fn write_int64<'life0, 'ret_life>( + &'life0 mut self, + int: u64, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash2::TSYMBOL_INT64]).await?; + // body + stream.write_all(&Integer64::from(int)).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await + }) } - async fn write_float(&mut self, float: f32) -> IoResult<()> { - let stream = self.get_mut_stream(); - // tsymbol - stream.write_all(&[Skyhash2::TSYMBOL_FLOAT]).await?; - // body - stream.write_all(float.to_string().as_bytes()).await?; - // LF - stream.write_all(&[Skyhash2::LF]).await + fn write_float<'life0, 'ret_life>( + &'life0 mut self, + float: f32, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash2::TSYMBOL_FLOAT]).await?; + // body + stream.write_all(float.to_string().as_bytes()).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await + }) } - async fn write_typed_array_element(&mut self, element: &[u8]) -> IoResult<()> { - let stream = self.get_mut_stream(); - // len - stream.write_all(&Integer64::from(element.len())).await?; - // LF - stream.write_all(&[Skyhash2::LF]).await?; - // body - stream.write_all(element).await + fn write_typed_array_element<'life0, 'life1, 'ret_life>( + &'life0 mut self, + element: &'life1 [u8], + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // len + stream.write_all(&Integer64::from(element.len())).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await?; + // body + stream.write_all(element).await + }) } } diff --git a/server/src/queryengine/ddl.rs b/server/src/queryengine/ddl.rs index 9f35e765..058a0004 100644 --- a/server/src/queryengine/ddl.rs +++ b/server/src/queryengine/ddl.rs @@ -49,8 +49,7 @@ action! { TABLE => create_table(handle, con, act).await?, KEYSPACE => create_keyspace(handle, con, act).await?, _ => { - con.write_response(responses::groups::UNKNOWN_DDL_QUERY) - .await?; + con._write_raw(groups::UNKNOWN_DDL_QUERY).await?; } } Ok(()) @@ -67,8 +66,7 @@ action! { TABLE => drop_table(handle, con, act).await?, KEYSPACE => drop_keyspace(handle, con, act).await?, _ => { - con.write_response(responses::groups::UNKNOWN_DDL_QUERY) - .await?; + con._write_raw(groups::UNKNOWN_DDL_QUERY).await?; } } Ok(()) @@ -89,9 +87,9 @@ action! { }; if registry::state_okay() { handle.create_table(table_entity, model_code, is_volatile)?; - con.write_response(responses::groups::OKAY).await?; + con._write_raw(groups::OKAY).await?; } else { - conwrite!(con, responses::groups::SERVER_ERR)?; + return util::err(groups::SERVER_ERR); } Ok(()) } @@ -108,12 +106,12 @@ action! { let ksid = unsafe { ObjectID::from_slice(ksid_str) }; if registry::state_okay() { handle.create_keyspace(ksid)?; - con.write_response(responses::groups::OKAY).await? + con._write_raw(groups::OKAY).await? } else { - conwrite!(con, responses::groups::SERVER_ERR)?; + return util::err(groups::SERVER_ERR); } } - None => con.write_response(responses::groups::ACTION_ERR).await?, + None => return util::err(groups::ACTION_ERR), } Ok(()) } @@ -126,12 +124,12 @@ action! { let entity_group = parser::Entity::from_slice(eg)?; if registry::state_okay() { handle.drop_table(entity_group)?; - con.write_response(responses::groups::OKAY).await?; + con._write_raw(groups::OKAY).await?; } else { - conwrite!(con, responses::groups::SERVER_ERR)?; + return util::err(groups::SERVER_ERR); } }, - None => con.write_response(responses::groups::ACTION_ERR).await?, + None => return util::err(groups::ACTION_ERR), } Ok(()) } @@ -157,12 +155,12 @@ action! { handle.drop_keyspace(objid) }; result?; - con.write_response(responses::groups::OKAY).await?; + con._write_raw(groups::OKAY).await?; } else { - conwrite!(con, responses::groups::SERVER_ERR)?; + return util::err(groups::SERVER_ERR); } }, - None => con.write_response(responses::groups::ACTION_ERR).await?, + None => return util::err(groups::ACTION_ERR), } Ok(()) } diff --git a/server/src/queryengine/inspect.rs b/server/src/queryengine/inspect.rs index b1829663..25768d31 100644 --- a/server/src/queryengine/inspect.rs +++ b/server/src/queryengine/inspect.rs @@ -27,9 +27,9 @@ use super::ddl::{KEYSPACE, TABLE}; use crate::corestore::memstore::ObjectID; use crate::dbnet::connection::prelude::*; -use crate::resp::writer::TypedArrayWriter; const KEYSPACES: &[u8] = "KEYSPACES".as_bytes(); + action! { /// Runs an inspect query: /// - `INSPECT KEYSPACES` is run by this function itself @@ -52,17 +52,15 @@ action! { .iter() .map(|kv| kv.key().clone()) .collect(); - let mut writer = unsafe { - TypedArrayWriter::new(con, b'+', ks_list.len()) - }.await?; - for tbl in ks_list { - writer.write_element(tbl).await?; + con.write_typed_non_null_array_header(ks_list.len(), b'+').await?; + for ks in ks_list { + con.write_typed_non_null_array_element(&ks).await?; } } - _ => conwrite!(con, responses::groups::UNKNOWN_INSPECT_QUERY)?, + _ => return util::err(groups::UNKNOWN_INSPECT_QUERY), } } - None => aerr!(con), + None => return util::err(groups::ACTION_ERR), } Ok(()) } @@ -70,32 +68,30 @@ action! { /// INSPECT a keyspace. This should only have the keyspace ID fn inspect_keyspace(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { ensure_length(act.len(), |len| len < 2)?; - let tbl_list: Vec; + let tbl_list: Vec = match act.next() { Some(keyspace_name) => { // inspect the provided keyspace let ksid = if keyspace_name.len() > 64 { - return conwrite!(con, responses::groups::BAD_CONTAINER_NAME); + return util::err(groups::BAD_CONTAINER_NAME); } else { keyspace_name }; let ks = match handle.get_keyspace(ksid) { Some(kspace) => kspace, - None => return conwrite!(con, responses::groups::CONTAINER_NOT_FOUND), + None => return util::err(groups::CONTAINER_NOT_FOUND), }; - tbl_list = ks.tables.iter().map(|kv| kv.key().clone()).collect(); + ks.tables.iter().map(|kv| kv.key().clone()).collect() }, None => { // inspect the current keyspace let cks = handle.get_cks()?; - tbl_list = cks.tables.iter().map(|kv| kv.key().clone()).collect(); + cks.tables.iter().map(|kv| kv.key().clone()).collect() }, - } - let mut writer = unsafe { - TypedArrayWriter::new(con, b'+', tbl_list.len()) - }.await?; + }; + con.write_typed_non_null_array_header(tbl_list.len(), b'+').await?; for tbl in tbl_list { - writer.write_element(tbl).await?; + con.write_typed_non_null_array_element(&tbl).await?; } Ok(()) } @@ -106,12 +102,12 @@ action! { match act.next() { Some(entity) => { let entity = handle_entity!(con, entity); - conwrite!(con, get_tbl!(entity, handle, con).describe_self())?; + con.write_string(get_tbl!(entity, handle, con).describe_self()).await?; }, None => { // inspect the current table let tbl = handle.get_table_result()?; - con.write_response(tbl.describe_self()).await?; + con.write_string(tbl.describe_self()).await?; }, } Ok(()) diff --git a/server/src/queryengine/mod.rs b/server/src/queryengine/mod.rs index c98f18b0..14d4a738 100644 --- a/server/src/queryengine/mod.rs +++ b/server/src/queryengine/mod.rs @@ -30,7 +30,7 @@ use crate::actions::{ActionError, ActionResult}; use crate::auth; use crate::corestore::Corestore; use crate::dbnet::connection::prelude::*; -use crate::protocol::{iter::AnyArrayIter, responses, PipelinedQuery, SimpleQuery, UnsafeSlice}; +use crate::protocol::{iter::AnyArrayIter, PipelinedQuery, SimpleQuery, UnsafeSlice}; use crate::queryengine::parser::Entity; use crate::{actions, admin}; mod ddl; @@ -67,7 +67,7 @@ macro_rules! gen_constants_and_matches { tags::$action2 => $fns2.await?, )* _ => { - $con.write_response(responses::groups::UNKNOWN_ACTION).await?; + $con._write_raw(groups::UNKNOWN_ACTION).await?; } } }; @@ -164,7 +164,7 @@ action! { act.next_unchecked() }; handle.swap_entity(Entity::from_slice(entity)?)?; - con.write_response(groups::OKAY).await?; + con._write_raw(groups::OKAY).await?; Ok(()) } } @@ -188,7 +188,7 @@ async fn execute_stage_pedantic< }; match ret.await { Ok(()) => Ok(()), - Err(ActionError::ActionError(e)) => con.write_response(e).await, + Err(ActionError::ActionError(e)) => con._write_raw(e).await, Err(ActionError::IoError(ioe)) => Err(ioe), } } diff --git a/server/src/resp/mod.rs b/server/src/resp/mod.rs deleted file mode 100644 index 5dcd3f2e..00000000 --- a/server/src/resp/mod.rs +++ /dev/null @@ -1,226 +0,0 @@ -/* - * Created on Mon Aug 17 2020 - * - * This file is a part of Skytable - * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source - * NoSQL database written by Sayan Nandan ("the Author") with the - * vision to provide flexibility in data modelling without compromising - * on performance, queryability or scalability. - * - * Copyright (c) 2020, Sayan Nandan - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - * -*/ - -#![allow(clippy::needless_lifetimes)] - -//! Utilities for generating responses, which are only used by the `server` -//! -use crate::corestore::buffers::Integer64; -use crate::corestore::memstore::ObjectID; -use crate::util::FutureResult; -use bytes::Bytes; -use std::io::Error as IoError; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; -pub mod writer; - -pub const TSYMBOL_UNICODE_STRING: u8 = b'+'; -pub const TSYMBOL_FLOAT: u8 = b'%'; - -type FutureIoResult<'s> = FutureResult<'s, Result<(), IoError>>; - -/// # The `Writable` trait -/// All trait implementors are given access to an asynchronous stream to which -/// they must write a response. -/// -/// Every `write()` call makes a call to the [`IsConnection`](./IsConnection)'s -/// `write_lowlevel` function, which in turn writes something to the underlying stream. -/// -/// Do note that this write **doesn't gurantee immediate completion** as the underlying -/// stream might use buffering. So, the best idea would be to use to use the `flush()` -/// call on the stream. -pub trait Writable { - /* - HACK(@ohsayan): Since `async` is not supported in traits just yet, we will have to - use explicit declarations for asynchoronous functions - */ - fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s>; -} - -pub trait IsConnection: Sync + Send { - fn write_lowlevel<'s>(&'s mut self, bytes: &'s [u8]) -> FutureIoResult<'s>; -} - -impl IsConnection for T -where - T: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, -{ - fn write_lowlevel<'s>(&'s mut self, bytes: &'s [u8]) -> FutureIoResult<'s> { - Box::pin(self.write_all(bytes)) - } -} - -/// A `BytesWrapper` object wraps around a `Bytes` object that might have been pulled -/// from `Corestore`. -/// -/// This wrapper exists to prevent trait implementation conflicts when -/// an impl for `fmt::Display` may be implemented upstream -#[derive(Debug, PartialEq)] -pub struct BytesWrapper(pub Bytes); - -impl BytesWrapper { - pub fn finish_into_bytes(self) -> Bytes { - self.0 - } -} - -#[derive(Debug, PartialEq)] -pub struct StringWrapper(pub String); - -impl Writable for StringWrapper { - fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> { - Box::pin(async move { - con.write_lowlevel(&[TSYMBOL_UNICODE_STRING]).await?; - // Now get the size of the Bytes object as bytes - let size = Integer64::from(self.0.len()); - // Write this to the stream - con.write_lowlevel(&size).await?; - // Now write a LF character - con.write_lowlevel(&[b'\n']).await?; - // Now write the REAL bytes (of the object) - con.write_lowlevel(self.0.as_bytes()).await?; - Ok(()) - }) - } -} - -impl Writable for Vec { - fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> { - Box::pin(async move { con.write_lowlevel(&self).await }) - } -} - -impl Writable for [u8; N] { - fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> { - Box::pin(async move { con.write_lowlevel(&self).await }) - } -} - -impl Writable for &'static [u8] { - fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> { - Box::pin(async move { con.write_lowlevel(self).await }) - } -} - -impl Writable for &'static str { - fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> { - Box::pin(async move { - // First write a `+` character to the stream since this is a - // string (we represent `String`s as `Byte` objects internally) - // and since `Bytes` are effectively `String`s we will append the - // type operator `+` to the stream - con.write_lowlevel(&[TSYMBOL_UNICODE_STRING]).await?; - // Now get the size of the Bytes object as bytes - let size = Integer64::from(self.len()); - // Write this to the stream - con.write_lowlevel(&size).await?; - // Now write a LF character - con.write_lowlevel(&[b'\n']).await?; - // Now write the REAL bytes (of the object) - con.write_lowlevel(self.as_bytes()).await?; - Ok(()) - }) - } -} - -impl Writable for BytesWrapper { - fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> { - Box::pin(async move { - // First write a `+` character to the stream since this is a - // string (we represent `String`s as `Byte` objects internally) - // and since `Bytes` are effectively `String`s we will append the - // type operator `+` to the stream - let bytes = self.finish_into_bytes(); - con.write_lowlevel(&[TSYMBOL_UNICODE_STRING]).await?; - // Now get the size of the Bytes object as bytes - let size = Integer64::from(bytes.len()); - // Write this to the stream - con.write_lowlevel(&size).await?; - // Now write a LF character - con.write_lowlevel(&[b'\n']).await?; - // Now write the REAL bytes (of the object) - con.write_lowlevel(&bytes).await?; - Ok(()) - }) - } -} - -impl Writable for usize { - fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> { - Box::pin(async move { - con.write_lowlevel(b":").await?; - let usize_bytes = Integer64::from(self); - con.write_lowlevel(&usize_bytes).await?; - con.write_lowlevel(b"\n").await?; - Ok(()) - }) - } -} - -impl Writable for u64 { - fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> { - Box::pin(async move { - con.write_lowlevel(b":").await?; - let usize_bytes = Integer64::from(self); - con.write_lowlevel(&usize_bytes).await?; - con.write_lowlevel(b"\n").await?; - Ok(()) - }) - } -} - -impl Writable for ObjectID { - fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> { - Box::pin(async move { - // First write a `+` character to the stream since this is a - // string (we represent `String`s as `Byte` objects internally) - // and since `Bytes` are effectively `String`s we will append the - // type operator `+` to the stream - con.write_lowlevel(&[TSYMBOL_UNICODE_STRING]).await?; - // Now get the size of the Bytes object as bytes - let size = Integer64::from(self.len()); - // Write this to the stream - con.write_lowlevel(&size).await?; - // Now write a LF character - con.write_lowlevel(&[b'\n']).await?; - // Now write the REAL bytes (of the object) - con.write_lowlevel(&self).await?; - Ok(()) - }) - } -} - -impl Writable for f32 { - fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> { - Box::pin(async move { - let payload = self.to_string(); - con.write_lowlevel(&[TSYMBOL_FLOAT]).await?; - con.write_lowlevel(payload.as_bytes()).await?; - con.write_lowlevel(&[b'\n']).await?; - Ok(()) - }) - } -} diff --git a/server/src/resp/writer.rs b/server/src/resp/writer.rs deleted file mode 100644 index 4590390f..00000000 --- a/server/src/resp/writer.rs +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Created on Thu Aug 12 2021 - * - * This file is a part of Skytable - * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source - * NoSQL database written by Sayan Nandan ("the Author") with the - * vision to provide flexibility in data modelling without compromising - * on performance, queryability or scalability. - * - * Copyright (c) 2021, Sayan Nandan - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - * -*/ - -use crate::corestore::buffers::Integer64; -use crate::corestore::Data; -use crate::protocol::{ - interface::{ProtocolRead, ProtocolSpec}, - responses::groups, -}; -use crate::IoResult; -use core::marker::PhantomData; -use tokio::io::AsyncReadExt; -use tokio::io::AsyncWriteExt; - -/// Write a raw mono group with a custom tsymbol -pub async unsafe fn write_raw_mono( - con: &mut T, - tsymbol: u8, - payload: &Data, -) -> IoResult<()> -where - P: ProtocolSpec, - T: ProtocolRead, - Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, -{ - let raw_stream = con.raw_stream(); - raw_stream.write_all(&[tsymbol; 1]).await?; // first write tsymbol - let bytes = Integer64::from(payload.len()); - raw_stream.write_all(&bytes).await?; // then len - raw_stream.write_all(&[b'\n']).await?; // LF - raw_stream.write_all(payload).await?; // payload - Ok(()) -} - -#[derive(Debug)] -/// A writer for a flat array, which is a multi-typed non-recursive array -pub struct FlatArrayWriter<'a, P, T, Strm> { - tsymbol: u8, - con: &'a mut T, - _owned: PhantomData<(P, Strm)>, -} - -#[allow(dead_code)] // TODO(@ohsayan): Remove this once we start using the flat array writer -impl<'a, P, T, Strm> FlatArrayWriter<'a, P, T, Strm> -where - P: ProtocolSpec, - T: ProtocolRead, - Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, -{ - /// Intialize a new flat array writer. This will write out the tsymbol - /// and length for the flat array - pub async unsafe fn new( - con: &'a mut T, - tsymbol: u8, - len: usize, - ) -> IoResult> { - { - let stream = con.raw_stream(); - // first write _ - stream.write_all(&[b'_']).await?; - let bytes = Integer64::from(len); - // now write len - stream.write_all(&bytes).await?; - // first LF - stream.write_all(&[b'\n']).await?; - } - Ok(Self { - con, - tsymbol, - _owned: PhantomData, - }) - } - /// Write an element - pub async fn write_element(&mut self, bytes: impl AsRef<[u8]>) -> IoResult<()> { - let stream = unsafe { self.con.raw_stream() }; - let bytes = bytes.as_ref(); - // first write - stream.write_all(&[self.tsymbol]).await?; - // now len - let len = Integer64::from(bytes.len()); - stream.write_all(&len).await?; - // now LF - stream.write_all(&[b'\n']).await?; - // now element - stream.write_all(bytes).await?; - Ok(()) - } - /// Write the NIL response code - pub async fn write_nil(&mut self) -> IoResult<()> { - let stream = unsafe { self.con.raw_stream() }; - stream.write_all(groups::NIL).await?; - Ok(()) - } - /// Write the SERVER_ERR (5) response code - pub async fn write_server_error(&mut self) -> IoResult<()> { - let stream = unsafe { self.con.raw_stream() }; - stream.write_all(groups::NIL).await?; - Ok(()) - } -} - -#[derive(Debug)] -/// A writer for a typed array, which is a singly-typed array which either -/// has a typed element or a `NULL` -pub struct TypedArrayWriter<'a, P, T, Strm> { - con: &'a mut T, - _owned: PhantomData<(P, Strm)>, -} - -impl<'a, P, T, Strm> TypedArrayWriter<'a, P, T, Strm> -where - P: ProtocolSpec, - T: ProtocolRead, - Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, -{ - /// Create a new `typedarraywriter`. This will write the tsymbol and - /// the array length - pub async unsafe fn new( - con: &'a mut T, - tsymbol: u8, - len: usize, - ) -> IoResult> { - { - let stream = con.raw_stream(); - // first write @ - stream.write_all(&[b'@', tsymbol]).await?; - let bytes = Integer64::from(len); - // now write len - stream.write_all(&bytes).await?; - // first LF - stream.write_all(&[b'\n']).await?; - } - Ok(Self { - con, - _owned: PhantomData, - }) - } - /// Write an element - pub async fn write_element(&mut self, bytes: impl AsRef<[u8]>) -> IoResult<()> { - let stream = unsafe { self.con.raw_stream() }; - let bytes = bytes.as_ref(); - // write len - let len = Integer64::from(bytes.len()); - stream.write_all(&len).await?; - // now LF - stream.write_all(&[b'\n']).await?; - // now element - stream.write_all(bytes).await?; - Ok(()) - } - /// Write a null - pub async fn write_null(&mut self) -> IoResult<()> { - let stream = unsafe { self.con.raw_stream() }; - stream.write_all(&[b'\0']).await?; - Ok(()) - } -} - -#[derive(Debug)] -/// A writer for a non-null typed array, which is a singly-typed array which either -/// has a typed element or a `NULL` -pub struct NonNullArrayWriter<'a, P, T, Strm> { - con: &'a mut T, - _owned: PhantomData<(P, Strm)>, -} - -impl<'a, P, T, Strm> NonNullArrayWriter<'a, P, T, Strm> -where - P: ProtocolSpec, - T: ProtocolRead, - Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, -{ - /// Create a new `typedarraywriter`. This will write the tsymbol and - /// the array length - pub async unsafe fn new( - con: &'a mut T, - tsymbol: u8, - len: usize, - ) -> IoResult> { - { - let stream = con.raw_stream(); - // first write @ - stream.write_all(&[b'^', tsymbol]).await?; - let bytes = Integer64::from(len); - // now write len - stream.write_all(&bytes).await?; - // first LF - stream.write_all(&[b'\n']).await?; - } - Ok(Self { - con, - _owned: PhantomData, - }) - } - /// Write an element - pub async fn write_element(&mut self, bytes: impl AsRef<[u8]>) -> IoResult<()> { - let stream = unsafe { self.con.raw_stream() }; - let bytes = bytes.as_ref(); - // write len - let len = Integer64::from(bytes.len()); - stream.write_all(&len).await?; - // now LF - stream.write_all(&[b'\n']).await?; - // now element - stream.write_all(bytes).await?; - Ok(()) - } -} diff --git a/server/src/tests/inspect_tests.rs b/server/src/tests/inspect_tests.rs index 4ed4e86d..f137fc00 100644 --- a/server/src/tests/inspect_tests.rs +++ b/server/src/tests/inspect_tests.rs @@ -34,7 +34,7 @@ mod __private { query.push("KEYSPACES"); assert!(matches!( con.run_query_raw(&query).await.unwrap(), - Element::Array(Array::Str(_)) + Element::Array(Array::NonNullStr(_)) )) } async fn test_inspect_keyspace() { @@ -43,7 +43,7 @@ mod __private { query.push(&__MYKS__); assert!(matches!( con.run_query_raw(&query).await.unwrap(), - Element::Array(Array::Str(_)) + Element::Array(Array::NonNullStr(_)) )) } async fn test_inspect_current_keyspace() { diff --git a/server/src/tests/kvengine.rs b/server/src/tests/kvengine.rs index b055557e..728d56a0 100644 --- a/server/src/tests/kvengine.rs +++ b/server/src/tests/kvengine.rs @@ -1035,8 +1035,7 @@ mod __private { .into_iter() .map(|element| element.to_owned()) .collect(); - if let Element::Array(Array::Str(arr)) = ret { - let arr: Vec = arr.into_iter().map(|v| v.unwrap()).collect(); + if let Element::Array(Array::NonNullStr(arr)) = ret { assert_eq!(ret_should_have.len(), arr.len()); assert!(ret_should_have.into_iter().all(|key| arr.contains(&key))); } else { @@ -1070,8 +1069,7 @@ mod __private { .into_iter() .map(|element| element.to_owned()) .collect(); - if let Element::Array(Array::Str(arr)) = ret { - let arr: Vec = arr.into_iter().map(|v| v.unwrap()).collect(); + if let Element::Array(Array::NonNullStr(arr)) = ret { assert_eq!(ret_should_have.len(), arr.len()); assert!(ret_should_have.into_iter().all(|key| arr.contains(&key))); } else { @@ -1092,8 +1090,7 @@ mod __private { .into_iter() .map(|element| element.to_owned()) .collect(); - if let Element::Array(Array::Str(arr)) = ret { - let arr: Vec = arr.into_iter().map(|v| v.unwrap()).collect(); + if let Element::Array(Array::NonNullStr(arr)) = ret { assert_eq!(ret_should_have.len(), arr.len()); assert!(ret_should_have.into_iter().all(|key| arr.contains(&key))); } else { @@ -1115,8 +1112,7 @@ mod __private { .into_iter() .map(|element| element.to_owned()) .collect(); - if let Element::Array(Array::Str(arr)) = ret { - let arr: Vec = arr.into_iter().map(|v| v.unwrap()).collect(); + if let Element::Array(Array::NonNullStr(arr)) = ret { assert_eq!(ret_should_have.len(), arr.len()); assert!(ret_should_have.into_iter().all(|key| arr.contains(&key))); } else { diff --git a/server/src/tests/kvengine_list.rs b/server/src/tests/kvengine_list.rs index 63307153..ca056cec 100644 --- a/server/src/tests/kvengine_list.rs +++ b/server/src/tests/kvengine_list.rs @@ -61,7 +61,7 @@ mod __private { async fn test_lget_emptylist_okay() { lset!(con, "mysuperlist"); let q = query!("lget", "mysuperlist"); - runeq!(con, q, Element::Array(Array::Str(vec![]))); + runeq!(con, q, Element::Array(Array::NonNullStr(vec![]))); } async fn test_lget_list_with_elements_okay() { lset!(con, "mysuperlist", "elementa", "elementb", "elementc"); diff --git a/server/src/tests/macros.rs b/server/src/tests/macros.rs index 174e8eb0..49870639 100644 --- a/server/src/tests/macros.rs +++ b/server/src/tests/macros.rs @@ -104,7 +104,7 @@ macro_rules! assert_okay { } macro_rules! assert_skyhash_arrayeq { - (str, $con:expr, $query:expr, $($val:expr),*) => { + (!str, $con:expr, $query:expr, $($val:expr),*) => { runeq!( $con, $query, @@ -115,13 +115,13 @@ macro_rules! assert_skyhash_arrayeq { )) ) }; - (bin, $con:expr, $query:expr, $($val:expr),*) => { + (str, $con:expr, $query:expr, $($val:expr),*) => { runeq!( $con, $query, - skytable::Element::Array(skytable::types::Array::Bin( + skytable::Element::Array(skytable::types::Array::NonNullStr( vec![ - $(Some($val.into()),)* + $($val.into(),)* ] )) ) diff --git a/server/src/tests/persist/mod.rs b/server/src/tests/persist/mod.rs index 95cd20c6..7b94dd72 100644 --- a/server/src/tests/persist/mod.rs +++ b/server/src/tests/persist/mod.rs @@ -161,9 +161,9 @@ impl PersistValue for [&[u8]; N] { fn response_load(&self) -> Element { let mut flat = Vec::with_capacity(N); for item in self { - flat.push(Some(item.to_vec())); + flat.push(item.to_vec()); } - Element::Array(Array::Bin(flat)) + Element::Array(Array::NonNullBin(flat)) } } @@ -174,9 +174,9 @@ impl PersistValue for [&str; N] { fn response_load(&self) -> Element { let mut flat = Vec::with_capacity(N); for item in self { - flat.push(Some(item.to_string())); + flat.push(item.to_string()); } - Element::Array(Array::Str(flat)) + Element::Array(Array::NonNullStr(flat)) } } From 67b19602b953da4927b6ba8c45102df363307501 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Fri, 29 Apr 2022 13:23:22 -0700 Subject: [PATCH 07/13] Make all responses/groups generic over protocol --- server/src/actions/dbsize.rs | 2 +- server/src/actions/del.rs | 8 +- server/src/actions/exists.rs | 6 +- server/src/actions/flushdb.rs | 6 +- server/src/actions/get.rs | 8 +- server/src/actions/keylen.rs | 6 +- server/src/actions/lists/lget.rs | 60 +-- server/src/actions/lists/lmod.rs | 53 ++- server/src/actions/lists/mod.rs | 14 +- server/src/actions/lskeys.rs | 6 +- server/src/actions/macros.rs | 10 +- server/src/actions/mget.rs | 6 +- server/src/actions/mod.rs | 47 ++- server/src/actions/mpop.rs | 8 +- server/src/actions/mset.rs | 8 +- server/src/actions/mupdate.rs | 8 +- server/src/actions/pop.rs | 10 +- server/src/actions/set.rs | 12 +- server/src/actions/strong/sdel.rs | 14 +- server/src/actions/strong/sset.rs | 14 +- server/src/actions/strong/supdate.rs | 14 +- server/src/actions/update.rs | 12 +- server/src/actions/uset.rs | 8 +- server/src/actions/whereami.rs | 2 +- server/src/admin/mksnap.rs | 22 +- server/src/admin/sys.rs | 6 +- server/src/auth/mod.rs | 28 +- server/src/corestore/memstore.rs | 1 + server/src/corestore/mod.rs | 5 +- server/src/corestore/table.rs | 8 +- server/src/dbnet/connection.rs | 11 +- .../{interface/mod.rs => interface.rs} | 80 +++- server/src/protocol/iter.rs | 15 - server/src/protocol/mod.rs | 266 +----------- server/src/protocol/responses.rs | 153 ------- server/src/protocol/v2/interface_impls.rs | 274 ++++++++++++ server/src/protocol/v2/mod.rs | 395 +++++++++++------- server/src/protocol/{ => v2}/tests.rs | 17 +- server/src/queryengine/ddl.rs | 62 +-- server/src/queryengine/inspect.rs | 23 +- server/src/queryengine/mod.rs | 12 +- server/src/queryengine/parser.rs | 46 +- server/src/queryengine/tests.rs | 129 +++--- server/src/util/mod.rs | 10 +- 44 files changed, 976 insertions(+), 929 deletions(-) rename server/src/protocol/{interface/mod.rs => interface.rs} (78%) delete mode 100644 server/src/protocol/responses.rs create mode 100644 server/src/protocol/v2/interface_impls.rs rename server/src/protocol/{ => v2}/tests.rs (96%) diff --git a/server/src/actions/dbsize.rs b/server/src/actions/dbsize.rs index a52cab79..98ba2cf1 100644 --- a/server/src/actions/dbsize.rs +++ b/server/src/actions/dbsize.rs @@ -29,7 +29,7 @@ use crate::dbnet::connection::prelude::*; action!( /// Returns the number of keys in the database fn dbsize(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len < 2)?; + ensure_length::

(act.len(), |len| len < 2)?; if act.is_empty() { let len = get_tbl_ref!(handle, con).count(); con.write_usize(len).await?; diff --git a/server/src/actions/del.rs b/server/src/actions/del.rs index b56ec2b1..e6edc277 100644 --- a/server/src/actions/del.rs +++ b/server/src/actions/del.rs @@ -38,7 +38,7 @@ action!( /// Do note that this function is blocking since it acquires a write lock. /// It will write an entire datagroup, for this `del` action fn del(handle: &Corestore, con: &'a mut T, act: ActionIter<'a>) { - ensure_length(act.len(), |size| size != 0)?; + ensure_length::

(act.len(), |size| size != 0)?; let table = get_tbl_ref!(handle, con); macro_rules! remove { ($engine:expr) => {{ @@ -59,10 +59,10 @@ action!( if let Some(done_howmany) = done_howmany { con.write_usize(done_howmany).await?; } else { - con._write_raw(groups::SERVER_ERR).await?; + con._write_raw(P::RCODE_SERVER_ERR).await?; } } else { - return util::err(groups::ENCODING_ERROR); + return util::err(P::RCODE_ENCODING_ERROR); } }}; } @@ -74,7 +74,7 @@ action!( remove!(kvlmap) } #[allow(unreachable_patterns)] - _ => return util::err(groups::WRONG_MODEL), + _ => return util::err(P::RSTRING_WRONG_MODEL), } Ok(()) } diff --git a/server/src/actions/exists.rs b/server/src/actions/exists.rs index f2d69828..04e9a654 100644 --- a/server/src/actions/exists.rs +++ b/server/src/actions/exists.rs @@ -36,7 +36,7 @@ use crate::util::compiler; action!( /// Run an `EXISTS` query fn exists(handle: &Corestore, con: &'a mut T, act: ActionIter<'a>) { - ensure_length(act.len(), |len| len != 0)?; + ensure_length::

(act.len(), |len| len != 0)?; let mut how_many_of_them_exist = 0usize; macro_rules! exists { ($engine:expr) => {{ @@ -47,7 +47,7 @@ action!( }); con.write_usize(how_many_of_them_exist).await?; } else { - return util::err(groups::ENCODING_ERROR); + return util::err(P::RCODE_ENCODING_ERROR); } }}; } @@ -56,7 +56,7 @@ action!( DataModel::KV(kve) => exists!(kve), DataModel::KVExtListmap(kve) => exists!(kve), #[allow(unreachable_patterns)] - _ => return util::err(groups::WRONG_MODEL), + _ => return util::err(P::RSTRING_WRONG_MODEL), } Ok(()) } diff --git a/server/src/actions/flushdb.rs b/server/src/actions/flushdb.rs index d7dc7335..ef717334 100644 --- a/server/src/actions/flushdb.rs +++ b/server/src/actions/flushdb.rs @@ -30,7 +30,7 @@ use crate::queryengine::ActionIter; action!( /// Delete all the keys in the database fn flushdb(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len < 2)?; + ensure_length::

(act.len(), |len| len < 2)?; if registry::state_okay() { if act.is_empty() { // flush the current table @@ -41,9 +41,9 @@ action!( let entity = handle_entity!(con, raw_entity); get_tbl!(entity, handle, con).truncate_table(); } - con._write_raw(groups::OKAY).await?; + con._write_raw(P::RCODE_OKAY).await?; } else { - con._write_raw(groups::SERVER_ERR).await?; + con._write_raw(P::RCODE_SERVER_ERR).await?; } Ok(()) } diff --git a/server/src/actions/get.rs b/server/src/actions/get.rs index d9739939..4a599e2c 100644 --- a/server/src/actions/get.rs +++ b/server/src/actions/get.rs @@ -33,16 +33,16 @@ use crate::util::compiler; action!( /// Run a `GET` query fn get(handle: &crate::corestore::Corestore, con: &mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len == 1)?; - let kve = handle.get_table_with::()?; + ensure_length::

(act.len(), |len| len == 1)?; + let kve = handle.get_table_with::()?; unsafe { match kve.get_cloned(act.next_unchecked()) { Ok(Some(val)) => { con.write_mono_length_prefixed_with_tsymbol(&val, kve.get_value_tsymbol()) .await? } - Err(_) => compiler::cold_err(con._write_raw(groups::ENCODING_ERROR)).await?, - Ok(_) => con._write_raw(groups::NIL).await?, + Err(_) => compiler::cold_err(con._write_raw(P::RCODE_ENCODING_ERROR)).await?, + Ok(_) => con._write_raw(P::RCODE_NIL).await?, } } Ok(()) diff --git a/server/src/actions/keylen.rs b/server/src/actions/keylen.rs index cf1cad8e..41c3bd8e 100644 --- a/server/src/actions/keylen.rs +++ b/server/src/actions/keylen.rs @@ -31,9 +31,9 @@ action!( /// /// At this moment, `keylen` only supports a single key fn keylen(handle: &crate::corestore::Corestore, con: &mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len == 1)?; + ensure_length::

(act.len(), |len| len == 1)?; let res: Option = { - let reader = handle.get_table_with::()?; + let reader = handle.get_table_with::()?; unsafe { // UNSAFE(@ohsayan): this is completely safe as we've already checked // the number of arguments is one @@ -48,7 +48,7 @@ action!( con.write_usize(value).await?; } else { // Ah, couldn't find that key - con._write_raw(groups::NIL).await?; + con._write_raw(P::RCODE_NIL).await?; } Ok(()) } diff --git a/server/src/actions/lists/lget.rs b/server/src/actions/lists/lget.rs index e5900c30..b98dde40 100644 --- a/server/src/actions/lists/lget.rs +++ b/server/src/actions/lists/lget.rs @@ -63,8 +63,8 @@ action! { /// - `LGET LAST` will return the last item /// if it exists fn lget(handle: &Corestore, con: &mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len != 0)?; - let listmap = handle.get_table_with::()?; + ensure_length::

(act.len(), |len| len != 0)?; + let listmap = handle.get_table_with::()?; // get the list name let listname = unsafe { act.next_unchecked() }; // now let us see what we need to do @@ -72,7 +72,7 @@ action! { () => { match unsafe { String::from_utf8_lossy(act.next_unchecked()) }.parse::() { Ok(int) => int, - Err(_) => return util::err(groups::WRONGTYPE_ERR), + Err(_) => return util::err(P::RCODE_WRONGTYPE_ERR), } }; } @@ -81,32 +81,32 @@ action! { // just return everything in the list let items = match listmap.list_cloned_full(listname) { Ok(Some(list)) => list, - Ok(None) => return Err(groups::NIL.into()), - Err(()) => return Err(groups::ENCODING_ERROR.into()), + Ok(None) => return Err(P::RCODE_NIL.into()), + Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()), }; writelist!(con, listmap, items); } Some(subaction) => { match subaction.as_ref() { LEN => { - ensure_length(act.len(), |len| len == 0)?; + ensure_length::

(act.len(), |len| len == 0)?; match listmap.list_len(listname) { Ok(Some(len)) => con.write_usize(len).await?, - Ok(None) => return Err(groups::NIL.into()), - Err(()) => return Err(groups::ENCODING_ERROR.into()), + Ok(None) => return Err(P::RCODE_NIL.into()), + Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()), } } LIMIT => { - ensure_length(act.len(), |len| len == 1)?; + ensure_length::

(act.len(), |len| len == 1)?; let count = get_numeric_count!(); match listmap.list_cloned(listname, count) { Ok(Some(items)) => writelist!(con, listmap, items), - Ok(None) => return Err(groups::NIL.into()), - Err(()) => return Err(groups::ENCODING_ERROR.into()), + Ok(None) => return Err(P::RCODE_NIL.into()), + Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()), } } VALUEAT => { - ensure_length(act.len(), |len| len == 1)?; + ensure_length::

(act.len(), |len| len == 1)?; let idx = get_numeric_count!(); let maybe_value = listmap.get(listname).map(|list| { list.map(|lst| lst.read().get(idx).cloned()) @@ -120,18 +120,18 @@ action! { } Some(None) => { // bad index - return Err(groups::LISTMAP_BAD_INDEX.into()); + return Err(P::RSTRING_LISTMAP_BAD_INDEX.into()); } None => { // not found - return Err(groups::NIL.into()); + return Err(P::RCODE_NIL.into()); } } - Err(()) => return Err(groups::ENCODING_ERROR.into()), + Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()), } } LAST => { - ensure_length(act.len(), |len| len == 0)?; + ensure_length::

(act.len(), |len| len == 0)?; let maybe_value = listmap.get(listname).map(|list| { list.map(|lst| lst.read().last().cloned()) }); @@ -142,14 +142,14 @@ action! { &value, listmap.get_value_tsymbol() ).await?; }, - Some(None) => return Err(groups::LISTMAP_LIST_IS_EMPTY.into()), - None => return Err(groups::NIL.into()), + Some(None) => return Err(P::RSTRING_LISTMAP_LIST_IS_EMPTY.into()), + None => return Err(P::RCODE_NIL.into()), } - Err(()) => return Err(groups::ENCODING_ERROR.into()), + Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()), } } FIRST => { - ensure_length(act.len(), |len| len == 0)?; + ensure_length::

(act.len(), |len| len == 0)?; let maybe_value = listmap.get(listname).map(|list| { list.map(|lst| lst.read().first().cloned()) }); @@ -160,10 +160,10 @@ action! { &value, listmap.get_value_tsymbol() ).await?; }, - Some(None) => return Err(groups::LISTMAP_LIST_IS_EMPTY.into()), - None => return Err(groups::NIL.into()), + Some(None) => return Err(P::RSTRING_LISTMAP_LIST_IS_EMPTY.into()), + None => return Err(P::RCODE_NIL.into()), } - Err(()) => return Err(groups::ENCODING_ERROR.into()), + Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()), } } RANGE => { @@ -171,13 +171,13 @@ action! { Some(start) => { let start: usize = match start.parse() { Ok(v) => v, - Err(_) => return util::err(groups::WRONGTYPE_ERR), + Err(_) => return util::err(P::RCODE_WRONGTYPE_ERR), }; let mut range = Range::new(start); if let Some(stop) = act.next_string_owned() { let stop: usize = match stop.parse() { Ok(v) => v, - Err(_) => return util::err(groups::WRONGTYPE_ERR), + Err(_) => return util::err(P::RCODE_WRONGTYPE_ERR), }; range.set_stop(stop); }; @@ -188,17 +188,17 @@ action! { Some(ret) => { writelist!(con, listmap, ret); }, - None => return Err(groups::LISTMAP_BAD_INDEX.into()), + None => return Err(P::RSTRING_LISTMAP_BAD_INDEX.into()), } } - Ok(None) => return Err(groups::NIL.into()), - Err(()) => return Err(groups::ENCODING_ERROR.into()), + Ok(None) => return Err(P::RCODE_NIL.into()), + Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()), } } - None => return Err(groups::ACTION_ERR.into()), + None => return Err(P::RCODE_ACTION_ERR.into()), } } - _ => return Err(groups::UNKNOWN_ACTION.into()), + _ => return Err(P::RCODE_UNKNOWN_ACTION.into()), } } } diff --git a/server/src/actions/lists/lmod.rs b/server/src/actions/lists/lmod.rs index b5b090f2..20901bb6 100644 --- a/server/src/actions/lists/lmod.rs +++ b/server/src/actions/lists/lmod.rs @@ -24,7 +24,6 @@ * */ -use super::OKAY_BADIDX_NIL_NLUT; use crate::corestore::Data; use crate::dbnet::connection::prelude::*; use crate::util::compiler; @@ -44,55 +43,55 @@ action! { /// - `LMOD remove ` /// - `LMOD clear` fn lmod(handle: &Corestore, con: &mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len > 1)?; - let listmap = handle.get_table_with::()?; + ensure_length::

(act.len(), |len| len > 1)?; + let listmap = handle.get_table_with::()?; // get the list name let listname = unsafe { act.next_unchecked() }; macro_rules! get_numeric_count { () => { match unsafe { String::from_utf8_lossy(act.next_unchecked()) }.parse::() { Ok(int) => int, - Err(_) => return Err(groups::WRONGTYPE_ERR.into()), + Err(_) => return Err(P::RCODE_WRONGTYPE_ERR.into()), } }; } // now let us see what we need to do match unsafe { act.next_uppercase_unchecked() }.as_ref() { CLEAR => { - ensure_length(act.len(), |len| len == 0)?; + ensure_length::

(act.len(), |len| len == 0)?; let list = match listmap.get_inner_ref().get(listname) { Some(l) => l, - _ => return Err(groups::NIL.into()), + _ => return Err(P::RCODE_NIL.into()), }; let okay = if registry::state_okay() { list.write().clear(); - groups::OKAY + P::RCODE_OKAY } else { - groups::SERVER_ERR + P::RCODE_SERVER_ERR }; con._write_raw(okay).await? } PUSH => { - ensure_boolean_or_aerr(!act.is_empty())?; + ensure_boolean_or_aerr::

(!act.is_empty())?; let list = match listmap.get_inner_ref().get(listname) { Some(l) => l, - _ => return Err(groups::NIL.into()), + _ => return Err(P::RCODE_NIL.into()), }; let venc_ok = listmap.get_val_encoder(); let ret = if compiler::likely(act.as_ref().all(venc_ok)) { if registry::state_okay() { list.write().extend(act.map(Data::copy_from_slice)); - groups::OKAY + P::RCODE_OKAY } else { - groups::SERVER_ERR + P::RCODE_SERVER_ERR } } else { - groups::ENCODING_ERROR + P::RCODE_ENCODING_ERROR }; con._write_raw(ret).await? } REMOVE => { - ensure_length(act.len(), |len| len == 1)?; + ensure_length::

(act.len(), |len| len == 1)?; let idx_to_remove = get_numeric_count!(); if registry::state_okay() { let maybe_value = listmap.get_inner_ref().get(listname).map(|list| { @@ -104,13 +103,13 @@ action! { false } }); - con._write_raw(OKAY_BADIDX_NIL_NLUT[maybe_value]).await? + con._write_raw(P::OKAY_BADIDX_NIL_NLUT[maybe_value]).await? } else { - return Err(groups::SERVER_ERR.into()); + return Err(P::RCODE_SERVER_ERR.into()); } } INSERT => { - ensure_length(act.len(), |len| len == 2)?; + ensure_length::

(act.len(), |len| len == 2)?; let idx_to_insert_at = get_numeric_count!(); let bts = unsafe { act.next_unchecked() }; let ret = if compiler::likely(listmap.is_val_ok(bts)) { @@ -128,21 +127,21 @@ action! { false } }), - Err(()) => return Err(groups::ENCODING_ERROR.into()), + Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()), }; - OKAY_BADIDX_NIL_NLUT[maybe_insert] + P::OKAY_BADIDX_NIL_NLUT[maybe_insert] } else { // flush broken; server err - groups::SERVER_ERR + P::RCODE_SERVER_ERR } } else { // encoding failed, uh - groups::ENCODING_ERROR + P::RCODE_ENCODING_ERROR }; con._write_raw(ret).await? } POP => { - ensure_length(act.len(), |len| len < 2)?; + ensure_length::

(act.len(), |len| len < 2)?; let idx = if act.len() == 1 { // we have an idx Some(get_numeric_count!()) @@ -165,7 +164,7 @@ action! { wlock.pop() } }), - Err(()) => return Err(groups::ENCODING_ERROR.into()), + Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()), }; match maybe_pop { Some(Some(val)) => { @@ -174,15 +173,15 @@ action! { ).await?; } Some(None) => { - con._write_raw(groups::LISTMAP_BAD_INDEX).await?; + con._write_raw(P::RSTRING_LISTMAP_BAD_INDEX).await?; } - None => con._write_raw(groups::NIL).await?, + None => con._write_raw(P::RCODE_NIL).await?, } } else { - con._write_raw(groups::SERVER_ERR).await? + con._write_raw(P::RCODE_SERVER_ERR).await? } } - _ => con._write_raw(groups::UNKNOWN_ACTION).await?, + _ => con._write_raw(P::RCODE_UNKNOWN_ACTION).await?, } Ok(()) } diff --git a/server/src/actions/lists/mod.rs b/server/src/actions/lists/mod.rs index 304a3799..3f522304 100644 --- a/server/src/actions/lists/mod.rs +++ b/server/src/actions/lists/mod.rs @@ -30,22 +30,16 @@ mod macros; pub mod lget; pub mod lmod; -use crate::corestore::booltable::BytesBoolTable; -use crate::corestore::booltable::BytesNicheLUT; use crate::corestore::Data; use crate::dbnet::connection::prelude::*; use crate::kvengine::LockedVec; -const OKAY_OVW_BLUT: BytesBoolTable = BytesBoolTable::new(groups::OKAY, groups::OVERWRITE_ERR); -const OKAY_BADIDX_NIL_NLUT: BytesNicheLUT = - BytesNicheLUT::new(groups::NIL, groups::OKAY, groups::LISTMAP_BAD_INDEX); - action! { /// Handle an `LSET` query for the list model /// Syntax: `LSET ` fn lset(handle: &Corestore, con: &mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len > 0)?; - let listmap = handle.get_table_with::()?; + ensure_length::

(act.len(), |len| len > 0)?; + let listmap = handle.get_table_with::()?; let listname = unsafe { act.next_unchecked_bytes() }; let list = listmap.get_inner_ref(); if registry::state_okay() { @@ -56,9 +50,9 @@ action! { } else { false }; - con._write_raw(OKAY_OVW_BLUT[did]).await? + con._write_raw(P::OKAY_OVW_BLUT[did]).await? } else { - con._write_raw(groups::SERVER_ERR).await? + con._write_raw(P::RCODE_SERVER_ERR).await? } Ok(()) } diff --git a/server/src/actions/lskeys.rs b/server/src/actions/lskeys.rs index 82ccc564..772265d5 100644 --- a/server/src/actions/lskeys.rs +++ b/server/src/actions/lskeys.rs @@ -33,7 +33,7 @@ const DEFAULT_COUNT: usize = 10; action!( /// Run an `LSKEYS` query fn lskeys(handle: &crate::corestore::Corestore, con: &mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |size| size < 4)?; + ensure_length::

(act.len(), |size| size < 4)?; let (table, count) = if act.is_empty() { (get_tbl!(handle, con), DEFAULT_COUNT) } else if act.len() == 1 { @@ -44,7 +44,7 @@ action!( let count = if let Ok(cnt) = String::from_utf8_lossy(nextret).parse::() { cnt } else { - return util::err(groups::WRONGTYPE_ERR); + return util::err(P::RCODE_WRONGTYPE_ERR); }; (get_tbl!(handle, con), count) } else { @@ -60,7 +60,7 @@ action!( let count = if let Ok(cnt) = String::from_utf8_lossy(count_ret).parse::() { cnt } else { - return util::err(groups::WRONGTYPE_ERR); + return util::err(P::RCODE_WRONGTYPE_ERR); }; (get_tbl!(entity, handle, con), count) }; diff --git a/server/src/actions/macros.rs b/server/src/actions/macros.rs index 40be267a..3acd02a7 100644 --- a/server/src/actions/macros.rs +++ b/server/src/actions/macros.rs @@ -49,12 +49,14 @@ macro_rules! is_lowbit_unset { #[macro_export] macro_rules! get_tbl { ($entity:expr, $store:expr, $con:expr) => {{ - $store.get_table($entity)? + $crate::actions::translate_ddl_error::>( + $store.get_table($entity), + )? }}; ($store:expr, $con:expr) => {{ match $store.get_ctable() { Some(tbl) => tbl, - None => return $crate::util::err($crate::protocol::responses::groups::DEFAULT_UNSET), + None => return $crate::util::err(P::RSTRING_DEFAULT_UNSET), } }}; } @@ -64,7 +66,7 @@ macro_rules! get_tbl_ref { ($store:expr, $con:expr) => {{ match $store.get_ctable_ref() { Some(tbl) => tbl, - None => return $crate::util::err($crate::protocol::responses::groups::DEFAULT_UNSET), + None => return $crate::util::err(P::RSTRING_DEFAULT_UNSET), } }}; } @@ -72,7 +74,7 @@ macro_rules! get_tbl_ref { #[macro_export] macro_rules! handle_entity { ($con:expr, $ident:expr) => {{ - match $crate::queryengine::parser::Entity::from_slice(&$ident) { + match $crate::queryengine::parser::Entity::from_slice::

(&$ident) { Ok(e) => e, Err(e) => return Err(e.into()), } diff --git a/server/src/actions/mget.rs b/server/src/actions/mget.rs index a4789902..2daba2d1 100644 --- a/server/src/actions/mget.rs +++ b/server/src/actions/mget.rs @@ -33,8 +33,8 @@ action!( /// Run an `MGET` query /// fn mget(handle: &crate::corestore::Corestore, con: &mut T, act: ActionIter<'a>) { - ensure_length(act.len(), |size| size != 0)?; - let kve = handle.get_table_with::()?; + ensure_length::

(act.len(), |size| size != 0)?; + let kve = handle.get_table_with::()?; let encoding_is_okay = ENCODING_LUT_ITER[kve.is_key_encoded()](act.as_ref()); if compiler::likely(encoding_is_okay) { con.write_typed_array_header(act.len(), kve.get_value_tsymbol()) @@ -46,7 +46,7 @@ action!( } } } else { - return util::err(groups::ENCODING_ERROR); + return util::err(P::RCODE_ENCODING_ERROR); } Ok(()) } diff --git a/server/src/actions/mod.rs b/server/src/actions/mod.rs index 22d99e88..3577964a 100644 --- a/server/src/actions/mod.rs +++ b/server/src/actions/mod.rs @@ -51,7 +51,7 @@ pub mod update; pub mod uset; pub mod whereami; use crate::corestore::memstore::DdlError; -use crate::protocol::responses::groups; +use crate::protocol::interface::ProtocolSpec; use crate::util; use std::io::Error as IoError; @@ -77,36 +77,41 @@ impl From for ActionError { } } -impl From for ActionError { - fn from(e: DdlError) -> Self { - let ret = match e { - DdlError::AlreadyExists => groups::ALREADY_EXISTS, - DdlError::DdlTransactionFailure => groups::DDL_TRANSACTIONAL_FAILURE, - DdlError::DefaultNotFound => groups::DEFAULT_UNSET, - DdlError::NotEmpty => groups::KEYSPACE_NOT_EMPTY, - DdlError::NotReady => groups::NOT_READY, - DdlError::ObjectNotFound => groups::CONTAINER_NOT_FOUND, - DdlError::ProtectedObject => groups::PROTECTED_OBJECT, - DdlError::StillInUse => groups::STILL_IN_USE, - DdlError::WrongModel => groups::WRONG_MODEL, - }; - Self::ActionError(ret) +#[cold] +#[inline(never)] +pub fn translate_ddl_error(r: Result) -> Result { + match r { + Ok(r) => Ok(r), + Err(e) => { + let err = match e { + DdlError::AlreadyExists => P::RSTRING_ALREADY_EXISTS, + DdlError::DdlTransactionFailure => P::RSTRING_DDL_TRANSACTIONAL_FAILURE, + DdlError::DefaultNotFound => P::RSTRING_DEFAULT_UNSET, + DdlError::NotEmpty => P::RSTRING_KEYSPACE_NOT_EMPTY, + DdlError::NotReady => P::RSTRING_NOT_READY, + DdlError::ObjectNotFound => P::RSTRING_CONTAINER_NOT_FOUND, + DdlError::ProtectedObject => P::RSTRING_PROTECTED_OBJECT, + DdlError::StillInUse => P::RSTRING_STILL_IN_USE, + DdlError::WrongModel => P::RSTRING_WRONG_MODEL, + }; + Err(ActionError::ActionError(err)) + } } } -pub fn ensure_length(len: usize, is_valid: fn(usize) -> bool) -> ActionResult<()> { +pub fn ensure_length(len: usize, is_valid: fn(usize) -> bool) -> ActionResult<()> { if util::compiler::likely(is_valid(len)) { Ok(()) } else { - util::err(groups::ACTION_ERR) + util::err(P::RCODE_ACTION_ERR) } } -pub fn ensure_boolean_or_aerr(boolean: bool) -> ActionResult<()> { +pub fn ensure_boolean_or_aerr(boolean: bool) -> ActionResult<()> { if util::compiler::likely(boolean) { Ok(()) } else { - util::err(groups::ACTION_ERR) + util::err(P::RCODE_ACTION_ERR) } } @@ -124,13 +129,13 @@ pub mod heya { action!( /// Returns a `HEY!` `Response` fn heya(_handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len == 0 || len == 1)?; + ensure_length::

(act.len(), |len| len == 0 || len == 1)?; if act.len() == 1 { let raw_byte = unsafe { act.next_unchecked() }; con.write_mono_length_prefixed_with_tsymbol(raw_byte, b'+') .await?; } else { - return util::err(groups::HEYA); + return util::err(P::FULLRESP_HEYA); } Ok(()) } diff --git a/server/src/actions/mpop.rs b/server/src/actions/mpop.rs index 4a856252..0a443bdc 100644 --- a/server/src/actions/mpop.rs +++ b/server/src/actions/mpop.rs @@ -33,9 +33,9 @@ use crate::util::compiler; action!( /// Run an MPOP action fn mpop(handle: &corestore::Corestore, con: &mut T, act: ActionIter<'a>) { - ensure_length(act.len(), |len| len != 0)?; + ensure_length::

(act.len(), |len| len != 0)?; if registry::state_okay() { - let kve = handle.get_table_with::()?; + let kve = handle.get_table_with::()?; let encoding_is_okay = ENCODING_LUT_ITER[kve.is_key_encoded()](act.as_ref()); if compiler::likely(encoding_is_okay) { con.write_typed_array_header(act.len(), kve.get_value_tsymbol()) @@ -47,11 +47,11 @@ action!( } } } else { - return util::err(groups::ENCODING_ERROR); + return util::err(P::RCODE_ENCODING_ERROR); } } else { // don't begin the operation at all if the database is poisoned - return util::err(groups::SERVER_ERR); + return util::err(P::RCODE_SERVER_ERR); } Ok(()) } diff --git a/server/src/actions/mset.rs b/server/src/actions/mset.rs index 2093244d..d5bc3052 100644 --- a/server/src/actions/mset.rs +++ b/server/src/actions/mset.rs @@ -33,8 +33,8 @@ action!( /// Run an `MSET` query fn mset(handle: &crate::corestore::Corestore, con: &mut T, mut act: ActionIter<'a>) { let howmany = act.len(); - ensure_length(howmany, |size| size & 1 == 0 && size != 0)?; - let kve = handle.get_table_with::()?; + ensure_length::

(howmany, |size| size & 1 == 0 && size != 0)?; + let kve = handle.get_table_with::()?; let encoding_is_okay = ENCODING_LUT_ITER_PAIR[kve.get_encoding_tuple()](&act); if compiler::likely(encoding_is_okay) { let done_howmany: Option = if registry::state_okay() { @@ -51,10 +51,10 @@ action!( if let Some(done_howmany) = done_howmany { con.write_usize(done_howmany).await?; } else { - return util::err(groups::SERVER_ERR); + return util::err(P::RCODE_SERVER_ERR); } } else { - return util::err(groups::ENCODING_ERROR); + return util::err(P::RCODE_ENCODING_ERROR); } Ok(()) } diff --git a/server/src/actions/mupdate.rs b/server/src/actions/mupdate.rs index 8c802418..afedb6e6 100644 --- a/server/src/actions/mupdate.rs +++ b/server/src/actions/mupdate.rs @@ -33,8 +33,8 @@ action!( /// Run an `MUPDATE` query fn mupdate(handle: &crate::corestore::Corestore, con: &mut T, mut act: ActionIter<'a>) { let howmany = act.len(); - ensure_length(howmany, |size| size & 1 == 0 && size != 0)?; - let kve = handle.get_table_with::()?; + ensure_length::

(howmany, |size| size & 1 == 0 && size != 0)?; + let kve = handle.get_table_with::()?; let encoding_is_okay = ENCODING_LUT_ITER_PAIR[kve.get_encoding_tuple()](&act); let done_howmany: Option; if compiler::likely(encoding_is_okay) { @@ -52,10 +52,10 @@ action!( if let Some(done_howmany) = done_howmany { con.write_usize(done_howmany).await?; } else { - return util::err(groups::SERVER_ERR); + return util::err(P::RCODE_SERVER_ERR); } } else { - return util::err(groups::ENCODING_ERROR); + return util::err(P::RCODE_ENCODING_ERROR); } Ok(()) } diff --git a/server/src/actions/pop.rs b/server/src/actions/pop.rs index 00524ec4..cac72529 100644 --- a/server/src/actions/pop.rs +++ b/server/src/actions/pop.rs @@ -28,22 +28,22 @@ use crate::dbnet::connection::prelude::*; action! { fn pop(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len == 1)?; + ensure_length::

(act.len(), |len| len == 1)?; let key = unsafe { // SAFETY: We have checked for there to be one arg act.next_unchecked() }; if registry::state_okay() { - let kve = handle.get_table_with::()?; + let kve = handle.get_table_with::()?; match kve.pop(key) { Ok(Some(val)) => con.write_mono_length_prefixed_with_tsymbol( &val, kve.get_value_tsymbol() ).await?, - Ok(None) => return util::err(groups::NIL), - Err(()) => return util::err(groups::ENCODING_ERROR), + Ok(None) => return util::err(P::RCODE_NIL), + Err(()) => return util::err(P::RCODE_ENCODING_ERROR), } } else { - return util::err(groups::SERVER_ERR); + return util::err(P::RCODE_SERVER_ERR); } Ok(()) } diff --git a/server/src/actions/set.rs b/server/src/actions/set.rs index d5a0e799..af047876 100644 --- a/server/src/actions/set.rs +++ b/server/src/actions/set.rs @@ -28,21 +28,17 @@ //! This module provides functions to work with `SET` queries use crate::corestore; -use crate::corestore::booltable::BytesNicheLUT; use crate::dbnet::connection::prelude::*; use crate::queryengine::ActionIter; use corestore::Data; -const SET_NLUT: BytesNicheLUT = - BytesNicheLUT::new(groups::ENCODING_ERROR, groups::OKAY, groups::OVERWRITE_ERR); - action!( /// Run a `SET` query fn set(handle: &crate::corestore::Corestore, con: &mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len == 2)?; + ensure_length::

(act.len(), |len| len == 2)?; if registry::state_okay() { let did_we = { - let writer = handle.get_table_with::()?; + let writer = handle.get_table_with::()?; match unsafe { // UNSAFE(@ohsayan): This is completely safe as we've already checked // that there are exactly 2 arguments @@ -56,9 +52,9 @@ action!( Err(()) => None, } }; - con._write_raw(SET_NLUT[did_we]).await?; + con._write_raw(P::SET_NLUT[did_we]).await?; } else { - con._write_raw(groups::SERVER_ERR).await?; + con._write_raw(P::RCODE_SERVER_ERR).await?; } Ok(()) } diff --git a/server/src/actions/strong/sdel.rs b/server/src/actions/strong/sdel.rs index bb096788..f2183b62 100644 --- a/server/src/actions/strong/sdel.rs +++ b/server/src/actions/strong/sdel.rs @@ -37,8 +37,8 @@ action! { /// This either returns `Okay` if all the keys were `del`eted, or it returns a /// `Nil`, which is code `1` fn sdel(handle: &crate::corestore::Corestore, con: &mut T, act: ActionIter<'a>) { - ensure_length(act.len(), |len| len != 0)?; - let kve = handle.get_table_with::()?; + ensure_length::

(act.len(), |len| len != 0)?; + let kve = handle.get_table_with::()?; if registry::state_okay() { // guarantee one check: consistency let key_encoder = kve.get_key_encoder(); @@ -48,15 +48,15 @@ action! { self::snapshot_and_del(kve, key_encoder, act.into_inner()) }; match outcome { - StrongActionResult::Okay => con._write_raw(groups::OKAY).await?, + StrongActionResult::Okay => con._write_raw(P::RCODE_OKAY).await?, StrongActionResult::Nil => { // good, it failed because some key didn't exist - return util::err(groups::NIL); + return util::err(P::RCODE_NIL); }, - StrongActionResult::ServerError => return util::err(groups::SERVER_ERR), + StrongActionResult::ServerError => return util::err(P::RCODE_SERVER_ERR), StrongActionResult::EncodingError => { // error we love to hate: encoding error, ugh - return util::err(groups::ENCODING_ERROR); + return util::err(P::RCODE_ENCODING_ERROR); }, StrongActionResult::OverwriteError => unsafe { // SAFETY check: never the case @@ -64,7 +64,7 @@ action! { } } } else { - return util::err(groups::SERVER_ERR); + return util::err(P::RCODE_SERVER_ERR); } Ok(()) } diff --git a/server/src/actions/strong/sset.rs b/server/src/actions/strong/sset.rs index 384e7a4c..ddc57dcd 100644 --- a/server/src/actions/strong/sset.rs +++ b/server/src/actions/strong/sset.rs @@ -40,8 +40,8 @@ action! { /// `Overwrite Error` or code `2` fn sset(handle: &crate::corestore::Corestore, con: &mut T, act: ActionIter<'a>) { let howmany = act.len(); - ensure_length(howmany, |size| size & 1 == 0 && size != 0)?; - let kve = handle.get_table_with::()?; + ensure_length::

(howmany, |size| size & 1 == 0 && size != 0)?; + let kve = handle.get_table_with::()?; if registry::state_okay() { let encoder = kve.get_double_encoder(); let outcome = unsafe { @@ -50,12 +50,12 @@ action! { self::snapshot_and_insert(kve, encoder, act.into_inner()) }; match outcome { - StrongActionResult::Okay => con._write_raw(groups::OKAY).await?, - StrongActionResult::OverwriteError => return util::err(groups::OVERWRITE_ERR), - StrongActionResult::ServerError => return util::err(groups::SERVER_ERR), + StrongActionResult::Okay => con._write_raw(P::RCODE_OKAY).await?, + StrongActionResult::OverwriteError => return util::err(P::RCODE_OVERWRITE_ERR), + StrongActionResult::ServerError => return util::err(P::RCODE_SERVER_ERR), StrongActionResult::EncodingError => { // error we love to hate: encoding error, ugh - return util::err(groups::ENCODING_ERROR); + return util::err(P::RCODE_ENCODING_ERROR); }, StrongActionResult::Nil => unsafe { // SAFETY check: never the case @@ -63,7 +63,7 @@ action! { } } } else { - return util::err(groups::SERVER_ERR); + return util::err(P::RCODE_SERVER_ERR); } Ok(()) } diff --git a/server/src/actions/strong/supdate.rs b/server/src/actions/strong/supdate.rs index c30c49d0..8d1d3fb6 100644 --- a/server/src/actions/strong/supdate.rs +++ b/server/src/actions/strong/supdate.rs @@ -40,8 +40,8 @@ action! { /// or code `1` fn supdate(handle: &crate::corestore::Corestore, con: &mut T, act: ActionIter<'a>) { let howmany = act.len(); - ensure_length(howmany, |size| size & 1 == 0 && size != 0)?; - let kve = handle.get_table_with::()?; + ensure_length::

(howmany, |size| size & 1 == 0 && size != 0)?; + let kve = handle.get_table_with::()?; if registry::state_okay() { let encoder = kve.get_double_encoder(); let outcome = unsafe { @@ -49,15 +49,15 @@ action! { self::snapshot_and_update(kve, encoder, act.into_inner()) }; match outcome { - StrongActionResult::Okay => con._write_raw(groups::OKAY).await?, + StrongActionResult::Okay => con._write_raw(P::RCODE_OKAY).await?, StrongActionResult::Nil => { // good, it failed because some key didn't exist - return util::err(groups::NIL); + return util::err(P::RCODE_NIL); }, - StrongActionResult::ServerError => return util::err(groups::SERVER_ERR), + StrongActionResult::ServerError => return util::err(P::RCODE_SERVER_ERR), StrongActionResult::EncodingError => { // error we love to hate: encoding error, ugh - return util::err(groups::ENCODING_ERROR); + return util::err(P::RCODE_ENCODING_ERROR); }, StrongActionResult::OverwriteError => unsafe { // SAFETY check: never the case @@ -65,7 +65,7 @@ action! { } } } else { - return util::err(groups::SERVER_ERR); + return util::err(P::RCODE_SERVER_ERR); } Ok(()) } diff --git a/server/src/actions/update.rs b/server/src/actions/update.rs index e991d08b..29511ec7 100644 --- a/server/src/actions/update.rs +++ b/server/src/actions/update.rs @@ -28,20 +28,16 @@ //! This module provides functions to work with `UPDATE` queries //! -use crate::corestore::booltable::BytesNicheLUT; use crate::corestore::Data; use crate::dbnet::connection::prelude::*; -const UPDATE_NLUT: BytesNicheLUT = - BytesNicheLUT::new(groups::ENCODING_ERROR, groups::OKAY, groups::NIL); - action!( /// Run an `UPDATE` query fn update(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len == 2)?; + ensure_length::

(act.len(), |len| len == 2)?; if registry::state_okay() { let did_we = { - let writer = handle.get_table_with::()?; + let writer = handle.get_table_with::()?; match unsafe { // UNSAFE(@ohsayan): This is completely safe as we've already checked // that there are exactly 2 arguments @@ -55,9 +51,9 @@ action!( Err(()) => None, } }; - con._write_raw(UPDATE_NLUT[did_we]).await?; + con._write_raw(P::UPDATE_NLUT[did_we]).await?; } else { - return util::err(groups::SERVER_ERR); + return util::err(P::RCODE_SERVER_ERR); } Ok(()) } diff --git a/server/src/actions/uset.rs b/server/src/actions/uset.rs index 29ca1dad..b8e0101b 100644 --- a/server/src/actions/uset.rs +++ b/server/src/actions/uset.rs @@ -36,8 +36,8 @@ action!( /// This is like "INSERT or UPDATE" fn uset(handle: &crate::corestore::Corestore, con: &mut T, mut act: ActionIter<'a>) { let howmany = act.len(); - ensure_length(howmany, |size| size & 1 == 0 && size != 0)?; - let kve = handle.get_table_with::()?; + ensure_length::

(howmany, |size| size & 1 == 0 && size != 0)?; + let kve = handle.get_table_with::()?; let encoding_is_okay = ENCODING_LUT_ITER_PAIR[kve.get_encoding_tuple()](&act); if compiler::likely(encoding_is_okay) { if registry::state_okay() { @@ -46,10 +46,10 @@ action!( } con.write_usize(howmany / 2).await?; } else { - return util::err(groups::SERVER_ERR); + return util::err(P::RCODE_SERVER_ERR); } } else { - return util::err(groups::ENCODING_ERROR); + return util::err(P::RCODE_ENCODING_ERROR); } Ok(()) } diff --git a/server/src/actions/whereami.rs b/server/src/actions/whereami.rs index 03053bb3..bc958d02 100644 --- a/server/src/actions/whereami.rs +++ b/server/src/actions/whereami.rs @@ -28,7 +28,7 @@ use crate::dbnet::connection::prelude::*; action! { fn whereami(store: &Corestore, con: &mut T, act: ActionIter<'a>) { - ensure_length(act.len(), |len| len == 0)?; + ensure_length::

(act.len(), |len| len == 0)?; match store.get_ids() { (Some(ks), Some(tbl)) => { con.write_typed_non_null_array_header(2, b'+').await?; diff --git a/server/src/admin/mksnap.rs b/server/src/admin/mksnap.rs index 8a6c3d1b..cce96361 100644 --- a/server/src/admin/mksnap.rs +++ b/server/src/admin/mksnap.rs @@ -38,10 +38,10 @@ action!( if act.is_empty() { // traditional mksnap match engine.mksnap(handle.clone_store()).await { - SnapshotActionResult::Ok => con._write_raw(groups::OKAY).await?, - SnapshotActionResult::Failure => return util::err(groups::SERVER_ERR), - SnapshotActionResult::Disabled => return util::err(groups::SNAPSHOT_DISABLED), - SnapshotActionResult::Busy => return util::err(groups::SNAPSHOT_BUSY), + SnapshotActionResult::Ok => con._write_raw(P::RCODE_OKAY).await?, + SnapshotActionResult::Failure => return util::err(P::RCODE_SERVER_ERR), + SnapshotActionResult::Disabled => return util::err(P::RSTRING_SNAPSHOT_DISABLED), + SnapshotActionResult::Busy => return util::err(P::RSTRING_SNAPSHOT_BUSY), _ => unsafe { impossible!() }, } } else if act.len() == 1 { @@ -51,7 +51,7 @@ action!( act.next_unchecked_bytes() }; if !encoding::is_utf8(&name) { - return util::err(groups::ENCODING_ERROR); + return util::err(P::RCODE_ENCODING_ERROR); } // SECURITY: Check for directory traversal syntax @@ -72,21 +72,21 @@ action!( .count() != 0; if illegal_snapshot { - return util::err(groups::SNAPSHOT_ILLEGAL_NAME); + return util::err(P::RSTRING_SNAPSHOT_ILLEGAL_NAME); } // now make the snapshot match engine.mkrsnap(name, handle.clone_store()).await { - SnapshotActionResult::Ok => con._write_raw(groups::OKAY).await?, - SnapshotActionResult::Failure => return util::err(groups::SERVER_ERR), - SnapshotActionResult::Busy => return util::err(groups::SNAPSHOT_BUSY), + SnapshotActionResult::Ok => con._write_raw(P::RCODE_OKAY).await?, + SnapshotActionResult::Failure => return util::err(P::RCODE_SERVER_ERR), + SnapshotActionResult::Busy => return util::err(P::RSTRING_SNAPSHOT_BUSY), SnapshotActionResult::AlreadyExists => { - return util::err(groups::SNAPSHOT_DUPLICATE) + return util::err(P::RSTRING_SNAPSHOT_DUPLICATE) } _ => unsafe { impossible!() }, } } else { - return util::err(groups::ACTION_ERR); + return util::err(P::RCODE_ACTION_ERR); } Ok(()) } diff --git a/server/src/admin/sys.rs b/server/src/admin/sys.rs index 8644c79e..923e5434 100644 --- a/server/src/admin/sys.rs +++ b/server/src/admin/sys.rs @@ -47,11 +47,11 @@ const HEALTH_TABLE: BoolTable<&str> = BoolTable::new("good", "critical"); action! { fn sys(_handle: &Corestore, con: &mut T, iter: ActionIter<'_>) { let mut iter = iter; - ensure_boolean_or_aerr(iter.len() == 2)?; + ensure_boolean_or_aerr::

(iter.len() == 2)?; match unsafe { iter.next_lowercase_unchecked() }.as_ref() { INFO => sys_info(con, &mut iter).await, METRIC => sys_metric(con, &mut iter).await, - _ => util::err(groups::UNKNOWN_ACTION), + _ => util::err(P::RCODE_UNKNOWN_ACTION), } } fn sys_info(con: &mut T, iter: &mut ActionIter<'_>) { @@ -73,7 +73,7 @@ action! { Ok(size) => con.write_int64(size).await?, Err(e) => { log::error!("Failed to get storage usage with: {e}"); - return util::err(groups::SERVER_ERR); + return util::err(P::RCODE_SERVER_ERR); }, } } diff --git a/server/src/auth/mod.rs b/server/src/auth/mod.rs index 6f234a92..652b4e7f 100644 --- a/server/src/auth/mod.rs +++ b/server/src/auth/mod.rs @@ -64,42 +64,42 @@ action! { iter: ActionIter<'_> ) { let mut iter = iter; - match iter.next_lowercase().unwrap_or_aerr()?.as_ref() { + match iter.next_lowercase().unwrap_or_aerr::

()?.as_ref() { AUTH_LOGIN => self::_auth_login(con, auth, &mut iter).await, AUTH_CLAIM => self::_auth_claim(con, auth, &mut iter).await, AUTH_ADDUSER => { - ensure_boolean_or_aerr(iter.len() == 1)?; // just the username + ensure_boolean_or_aerr::

(iter.len() == 1)?; // just the username let username = unsafe { iter.next_unchecked() }; let key = auth.provider_mut().claim_user(username)?; con.write_string(&key).await?; Ok(()) } AUTH_LOGOUT => { - ensure_boolean_or_aerr(iter.is_empty())?; // nothing else + ensure_boolean_or_aerr::

(iter.is_empty())?; // nothing else auth.provider_mut().logout()?; auth.swap_executor_to_anonymous(); - con._write_raw(groups::OKAY).await?; + con._write_raw(P::RCODE_OKAY).await?; Ok(()) } AUTH_DELUSER => { - ensure_boolean_or_aerr(iter.len() == 1)?; // just the username + ensure_boolean_or_aerr::

(iter.len() == 1)?; // just the username auth.provider_mut().delete_user(unsafe { iter.next_unchecked() })?; - con._write_raw(groups::OKAY).await?; + con._write_raw(P::RCODE_OKAY).await?; Ok(()) } AUTH_RESTORE => self::auth_restore(con, auth, &mut iter).await, AUTH_LISTUSER => self::auth_listuser(con, auth, &mut iter).await, AUTH_WHOAMI => self::auth_whoami(con, auth, &mut iter).await, - _ => util::err(groups::UNKNOWN_ACTION), + _ => util::err(P::RCODE_UNKNOWN_ACTION), } } fn auth_whoami(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { - ensure_boolean_or_aerr(ActionIter::is_empty(iter))?; + ensure_boolean_or_aerr::

(ActionIter::is_empty(iter))?; con.write_string(&auth.provider().whoami()?).await?; Ok(()) } fn auth_listuser(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { - ensure_boolean_or_aerr(ActionIter::is_empty(iter))?; + ensure_boolean_or_aerr::

(ActionIter::is_empty(iter))?; let usernames = auth.provider().collect_usernames()?; con.write_typed_non_null_array_header(usernames.len(), b'+').await?; for username in usernames { @@ -119,13 +119,13 @@ action! { let id = unsafe { iter.next_unchecked() }; auth.provider().regenerate_using_origin(origin, id)? } - _ => return util::err(groups::ACTION_ERR), + _ => return util::err(P::RCODE_ACTION_ERR), }; con.write_string(&newkey).await?; Ok(()) } fn _auth_claim(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { - ensure_boolean_or_aerr(iter.len() == 1)?; // just the origin key + ensure_boolean_or_aerr::

(iter.len() == 1)?; // just the origin key let origin_key = unsafe { iter.next_unchecked() }; let key = auth.provider_mut().claim_root(origin_key)?; auth.swap_executor_to_authenticated(); @@ -139,7 +139,7 @@ action! { iter: ActionIter<'_> ) { let mut iter = iter; - match iter.next_lowercase().unwrap_or_aerr()?.as_ref() { + match iter.next_lowercase().unwrap_or_aerr::

()?.as_ref() { AUTH_LOGIN => self::_auth_login(con, auth, &mut iter).await, AUTH_CLAIM => self::_auth_claim(con, auth, &mut iter).await, AUTH_RESTORE => self::auth_restore(con, auth, &mut iter).await, @@ -149,11 +149,11 @@ action! { } fn _auth_login(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { // sweet, where's our username and password - ensure_boolean_or_aerr(iter.len() == 2)?; // just the uname and pass + ensure_boolean_or_aerr::

(iter.len() == 2)?; // just the uname and pass let (username, password) = unsafe { (iter.next_unchecked(), iter.next_unchecked()) }; auth.provider_mut().login(username, password)?; auth.swap_executor_to_authenticated(); - con._write_raw(groups::OKAY).await?; + con._write_raw(P::RCODE_OKAY).await?; Ok(()) } } diff --git a/server/src/corestore/memstore.rs b/server/src/corestore/memstore.rs index 10caafea..468b91b4 100644 --- a/server/src/corestore/memstore.rs +++ b/server/src/corestore/memstore.rs @@ -132,6 +132,7 @@ mod cluster { #[derive(Debug, PartialEq)] /// Errors arising from trying to modify/access containers +#[allow(dead_code)] pub enum DdlError { /// The object is still in use StillInUse, diff --git a/server/src/corestore/mod.rs b/server/src/corestore/mod.rs index 5b6e6646..ad5689dc 100644 --- a/server/src/corestore/mod.rs +++ b/server/src/corestore/mod.rs @@ -29,6 +29,7 @@ use crate::corestore::{ memstore::{DdlError, Keyspace, Memstore, ObjectID, DEFAULT}, table::{DescribeTable, Table}, }; +use crate::protocol::interface::ProtocolSpec; use crate::queryengine::parser::{Entity, OwnedEntity}; use crate::registry; use crate::storage; @@ -210,8 +211,8 @@ impl Corestore { self.estate.table.as_ref().map(|(_, tbl)| tbl.as_ref()) } /// Returns a table with the provided specification - pub fn get_table_with(&self) -> ActionResult<&T::Table> { - T::get(self) + pub fn get_table_with(&self) -> ActionResult<&T::Table> { + T::get::

(self) } /// Create a table: in-memory; **no transactional guarantees**. Two tables can be created /// simultaneously, but are never flushed unless we are very lucky. If the global flush diff --git a/server/src/corestore/table.rs b/server/src/corestore/table.rs index fb908093..c9c8959c 100644 --- a/server/src/corestore/table.rs +++ b/server/src/corestore/table.rs @@ -32,22 +32,22 @@ use crate::corestore::Data; use crate::corestore::{memstore::DdlError, KeyspaceResult}; use crate::dbnet::connection::prelude::Corestore; use crate::kvengine::{KVEListmap, KVEStandard, LockedVec}; -use crate::protocol::responses::groups; +use crate::protocol::interface::ProtocolSpec; use crate::util; pub trait DescribeTable { type Table; fn try_get(table: &Table) -> Option<&Self::Table>; - fn get(store: &Corestore) -> ActionResult<&Self::Table> { + fn get(store: &Corestore) -> ActionResult<&Self::Table> { match store.estate.table { Some((_, ref table)) => { // so we do have a table match Self::try_get(table) { Some(tbl) => Ok(tbl), - None => util::err(groups::WRONG_MODEL), + None => util::err(P::RSTRING_WRONG_MODEL), } } - None => util::err(groups::DEFAULT_UNSET), + None => util::err(P::RSTRING_DEFAULT_UNSET), } } } diff --git a/server/src/dbnet/connection.rs b/server/src/dbnet/connection.rs index b7a8e8f2..339e18b7 100644 --- a/server/src/dbnet/connection.rs +++ b/server/src/dbnet/connection.rs @@ -46,7 +46,7 @@ use crate::{ }, protocol::{ interface::{ProtocolRead, ProtocolSpec, ProtocolWrite}, - responses, Query, + Query, }, queryengine, IoResult, }; @@ -107,16 +107,13 @@ pub mod prelude { //! This module is hollow itself, it only re-exports from `dbnet::con` and `tokio::io` pub use super::{AuthProviderHandle, ClientConnection, Stream}; pub use crate::{ - actions::{ensure_boolean_or_aerr, ensure_cond_or_err, ensure_length}, + actions::{ensure_boolean_or_aerr, ensure_cond_or_err, ensure_length, translate_ddl_error}, corestore::{ table::{KVEBlob, KVEList}, Corestore, }, get_tbl, handle_entity, is_lowbit_set, - protocol::{ - interface::{ProtocolRead, ProtocolSpec}, - responses::{self, groups}, - }, + protocol::interface::{ProtocolRead, ProtocolSpec}, queryengine::ActionIter, registry, util::{self, FutureResult, UnwrapActionError, Unwrappable}, @@ -286,7 +283,7 @@ where Ok(QueryResult::E(r)) => self.con.close_conn_with_error(r).await?, Ok(QueryResult::Wrongtype) => { self.con - .close_conn_with_error(responses::groups::WRONGTYPE_ERR) + .close_conn_with_error(P::RCODE_WRONGTYPE_ERR) .await? } Ok(QueryResult::Disconnected) => return Ok(()), diff --git a/server/src/protocol/interface/mod.rs b/server/src/protocol/interface.rs similarity index 78% rename from server/src/protocol/interface/mod.rs rename to server/src/protocol/interface.rs index a4feb8e8..9c28afba 100644 --- a/server/src/protocol/interface/mod.rs +++ b/server/src/protocol/interface.rs @@ -24,9 +24,12 @@ * */ -use super::{responses, ParseError}; +use super::ParseError; use crate::{ - corestore::buffers::Integer64, + corestore::{ + booltable::{BytesBoolTable, BytesNicheLUT}, + buffers::Integer64, + }, dbnet::connection::{QueryResult, QueryWithAdvance, RawConnection, Stream}, util::FutureResult, IoResult, @@ -35,6 +38,7 @@ use std::io::{Error as IoError, ErrorKind}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}; pub trait ProtocolSpec { + // type symbols const TSYMBOL_STRING: u8; const TSYMBOL_BINARY: u8; const TSYMBOL_FLOAT: u8; @@ -43,10 +47,80 @@ pub trait ProtocolSpec { const TSYMBOL_TYPED_NON_NULL_ARRAY: u8; const TSYMBOL_ARRAY: u8; const TSYMBOL_FLAT_ARRAY: u8; + + // charset const LF: u8 = b'\n'; + + // metaframe const SIMPLE_QUERY_HEADER: &'static [u8]; const PIPELINED_QUERY_FIRST_BYTE: u8; + + // typed array const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8]; + + // respcodes + const RCODE_OKAY: &'static [u8]; + const RCODE_NIL: &'static [u8]; + const RCODE_OVERWRITE_ERR: &'static [u8]; + const RCODE_ACTION_ERR: &'static [u8]; + const RCODE_PACKET_ERR: &'static [u8]; + const RCODE_SERVER_ERR: &'static [u8]; + const RCODE_OTHER_ERR_EMPTY: &'static [u8]; + const RCODE_UNKNOWN_ACTION: &'static [u8]; + const RCODE_WRONGTYPE_ERR: &'static [u8]; + const RCODE_UNKNOWN_DATA_TYPE: &'static [u8]; + const RCODE_ENCODING_ERROR: &'static [u8]; + + // respstrings + const RSTRING_SNAPSHOT_BUSY: &'static [u8]; + const RSTRING_SNAPSHOT_DISABLED: &'static [u8]; + const RSTRING_SNAPSHOT_DUPLICATE: &'static [u8]; + const RSTRING_SNAPSHOT_ILLEGAL_NAME: &'static [u8]; + const RSTRING_ERR_ACCESS_AFTER_TERMSIG: &'static [u8]; + const RSTRING_DEFAULT_UNSET: &'static [u8]; + const RSTRING_CONTAINER_NOT_FOUND: &'static [u8]; + const RSTRING_STILL_IN_USE: &'static [u8]; + const RSTRING_PROTECTED_OBJECT: &'static [u8]; + const RSTRING_WRONG_MODEL: &'static [u8]; + const RSTRING_ALREADY_EXISTS: &'static [u8]; + const RSTRING_NOT_READY: &'static [u8]; + const RSTRING_DDL_TRANSACTIONAL_FAILURE: &'static [u8]; + const RSTRING_UNKNOWN_DDL_QUERY: &'static [u8]; + const RSTRING_BAD_EXPRESSION: &'static [u8]; + const RSTRING_UNKNOWN_MODEL: &'static [u8]; + const RSTRING_TOO_MANY_ARGUMENTS: &'static [u8]; + const RSTRING_CONTAINER_NAME_TOO_LONG: &'static [u8]; + const RSTRING_BAD_CONTAINER_NAME: &'static [u8]; + const RSTRING_UNKNOWN_INSPECT_QUERY: &'static [u8]; + const RSTRING_UNKNOWN_PROPERTY: &'static [u8]; + const RSTRING_KEYSPACE_NOT_EMPTY: &'static [u8]; + const RSTRING_BAD_TYPE_FOR_KEY: &'static [u8]; + const RSTRING_LISTMAP_BAD_INDEX: &'static [u8]; + const RSTRING_LISTMAP_LIST_IS_EMPTY: &'static [u8]; + + // full responses + const FULLRESP_RCODE_PACKET_ERR: &'static [u8]; + const FULLRESP_HEYA: &'static [u8]; + + // LUTs + const SET_NLUT: BytesNicheLUT = BytesNicheLUT::new( + Self::RCODE_ENCODING_ERROR, + Self::RCODE_OKAY, + Self::RCODE_OVERWRITE_ERR, + ); + const OKAY_BADIDX_NIL_NLUT: BytesNicheLUT = BytesNicheLUT::new( + Self::RCODE_NIL, + Self::RCODE_OKAY, + Self::RSTRING_LISTMAP_BAD_INDEX, + ); + const OKAY_OVW_BLUT: BytesBoolTable = + BytesBoolTable::new(Self::RCODE_OKAY, Self::RCODE_OVERWRITE_ERR); + + const UPDATE_NLUT: BytesNicheLUT = BytesNicheLUT::new( + Self::RCODE_ENCODING_ERROR, + Self::RCODE_OKAY, + Self::RCODE_NIL, + ); } /// # The `ProtocolRead` trait @@ -93,7 +167,7 @@ where Err(ParseError::NotEnough) => (), Err(ParseError::DatatypeParseFailure) => return Ok(QueryResult::Wrongtype), Err(ParseError::UnexpectedByte | ParseError::BadPacket) => { - return Ok(QueryResult::E(responses::full_responses::R_PACKET_ERR)); + return Ok(QueryResult::E(P::FULLRESP_RCODE_PACKET_ERR)); } } } diff --git a/server/src/protocol/iter.rs b/server/src/protocol/iter.rs index 69e471a2..f34f1e01 100644 --- a/server/src/protocol/iter.rs +++ b/server/src/protocol/iter.rs @@ -174,18 +174,3 @@ impl<'a> DoubleEndedIterator for BorrowedAnyArrayIter<'a> { impl<'a> ExactSizeIterator for BorrowedAnyArrayIter<'a> {} impl<'a> FusedIterator for BorrowedAnyArrayIter<'a> {} - -#[test] -fn test_iter() { - use super::{Parser, Query}; - let (q, _fwby) = Parser::parse(b"*3\n3\nset1\nx3\n100").unwrap(); - let r = match q { - Query::Simple(q) => q, - _ => panic!("Wrong query"), - }; - let it = r.as_slice().iter(); - let mut iter = unsafe { AnyArrayIter::new(it) }; - assert_eq!(iter.next_uppercase().unwrap().as_ref(), "SET".as_bytes()); - assert_eq!(iter.next().unwrap(), "x".as_bytes()); - assert_eq!(iter.next().unwrap(), "100".as_bytes()); -} diff --git a/server/src/protocol/mod.rs b/server/src/protocol/mod.rs index bd70dff5..964c48fb 100644 --- a/server/src/protocol/mod.rs +++ b/server/src/protocol/mod.rs @@ -25,15 +25,12 @@ */ use crate::corestore::heap_array::HeapArray; -use core::{fmt, mem::transmute, slice}; +use core::{fmt, slice}; #[cfg(feature = "nightly")] mod benches; -#[cfg(test)] -mod tests; // pub mods pub mod interface; pub mod iter; -pub mod responses; // versions mod v2; // endof pub mods @@ -42,7 +39,7 @@ mod v2; pub const PROTOCOL_VERSION: f32 = 2.0; /// The Skyhash protocol version string (Skyhash-x.y) pub const PROTOCOL_VERSIONSTRING: &str = "Skyhash-2.0"; -pub type Skyhash2 = Parser; +pub type Skyhash2 = v2::Parser; #[derive(PartialEq)] /// As its name says, an [`UnsafeSlice`] is a terribly unsafe slice. It's guarantess are @@ -167,262 +164,3 @@ impl PipelinedQuery { struct OwnedPipelinedQuery { data: Vec>>, } - -/// A parser for Skyhash 2.0 -pub struct Parser { - end: *const u8, - cursor: *const u8, -} - -unsafe impl Sync for Parser {} -unsafe impl Send for Parser {} - -impl Parser { - /// Initialize a new parser - fn new(slice: &[u8]) -> Self { - unsafe { - Self { - end: slice.as_ptr().add(slice.len()), - cursor: slice.as_ptr(), - } - } - } -} - -// basic methods -impl Parser { - /// Returns a ptr one byte past the allocation of the buffer - const fn data_end_ptr(&self) -> *const u8 { - self.end - } - /// Returns the position of the cursor - /// WARNING: Deref might led to a segfault - const fn cursor_ptr(&self) -> *const u8 { - self.cursor - } - /// Check how many bytes we have left - fn remaining(&self) -> usize { - self.data_end_ptr() as usize - self.cursor_ptr() as usize - } - /// Check if we have `size` bytes remaining - fn has_remaining(&self, size: usize) -> bool { - self.remaining() >= size - } - #[cfg(test)] - /// Check if we have exhausted the buffer - fn exhausted(&self) -> bool { - self.cursor_ptr() >= self.data_end_ptr() - } - /// Check if the buffer is not exhausted - fn not_exhausted(&self) -> bool { - self.cursor_ptr() < self.data_end_ptr() - } - /// Attempts to return the byte pointed at by the cursor. - /// WARNING: The same segfault warning - const unsafe fn get_byte_at_cursor(&self) -> u8 { - *self.cursor_ptr() - } -} - -// mutable refs -impl Parser { - /// Increment the cursor by `by` positions - unsafe fn incr_cursor_by(&mut self, by: usize) { - self.cursor = self.cursor.add(by); - } - /// Increment the position of the cursor by one position - unsafe fn incr_cursor(&mut self) { - self.incr_cursor_by(1); - } -} - -// higher level abstractions -impl Parser { - /// Attempt to read `len` bytes - fn read_until(&mut self, len: usize) -> ParseResult { - if self.has_remaining(len) { - unsafe { - // UNSAFE(@ohsayan): Already verified lengths - let slice = UnsafeSlice::new(self.cursor_ptr(), len); - self.incr_cursor_by(len); - Ok(slice) - } - } else { - Err(ParseError::NotEnough) - } - } - #[cfg(test)] - /// Attempt to read a byte slice terminated by an LF - fn read_line(&mut self) -> ParseResult { - let start_ptr = self.cursor_ptr(); - unsafe { - while self.not_exhausted() && self.get_byte_at_cursor() != b'\n' { - self.incr_cursor(); - } - if self.not_exhausted() && self.get_byte_at_cursor() == b'\n' { - let len = self.cursor_ptr() as usize - start_ptr as usize; - self.incr_cursor(); // skip LF - Ok(UnsafeSlice::new(start_ptr, len)) - } else { - Err(ParseError::NotEnough) - } - } - } - /// Attempt to read a line, **rejecting an empty payload** - fn read_line_pedantic(&mut self) -> ParseResult { - let start_ptr = self.cursor_ptr(); - unsafe { - while self.not_exhausted() && self.get_byte_at_cursor() != b'\n' { - self.incr_cursor(); - } - let len = self.cursor_ptr() as usize - start_ptr as usize; - let has_lf = self.not_exhausted() && self.get_byte_at_cursor() == b'\n'; - if has_lf && len != 0 { - self.incr_cursor(); // skip LF - Ok(UnsafeSlice::new(start_ptr, len)) - } else { - // just some silly hackery - Err(transmute(has_lf)) - } - } - } - /// Attempt to read an `usize` from the buffer - fn read_usize(&mut self) -> ParseResult { - let line = self.read_line_pedantic()?; - let bytes = unsafe { - // UNSAFE(@ohsayan): We just extracted the slice - line.as_slice() - }; - let mut ret = 0usize; - for byte in bytes { - if byte.is_ascii_digit() { - ret = match ret.checked_mul(10) { - Some(r) => r, - None => return Err(ParseError::DatatypeParseFailure), - }; - ret = match ret.checked_add((byte & 0x0F) as _) { - Some(r) => r, - None => return Err(ParseError::DatatypeParseFailure), - }; - } else { - return Err(ParseError::DatatypeParseFailure); - } - } - Ok(ret) - } -} - -// query impls -impl Parser { - /// Parse the next simple query. This should have passed the `*` tsymbol - /// - /// Simple query structure (tokenized line-by-line): - /// ```text - /// * -> Simple Query Header - /// \n -> Count of elements in the simple query - /// \n -> Length of element 1 - /// -> element 1 itself - /// \n -> Length of element 2 - /// -> element 2 itself - /// ... - /// ``` - fn _next_simple_query(&mut self) -> ParseResult> { - let element_count = self.read_usize()?; - unsafe { - let mut data = HeapArray::new_writer(element_count); - for i in 0..element_count { - let element_size = self.read_usize()?; - let element = self.read_until(element_size)?; - data.write_to_index(i, element); - } - Ok(data.finish()) - } - } - /// Parse a simple query - fn next_simple_query(&mut self) -> ParseResult { - Ok(SimpleQuery { - data: self._next_simple_query()?, - }) - } - /// Parse a pipelined query. This should have passed the `$` tsymbol - /// - /// Pipelined query structure (tokenized line-by-line): - /// ```text - /// $ -> Pipeline - /// \n -> Pipeline has n queries - /// \n -> Query 1 has 3 elements - /// \n -> Q1E1 has 3 bytes - /// -> Q1E1 itself - /// \n -> Q1E2 has 1 byte - /// -> Q1E2 itself - /// \n -> Q1E3 has 3 bytes - /// -> Q1E3 itself - /// \n -> Query 2 has 2 elements - /// \n -> Q2E1 has 3 bytes - /// -> Q2E1 itself - /// \n -> Q2E2 has 1 byte - /// -> Q2E2 itself - /// ... - /// ``` - /// - /// Example: - /// ```text - /// $ -> Pipeline - /// 2\n -> Pipeline has 2 queries - /// 3\n -> Query 1 has 3 elements - /// 3\n -> Q1E1 has 3 bytes - /// SET -> Q1E1 itself - /// 1\n -> Q1E2 has 1 byte - /// x -> Q1E2 itself - /// 3\n -> Q1E3 has 3 bytes - /// 100 -> Q1E3 itself - /// 2\n -> Query 2 has 2 elements - /// 3\n -> Q2E1 has 3 bytes - /// GET -> Q2E1 itself - /// 1\n -> Q2E2 has 1 byte - /// x -> Q2E2 itself - /// ``` - fn next_pipeline(&mut self) -> ParseResult { - let query_count = self.read_usize()?; - unsafe { - let mut queries = HeapArray::new_writer(query_count); - for i in 0..query_count { - let sq = self._next_simple_query()?; - queries.write_to_index(i, sq); - } - Ok(PipelinedQuery { - data: queries.finish(), - }) - } - } - fn _parse(&mut self) -> ParseResult { - if self.not_exhausted() { - unsafe { - let first_byte = self.get_byte_at_cursor(); - self.incr_cursor(); - let data = match first_byte { - b'*' => { - // a simple query - Query::Simple(self.next_simple_query()?) - } - b'$' => { - // a pipelined query - Query::Pipelined(self.next_pipeline()?) - } - _ => return Err(ParseError::UnexpectedByte), - }; - Ok(data) - } - } else { - Err(ParseError::NotEnough) - } - } - // only expose this. don't expose Self::new since that'll be _relatively easier_ to - // invalidate invariants for - pub fn parse(buf: &[u8]) -> ParseResult<(Query, usize)> { - let mut slf = Self::new(buf); - let body = slf._parse()?; - let consumed = slf.cursor_ptr() as usize - buf.as_ptr() as usize; - Ok((body, consumed)) - } -} diff --git a/server/src/protocol/responses.rs b/server/src/protocol/responses.rs deleted file mode 100644 index 6023c24d..00000000 --- a/server/src/protocol/responses.rs +++ /dev/null @@ -1,153 +0,0 @@ -/* - * Created on Sat Aug 22 2020 - * - * This file is a part of Skytable - * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source - * NoSQL database written by Sayan Nandan ("the Author") with the - * vision to provide flexibility in data modelling without compromising - * on performance, queryability or scalability. - * - * Copyright (c) 2020, Sayan Nandan - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - * -*/ - -//! Primitives for generating Skyhash compatible responses - -pub mod groups { - #![allow(unused)] - //! # Pre-compiled response **elements** - //! These are pre-compiled response groups and **not** complete responses. If complete - //! responses are required, user protocol::responses::fresp - use ::sky_macros::compiled_eresp_bytes as eresp; - /// Response code 0 as a array element - pub const OKAY: &[u8] = eresp!("0"); - /// Response code 1 as a array element - pub const NIL: &[u8] = eresp!("1"); - /// Response code 2 as a array element - pub const OVERWRITE_ERR: &[u8] = eresp!("2"); - /// Response code 3 as a array element - pub const ACTION_ERR: &[u8] = eresp!("3"); - /// Response code 4 as a array element - pub const PACKET_ERR: &[u8] = eresp!("4"); - /// Response code 5 as a array element - pub const SERVER_ERR: &[u8] = eresp!("5"); - /// Response code 6 as a array element - pub const OTHER_ERR_EMPTY: &[u8] = eresp!("6"); - /// Response group element with string "HEYA" - pub const HEYA: &[u8] = "+4\nHEY!".as_bytes(); - /// "Unknown action" error response - pub const UNKNOWN_ACTION: &[u8] = eresp!("Unknown action"); - /// Response code 7 - pub const WRONGTYPE_ERR: &[u8] = eresp!("7"); - /// Response code 8 - pub const UNKNOWN_DATA_TYPE: &[u8] = eresp!("8"); - /// Response code 9 as an array element - pub const ENCODING_ERROR: &[u8] = eresp!("9"); - /// Snapshot busy error - pub const SNAPSHOT_BUSY: &[u8] = eresp!("err-snapshot-busy"); - /// Snapshot disabled (other error) - pub const SNAPSHOT_DISABLED: &[u8] = eresp!("err-snapshot-disabled"); - /// Duplicate snapshot - pub const SNAPSHOT_DUPLICATE: &[u8] = eresp!("duplicate-snapshot"); - /// Snapshot has illegal name (other error) - pub const SNAPSHOT_ILLEGAL_NAME: &[u8] = eresp!("err-invalid-snapshot-name"); - /// Access after termination signal (other error) - pub const ERR_ACCESS_AFTER_TERMSIG: &[u8] = eresp!("err-access-after-termsig"); - - // keyspace related resps - /// The default container was not set - pub const DEFAULT_UNSET: &[u8] = eresp!("default-container-unset"); - /// The container was not found - pub const CONTAINER_NOT_FOUND: &[u8] = eresp!("container-not-found"); - /// The container is still in use and so cannot be removed - pub const STILL_IN_USE: &[u8] = eresp!("still-in-use"); - /// This is a protected object and hence cannot be accessed - pub const PROTECTED_OBJECT: &[u8] = eresp!("err-protected-object"); - /// The action was applied against the wrong model - pub const WRONG_MODEL: &[u8] = eresp!("wrong-model"); - /// The container already exists - pub const ALREADY_EXISTS: &[u8] = eresp!("err-already-exists"); - /// The container is not ready - pub const NOT_READY: &[u8] = eresp!("not-ready"); - /// A transactional failure occurred - pub const DDL_TRANSACTIONAL_FAILURE: &[u8] = eresp!("transactional-failure"); - /// An unknown DDL query was run - pub const UNKNOWN_DDL_QUERY: &[u8] = eresp!("unknown-ddl-query"); - /// The expression for a DDL query was malformed - pub const BAD_EXPRESSION: &[u8] = eresp!("malformed-expression"); - /// An unknown model was passed in a DDL query - pub const UNKNOWN_MODEL: &[u8] = eresp!("unknown-model"); - /// Too many arguments were passed to model constructor - pub const TOO_MANY_ARGUMENTS: &[u8] = eresp!("too-many-args"); - /// The container name is too long - pub const CONTAINER_NAME_TOO_LONG: &[u8] = eresp!("container-name-too-long"); - /// The container name contains invalid characters - pub const BAD_CONTAINER_NAME: &[u8] = eresp!("bad-container-name"); - /// An unknown inspect query - pub const UNKNOWN_INSPECT_QUERY: &[u8] = eresp!("unknown-inspect-query"); - /// An unknown table property was passed - pub const UNKNOWN_PROPERTY: &[u8] = eresp!("unknown-property"); - /// The keyspace is not empty and hence cannot be removed - pub const KEYSPACE_NOT_EMPTY: &[u8] = eresp!("keyspace-not-empty"); - /// Bad type supplied in a DDL query for the key - pub const BAD_TYPE_FOR_KEY: &[u8] = eresp!("bad-type-for-key"); - /// The index for the provided list was non-existent - pub const LISTMAP_BAD_INDEX: &[u8] = eresp!("bad-list-index"); - /// The list is empty - pub const LISTMAP_LIST_IS_EMPTY: &[u8] = eresp!("list-is-empty"); -} - -pub mod full_responses { - #![allow(unused)] - //! # Pre-compiled **responses** - //! These are pre-compiled **complete** responses. This means that they should - //! be written off directly to the stream and should **not be preceded by any response metaframe** - - /// Response code: 0 (Okay) - pub const R_OKAY: &[u8] = "*!1\n0\n".as_bytes(); - /// Response code: 1 (Nil) - pub const R_NIL: &[u8] = "*!1\n1\n".as_bytes(); - /// Response code: 2 (Overwrite Error) - pub const R_OVERWRITE_ERR: &[u8] = "*!1\n2\n".as_bytes(); - /// Response code: 3 (Action Error) - pub const R_ACTION_ERR: &[u8] = "*!1\n3\n".as_bytes(); - /// Response code: 4 (Packet Error) - pub const R_PACKET_ERR: &[u8] = "*!1\n4\n".as_bytes(); - /// Response code: 5 (Server Error) - pub const R_SERVER_ERR: &[u8] = "*!1\n5\n".as_bytes(); - /// Response code: 6 (Other Error _without description_) - pub const R_OTHER_ERR_EMPTY: &[u8] = "*!1\n6\n".as_bytes(); - /// Response code: 7; wrongtype - pub const R_WRONGTYPE_ERR: &[u8] = "*!1\n7".as_bytes(); - /// Response code: 8; unknown data type - pub const R_UNKNOWN_DATA_TYPE: &[u8] = "*!1\n8\n".as_bytes(); - /// A heya response - pub const R_HEYA: &[u8] = "*+4\nHEY!\n".as_bytes(); - /// An other response with description: "Unknown action" - pub const R_UNKNOWN_ACTION: &[u8] = "*!14\nUnknown action\n".as_bytes(); - /// A 0 uint64 reply - pub const R_ONE_INT_REPLY: &[u8] = "*:1\n1\n".as_bytes(); - /// A 1 uint64 reply - pub const R_ZERO_INT_REPLY: &[u8] = "*:1\n0\n".as_bytes(); - /// Snapshot busy (other error) - pub const R_SNAPSHOT_BUSY: &[u8] = "*!17\nerr-snapshot-busy\n".as_bytes(); - /// Snapshot disabled (other error) - pub const R_SNAPSHOT_DISABLED: &[u8] = "*!21\nerr-snapshot-disabled\n".as_bytes(); - /// Snapshot has illegal name (other error) - pub const R_SNAPSHOT_ILLEGAL_NAME: &[u8] = "*!25\nerr-invalid-snapshot-name\n".as_bytes(); - /// Access after termination signal (other error) - pub const R_ERR_ACCESS_AFTER_TERMSIG: &[u8] = "*!24\nerr-access-after-termsig\n".as_bytes(); -} diff --git a/server/src/protocol/v2/interface_impls.rs b/server/src/protocol/v2/interface_impls.rs new file mode 100644 index 00000000..fc60c04a --- /dev/null +++ b/server/src/protocol/v2/interface_impls.rs @@ -0,0 +1,274 @@ +/* + * Created on Sat Apr 30 2022 + * + * This file is a part of Skytable + * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source + * NoSQL database written by Sayan Nandan ("the Author") with the + * vision to provide flexibility in data modelling without compromising + * on performance, queryability or scalability. + * + * Copyright (c) 2022, Sayan Nandan + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * +*/ + +use crate::{ + corestore::buffers::Integer64, + dbnet::connection::{QueryWithAdvance, RawConnection, Stream}, + protocol::{ + interface::{ProtocolRead, ProtocolSpec, ProtocolWrite}, + ParseError, Skyhash2, + }, + util::FutureResult, + IoResult, +}; +use ::sky_macros::compiled_eresp_bytes as eresp; +use tokio::io::AsyncWriteExt; + +impl ProtocolSpec for Skyhash2 { + // type symbols + const TSYMBOL_STRING: u8 = b'+'; + const TSYMBOL_BINARY: u8 = b'?'; + const TSYMBOL_FLOAT: u8 = b'%'; + const TSYMBOL_INT64: u8 = b':'; + const TSYMBOL_TYPED_ARRAY: u8 = b'@'; + const TSYMBOL_TYPED_NON_NULL_ARRAY: u8 = b'^'; + const TSYMBOL_ARRAY: u8 = b'&'; + const TSYMBOL_FLAT_ARRAY: u8 = b'_'; + + // typed array + const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8] = b"\0"; + + // metaframe + const SIMPLE_QUERY_HEADER: &'static [u8] = b"*"; + const PIPELINED_QUERY_FIRST_BYTE: u8 = b'$'; + + // respcodes + /// Response code 0 as a array element + const RCODE_OKAY: &'static [u8] = eresp!("0"); + /// Response code 1 as a array element + const RCODE_NIL: &'static [u8] = eresp!("1"); + /// Response code 2 as a array element + const RCODE_OVERWRITE_ERR: &'static [u8] = eresp!("2"); + /// Response code 3 as a array element + const RCODE_ACTION_ERR: &'static [u8] = eresp!("3"); + /// Response code 4 as a array element + const RCODE_PACKET_ERR: &'static [u8] = eresp!("4"); + /// Response code 5 as a array element + const RCODE_SERVER_ERR: &'static [u8] = eresp!("5"); + /// Response code 6 as a array element + const RCODE_OTHER_ERR_EMPTY: &'static [u8] = eresp!("6"); + /// "Unknown action" error response + const RCODE_UNKNOWN_ACTION: &'static [u8] = eresp!("Unknown action"); + /// Response code 7 + const RCODE_WRONGTYPE_ERR: &'static [u8] = eresp!("7"); + /// Response code 8 + const RCODE_UNKNOWN_DATA_TYPE: &'static [u8] = eresp!("8"); + /// Response code 9 as an array element + const RCODE_ENCODING_ERROR: &'static [u8] = eresp!("9"); + + // respstrings + + /// Snapshot busy error + const RSTRING_SNAPSHOT_BUSY: &'static [u8] = eresp!("err-snapshot-busy"); + /// Snapshot disabled (other error) + const RSTRING_SNAPSHOT_DISABLED: &'static [u8] = eresp!("err-snapshot-disabled"); + /// Duplicate snapshot + const RSTRING_SNAPSHOT_DUPLICATE: &'static [u8] = eresp!("duplicate-snapshot"); + /// Snapshot has illegal name (other error) + const RSTRING_SNAPSHOT_ILLEGAL_NAME: &'static [u8] = eresp!("err-invalid-snapshot-name"); + /// Access after termination signal (other error) + const RSTRING_ERR_ACCESS_AFTER_TERMSIG: &'static [u8] = eresp!("err-access-after-termsig"); + + // keyspace related resps + /// The default container was not set + const RSTRING_DEFAULT_UNSET: &'static [u8] = eresp!("default-container-unset"); + /// The container was not found + const RSTRING_CONTAINER_NOT_FOUND: &'static [u8] = eresp!("container-not-found"); + /// The container is still in use and so cannot be removed + const RSTRING_STILL_IN_USE: &'static [u8] = eresp!("still-in-use"); + /// This is a protected object and hence cannot be accessed + const RSTRING_PROTECTED_OBJECT: &'static [u8] = eresp!("err-protected-object"); + /// The action was applied against the wrong model + const RSTRING_WRONG_MODEL: &'static [u8] = eresp!("wrong-model"); + /// The container already exists + const RSTRING_ALREADY_EXISTS: &'static [u8] = eresp!("err-already-exists"); + /// The container is not ready + const RSTRING_NOT_READY: &'static [u8] = eresp!("not-ready"); + /// A transactional failure occurred + const RSTRING_DDL_TRANSACTIONAL_FAILURE: &'static [u8] = eresp!("transactional-failure"); + /// An unknown DDL query was run + const RSTRING_UNKNOWN_DDL_QUERY: &'static [u8] = eresp!("unknown-ddl-query"); + /// The expression for a DDL query was malformed + const RSTRING_BAD_EXPRESSION: &'static [u8] = eresp!("malformed-expression"); + /// An unknown model was passed in a DDL query + const RSTRING_UNKNOWN_MODEL: &'static [u8] = eresp!("unknown-model"); + /// Too many arguments were passed to model constructor + const RSTRING_TOO_MANY_ARGUMENTS: &'static [u8] = eresp!("too-many-args"); + /// The container name is too long + const RSTRING_CONTAINER_NAME_TOO_LONG: &'static [u8] = eresp!("container-name-too-long"); + /// The container name contains invalid characters + const RSTRING_BAD_CONTAINER_NAME: &'static [u8] = eresp!("bad-container-name"); + /// An unknown inspect query + const RSTRING_UNKNOWN_INSPECT_QUERY: &'static [u8] = eresp!("unknown-inspect-query"); + /// An unknown table property was passed + const RSTRING_UNKNOWN_PROPERTY: &'static [u8] = eresp!("unknown-property"); + /// The keyspace is not empty and hence cannot be removed + const RSTRING_KEYSPACE_NOT_EMPTY: &'static [u8] = eresp!("keyspace-not-empty"); + /// Bad type supplied in a DDL query for the key + const RSTRING_BAD_TYPE_FOR_KEY: &'static [u8] = eresp!("bad-type-for-key"); + /// The index for the provided list was non-existent + const RSTRING_LISTMAP_BAD_INDEX: &'static [u8] = eresp!("bad-list-index"); + /// The list is empty + const RSTRING_LISTMAP_LIST_IS_EMPTY: &'static [u8] = eresp!("list-is-empty"); + + // full responses + const FULLRESP_RCODE_PACKET_ERR: &'static [u8] = b"*!1\n4\n"; + const FULLRESP_HEYA: &'static [u8] = b"+4\nHEY!"; +} + +impl ProtocolRead for T +where + T: RawConnection + Send + Sync, + Strm: Stream, +{ + fn try_query(&self) -> Result { + Skyhash2::parse(self.get_buffer()) + } +} + +impl ProtocolWrite for T +where + T: RawConnection + Send + Sync, + Strm: Stream, +{ + fn write_string<'life0, 'life1, 'ret_life>( + &'life0 mut self, + string: &'life1 str, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash2::TSYMBOL_STRING]).await?; + // length + let len_bytes = Integer64::from(string.len()); + stream.write_all(&len_bytes).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await?; + // payload + stream.write_all(string.as_bytes()).await + }) + } + fn write_binary<'life0, 'life1, 'ret_life>( + &'life0 mut self, + binary: &'life1 [u8], + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash2::TSYMBOL_BINARY]).await?; + // length + let len_bytes = Integer64::from(binary.len()); + stream.write_all(&len_bytes).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await?; + // payload + stream.write_all(binary).await + }) + } + fn write_usize<'life0, 'ret_life>( + &'life0 mut self, + size: usize, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash2::TSYMBOL_INT64]).await?; + // body + stream.write_all(&Integer64::from(size)).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await + }) + } + fn write_int64<'life0, 'ret_life>( + &'life0 mut self, + int: u64, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash2::TSYMBOL_INT64]).await?; + // body + stream.write_all(&Integer64::from(int)).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await + }) + } + fn write_float<'life0, 'ret_life>( + &'life0 mut self, + float: f32, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash2::TSYMBOL_FLOAT]).await?; + // body + stream.write_all(float.to_string().as_bytes()).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await + }) + } + fn write_typed_array_element<'life0, 'life1, 'ret_life>( + &'life0 mut self, + element: &'life1 [u8], + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // len + stream.write_all(&Integer64::from(element.len())).await?; + // LF + stream.write_all(&[Skyhash2::LF]).await?; + // body + stream.write_all(element).await + }) + } +} diff --git a/server/src/protocol/v2/mod.rs b/server/src/protocol/v2/mod.rs index 7a66de08..1a358ece 100644 --- a/server/src/protocol/v2/mod.rs +++ b/server/src/protocol/v2/mod.rs @@ -24,162 +24,271 @@ * */ -use super::{ - interface::{ProtocolRead, ProtocolSpec, ProtocolWrite}, - ParseError, Skyhash2, -}; +mod interface_impls; + use crate::{ - corestore::buffers::Integer64, - dbnet::connection::{QueryWithAdvance, RawConnection, Stream}, - util::FutureResult, - IoResult, + corestore::heap_array::HeapArray, + protocol::{ParseError, ParseResult, PipelinedQuery, Query, SimpleQuery, UnsafeSlice}, }; -use tokio::io::AsyncWriteExt; +use core::mem::transmute; +#[cfg(test)] +mod tests; -impl ProtocolSpec for Skyhash2 { - const TSYMBOL_STRING: u8 = b'+'; - const TSYMBOL_BINARY: u8 = b'?'; - const TSYMBOL_FLOAT: u8 = b'%'; - const TSYMBOL_INT64: u8 = b':'; - const TSYMBOL_TYPED_ARRAY: u8 = b'@'; - const TSYMBOL_TYPED_NON_NULL_ARRAY: u8 = b'^'; - const TSYMBOL_ARRAY: u8 = b'&'; - const TSYMBOL_FLAT_ARRAY: u8 = b'_'; - const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8] = b"\0"; - const SIMPLE_QUERY_HEADER: &'static [u8] = b"*"; - const PIPELINED_QUERY_FIRST_BYTE: u8 = b'$'; +/// A parser for Skyhash 2.0 +pub struct Parser { + end: *const u8, + cursor: *const u8, } -impl ProtocolRead for T -where - T: RawConnection + Send + Sync, - Strm: Stream, -{ - fn try_query(&self) -> Result { - Skyhash2::parse(self.get_buffer()) +unsafe impl Sync for Parser {} +unsafe impl Send for Parser {} + +impl Parser { + /// Initialize a new parser + fn new(slice: &[u8]) -> Self { + unsafe { + Self { + end: slice.as_ptr().add(slice.len()), + cursor: slice.as_ptr(), + } + } } } -impl ProtocolWrite for T -where - T: RawConnection + Send + Sync, - Strm: Stream, -{ - fn write_string<'life0, 'life1, 'ret_life>( - &'life0 mut self, - string: &'life1 str, - ) -> FutureResult<'ret_life, IoResult<()>> - where - 'life0: 'ret_life, - 'life1: 'ret_life, - Self: 'ret_life, - { - Box::pin(async move { - let stream = self.get_mut_stream(); - // tsymbol - stream.write_all(&[Skyhash2::TSYMBOL_STRING]).await?; - // length - let len_bytes = Integer64::from(string.len()); - stream.write_all(&len_bytes).await?; - // LF - stream.write_all(&[Skyhash2::LF]).await?; - // payload - stream.write_all(string.as_bytes()).await - }) +// basic methods +impl Parser { + /// Returns a ptr one byte past the allocation of the buffer + const fn data_end_ptr(&self) -> *const u8 { + self.end } - fn write_binary<'life0, 'life1, 'ret_life>( - &'life0 mut self, - binary: &'life1 [u8], - ) -> FutureResult<'ret_life, IoResult<()>> - where - 'life0: 'ret_life, - 'life1: 'ret_life, - Self: 'ret_life, - { - Box::pin(async move { - let stream = self.get_mut_stream(); - // tsymbol - stream.write_all(&[Skyhash2::TSYMBOL_BINARY]).await?; - // length - let len_bytes = Integer64::from(binary.len()); - stream.write_all(&len_bytes).await?; - // LF - stream.write_all(&[Skyhash2::LF]).await?; - // payload - stream.write_all(binary).await - }) + /// Returns the position of the cursor + /// WARNING: Deref might led to a segfault + const fn cursor_ptr(&self) -> *const u8 { + self.cursor } - fn write_usize<'life0, 'ret_life>( - &'life0 mut self, - size: usize, - ) -> FutureResult<'ret_life, IoResult<()>> - where - 'life0: 'ret_life, - Self: 'ret_life, - { - Box::pin(async move { - let stream = self.get_mut_stream(); - // tsymbol - stream.write_all(&[Skyhash2::TSYMBOL_INT64]).await?; - // body - stream.write_all(&Integer64::from(size)).await?; - // LF - stream.write_all(&[Skyhash2::LF]).await - }) + /// Check how many bytes we have left + fn remaining(&self) -> usize { + self.data_end_ptr() as usize - self.cursor_ptr() as usize } - fn write_int64<'life0, 'ret_life>( - &'life0 mut self, - int: u64, - ) -> FutureResult<'ret_life, IoResult<()>> - where - 'life0: 'ret_life, - Self: 'ret_life, - { - Box::pin(async move { - let stream = self.get_mut_stream(); - // tsymbol - stream.write_all(&[Skyhash2::TSYMBOL_INT64]).await?; - // body - stream.write_all(&Integer64::from(int)).await?; - // LF - stream.write_all(&[Skyhash2::LF]).await - }) + /// Check if we have `size` bytes remaining + fn has_remaining(&self, size: usize) -> bool { + self.remaining() >= size } - fn write_float<'life0, 'ret_life>( - &'life0 mut self, - float: f32, - ) -> FutureResult<'ret_life, IoResult<()>> - where - 'life0: 'ret_life, - Self: 'ret_life, - { - Box::pin(async move { - let stream = self.get_mut_stream(); - // tsymbol - stream.write_all(&[Skyhash2::TSYMBOL_FLOAT]).await?; - // body - stream.write_all(float.to_string().as_bytes()).await?; - // LF - stream.write_all(&[Skyhash2::LF]).await - }) + #[cfg(test)] + /// Check if we have exhausted the buffer + fn exhausted(&self) -> bool { + self.cursor_ptr() >= self.data_end_ptr() + } + /// Check if the buffer is not exhausted + fn not_exhausted(&self) -> bool { + self.cursor_ptr() < self.data_end_ptr() + } + /// Attempts to return the byte pointed at by the cursor. + /// WARNING: The same segfault warning + const unsafe fn get_byte_at_cursor(&self) -> u8 { + *self.cursor_ptr() + } +} + +// mutable refs +impl Parser { + /// Increment the cursor by `by` positions + unsafe fn incr_cursor_by(&mut self, by: usize) { + self.cursor = self.cursor.add(by); } - fn write_typed_array_element<'life0, 'life1, 'ret_life>( - &'life0 mut self, - element: &'life1 [u8], - ) -> FutureResult<'ret_life, IoResult<()>> - where - 'life0: 'ret_life, - 'life1: 'ret_life, - Self: 'ret_life, - { - Box::pin(async move { - let stream = self.get_mut_stream(); - // len - stream.write_all(&Integer64::from(element.len())).await?; - // LF - stream.write_all(&[Skyhash2::LF]).await?; - // body - stream.write_all(element).await + /// Increment the position of the cursor by one position + unsafe fn incr_cursor(&mut self) { + self.incr_cursor_by(1); + } +} + +// higher level abstractions +impl Parser { + /// Attempt to read `len` bytes + fn read_until(&mut self, len: usize) -> ParseResult { + if self.has_remaining(len) { + unsafe { + // UNSAFE(@ohsayan): Already verified lengths + let slice = UnsafeSlice::new(self.cursor_ptr(), len); + self.incr_cursor_by(len); + Ok(slice) + } + } else { + Err(ParseError::NotEnough) + } + } + #[cfg(test)] + /// Attempt to read a byte slice terminated by an LF + fn read_line(&mut self) -> ParseResult { + let start_ptr = self.cursor_ptr(); + unsafe { + while self.not_exhausted() && self.get_byte_at_cursor() != b'\n' { + self.incr_cursor(); + } + if self.not_exhausted() && self.get_byte_at_cursor() == b'\n' { + let len = self.cursor_ptr() as usize - start_ptr as usize; + self.incr_cursor(); // skip LF + Ok(UnsafeSlice::new(start_ptr, len)) + } else { + Err(ParseError::NotEnough) + } + } + } + /// Attempt to read a line, **rejecting an empty payload** + fn read_line_pedantic(&mut self) -> ParseResult { + let start_ptr = self.cursor_ptr(); + unsafe { + while self.not_exhausted() && self.get_byte_at_cursor() != b'\n' { + self.incr_cursor(); + } + let len = self.cursor_ptr() as usize - start_ptr as usize; + let has_lf = self.not_exhausted() && self.get_byte_at_cursor() == b'\n'; + if has_lf && len != 0 { + self.incr_cursor(); // skip LF + Ok(UnsafeSlice::new(start_ptr, len)) + } else { + // just some silly hackery + Err(transmute(has_lf)) + } + } + } + /// Attempt to read an `usize` from the buffer + fn read_usize(&mut self) -> ParseResult { + let line = self.read_line_pedantic()?; + let bytes = unsafe { + // UNSAFE(@ohsayan): We just extracted the slice + line.as_slice() + }; + let mut ret = 0usize; + for byte in bytes { + if byte.is_ascii_digit() { + ret = match ret.checked_mul(10) { + Some(r) => r, + None => return Err(ParseError::DatatypeParseFailure), + }; + ret = match ret.checked_add((byte & 0x0F) as _) { + Some(r) => r, + None => return Err(ParseError::DatatypeParseFailure), + }; + } else { + return Err(ParseError::DatatypeParseFailure); + } + } + Ok(ret) + } +} + +// query impls +impl Parser { + /// Parse the next simple query. This should have passed the `*` tsymbol + /// + /// Simple query structure (tokenized line-by-line): + /// ```text + /// * -> Simple Query Header + /// \n -> Count of elements in the simple query + /// \n -> Length of element 1 + /// -> element 1 itself + /// \n -> Length of element 2 + /// -> element 2 itself + /// ... + /// ``` + fn _next_simple_query(&mut self) -> ParseResult> { + let element_count = self.read_usize()?; + unsafe { + let mut data = HeapArray::new_writer(element_count); + for i in 0..element_count { + let element_size = self.read_usize()?; + let element = self.read_until(element_size)?; + data.write_to_index(i, element); + } + Ok(data.finish()) + } + } + /// Parse a simple query + fn next_simple_query(&mut self) -> ParseResult { + Ok(SimpleQuery { + data: self._next_simple_query()?, }) } + /// Parse a pipelined query. This should have passed the `$` tsymbol + /// + /// Pipelined query structure (tokenized line-by-line): + /// ```text + /// $ -> Pipeline + /// \n -> Pipeline has n queries + /// \n -> Query 1 has 3 elements + /// \n -> Q1E1 has 3 bytes + /// -> Q1E1 itself + /// \n -> Q1E2 has 1 byte + /// -> Q1E2 itself + /// \n -> Q1E3 has 3 bytes + /// -> Q1E3 itself + /// \n -> Query 2 has 2 elements + /// \n -> Q2E1 has 3 bytes + /// -> Q2E1 itself + /// \n -> Q2E2 has 1 byte + /// -> Q2E2 itself + /// ... + /// ``` + /// + /// Example: + /// ```text + /// $ -> Pipeline + /// 2\n -> Pipeline has 2 queries + /// 3\n -> Query 1 has 3 elements + /// 3\n -> Q1E1 has 3 bytes + /// SET -> Q1E1 itself + /// 1\n -> Q1E2 has 1 byte + /// x -> Q1E2 itself + /// 3\n -> Q1E3 has 3 bytes + /// 100 -> Q1E3 itself + /// 2\n -> Query 2 has 2 elements + /// 3\n -> Q2E1 has 3 bytes + /// GET -> Q2E1 itself + /// 1\n -> Q2E2 has 1 byte + /// x -> Q2E2 itself + /// ``` + fn next_pipeline(&mut self) -> ParseResult { + let query_count = self.read_usize()?; + unsafe { + let mut queries = HeapArray::new_writer(query_count); + for i in 0..query_count { + let sq = self._next_simple_query()?; + queries.write_to_index(i, sq); + } + Ok(PipelinedQuery { + data: queries.finish(), + }) + } + } + fn _parse(&mut self) -> ParseResult { + if self.not_exhausted() { + unsafe { + let first_byte = self.get_byte_at_cursor(); + self.incr_cursor(); + let data = match first_byte { + b'*' => { + // a simple query + Query::Simple(self.next_simple_query()?) + } + b'$' => { + // a pipelined query + Query::Pipelined(self.next_pipeline()?) + } + _ => return Err(ParseError::UnexpectedByte), + }; + Ok(data) + } + } else { + Err(ParseError::NotEnough) + } + } + // only expose this. don't expose Self::new since that'll be _relatively easier_ to + // invalidate invariants for + pub fn parse(buf: &[u8]) -> ParseResult<(Query, usize)> { + let mut slf = Self::new(buf); + let body = slf._parse()?; + let consumed = slf.cursor_ptr() as usize - buf.as_ptr() as usize; + Ok((body, consumed)) + } } diff --git a/server/src/protocol/tests.rs b/server/src/protocol/v2/tests.rs similarity index 96% rename from server/src/protocol/tests.rs rename to server/src/protocol/v2/tests.rs index dfb537b3..8046ed8b 100644 --- a/server/src/protocol/tests.rs +++ b/server/src/protocol/v2/tests.rs @@ -25,7 +25,7 @@ */ use super::{Parser, PipelinedQuery, Query, SimpleQuery}; -use crate::protocol::ParseError; +use crate::protocol::{iter::AnyArrayIter, ParseError}; use std::iter::Map; use std::vec::IntoIter as VecIntoIter; @@ -641,3 +641,18 @@ fn pipelined_query_fail_because_not_enough() { assert_eq!(ret, ParseError::NotEnough) } } + +#[test] +fn test_iter() { + use super::{Parser, Query}; + let (q, _fwby) = Parser::parse(b"*3\n3\nset1\nx3\n100").unwrap(); + let r = match q { + Query::Simple(q) => q, + _ => panic!("Wrong query"), + }; + let it = r.as_slice().iter(); + let mut iter = unsafe { AnyArrayIter::new(it) }; + assert_eq!(iter.next_uppercase().unwrap().as_ref(), "SET".as_bytes()); + assert_eq!(iter.next().unwrap(), "x".as_bytes()); + assert_eq!(iter.next().unwrap(), "100".as_bytes()); +} diff --git a/server/src/queryengine/ddl.rs b/server/src/queryengine/ddl.rs index 058a0004..f6dfac1a 100644 --- a/server/src/queryengine/ddl.rs +++ b/server/src/queryengine/ddl.rs @@ -42,14 +42,14 @@ action! { /// like queries fn create(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { // minlength is 2 (create has already been checked) - ensure_length(act.len(), |size| size > 1)?; + ensure_length::

(act.len(), |size| size > 1)?; let mut create_what = unsafe { act.next().unsafe_unwrap() }.to_vec(); create_what.make_ascii_uppercase(); match create_what.as_ref() { TABLE => create_table(handle, con, act).await?, KEYSPACE => create_keyspace(handle, con, act).await?, _ => { - con._write_raw(groups::UNKNOWN_DDL_QUERY).await?; + con._write_raw(P::RSTRING_UNKNOWN_DDL_QUERY).await?; } } Ok(()) @@ -59,14 +59,14 @@ action! { /// like queries fn ddl_drop(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { // minlength is 2 (create has already been checked) - ensure_length(act.len(), |size| size > 1)?; + ensure_length::

(act.len(), |size| size > 1)?; let mut create_what = unsafe { act.next().unsafe_unwrap() }.to_vec(); create_what.make_ascii_uppercase(); match create_what.as_ref() { TABLE => drop_table(handle, con, act).await?, KEYSPACE => drop_keyspace(handle, con, act).await?, _ => { - con._write_raw(groups::UNKNOWN_DDL_QUERY).await?; + con._write_raw(P::RSTRING_UNKNOWN_DDL_QUERY).await?; } } Ok(()) @@ -74,77 +74,77 @@ action! { /// We should have ` (args) properties` fn create_table(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |size| size > 1 && size < 4)?; + ensure_length::

(act.len(), |size| size > 1 && size < 4)?; let table_name = unsafe { act.next().unsafe_unwrap() }; let model_name = unsafe { act.next().unsafe_unwrap() }; - let (table_entity, model_code) = parser::parse_table_args(table_name, model_name)?; + let (table_entity, model_code) = parser::parse_table_args::

(table_name, model_name)?; let is_volatile = match act.next() { Some(maybe_volatile) => { - ensure_cond_or_err(maybe_volatile.eq(VOLATILE), responses::groups::UNKNOWN_PROPERTY)?; + ensure_cond_or_err(maybe_volatile.eq(VOLATILE), P::RSTRING_UNKNOWN_PROPERTY)?; true } None => false, }; if registry::state_okay() { - handle.create_table(table_entity, model_code, is_volatile)?; - con._write_raw(groups::OKAY).await?; + translate_ddl_error::(handle.create_table(table_entity, model_code, is_volatile))?; + con._write_raw(P::RCODE_OKAY).await?; } else { - return util::err(groups::SERVER_ERR); + return util::err(P::RCODE_SERVER_ERR); } Ok(()) } /// We should have `` fn create_keyspace(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len == 1)?; + ensure_length::

(act.len(), |len| len == 1)?; match act.next() { Some(ksid) => { - ensure_cond_or_err(encoding::is_utf8(&ksid), responses::groups::ENCODING_ERROR)?; + ensure_cond_or_err(encoding::is_utf8(&ksid), P::RCODE_ENCODING_ERROR)?; let ksid_str = unsafe { str::from_utf8_unchecked(ksid) }; - ensure_cond_or_err(VALID_CONTAINER_NAME.is_match(ksid_str), responses::groups::BAD_EXPRESSION)?; - ensure_cond_or_err(ksid.len() < 64, responses::groups::CONTAINER_NAME_TOO_LONG)?; + ensure_cond_or_err(VALID_CONTAINER_NAME.is_match(ksid_str), P::RSTRING_BAD_EXPRESSION)?; + ensure_cond_or_err(ksid.len() < 64, P::RSTRING_CONTAINER_NAME_TOO_LONG)?; let ksid = unsafe { ObjectID::from_slice(ksid_str) }; if registry::state_okay() { - handle.create_keyspace(ksid)?; - con._write_raw(groups::OKAY).await? + translate_ddl_error::(handle.create_keyspace(ksid))?; + con._write_raw(P::RCODE_OKAY).await? } else { - return util::err(groups::SERVER_ERR); + return util::err(P::RCODE_SERVER_ERR); } } - None => return util::err(groups::ACTION_ERR), + None => return util::err(P::RCODE_ACTION_ERR), } Ok(()) } /// Drop a table (`` only) fn drop_table(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |size| size == 1)?; + ensure_length::

(act.len(), |size| size == 1)?; match act.next() { Some(eg) => { - let entity_group = parser::Entity::from_slice(eg)?; + let entity_group = parser::Entity::from_slice::

(eg)?; if registry::state_okay() { - handle.drop_table(entity_group)?; - con._write_raw(groups::OKAY).await?; + translate_ddl_error::(handle.drop_table(entity_group))?; + con._write_raw(P::RCODE_OKAY).await?; } else { - return util::err(groups::SERVER_ERR); + return util::err(P::RCODE_SERVER_ERR); } }, - None => return util::err(groups::ACTION_ERR), + None => return util::err(P::RCODE_ACTION_ERR), } Ok(()) } /// Drop a keyspace (`` only) fn drop_keyspace(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |size| size == 1)?; + ensure_length::

(act.len(), |size| size == 1)?; match act.next() { Some(ksid) => { - ensure_cond_or_err(ksid.len() < 64, responses::groups::CONTAINER_NAME_TOO_LONG)?; + ensure_cond_or_err(ksid.len() < 64, P::RSTRING_CONTAINER_NAME_TOO_LONG)?; let force_remove = match act.next() { Some(bts) if bts.eq(FORCE_REMOVE) => true, None => false, _ => { - return util::err(responses::groups::UNKNOWN_ACTION); + return util::err(P::RCODE_UNKNOWN_ACTION); } }; if registry::state_okay() { @@ -154,13 +154,13 @@ action! { } else { handle.drop_keyspace(objid) }; - result?; - con._write_raw(groups::OKAY).await?; + translate_ddl_error::(result)?; + con._write_raw(P::RCODE_OKAY).await?; } else { - return util::err(groups::SERVER_ERR); + return util::err(P::RCODE_SERVER_ERR); } }, - None => return util::err(groups::ACTION_ERR), + None => return util::err(P::RCODE_ACTION_ERR), } Ok(()) } diff --git a/server/src/queryengine/inspect.rs b/server/src/queryengine/inspect.rs index 25768d31..365dce2d 100644 --- a/server/src/queryengine/inspect.rs +++ b/server/src/queryengine/inspect.rs @@ -25,7 +25,10 @@ */ use super::ddl::{KEYSPACE, TABLE}; -use crate::corestore::memstore::ObjectID; +use crate::corestore::{ + memstore::{Keyspace, ObjectID}, + table::Table, +}; use crate::dbnet::connection::prelude::*; const KEYSPACES: &[u8] = "KEYSPACES".as_bytes(); @@ -44,7 +47,7 @@ action! { KEYSPACE => inspect_keyspace(handle, con, act).await?, TABLE => inspect_table(handle, con, act).await?, KEYSPACES => { - ensure_length(act.len(), |len| len == 0)?; + ensure_length::

(act.len(), |len| len == 0)?; // let's return what all keyspaces exist let ks_list: Vec = handle .get_store() @@ -57,35 +60,35 @@ action! { con.write_typed_non_null_array_element(&ks).await?; } } - _ => return util::err(groups::UNKNOWN_INSPECT_QUERY), + _ => return util::err(P::RSTRING_UNKNOWN_INSPECT_QUERY), } } - None => return util::err(groups::ACTION_ERR), + None => return util::err(P::RCODE_ACTION_ERR), } Ok(()) } /// INSPECT a keyspace. This should only have the keyspace ID fn inspect_keyspace(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len < 2)?; + ensure_length::

(act.len(), |len| len < 2)?; let tbl_list: Vec = match act.next() { Some(keyspace_name) => { // inspect the provided keyspace let ksid = if keyspace_name.len() > 64 { - return util::err(groups::BAD_CONTAINER_NAME); + return util::err(P::RSTRING_BAD_CONTAINER_NAME); } else { keyspace_name }; let ks = match handle.get_keyspace(ksid) { Some(kspace) => kspace, - None => return util::err(groups::CONTAINER_NOT_FOUND), + None => return util::err(P::RSTRING_CONTAINER_NOT_FOUND), }; ks.tables.iter().map(|kv| kv.key().clone()).collect() }, None => { // inspect the current keyspace - let cks = handle.get_cks()?; + let cks = translate_ddl_error::(handle.get_cks())?; cks.tables.iter().map(|kv| kv.key().clone()).collect() }, }; @@ -98,7 +101,7 @@ action! { /// INSPECT a table. This should only have the table ID fn inspect_table(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len < 2)?; + ensure_length::

(act.len(), |len| len < 2)?; match act.next() { Some(entity) => { let entity = handle_entity!(con, entity); @@ -106,7 +109,7 @@ action! { }, None => { // inspect the current table - let tbl = handle.get_table_result()?; + let tbl = translate_ddl_error::(handle.get_table_result())?; con.write_string(tbl.describe_self()).await?; }, } diff --git a/server/src/queryengine/mod.rs b/server/src/queryengine/mod.rs index 14d4a738..b18fe03b 100644 --- a/server/src/queryengine/mod.rs +++ b/server/src/queryengine/mod.rs @@ -58,7 +58,7 @@ macro_rules! gen_constants_and_matches { pub const $action2: &[u8] = stringify!($action2).as_bytes(); )* } - let first = $buf.next_uppercase().unwrap_or_custom_aerr(groups::PACKET_ERR)?; + let first = $buf.next_uppercase().unwrap_or_custom_aerr(P::RCODE_PACKET_ERR)?; match first.as_ref() { $( tags::$action => $fns($db, $con, $buf).await?, @@ -67,7 +67,7 @@ macro_rules! gen_constants_and_matches { tags::$action2 => $fns2.await?, )* _ => { - $con._write_raw(groups::UNKNOWN_ACTION).await?; + $con._write_raw(P::RCODE_UNKNOWN_ACTION).await?; } } }; @@ -87,7 +87,7 @@ action! { // won't suddenly become invalid AnyArrayIter::new(bufref.iter()) }; - match iter.next_lowercase().unwrap_or_custom_aerr(groups::PACKET_ERR)?.as_ref() { + match iter.next_lowercase().unwrap_or_custom_aerr(P::RCODE_PACKET_ERR)?.as_ref() { ACTION_AUTH => auth::auth_login_only(con, auth, iter).await, _ => util::err(auth::errors::AUTH_CODE_BAD_CREDENTIALS), } @@ -158,13 +158,13 @@ async fn execute_stage<'a, P: ProtocolSpec, T: 'a + ClientConnection, S action! { /// Handle `use ` like queries fn entity_swap(handle: &mut Corestore, con: &mut T, mut act: ActionIter<'a>) { - ensure_length(act.len(), |len| len == 1)?; + ensure_length::

(act.len(), |len| len == 1)?; let entity = unsafe { // SAFETY: Already checked len act.next_unchecked() }; - handle.swap_entity(Entity::from_slice(entity)?)?; - con._write_raw(groups::OKAY).await?; + translate_ddl_error::(handle.swap_entity(Entity::from_slice::

(entity)?))?; + con._write_raw(P::RCODE_OKAY).await?; Ok(()) } } diff --git a/server/src/queryengine/parser.rs b/server/src/queryengine/parser.rs index 0e0eda61..bcb2a269 100644 --- a/server/src/queryengine/parser.rs +++ b/server/src/queryengine/parser.rs @@ -26,7 +26,7 @@ use crate::corestore::{lazy::Lazy, memstore::ObjectID}; use crate::kvengine::encoding; -use crate::protocol::responses; +use crate::queryengine::ProtocolSpec; use crate::util::{ self, compiler::{self, cold_err}, @@ -47,20 +47,20 @@ pub(super) static VALID_CONTAINER_NAME: LazyRegexFn = pub(super) static VALID_TYPENAME: LazyRegexFn = LazyRegexFn::new(|| Regex::new("^<[a-zA-Z][a-zA-Z0-9]+[^>\\s]?>{1}$").unwrap()); -pub(super) fn parse_table_args<'a>( +pub(super) fn parse_table_args<'a, P: ProtocolSpec>( table_name: &'a [u8], model_name: &'a [u8], ) -> Result<(Entity<'a>, u8), &'static [u8]> { if compiler::unlikely(!encoding::is_utf8(&table_name) || !encoding::is_utf8(&model_name)) { - return Err(responses::groups::ENCODING_ERROR); + return Err(P::RCODE_ENCODING_ERROR); } let model_name_str = unsafe { str::from_utf8_unchecked(model_name) }; // get the entity group - let entity_group = Entity::from_slice(table_name)?; + let entity_group = Entity::from_slice::

(table_name)?; let splits: Vec<&str> = model_name_str.split('(').collect(); if compiler::unlikely(splits.len() != 2) { - return Err(responses::groups::BAD_EXPRESSION); + return Err(P::RSTRING_BAD_EXPRESSION); } let model_name_split = unsafe { ucidx!(splits, 0) }; @@ -69,19 +69,19 @@ pub(super) fn parse_table_args<'a>( // model name has to have at least one char while model args should have // atleast `)` 1 chars (for example if the model takes no arguments: `smh()`) if compiler::unlikely(model_name_split.is_empty() || model_args_split.is_empty()) { - return Err(responses::groups::BAD_EXPRESSION); + return Err(P::RSTRING_BAD_EXPRESSION); } // THIS IS WHERE WE HANDLE THE NEWER MODELS if model_name_split.as_bytes() != KEYMAP { - return Err(responses::groups::UNKNOWN_MODEL); + return Err(P::RSTRING_UNKNOWN_MODEL); } let non_bracketed_end = unsafe { ucidx!(*model_args_split.as_bytes(), model_args_split.len() - 1) != b')' }; if compiler::unlikely(non_bracketed_end) { - return Err(responses::groups::BAD_EXPRESSION); + return Err(P::RSTRING_BAD_EXPRESSION); } // should be (ty1, ty2) @@ -96,10 +96,10 @@ pub(super) fn parse_table_args<'a>( let all_nonzero = model_args.into_iter().all(|v| !v.is_empty()); if all_nonzero { // arg fun - Err(responses::groups::TOO_MANY_ARGUMENTS) + Err(P::RSTRING_TOO_MANY_ARGUMENTS) } else { // comma fun - Err(responses::groups::BAD_EXPRESSION) + Err(P::RSTRING_BAD_EXPRESSION) } }); } @@ -116,7 +116,7 @@ pub(super) fn parse_table_args<'a>( VALID_CONTAINER_NAME.is_match(val_ty) }; if compiler::unlikely(!(valid_key_ty || valid_val_ty)) { - return Err(responses::groups::BAD_EXPRESSION); + return Err(P::RSTRING_BAD_EXPRESSION); } let key_ty = key_ty.as_bytes(); let val_ty = val_ty.as_bytes(); @@ -132,8 +132,8 @@ pub(super) fn parse_table_args<'a>( (STR, LIST_BINSTR) => 6, (STR, LIST_STR) => 7, // KVExt bad keytypes (we can't use lists as keys for obvious reasons) - (LIST_STR, _) | (LIST_BINSTR, _) => return Err(responses::groups::BAD_TYPE_FOR_KEY), - _ => return Err(responses::groups::UNKNOWN_DATA_TYPE), + (LIST_STR, _) | (LIST_BINSTR, _) => return Err(P::RSTRING_BAD_TYPE_FOR_KEY), + _ => return Err(P::RCODE_UNKNOWN_DATA_TYPE), }; Ok((entity_group, model_code)) } @@ -179,29 +179,29 @@ impl<'a> fmt::Debug for Entity<'a> { } impl<'a> Entity<'a> { - pub fn from_slice(input: ByteSlice<'a>) -> Result, &'static [u8]> { + pub fn from_slice(input: ByteSlice<'a>) -> Result, &'static [u8]> { let parts: Vec<&[u8]> = input.split(|b| *b == b':').collect(); if compiler::unlikely(parts.is_empty() || parts.len() > 2) { - return util::err(responses::groups::BAD_EXPRESSION); + return util::err(P::RSTRING_BAD_EXPRESSION); } // just the table let first_entity = unsafe { ucidx!(parts, 0) }; if parts.len() == 1 { - Ok(Entity::Single(Self::verify_entity_name(first_entity)?)) + Ok(Entity::Single(Self::verify_entity_name::

(first_entity)?)) } else { - let second_entity = Self::verify_entity_name(unsafe { ucidx!(parts, 1) })?; + let second_entity = Self::verify_entity_name::

(unsafe { ucidx!(parts, 1) })?; if first_entity.is_empty() { // partial syntax; so the table is in the second position Ok(Entity::Partial(second_entity)) } else { - let keyspace = Self::verify_entity_name(first_entity)?; - let table = Self::verify_entity_name(second_entity)?; + let keyspace = Self::verify_entity_name::

(first_entity)?; + let table = Self::verify_entity_name::

(second_entity)?; Ok(Entity::Full(keyspace, table)) } } } #[inline(always)] - fn verify_entity_name(input: &[u8]) -> Result<&[u8], &'static [u8]> { + fn verify_entity_name(input: &[u8]) -> Result<&[u8], &'static [u8]> { let mut valid_name = input.len() < 65 && encoding::is_utf8(input) && unsafe { VALID_CONTAINER_NAME.is_match(str::from_utf8_unchecked(input)) }; @@ -220,13 +220,13 @@ impl<'a> Entity<'a> { Ok(input) } else if compiler::unlikely(input.is_empty()) { // bad expression (something like `:`) - util::err(responses::groups::BAD_EXPRESSION) + util::err(P::RSTRING_BAD_EXPRESSION) } else if compiler::unlikely(input.eq(b"system")) { // system cannot be switched to - util::err(responses::groups::PROTECTED_OBJECT) + util::err(P::RSTRING_PROTECTED_OBJECT) } else { // the container has a bad name - util::err(responses::groups::BAD_CONTAINER_NAME) + util::err(P::RSTRING_BAD_CONTAINER_NAME) } } pub fn as_owned(&self) -> OwnedEntity { diff --git a/server/src/queryengine/tests.rs b/server/src/queryengine/tests.rs index 5bcdf9e8..c6b6e7d1 100644 --- a/server/src/queryengine/tests.rs +++ b/server/src/queryengine/tests.rs @@ -28,6 +28,8 @@ use super::parser; mod parser_ddl_tests { use super::parser::Entity; + use crate::protocol::interface::ProtocolSpec; + use crate::protocol::Skyhash2; macro_rules! byvec { ($($element:expr),*) => { vec![ @@ -38,9 +40,8 @@ mod parser_ddl_tests { }; } fn parse_table_args_test(input: Vec<&'static [u8]>) -> Result<(Entity<'_>, u8), &'static [u8]> { - super::parser::parse_table_args(input[0], input[1]) + super::parser::parse_table_args::(input[0], input[1]) } - use crate::protocol::responses; #[test] fn test_table_args_valid() { // binstr, binstr @@ -97,12 +98,12 @@ mod parser_ddl_tests { let it = byvec!("1one", "keymap(binstr,binstr)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_CONTAINER_NAME + Skyhash2::RSTRING_BAD_CONTAINER_NAME ); let it = byvec!("%whywouldsomeone", "keymap(binstr,binstr)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_CONTAINER_NAME + Skyhash2::RSTRING_BAD_CONTAINER_NAME ); } #[test] @@ -133,22 +134,22 @@ mod parser_ddl_tests { let it = byvec!("mycooltbl", "keymap(wth, str)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::UNKNOWN_DATA_TYPE + Skyhash2::RCODE_UNKNOWN_DATA_TYPE ); let it = byvec!("mycooltbl", "keymap(wth, wth)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::UNKNOWN_DATA_TYPE + Skyhash2::RCODE_UNKNOWN_DATA_TYPE ); let it = byvec!("mycooltbl", "keymap(str, wth)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::UNKNOWN_DATA_TYPE + Skyhash2::RCODE_UNKNOWN_DATA_TYPE ); let it = byvec!("mycooltbl", "keymap(wth1, wth2)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::UNKNOWN_DATA_TYPE + Skyhash2::RCODE_UNKNOWN_DATA_TYPE ); } #[test] @@ -156,17 +157,17 @@ mod parser_ddl_tests { let it = byvec!("mycooltbl", "wthmap(wth, wth)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::UNKNOWN_MODEL + Skyhash2::RSTRING_UNKNOWN_MODEL ); let it = byvec!("mycooltbl", "wthmap(str, str)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::UNKNOWN_MODEL + Skyhash2::RSTRING_UNKNOWN_MODEL ); let it = byvec!("mycooltbl", "wthmap()"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::UNKNOWN_MODEL + Skyhash2::RSTRING_UNKNOWN_MODEL ); } #[test] @@ -174,82 +175,82 @@ mod parser_ddl_tests { let it = byvec!("mycooltbl", "keymap("); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap(,"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap(,,"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap),"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap),,"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap),,)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap(,)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap(,,)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap,"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap,,"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap,,)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap(str,"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap(str,str"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap(str,str,"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap(str,str,)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); let it = byvec!("mycooltbl", "keymap(str,str,),"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_EXPRESSION + Skyhash2::RSTRING_BAD_EXPRESSION ); } @@ -258,14 +259,14 @@ mod parser_ddl_tests { let it = byvec!("mycooltbl", "keymap(str, str, str)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::TOO_MANY_ARGUMENTS + Skyhash2::RSTRING_TOO_MANY_ARGUMENTS ); // this should be valid for not-yet-known data types too let it = byvec!("mycooltbl", "keymap(wth, wth, wth)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::TOO_MANY_ARGUMENTS + Skyhash2::RSTRING_TOO_MANY_ARGUMENTS ); } @@ -274,86 +275,96 @@ mod parser_ddl_tests { let it = byvec!("myverycooltbl", "keymap(list, str)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_TYPE_FOR_KEY + Skyhash2::RSTRING_BAD_TYPE_FOR_KEY ); let it = byvec!("myverycooltbl", "keymap(list, binstr)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_TYPE_FOR_KEY + Skyhash2::RSTRING_BAD_TYPE_FOR_KEY ); // for consistency checks let it = byvec!("myverycooltbl", "keymap(list, binstr)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_TYPE_FOR_KEY + Skyhash2::RSTRING_BAD_TYPE_FOR_KEY ); let it = byvec!("myverycooltbl", "keymap(list, str)"); assert_eq!( parse_table_args_test(it).unwrap_err(), - responses::groups::BAD_TYPE_FOR_KEY + Skyhash2::RSTRING_BAD_TYPE_FOR_KEY ); } } mod entity_parser_tests { use super::parser::Entity; - use crate::protocol::responses; + use crate::protocol::interface::ProtocolSpec; + use crate::protocol::Skyhash2; #[test] fn test_query_full_entity_okay() { let x = byt!("ks:tbl"); - assert_eq!(Entity::from_slice(&x).unwrap(), Entity::Full(b"ks", b"tbl")); + assert_eq!( + Entity::from_slice::(&x).unwrap(), + Entity::Full(b"ks", b"tbl") + ); } #[test] fn test_query_half_entity() { let x = byt!("tbl"); - assert_eq!(Entity::from_slice(&x).unwrap(), Entity::Single(b"tbl")) + assert_eq!( + Entity::from_slice::(&x).unwrap(), + Entity::Single(b"tbl") + ) } #[test] fn test_query_partial_entity() { let x = byt!(":tbl"); - assert_eq!(Entity::from_slice(&x).unwrap(), Entity::Partial(b"tbl")) + assert_eq!( + Entity::from_slice::(&x).unwrap(), + Entity::Partial(b"tbl") + ) } #[test] fn test_query_entity_badexpr() { let x = byt!("ks:"); assert_eq!( - Entity::from_slice(&x).unwrap_err(), - responses::groups::BAD_EXPRESSION + Entity::from_slice::(&x).unwrap_err(), + Skyhash2::RSTRING_BAD_EXPRESSION ); let x = byt!(":"); assert_eq!( - Entity::from_slice(&x).unwrap_err(), - responses::groups::BAD_EXPRESSION + Entity::from_slice::(&x).unwrap_err(), + Skyhash2::RSTRING_BAD_EXPRESSION ); let x = byt!("::"); assert_eq!( - Entity::from_slice(&x).unwrap_err(), - responses::groups::BAD_EXPRESSION + Entity::from_slice::(&x).unwrap_err(), + Skyhash2::RSTRING_BAD_EXPRESSION ); let x = byt!("::ks"); assert_eq!( - Entity::from_slice(&x).unwrap_err(), - responses::groups::BAD_EXPRESSION + Entity::from_slice::(&x).unwrap_err(), + Skyhash2::RSTRING_BAD_EXPRESSION ); let x = byt!("ks::tbl"); assert_eq!( - Entity::from_slice(&x).unwrap_err(), - responses::groups::BAD_EXPRESSION + Entity::from_slice::(&x).unwrap_err(), + Skyhash2::RSTRING_BAD_EXPRESSION ); let x = byt!("ks::"); assert_eq!( - Entity::from_slice(&x).unwrap_err(), - responses::groups::BAD_EXPRESSION + Entity::from_slice::(&x).unwrap_err(), + Skyhash2::RSTRING_BAD_EXPRESSION ); let x = byt!("ks::tbl::"); assert_eq!( - Entity::from_slice(&x).unwrap_err(), - responses::groups::BAD_EXPRESSION + Entity::from_slice::(&x).unwrap_err(), + Skyhash2::RSTRING_BAD_EXPRESSION ); let x = byt!("::ks::tbl::"); assert_eq!( - Entity::from_slice(&x).unwrap_err(), - responses::groups::BAD_EXPRESSION + Entity::from_slice::(&x).unwrap_err(), + Skyhash2::RSTRING_BAD_EXPRESSION ); } @@ -361,21 +372,21 @@ mod entity_parser_tests { fn test_bad_entity_name() { let ename = byt!("$var"); assert_eq!( - Entity::from_slice(&ename).unwrap_err(), - responses::groups::BAD_CONTAINER_NAME + Entity::from_slice::(&ename).unwrap_err(), + Skyhash2::RSTRING_BAD_CONTAINER_NAME ); } #[test] fn ks_or_table_with_preload_or_partmap() { let badname = byt!("PARTMAP"); assert_eq!( - Entity::from_slice(&badname).unwrap_err(), - responses::groups::BAD_CONTAINER_NAME + Entity::from_slice::(&badname).unwrap_err(), + Skyhash2::RSTRING_BAD_CONTAINER_NAME ); let badname = byt!("PRELOAD"); assert_eq!( - Entity::from_slice(&badname).unwrap_err(), - responses::groups::BAD_CONTAINER_NAME + Entity::from_slice::(&badname).unwrap_err(), + Skyhash2::RSTRING_BAD_CONTAINER_NAME ); } } diff --git a/server/src/util/mod.rs b/server/src/util/mod.rs index d7fafcad..efd62261 100644 --- a/server/src/util/mod.rs +++ b/server/src/util/mod.rs @@ -27,10 +27,10 @@ #[macro_use] mod macros; pub mod compiler; -pub mod os; pub mod error; +pub mod os; use crate::actions::{ActionError, ActionResult}; -use crate::protocol::responses::groups; +use crate::protocol::interface::ProtocolSpec; use core::fmt::Debug; use core::future::Future; use core::ops::Deref; @@ -79,15 +79,15 @@ unsafe impl Unwrappable for Option { pub trait UnwrapActionError { fn unwrap_or_custom_aerr(self, e: impl Into) -> ActionResult; - fn unwrap_or_aerr(self) -> ActionResult; + fn unwrap_or_aerr(self) -> ActionResult; } impl UnwrapActionError for Option { fn unwrap_or_custom_aerr(self, e: impl Into) -> ActionResult { self.ok_or_else(|| e.into()) } - fn unwrap_or_aerr(self) -> ActionResult { - self.ok_or_else(|| groups::ACTION_ERR.into()) + fn unwrap_or_aerr(self) -> ActionResult { + self.ok_or_else(|| P::RCODE_ACTION_ERR.into()) } } From 7ec599edcbc8e552d5f33699778af0818c97a1ae Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Sat, 30 Apr 2022 06:58:32 -0700 Subject: [PATCH 08/13] Update bench suite for protocol Turns out that the original PR implementing Skyhash 2.0 did not update the benchmark code. --- server/src/actions/macros.rs | 2 +- server/src/auth/provider.rs | 4 +- server/src/dbnet/connection.rs | 6 +-- server/src/protocol/benches.rs | 79 ----------------------------- server/src/protocol/iter.rs | 22 ++++---- server/src/protocol/mod.rs | 27 ++++------ server/src/protocol/v2/benches.rs | 80 +++++++++++++++++++++++++++++ server/src/protocol/v2/mod.rs | 7 ++- server/src/protocol/v2/tests.rs | 84 ++++++++++++------------------- 9 files changed, 141 insertions(+), 170 deletions(-) delete mode 100644 server/src/protocol/benches.rs create mode 100644 server/src/protocol/v2/benches.rs diff --git a/server/src/actions/macros.rs b/server/src/actions/macros.rs index 3acd02a7..f502c86a 100644 --- a/server/src/actions/macros.rs +++ b/server/src/actions/macros.rs @@ -49,7 +49,7 @@ macro_rules! is_lowbit_unset { #[macro_export] macro_rules! get_tbl { ($entity:expr, $store:expr, $con:expr) => {{ - $crate::actions::translate_ddl_error::>( + $crate::actions::translate_ddl_error::>( $store.get_table($entity), )? }}; diff --git a/server/src/auth/provider.rs b/server/src/auth/provider.rs index 6223e294..e2a9600f 100644 --- a/server/src/auth/provider.rs +++ b/server/src/auth/provider.rs @@ -35,14 +35,12 @@ pub const AUTHKEY_SIZE: usize = 40; /// Size of an authn ID in bytes pub const AUTHID_SIZE: usize = 40; -#[cfg(debug_assertions)] pub mod testsuite_data { + #![allow(unused)] //! Temporary users created by the testsuite in debug mode pub const TESTSUITE_ROOT_USER: &str = "root"; pub const TESTSUITE_TEST_USER: &str = "testuser"; - #[cfg(test)] pub const TESTSUITE_ROOT_TOKEN: &str = "XUOdVKhEONnnGwNwT7WeLqbspDgVtKex0/nwFwBSW7XJxioHwpg6H."; - #[cfg(all(not(feature = "persist-suite"), test))] pub const TESTSUITE_TEST_TOKEN: &str = "mpobAB7EY8vnBs70d/..h1VvfinKIeEJgt1rg4wUkwF6aWCvGGR9le"; } diff --git a/server/src/dbnet/connection.rs b/server/src/dbnet/connection.rs index 339e18b7..eac2f7d6 100644 --- a/server/src/dbnet/connection.rs +++ b/server/src/dbnet/connection.rs @@ -113,7 +113,7 @@ pub mod prelude { Corestore, }, get_tbl, handle_entity, is_lowbit_set, - protocol::interface::{ProtocolRead, ProtocolSpec}, + protocol::interface::ProtocolSpec, queryengine::ActionIter, registry, util::{self, FutureResult, UnwrapActionError, Unwrappable}, @@ -124,8 +124,8 @@ pub mod prelude { /// # The `RawConnection` trait /// /// The `RawConnection` trait has low-level methods that can be used to interface with raw sockets. Any type -/// that successfully implements this trait will get an implementation for `ProtocolRead` which augments and -/// builds on these fundamental methods to provide high-level interfacing with queries. +/// that successfully implements this trait will get an implementation for `ProtocolRead` and `ProtocolWrite` +/// provided that it uses a protocol that implements the `ProtocolSpec` trait. /// /// ## Example of a `RawConnection` object /// Ideally a `RawConnection` object should look like (the generic parameter just exists for doc-tests, just think that diff --git a/server/src/protocol/benches.rs b/server/src/protocol/benches.rs deleted file mode 100644 index f9839349..00000000 --- a/server/src/protocol/benches.rs +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Created on Tue Nov 02 2021 - * - * This file is a part of Skytable - * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source - * NoSQL database written by Sayan Nandan ("the Author") with the - * vision to provide flexibility in data modelling without compromising - * on performance, queryability or scalability. - * - * Copyright (c) 2021, Sayan Nandan - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - * -*/ - -/* - Do note that the result of the benches might actually be slower, than faster! The reason it is so, is simply because of - the fact that we generate owned queries, by copying bytes which adds an overhead, but offers simplicity in writing tests - and/or benches -*/ - -extern crate test; -use super::{element::OwnedElement, OwnedQuery, Parser}; -use bytes::Bytes; -use test::Bencher; - -#[bench] -fn bench_simple_query_string(b: &mut Bencher) { - const PAYLOAD: &[u8] = b"*1\n+5\nsayan\n"; - unsafe { - b.iter(|| { - assert_eq!( - Parser::new(PAYLOAD).parse().unwrap().0.into_owned_query(), - OwnedQuery::SimpleQuery(OwnedElement::String(Bytes::from("sayan"))) - ); - }) - } -} - -#[bench] -fn bench_simple_query_uint(b: &mut Bencher) { - const PAYLOAD: &[u8] = b"*1\n:5\n12345\n"; - unsafe { - b.iter(|| { - assert_eq!( - Parser::new(PAYLOAD).parse().unwrap().0.into_owned_query(), - OwnedQuery::SimpleQuery(OwnedElement::UnsignedInt(12345)) - ); - }) - } -} - -#[bench] -fn bench_simple_query_any_array(b: &mut Bencher) { - const PAYLOAD: &[u8] = b"*1\n~3\n3\nthe\n3\ncat\n6\nmeowed\n"; - unsafe { - b.iter(|| { - assert_eq!( - Parser::new(PAYLOAD).parse().unwrap().0.into_owned_query(), - OwnedQuery::SimpleQuery(OwnedElement::AnyArray(vec![ - "the".into(), - "cat".into(), - "meowed".into() - ])) - ) - }) - } -} diff --git a/server/src/protocol/iter.rs b/server/src/protocol/iter.rs index f34f1e01..172f21fc 100644 --- a/server/src/protocol/iter.rs +++ b/server/src/protocol/iter.rs @@ -75,16 +75,14 @@ impl<'a> AnyArrayIter<'a> { } /// Returns the next value in uppercase pub fn next_uppercase(&mut self) -> Option> { - self.iter.next().map(|v| unsafe { - // SAFETY: Only construction is unsafe, forwarding is not - v.as_slice().to_ascii_uppercase().into_boxed_slice() - }) + self.iter + .next() + .map(|v| v.as_slice().to_ascii_uppercase().into_boxed_slice()) } pub fn next_lowercase(&mut self) -> Option> { - self.iter.next().map(|v| unsafe { - // SAFETY: Only construction is unsafe, forwarding is not - v.as_slice().to_ascii_lowercase().into_boxed_slice() - }) + self.iter + .next() + .map(|v| v.as_slice().to_ascii_lowercase().into_boxed_slice()) } pub unsafe fn next_lowercase_unchecked(&mut self) -> Box<[u8]> { self.next_lowercase().unwrap_or_else(|| impossible!()) @@ -143,7 +141,7 @@ unsafe impl DerefUnsafeSlice for Bytes { impl<'a> Iterator for AnyArrayIter<'a> { type Item = &'a [u8]; fn next(&mut self) -> Option { - self.iter.next().map(|v| unsafe { v.as_slice() }) + self.iter.next().map(|v| v.as_slice()) } fn size_hint(&self) -> (usize, Option) { self.iter.size_hint() @@ -152,7 +150,7 @@ impl<'a> Iterator for AnyArrayIter<'a> { impl<'a> DoubleEndedIterator for AnyArrayIter<'a> { fn next_back(&mut self) -> Option<::Item> { - self.iter.next_back().map(|v| unsafe { v.as_slice() }) + self.iter.next_back().map(|v| v.as_slice()) } } @@ -162,13 +160,13 @@ impl<'a> FusedIterator for AnyArrayIter<'a> {} impl<'a> Iterator for BorrowedAnyArrayIter<'a> { type Item = &'a [u8]; fn next(&mut self) -> Option { - self.iter.next().map(|v| unsafe { v.as_slice() }) + self.iter.next().map(|v| v.as_slice()) } } impl<'a> DoubleEndedIterator for BorrowedAnyArrayIter<'a> { fn next_back(&mut self) -> Option<::Item> { - self.iter.next_back().map(|v| unsafe { v.as_slice() }) + self.iter.next_back().map(|v| v.as_slice()) } } diff --git a/server/src/protocol/mod.rs b/server/src/protocol/mod.rs index 964c48fb..f42ecaca 100644 --- a/server/src/protocol/mod.rs +++ b/server/src/protocol/mod.rs @@ -26,8 +26,6 @@ use crate::corestore::heap_array::HeapArray; use core::{fmt, slice}; -#[cfg(feature = "nightly")] -mod benches; // pub mods pub mod interface; pub mod iter; @@ -73,8 +71,13 @@ impl UnsafeSlice { Self { start_ptr, len } } /// Return self as a slice - pub unsafe fn as_slice(&self) -> &[u8] { - slice::from_raw_parts(self.start_ptr, self.len) + pub fn as_slice(&self) -> &[u8] { + unsafe { + // UNSAFE(@ohsayan): Just like core::slice, we resemble the same idea: + // we assume that the unsafe construction was correct and hence *assume* + // that calling this is safe + slice::from_raw_parts(self.start_ptr, self.len) + } } } @@ -115,11 +118,7 @@ impl SimpleQuery { #[cfg(test)] fn into_owned(self) -> OwnedSimpleQuery { OwnedSimpleQuery { - data: self - .data - .iter() - .map(|v| unsafe { v.as_slice().to_owned() }) - .collect(), + data: self.data.iter().map(|v| v.as_slice().to_owned()).collect(), } } pub fn as_slice(&self) -> &[UnsafeSlice] { @@ -129,7 +128,7 @@ impl SimpleQuery { #[cfg(test)] struct OwnedSimpleQuery { - data: Vec>, + pub data: Vec>, } #[derive(Debug)] @@ -150,11 +149,7 @@ impl PipelinedQuery { data: self .data .iter() - .map(|v| { - v.iter() - .map(|v| unsafe { v.as_slice().to_owned() }) - .collect() - }) + .map(|v| v.iter().map(|v| v.as_slice().to_owned()).collect()) .collect(), } } @@ -162,5 +157,5 @@ impl PipelinedQuery { #[cfg(test)] struct OwnedPipelinedQuery { - data: Vec>>, + pub data: Vec>>, } diff --git a/server/src/protocol/v2/benches.rs b/server/src/protocol/v2/benches.rs new file mode 100644 index 00000000..c1e195f4 --- /dev/null +++ b/server/src/protocol/v2/benches.rs @@ -0,0 +1,80 @@ +/* + * Created on Sat Apr 30 2022 + * + * This file is a part of Skytable + * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source + * NoSQL database written by Sayan Nandan ("the Author") with the + * vision to provide flexibility in data modelling without compromising + * on performance, queryability or scalability. + * + * Copyright (c) 2022, Sayan Nandan + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * +*/ + +extern crate test; +use super::{super::Query, Parser}; +use test::Bencher; + +#[bench] +fn simple_query(b: &mut Bencher) { + const PAYLOAD: &[u8] = b"*3\n3\nSET1\nx3\n100"; + let expected = vec!["SET".to_owned(), "x".to_owned(), "100".to_owned()]; + b.iter(|| { + let (query, forward) = Parser::parse(PAYLOAD).unwrap(); + assert_eq!(forward, PAYLOAD.len()); + let query = if let Query::Simple(sq) = query { + sq + } else { + panic!("Got pipeline instead of simple query"); + }; + let ret: Vec = query + .as_slice() + .iter() + .map(|s| String::from_utf8_lossy(s.as_slice()).to_string()) + .collect(); + assert_eq!(ret, expected) + }); +} + +#[bench] +fn pipelined_query(b: &mut Bencher) { + const PAYLOAD: &[u8] = b"$2\n3\n3\nSET1\nx3\n1002\n3\nGET1\nx"; + let expected = vec![ + vec!["SET".to_owned(), "x".to_owned(), "100".to_owned()], + vec!["GET".to_owned(), "x".to_owned()], + ]; + b.iter(|| { + let (query, forward) = Parser::parse(PAYLOAD).unwrap(); + assert_eq!(forward, PAYLOAD.len()); + let query = if let Query::Pipelined(sq) = query { + sq + } else { + panic!("Got simple instead of pipeline query"); + }; + let ret: Vec> = query + .into_inner() + .iter() + .map(|query| { + query + .as_slice() + .iter() + .map(|v| String::from_utf8_lossy(v.as_slice()).to_string()) + .collect() + }) + .collect(); + assert_eq!(ret, expected) + }); +} diff --git a/server/src/protocol/v2/mod.rs b/server/src/protocol/v2/mod.rs index 1a358ece..8d533e88 100644 --- a/server/src/protocol/v2/mod.rs +++ b/server/src/protocol/v2/mod.rs @@ -31,6 +31,8 @@ use crate::{ protocol::{ParseError, ParseResult, PipelinedQuery, Query, SimpleQuery, UnsafeSlice}, }; use core::mem::transmute; +#[cfg(feature = "nightly")] +mod benches; #[cfg(test)] mod tests; @@ -155,10 +157,7 @@ impl Parser { /// Attempt to read an `usize` from the buffer fn read_usize(&mut self) -> ParseResult { let line = self.read_line_pedantic()?; - let bytes = unsafe { - // UNSAFE(@ohsayan): We just extracted the slice - line.as_slice() - }; + let bytes = line.as_slice(); let mut ret = 0usize; for byte in bytes { if byte.is_ascii_digit() { diff --git a/server/src/protocol/v2/tests.rs b/server/src/protocol/v2/tests.rs index 8046ed8b..5352f96b 100644 --- a/server/src/protocol/v2/tests.rs +++ b/server/src/protocol/v2/tests.rs @@ -67,11 +67,9 @@ fn get_slices(slices: &[&[u8]]) -> Packets { fn ensure_zero_reads(parser: &mut Parser) { let r = parser.read_until(0).unwrap(); - unsafe { - let slice = r.as_slice(); - assert_eq!(slice, b""); - assert!(slice.is_empty()); - } + let slice = r.as_slice(); + assert_eq!(slice, b""); + assert!(slice.is_empty()); } // We do this intentionally for "heap simulation" @@ -317,11 +315,9 @@ fn read_until_nonempty() { ensure_zero_reads(&mut parser); // now read the entire length; should always work let r = parser.read_until(len).unwrap(); - unsafe { - let slice = r.as_slice(); - assert_eq!(slice, src.as_slice()); - assert_eq!(slice.len(), len); - } + let slice = r.as_slice(); + assert_eq!(slice, src.as_slice()); + assert_eq!(slice.len(), len); // even after the buffer is exhausted, `0` should always work ensure_zero_reads(&mut parser); } @@ -346,23 +342,19 @@ fn read_until_not_enough() { fn read_until_more_bytes() { let sample1 = v!(b"abcd1"); let mut p1 = Parser::new(&sample1); - unsafe { - assert_eq!( - p1.read_until(&sample1.len() - 1).unwrap().as_slice(), - &sample1[..&sample1.len() - 1] - ); - // ensure we have not exhasuted - ensure_not_exhausted(&p1); - ensure_remaining(&p1, 1); - } + assert_eq!( + p1.read_until(&sample1.len() - 1).unwrap().as_slice(), + &sample1[..&sample1.len() - 1] + ); + // ensure we have not exhasuted + ensure_not_exhausted(&p1); + ensure_remaining(&p1, 1); let sample2 = v!(b"abcd1234567890!@#$"); let mut p2 = Parser::new(&sample2); - unsafe { - assert_eq!(p2.read_until(4).unwrap().as_slice(), &sample2[..4]); - // ensure we have not exhasuted - ensure_not_exhausted(&p2); - ensure_remaining(&p2, sample2.len() - 4); - } + assert_eq!(p2.read_until(4).unwrap().as_slice(), &sample2[..4]); + // ensure we have not exhasuted + ensure_not_exhausted(&p2); + ensure_remaining(&p2, sample2.len() - 4); } // read_line @@ -370,12 +362,10 @@ fn read_until_more_bytes() { fn read_line_special_case_only_lf() { let b = v!(b"\n"); let mut parser = Parser::new(&b); - unsafe { - let r = parser.read_line().unwrap(); - let slice = r.as_slice(); - assert_eq!(slice, b""); - assert!(slice.is_empty()); - }; + let r = parser.read_line().unwrap(); + let slice = r.as_slice(); + assert_eq!(slice, b""); + assert!(slice.is_empty()); // ensure it is exhausted ensure_exhausted(&parser); } @@ -389,12 +379,10 @@ fn read_line() { assert_eq!(parser.read_line().unwrap_err(), ParseError::NotEnough); } else { // should work - unsafe { - assert_eq!( - parser.read_line().unwrap().as_slice(), - &src.as_slice()[..len - 1] - ); - } + assert_eq!( + parser.read_line().unwrap().as_slice(), + &src.as_slice()[..len - 1] + ); // now, we attempt to read which should work ensure_zero_reads(&mut parser); } @@ -414,9 +402,7 @@ fn read_line_more_bytes() { let sample1 = v!(b"abcd\n1"); let mut p1 = Parser::new(&sample1); let line = p1.read_line().unwrap(); - unsafe { - assert_eq!(line.as_slice(), b"abcd"); - } + assert_eq!(line.as_slice(), b"abcd"); // we should still have one remaining ensure_not_exhausted(&p1); ensure_remaining(&p1, 1); @@ -427,17 +413,13 @@ fn read_line_subsequent_lf() { let sample1 = v!(b"abcd\n1\n"); let mut p1 = Parser::new(&sample1); let line = p1.read_line().unwrap(); - unsafe { - assert_eq!(line.as_slice(), b"abcd"); - } + assert_eq!(line.as_slice(), b"abcd"); // we should still have two octets remaining ensure_not_exhausted(&p1); ensure_remaining(&p1, 2); // and we should be able to read in another line let line = p1.read_line().unwrap(); - unsafe { - assert_eq!(line.as_slice(), b"1"); - } + assert_eq!(line.as_slice(), b"1"); ensure_exhausted(&p1); } @@ -453,12 +435,10 @@ fn read_line_pedantic_okay() { ); } else { // should work - unsafe { - assert_eq!( - parser.read_line_pedantic().unwrap().as_slice(), - &src.as_slice()[..len - 1] - ); - } + assert_eq!( + parser.read_line_pedantic().unwrap().as_slice(), + &src.as_slice()[..len - 1] + ); // now, we attempt to read which should work ensure_zero_reads(&mut parser); } From b5e0f68c884b3f9f9b0ae935bc92a1b0a549a1d7 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Mon, 2 May 2022 10:25:10 -0700 Subject: [PATCH 09/13] Add support for Skyhash 1.0 --- server/src/admin/sys.rs | 8 +- server/src/protocol/interface.rs | 11 + server/src/protocol/mod.rs | 29 +- server/src/protocol/v1/benches.rs | 80 ++++++ server/src/protocol/v1/interface_impls.rs | 294 ++++++++++++++++++++ server/src/protocol/v1/mod.rs | 323 ++++++++++++++++++++++ server/src/protocol/v1/tests.rs | 93 +++++++ server/src/protocol/v2/interface_impls.rs | 17 +- server/src/protocol/v2/mod.rs | 7 +- server/src/tests/mod.rs | 6 +- sky-macros/src/lib.rs | 40 ++- 11 files changed, 876 insertions(+), 32 deletions(-) create mode 100644 server/src/protocol/v1/benches.rs create mode 100644 server/src/protocol/v1/interface_impls.rs create mode 100644 server/src/protocol/v1/mod.rs create mode 100644 server/src/protocol/v1/tests.rs diff --git a/server/src/admin/sys.rs b/server/src/admin/sys.rs index 923e5434..2bb280f2 100644 --- a/server/src/admin/sys.rs +++ b/server/src/admin/sys.rs @@ -25,9 +25,7 @@ */ use crate::{ - corestore::booltable::BoolTable, - dbnet::connection::prelude::*, - protocol::{PROTOCOL_VERSION, PROTOCOL_VERSIONSTRING}, + corestore::booltable::BoolTable, dbnet::connection::prelude::*, storage::v1::interface::DIR_ROOT, }; use ::libsky::VERSION; @@ -56,8 +54,8 @@ action! { } fn sys_info(con: &mut T, iter: &mut ActionIter<'_>) { match unsafe { iter.next_lowercase_unchecked() }.as_ref() { - INFO_PROTOCOL => con.write_string(PROTOCOL_VERSIONSTRING).await?, - INFO_PROTOVER => con.write_float(PROTOCOL_VERSION).await?, + INFO_PROTOCOL => con.write_string(P::PROTOCOL_VERSIONSTRING).await?, + INFO_PROTOVER => con.write_float(P::PROTOCOL_VERSION).await?, INFO_VERSION => con.write_string(VERSION).await?, _ => return util::err(ERR_UNKNOWN_PROPERTY), } diff --git a/server/src/protocol/interface.rs b/server/src/protocol/interface.rs index 9c28afba..48252ce9 100644 --- a/server/src/protocol/interface.rs +++ b/server/src/protocol/interface.rs @@ -38,6 +38,13 @@ use std::io::{Error as IoError, ErrorKind}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}; pub trait ProtocolSpec { + // spec information + + /// The Skyhash protocol version + const PROTOCOL_VERSION: f32; + /// The Skyhash protocol version string (Skyhash-x.y) + const PROTOCOL_VERSIONSTRING: &'static str; + // type symbols const TSYMBOL_STRING: u8; const TSYMBOL_BINARY: u8; @@ -100,6 +107,7 @@ pub trait ProtocolSpec { // full responses const FULLRESP_RCODE_PACKET_ERR: &'static [u8]; + const FULLRESP_RCODE_WRONG_TYPE: &'static [u8]; const FULLRESP_HEYA: &'static [u8]; // LUTs @@ -169,6 +177,9 @@ where Err(ParseError::UnexpectedByte | ParseError::BadPacket) => { return Ok(QueryResult::E(P::FULLRESP_RCODE_PACKET_ERR)); } + Err(ParseError::WrongType) => { + return Ok(QueryResult::E(P::FULLRESP_RCODE_WRONG_TYPE)); + } } } }) diff --git a/server/src/protocol/mod.rs b/server/src/protocol/mod.rs index f42ecaca..100241f4 100644 --- a/server/src/protocol/mod.rs +++ b/server/src/protocol/mod.rs @@ -24,20 +24,28 @@ * */ -use crate::corestore::heap_array::HeapArray; -use core::{fmt, slice}; +#[cfg(test)] +use self::interface::ProtocolSpec; +use { + crate::corestore::heap_array::HeapArray, + core::{fmt, slice}, +}; // pub mods pub mod interface; pub mod iter; // versions +mod v1; mod v2; // endof pub mods -/// The Skyhash protocol version -pub const PROTOCOL_VERSION: f32 = 2.0; -/// The Skyhash protocol version string (Skyhash-x.y) -pub const PROTOCOL_VERSIONSTRING: &str = "Skyhash-2.0"; pub type Skyhash2 = v2::Parser; +pub type Skyhash1 = v1::Parser; +#[cfg(test)] +/// The latest protocol version supported by this version +pub const LATEST_PROTOCOL_VERSION: f32 = Skyhash2::PROTOCOL_VERSION; +#[cfg(test)] +/// The latest protocol version supported by this version (`Skyhash-x.y`) +pub const LATEST_PROTOCOL_VERSIONSTRING: &str = Skyhash2::PROTOCOL_VERSIONSTRING; #[derive(PartialEq)] /// As its name says, an [`UnsafeSlice`] is a terribly unsafe slice. It's guarantess are @@ -90,7 +98,6 @@ pub enum ParseError { /// Didn't get the number of expected bytes NotEnough = 0u8, /// The packet simply contains invalid data - #[allow(dead_code)] // HACK(@ohsayan): rustc can't "guess" the transmutation BadPacket = 1u8, /// The query contains an unexpected byte UnexpectedByte = 2u8, @@ -98,6 +105,8 @@ pub enum ParseError { /// /// This can happen not just for elements but can also happen for their sizes ([`Self::parse_into_u64`]) DatatypeParseFailure = 3u8, + /// The client supplied the wrong query data type for the given query + WrongType = 4u8, } /// A generic result to indicate parsing errors thorugh the [`ParseError`] enum @@ -121,6 +130,9 @@ impl SimpleQuery { data: self.data.iter().map(|v| v.as_slice().to_owned()).collect(), } } + pub const fn new(data: HeapArray) -> Self { + Self { data } + } pub fn as_slice(&self) -> &[UnsafeSlice] { &self.data } @@ -137,6 +149,9 @@ pub struct PipelinedQuery { } impl PipelinedQuery { + pub const fn new(data: HeapArray>) -> Self { + Self { data } + } pub fn len(&self) -> usize { self.data.len() } diff --git a/server/src/protocol/v1/benches.rs b/server/src/protocol/v1/benches.rs new file mode 100644 index 00000000..a47bb970 --- /dev/null +++ b/server/src/protocol/v1/benches.rs @@ -0,0 +1,80 @@ +/* + * Created on Mon May 02 2022 + * + * This file is a part of Skytable + * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source + * NoSQL database written by Sayan Nandan ("the Author") with the + * vision to provide flexibility in data modelling without compromising + * on performance, queryability or scalability. + * + * Copyright (c) 2022, Sayan Nandan + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * +*/ + +extern crate test; +use super::{super::Query, Parser}; +use test::Bencher; + +#[bench] +fn simple_query(b: &mut Bencher) { + const PAYLOAD: &[u8] = b"*1\n~3\n3\nSET\n1\nx\n3\n100\n"; + let expected = vec!["SET".to_owned(), "x".to_owned(), "100".to_owned()]; + b.iter(|| { + let (query, forward) = Parser::parse(PAYLOAD).unwrap(); + assert_eq!(forward, PAYLOAD.len()); + let query = if let Query::Simple(sq) = query { + sq + } else { + panic!("Got pipeline instead of simple query"); + }; + let ret: Vec = query + .as_slice() + .iter() + .map(|s| String::from_utf8_lossy(s.as_slice()).to_string()) + .collect(); + assert_eq!(ret, expected) + }); +} + +#[bench] +fn pipelined_query(b: &mut Bencher) { + const PAYLOAD: &[u8] = b"*2\n~3\n3\nSET\n1\nx\n3\n100\n~2\n3\nGET\n1\nx\n"; + let expected = vec![ + vec!["SET".to_owned(), "x".to_owned(), "100".to_owned()], + vec!["GET".to_owned(), "x".to_owned()], + ]; + b.iter(|| { + let (query, forward) = Parser::parse(PAYLOAD).unwrap(); + assert_eq!(forward, PAYLOAD.len()); + let query = if let Query::Pipelined(sq) = query { + sq + } else { + panic!("Got simple instead of pipeline query"); + }; + let ret: Vec> = query + .into_inner() + .iter() + .map(|query| { + query + .as_slice() + .iter() + .map(|v| String::from_utf8_lossy(v.as_slice()).to_string()) + .collect() + }) + .collect(); + assert_eq!(ret, expected) + }); +} diff --git a/server/src/protocol/v1/interface_impls.rs b/server/src/protocol/v1/interface_impls.rs new file mode 100644 index 00000000..f11a8a21 --- /dev/null +++ b/server/src/protocol/v1/interface_impls.rs @@ -0,0 +1,294 @@ +/* + * Created on Mon May 02 2022 + * + * This file is a part of Skytable + * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source + * NoSQL database written by Sayan Nandan ("the Author") with the + * vision to provide flexibility in data modelling without compromising + * on performance, queryability or scalability. + * + * Copyright (c) 2022, Sayan Nandan + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * +*/ + +use { + crate::{ + corestore::buffers::Integer64, + dbnet::connection::{QueryWithAdvance, RawConnection, Stream}, + protocol::{ + interface::{ProtocolRead, ProtocolSpec, ProtocolWrite}, + ParseError, Skyhash1, + }, + util::FutureResult, + IoResult, + }, + ::sky_macros::compiled_eresp_bytes_v1 as eresp, + tokio::io::AsyncWriteExt, +}; + +impl ProtocolSpec for Skyhash1 { + // spec information + const PROTOCOL_VERSION: f32 = 1.0; + const PROTOCOL_VERSIONSTRING: &'static str = "Skyhash-1.0"; + + // type symbols + const TSYMBOL_STRING: u8 = b'+'; + const TSYMBOL_BINARY: u8 = b'?'; + const TSYMBOL_FLOAT: u8 = b'%'; + const TSYMBOL_INT64: u8 = b':'; + const TSYMBOL_TYPED_ARRAY: u8 = b'@'; + const TSYMBOL_TYPED_NON_NULL_ARRAY: u8 = b'^'; + const TSYMBOL_ARRAY: u8 = b'&'; + const TSYMBOL_FLAT_ARRAY: u8 = b'_'; + + // typed array + const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8] = b"\0\n"; + + // metaframe + const SIMPLE_QUERY_HEADER: &'static [u8] = b"*1\n"; + const PIPELINED_QUERY_FIRST_BYTE: u8 = b'*'; + + // respcodes + /// Response code 0 as a array element + const RCODE_OKAY: &'static [u8] = eresp!("0"); + /// Response code 1 as a array element + const RCODE_NIL: &'static [u8] = eresp!("1"); + /// Response code 2 as a array element + const RCODE_OVERWRITE_ERR: &'static [u8] = eresp!("2"); + /// Response code 3 as a array element + const RCODE_ACTION_ERR: &'static [u8] = eresp!("3"); + /// Response code 4 as a array element + const RCODE_PACKET_ERR: &'static [u8] = eresp!("4"); + /// Response code 5 as a array element + const RCODE_SERVER_ERR: &'static [u8] = eresp!("5"); + /// Response code 6 as a array element + const RCODE_OTHER_ERR_EMPTY: &'static [u8] = eresp!("6"); + /// "Unknown action" error response + const RCODE_UNKNOWN_ACTION: &'static [u8] = eresp!("Unknown action"); + /// Response code 7 + const RCODE_WRONGTYPE_ERR: &'static [u8] = eresp!("7"); + /// Response code 8 + const RCODE_UNKNOWN_DATA_TYPE: &'static [u8] = eresp!("8"); + /// Response code 9 as an array element + const RCODE_ENCODING_ERROR: &'static [u8] = eresp!("9"); + + // respstrings + + /// Snapshot busy error + const RSTRING_SNAPSHOT_BUSY: &'static [u8] = eresp!("err-snapshot-busy"); + /// Snapshot disabled (other error) + const RSTRING_SNAPSHOT_DISABLED: &'static [u8] = eresp!("err-snapshot-disabled"); + /// Duplicate snapshot + const RSTRING_SNAPSHOT_DUPLICATE: &'static [u8] = eresp!("duplicate-snapshot"); + /// Snapshot has illegal name (other error) + const RSTRING_SNAPSHOT_ILLEGAL_NAME: &'static [u8] = eresp!("err-invalid-snapshot-name"); + /// Access after termination signal (other error) + const RSTRING_ERR_ACCESS_AFTER_TERMSIG: &'static [u8] = eresp!("err-access-after-termsig"); + + // keyspace related resps + /// The default container was not set + const RSTRING_DEFAULT_UNSET: &'static [u8] = eresp!("default-container-unset"); + /// The container was not found + const RSTRING_CONTAINER_NOT_FOUND: &'static [u8] = eresp!("container-not-found"); + /// The container is still in use and so cannot be removed + const RSTRING_STILL_IN_USE: &'static [u8] = eresp!("still-in-use"); + /// This is a protected object and hence cannot be accessed + const RSTRING_PROTECTED_OBJECT: &'static [u8] = eresp!("err-protected-object"); + /// The action was applied against the wrong model + const RSTRING_WRONG_MODEL: &'static [u8] = eresp!("wrong-model"); + /// The container already exists + const RSTRING_ALREADY_EXISTS: &'static [u8] = eresp!("err-already-exists"); + /// The container is not ready + const RSTRING_NOT_READY: &'static [u8] = eresp!("not-ready"); + /// A transactional failure occurred + const RSTRING_DDL_TRANSACTIONAL_FAILURE: &'static [u8] = eresp!("transactional-failure"); + /// An unknown DDL query was run + const RSTRING_UNKNOWN_DDL_QUERY: &'static [u8] = eresp!("unknown-ddl-query"); + /// The expression for a DDL query was malformed + const RSTRING_BAD_EXPRESSION: &'static [u8] = eresp!("malformed-expression"); + /// An unknown model was passed in a DDL query + const RSTRING_UNKNOWN_MODEL: &'static [u8] = eresp!("unknown-model"); + /// Too many arguments were passed to model constructor + const RSTRING_TOO_MANY_ARGUMENTS: &'static [u8] = eresp!("too-many-args"); + /// The container name is too long + const RSTRING_CONTAINER_NAME_TOO_LONG: &'static [u8] = eresp!("container-name-too-long"); + /// The container name contains invalid characters + const RSTRING_BAD_CONTAINER_NAME: &'static [u8] = eresp!("bad-container-name"); + /// An unknown inspect query + const RSTRING_UNKNOWN_INSPECT_QUERY: &'static [u8] = eresp!("unknown-inspect-query"); + /// An unknown table property was passed + const RSTRING_UNKNOWN_PROPERTY: &'static [u8] = eresp!("unknown-property"); + /// The keyspace is not empty and hence cannot be removed + const RSTRING_KEYSPACE_NOT_EMPTY: &'static [u8] = eresp!("keyspace-not-empty"); + /// Bad type supplied in a DDL query for the key + const RSTRING_BAD_TYPE_FOR_KEY: &'static [u8] = eresp!("bad-type-for-key"); + /// The index for the provided list was non-existent + const RSTRING_LISTMAP_BAD_INDEX: &'static [u8] = eresp!("bad-list-index"); + /// The list is empty + const RSTRING_LISTMAP_LIST_IS_EMPTY: &'static [u8] = eresp!("list-is-empty"); + + // full responses + const FULLRESP_RCODE_PACKET_ERR: &'static [u8] = b"*1\n!1\n4\n"; + const FULLRESP_RCODE_WRONG_TYPE: &'static [u8] = b"*1\n!1\n7\n"; + const FULLRESP_HEYA: &'static [u8] = b"*1\n+4\nHEY!\n"; +} + +impl ProtocolRead for T +where + T: RawConnection + Send + Sync, + Strm: Stream, +{ + fn try_query(&self) -> Result { + Skyhash1::parse(self.get_buffer()) + } +} + +impl ProtocolWrite for T +where + T: RawConnection + Send + Sync, + Strm: Stream, +{ + fn write_string<'life0, 'life1, 'ret_life>( + &'life0 mut self, + string: &'life1 str, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash1::TSYMBOL_STRING]).await?; + // length + let len_bytes = Integer64::from(string.len()); + stream.write_all(&len_bytes).await?; + // LF + stream.write_all(&[Skyhash1::LF]).await?; + // payload + stream.write_all(string.as_bytes()).await?; + // final LF + stream.write_all(&[Skyhash1::LF]).await + }) + } + fn write_binary<'life0, 'life1, 'ret_life>( + &'life0 mut self, + binary: &'life1 [u8], + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash1::TSYMBOL_BINARY]).await?; + // length + let len_bytes = Integer64::from(binary.len()); + stream.write_all(&len_bytes).await?; + // LF + stream.write_all(&[Skyhash1::LF]).await?; + // payload + stream.write_all(binary).await?; + // final LF + stream.write_all(&[Skyhash1::LF]).await + }) + } + fn write_usize<'life0, 'ret_life>( + &'life0 mut self, + size: usize, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { self.write_int64(size as _).await }) + } + fn write_int64<'life0, 'ret_life>( + &'life0 mut self, + int: u64, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash1::TSYMBOL_INT64]).await?; + // get body and sizeline + let body = Integer64::from(int); + let body_len = Integer64::from(body.len()); + // len of body + stream.write_all(&body_len).await?; + // sizeline LF + stream.write_all(&[Skyhash1::LF]).await?; + // body + stream.write_all(&body).await?; + // LF + stream.write_all(&[Skyhash1::LF]).await + }) + } + fn write_float<'life0, 'ret_life>( + &'life0 mut self, + float: f32, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // tsymbol + stream.write_all(&[Skyhash1::TSYMBOL_FLOAT]).await?; + // get body and sizeline + let body = float.to_string(); + let body = body.as_bytes(); + let sizeline = Integer64::from(body.len()); + // sizeline + stream.write_all(&sizeline).await?; + // sizeline LF + stream.write_all(&[Skyhash1::LF]).await?; + // body + stream.write_all(body).await?; + // LF + stream.write_all(&[Skyhash1::LF]).await + }) + } + fn write_typed_array_element<'life0, 'life1, 'ret_life>( + &'life0 mut self, + element: &'life1 [u8], + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // len + stream.write_all(&Integer64::from(element.len())).await?; + // LF + stream.write_all(&[Skyhash1::LF]).await?; + // body + stream.write_all(element).await?; + // LF + stream.write_all(&[Skyhash1::LF]).await + }) + } +} diff --git a/server/src/protocol/v1/mod.rs b/server/src/protocol/v1/mod.rs new file mode 100644 index 00000000..23545309 --- /dev/null +++ b/server/src/protocol/v1/mod.rs @@ -0,0 +1,323 @@ +/* + * Created on Sat Apr 30 2022 + * + * This file is a part of Skytable + * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source + * NoSQL database written by Sayan Nandan ("the Author") with the + * vision to provide flexibility in data modelling without compromising + * on performance, queryability or scalability. + * + * Copyright (c) 2022, Sayan Nandan + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * +*/ + +use super::{ParseError, ParseResult, PipelinedQuery, Query, SimpleQuery, UnsafeSlice}; +use crate::{ + corestore::heap_array::{HeapArray, HeapArrayWriter}, + dbnet::connection::QueryWithAdvance, +}; +use core::mem::transmute; + +mod interface_impls; +// test and bench modules +#[cfg(feature = "nightly")] +mod benches; +#[cfg(test)] +mod tests; + +/// A parser for Skyhash 1.0 +/// +/// Packet structure example (simple query): +/// ```text +/// *1\n +/// ~3\n +/// 3\n +/// SET\n +/// 1\n +/// x\n +/// 3\n +/// 100\n +/// ``` +pub struct Parser { + end: *const u8, + cursor: *const u8, +} + +unsafe impl Send for Parser {} +unsafe impl Sync for Parser {} + +impl Parser { + /// Initialize a new parser + fn new(slice: &[u8]) -> Self { + unsafe { + Self { + end: slice.as_ptr().add(slice.len()), + cursor: slice.as_ptr(), + } + } + } +} + +// basic methods +impl Parser { + /// Returns a ptr one byte past the allocation of the buffer + const fn data_end_ptr(&self) -> *const u8 { + self.end + } + /// Returns the position of the cursor + /// WARNING: Deref might led to a segfault + const fn cursor_ptr(&self) -> *const u8 { + self.cursor + } + /// Check how many bytes we have left + fn remaining(&self) -> usize { + self.data_end_ptr() as usize - self.cursor_ptr() as usize + } + /// Check if we have `size` bytes remaining + fn has_remaining(&self, size: usize) -> bool { + self.remaining() >= size + } + /// Check if we have exhausted the buffer + fn exhausted(&self) -> bool { + self.cursor_ptr() >= self.data_end_ptr() + } + /// Check if the buffer is not exhausted + fn not_exhausted(&self) -> bool { + self.cursor_ptr() < self.data_end_ptr() + } + /// Attempts to return the byte pointed at by the cursor. + /// WARNING: The same segfault warning + const unsafe fn get_byte_at_cursor(&self) -> u8 { + *self.cursor_ptr() + } +} + +// mutable refs +impl Parser { + /// Increment the cursor by `by` positions + unsafe fn incr_cursor_by(&mut self, by: usize) { + self.cursor = self.cursor.add(by); + } + /// Increment the position of the cursor by one position + unsafe fn incr_cursor(&mut self) { + self.incr_cursor_by(1); + } +} + +// utility methods +impl Parser { + /// Returns true if the cursor will give a char, but if `this_if_nothing_ahead` is set + /// to true, then if no byte is ahead, it will still return true + fn will_cursor_give_char(&self, ch: u8, true_if_nothing_ahead: bool) -> ParseResult { + if self.exhausted() { + // nothing left + if true_if_nothing_ahead { + Ok(true) + } else { + Err(ParseError::NotEnough) + } + } else if unsafe { self.get_byte_at_cursor().eq(&ch) } { + Ok(true) + } else { + Ok(false) + } + } + /// Check if the current cursor will give an LF + fn will_cursor_give_linefeed(&self) -> ParseResult { + self.will_cursor_give_char(b'\n', false) + } + /// Gets the _next element. **The cursor should be at the tsymbol (passed)** + fn _next(&mut self) -> ParseResult { + let element_size = self.read_usize()?; + self.read_until(element_size) + } +} + +// higher level abstractions +impl Parser { + /// Attempt to read `len` bytes + fn read_until(&mut self, len: usize) -> ParseResult { + if self.has_remaining(len) { + unsafe { + // UNSAFE(@ohsayan): Already verified lengths + let slice = UnsafeSlice::new(self.cursor_ptr(), len); + self.incr_cursor_by(len); + Ok(slice) + } + } else { + Err(ParseError::NotEnough) + } + } + /// Attempt to read a line, **rejecting an empty payload** + fn read_line_pedantic(&mut self) -> ParseResult { + let start_ptr = self.cursor_ptr(); + unsafe { + while self.not_exhausted() && self.get_byte_at_cursor() != b'\n' { + self.incr_cursor(); + } + let len = self.cursor_ptr() as usize - start_ptr as usize; + let has_lf = self.not_exhausted() && self.get_byte_at_cursor() == b'\n'; + if has_lf && len != 0 { + self.incr_cursor(); // skip LF + Ok(UnsafeSlice::new(start_ptr, len)) + } else { + // just some silly hackery + Err(transmute(has_lf)) + } + } + } + /// Attempt to read an `usize` from the buffer + fn read_usize(&mut self) -> ParseResult { + let line = self.read_line_pedantic()?; + let bytes = line.as_slice(); + let mut ret = 0usize; + for byte in bytes { + if byte.is_ascii_digit() { + ret = match ret.checked_mul(10) { + Some(r) => r, + None => return Err(ParseError::DatatypeParseFailure), + }; + ret = match ret.checked_add((byte & 0x0F) as _) { + Some(r) => r, + None => return Err(ParseError::DatatypeParseFailure), + }; + } else { + return Err(ParseError::DatatypeParseFailure); + } + } + Ok(ret) + } + /// Parse the next blob. **The cursor should be at the tsymbol (passed)** + fn parse_next_blob(&mut self) -> ParseResult { + { + let chunk = self._next()?; + if self.will_cursor_give_linefeed()? { + unsafe { + // UNSAFE(@ohsayan): We know that the buffer is not exhausted + // due to the above condition + self.incr_cursor(); + } + Ok(chunk) + } else { + Err(ParseError::UnexpectedByte) + } + } + } +} + +// query abstractions +impl Parser { + /// The buffer should resemble the below structure: + /// ``` + /// ~\n + /// \n + /// \n + /// \n + /// \n + /// ... + /// ``` + fn _parse_simple_query(&mut self) -> ParseResult> { + if self.not_exhausted() { + if unsafe { self.get_byte_at_cursor() } != b'~' { + // we need an any array + return Err(ParseError::WrongType); + } + unsafe { + // UNSAFE(@ohsayan): Just checked length + self.incr_cursor(); + } + let query_count = self.read_usize()?; + let mut writer = HeapArrayWriter::with_capacity(query_count); + for i in 0..query_count { + unsafe { + // UNSAFE(@ohsayan): The index of the for loop ensures that + // we never attempt to write to a bad memory location + writer.write_to_index(i, self.parse_next_blob()?); + } + } + Ok(unsafe { + // UNSAFE(@ohsayan): If we've reached here, then we have initialized + // all the queries + writer.finish() + }) + } else { + Err(ParseError::NotEnough) + } + } + fn parse_simple_query(&mut self) -> ParseResult { + Ok(SimpleQuery::new(self._parse_simple_query()?)) + } + /// The buffer should resemble the following structure: + /// ```text + /// # query 1 + /// ~\n + /// \n + /// \n + /// \n + /// \n + /// # query 2 + /// ~\n + /// \n + /// \n + /// \n + /// \n + /// ... + /// ``` + fn parse_pipelined_query(&mut self, length: usize) -> ParseResult { + let mut writer = HeapArrayWriter::with_capacity(length); + for i in 0..length { + unsafe { + // UNSAFE(@ohsayan): The above condition guarantees that the index + // never causes an overflow + writer.write_to_index(i, self._parse_simple_query()?); + } + } + unsafe { + // UNSAFE(@ohsayan): if we reached here, then we have inited everything + Ok(PipelinedQuery::new(writer.finish())) + } + } + fn _parse(&mut self) -> ParseResult { + if self.not_exhausted() { + let first_byte = unsafe { + // UNSAFE(@ohsayan): Just checked if buffer is exhausted or not + self.get_byte_at_cursor() + }; + if first_byte != b'*' { + // unknown query scheme, so it's a bad packet + return Err(ParseError::BadPacket); + } + unsafe { + // UNSAFE(@ohsayan): Checked buffer len and incremented, so we're good + self.incr_cursor() + }; + let query_count = self.read_usize()?; // get the length + if query_count == 1 { + Ok(Query::Simple(self.parse_simple_query()?)) + } else { + Ok(Query::Pipelined(self.parse_pipelined_query(query_count)?)) + } + } else { + Err(ParseError::NotEnough) + } + } + pub fn parse(buf: &[u8]) -> ParseResult { + let mut slf = Self::new(buf); + let body = slf._parse()?; + let consumed = slf.cursor_ptr() as usize - buf.as_ptr() as usize; + Ok((body, consumed)) + } +} diff --git a/server/src/protocol/v1/tests.rs b/server/src/protocol/v1/tests.rs new file mode 100644 index 00000000..fb2205c5 --- /dev/null +++ b/server/src/protocol/v1/tests.rs @@ -0,0 +1,93 @@ +/* + * Created on Mon May 02 2022 + * + * This file is a part of Skytable + * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source + * NoSQL database written by Sayan Nandan ("the Author") with the + * vision to provide flexibility in data modelling without compromising + * on performance, queryability or scalability. + * + * Copyright (c) 2022, Sayan Nandan + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * +*/ + +use { + super::Parser, + crate::protocol::{ParseError, Query}, +}; + +#[cfg(test)] +const SQPAYLOAD: &[u8] = b"*1\n~3\n3\nSET\n1\nx\n3\n100\n"; +#[cfg(test)] +const PQPAYLOAD: &[u8] = b"*2\n~3\n3\nSET\n1\nx\n3\n100\n~2\n3\nGET\n1\nx\n"; + +#[test] +fn parse_simple_query() { + let payload = SQPAYLOAD.to_vec(); + let (q, f) = Parser::parse(&payload).unwrap(); + let q: Vec = if let Query::Simple(q) = q { + q.as_slice() + .iter() + .map(|v| String::from_utf8_lossy(v.as_slice()).to_string()) + .collect() + } else { + panic!("Expected simple query") + }; + assert_eq!(f, payload.len()); + assert_eq!(q, vec!["SET".to_owned(), "x".into(), "100".into()]); +} + +#[test] +fn parse_simple_query_incomplete() { + for i in 0..SQPAYLOAD.len() - 1 { + let slice = &SQPAYLOAD[..i]; + assert_eq!(Parser::parse(slice).unwrap_err(), ParseError::NotEnough); + } +} + +#[test] +fn parse_pipelined_query() { + let payload = PQPAYLOAD.to_vec(); + let (q, f) = Parser::parse(&payload).unwrap(); + let q: Vec> = if let Query::Pipelined(q) = q { + q.into_inner() + .iter() + .map(|sq| { + sq.iter() + .map(|v| String::from_utf8_lossy(v.as_slice()).to_string()) + .collect() + }) + .collect() + } else { + panic!("Expected pipelined query query") + }; + assert_eq!(f, payload.len()); + assert_eq!( + q, + vec![ + vec!["SET".to_owned(), "x".into(), "100".into()], + vec!["GET".into(), "x".into()] + ] + ); +} + +#[test] +fn parse_pipelined_query_incomplete() { + for i in 0..PQPAYLOAD.len() - 1 { + let slice = &PQPAYLOAD[..i]; + assert_eq!(Parser::parse(slice).unwrap_err(), ParseError::NotEnough); + } +} diff --git a/server/src/protocol/v2/interface_impls.rs b/server/src/protocol/v2/interface_impls.rs index fc60c04a..962543a5 100644 --- a/server/src/protocol/v2/interface_impls.rs +++ b/server/src/protocol/v2/interface_impls.rs @@ -38,6 +38,10 @@ use ::sky_macros::compiled_eresp_bytes as eresp; use tokio::io::AsyncWriteExt; impl ProtocolSpec for Skyhash2 { + // spec information + const PROTOCOL_VERSION: f32 = 2.0; + const PROTOCOL_VERSIONSTRING: &'static str = "Skyhash-2.0"; + // type symbols const TSYMBOL_STRING: u8 = b'+'; const TSYMBOL_BINARY: u8 = b'?'; @@ -135,7 +139,8 @@ impl ProtocolSpec for Skyhash2 { const RSTRING_LISTMAP_LIST_IS_EMPTY: &'static [u8] = eresp!("list-is-empty"); // full responses - const FULLRESP_RCODE_PACKET_ERR: &'static [u8] = b"*!1\n4\n"; + const FULLRESP_RCODE_PACKET_ERR: &'static [u8] = b"*!4\n"; + const FULLRESP_RCODE_WRONG_TYPE: &'static [u8] = b"*!7\n"; const FULLRESP_HEYA: &'static [u8] = b"+4\nHEY!"; } @@ -206,15 +211,7 @@ where 'life0: 'ret_life, Self: 'ret_life, { - Box::pin(async move { - let stream = self.get_mut_stream(); - // tsymbol - stream.write_all(&[Skyhash2::TSYMBOL_INT64]).await?; - // body - stream.write_all(&Integer64::from(size)).await?; - // LF - stream.write_all(&[Skyhash2::LF]).await - }) + Box::pin(async move { self.write_int64(size as _).await }) } fn write_int64<'life0, 'ret_life>( &'life0 mut self, diff --git a/server/src/protocol/v2/mod.rs b/server/src/protocol/v2/mod.rs index 8d533e88..c57e47e7 100644 --- a/server/src/protocol/v2/mod.rs +++ b/server/src/protocol/v2/mod.rs @@ -28,6 +28,7 @@ mod interface_impls; use crate::{ corestore::heap_array::HeapArray, + dbnet::connection::QueryWithAdvance, protocol::{ParseError, ParseResult, PipelinedQuery, Query, SimpleQuery, UnsafeSlice}, }; use core::mem::transmute; @@ -205,9 +206,7 @@ impl Parser { } /// Parse a simple query fn next_simple_query(&mut self) -> ParseResult { - Ok(SimpleQuery { - data: self._next_simple_query()?, - }) + Ok(SimpleQuery::new(self._next_simple_query()?)) } /// Parse a pipelined query. This should have passed the `$` tsymbol /// @@ -284,7 +283,7 @@ impl Parser { } // only expose this. don't expose Self::new since that'll be _relatively easier_ to // invalidate invariants for - pub fn parse(buf: &[u8]) -> ParseResult<(Query, usize)> { + pub fn parse(buf: &[u8]) -> ParseResult { let mut slf = Self::new(buf); let body = slf._parse()?; let consumed = slf.cursor_ptr() as usize - buf.as_ptr() as usize; diff --git a/server/src/tests/mod.rs b/server/src/tests/mod.rs index ac4c243a..3f2a5274 100644 --- a/server/src/tests/mod.rs +++ b/server/src/tests/mod.rs @@ -52,7 +52,7 @@ mod tls { } mod sys { - use crate::protocol::{PROTOCOL_VERSION, PROTOCOL_VERSIONSTRING}; + use crate::protocol::{LATEST_PROTOCOL_VERSION, LATEST_PROTOCOL_VERSIONSTRING}; use libsky::VERSION; use sky_macros::dbtest_func as dbtest; use skytable::{query, Element, RespCode}; @@ -79,7 +79,7 @@ mod sys { runeq!( con, query!("sys", "info", "protocol"), - Element::String(PROTOCOL_VERSIONSTRING.to_owned()) + Element::String(LATEST_PROTOCOL_VERSIONSTRING.to_owned()) ) } #[dbtest] @@ -87,7 +87,7 @@ mod sys { runeq!( con, query!("sys", "info", "protover"), - Element::Float(PROTOCOL_VERSION) + Element::Float(LATEST_PROTOCOL_VERSION) ) } #[dbtest] diff --git a/sky-macros/src/lib.rs b/sky-macros/src/lib.rs index 349aa48d..2d66332f 100644 --- a/sky-macros/src/lib.rs +++ b/sky-macros/src/lib.rs @@ -106,18 +106,40 @@ pub fn dbtest_func(args: TokenStream, item: TokenStream) -> TokenStream { /// Get a compile time respcode/respstring array. For example, if you pass: "Unknown action", /// it will return: `!14\nUnknown Action\n` pub fn compiled_eresp_array(tokens: TokenStream) -> TokenStream { - _get_eresp_array(tokens) + _get_eresp_array(tokens, false) } -fn _get_eresp_array(tokens: TokenStream) -> TokenStream { +#[proc_macro] +/// Get a compile time respcode/respstring array. For example, if you pass: "Unknown action", +/// it will return: `!14\n14\nUnknown Action\n` +pub fn compiled_eresp_array_v1(tokens: TokenStream) -> TokenStream { + _get_eresp_array(tokens, true) +} + +fn _get_eresp_array(tokens: TokenStream, sizeline: bool) -> TokenStream { let payload_str = match syn::parse_macro_input!(tokens as Lit) { Lit::Str(st) => st.value(), _ => panic!("Expected a string literal"), }; - let payload_bytes = payload_str.as_bytes(); let mut processed = quote! { b'!', }; + if sizeline { + let payload_len = payload_str.as_bytes().len(); + let payload_len_str = payload_len.to_string(); + let payload_len_bytes = payload_len_str.as_bytes(); + for byte in payload_len_bytes { + processed = quote! { + #processed + #byte, + }; + } + processed = quote! { + #processed + b'\n', + }; + } + let payload_bytes = payload_str.as_bytes(); for byte in payload_bytes { processed = quote! { #processed @@ -145,3 +167,15 @@ pub fn compiled_eresp_bytes(tokens: TokenStream) -> TokenStream { } .into() } + +#[proc_macro] +/// Get a compile time respcode/respstring slice. For example, if you pass: "Unknown action", +/// it will return: `!14\nUnknown Action\n` +pub fn compiled_eresp_bytes_v1(tokens: TokenStream) -> TokenStream { + let ret = compiled_eresp_array_v1(tokens); + let ret = syn::parse_macro_input!(ret as syn::Expr); + quote! { + &#ret + } + .into() +} From 20f039cb8585f0987c05df3dbfc0f8853b8fa8b8 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Tue, 3 May 2022 00:57:16 -0700 Subject: [PATCH 10/13] Enable setting protocol version in configuration --- server/src/actions/mod.rs | 2 +- server/src/arbiter.rs | 12 +- server/src/cli.yml | 6 + server/src/config/cfgcli.rs | 6 + server/src/config/cfgenv.rs | 2 + server/src/config/cfgfile.rs | 5 +- server/src/config/definitions.rs | 58 +++++++- server/src/config/feedback.rs | 17 ++- server/src/config/mod.rs | 55 +++++++- server/src/config/tests.rs | 11 +- server/src/dbnet/mod.rs | 156 ++++++++++++++++------ server/src/dbnet/tcp.rs | 40 +++--- server/src/dbnet/tls.rs | 39 +++--- server/src/protocol/interface.rs | 4 +- server/src/protocol/v1/interface_impls.rs | 4 +- server/src/protocol/v2/interface_impls.rs | 4 +- 16 files changed, 324 insertions(+), 97 deletions(-) diff --git a/server/src/actions/mod.rs b/server/src/actions/mod.rs index 3577964a..346eb367 100644 --- a/server/src/actions/mod.rs +++ b/server/src/actions/mod.rs @@ -135,7 +135,7 @@ pub mod heya { con.write_mono_length_prefixed_with_tsymbol(raw_byte, b'+') .await?; } else { - return util::err(P::FULLRESP_HEYA); + con._write_raw(P::ELEMRESP_HEYA).await?; } Ok(()) } diff --git a/server/src/arbiter.rs b/server/src/arbiter.rs index 1b8e0a3e..663d4399 100644 --- a/server/src/arbiter.rs +++ b/server/src/arbiter.rs @@ -57,6 +57,7 @@ pub async fn run( snapshot, maxcon, auth, + protocol, .. }: ConfigurationSet, restore_filepath: Option, @@ -100,8 +101,15 @@ pub async fn run( let termsig = TerminationSignal::init().map_err(|e| Error::ioerror_extra(e, "binding to signals"))?; // start the server (single or multiple listeners) - let mut server = - dbnet::connect(ports, maxcon, db.clone(), auth_provider, signal.clone()).await?; + let mut server = dbnet::connect( + ports, + protocol, + maxcon, + db.clone(), + auth_provider, + signal.clone(), + ) + .await?; tokio::select! { _ = server.run_server() => {}, diff --git a/server/src/cli.yml b/server/src/cli.yml index cfda65a4..637e8f4c 100644 --- a/server/src/cli.yml +++ b/server/src/cli.yml @@ -115,3 +115,9 @@ args: takes_value: true help: Set the authentication origin key value_name: origin_key + - protover: + required: false + long: protover + takes_value: true + help: Set the protocol version + value_name: protover diff --git a/server/src/config/cfgcli.rs b/server/src/config/cfgcli.rs index 7066af0d..ae369610 100644 --- a/server/src/config/cfgcli.rs +++ b/server/src/config/cfgcli.rs @@ -72,6 +72,12 @@ pub(super) fn parse_cli_args(matches: ArgMatches) -> Configset { ) }; } + // protocol settings + fcli! { + protocol_settings, + matches.value_of("protover"), + "--protover" + }; // server settings fcli!( server_tcp, diff --git a/server/src/config/cfgenv.rs b/server/src/config/cfgenv.rs index b2e4e825..23ead026 100644 --- a/server/src/config/cfgenv.rs +++ b/server/src/config/cfgenv.rs @@ -44,6 +44,8 @@ pub(super) fn parse_env_config() -> Configset { ); }; } + // protocol settings + fenv!(protocol_settings, SKY_PROTOCOL_VERSION); // server settings fenv!(server_tcp, SKY_SYSTEM_HOST, SKY_SYSTEM_PORT); fenv!(server_noart, SKY_SYSTEM_NOART); diff --git a/server/src/config/cfgfile.rs b/server/src/config/cfgfile.rs index dc7dddc8..3748a4b0 100644 --- a/server/src/config/cfgfile.rs +++ b/server/src/config/cfgfile.rs @@ -25,7 +25,8 @@ */ use super::{ - AuthSettings, ConfigSourceParseResult, Configset, Modeset, OptString, TryFromConfigSource, + AuthSettings, ConfigSourceParseResult, Configset, Modeset, OptString, ProtocolVersion, + TryFromConfigSource, }; use serde::Deserialize; use std::net::IpAddr; @@ -59,6 +60,7 @@ pub struct ConfigKeyServer { pub(super) maxclient: Option, /// The deployment mode pub(super) mode: Option, + pub(super) protocol: Option, } /// The BGSAVE section in the config file @@ -175,6 +177,7 @@ pub fn from_file(file: ConfigFile) -> Configset { Optional::some(server.port), "server.port", ); + set.protocol_settings(server.protocol, "server.protocol"); set.server_maxcon(Optional::from(server.maxclient), "server.maxcon"); set.server_noart(Optional::from(server.noart), "server.noart"); set.server_mode(Optional::from(server.mode), "server.mode"); diff --git a/server/src/config/definitions.rs b/server/src/config/definitions.rs index 43853ba2..7eda5e9d 100644 --- a/server/src/config/definitions.rs +++ b/server/src/config/definitions.rs @@ -68,6 +68,54 @@ impl BGSave { } } +#[repr(u8)] +#[derive(Debug, PartialEq)] +pub enum ProtocolVersion { + V1, + V2, +} + +impl Default for ProtocolVersion { + fn default() -> Self { + Self::V2 + } +} + +impl ToString for ProtocolVersion { + fn to_string(&self) -> String { + match self { + Self::V1 => "Skyhash 1.0".to_owned(), + Self::V2 => "Skyhash 2.0".to_owned(), + } + } +} + +struct ProtocolVersionVisitor; + +impl<'de> Visitor<'de> for ProtocolVersionVisitor { + type Value = ProtocolVersion; + fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "a 40 character ASCII string") + } + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + value.parse().map_err(|_| { + E::custom("Invalid value for protocol version. Valid inputs: 1.0, 1.1, 1.2, 2.0") + }) + } +} + +impl<'de> Deserialize<'de> for ProtocolVersion { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_str(ProtocolVersionVisitor) + } +} + /// A `ConfigurationSet` which can be used by main::check_args_or_connect() to bind /// to a `TcpListener` and show the corresponding terminal output for the given /// configuration @@ -87,9 +135,12 @@ pub struct ConfigurationSet { pub mode: Modeset, /// The auth settings pub auth: AuthSettings, + /// The protocol version + pub protocol: ProtocolVersion, } impl ConfigurationSet { + #[allow(clippy::too_many_arguments)] pub const fn new( noart: bool, bgsave: BGSave, @@ -98,6 +149,7 @@ impl ConfigurationSet { maxcon: usize, mode: Modeset, auth: AuthSettings, + protocol: ProtocolVersion, ) -> Self { Self { noart, @@ -107,6 +159,7 @@ impl ConfigurationSet { maxcon, mode, auth, + protocol, } } /// Create a default `ConfigurationSet` with the following setup defaults: @@ -125,6 +178,7 @@ impl ConfigurationSet { MAXIMUM_CONNECTION_LIMIT, Modeset::Dev, AuthSettings::default(), + ProtocolVersion::V2, ) } /// Returns `false` if `noart` is enabled. Otherwise it returns `true` @@ -207,14 +261,14 @@ impl PortConfig { Self::Multi { host, port, ssl } => { format!( "skyhash://{host}:{port} and skyhash-secure://{host}:{tlsport}", - tlsport = ssl.get_port() + tlsport = ssl.get_port(), ) } Self::SecureOnly { host, ssl: SslOpts { port, .. }, } => format!("skyhash-secure://{host}:{port}"), - Self::InsecureOnly { host, port } => format!("skyhash://{host}:{port}"), + Self::InsecureOnly { host, port } => format!("skyhash://{host}:{port}",), } } } diff --git a/server/src/config/feedback.rs b/server/src/config/feedback.rs index 3793dd17..e0d0374d 100644 --- a/server/src/config/feedback.rs +++ b/server/src/config/feedback.rs @@ -66,9 +66,17 @@ impl FeedbackStack { impl fmt::Display for FeedbackStack { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { if !self.is_empty() { - write!(f, "{} {}:", self.feedback_source, self.feedback_type)?; - for err in self.stack.iter() { - write!(f, "\n{}- {}", TAB, err)?; + if self.len() == 1 { + write!( + f, + "{} {}: {}", + self.feedback_source, self.feedback_type, self.stack[0] + )?; + } else { + write!(f, "{} {}:", self.feedback_source, self.feedback_type)?; + for err in self.stack.iter() { + write!(f, "\n{}- {}", TAB, err)?; + } } } Ok(()) @@ -265,8 +273,7 @@ mod test { #[test] fn errorstack_fmt() { const EXPECTED: &str = "\ -Environment errors: - - Invalid value for `SKY_SYSTEM_PORT`. Expected a 16-bit integer\ +Environment errors: Invalid value for `SKY_SYSTEM_PORT`. Expected a 16-bit integer\ "; let mut estk = ErrorStack::new(EMSG_ENV); estk.push("Invalid value for `SKY_SYSTEM_PORT`. Expected a 16-bit integer"); diff --git a/server/src/config/mod.rs b/server/src/config/mod.rs index 72899354..cf3ca67b 100644 --- a/server/src/config/mod.rs +++ b/server/src/config/mod.rs @@ -211,6 +211,34 @@ impl FromStr for OptString { } } +impl FromStr for ProtocolVersion { + type Err = (); + fn from_str(st: &str) -> Result { + match st { + "1" | "1.0" | "1.1" | "1.2" => Ok(Self::V1), + "2" | "2.0" => Ok(Self::V2), + _ => Err(()), + } + } +} + +impl TryFromConfigSource for Option { + fn is_present(&self) -> bool { + self.is_some() + } + fn mutate_failed(self, target: &mut ProtocolVersion, trip: &mut bool) -> bool { + if let Some(v) = self { + *target = v; + *trip = true; + } + false + } + fn try_parse(self) -> ConfigSourceParseResult { + self.map(ConfigSourceParseResult::Okay) + .unwrap_or(ConfigSourceParseResult::Absent) + } +} + impl TryFromConfigSource for OptString { fn is_present(&self) -> bool { self.base.is_some() @@ -225,7 +253,7 @@ impl TryFromConfigSource for OptString { fn try_parse(self) -> ConfigSourceParseResult { self.base .map(|v| ConfigSourceParseResult::Okay(OptString { base: Some(v) })) - .unwrap_or(ConfigSourceParseResult::Okay(OptString::new_null())) + .unwrap_or(ConfigSourceParseResult::Absent) } } @@ -365,6 +393,13 @@ impl Configset { } else { return Err(ConfigError::CfgError(self.estack)); }; + if target.config.protocol != ProtocolVersion::default() { + target.wpush(format!( + "{} is deprecated. Switch to {}", + target.config.protocol.to_string(), + ProtocolVersion::default().to_string() + )); + } if target.is_prod_mode() { self::feedback::evaluate_prod_settings(&target.config).map(|_| target) } else { @@ -374,6 +409,24 @@ impl Configset { } } +// protocol settings +impl Configset { + pub fn protocol_settings( + &mut self, + nproto: impl TryFromConfigSource, + nproto_key: StaticStr, + ) { + let mut proto = ProtocolVersion::default(); + self.try_mutate( + nproto, + &mut proto, + nproto_key, + "a protocol version like 2.0 or 1.0", + ); + self.cfg.protocol = proto; + } +} + // server settings impl Configset { pub fn server_tcp( diff --git a/server/src/config/tests.rs b/server/src/config/tests.rs index becc8305..5822bd5a 100644 --- a/server/src/config/tests.rs +++ b/server/src/config/tests.rs @@ -345,7 +345,7 @@ mod cfg_file_tests { use crate::config::AuthkeyWrapper; use crate::config::{ cfgfile, AuthSettings, BGSave, Configset, ConfigurationSet, Modeset, PortConfig, - SnapshotConfig, SnapshotPref, SslOpts, DEFAULT_IPV4, DEFAULT_PORT, + ProtocolVersion, SnapshotConfig, SnapshotPref, SslOpts, DEFAULT_IPV4, DEFAULT_PORT, }; use crate::dbnet::MAXIMUM_CONNECTION_LIMIT; use std::net::{IpAddr, Ipv6Addr}; @@ -401,6 +401,7 @@ mod cfg_file_tests { maxcon: MAXIMUM_CONNECTION_LIMIT, mode: Modeset::Dev, auth: AuthSettings::default(), + protocol: ProtocolVersion::default(), } ); } @@ -422,6 +423,7 @@ mod cfg_file_tests { maxcon: MAXIMUM_CONNECTION_LIMIT, mode: Modeset::Dev, auth: AuthSettings::default(), + protocol: ProtocolVersion::default(), } ); } @@ -447,7 +449,8 @@ mod cfg_file_tests { ), MAXIMUM_CONNECTION_LIMIT, Modeset::Dev, - AuthSettings::new(AuthkeyWrapper::try_new(crate::TEST_AUTH_ORIGIN_KEY).unwrap()) + AuthSettings::new(AuthkeyWrapper::try_new(crate::TEST_AUTH_ORIGIN_KEY).unwrap()), + ProtocolVersion::default() ) ); } @@ -473,6 +476,7 @@ mod cfg_file_tests { maxcon: MAXIMUM_CONNECTION_LIMIT, mode: Modeset::Dev, auth: AuthSettings::default(), + protocol: ProtocolVersion::default(), } ); } @@ -495,6 +499,7 @@ mod cfg_file_tests { maxcon: MAXIMUM_CONNECTION_LIMIT, mode: Modeset::Dev, auth: AuthSettings::default(), + protocol: ProtocolVersion::default(), } ) } @@ -517,6 +522,7 @@ mod cfg_file_tests { maxcon: MAXIMUM_CONNECTION_LIMIT, mode: Modeset::Dev, auth: AuthSettings::default(), + protocol: ProtocolVersion::default(), } ) } @@ -535,6 +541,7 @@ mod cfg_file_tests { maxcon: MAXIMUM_CONNECTION_LIMIT, mode: Modeset::Dev, auth: AuthSettings::default(), + protocol: ProtocolVersion::default(), } ); } diff --git a/server/src/dbnet/mod.rs b/server/src/dbnet/mod.rs index 37c1eb2c..3eb71665 100644 --- a/server/src/dbnet/mod.rs +++ b/server/src/dbnet/mod.rs @@ -39,19 +39,24 @@ //! 5. Now errors are handled if they occur. Otherwise, the query is executed by `Corestore::execute_query()` //! -use self::tcp::Listener; -use crate::{ - auth::AuthProvider, - config::{PortConfig, SslOpts}, - corestore::Corestore, - util::error::{Error, SkyResult}, - IoResult, -}; -use std::{net::IpAddr, sync::Arc}; -use tls::SslListener; -use tokio::{ - net::TcpListener, - sync::{broadcast, mpsc, Semaphore}, +use { + self::{ + tcp::{Listener, ListenerV1}, + tls::{SslListener, SslListenerV1}, + }, + crate::{ + auth::AuthProvider, + config::{PortConfig, ProtocolVersion, SslOpts}, + corestore::Corestore, + util::error::{Error, SkyResult}, + IoResult, + }, + core::future::Future, + std::{net::IpAddr, sync::Arc}, + tokio::{ + net::TcpListener, + sync::{broadcast, mpsc, Semaphore}, + }, }; pub mod connection; #[macro_use] @@ -160,35 +165,93 @@ impl BaseListener { #[allow(clippy::large_enum_variant)] pub enum MultiListener { SecureOnly(SslListener), + SecureOnlyV1(SslListenerV1), InsecureOnly(Listener), + InsecureOnlyV1(ListenerV1), Multi(Listener, SslListener), + MultiV1(ListenerV1, SslListenerV1), +} + +async fn wait_on_port_futures( + a: impl Future>, + b: impl Future>, +) -> IoResult<()> { + let (e1, e2) = tokio::join!(a, b); + if let Err(e) = e1 { + log::error!("Insecure listener failed with: {}", e); + } + if let Err(e) = e2 { + log::error!("Secure listener failed with: {}", e); + } + Ok(()) } impl MultiListener { /// Create a new `InsecureOnly` listener - pub fn new_insecure_only(base: BaseListener) -> Self { - MultiListener::InsecureOnly(Listener::new(base)) + pub fn new_insecure_only(base: BaseListener, protocol: ProtocolVersion) -> Self { + match protocol { + ProtocolVersion::V2 => MultiListener::InsecureOnly(Listener::new(base)), + ProtocolVersion::V1 => MultiListener::InsecureOnlyV1(ListenerV1::new(base)), + } } /// Create a new `SecureOnly` listener - pub fn new_secure_only(base: BaseListener, ssl: SslOpts) -> SkyResult { - let listener = - SslListener::new_pem_based_ssl_connection(ssl.key, ssl.chain, base, ssl.passfile)?; - Ok(MultiListener::SecureOnly(listener)) + pub fn new_secure_only( + base: BaseListener, + ssl: SslOpts, + protocol: ProtocolVersion, + ) -> SkyResult { + let listener = match protocol { + ProtocolVersion::V2 => { + let listener = SslListener::new_pem_based_ssl_connection( + ssl.key, + ssl.chain, + base, + ssl.passfile, + )?; + MultiListener::SecureOnly(listener) + } + ProtocolVersion::V1 => { + let listener = SslListenerV1::new_pem_based_ssl_connection( + ssl.key, + ssl.chain, + base, + ssl.passfile, + )?; + MultiListener::SecureOnlyV1(listener) + } + }; + Ok(listener) } /// Create a new `Multi` listener that has both a secure and an insecure listener pub async fn new_multi( ssl_base_listener: BaseListener, tcp_base_listener: BaseListener, ssl: SslOpts, + protocol: ProtocolVersion, ) -> SkyResult { - let secure_listener = SslListener::new_pem_based_ssl_connection( - ssl.key, - ssl.chain, - ssl_base_listener, - ssl.passfile, - )?; - let insecure_listener = Listener::new(tcp_base_listener); - Ok(MultiListener::Multi(insecure_listener, secure_listener)) + let mls = match protocol { + ProtocolVersion::V2 => { + let secure_listener = SslListener::new_pem_based_ssl_connection( + ssl.key, + ssl.chain, + ssl_base_listener, + ssl.passfile, + )?; + let insecure_listener = Listener::new(tcp_base_listener); + MultiListener::Multi(insecure_listener, secure_listener) + } + ProtocolVersion::V1 => { + let secure_listener = SslListenerV1::new_pem_based_ssl_connection( + ssl.key, + ssl.chain, + ssl_base_listener, + ssl.passfile, + )?; + let insecure_listener = ListenerV1::new(tcp_base_listener); + MultiListener::MultiV1(insecure_listener, secure_listener) + } + }; + Ok(mls) } /// Start the server /// @@ -197,18 +260,14 @@ impl MultiListener { pub async fn run_server(&mut self) -> IoResult<()> { match self { MultiListener::SecureOnly(secure_listener) => secure_listener.run().await, + MultiListener::SecureOnlyV1(secure_listener) => secure_listener.run().await, MultiListener::InsecureOnly(insecure_listener) => insecure_listener.run().await, + MultiListener::InsecureOnlyV1(insecure_listener) => insecure_listener.run().await, MultiListener::Multi(insecure_listener, secure_listener) => { - let insec = insecure_listener.run(); - let sec = secure_listener.run(); - let (e1, e2) = tokio::join!(insec, sec); - if let Err(e) = e1 { - log::error!("Insecure listener failed with: {}", e); - } - if let Err(e) = e2 { - log::error!("Secure listener failed with: {}", e); - } - Ok(()) + wait_on_port_futures(insecure_listener.run(), secure_listener.run()).await + } + MultiListener::MultiV1(insecure_listener, secure_listener) => { + wait_on_port_futures(insecure_listener.run(), secure_listener.run()).await } } } @@ -218,12 +277,18 @@ impl MultiListener { /// make sure that the data is saved!** pub async fn finish_with_termsig(self) { match self { - MultiListener::InsecureOnly(server) => server.base.release_self().await, - MultiListener::SecureOnly(server) => server.base.release_self().await, + MultiListener::InsecureOnly(Listener { base, .. }) + | MultiListener::SecureOnly(SslListener { base, .. }) + | MultiListener::InsecureOnlyV1(ListenerV1 { base, .. }) + | MultiListener::SecureOnlyV1(SslListenerV1 { base, .. }) => base.release_self().await, MultiListener::Multi(insecure, secure) => { insecure.base.release_self().await; secure.base.release_self().await; } + MultiListener::MultiV1(insecure, secure) => { + insecure.base.release_self().await; + secure.base.release_self().await; + } } } } @@ -231,6 +296,7 @@ impl MultiListener { /// Initialize the database networking pub async fn connect( ports: PortConfig, + protocol: ProtocolVersion, maxcon: usize, db: Corestore, auth: AuthProvider, @@ -250,15 +316,17 @@ pub async fn connect( let description = ports.get_description(); let server = match ports { PortConfig::InsecureOnly { host, port } => { - MultiListener::new_insecure_only(base_listener_init(host, port).await?) - } - PortConfig::SecureOnly { host, ssl } => { - MultiListener::new_secure_only(base_listener_init(host, ssl.port).await?, ssl)? + MultiListener::new_insecure_only(base_listener_init(host, port).await?, protocol) } + PortConfig::SecureOnly { host, ssl } => MultiListener::new_secure_only( + base_listener_init(host, ssl.port).await?, + ssl, + protocol, + )?, PortConfig::Multi { host, port, ssl } => { let secure_listener = base_listener_init(host, ssl.port).await?; let insecure_listener = base_listener_init(host, port).await?; - MultiListener::new_multi(secure_listener, insecure_listener, ssl).await? + MultiListener::new_multi(secure_listener, insecure_listener, ssl, protocol).await? } }; log::info!("Server started on {description}"); diff --git a/server/src/dbnet/tcp.rs b/server/src/dbnet/tcp.rs index 64f7c9ad..0c51f5bc 100644 --- a/server/src/dbnet/tcp.rs +++ b/server/src/dbnet/tcp.rs @@ -24,25 +24,28 @@ * */ -use crate::protocol::{ - interface::{ProtocolRead, ProtocolSpec, ProtocolWrite}, - Skyhash2, -}; -use crate::{ - dbnet::{ - connection::{ConnectionHandler, ExecutorFn}, - BaseListener, Terminator, - }, - protocol, IoResult, -}; -use bytes::BytesMut; -use libsky::BUF_CAP; pub use protocol::{ParseResult, Query}; -use std::{cell::Cell, time::Duration}; -use tokio::{ - io::{AsyncWrite, BufWriter}, - net::TcpStream, - time, +use { + crate::{ + dbnet::{ + connection::{ConnectionHandler, ExecutorFn}, + BaseListener, Terminator, + }, + protocol::{ + self, + interface::{ProtocolRead, ProtocolSpec, ProtocolWrite}, + Skyhash1, Skyhash2, + }, + IoResult, + }, + bytes::BytesMut, + libsky::BUF_CAP, + std::{cell::Cell, time::Duration}, + tokio::{ + io::{AsyncWrite, BufWriter}, + net::TcpStream, + time, + }, }; pub trait BufferedSocketStream: AsyncWrite {} @@ -99,6 +102,7 @@ impl TcpBackoff { } pub type Listener = RawListener; +pub type ListenerV1 = RawListener; /// A listener pub struct RawListener

{ diff --git a/server/src/dbnet/tls.rs b/server/src/dbnet/tls.rs index d258c257..fce6925d 100644 --- a/server/src/dbnet/tls.rs +++ b/server/src/dbnet/tls.rs @@ -24,32 +24,35 @@ * */ -use crate::{ - dbnet::{ - connection::{ConnectionHandler, ExecutorFn}, - tcp::{BufferedSocketStream, Connection, TcpBackoff}, - BaseListener, Terminator, +use { + crate::{ + dbnet::{ + connection::{ConnectionHandler, ExecutorFn}, + tcp::{BufferedSocketStream, Connection, TcpBackoff}, + BaseListener, Terminator, + }, + protocol::{ + interface::{ProtocolRead, ProtocolSpec, ProtocolWrite}, + Skyhash1, Skyhash2, + }, + util::error::{Error, SkyResult}, + IoResult, }, - protocol::{ - interface::{ProtocolRead, ProtocolSpec, ProtocolWrite}, - Skyhash2, + openssl::{ + pkey::PKey, + rsa::Rsa, + ssl::{Ssl, SslAcceptor, SslFiletype, SslMethod}, }, - util::error::{Error, SkyResult}, - IoResult, + std::{fs, pin::Pin}, + tokio::net::TcpStream, + tokio_openssl::SslStream, }; -use openssl::{ - pkey::PKey, - rsa::Rsa, - ssl::{Ssl, SslAcceptor, SslFiletype, SslMethod}, -}; -use std::{fs, pin::Pin}; -use tokio::net::TcpStream; -use tokio_openssl::SslStream; impl BufferedSocketStream for SslStream {} type SslExecutorFn

= ExecutorFn>, SslStream>; pub type SslListener = SslListenerRaw; +pub type SslListenerV1 = SslListenerRaw; pub struct SslListenerRaw

{ pub base: BaseListener, diff --git a/server/src/protocol/interface.rs b/server/src/protocol/interface.rs index 48252ce9..ba657bce 100644 --- a/server/src/protocol/interface.rs +++ b/server/src/protocol/interface.rs @@ -105,10 +105,12 @@ pub trait ProtocolSpec { const RSTRING_LISTMAP_BAD_INDEX: &'static [u8]; const RSTRING_LISTMAP_LIST_IS_EMPTY: &'static [u8]; + // element responses + const ELEMRESP_HEYA: &'static [u8]; + // full responses const FULLRESP_RCODE_PACKET_ERR: &'static [u8]; const FULLRESP_RCODE_WRONG_TYPE: &'static [u8]; - const FULLRESP_HEYA: &'static [u8]; // LUTs const SET_NLUT: BytesNicheLUT = BytesNicheLUT::new( diff --git a/server/src/protocol/v1/interface_impls.rs b/server/src/protocol/v1/interface_impls.rs index f11a8a21..4d230466 100644 --- a/server/src/protocol/v1/interface_impls.rs +++ b/server/src/protocol/v1/interface_impls.rs @@ -140,10 +140,12 @@ impl ProtocolSpec for Skyhash1 { /// The list is empty const RSTRING_LISTMAP_LIST_IS_EMPTY: &'static [u8] = eresp!("list-is-empty"); + // elements + const ELEMRESP_HEYA: &'static [u8] = b"+4\nHEY!\n"; + // full responses const FULLRESP_RCODE_PACKET_ERR: &'static [u8] = b"*1\n!1\n4\n"; const FULLRESP_RCODE_WRONG_TYPE: &'static [u8] = b"*1\n!1\n7\n"; - const FULLRESP_HEYA: &'static [u8] = b"*1\n+4\nHEY!\n"; } impl ProtocolRead for T diff --git a/server/src/protocol/v2/interface_impls.rs b/server/src/protocol/v2/interface_impls.rs index 962543a5..a4325ec9 100644 --- a/server/src/protocol/v2/interface_impls.rs +++ b/server/src/protocol/v2/interface_impls.rs @@ -138,10 +138,12 @@ impl ProtocolSpec for Skyhash2 { /// The list is empty const RSTRING_LISTMAP_LIST_IS_EMPTY: &'static [u8] = eresp!("list-is-empty"); + // elements + const ELEMRESP_HEYA: &'static [u8] = b"+4\nHEY!"; + // full responses const FULLRESP_RCODE_PACKET_ERR: &'static [u8] = b"*!4\n"; const FULLRESP_RCODE_WRONG_TYPE: &'static [u8] = b"*!7\n"; - const FULLRESP_HEYA: &'static [u8] = b"+4\nHEY!"; } impl ProtocolRead for T From 231dd53341d6a9bf4c5ef7eb3773d5ea4cee8620 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Tue, 3 May 2022 10:26:33 -0700 Subject: [PATCH 11/13] Reduce code redundancy by using `RawParser` and `RawParserExt` Also added changelog --- CHANGELOG.md | 12 +++ server/src/protocol/mod.rs | 2 + server/src/protocol/raw_parser.rs | 148 ++++++++++++++++++++++++++++++ server/src/protocol/v1/mod.rs | 124 +++++-------------------- server/src/protocol/v2/mod.rs | 144 ++++------------------------- server/src/protocol/v2/tests.rs | 5 +- 6 files changed, 206 insertions(+), 229 deletions(-) create mode 100644 server/src/protocol/raw_parser.rs diff --git a/CHANGELOG.md b/CHANGELOG.md index d0b48bfb..f9b2b130 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,18 @@ All changes in this project will be noted in this file. +## Version 0.8.0 + +### Additions + +- New protocol: Skyhash 2.0 + - Reduced bandwidth usage (as much as 50%) + - Even simpler client implementations +- Backward compatibility with Skyhash 1.0: + - Simply set the protocol version you want to use in the config file, env vars or pass it as a CLI + argument + - Even faster implementation, even for Skyhash 1.0 + ## Version 0.7.5 ### Additions diff --git a/server/src/protocol/mod.rs b/server/src/protocol/mod.rs index 100241f4..6c5e7fd1 100644 --- a/server/src/protocol/mod.rs +++ b/server/src/protocol/mod.rs @@ -33,6 +33,8 @@ use { // pub mods pub mod interface; pub mod iter; +// internal mods +mod raw_parser; // versions mod v1; mod v2; diff --git a/server/src/protocol/raw_parser.rs b/server/src/protocol/raw_parser.rs new file mode 100644 index 00000000..332614ec --- /dev/null +++ b/server/src/protocol/raw_parser.rs @@ -0,0 +1,148 @@ +/* + * Created on Tue May 03 2022 + * + * This file is a part of Skytable + * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source + * NoSQL database written by Sayan Nandan ("the Author") with the + * vision to provide flexibility in data modelling without compromising + * on performance, queryability or scalability. + * + * Copyright (c) 2022, Sayan Nandan + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + * +*/ + +use { + super::{ParseError, ParseResult, UnsafeSlice}, + core::mem::transmute, +}; + +/// The `RawParser` trait has three methods that implementors must define: +/// - `cursor_ptr` -> Should point to the current position in the buffer for the parser +/// - `cursor_ptr_mut` -> a mutable reference to the cursor +/// - `data_end_ptr` -> a ptr to one byte past the allocated area of the buffer +/// +/// # Safety +/// - `cursor_ptr` must point to a valid location in memory +/// - `data_end_ptr` must point to a valid location in memory, in the **same allocated area** +pub(super) unsafe trait RawParser { + fn cursor_ptr(&self) -> *const u8; + fn cursor_ptr_mut(&mut self) -> &mut *const u8; + fn data_end_ptr(&self) -> *const u8; + /// Check how many bytes we have left + fn remaining(&self) -> usize { + self.data_end_ptr() as usize - self.cursor_ptr() as usize + } + /// Check if we have `size` bytes remaining + fn has_remaining(&self, size: usize) -> bool { + self.remaining() >= size + } + /// Check if we have exhausted the buffer + fn exhausted(&self) -> bool { + self.cursor_ptr() >= self.data_end_ptr() + } + /// Check if the buffer is not exhausted + fn not_exhausted(&self) -> bool { + self.cursor_ptr() < self.data_end_ptr() + } + /// Attempts to return the byte pointed at by the cursor. + /// WARNING: The same segfault warning + unsafe fn get_byte_at_cursor(&self) -> u8 { + *self.cursor_ptr() + } + /// Increment the cursor by `by` positions + unsafe fn incr_cursor_by(&mut self, by: usize) { + let current = *self.cursor_ptr_mut(); + *self.cursor_ptr_mut() = current.add(by); + } + /// Increment the position of the cursor by one position + unsafe fn incr_cursor(&mut self) { + self.incr_cursor_by(1); + } +} + +pub(super) trait RawParserExt: RawParser { + /// Attempt to read `len` bytes + fn read_until(&mut self, len: usize) -> ParseResult { + if self.has_remaining(len) { + unsafe { + // UNSAFE(@ohsayan): Already verified lengths + let slice = UnsafeSlice::new(self.cursor_ptr(), len); + self.incr_cursor_by(len); + Ok(slice) + } + } else { + Err(ParseError::NotEnough) + } + } + #[cfg(test)] + /// Attempt to read a byte slice terminated by an LF + fn read_line(&mut self) -> ParseResult { + let start_ptr = self.cursor_ptr(); + unsafe { + while self.not_exhausted() && self.get_byte_at_cursor() != b'\n' { + self.incr_cursor(); + } + if self.not_exhausted() && self.get_byte_at_cursor() == b'\n' { + let len = self.cursor_ptr() as usize - start_ptr as usize; + self.incr_cursor(); // skip LF + Ok(UnsafeSlice::new(start_ptr, len)) + } else { + Err(ParseError::NotEnough) + } + } + } + /// Attempt to read a line, **rejecting an empty payload** + fn read_line_pedantic(&mut self) -> ParseResult { + let start_ptr = self.cursor_ptr(); + unsafe { + while self.not_exhausted() && self.get_byte_at_cursor() != b'\n' { + self.incr_cursor(); + } + let len = self.cursor_ptr() as usize - start_ptr as usize; + let has_lf = self.not_exhausted() && self.get_byte_at_cursor() == b'\n'; + if has_lf && len != 0 { + self.incr_cursor(); // skip LF + Ok(UnsafeSlice::new(start_ptr, len)) + } else { + // just some silly hackery + Err(transmute(has_lf)) + } + } + } + /// Attempt to read an `usize` from the buffer + fn read_usize(&mut self) -> ParseResult { + let line = self.read_line_pedantic()?; + let bytes = line.as_slice(); + let mut ret = 0usize; + for byte in bytes { + if byte.is_ascii_digit() { + ret = match ret.checked_mul(10) { + Some(r) => r, + None => return Err(ParseError::DatatypeParseFailure), + }; + ret = match ret.checked_add((byte & 0x0F) as _) { + Some(r) => r, + None => return Err(ParseError::DatatypeParseFailure), + }; + } else { + return Err(ParseError::DatatypeParseFailure); + } + } + Ok(ret) + } +} + +impl RawParserExt for T where T: RawParser {} diff --git a/server/src/protocol/v1/mod.rs b/server/src/protocol/v1/mod.rs index 23545309..c0319c74 100644 --- a/server/src/protocol/v1/mod.rs +++ b/server/src/protocol/v1/mod.rs @@ -24,12 +24,16 @@ * */ -use super::{ParseError, ParseResult, PipelinedQuery, Query, SimpleQuery, UnsafeSlice}; -use crate::{ - corestore::heap_array::{HeapArray, HeapArrayWriter}, - dbnet::connection::QueryWithAdvance, +use { + super::{ + raw_parser::{RawParser, RawParserExt}, + ParseError, ParseResult, PipelinedQuery, Query, SimpleQuery, UnsafeSlice, + }, + crate::{ + corestore::heap_array::{HeapArray, HeapArrayWriter}, + dbnet::connection::QueryWithAdvance, + }, }; -use core::mem::transmute; mod interface_impls; // test and bench modules @@ -56,6 +60,18 @@ pub struct Parser { cursor: *const u8, } +unsafe impl RawParser for Parser { + fn cursor_ptr(&self) -> *const u8 { + self.cursor + } + fn cursor_ptr_mut(&mut self) -> &mut *const u8 { + &mut self.cursor + } + fn data_end_ptr(&self) -> *const u8 { + self.end + } +} + unsafe impl Send for Parser {} unsafe impl Sync for Parser {} @@ -71,52 +87,6 @@ impl Parser { } } -// basic methods -impl Parser { - /// Returns a ptr one byte past the allocation of the buffer - const fn data_end_ptr(&self) -> *const u8 { - self.end - } - /// Returns the position of the cursor - /// WARNING: Deref might led to a segfault - const fn cursor_ptr(&self) -> *const u8 { - self.cursor - } - /// Check how many bytes we have left - fn remaining(&self) -> usize { - self.data_end_ptr() as usize - self.cursor_ptr() as usize - } - /// Check if we have `size` bytes remaining - fn has_remaining(&self, size: usize) -> bool { - self.remaining() >= size - } - /// Check if we have exhausted the buffer - fn exhausted(&self) -> bool { - self.cursor_ptr() >= self.data_end_ptr() - } - /// Check if the buffer is not exhausted - fn not_exhausted(&self) -> bool { - self.cursor_ptr() < self.data_end_ptr() - } - /// Attempts to return the byte pointed at by the cursor. - /// WARNING: The same segfault warning - const unsafe fn get_byte_at_cursor(&self) -> u8 { - *self.cursor_ptr() - } -} - -// mutable refs -impl Parser { - /// Increment the cursor by `by` positions - unsafe fn incr_cursor_by(&mut self, by: usize) { - self.cursor = self.cursor.add(by); - } - /// Increment the position of the cursor by one position - unsafe fn incr_cursor(&mut self) { - self.incr_cursor_by(1); - } -} - // utility methods impl Parser { /// Returns true if the cursor will give a char, but if `this_if_nothing_ahead` is set @@ -148,58 +118,6 @@ impl Parser { // higher level abstractions impl Parser { - /// Attempt to read `len` bytes - fn read_until(&mut self, len: usize) -> ParseResult { - if self.has_remaining(len) { - unsafe { - // UNSAFE(@ohsayan): Already verified lengths - let slice = UnsafeSlice::new(self.cursor_ptr(), len); - self.incr_cursor_by(len); - Ok(slice) - } - } else { - Err(ParseError::NotEnough) - } - } - /// Attempt to read a line, **rejecting an empty payload** - fn read_line_pedantic(&mut self) -> ParseResult { - let start_ptr = self.cursor_ptr(); - unsafe { - while self.not_exhausted() && self.get_byte_at_cursor() != b'\n' { - self.incr_cursor(); - } - let len = self.cursor_ptr() as usize - start_ptr as usize; - let has_lf = self.not_exhausted() && self.get_byte_at_cursor() == b'\n'; - if has_lf && len != 0 { - self.incr_cursor(); // skip LF - Ok(UnsafeSlice::new(start_ptr, len)) - } else { - // just some silly hackery - Err(transmute(has_lf)) - } - } - } - /// Attempt to read an `usize` from the buffer - fn read_usize(&mut self) -> ParseResult { - let line = self.read_line_pedantic()?; - let bytes = line.as_slice(); - let mut ret = 0usize; - for byte in bytes { - if byte.is_ascii_digit() { - ret = match ret.checked_mul(10) { - Some(r) => r, - None => return Err(ParseError::DatatypeParseFailure), - }; - ret = match ret.checked_add((byte & 0x0F) as _) { - Some(r) => r, - None => return Err(ParseError::DatatypeParseFailure), - }; - } else { - return Err(ParseError::DatatypeParseFailure); - } - } - Ok(ret) - } /// Parse the next blob. **The cursor should be at the tsymbol (passed)** fn parse_next_blob(&mut self) -> ParseResult { { diff --git a/server/src/protocol/v2/mod.rs b/server/src/protocol/v2/mod.rs index c57e47e7..066c12c2 100644 --- a/server/src/protocol/v2/mod.rs +++ b/server/src/protocol/v2/mod.rs @@ -26,12 +26,14 @@ mod interface_impls; -use crate::{ - corestore::heap_array::HeapArray, - dbnet::connection::QueryWithAdvance, - protocol::{ParseError, ParseResult, PipelinedQuery, Query, SimpleQuery, UnsafeSlice}, +use { + super::{ + raw_parser::{RawParser, RawParserExt}, + ParseError, ParseResult, PipelinedQuery, Query, SimpleQuery, UnsafeSlice, + }, + crate::{corestore::heap_array::HeapArray, dbnet::connection::QueryWithAdvance}, }; -use core::mem::transmute; + #[cfg(feature = "nightly")] mod benches; #[cfg(test)] @@ -43,6 +45,18 @@ pub struct Parser { cursor: *const u8, } +unsafe impl RawParser for Parser { + fn cursor_ptr(&self) -> *const u8 { + self.cursor + } + fn cursor_ptr_mut(&mut self) -> &mut *const u8 { + &mut self.cursor + } + fn data_end_ptr(&self) -> *const u8 { + self.end + } +} + unsafe impl Sync for Parser {} unsafe impl Send for Parser {} @@ -58,126 +72,6 @@ impl Parser { } } -// basic methods -impl Parser { - /// Returns a ptr one byte past the allocation of the buffer - const fn data_end_ptr(&self) -> *const u8 { - self.end - } - /// Returns the position of the cursor - /// WARNING: Deref might led to a segfault - const fn cursor_ptr(&self) -> *const u8 { - self.cursor - } - /// Check how many bytes we have left - fn remaining(&self) -> usize { - self.data_end_ptr() as usize - self.cursor_ptr() as usize - } - /// Check if we have `size` bytes remaining - fn has_remaining(&self, size: usize) -> bool { - self.remaining() >= size - } - #[cfg(test)] - /// Check if we have exhausted the buffer - fn exhausted(&self) -> bool { - self.cursor_ptr() >= self.data_end_ptr() - } - /// Check if the buffer is not exhausted - fn not_exhausted(&self) -> bool { - self.cursor_ptr() < self.data_end_ptr() - } - /// Attempts to return the byte pointed at by the cursor. - /// WARNING: The same segfault warning - const unsafe fn get_byte_at_cursor(&self) -> u8 { - *self.cursor_ptr() - } -} - -// mutable refs -impl Parser { - /// Increment the cursor by `by` positions - unsafe fn incr_cursor_by(&mut self, by: usize) { - self.cursor = self.cursor.add(by); - } - /// Increment the position of the cursor by one position - unsafe fn incr_cursor(&mut self) { - self.incr_cursor_by(1); - } -} - -// higher level abstractions -impl Parser { - /// Attempt to read `len` bytes - fn read_until(&mut self, len: usize) -> ParseResult { - if self.has_remaining(len) { - unsafe { - // UNSAFE(@ohsayan): Already verified lengths - let slice = UnsafeSlice::new(self.cursor_ptr(), len); - self.incr_cursor_by(len); - Ok(slice) - } - } else { - Err(ParseError::NotEnough) - } - } - #[cfg(test)] - /// Attempt to read a byte slice terminated by an LF - fn read_line(&mut self) -> ParseResult { - let start_ptr = self.cursor_ptr(); - unsafe { - while self.not_exhausted() && self.get_byte_at_cursor() != b'\n' { - self.incr_cursor(); - } - if self.not_exhausted() && self.get_byte_at_cursor() == b'\n' { - let len = self.cursor_ptr() as usize - start_ptr as usize; - self.incr_cursor(); // skip LF - Ok(UnsafeSlice::new(start_ptr, len)) - } else { - Err(ParseError::NotEnough) - } - } - } - /// Attempt to read a line, **rejecting an empty payload** - fn read_line_pedantic(&mut self) -> ParseResult { - let start_ptr = self.cursor_ptr(); - unsafe { - while self.not_exhausted() && self.get_byte_at_cursor() != b'\n' { - self.incr_cursor(); - } - let len = self.cursor_ptr() as usize - start_ptr as usize; - let has_lf = self.not_exhausted() && self.get_byte_at_cursor() == b'\n'; - if has_lf && len != 0 { - self.incr_cursor(); // skip LF - Ok(UnsafeSlice::new(start_ptr, len)) - } else { - // just some silly hackery - Err(transmute(has_lf)) - } - } - } - /// Attempt to read an `usize` from the buffer - fn read_usize(&mut self) -> ParseResult { - let line = self.read_line_pedantic()?; - let bytes = line.as_slice(); - let mut ret = 0usize; - for byte in bytes { - if byte.is_ascii_digit() { - ret = match ret.checked_mul(10) { - Some(r) => r, - None => return Err(ParseError::DatatypeParseFailure), - }; - ret = match ret.checked_add((byte & 0x0F) as _) { - Some(r) => r, - None => return Err(ParseError::DatatypeParseFailure), - }; - } else { - return Err(ParseError::DatatypeParseFailure); - } - } - Ok(ret) - } -} - // query impls impl Parser { /// Parse the next simple query. This should have passed the `*` tsymbol diff --git a/server/src/protocol/v2/tests.rs b/server/src/protocol/v2/tests.rs index 5352f96b..d60c9e38 100644 --- a/server/src/protocol/v2/tests.rs +++ b/server/src/protocol/v2/tests.rs @@ -24,7 +24,10 @@ * */ -use super::{Parser, PipelinedQuery, Query, SimpleQuery}; +use super::{ + super::raw_parser::{RawParser, RawParserExt}, + Parser, PipelinedQuery, Query, SimpleQuery, +}; use crate::protocol::{iter::AnyArrayIter, ParseError}; use std::iter::Map; use std::vec::IntoIter as VecIntoIter; From 89067c1fd57fcc07d08ff8ac7b33486a7e828230 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Tue, 3 May 2022 10:58:41 -0700 Subject: [PATCH 12/13] Revise trait definitions --- server/src/protocol/interface.rs | 15 +++++++++++++++ server/src/protocol/raw_parser.rs | 30 ++++++++++++++++++++++++++++-- server/src/protocol/v1/mod.rs | 2 +- server/src/protocol/v2/mod.rs | 2 +- server/src/protocol/v2/tests.rs | 2 +- 5 files changed, 46 insertions(+), 5 deletions(-) diff --git a/server/src/protocol/interface.rs b/server/src/protocol/interface.rs index ba657bce..4ece4b8c 100644 --- a/server/src/protocol/interface.rs +++ b/server/src/protocol/interface.rs @@ -37,6 +37,21 @@ use crate::{ use std::io::{Error as IoError, ErrorKind}; use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}; +/* +NOTE TO SELF (@ohsayan): Why do we split everything into separate traits? To avoid mistakes +in the future. We don't want any action to randomly call `read_query`, which was possible +with the earlier `ProtcolConnectionExt` trait, since it was imported by every action from +the prelude. +- `ProtocolSpec`: this is like a charset definition of the protocol along with some other +good stuff +- `ProtocolRead`: should only read from the stream and never write +- `ProtocolWrite`: should only write data and never read + +These distinctions reduce the likelihood of making mistakes while implementing the traits + +-- Sayan (May, 2022) +*/ + pub trait ProtocolSpec { // spec information diff --git a/server/src/protocol/raw_parser.rs b/server/src/protocol/raw_parser.rs index 332614ec..67606ae9 100644 --- a/server/src/protocol/raw_parser.rs +++ b/server/src/protocol/raw_parser.rs @@ -29,11 +29,26 @@ use { core::mem::transmute, }; +/* +NOTE TO SELF (@ohsayan): The reason we split this into three traits is because: +- `RawParser` is the only one that is to be implemented. Just provide information about the cursor +- `RawParserMeta` provides information about the buffer based on cursor and end ptr information +- `RawParserExt` provides high-level abstractions over `RawParserMeta`. It is like the "super trait" + +These distinctions reduce the likelihood of "accidentally incorrect impls" (we could've easily included +`RawParserMeta` inside `RawParser`). + +-- Sayan (May, 2022) +*/ + /// The `RawParser` trait has three methods that implementors must define: +/// /// - `cursor_ptr` -> Should point to the current position in the buffer for the parser /// - `cursor_ptr_mut` -> a mutable reference to the cursor /// - `data_end_ptr` -> a ptr to one byte past the allocated area of the buffer /// +/// All implementors of `RawParser` get a free implementation for `RawParserMeta` and `RawParserExt` +/// /// # Safety /// - `cursor_ptr` must point to a valid location in memory /// - `data_end_ptr` must point to a valid location in memory, in the **same allocated area** @@ -41,6 +56,12 @@ pub(super) unsafe trait RawParser { fn cursor_ptr(&self) -> *const u8; fn cursor_ptr_mut(&mut self) -> &mut *const u8; fn data_end_ptr(&self) -> *const u8; +} + +/// The `RawParserMeta` trait builds on top of the `RawParser` trait to provide low-level interactions +/// and information with the parser's buffer. It is implemented for any type that implements the `RawParser` +/// trait. Manual implementation is discouraged +pub(super) trait RawParserMeta: RawParser { /// Check how many bytes we have left fn remaining(&self) -> usize { self.data_end_ptr() as usize - self.cursor_ptr() as usize @@ -73,7 +94,12 @@ pub(super) unsafe trait RawParser { } } -pub(super) trait RawParserExt: RawParser { +impl RawParserMeta for T where T: RawParser {} + +/// `RawParserExt` builds on the `RawParser` and `RawParserMeta` traits to provide high level abstractions +/// like reading lines, or a slice of a given length. It is implemented for any type that +/// implements the `RawParser` trait. Manual implementation is discouraged +pub(super) trait RawParserExt: RawParser + RawParserMeta { /// Attempt to read `len` bytes fn read_until(&mut self, len: usize) -> ParseResult { if self.has_remaining(len) { @@ -145,4 +171,4 @@ pub(super) trait RawParserExt: RawParser { } } -impl RawParserExt for T where T: RawParser {} +impl RawParserExt for T where T: RawParser + RawParserMeta {} diff --git a/server/src/protocol/v1/mod.rs b/server/src/protocol/v1/mod.rs index c0319c74..fd0799e9 100644 --- a/server/src/protocol/v1/mod.rs +++ b/server/src/protocol/v1/mod.rs @@ -26,7 +26,7 @@ use { super::{ - raw_parser::{RawParser, RawParserExt}, + raw_parser::{RawParser, RawParserExt, RawParserMeta}, ParseError, ParseResult, PipelinedQuery, Query, SimpleQuery, UnsafeSlice, }, crate::{ diff --git a/server/src/protocol/v2/mod.rs b/server/src/protocol/v2/mod.rs index 066c12c2..7e7bb916 100644 --- a/server/src/protocol/v2/mod.rs +++ b/server/src/protocol/v2/mod.rs @@ -28,7 +28,7 @@ mod interface_impls; use { super::{ - raw_parser::{RawParser, RawParserExt}, + raw_parser::{RawParser, RawParserExt, RawParserMeta}, ParseError, ParseResult, PipelinedQuery, Query, SimpleQuery, UnsafeSlice, }, crate::{corestore::heap_array::HeapArray, dbnet::connection::QueryWithAdvance}, diff --git a/server/src/protocol/v2/tests.rs b/server/src/protocol/v2/tests.rs index d60c9e38..defe4a0c 100644 --- a/server/src/protocol/v2/tests.rs +++ b/server/src/protocol/v2/tests.rs @@ -25,7 +25,7 @@ */ use super::{ - super::raw_parser::{RawParser, RawParserExt}, + super::raw_parser::{RawParser, RawParserExt, RawParserMeta}, Parser, PipelinedQuery, Query, SimpleQuery, }; use crate::protocol::{iter::AnyArrayIter, ParseError}; From 31bdc83108aada97ae688524613d8957910f5af4 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Tue, 3 May 2022 12:17:23 -0700 Subject: [PATCH 13/13] Make auth errors generic over protocol --- server/src/actions/mod.rs | 41 +++++--- server/src/auth/errors.rs | 70 ------------- server/src/auth/mod.rs | 26 ++--- server/src/auth/provider.rs | 115 ++++++++++++---------- server/src/auth/tests.rs | 55 ++++++----- server/src/dbnet/connection.rs | 5 +- server/src/protocol/interface.rs | 88 ++++++++++++++--- server/src/protocol/v1/interface_impls.rs | 72 ++++++-------- server/src/protocol/v2/interface_impls.rs | 64 +++++------- server/src/queryengine/mod.rs | 2 +- 10 files changed, 273 insertions(+), 265 deletions(-) delete mode 100644 server/src/auth/errors.rs diff --git a/server/src/actions/mod.rs b/server/src/actions/mod.rs index 346eb367..eb1daf31 100644 --- a/server/src/actions/mod.rs +++ b/server/src/actions/mod.rs @@ -65,6 +65,16 @@ pub enum ActionError { IoError(std::io::Error), } +impl PartialEq for ActionError { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Self::ActionError(a1), Self::ActionError(a2)) => a1 == a2, + (Self::IoError(ioe1), Self::IoError(ioe2)) => ioe1.to_string() == ioe2.to_string(), + _ => false, + } + } +} + impl From<&'static [u8]> for ActionError { fn from(e: &'static [u8]) -> Self { Self::ActionError(e) @@ -79,23 +89,26 @@ impl From for ActionError { #[cold] #[inline(never)] +fn map_ddl_error_to_status(e: DdlError) -> ActionError { + let r = match e { + DdlError::AlreadyExists => P::RSTRING_ALREADY_EXISTS, + DdlError::DdlTransactionFailure => P::RSTRING_DDL_TRANSACTIONAL_FAILURE, + DdlError::DefaultNotFound => P::RSTRING_DEFAULT_UNSET, + DdlError::NotEmpty => P::RSTRING_KEYSPACE_NOT_EMPTY, + DdlError::NotReady => P::RSTRING_NOT_READY, + DdlError::ObjectNotFound => P::RSTRING_CONTAINER_NOT_FOUND, + DdlError::ProtectedObject => P::RSTRING_PROTECTED_OBJECT, + DdlError::StillInUse => P::RSTRING_STILL_IN_USE, + DdlError::WrongModel => P::RSTRING_WRONG_MODEL, + }; + ActionError::ActionError(r) +} + +#[inline(always)] pub fn translate_ddl_error(r: Result) -> Result { match r { Ok(r) => Ok(r), - Err(e) => { - let err = match e { - DdlError::AlreadyExists => P::RSTRING_ALREADY_EXISTS, - DdlError::DdlTransactionFailure => P::RSTRING_DDL_TRANSACTIONAL_FAILURE, - DdlError::DefaultNotFound => P::RSTRING_DEFAULT_UNSET, - DdlError::NotEmpty => P::RSTRING_KEYSPACE_NOT_EMPTY, - DdlError::NotReady => P::RSTRING_NOT_READY, - DdlError::ObjectNotFound => P::RSTRING_CONTAINER_NOT_FOUND, - DdlError::ProtectedObject => P::RSTRING_PROTECTED_OBJECT, - DdlError::StillInUse => P::RSTRING_STILL_IN_USE, - DdlError::WrongModel => P::RSTRING_WRONG_MODEL, - }; - Err(ActionError::ActionError(err)) - } + Err(e) => Err(map_ddl_error_to_status::

(e)), } } diff --git a/server/src/auth/errors.rs b/server/src/auth/errors.rs deleted file mode 100644 index 012bfc99..00000000 --- a/server/src/auth/errors.rs +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Created on Sun Mar 06 2022 - * - * This file is a part of Skytable - * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source - * NoSQL database written by Sayan Nandan ("the Author") with the - * vision to provide flexibility in data modelling without compromising - * on performance, queryability or scalability. - * - * Copyright (c) 2022, Sayan Nandan - * - * This program is free software: you can redistribute it and/or modify - * it under the terms of the GNU Affero General Public License as published by - * the Free Software Foundation, either version 3 of the License, or - * (at your option) any later version. - * - * This program is distributed in the hope that it will be useful, - * but WITHOUT ANY WARRANTY; without even the implied warranty of - * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - * GNU Affero General Public License for more details. - * - * You should have received a copy of the GNU Affero General Public License - * along with this program. If not, see . - * -*/ - -use crate::actions::ActionError; - -/// Skyhash respstring: already claimed (user was already claimed) -pub const AUTH_ERROR_ALREADYCLAIMED: &[u8] = b"!err-auth-already-claimed\n"; -/// Skyhash respcode(10): bad credentials (either bad creds or invalid user) -pub const AUTH_CODE_BAD_CREDENTIALS: &[u8] = b"!10\n"; -/// Skyhash respstring: auth is disabled -pub const AUTH_ERROR_DISABLED: &[u8] = b"!err-auth-disabled\n"; -/// Skyhash respcode(11): Insufficient permissions (same for anonymous user) -pub const AUTH_CODE_PERMS: &[u8] = b"!11\n"; -/// Skyhash respstring: ID is too long -pub const AUTH_ERROR_ILLEGAL_USERNAME: &[u8] = b"!err-auth-illegal-username\n"; -/// Skyhash respstring: ID is protected/in use -pub const AUTH_ERROR_FAILED_TO_DELETE_USER: &[u8] = b"!err-auth-deluser-fail\n"; - -/// Auth erros -#[derive(PartialEq, Debug)] -pub enum AuthError { - /// The auth slot was already claimed - AlreadyClaimed, - /// Bad userid/tokens/keys - BadCredentials, - /// Auth is disabled - Disabled, - /// The action is not available to the current account - PermissionDenied, - /// The user is anonymous and doesn't have the right to execute this - Anonymous, - /// Some other error - Other(&'static [u8]), -} - -impl From for ActionError { - fn from(e: AuthError) -> Self { - let r = match e { - AuthError::AlreadyClaimed => AUTH_ERROR_ALREADYCLAIMED, - AuthError::Anonymous | AuthError::PermissionDenied => AUTH_CODE_PERMS, - AuthError::BadCredentials => AUTH_CODE_BAD_CREDENTIALS, - AuthError::Disabled => AUTH_ERROR_DISABLED, - AuthError::Other(e) => e, - }; - ActionError::ActionError(r) - } -} diff --git a/server/src/auth/mod.rs b/server/src/auth/mod.rs index 652b4e7f..db9d50ea 100644 --- a/server/src/auth/mod.rs +++ b/server/src/auth/mod.rs @@ -38,9 +38,7 @@ mod keys; pub mod provider; -pub use provider::{AuthProvider, AuthResult, Authmap}; -pub mod errors; -pub use errors::AuthError; +pub use provider::{AuthProvider, Authmap}; #[cfg(test)] mod tests; @@ -70,20 +68,20 @@ action! { AUTH_ADDUSER => { ensure_boolean_or_aerr::

(iter.len() == 1)?; // just the username let username = unsafe { iter.next_unchecked() }; - let key = auth.provider_mut().claim_user(username)?; + let key = auth.provider_mut().claim_user::

(username)?; con.write_string(&key).await?; Ok(()) } AUTH_LOGOUT => { ensure_boolean_or_aerr::

(iter.is_empty())?; // nothing else - auth.provider_mut().logout()?; + auth.provider_mut().logout::

()?; auth.swap_executor_to_anonymous(); con._write_raw(P::RCODE_OKAY).await?; Ok(()) } AUTH_DELUSER => { ensure_boolean_or_aerr::

(iter.len() == 1)?; // just the username - auth.provider_mut().delete_user(unsafe { iter.next_unchecked() })?; + auth.provider_mut().delete_user::

(unsafe { iter.next_unchecked() })?; con._write_raw(P::RCODE_OKAY).await?; Ok(()) } @@ -95,12 +93,12 @@ action! { } fn auth_whoami(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { ensure_boolean_or_aerr::

(ActionIter::is_empty(iter))?; - con.write_string(&auth.provider().whoami()?).await?; + con.write_string(&auth.provider().whoami::

()?).await?; Ok(()) } fn auth_listuser(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { ensure_boolean_or_aerr::

(ActionIter::is_empty(iter))?; - let usernames = auth.provider().collect_usernames()?; + let usernames = auth.provider().collect_usernames::

()?; con.write_typed_non_null_array_header(usernames.len(), b'+').await?; for username in usernames { con.write_typed_non_null_array_element(username.as_bytes()).await?; @@ -111,13 +109,15 @@ action! { let newkey = match iter.len() { 1 => { // so this fella thinks they're root - auth.provider().regenerate(unsafe {iter.next_unchecked()})? + auth.provider().regenerate::

( + unsafe { iter.next_unchecked() } + )? } 2 => { // so this fella is giving us the origin key let origin = unsafe { iter.next_unchecked() }; let id = unsafe { iter.next_unchecked() }; - auth.provider().regenerate_using_origin(origin, id)? + auth.provider().regenerate_using_origin::

(origin, id)? } _ => return util::err(P::RCODE_ACTION_ERR), }; @@ -127,7 +127,7 @@ action! { fn _auth_claim(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { ensure_boolean_or_aerr::

(iter.len() == 1)?; // just the origin key let origin_key = unsafe { iter.next_unchecked() }; - let key = auth.provider_mut().claim_root(origin_key)?; + let key = auth.provider_mut().claim_root::

(origin_key)?; auth.swap_executor_to_authenticated(); con.write_string(&key).await?; Ok(()) @@ -144,14 +144,14 @@ action! { AUTH_CLAIM => self::_auth_claim(con, auth, &mut iter).await, AUTH_RESTORE => self::auth_restore(con, auth, &mut iter).await, AUTH_WHOAMI => self::auth_whoami(con, auth, &mut iter).await, - _ => util::err(errors::AUTH_CODE_PERMS), + _ => util::err(P::AUTH_CODE_PERMS), } } fn _auth_login(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) { // sweet, where's our username and password ensure_boolean_or_aerr::

(iter.len() == 2)?; // just the uname and pass let (username, password) = unsafe { (iter.next_unchecked(), iter.next_unchecked()) }; - auth.provider_mut().login(username, password)?; + auth.provider_mut().login::

(username, password)?; auth.swap_executor_to_authenticated(); con._write_raw(P::RCODE_OKAY).await?; Ok(()) diff --git a/server/src/auth/provider.rs b/server/src/auth/provider.rs index e2a9600f..35ac27c3 100644 --- a/server/src/auth/provider.rs +++ b/server/src/auth/provider.rs @@ -24,9 +24,12 @@ * */ -use super::{errors, keys, AuthError}; +use super::keys; +use crate::actions::{ActionError, ActionResult}; use crate::corestore::array::Array; use crate::corestore::htable::Coremap; +use crate::protocol::interface::ProtocolSpec; +use crate::util::err; use std::sync::Arc; // constants @@ -54,8 +57,6 @@ const USER_ROOT: AuthID = unsafe { AuthID::from_const(USER_ROOT_ARRAY, 4) }; type AuthID = Array; /// An authn key pub type Authkey = [u8; AUTHKEY_SIZE]; -/// Result of an auth operation -pub type AuthResult = Result; /// Authmap pub type Authmap = Arc>; @@ -119,8 +120,8 @@ impl AuthProvider { pub const fn is_enabled(&self) -> bool { matches!(self.origin, Some(_)) } - pub fn claim_root(&mut self, origin_key: &[u8]) -> AuthResult { - self.verify_origin(origin_key)?; + pub fn claim_root(&mut self, origin_key: &[u8]) -> ActionResult { + self.verify_origin::

(origin_key)?; // the origin key was good, let's try claiming root let (key, store) = keys::generate_full(); if self.authmap.true_if_insert(USER_ROOT, store) { @@ -128,33 +129,33 @@ impl AuthProvider { self.whoami = Some(USER_ROOT); Ok(key) } else { - Err(AuthError::AlreadyClaimed) + err(P::AUTH_ERROR_ALREADYCLAIMED) } } - fn are_you_root(&self) -> AuthResult { - self.ensure_enabled()?; + fn are_you_root(&self) -> ActionResult { + self.ensure_enabled::

()?; match self.whoami.as_ref().map(|v| v.eq(&USER_ROOT)) { Some(v) => Ok(v), - None => Err(AuthError::Anonymous), + None => err(P::AUTH_CODE_PERMS), } } - pub fn claim_user(&self, claimant: &[u8]) -> AuthResult { - self.ensure_root()?; - self._claim_user(claimant) + pub fn claim_user(&self, claimant: &[u8]) -> ActionResult { + self.ensure_root::

()?; + self._claim_user::

(claimant) } - pub fn _claim_user(&self, claimant: &[u8]) -> AuthResult { + pub fn _claim_user(&self, claimant: &[u8]) -> ActionResult { let (key, store) = keys::generate_full(); if self .authmap - .true_if_insert(Self::try_auth_id(claimant)?, store) + .true_if_insert(Self::try_auth_id::

(claimant)?, store) { Ok(key) } else { - Err(AuthError::AlreadyClaimed) + err(P::AUTH_ERROR_ALREADYCLAIMED) } } - pub fn login(&mut self, account: &[u8], token: &[u8]) -> AuthResult<()> { - self.ensure_enabled()?; + pub fn login(&mut self, account: &[u8], token: &[u8]) -> ActionResult<()> { + self.ensure_enabled::

()?; match self .authmap .get(account) @@ -162,84 +163,94 @@ impl AuthProvider { { Some(Some(true)) => { // great, authenticated - self.whoami = Some(Self::try_auth_id(account)?); + self.whoami = Some(Self::try_auth_id::

(account)?); Ok(()) } _ => { // either the password was wrong, or the username was wrong - Err(AuthError::BadCredentials) + err(P::AUTH_CODE_BAD_CREDENTIALS) } } } - pub fn regenerate_using_origin(&self, origin: &[u8], account: &[u8]) -> AuthResult { - self.verify_origin(origin)?; - self._regenerate(account) + pub fn regenerate_using_origin( + &self, + origin: &[u8], + account: &[u8], + ) -> ActionResult { + self.verify_origin::

(origin)?; + self._regenerate::

(account) } - pub fn regenerate(&self, account: &[u8]) -> AuthResult { - self.ensure_root()?; - self._regenerate(account) + pub fn regenerate(&self, account: &[u8]) -> ActionResult { + self.ensure_root::

()?; + self._regenerate::

(account) } /// Regenerate the token for the given user. This returns a new token - fn _regenerate(&self, account: &[u8]) -> AuthResult { - let id = Self::try_auth_id(account)?; + fn _regenerate(&self, account: &[u8]) -> ActionResult { + let id = Self::try_auth_id::

(account)?; let (key, store) = keys::generate_full(); if self.authmap.true_if_update(id, store) { Ok(key) } else { - Err(AuthError::BadCredentials) + err(P::AUTH_CODE_BAD_CREDENTIALS) } } - fn try_auth_id(authid: &[u8]) -> AuthResult { + fn try_auth_id(authid: &[u8]) -> ActionResult { if authid.is_ascii() && authid.len() <= AUTHID_SIZE { Ok(unsafe { // We just verified the length AuthID::from_slice(authid) }) } else { - Err(AuthError::Other(errors::AUTH_ERROR_ILLEGAL_USERNAME)) + err(P::AUTH_ERROR_ILLEGAL_USERNAME) } } - pub fn logout(&mut self) -> AuthResult<()> { - self.ensure_enabled()?; - self.whoami.take().map(|_| ()).ok_or(AuthError::Anonymous) + pub fn logout(&mut self) -> ActionResult<()> { + self.ensure_enabled::

()?; + self.whoami + .take() + .map(|_| ()) + .ok_or(ActionError::ActionError(P::AUTH_CODE_PERMS)) } - fn ensure_enabled(&self) -> AuthResult<()> { - self.origin.as_ref().map(|_| ()).ok_or(AuthError::Disabled) + fn ensure_enabled(&self) -> ActionResult<()> { + self.origin + .as_ref() + .map(|_| ()) + .ok_or(ActionError::ActionError(P::AUTH_ERROR_DISABLED)) } - pub fn verify_origin(&self, origin: &[u8]) -> AuthResult<()> { - if self.get_origin()?.eq(origin) { + pub fn verify_origin(&self, origin: &[u8]) -> ActionResult<()> { + if self.get_origin::

()?.eq(origin) { Ok(()) } else { - Err(AuthError::BadCredentials) + err(P::AUTH_CODE_BAD_CREDENTIALS) } } - fn get_origin(&self) -> AuthResult<&Authkey> { + fn get_origin(&self) -> ActionResult<&Authkey> { match self.origin.as_ref() { Some(key) => Ok(key), - None => Err(AuthError::Disabled), + None => err(P::AUTH_ERROR_DISABLED), } } - fn ensure_root(&self) -> AuthResult<()> { - if self.are_you_root()? { + fn ensure_root(&self) -> ActionResult<()> { + if self.are_you_root::

()? { Ok(()) } else { - Err(AuthError::PermissionDenied) + err(P::AUTH_CODE_PERMS) } } - pub fn delete_user(&self, user: &[u8]) -> AuthResult<()> { - self.ensure_root()?; + pub fn delete_user(&self, user: &[u8]) -> ActionResult<()> { + self.ensure_root::

()?; if user.eq(&USER_ROOT) { // can't delete root! - Err(AuthError::Other(errors::AUTH_ERROR_FAILED_TO_DELETE_USER)) + err(P::AUTH_ERROR_FAILED_TO_DELETE_USER) } else if self.authmap.true_if_removed(user) { Ok(()) } else { - Err(AuthError::BadCredentials) + err(P::AUTH_CODE_BAD_CREDENTIALS) } } /// List all the users - pub fn collect_usernames(&self) -> AuthResult> { - self.ensure_root()?; + pub fn collect_usernames(&self) -> ActionResult> { + self.ensure_root::

()?; Ok(self .authmap .iter() @@ -247,12 +258,12 @@ impl AuthProvider { .collect()) } /// Return the AuthID of the current user - pub fn whoami(&self) -> AuthResult { - self.ensure_enabled()?; + pub fn whoami(&self) -> ActionResult { + self.ensure_enabled::

()?; self.whoami .as_ref() .map(|v| String::from_utf8_lossy(v).to_string()) - .ok_or(AuthError::Anonymous) + .ok_or(ActionError::ActionError(P::AUTH_CODE_PERMS)) } } diff --git a/server/src/auth/tests.rs b/server/src/auth/tests.rs index c083130a..a067eefe 100644 --- a/server/src/auth/tests.rs +++ b/server/src/auth/tests.rs @@ -35,77 +35,88 @@ mod keys { } mod authn { - use crate::auth::{AuthError, AuthProvider}; + use crate::actions::ActionError; + use crate::auth::AuthProvider; + use crate::protocol::{interface::ProtocolSpec, Skyhash2}; const ORIG: &[u8; 40] = b"c4299d190fb9a00626797fcc138c56eae9971664"; #[test] fn claim_root_okay() { let mut provider = AuthProvider::new_blank(Some(*ORIG)); - let _ = provider.claim_root(ORIG).unwrap(); + let _ = provider.claim_root::(ORIG).unwrap(); } #[test] fn claim_root_wrongkey() { let mut provider = AuthProvider::new_blank(Some(*ORIG)); - let claim_err = provider.claim_root(&ORIG[1..]).unwrap_err(); - assert_eq!(claim_err, AuthError::BadCredentials); + let claim_err = provider.claim_root::(&ORIG[1..]).unwrap_err(); + assert_eq!( + claim_err, + ActionError::ActionError(Skyhash2::AUTH_CODE_BAD_CREDENTIALS) + ); } #[test] fn claim_root_disabled() { let mut provider = AuthProvider::new_disabled(); assert_eq!( - provider.claim_root(b"abcd").unwrap_err(), - AuthError::Disabled + provider.claim_root::(b"abcd").unwrap_err(), + ActionError::ActionError(Skyhash2::AUTH_ERROR_DISABLED) ); } #[test] fn claim_root_already_claimed() { let mut provider = AuthProvider::new_blank(Some(*ORIG)); - let _ = provider.claim_root(ORIG).unwrap(); + let _ = provider.claim_root::(ORIG).unwrap(); assert_eq!( - provider.claim_root(ORIG).unwrap_err(), - AuthError::AlreadyClaimed + provider.claim_root::(ORIG).unwrap_err(), + ActionError::ActionError(Skyhash2::AUTH_ERROR_ALREADYCLAIMED) ); } #[test] fn claim_user_okay_with_login() { let mut provider = AuthProvider::new_blank(Some(*ORIG)); // claim root - let rootkey = provider.claim_root(ORIG).unwrap(); + let rootkey = provider.claim_root::(ORIG).unwrap(); // login as root - provider.login(b"root", rootkey.as_bytes()).unwrap(); + provider + .login::(b"root", rootkey.as_bytes()) + .unwrap(); // claim user - let _ = provider.claim_user(b"sayan").unwrap(); + let _ = provider.claim_user::(b"sayan").unwrap(); } #[test] fn claim_user_fail_not_root_with_login() { let mut provider = AuthProvider::new_blank(Some(*ORIG)); // claim root - let rootkey = provider.claim_root(ORIG).unwrap(); + let rootkey = provider.claim_root::(ORIG).unwrap(); // login as root - provider.login(b"root", rootkey.as_bytes()).unwrap(); + provider + .login::(b"root", rootkey.as_bytes()) + .unwrap(); // claim user - let userkey = provider.claim_user(b"user").unwrap(); + let userkey = provider.claim_user::(b"user").unwrap(); // login as user - provider.login(b"user", userkey.as_bytes()).unwrap(); + provider + .login::(b"user", userkey.as_bytes()) + .unwrap(); // now try to claim an user being a non-root account assert_eq!( - provider.claim_user(b"otheruser").unwrap_err(), - AuthError::PermissionDenied + provider.claim_user::(b"otheruser").unwrap_err(), + ActionError::ActionError(Skyhash2::AUTH_CODE_PERMS) ); } #[test] fn claim_user_fail_anonymous() { let mut provider = AuthProvider::new_blank(Some(*ORIG)); // claim root - let _ = provider.claim_root(ORIG).unwrap(); + let _ = provider.claim_root::(ORIG).unwrap(); // logout - provider.logout().unwrap(); + provider.logout::().unwrap(); // try to claim as an anonymous user assert_eq!( - provider.claim_user(b"newuser").unwrap_err(), - AuthError::Anonymous + provider.claim_user::(b"newuser").unwrap_err(), + ActionError::ActionError(Skyhash2::AUTH_CODE_PERMS) ); } } diff --git a/server/src/dbnet/connection.rs b/server/src/dbnet/connection.rs index eac2f7d6..8eda6092 100644 --- a/server/src/dbnet/connection.rs +++ b/server/src/dbnet/connection.rs @@ -37,7 +37,7 @@ use crate::{ actions::{ActionError, ActionResult}, - auth::{self, AuthProvider}, + auth::AuthProvider, corestore::Corestore, dbnet::{ connection::prelude::FutureResult, @@ -312,8 +312,7 @@ where } Query::Pipelined(_) => { con.write_simple_query_header().await?; - con._write_raw(auth::errors::AUTH_CODE_BAD_CREDENTIALS) - .await?; + con._write_raw(P::AUTH_CODE_BAD_CREDENTIALS).await?; } } Ok(()) diff --git a/server/src/protocol/interface.rs b/server/src/protocol/interface.rs index 4ece4b8c..b1e95a34 100644 --- a/server/src/protocol/interface.rs +++ b/server/src/protocol/interface.rs @@ -52,7 +52,10 @@ These distinctions reduce the likelihood of making mistakes while implementing t -- Sayan (May, 2022) */ -pub trait ProtocolSpec { +/// The `ProtocolSpec` trait is used to define the character set and pre-generated elements +/// and responses for a protocol version. To make any actual use of it, you need to implement +/// both the `ProtocolRead` and `ProtocolWrite` for the protocol +pub trait ProtocolSpec: Send + Sync { // spec information /// The Skyhash protocol version @@ -61,91 +64,160 @@ pub trait ProtocolSpec { const PROTOCOL_VERSIONSTRING: &'static str; // type symbols + /// Type symbol for unicode strings const TSYMBOL_STRING: u8; + /// Type symbol for blobs const TSYMBOL_BINARY: u8; + /// Type symbol for float const TSYMBOL_FLOAT: u8; + /// Type symbok for int64 const TSYMBOL_INT64: u8; + /// Type symbol for typed array const TSYMBOL_TYPED_ARRAY: u8; + /// Type symbol for typed non-null array const TSYMBOL_TYPED_NON_NULL_ARRAY: u8; + /// Type symbol for an array const TSYMBOL_ARRAY: u8; + /// Type symbol for a flat array const TSYMBOL_FLAT_ARRAY: u8; // charset + /// The line-feed character or separator const LF: u8 = b'\n'; // metaframe + /// The header for simple queries const SIMPLE_QUERY_HEADER: &'static [u8]; + /// The header for pipelined queries (excluding length, obviously) const PIPELINED_QUERY_FIRST_BYTE: u8; // typed array + /// Null element represenation for a typed array const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8]; // respcodes + /// Respcode 0: Okay const RCODE_OKAY: &'static [u8]; + /// Respcode 1: Nil const RCODE_NIL: &'static [u8]; + /// Respcode 2: Overwrite error const RCODE_OVERWRITE_ERR: &'static [u8]; + /// Respcode 3: Action error const RCODE_ACTION_ERR: &'static [u8]; + /// Respcode 4: Packet error const RCODE_PACKET_ERR: &'static [u8]; + /// Respcode 5: Server error const RCODE_SERVER_ERR: &'static [u8]; + /// Respcode 6: Other error const RCODE_OTHER_ERR_EMPTY: &'static [u8]; + /// Respcode 7: Unknown action const RCODE_UNKNOWN_ACTION: &'static [u8]; + /// Respcode 8: Wrongtype error const RCODE_WRONGTYPE_ERR: &'static [u8]; + /// Respcode 9: Unknown data type error const RCODE_UNKNOWN_DATA_TYPE: &'static [u8]; + /// Respcode 10: Encoding error const RCODE_ENCODING_ERROR: &'static [u8]; // respstrings + /// Respstring when snapshot engine is busy const RSTRING_SNAPSHOT_BUSY: &'static [u8]; + /// Respstring when snapshots are disabled const RSTRING_SNAPSHOT_DISABLED: &'static [u8]; + /// Respstring when duplicate snapshot creation is attempted const RSTRING_SNAPSHOT_DUPLICATE: &'static [u8]; + /// Respstring when snapshot has illegal chars const RSTRING_SNAPSHOT_ILLEGAL_NAME: &'static [u8]; + /// Respstring when a **very bad error** happens (use after termsig) const RSTRING_ERR_ACCESS_AFTER_TERMSIG: &'static [u8]; + /// Respstring when the default container is unset const RSTRING_DEFAULT_UNSET: &'static [u8]; + /// Respstring when the container is not found const RSTRING_CONTAINER_NOT_FOUND: &'static [u8]; + /// Respstring when the container is still in use, but a _free_ op is attempted const RSTRING_STILL_IN_USE: &'static [u8]; + /// Respstring when a protected container is attempted to be accessed/modified const RSTRING_PROTECTED_OBJECT: &'static [u8]; + /// Respstring when an action is not suitable for the current table model const RSTRING_WRONG_MODEL: &'static [u8]; + /// Respstring when the container already exists const RSTRING_ALREADY_EXISTS: &'static [u8]; + /// Respstring when the container is not ready const RSTRING_NOT_READY: &'static [u8]; + /// Respstring when a DDL transaction fails const RSTRING_DDL_TRANSACTIONAL_FAILURE: &'static [u8]; + /// Respstring when an unknow DDL query is run (`CREATE BLAH`, for example) const RSTRING_UNKNOWN_DDL_QUERY: &'static [u8]; + /// Respstring when a bad DDL expression is run const RSTRING_BAD_EXPRESSION: &'static [u8]; + /// Respstring when an unsupported model is attempted to be used during table creation const RSTRING_UNKNOWN_MODEL: &'static [u8]; + /// Respstring when too many arguments are passed to a DDL query const RSTRING_TOO_MANY_ARGUMENTS: &'static [u8]; + /// Respstring when the container name is too long const RSTRING_CONTAINER_NAME_TOO_LONG: &'static [u8]; + /// Respstring when the container name const RSTRING_BAD_CONTAINER_NAME: &'static [u8]; + /// Respstring when an unknown inspect query is run (`INSPECT blah`, for example) const RSTRING_UNKNOWN_INSPECT_QUERY: &'static [u8]; + /// Respstring when an unknown table property is passed during table creation const RSTRING_UNKNOWN_PROPERTY: &'static [u8]; + /// Respstring when a non-empty keyspace is attempted to be dropped const RSTRING_KEYSPACE_NOT_EMPTY: &'static [u8]; + /// Respstring when a bad type is provided for a key in the K/V engine (like using a `list` + /// for the key) const RSTRING_BAD_TYPE_FOR_KEY: &'static [u8]; + /// Respstring when a non-existent index is attempted to be accessed in a list const RSTRING_LISTMAP_BAD_INDEX: &'static [u8]; + /// Respstring when a list is empty and we attempt to access/modify it const RSTRING_LISTMAP_LIST_IS_EMPTY: &'static [u8]; // element responses + /// A string element containing the text "HEY!" const ELEMRESP_HEYA: &'static [u8]; // full responses + /// A **full response** for a packet error const FULLRESP_RCODE_PACKET_ERR: &'static [u8]; + /// A **full response** for a wrongtype error const FULLRESP_RCODE_WRONG_TYPE: &'static [u8]; // LUTs + /// A LUT for SET operations const SET_NLUT: BytesNicheLUT = BytesNicheLUT::new( Self::RCODE_ENCODING_ERROR, Self::RCODE_OKAY, Self::RCODE_OVERWRITE_ERR, ); + /// A LUT for lists const OKAY_BADIDX_NIL_NLUT: BytesNicheLUT = BytesNicheLUT::new( Self::RCODE_NIL, Self::RCODE_OKAY, Self::RSTRING_LISTMAP_BAD_INDEX, ); + /// A LUT for SET operations const OKAY_OVW_BLUT: BytesBoolTable = BytesBoolTable::new(Self::RCODE_OKAY, Self::RCODE_OVERWRITE_ERR); - + /// A LUT for UPDATE operations const UPDATE_NLUT: BytesNicheLUT = BytesNicheLUT::new( Self::RCODE_ENCODING_ERROR, Self::RCODE_OKAY, Self::RCODE_NIL, ); + + // auth error respstrings + /// respstring: already claimed (user was already claimed) + const AUTH_ERROR_ALREADYCLAIMED: &'static [u8]; + /// respcode(10): bad credentials (either bad creds or invalid user) + const AUTH_CODE_BAD_CREDENTIALS: &'static [u8]; + /// respstring: auth is disabled + const AUTH_ERROR_DISABLED: &'static [u8]; + /// respcode(11): Insufficient permissions (same for anonymous user) + const AUTH_CODE_PERMS: &'static [u8]; + /// respstring: ID is too long + const AUTH_ERROR_ILLEGAL_USERNAME: &'static [u8]; + /// respstring: ID is protected/in use + const AUTH_ERROR_FAILED_TO_DELETE_USER: &'static [u8]; } /// # The `ProtocolRead` trait @@ -298,17 +370,7 @@ where where 'life0: 'ret_life, 'life1: 'ret_life, - Self: Send + 'ret_life, - { - Box::pin(async move { - let stream = self.get_mut_stream(); - // - stream.write_all(&[tsymbol]).await?; - stream.write_all(&Integer64::from(data.len())).await?; - stream.write_all(&[P::LF]).await?; - stream.write_all(data).await - }) - } + Self: Send + 'ret_life; /// serialize and write an `&str` to the stream fn write_string<'life0, 'life1, 'ret_life>( &'life0 mut self, diff --git a/server/src/protocol/v1/interface_impls.rs b/server/src/protocol/v1/interface_impls.rs index 4d230466..17bc6626 100644 --- a/server/src/protocol/v1/interface_impls.rs +++ b/server/src/protocol/v1/interface_impls.rs @@ -55,89 +55,52 @@ impl ProtocolSpec for Skyhash1 { const TSYMBOL_FLAT_ARRAY: u8 = b'_'; // typed array - const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8] = b"\0\n"; + const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8] = b"\0"; // metaframe - const SIMPLE_QUERY_HEADER: &'static [u8] = b"*1\n"; - const PIPELINED_QUERY_FIRST_BYTE: u8 = b'*'; + const SIMPLE_QUERY_HEADER: &'static [u8] = b"*"; + const PIPELINED_QUERY_FIRST_BYTE: u8 = b'$'; // respcodes - /// Response code 0 as a array element const RCODE_OKAY: &'static [u8] = eresp!("0"); - /// Response code 1 as a array element const RCODE_NIL: &'static [u8] = eresp!("1"); - /// Response code 2 as a array element const RCODE_OVERWRITE_ERR: &'static [u8] = eresp!("2"); - /// Response code 3 as a array element const RCODE_ACTION_ERR: &'static [u8] = eresp!("3"); - /// Response code 4 as a array element const RCODE_PACKET_ERR: &'static [u8] = eresp!("4"); - /// Response code 5 as a array element const RCODE_SERVER_ERR: &'static [u8] = eresp!("5"); - /// Response code 6 as a array element const RCODE_OTHER_ERR_EMPTY: &'static [u8] = eresp!("6"); - /// "Unknown action" error response const RCODE_UNKNOWN_ACTION: &'static [u8] = eresp!("Unknown action"); - /// Response code 7 const RCODE_WRONGTYPE_ERR: &'static [u8] = eresp!("7"); - /// Response code 8 const RCODE_UNKNOWN_DATA_TYPE: &'static [u8] = eresp!("8"); - /// Response code 9 as an array element const RCODE_ENCODING_ERROR: &'static [u8] = eresp!("9"); // respstrings - - /// Snapshot busy error const RSTRING_SNAPSHOT_BUSY: &'static [u8] = eresp!("err-snapshot-busy"); - /// Snapshot disabled (other error) const RSTRING_SNAPSHOT_DISABLED: &'static [u8] = eresp!("err-snapshot-disabled"); - /// Duplicate snapshot const RSTRING_SNAPSHOT_DUPLICATE: &'static [u8] = eresp!("duplicate-snapshot"); - /// Snapshot has illegal name (other error) const RSTRING_SNAPSHOT_ILLEGAL_NAME: &'static [u8] = eresp!("err-invalid-snapshot-name"); - /// Access after termination signal (other error) const RSTRING_ERR_ACCESS_AFTER_TERMSIG: &'static [u8] = eresp!("err-access-after-termsig"); // keyspace related resps - /// The default container was not set const RSTRING_DEFAULT_UNSET: &'static [u8] = eresp!("default-container-unset"); - /// The container was not found const RSTRING_CONTAINER_NOT_FOUND: &'static [u8] = eresp!("container-not-found"); - /// The container is still in use and so cannot be removed const RSTRING_STILL_IN_USE: &'static [u8] = eresp!("still-in-use"); - /// This is a protected object and hence cannot be accessed const RSTRING_PROTECTED_OBJECT: &'static [u8] = eresp!("err-protected-object"); - /// The action was applied against the wrong model const RSTRING_WRONG_MODEL: &'static [u8] = eresp!("wrong-model"); - /// The container already exists const RSTRING_ALREADY_EXISTS: &'static [u8] = eresp!("err-already-exists"); - /// The container is not ready const RSTRING_NOT_READY: &'static [u8] = eresp!("not-ready"); - /// A transactional failure occurred const RSTRING_DDL_TRANSACTIONAL_FAILURE: &'static [u8] = eresp!("transactional-failure"); - /// An unknown DDL query was run const RSTRING_UNKNOWN_DDL_QUERY: &'static [u8] = eresp!("unknown-ddl-query"); - /// The expression for a DDL query was malformed const RSTRING_BAD_EXPRESSION: &'static [u8] = eresp!("malformed-expression"); - /// An unknown model was passed in a DDL query const RSTRING_UNKNOWN_MODEL: &'static [u8] = eresp!("unknown-model"); - /// Too many arguments were passed to model constructor const RSTRING_TOO_MANY_ARGUMENTS: &'static [u8] = eresp!("too-many-args"); - /// The container name is too long const RSTRING_CONTAINER_NAME_TOO_LONG: &'static [u8] = eresp!("container-name-too-long"); - /// The container name contains invalid characters const RSTRING_BAD_CONTAINER_NAME: &'static [u8] = eresp!("bad-container-name"); - /// An unknown inspect query const RSTRING_UNKNOWN_INSPECT_QUERY: &'static [u8] = eresp!("unknown-inspect-query"); - /// An unknown table property was passed const RSTRING_UNKNOWN_PROPERTY: &'static [u8] = eresp!("unknown-property"); - /// The keyspace is not empty and hence cannot be removed const RSTRING_KEYSPACE_NOT_EMPTY: &'static [u8] = eresp!("keyspace-not-empty"); - /// Bad type supplied in a DDL query for the key const RSTRING_BAD_TYPE_FOR_KEY: &'static [u8] = eresp!("bad-type-for-key"); - /// The index for the provided list was non-existent const RSTRING_LISTMAP_BAD_INDEX: &'static [u8] = eresp!("bad-list-index"); - /// The list is empty const RSTRING_LISTMAP_LIST_IS_EMPTY: &'static [u8] = eresp!("list-is-empty"); // elements @@ -146,6 +109,14 @@ impl ProtocolSpec for Skyhash1 { // full responses const FULLRESP_RCODE_PACKET_ERR: &'static [u8] = b"*1\n!1\n4\n"; const FULLRESP_RCODE_WRONG_TYPE: &'static [u8] = b"*1\n!1\n7\n"; + + // auth rcodes/strings + const AUTH_ERROR_ALREADYCLAIMED: &'static [u8] = eresp!("err-auth-already-claimed"); + const AUTH_CODE_BAD_CREDENTIALS: &'static [u8] = eresp!("10"); + const AUTH_ERROR_DISABLED: &'static [u8] = eresp!("err-auth-disabled"); + const AUTH_CODE_PERMS: &'static [u8] = eresp!("11"); + const AUTH_ERROR_ILLEGAL_USERNAME: &'static [u8] = eresp!("err-auth-illegal-username"); + const AUTH_ERROR_FAILED_TO_DELETE_USER: &'static [u8] = eresp!("err-auth-deluser-fail"); } impl ProtocolRead for T @@ -163,6 +134,27 @@ where T: RawConnection + Send + Sync, Strm: Stream, { + fn write_mono_length_prefixed_with_tsymbol<'life0, 'life1, 'ret_life>( + &'life0 mut self, + data: &'life1 [u8], + tsymbol: u8, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: Send + 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // + stream.write_all(&[tsymbol]).await?; + stream.write_all(&Integer64::from(data.len())).await?; + stream.write_all(&[Skyhash1::LF]).await?; + // + stream.write_all(data).await?; + stream.write_all(&[Skyhash1::LF]).await + }) + } fn write_string<'life0, 'life1, 'ret_life>( &'life0 mut self, string: &'life1 str, diff --git a/server/src/protocol/v2/interface_impls.rs b/server/src/protocol/v2/interface_impls.rs index a4325ec9..1f4d72d0 100644 --- a/server/src/protocol/v2/interface_impls.rs +++ b/server/src/protocol/v2/interface_impls.rs @@ -60,82 +60,45 @@ impl ProtocolSpec for Skyhash2 { const PIPELINED_QUERY_FIRST_BYTE: u8 = b'$'; // respcodes - /// Response code 0 as a array element const RCODE_OKAY: &'static [u8] = eresp!("0"); - /// Response code 1 as a array element const RCODE_NIL: &'static [u8] = eresp!("1"); - /// Response code 2 as a array element const RCODE_OVERWRITE_ERR: &'static [u8] = eresp!("2"); - /// Response code 3 as a array element const RCODE_ACTION_ERR: &'static [u8] = eresp!("3"); - /// Response code 4 as a array element const RCODE_PACKET_ERR: &'static [u8] = eresp!("4"); - /// Response code 5 as a array element const RCODE_SERVER_ERR: &'static [u8] = eresp!("5"); - /// Response code 6 as a array element const RCODE_OTHER_ERR_EMPTY: &'static [u8] = eresp!("6"); - /// "Unknown action" error response const RCODE_UNKNOWN_ACTION: &'static [u8] = eresp!("Unknown action"); - /// Response code 7 const RCODE_WRONGTYPE_ERR: &'static [u8] = eresp!("7"); - /// Response code 8 const RCODE_UNKNOWN_DATA_TYPE: &'static [u8] = eresp!("8"); - /// Response code 9 as an array element const RCODE_ENCODING_ERROR: &'static [u8] = eresp!("9"); // respstrings - - /// Snapshot busy error const RSTRING_SNAPSHOT_BUSY: &'static [u8] = eresp!("err-snapshot-busy"); - /// Snapshot disabled (other error) const RSTRING_SNAPSHOT_DISABLED: &'static [u8] = eresp!("err-snapshot-disabled"); - /// Duplicate snapshot const RSTRING_SNAPSHOT_DUPLICATE: &'static [u8] = eresp!("duplicate-snapshot"); - /// Snapshot has illegal name (other error) const RSTRING_SNAPSHOT_ILLEGAL_NAME: &'static [u8] = eresp!("err-invalid-snapshot-name"); - /// Access after termination signal (other error) const RSTRING_ERR_ACCESS_AFTER_TERMSIG: &'static [u8] = eresp!("err-access-after-termsig"); // keyspace related resps - /// The default container was not set const RSTRING_DEFAULT_UNSET: &'static [u8] = eresp!("default-container-unset"); - /// The container was not found const RSTRING_CONTAINER_NOT_FOUND: &'static [u8] = eresp!("container-not-found"); - /// The container is still in use and so cannot be removed const RSTRING_STILL_IN_USE: &'static [u8] = eresp!("still-in-use"); - /// This is a protected object and hence cannot be accessed const RSTRING_PROTECTED_OBJECT: &'static [u8] = eresp!("err-protected-object"); - /// The action was applied against the wrong model const RSTRING_WRONG_MODEL: &'static [u8] = eresp!("wrong-model"); - /// The container already exists const RSTRING_ALREADY_EXISTS: &'static [u8] = eresp!("err-already-exists"); - /// The container is not ready const RSTRING_NOT_READY: &'static [u8] = eresp!("not-ready"); - /// A transactional failure occurred const RSTRING_DDL_TRANSACTIONAL_FAILURE: &'static [u8] = eresp!("transactional-failure"); - /// An unknown DDL query was run const RSTRING_UNKNOWN_DDL_QUERY: &'static [u8] = eresp!("unknown-ddl-query"); - /// The expression for a DDL query was malformed const RSTRING_BAD_EXPRESSION: &'static [u8] = eresp!("malformed-expression"); - /// An unknown model was passed in a DDL query const RSTRING_UNKNOWN_MODEL: &'static [u8] = eresp!("unknown-model"); - /// Too many arguments were passed to model constructor const RSTRING_TOO_MANY_ARGUMENTS: &'static [u8] = eresp!("too-many-args"); - /// The container name is too long const RSTRING_CONTAINER_NAME_TOO_LONG: &'static [u8] = eresp!("container-name-too-long"); - /// The container name contains invalid characters const RSTRING_BAD_CONTAINER_NAME: &'static [u8] = eresp!("bad-container-name"); - /// An unknown inspect query const RSTRING_UNKNOWN_INSPECT_QUERY: &'static [u8] = eresp!("unknown-inspect-query"); - /// An unknown table property was passed const RSTRING_UNKNOWN_PROPERTY: &'static [u8] = eresp!("unknown-property"); - /// The keyspace is not empty and hence cannot be removed const RSTRING_KEYSPACE_NOT_EMPTY: &'static [u8] = eresp!("keyspace-not-empty"); - /// Bad type supplied in a DDL query for the key const RSTRING_BAD_TYPE_FOR_KEY: &'static [u8] = eresp!("bad-type-for-key"); - /// The index for the provided list was non-existent const RSTRING_LISTMAP_BAD_INDEX: &'static [u8] = eresp!("bad-list-index"); - /// The list is empty const RSTRING_LISTMAP_LIST_IS_EMPTY: &'static [u8] = eresp!("list-is-empty"); // elements @@ -144,6 +107,14 @@ impl ProtocolSpec for Skyhash2 { // full responses const FULLRESP_RCODE_PACKET_ERR: &'static [u8] = b"*!4\n"; const FULLRESP_RCODE_WRONG_TYPE: &'static [u8] = b"*!7\n"; + + // auth respcodes/strings + const AUTH_ERROR_ALREADYCLAIMED: &'static [u8] = eresp!("err-auth-already-claimed"); + const AUTH_CODE_BAD_CREDENTIALS: &'static [u8] = eresp!("10"); + const AUTH_ERROR_DISABLED: &'static [u8] = eresp!("err-auth-disabled"); + const AUTH_CODE_PERMS: &'static [u8] = eresp!("11"); + const AUTH_ERROR_ILLEGAL_USERNAME: &'static [u8] = eresp!("err-auth-illegal-username"); + const AUTH_ERROR_FAILED_TO_DELETE_USER: &'static [u8] = eresp!("err-auth-deluser-fail"); } impl ProtocolRead for T @@ -161,6 +132,25 @@ where T: RawConnection + Send + Sync, Strm: Stream, { + fn write_mono_length_prefixed_with_tsymbol<'life0, 'life1, 'ret_life>( + &'life0 mut self, + data: &'life1 [u8], + tsymbol: u8, + ) -> FutureResult<'ret_life, IoResult<()>> + where + 'life0: 'ret_life, + 'life1: 'ret_life, + Self: Send + 'ret_life, + { + Box::pin(async move { + let stream = self.get_mut_stream(); + // + stream.write_all(&[tsymbol]).await?; + stream.write_all(&Integer64::from(data.len())).await?; + stream.write_all(&[Skyhash2::LF]).await?; + stream.write_all(data).await + }) + } fn write_string<'life0, 'life1, 'ret_life>( &'life0 mut self, string: &'life1 str, diff --git a/server/src/queryengine/mod.rs b/server/src/queryengine/mod.rs index b18fe03b..f4c56c7f 100644 --- a/server/src/queryengine/mod.rs +++ b/server/src/queryengine/mod.rs @@ -89,7 +89,7 @@ action! { }; match iter.next_lowercase().unwrap_or_custom_aerr(P::RCODE_PACKET_ERR)?.as_ref() { ACTION_AUTH => auth::auth_login_only(con, auth, iter).await, - _ => util::err(auth::errors::AUTH_CODE_BAD_CREDENTIALS), + _ => util::err(P::AUTH_CODE_BAD_CREDENTIALS), } } //// Execute a simple query