Upgrade all interfaces to use the Skyhash protocol

next
Sayan Nandan 3 years ago
parent d6a3cc2acb
commit 78067d15eb

@ -37,19 +37,19 @@ use std::path::{Component, PathBuf};
pub async fn mksnap<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
let howmany = act.len() - 1;
if howmany == 0 {
if !handle.is_snapshot_enabled() {
// Since snapshotting is disabled, we can't create a snapshot!
// We'll just return an error returning the same
return con
.write_response(&**responses::fresp::R_SNAPSHOT_DISABLED)
.write_response(&**responses::groups::SNAPSHOT_DISABLED)
.await;
}
// We will just follow the standard convention of creating snapshots
@ -87,7 +87,7 @@ where
}
if engine_was_busy {
return con
.write_response(&**responses::fresp::R_SNAPSHOT_BUSY)
.write_response(&**responses::groups::SNAPSHOT_BUSY)
.await;
}
if let Some(succeeded) = snap_result {
@ -107,13 +107,13 @@ where
// We shouldn't ever reach here if all our logic is correct
// but if we do, something is wrong with the runtime
return con
.write_response(&**responses::fresp::R_ERR_ACCESS_AFTER_TERMSIG)
.write_response(&**responses::groups::ERR_ACCESS_AFTER_TERMSIG)
.await;
}
} else {
if howmany == 1 {
// This means that the user wants to create a 'named' snapshot
let snapname = act.get_ref().get(1).unwrap_or_else(|| unsafe {
let snapname = act.get(1).unwrap_or_else(|| unsafe {
// UNSAFE(@ohsayan): We've already checked that the action
// contains a second argument, so this can't be reached
unreachable_unchecked()
@ -134,7 +134,7 @@ where
!= 0;
if illegal_snapshot {
return con
.write_response(&**responses::fresp::R_SNAPSHOT_ILLEGAL_NAME)
.write_response(&**responses::groups::SNAPSHOT_ILLEGAL_NAME)
.await;
}
let failed;

@ -29,6 +29,7 @@
use crate::config::BGSave;
use crate::config::SnapshotConfig;
use crate::config::SnapshotPref;
use crate::coredb::htable::HTable;
use crate::dbnet::connection::prelude::*;
use crate::diskstore;
use crate::protocol::Query;
@ -40,7 +41,6 @@ use libsky::TResult;
use parking_lot::RwLock;
use parking_lot::RwLockReadGuard;
use parking_lot::RwLockWriteGuard;
use crate::coredb::htable::HTable;
use std::sync::Arc;
use tokio;
pub mod htable;
@ -268,12 +268,13 @@ impl CoreDB {
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
match query {
Query::Simple(q) => {
Query::SimpleQuery(q) => {
con.write_simple_query_header().await?;
queryengine::execute_simple(&self, con, q).await?;
con.flush_stream().await?;
}
// TODO(@ohsayan): Pipeline commands haven't been implemented yet
Query::Pipelined(_) => unimplemented!(),
Query::PipelinedQuery(_) => unimplemented!(),
}
Ok(())
}
@ -386,7 +387,7 @@ impl CoreDB {
/// **⚠ Do note**: This is super inefficient since it performs an actual
/// clone of the `HTable` and doesn't do any `Arc`-business! This function
/// can be used by test functions and the server, but **use with caution!**
pub fn get_HTable_deep_clone(&self) -> HTable<String, Data> {
pub fn get_htable_deep_clone(&self) -> HTable<String, Data> {
(*self.acquire_read().get_ref()).clone()
}

@ -40,13 +40,13 @@ use crate::dbnet::tcp::BufferedSocketStream;
use crate::dbnet::Terminator;
use crate::protocol;
use crate::protocol::responses;
use crate::protocol::ParseError;
use crate::protocol::Query;
use crate::resp::Writable;
use crate::CoreDB;
use bytes::Buf;
use bytes::BytesMut;
use libsky::TResult;
use protocol::ParseResult;
use protocol::QueryResult;
use std::future::Future;
use std::io::Error as IoError;
use std::io::ErrorKind;
@ -60,6 +60,15 @@ use tokio::io::BufWriter;
use tokio::sync::mpsc;
use tokio::sync::Semaphore;
pub const SIMPLE_QUERY_HEADER: [u8; 3] = [b'*', b'1', b'\n'];
pub enum QueryResult {
Q(Query),
E(Vec<u8>),
Empty,
Wrongtype,
}
pub mod prelude {
//! A 'prelude' for callers that would like to use the `ProtocolConnection` and `ProtocolConnectionExt` traits
//!
@ -108,11 +117,11 @@ where
})
}
/// Try to parse a query from the buffered data
fn try_query(&self) -> Result<ParseResult, ()> {
fn try_query(&self) -> Result<(Query, usize), ParseError> {
if self.get_buffer().is_empty() {
return Err(());
return Err(ParseError::Empty);
}
Ok(protocol::parse(&self.get_buffer()[..]))
protocol::Parser::new(&self.get_buffer()).parse()
}
/// Read a query from the remote end
///
@ -131,18 +140,19 @@ where
loop {
mv_self.read_again().await?;
match mv_self.try_query() {
Ok(ParseResult::Query(query, forward)) => {
mv_self.advance_buffer(forward);
Ok((query, forward_by)) => {
mv_self.advance_buffer(forward_by);
return Ok(QueryResult::Q(query));
}
Ok(ParseResult::BadPacket(discard_len)) => {
mv_self.advance_buffer(discard_len);
Err(ParseError::Empty) => return Ok(QueryResult::Empty),
Err(ParseError::NotEnough) => (),
Err(ParseError::DataTypeParseError) => return Ok(QueryResult::Wrongtype),
Err(ParseError::UnexpectedByte) | Err(ParseError::BadPacket) => {
return Ok(QueryResult::E(responses::fresp::R_PACKET_ERR.to_owned()));
}
Err(_) => {
return Ok(QueryResult::Empty);
Err(ParseError::UnknownDatatype) => {
unimplemented!()
}
_ => (),
}
}
};
@ -167,6 +177,23 @@ where
ret
})
}
/// Write the simple query header `*1\n` to the stream
fn write_simple_query_header<'r, 's>(
&'r mut self,
) -> Pin<Box<dyn Future<Output = IoResult<()>> + Send + 's>>
where
'r: 's,
Self: Send + 's,
{
Box::pin(async move {
let mv_self = self;
let ret: IoResult<()> = {
mv_self.write_response(&SIMPLE_QUERY_HEADER[..]).await?;
Ok(())
};
ret
})
}
/// Wraps around the `write_response` used to differentiate between a
/// success response and an error response
fn close_conn_with_error<'r, 's>(
@ -330,6 +357,11 @@ where
log::debug!("Failed to read query!");
self.con.close_conn_with_error(r).await?
}
Ok(QueryResult::Wrongtype) => {
self.con
.close_conn_with_error(responses::groups::WRONGTYPE_ERR.to_owned())
.await?
}
Ok(QueryResult::Empty) => return Ok(()),
#[cfg(windows)]
Err(e) => match e.kind() {

@ -246,7 +246,7 @@ fn test_snapshot() {
let _ = snapengine.mksnap();
let current = snapengine.get_snapshots().next().unwrap();
let read_hmap = diskstore::test_deserialize(fs::read(PathBuf::from(current)).unwrap()).unwrap();
let dbhmap = db.get_HTable_deep_clone();
let dbhmap = db.get_htable_deep_clone();
assert_eq!(read_hmap, dbhmap);
snapengine.clearall().unwrap();
fs::remove_dir_all(ourdir).unwrap();

@ -25,28 +25,23 @@
*/
use crate::dbnet::connection::prelude::*;
use crate::protocol::responses;
use crate::resp::GroupBegin;
/// Get the number of keys in the database
pub async fn dbsize<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
if act.howmany() != 0 {
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
}
crate::err_if_len_is!(act, con, != 0);
let len;
{
len = handle.acquire_read().get_ref().len();
}
con.write_response(GroupBegin(1)).await?;
con.write_simple_query_header().await?;
con.write_response(len).await?;
Ok(())
}

@ -27,11 +27,8 @@
//! # `DEL` queries
//! This module provides functions to work with `DEL` queries
use crate::dbnet::connection::prelude::*;
use crate::protocol::responses;
use crate::resp::GroupBegin;
/// Run a `DEL` query
///
@ -40,24 +37,19 @@ use crate::resp::GroupBegin;
pub async fn del<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
if howmany == 0 {
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
}
// Write #<m>\n#<n>\n&<howmany>\n to the stream
con.write_response(GroupBegin(1)).await?;
crate::err_if_len_is!(act, con, == 0);
let done_howmany: Option<usize>;
{
if let Some(mut whandle) = handle.acquire_write() {
let mut many = 0;
let cmap = (*whandle).get_mut_ref();
act.into_iter().for_each(|key| {
act.into_iter().skip(1).for_each(|key| {
if cmap.remove(&key).is_some() {
many += 1
}

@ -27,33 +27,24 @@
//! # `EXISTS` queries
//! This module provides functions to work with `EXISTS` queries
use crate::dbnet::connection::prelude::*;
use crate::protocol::responses;
use crate::resp::GroupBegin;
/// Run an `EXISTS` query
pub async fn exists<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
if howmany == 0 {
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
}
// Write #<m>\n#<n>\n&1\n to the stream
con.write_response(GroupBegin(1)).await?;
crate::err_if_len_is!(act, con, == 0);
let mut how_many_of_them_exist = 0usize;
{
let rhandle = handle.acquire_read();
let cmap = rhandle.get_ref();
act.into_iter().for_each(|key| {
act.into_iter().skip(1).for_each(|key| {
if cmap.contains_key(&key) {
how_many_of_them_exist += 1;
}

@ -24,24 +24,20 @@
*
*/
use crate::dbnet::connection::prelude::*;
use crate::protocol::responses;
/// Delete all the keys in the database
pub async fn flushdb<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
if act.howmany() != 0 {
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
}
crate::err_if_len_is!(act, con, != 0);
let failed;
{
if let Some(mut table) = handle.acquire_write() {
@ -52,8 +48,8 @@ where
}
}
if failed {
con.write_response(&**responses::fresp::R_SERVER_ERR).await
con.write_response(&**responses::groups::SERVER_ERR).await
} else {
con.write_response(&**responses::fresp::R_OKAY).await
con.write_response(&**responses::groups::OKAY).await
}
}

@ -27,29 +27,22 @@
//! # `GET` queries
//! This module provides functions to work with `GET` queries
use crate::dbnet::connection::prelude::*;
use crate::protocol::responses;
use crate::resp::{BytesWrapper, GroupBegin};
use crate::resp::BytesWrapper;
use bytes::Bytes;
/// Run a `GET` query
pub async fn get<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
if howmany != 1 {
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
}
// Write #<m>\n#<n>\n&1\n to the stream
con.write_response(GroupBegin(1)).await?;
crate::err_if_len_is!(act, con, != 1);
let res: Option<Bytes> = {
let rhandle = handle.acquire_read();
let reader = rhandle.get_ref();
@ -57,7 +50,7 @@ where
// UNSAFE(@ohsayan): act.get_ref().get_unchecked() is safe because we've already if the action
// group contains one argument (excluding the action itself)
reader
.get(act.get_ref().get_unchecked(1))
.get(act.get_unchecked(1))
.map(|b| b.get_blob().clone())
}
};

@ -28,10 +28,7 @@
//! #`JGET` queries
//! Functions for handling `JGET` queries
use crate::dbnet::connection::prelude::*;
use crate::protocol::responses;
/// Run a `JGET` query
/// This returns a JSON key/value pair of keys and values
@ -45,16 +42,13 @@ use crate::protocol::responses;
pub async fn jget<T, Strm>(
_handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
if howmany != 1 {
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
}
crate::err_if_len_is!(act, con, != 1);
todo!()
}

@ -26,8 +26,6 @@
use crate::dbnet::connection::prelude::*;
use crate::protocol::responses;
use crate::resp::GroupBegin;
/// Run a `KEYLEN` query
///
@ -35,18 +33,13 @@ use crate::resp::GroupBegin;
pub async fn keylen<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
if howmany != 1 {
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
}
// Write #<m>\n#<n>\n&1\n to the stream
con.write_response(GroupBegin(1)).await?;
crate::err_if_len_is!(act, con, != 1);
let res: Option<usize> = {
let rhandle = handle.acquire_read();
let reader = rhandle.get_ref();
@ -54,7 +47,7 @@ where
// UNSAFE(@ohsayan): get_unchecked() is completely safe as we've already checked
// the number of arguments is one
reader
.get(act.get_ref().get_unchecked(1))
.get(act.get_unchecked(1))
.map(|b| b.get_blob().len())
}
};

@ -24,32 +24,24 @@
*
*/
use crate::dbnet::connection::prelude::*;
use crate::protocol::responses;
use crate::resp::{BytesWrapper, GroupBegin};
use crate::resp::BytesWrapper;
use bytes::Bytes;
use libsky::terrapipe::RespCodes;
/// Run an `MGET` query
///
pub async fn mget<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
if howmany == 0 {
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
}
// Write #<m>\n#<n>\n&<howmany>\n to the stream
con.write_response(GroupBegin(howmany)).await?;
let mut keys = act.into_iter();
crate::err_if_len_is!(act, con, == 0);
let mut keys = act.into_iter().skip(1);
while let Some(key) = keys.next() {
let res: Option<Bytes> = {
let rhandle = handle.acquire_read();

@ -44,19 +44,72 @@ pub mod update;
pub mod uset;
pub mod heya {
//! Respond to `HEYA` queries
use crate::protocol;
use crate::dbnet::connection::prelude::*;
use crate::protocol;
use protocol::responses;
/// Returns a `HEY!` `Response`
pub async fn heya<T, Strm>(
_handle: &crate::coredb::CoreDB,
con: &mut T,
_act: crate::protocol::ActionGroup,
_act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
con.write_response(&**responses::fresp::R_HEYA).await
con.write_response(&**responses::groups::HEYA).await
}
}
#[macro_export]
macro_rules! err_if_len_is {
($buf:ident, $con:ident, == $len:literal) => {
if $buf.len() - 1 == $len {
return $con
.write_response(&**crate::protocol::responses::groups::ACTION_ERR)
.await;
}
};
($buf:ident, $con:ident, != $len:literal) => {
if $buf.len() - 1 != $len {
return $con
.write_response(&**crate::protocol::responses::groups::ACTION_ERR)
.await;
}
};
($buf:ident, $con:ident, > $len:literal) => {
if $buf.len() - 1 > $len {
return $con
.write_response(&**crate::protocol::responses::groups::ACTION_ERR)
.await;
}
};
($buf:ident, $con:ident, < $len:literal) => {
if $buf.len() - 1 < $len {
return $con
.write_response(&**crate::protocol::responses::groups::ACTION_ERR)
.await;
}
};
($buf:ident, $con:ident, >= $len:literal) => {
if $buf.len() - 1 >= $len {
return $con
.write_response(&**crate::protocol::responses::groups::ACTION_ERR)
.await;
}
};
($buf:ident, $con:ident, <= $len:literal) => {
if $buf.len() - 1 <= $len {
return $con
.write_response(&**crate::protocol::responses::groups::ACTION_ERR)
.await;
}
};
($buf:ident, $con:ident, & $len:literal) => {
if $buf.len() - 1 & $len {
return $con
.write_response(&**crate::protocol::responses::groups::ACTION_ERR)
.await;
}
};
}

@ -24,33 +24,29 @@
*
*/
use crate::coredb::{self};
use crate::coredb;
use crate::coredb::htable::Entry;
use crate::dbnet::connection::prelude::*;
use crate::protocol::responses;
use crate::resp::GroupBegin;
use crate::coredb::htable::Entry;
/// Run an `MSET` query
pub async fn mset<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
let howmany = act.len() - 1;
if howmany & 1 == 1 || howmany == 0 {
// An odd number of arguments means that the number of keys
// is not the same as the number of values, we won't run this
// action at all
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
return con.write_response(&**responses::groups::ACTION_ERR).await;
}
// Write #<m>\n#<n>\n&<howmany>\n to the stream
// It is howmany/2 since we will be writing howmany/2 number of responses
con.write_response(GroupBegin(1)).await?;
let mut kviter = act.into_iter();
let mut kviter = act.into_iter().skip(1);
let done_howmany: Option<usize>;
{
if let Some(mut whandle) = handle.acquire_write() {
@ -72,6 +68,6 @@ where
if let Some(done_howmany) = done_howmany {
return con.write_response(done_howmany as usize).await;
} else {
return con.write_response(&**responses::fresp::R_SERVER_ERR).await;
return con.write_response(&**responses::groups::SERVER_ERR).await;
}
}

@ -24,33 +24,29 @@
*
*/
use crate::coredb::{self};
use crate::coredb;
use crate::coredb::htable::Entry;
use crate::dbnet::connection::prelude::*;
use crate::protocol::responses;
use crate::resp::GroupBegin;
use crate::coredb::htable::Entry;
/// Run an `MUPDATE` query
pub async fn mupdate<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
let howmany = act.len() - 1;
if howmany & 1 == 1 || howmany == 0 {
// An odd number of arguments means that the number of keys
// is not the same as the number of values, we won't run this
// action at all
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
return con.write_response(&**responses::groups::ACTION_ERR).await;
}
// Write #<m>\n#<n>\n&<howmany>\n to the stream
// It is howmany/2 since we will be writing howmany/2 number of responses
con.write_response(GroupBegin(1)).await?;
let mut kviter = act.into_iter();
let mut kviter = act.into_iter().skip(1);
let done_howmany: Option<usize>;
{
if let Some(mut whandle) = handle.acquire_write() {
@ -72,6 +68,6 @@ where
if let Some(done_howmany) = done_howmany {
return con.write_response(done_howmany as usize).await;
} else {
return con.write_response(&**responses::fresp::R_SERVER_ERR).await;
return con.write_response(&**responses::groups::SERVER_ERR).await;
}
}

@ -27,29 +27,29 @@
//! # `SET` queries
//! This module provides functions to work with `SET` queries
use crate::coredb::htable::Entry;
use crate::coredb::{self};
use crate::dbnet::connection::prelude::*;
use crate::protocol::responses;
use coredb::Data;
use crate::coredb::htable::Entry;
use std::hint::unreachable_unchecked;
/// Run a `SET` query
pub async fn set<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
let howmany = act.len() - 1;
if howmany != 2 {
// There should be exactly 2 arguments
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
return con.write_response(&**responses::groups::ACTION_ERR).await;
}
let mut it = act.into_iter();
let mut it = act.into_iter().skip(1);
let did_we = {
if let Some(mut writer) = handle.acquire_write() {
let writer = writer.get_mut_ref();
@ -73,14 +73,13 @@ where
};
if let Some(did_we) = did_we {
if did_we {
con.write_response(&**responses::fresp::R_OKAY).await?;
con.write_response(&**responses::groups::OKAY).await?;
} else {
con.write_response(&**responses::fresp::R_OVERWRITE_ERR)
con.write_response(&**responses::groups::OVERWRITE_ERR)
.await?;
}
} else {
con.write_response(&**responses::fresp::R_SERVER_ERR)
.await?;
con.write_response(&**responses::groups::SERVER_ERR).await?;
}
Ok(())
}

@ -48,15 +48,15 @@ use std::hint::unreachable_unchecked;
pub async fn sset<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
let howmany = act.len() - 1;
if howmany & 1 == 1 || howmany == 0 {
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
return con.write_response(&**responses::groups::ACTION_ERR).await;
}
let mut failed = Some(false);
{
@ -67,7 +67,6 @@ where
// This iterator gives us the keys and values, skipping the first argument which
// is the action name
let mut key_iter = act
.get_ref()
.get(1..)
.unwrap_or_else(|| unsafe {
// UNSAFE(@ohsayan): We've already checked if the action group contains more than one arugment
@ -110,13 +109,13 @@ where
}
if let Some(failed) = failed {
if failed {
con.write_response(&**responses::fresp::R_OVERWRITE_ERR)
con.write_response(&**responses::groups::OVERWRITE_ERR)
.await
} else {
con.write_response(&**responses::fresp::R_OKAY).await
con.write_response(&**responses::groups::OKAY).await
}
} else {
con.write_response(&**responses::fresp::R_SERVER_ERR).await
con.write_response(&**responses::groups::SERVER_ERR).await
}
}
@ -127,15 +126,15 @@ where
pub async fn sdel<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
let howmany = act.len() - 1;
if howmany == 0 {
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
return con.write_response(&**responses::groups::ACTION_ERR).await;
}
let mut failed = Some(false);
{
@ -143,7 +142,6 @@ where
// doesn't go beyond the scope of this function - and is never used across
// an await: cause, the compiler ain't as smart as we are ;)
let mut key_iter = act
.get_ref()
.get(1..)
.unwrap_or_else(|| unsafe {
// UNSAFE(@ohsayan): This is safe as we've already checked that there are arguments
@ -185,12 +183,12 @@ where
}
if let Some(failed) = failed {
if failed {
con.write_response(&**responses::fresp::R_NIL).await
con.write_response(&**responses::groups::NIL).await
} else {
con.write_response(&**responses::fresp::R_OKAY).await
con.write_response(&**responses::groups::OKAY).await
}
} else {
con.write_response(&**responses::fresp::R_SERVER_ERR).await
con.write_response(&**responses::groups::SERVER_ERR).await
}
}
@ -201,15 +199,15 @@ where
pub async fn supdate<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
let howmany = act.len() - 1;
if howmany & 1 == 1 || howmany == 0 {
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
return con.write_response(&**responses::groups::ACTION_ERR).await;
}
let mut failed = Some(false);
{
@ -217,7 +215,6 @@ where
// doesn't go beyond the scope of this function - and is never used across
// an await: cause, the compiler ain't as smart as we are ;)
let mut key_iter = act
.get_ref()
.get(1..)
.unwrap_or_else(|| unsafe {
// UNSAFE(@ohsayan): We've already checked that the action group contains more
@ -262,11 +259,11 @@ where
}
if let Some(failed) = failed {
if failed {
con.write_response(&**responses::fresp::R_NIL).await
con.write_response(&**responses::groups::NIL).await
} else {
con.write_response(&**responses::fresp::R_OKAY).await
con.write_response(&**responses::groups::OKAY).await
}
} else {
con.write_response(&**responses::fresp::R_SERVER_ERR).await
con.write_response(&**responses::groups::SERVER_ERR).await
}
}

@ -38,18 +38,18 @@ use std::hint::unreachable_unchecked;
pub async fn update<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
let howmany = act.len() - 1;
if howmany != 2 {
// There should be exactly 2 arguments
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
return con.write_response(&**responses::groups::ACTION_ERR).await;
}
let mut it = act.into_iter();
let mut it = act.into_iter().skip(1);
let did_we = {
if let Some(mut whandle) = handle.acquire_write() {
let writer = whandle.get_mut_ref();
@ -73,13 +73,12 @@ where
};
if let Some(did_we) = did_we {
if did_we {
con.write_response(&**responses::fresp::R_OKAY).await?;
con.write_response(&**responses::groups::OKAY).await?;
} else {
con.write_response(&**responses::fresp::R_NIL).await?;
con.write_response(&**responses::groups::NIL).await?;
}
} else {
con.write_response(&**responses::fresp::R_SERVER_ERR)
.await?;
con.write_response(&**responses::groups::SERVER_ERR).await?;
}
Ok(())
}

@ -27,8 +27,6 @@
use crate::coredb::{self};
use crate::dbnet::connection::prelude::*;
use crate::protocol::responses;
use crate::resp::GroupBegin;
/// Run an `USET` query
///
@ -36,23 +34,20 @@ use crate::resp::GroupBegin;
pub async fn uset<T, Strm>(
handle: &crate::coredb::CoreDB,
con: &mut T,
act: crate::protocol::ActionGroup,
act: Vec<String>,
) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let howmany = act.howmany();
let howmany = act.len() - 1;
if howmany & 1 == 1 || howmany == 0 {
// An odd number of arguments means that the number of keys
// is not the same as the number of values, we won't run this
// action at all
return con.write_response(&**responses::fresp::R_ACTION_ERR).await;
}
// Write #<m>\n#<n>\n&<howmany>\n to the stream
// It is howmany/2 since we will be writing howmany/2 number of responses
con.write_response(GroupBegin(1)).await?;
let mut kviter = act.into_iter();
let mut kviter = act.into_iter().skip(1);
let failed = {
if let Some(mut whandle) = handle.acquire_write() {
let writer = whandle.get_mut_ref();

@ -0,0 +1,88 @@
/*
* Created on Tue May 11 2021
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2021, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
use std::borrow::Cow;
#[derive(Debug, PartialEq)]
#[non_exhaustive]
/// # Data Types
///
/// This enum represents the data types supported by the Skyhash Protocol
pub enum Element {
/// Arrays can be nested! Their `<tsymbol>` is `&`
Array(Vec<Element>),
/// A String value; `<tsymbol>` is `+`
String(String),
/// An unsigned integer value; `<tsymbol>` is `:`
UnsignedInt(u64),
/// A non-recursive String array; tsymbol: `_`
FlatArray(Vec<String>),
}
impl Element {
/// This will return a reference to the first element in the element
///
/// If this element is a compound type, it will return a reference to the first element in the compound
/// type
pub fn get_first(&self) -> Option<Cow<String>> {
match self {
Self::Array(elem) => match elem.first() {
Some(el) => match el {
Element::String(st) => Some(Cow::Borrowed(&st)),
_ => None,
},
None => None,
},
Self::FlatArray(elem) => match elem.first() {
Some(el) => Some(Cow::Borrowed(&el)),
None => None,
},
Self::String(ref st) => Some(Cow::Borrowed(&st)),
_ => None,
}
}
pub fn is_flat_array(&self) -> bool {
if let Self::FlatArray(_) = self {
true
} else {
false
}
}
pub fn get_flat_array_size(&self) -> Option<usize> {
if let Self::FlatArray(a) = self {
Some(a.len())
} else {
None
}
}
pub fn is_flat_array_len_eq(&self, size: usize) -> Option<bool> {
if let Self::FlatArray(a) = self {
Some(a.len() == size)
} else {
None
}
}
}

File diff suppressed because it is too large Load Diff

@ -1,679 +0,0 @@
/*
* Created on Mon May 10 2021
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2021, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
//! # The Skyhash Protocol
//!
//! ## Introduction
//! The Skyhash Protocol is a serialization protocol that is used by Skytable for client/server communication.
//! It works in a query/response action similar to HTTP's request/response action. Skyhash supersedes the Terrapipe
//! protocol as a more simple, reliable, robust and scalable protocol.
//!
//! This module contains the [`Parser`] for the Skyhash protocol and it's enough to just pass a query packet as
//! a slice of unsigned 8-bit integers and the parser will do everything else. The Skyhash protocol was designed
//! and implemented by the Author (Sayan Nandan)
//!
use std::hint::unreachable_unchecked;
#[derive(Debug)]
/// # Skyhash Deserializer (Parser)
///
/// The [`Parser`] object can be used to deserialized a packet serialized by Skyhash which in turn serializes
/// it into data structures native to the Rust Language (and some Compound Types built on top of them).
///
/// ## Evaluation
///
/// The parser is pessimistic in most cases and will readily throw out any errors. On non-recusrive types
/// there is no recursion, but the parser will use implicit recursion for nested arrays. The parser will
/// happily not report any errors if some part of the next query was passed. This is very much a possibility
/// and so has been accounted for
///
/// ## Important note
///
/// All developers willing to modify the deserializer must keep this in mind: the cursor is always Ahead-Of-Position
/// that is the cursor should always point at the next character that can be read.
///
pub(super) struct Parser<'a> {
/// The internal cursor position
///
/// Do not even think of touching this externally
cursor: usize,
/// The buffer slice
buffer: &'a [u8],
}
#[derive(Debug, PartialEq)]
/// # Parser Errors
///
/// Several errors can arise during parsing and this enum accounts for them
pub enum ParseError {
/// Didn't get the number of expected bytes
NotEnough,
/// The query contains an unexpected byte
UnexpectedByte,
/// The packet simply contains invalid data
///
/// This is rarely returned and only in the special cases where a bad client sends `0` as
/// the query count
BadPacket,
/// A data type was given but the parser failed to serialize it into this type
///
/// This can happen not just for elements but can also happen for their sizes ([`Self::parse_into_u64`])
DataTypeParseError,
/// A data type that the server doesn't know was passed into the query
///
/// This is a frequent problem that can arise between different server editions as more data types
/// can be added with changing server versions
UnknownDatatype,
}
#[derive(Debug, PartialEq)]
/// # Types of Queries
///
/// A simple query carries out one action while a complex query executes multiple actions
pub enum Query {
/// A simple query will just hold one element
SimpleQuery(DataType),
/// A pipelined/batch query will hold multiple elements
PipelinedQuery(Vec<DataType>),
}
#[derive(Debug, PartialEq)]
#[non_exhaustive]
/// # Data Types
///
/// This enum represents the data types supported by the Skyhash Protocol
pub enum DataType {
/// Arrays can be nested! Their `<tsymbol>` is `&`
Array(Vec<DataType>),
/// A String value; `<tsymbol>` is `+`
String(String),
/// An unsigned integer value; `<tsymbol>` is `:`
UnsignedInt(u64),
}
/// A generic result to indicate parsing errors thorugh the [`ParseError`] enum
type ParseResult<T> = Result<T, ParseError>;
impl<'a> Parser<'a> {
/// Initialize a new parser instance
pub const fn new(buffer: &'a [u8]) -> Self {
Parser {
cursor: 0usize,
buffer,
}
}
/// Read from the current cursor position to `until` number of positions ahead
/// This **will forward the cursor itself** if the bytes exist or it will just return a `NotEnough` error
fn read_until(&mut self, until: usize) -> ParseResult<&[u8]> {
if let Some(b) = self.buffer.get(self.cursor..self.cursor + until) {
self.cursor += until;
Ok(b)
} else {
Err(ParseError::NotEnough)
}
}
/// This returns the position at which the line parsing began and the position at which the line parsing
/// stopped, in other words, you should be able to do self.buffer[started_at..stopped_at] to get a line
/// and do it unchecked. This **will move the internal cursor ahead** and place it **at the `\n` byte**
fn read_line(&mut self) -> (usize, usize) {
let started_at = self.cursor;
let mut stopped_at = self.cursor;
while self.cursor < self.buffer.len() {
if self.buffer[self.cursor] == b'\n' {
// Oh no! Newline reached, time to break the loop
// But before that ... we read the newline, so let's advance the cursor
self.incr_cursor();
break;
}
// So this isn't an LF, great! Let's forward the stopped_at position
stopped_at += 1;
self.incr_cursor();
}
(started_at, stopped_at)
}
/// Push the internal cursor ahead by one
fn incr_cursor(&mut self) {
self.cursor += 1;
}
/// This function will evaluate if the byte at the current cursor position equals the `ch` argument, i.e
/// the expression `*v == ch` is evaluated. However, if no element is present ahead, then the function
/// will return `Ok(_this_if_nothing_ahead_)`
fn will_cursor_give_char(&self, ch: u8, this_if_nothing_ahead: bool) -> ParseResult<bool> {
self.buffer.get(self.cursor).map_or(
if this_if_nothing_ahead {
Ok(true)
} else {
Err(ParseError::NotEnough)
},
|v| Ok(*v == ch),
)
}
/// Will the current cursor position give a linefeed? This will return `ParseError::NotEnough` if
/// the current cursor points at a non-existent index in `self.buffer`
fn will_cursor_give_linefeed(&self) -> ParseResult<bool> {
self.will_cursor_give_char(b'\n', false)
}
/// Parse a stream of bytes into [`usize`]
fn parse_into_usize(bytes: &[u8]) -> ParseResult<usize> {
if bytes.len() == 0 {
return Err(ParseError::NotEnough);
}
let mut byte_iter = bytes.into_iter();
let mut item_usize = 0usize;
while let Some(dig) = byte_iter.next() {
if !dig.is_ascii_digit() {
// dig has to be an ASCII digit
return Err(ParseError::DataTypeParseError);
}
// 48 is the ASCII code for 0, and 57 is the ascii code for 9
// so if 0 is given, the subtraction should give 0; similarly
// if 9 is given, the subtraction should give us 9!
let curdig: usize = dig
.checked_sub(48)
.unwrap_or_else(|| unsafe { unreachable_unchecked() })
.into();
// The usize can overflow; check that case
let product = match item_usize.checked_mul(10) {
Some(not_overflowed) => not_overflowed,
None => return Err(ParseError::DataTypeParseError),
};
let sum = match product.checked_add(curdig) {
Some(not_overflowed) => not_overflowed,
None => return Err(ParseError::DataTypeParseError),
};
item_usize = sum;
}
Ok(item_usize)
}
/// Pasre a stream of bytes into an [`u64`]
fn parse_into_u64(bytes: &[u8]) -> ParseResult<u64> {
if bytes.len() == 0 {
return Err(ParseError::NotEnough);
}
let mut byte_iter = bytes.into_iter();
let mut item_u64 = 0u64;
while let Some(dig) = byte_iter.next() {
if !dig.is_ascii_digit() {
// dig has to be an ASCII digit
return Err(ParseError::DataTypeParseError);
}
// 48 is the ASCII code for 0, and 57 is the ascii code for 9
// so if 0 is given, the subtraction should give 0; similarly
// if 9 is given, the subtraction should give us 9!
let curdig: u64 = dig
.checked_sub(48)
.unwrap_or_else(|| unsafe { unreachable_unchecked() })
.into();
// Now the entire u64 can overflow, so let's attempt to check it
let product = match item_u64.checked_mul(10) {
Some(not_overflowed) => not_overflowed,
None => return Err(ParseError::DataTypeParseError),
};
let sum = match product.checked_add(curdig) {
Some(not_overflowed) => not_overflowed,
None => return Err(ParseError::DataTypeParseError),
};
item_u64 = sum;
}
Ok(item_u64)
}
/// This will return the number of datagroups present in this query packet
///
/// This **will forward the cursor itself**
fn parse_metaframe_get_datagroup_count(&mut self) -> ParseResult<usize> {
// the smallest query we can have is: *1\n or 3 chars
if self.buffer.len() < 3 {
return Err(ParseError::NotEnough);
}
// Now we want to read `*<n>\n`
let (start, stop) = self.read_line();
if let Some(our_chunk) = self.buffer.get(start..stop) {
if our_chunk[0] == b'*' {
// Good, this will tell us the number of actions
// Let us attempt to read the usize from this point onwards
// that is excluding the '*' (so 1..)
let ret = Self::parse_into_usize(&our_chunk[1..])?;
Ok(ret)
} else {
Err(ParseError::UnexpectedByte)
}
} else {
Err(ParseError::NotEnough)
}
}
/// Get the next element **without** the tsymbol
///
/// This function **does not forward the newline**
fn __get_next_element(&mut self) -> ParseResult<&[u8]> {
let string_sizeline = self.read_line();
if let Some(line) = self.buffer.get(string_sizeline.0..string_sizeline.1) {
let string_size = Self::parse_into_usize(line)?;
let our_chunk = self.read_until(string_size)?;
Ok(our_chunk)
} else {
Err(ParseError::NotEnough)
}
}
/// The cursor should have passed the `+` tsymbol
fn parse_next_string(&mut self) -> ParseResult<String> {
let our_string_chunk = self.__get_next_element()?;
let our_string = String::from_utf8_lossy(&our_string_chunk).to_string();
if self.will_cursor_give_linefeed()? {
// there is a lf after the end of the string; great!
// let's skip that now
self.incr_cursor();
// let's return our string
Ok(our_string)
} else {
Err(ParseError::UnexpectedByte)
}
}
/// The cursor should have passed the `:` tsymbol
fn parse_next_u64(&mut self) -> ParseResult<u64> {
let our_u64_chunk = self.__get_next_element()?;
let our_u64 = Self::parse_into_u64(our_u64_chunk)?;
if self.will_cursor_give_linefeed()? {
// line feed after u64; heck yeah!
self.incr_cursor();
// return it
Ok(our_u64)
} else {
Err(ParseError::UnexpectedByte)
}
}
/// The cursor should be **at the tsymbol**
fn parse_next_element(&mut self) -> ParseResult<DataType> {
if let Some(tsymbol) = self.buffer.get(self.cursor) {
// so we have a tsymbol; nice, let's match it
// but advance the cursor before doing that
self.incr_cursor();
let ret = match *tsymbol {
b'+' => DataType::String(self.parse_next_string()?),
b':' => DataType::UnsignedInt(self.parse_next_u64()?),
b'&' => DataType::Array(self.parse_next_array()?),
_ => return Err(ParseError::UnknownDatatype),
};
Ok(ret)
} else {
// Not enough bytes to read an element
Err(ParseError::NotEnough)
}
}
/// The tsymbol `&` should have been passed!
fn parse_next_array(&mut self) -> ParseResult<Vec<DataType>> {
let (start, stop) = self.read_line();
if let Some(our_size_chunk) = self.buffer.get(start..stop) {
let array_size = Self::parse_into_usize(our_size_chunk)?;
let mut array = Vec::with_capacity(array_size);
for _ in 0..array_size {
array.push(self.parse_next_element()?);
}
Ok(array)
} else {
Err(ParseError::NotEnough)
}
}
/// Parse a query and return the [`Query`] and an `usize` indicating the number of bytes that
/// can be safely discarded from the buffer. It will otherwise return errors if they are found.
///
/// This object will drop `Self`
pub fn parse(mut self) -> Result<(Query, usize), ParseError> {
let number_of_queries = self.parse_metaframe_get_datagroup_count()?;
println!("Got count: {}", number_of_queries);
if number_of_queries == 0 {
// how on earth do you expect us to execute 0 queries? waste of bandwidth
return Err(ParseError::BadPacket);
}
if number_of_queries == 1 {
// This is a simple query
let single_group = self.parse_next_element()?;
// The below line defaults to false if no item is there in the buffer
// or it checks if the next time is a \r char; if it is, then it is the beginning
// of the next query
if self
.will_cursor_give_char(b'*', true)
.unwrap_or_else(|_| unsafe {
// This will never be the case because we'll always get a result and no error value
// as we've passed true which will yield Ok(true) even if there is no byte ahead
unreachable_unchecked()
})
{
Ok((Query::SimpleQuery(single_group), self.cursor))
} else {
// the next item isn't the beginning of a query but something else?
// that doesn't look right!
Err(ParseError::UnexpectedByte)
}
} else {
// This is a pipelined query
// We'll first make space for all the actiongroups
let mut queries = Vec::with_capacity(number_of_queries);
for _ in 0..number_of_queries {
queries.push(self.parse_next_element()?);
}
if self.will_cursor_give_char(b'*', true)? {
Ok((Query::PipelinedQuery(queries), self.cursor))
} else {
Err(ParseError::UnexpectedByte)
}
}
}
}
#[test]
fn test_metaframe_parse() {
let metaframe = "*2\n".as_bytes();
let mut parser = Parser::new(&metaframe);
assert_eq!(2, parser.parse_metaframe_get_datagroup_count().unwrap());
assert_eq!(parser.cursor, metaframe.len());
}
#[test]
fn test_cursor_next_char() {
let bytes = &[b'\n'];
assert!(Parser::new(&bytes[..])
.will_cursor_give_char(b'\n', false)
.unwrap());
let bytes = &[];
assert!(Parser::new(&bytes[..])
.will_cursor_give_char(b'\r', true)
.unwrap());
let bytes = &[];
assert!(
Parser::new(&bytes[..])
.will_cursor_give_char(b'\n', false)
.unwrap_err()
== ParseError::NotEnough
);
}
#[test]
fn test_metaframe_parse_fail() {
// First byte should be CR and not $
let metaframe = "$2\n*2\n".as_bytes();
let mut parser = Parser::new(&metaframe);
assert_eq!(
parser.parse_metaframe_get_datagroup_count().unwrap_err(),
ParseError::UnexpectedByte
);
// Give a wrong length approximation
let metaframe = "\r1\n*2\n".as_bytes();
assert_eq!(
Parser::new(&metaframe)
.parse_metaframe_get_datagroup_count()
.unwrap_err(),
ParseError::UnexpectedByte
);
}
#[test]
fn test_query_fail_not_enough() {
let query_packet = "*".as_bytes();
assert_eq!(
Parser::new(&query_packet).parse().err().unwrap(),
ParseError::NotEnough
);
let metaframe = "*2".as_bytes();
assert_eq!(
Parser::new(&metaframe)
.parse_metaframe_get_datagroup_count()
.unwrap_err(),
ParseError::NotEnough
);
}
#[test]
fn test_parse_next_string() {
let bytes = "5\nsayan\n".as_bytes();
let st = Parser::new(&bytes).parse_next_string().unwrap();
assert_eq!(st, "sayan".to_owned());
}
#[test]
fn test_parse_next_u64() {
let max = 18446744073709551615;
assert!(u64::MAX == max);
let bytes = "20\n18446744073709551615\n".as_bytes();
let our_u64 = Parser::new(&bytes).parse_next_u64().unwrap();
assert_eq!(our_u64, max);
// now overflow the u64
let bytes = "21\n184467440737095516156\n".as_bytes();
let our_u64 = Parser::new(&bytes).parse_next_u64().unwrap_err();
assert_eq!(our_u64, ParseError::DataTypeParseError);
}
#[test]
fn test_parse_next_element_string() {
let bytes = "+5\nsayan\n".as_bytes();
let next_element = Parser::new(&bytes).parse_next_element().unwrap();
assert_eq!(next_element, DataType::String("sayan".to_owned()));
}
#[test]
fn test_parse_next_element_string_fail() {
let bytes = "+5\nsayan".as_bytes();
assert_eq!(
Parser::new(&bytes).parse_next_element().unwrap_err(),
ParseError::NotEnough
);
}
#[test]
fn test_parse_next_element_u64() {
let bytes = ":20\n18446744073709551615\n".as_bytes();
let our_u64 = Parser::new(&bytes).parse_next_element().unwrap();
assert_eq!(our_u64, DataType::UnsignedInt(u64::MAX));
}
#[test]
fn test_parse_next_element_u64_fail() {
let bytes = ":20\n18446744073709551615".as_bytes();
assert_eq!(
Parser::new(&bytes).parse_next_element().unwrap_err(),
ParseError::NotEnough
);
}
#[test]
fn test_parse_next_element_array() {
let bytes = "&3\n+4\nMGET\n+3\nfoo\n+3\nbar\n".as_bytes();
let mut parser = Parser::new(&bytes);
let array = parser.parse_next_element().unwrap();
assert_eq!(
array,
DataType::Array(vec![
DataType::String("MGET".to_owned()),
DataType::String("foo".to_owned()),
DataType::String("bar".to_owned())
])
);
assert_eq!(parser.cursor, bytes.len());
}
#[test]
fn test_parse_next_element_array_fail() {
// should've been three elements, but there are two!
let bytes = "&3\n+4\nMGET\n+3\nfoo\n+3\n".as_bytes();
let mut parser = Parser::new(&bytes);
assert_eq!(
parser.parse_next_element().unwrap_err(),
ParseError::NotEnough
);
}
#[test]
fn test_parse_nested_array() {
// let's test a nested array
let bytes =
"&3\n+3\nACT\n+3\nfoo\n&4\n+5\nsayan\n+2\nis\n+7\nworking\n&2\n+6\nreally\n+4\nhard\n"
.as_bytes();
let mut parser = Parser::new(&bytes);
let array = parser.parse_next_element().unwrap();
assert_eq!(
array,
DataType::Array(vec![
DataType::String("ACT".to_owned()),
DataType::String("foo".to_owned()),
DataType::Array(vec![
DataType::String("sayan".to_owned()),
DataType::String("is".to_owned()),
DataType::String("working".to_owned()),
DataType::Array(vec![
DataType::String("really".to_owned()),
DataType::String("hard".to_owned())
])
])
])
);
assert_eq!(parser.cursor, bytes.len());
}
#[test]
fn test_parse_multitype_array() {
// let's test a nested array
let bytes = "&3\n+3\nACT\n+3\nfoo\n&4\n+5\nsayan\n+2\nis\n+7\nworking\n&2\n:2\n23\n+5\napril\n"
.as_bytes();
let mut parser = Parser::new(&bytes);
let array = parser.parse_next_element().unwrap();
assert_eq!(
array,
DataType::Array(vec![
DataType::String("ACT".to_owned()),
DataType::String("foo".to_owned()),
DataType::Array(vec![
DataType::String("sayan".to_owned()),
DataType::String("is".to_owned()),
DataType::String("working".to_owned()),
DataType::Array(vec![
DataType::UnsignedInt(23),
DataType::String("april".to_owned())
])
])
])
);
assert_eq!(parser.cursor, bytes.len());
}
#[test]
fn test_parse_a_query() {
let bytes =
"*1\n&3\n+3\nACT\n+3\nfoo\n&4\n+5\nsayan\n+2\nis\n+7\nworking\n&2\n:2\n23\n+5\napril\n"
.as_bytes();
let parser = Parser::new(&bytes);
let (resp, forward_by) = parser.parse().unwrap();
assert_eq!(
resp,
Query::SimpleQuery(DataType::Array(vec![
DataType::String("ACT".to_owned()),
DataType::String("foo".to_owned()),
DataType::Array(vec![
DataType::String("sayan".to_owned()),
DataType::String("is".to_owned()),
DataType::String("working".to_owned()),
DataType::Array(vec![
DataType::UnsignedInt(23),
DataType::String("april".to_owned())
])
])
]))
);
assert_eq!(forward_by, bytes.len());
}
#[test]
fn test_parse_a_query_fail_moredata() {
let bytes =
"*1\n&3\n+3\nACT\n+3\nfoo\n&4\n+5\nsayan\n+2\nis\n+7\nworking\n&1\n:2\n23\n+5\napril\n"
.as_bytes();
let parser = Parser::new(&bytes);
assert_eq!(parser.parse().unwrap_err(), ParseError::UnexpectedByte);
}
#[test]
fn test_pipelined_query_incomplete() {
// this was a pipelined query: we expected two queries but got one!
let bytes =
"*2\n&3\n+3\nACT\n+3\nfoo\n&4\n+5\nsayan\n+2\nis\n+7\nworking\n&2\n:2\n23\n+5\napril\n"
.as_bytes();
assert_eq!(
Parser::new(&bytes).parse().unwrap_err(),
ParseError::NotEnough
)
}
#[test]
fn test_pipelined_query() {
let bytes =
"*2\n&3\n+3\nACT\n+3\nfoo\n&3\n+5\nsayan\n+2\nis\n+7\nworking\n+4\nHEYA\n".as_bytes();
/*
(\r2\n*2\n)(&3\n)({+3\nACT\n}{+3\nfoo\n}{[&3\n][+5\nsayan\n][+2\nis\n][+7\nworking\n]})(+4\nHEYA\n)
*/
let (res, forward_by) = Parser::new(&bytes).parse().unwrap();
assert_eq!(
res,
Query::PipelinedQuery(vec![
DataType::Array(vec![
DataType::String("ACT".to_owned()),
DataType::String("foo".to_owned()),
DataType::Array(vec![
DataType::String("sayan".to_owned()),
DataType::String("is".to_owned()),
DataType::String("working".to_owned())
])
]),
DataType::String("HEYA".to_owned())
])
);
assert_eq!(forward_by, bytes.len());
}
#[test]
fn test_query_with_part_of_next_query() {
let bytes =
"*1\n&3\n+3\nACT\n+3\nfoo\n&4\n+5\nsayan\n+2\nis\n+7\nworking\n&2\n:2\n23\n+5\napril\n*1\n"
.as_bytes();
let (res, forward_by) = Parser::new(&bytes).parse().unwrap();
assert_eq!(
res,
Query::SimpleQuery(DataType::Array(vec![
DataType::String("ACT".to_owned()),
DataType::String("foo".to_owned()),
DataType::Array(vec![
DataType::String("sayan".to_owned()),
DataType::String("is".to_owned()),
DataType::String("working".to_owned()),
DataType::Array(vec![
DataType::UnsignedInt(23),
DataType::String("april".to_owned())
])
])
]))
);
// there are some ingenious folks on this planet who might just go bombing one query after the other
// we happily ignore those excess queries and leave it to the next round of parsing
assert_eq!(forward_by, bytes.len() - "*1\n".len());
}

@ -27,9 +27,8 @@
//! Primitives for generating Terrapipe compatible responses
pub mod groups {
//! # Pre-compiled response **groups**
//! These are pre-compiled response groups and **not** complete responses, that is, this is
//! to be sent after a `GroupBegin(n)` has been written to the stream. If complete
//! # Pre-compiled response **elements**
//! These are pre-compiled response groups and **not** complete responses. If complete
//! responses are required, user protocol::responses::fresp
use lazy_static::lazy_static;
lazy_static! {
@ -51,46 +50,56 @@ pub mod groups {
pub static ref HEYA: Vec<u8> = "+4\nHEY!\n".as_bytes().to_owned();
/// "Unknown action" error response
pub static ref UNKNOWN_ACTION: Vec<u8> = "!14\nUnknown action\n".as_bytes().to_owned();
pub static ref WRONGTYPE_ERR: Vec<u8> = "!1\n7\n".as_bytes().to_owned();
pub static ref SNAPSHOT_BUSY: Vec<u8> = "!17\nerr-snapshot-busy\n".as_bytes().to_owned();
/// Snapshot disabled (other error)
pub static ref SNAPSHOT_DISABLED: Vec<u8> = "!21\nerr-snapshot-disabled\n".as_bytes().to_owned();
/// Snapshot has illegal name (other error)
pub static ref SNAPSHOT_ILLEGAL_NAME: Vec<u8> = "!25\nerr-invalid-snapshot-name\n".as_bytes().to_owned();
/// Access after termination signal (other error)
pub static ref ERR_ACCESS_AFTER_TERMSIG: Vec<u8> = "!24\nerr-access-after-termsig\n".as_bytes().to_owned();
}
}
pub mod fresp {
//! # Pre-compiled **responses**
//! These are pre-compiled **complete** responses. This means that they should
//! be written off directly to the stream and should **not be preceded by a `GroupBegin(n)`**
//! be written off directly to the stream and should **not be preceded by any response metaframe**
use lazy_static::lazy_static;
lazy_static! {
/// Response code: 0 (Okay)
pub static ref R_OKAY: Vec<u8> = "#2\n*1\n#2\n&1\n!1\n0\n".as_bytes().to_owned();
pub static ref R_OKAY: Vec<u8> = "*1\n!1\n0\n".as_bytes().to_owned();
/// Response code: 1 (Nil)
pub static ref R_NIL: Vec<u8> = "#2\n*1\n#2\n&1\n!1\n1\n".as_bytes().to_owned();
pub static ref R_NIL: Vec<u8> = "*1\n!1\n1\n".as_bytes().to_owned();
/// Response code: 2 (Overwrite Error)
pub static ref R_OVERWRITE_ERR: Vec<u8> = "#2\n*1\n#2\n&1\n!1\n2\n".as_bytes().to_owned();
pub static ref R_OVERWRITE_ERR: Vec<u8> = "*1\n!1\n2\n".as_bytes().to_owned();
/// Response code: 3 (Action Error)
pub static ref R_ACTION_ERR: Vec<u8> = "#2\n*1\n#2\n&1\n!1\n3\n".as_bytes().to_owned();
pub static ref R_ACTION_ERR: Vec<u8> = "*1\n!1\n3\n".as_bytes().to_owned();
/// Response code: 4 (Packet Error)
pub static ref R_PACKET_ERR: Vec<u8> = "#2\n*1\n#2\n&1\n!1\n4\n".as_bytes().to_owned();
pub static ref R_PACKET_ERR: Vec<u8> = "*1\n!1\n4\n".as_bytes().to_owned();
/// Response code: 5 (Server Error)
pub static ref R_SERVER_ERR: Vec<u8> = "#2\n*1\n#2\n&1\n!1\n5\n".as_bytes().to_owned();
pub static ref R_SERVER_ERR: Vec<u8> = "*1\n!1\n5\n".as_bytes().to_owned();
/// Response code: 6 (Other Error _without description_)
pub static ref R_OTHER_ERR_EMPTY: Vec<u8> = "#2\n*1\n#2\n&1\n!1\n6\n".as_bytes().to_owned();
pub static ref R_OTHER_ERR_EMPTY: Vec<u8> = "*1\n!1\n6\n".as_bytes().to_owned();
/// A heya response
pub static ref R_HEYA: Vec<u8> = "#2\n*1\n#2\n&1\n+4\nHEY!\n".as_bytes().to_owned();
pub static ref R_HEYA: Vec<u8> = "*1\n+4\nHEY!\n".as_bytes().to_owned();
/// An other response with description: "Unknown action"
pub static ref R_UNKNOWN_ACTION: Vec<u8> = "#2\n*1\n#2\n&1\n!14\nUnknown action\n"
pub static ref R_UNKNOWN_ACTION: Vec<u8> = "*1\n!14\nUnknown action\n"
.as_bytes()
.to_owned();
/// A 0 uint64 reply
pub static ref R_ONE_INT_REPLY: Vec<u8> = "#2\n*1\n#2\n&1\n:1\n1\n".as_bytes().to_owned();
pub static ref R_ONE_INT_REPLY: Vec<u8> = "*1\n:1\n1\n".as_bytes().to_owned();
/// A 1 uint64 reply
pub static ref R_ZERO_INT_REPLY: Vec<u8> = "#2\n*1\n#2\n&1\n:1\n0\n".as_bytes().to_owned();
pub static ref R_ZERO_INT_REPLY: Vec<u8> = "*1\n:1\n0\n".as_bytes().to_owned();
/// Snapshot busy (other error)
pub static ref R_SNAPSHOT_BUSY: Vec<u8> = "#2\n*1\n#2\n&1\n!17\nerr-snapshot-busy\n".as_bytes().to_owned();
pub static ref R_SNAPSHOT_BUSY: Vec<u8> = "*1\n!17\nerr-snapshot-busy\n".as_bytes().to_owned();
/// Snapshot disabled (other error)
pub static ref R_SNAPSHOT_DISABLED: Vec<u8> = "#2\n*1\n#2\n&1\n!21\nerr-snapshot-disabled\n".as_bytes().to_owned();
pub static ref R_SNAPSHOT_DISABLED: Vec<u8> = "*1\n!21\nerr-snapshot-disabled\n".as_bytes().to_owned();
/// Snapshot has illegal name (other error)
pub static ref R_SNAPSHOT_ILLEGAL_NAME: Vec<u8> = "#2\n*1\n#2\n&1\n!25\nerr-invalid-snapshot-name\n".as_bytes().to_owned();
pub static ref R_SNAPSHOT_ILLEGAL_NAME: Vec<u8> = "*1\n!25\nerr-invalid-snapshot-name\n".as_bytes().to_owned();
/// Access after termination signal (other error)
pub static ref R_ERR_ACCESS_AFTER_TERMSIG: Vec<u8> = "#2\n*1\n#2\n&1\n!24\nerr-access-after-termsig\n".as_bytes().to_owned();
pub static ref R_ERR_ACCESS_AFTER_TERMSIG: Vec<u8> = "*1\n!24\nerr-access-after-termsig\n".as_bytes().to_owned();
/// Response code: 7; wrongtype
pub static ref R_WRONGTYPE_ERR: Vec<u8> = "*1\n!1\n7".as_bytes().to_owned();
}
}

@ -30,7 +30,7 @@ use crate::coredb::CoreDB;
use crate::dbnet::connection::prelude::*;
use crate::gen_match;
use crate::protocol::responses;
use crate::protocol::ActionGroup;
use crate::protocol::Element;
use crate::{admin, kvengine};
mod tags {
@ -73,23 +73,14 @@ mod tags {
}
/// Execute a simple(*) query
pub async fn execute_simple<T, Strm>(
db: &CoreDB,
con: &mut T,
buf: ActionGroup,
) -> std::io::Result<()>
pub async fn execute_simple<T, Strm>(db: &CoreDB, con: &mut T, buf: Element) -> std::io::Result<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let first = match buf.get_first() {
None => {
return con
.write_response(responses::fresp::R_PACKET_ERR.to_owned())
.await
.map_err(|e| e.into());
}
Some(f) => f.to_uppercase(),
Some(element) => element.to_lowercase(),
None => return con.write_response(&**responses::groups::PACKET_ERR).await,
};
gen_match!(
first,
@ -123,16 +114,21 @@ where
/// **NOTE:** This macro needs _paths_ for both sides of the $x => $y, to produce something sensible
macro_rules! gen_match {
($pre:ident, $db:ident, $con:ident, $buf:ident, $($x:path => $y:path),*) => {
let flat_array = if let crate::protocol::Element::FlatArray(array) = $buf {
array
} else {
return $con.write_response(&**responses::groups::WRONGTYPE_ERR).await;
};
match $pre.as_str() {
// First repeat over all the $x => $y patterns, passing in the variables
// and adding .await calls and adding the `?`
$(
$x => $y($db, $con, $buf).await?,
$x => $y($db, $con, flat_array).await?,
)*
// Now add the final case where no action is matched
_ => {
$con.write_response(responses::fresp::R_UNKNOWN_ACTION.to_owned())
.await?;
$con.write_response(&**responses::groups::UNKNOWN_ACTION)
.await;
},
}
};

@ -82,16 +82,6 @@ where
#[derive(Debug, PartialEq)]
pub struct BytesWrapper(pub Bytes);
/// This indicates the beginning of a response group in a response.
///
/// It holds the number of items to be written and writes:
/// ```text
/// #<self.0.to_string().len().to_string().into_bytes()>\n
/// &<self.0.to_string()>\n
/// ```
#[derive(Debug, PartialEq)]
pub struct GroupBegin(pub usize);
impl BytesWrapper {
pub fn finish_into_bytes(self) -> Bytes {
self.0
@ -194,32 +184,6 @@ impl Writable for RespCodes {
}
}
impl Writable for GroupBegin {
fn write<'s>(
self,
con: &'s mut impl IsConnection,
) -> Pin<Box<(dyn Future<Output = Result<(), IoError>> + Send + Sync + 's)>> {
async fn write_bytes(con: &mut impl IsConnection, size: usize) -> Result<(), IoError> {
con.write_lowlevel(b"#2\n*1\n").await?;
// First write a `#` which indicates that the next bytes give the
// prefix length
con.write_lowlevel(&[b'#']).await?;
let group_len_as_bytes = size.to_string().into_bytes();
let group_prefix_len_as_bytes = (group_len_as_bytes.len() + 1).to_string().into_bytes();
// Now write Self's len as bytes
con.write_lowlevel(&group_prefix_len_as_bytes).await?;
// Now write a LF and '&' which signifies the beginning of a datagroup
con.write_lowlevel(&[b'\n', b'&']).await?;
// Now write the number of items in the datagroup as bytes
con.write_lowlevel(&group_len_as_bytes).await?;
// Now write a '\n' character
con.write_lowlevel(&[b'\n']).await?;
Ok(())
}
Box::pin(write_bytes(con, self.0))
}
}
impl Writable for usize {
fn write<'s>(
self,

Loading…
Cancel
Save