Merge pull request #258 from skytable/protocol/compat

Add backwards compatibility for Skyhash 1.0
next
Glydr 2 years ago committed by GitHub
commit 4e90d97ee3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,6 +2,18 @@
All changes in this project will be noted in this file.
## Version 0.8.0
### Additions
- New protocol: Skyhash 2.0
- Reduced bandwidth usage (as much as 50%)
- Even simpler client implementations
- Backward compatibility with Skyhash 1.0:
- Simply set the protocol version you want to use in the config file, env vars or pass it as a CLI
argument
- Even faster implementation, even for Skyhash 1.0
## Version 0.7.5
### Additions

@ -29,14 +29,15 @@ use crate::dbnet::connection::prelude::*;
action!(
/// Returns the number of keys in the database
fn dbsize(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len < 2)?;
ensure_length::<P>(act.len(), |len| len < 2)?;
if act.is_empty() {
let len = get_tbl_ref!(handle, con).count();
con.write_response(len).await?;
con.write_usize(len).await?;
} else {
let raw_entity = unsafe { act.next().unsafe_unwrap() };
let entity = handle_entity!(con, raw_entity);
conwrite!(con, get_tbl!(entity, handle, con).count())?;
con.write_usize(get_tbl!(entity, handle, con).count())
.await?;
}
Ok(())
}

@ -38,7 +38,7 @@ action!(
/// Do note that this function is blocking since it acquires a write lock.
/// It will write an entire datagroup, for this `del` action
fn del(handle: &Corestore, con: &'a mut T, act: ActionIter<'a>) {
ensure_length(act.len(), |size| size != 0)?;
ensure_length::<P>(act.len(), |size| size != 0)?;
let table = get_tbl_ref!(handle, con);
macro_rules! remove {
($engine:expr) => {{
@ -57,12 +57,12 @@ action!(
}
}
if let Some(done_howmany) = done_howmany {
con.write_response(done_howmany).await?;
con.write_usize(done_howmany).await?;
} else {
con.write_response(responses::groups::SERVER_ERR).await?;
con._write_raw(P::RCODE_SERVER_ERR).await?;
}
} else {
compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?;
return util::err(P::RCODE_ENCODING_ERROR);
}
}};
}
@ -74,7 +74,7 @@ action!(
remove!(kvlmap)
}
#[allow(unreachable_patterns)]
_ => conwrite!(con, groups::WRONG_MODEL)?,
_ => return util::err(P::RSTRING_WRONG_MODEL),
}
Ok(())
}

@ -36,7 +36,7 @@ use crate::util::compiler;
action!(
/// Run an `EXISTS` query
fn exists(handle: &Corestore, con: &'a mut T, act: ActionIter<'a>) {
ensure_length(act.len(), |len| len != 0)?;
ensure_length::<P>(act.len(), |len| len != 0)?;
let mut how_many_of_them_exist = 0usize;
macro_rules! exists {
($engine:expr) => {{
@ -45,9 +45,9 @@ action!(
act.for_each(|key| {
how_many_of_them_exist += $engine.exists_unchecked(key) as usize;
});
conwrite!(con, how_many_of_them_exist)?;
con.write_usize(how_many_of_them_exist).await?;
} else {
compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?;
return util::err(P::RCODE_ENCODING_ERROR);
}
}};
}
@ -56,7 +56,7 @@ action!(
DataModel::KV(kve) => exists!(kve),
DataModel::KVExtListmap(kve) => exists!(kve),
#[allow(unreachable_patterns)]
_ => conwrite!(con, groups::WRONG_MODEL)?,
_ => return util::err(P::RSTRING_WRONG_MODEL),
}
Ok(())
}

@ -30,7 +30,7 @@ use crate::queryengine::ActionIter;
action!(
/// Delete all the keys in the database
fn flushdb(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len < 2)?;
ensure_length::<P>(act.len(), |len| len < 2)?;
if registry::state_okay() {
if act.is_empty() {
// flush the current table
@ -41,9 +41,9 @@ action!(
let entity = handle_entity!(con, raw_entity);
get_tbl!(entity, handle, con).truncate_table();
}
conwrite!(con, responses::groups::OKAY)?;
con._write_raw(P::RCODE_OKAY).await?;
} else {
conwrite!(con, responses::groups::SERVER_ERR)?;
con._write_raw(P::RCODE_SERVER_ERR).await?;
}
Ok(())
}

@ -28,19 +28,21 @@
//! This module provides functions to work with `GET` queries
use crate::dbnet::connection::prelude::*;
use crate::resp::writer;
use crate::util::compiler;
action!(
/// Run a `GET` query
fn get(handle: &crate::corestore::Corestore, con: &mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len == 1)?;
let kve = handle.get_table_with::<KVEBlob>()?;
ensure_length::<P>(act.len(), |len| len == 1)?;
let kve = handle.get_table_with::<P, KVEBlob>()?;
unsafe {
match kve.get_cloned(act.next_unchecked()) {
Ok(Some(val)) => writer::write_raw_mono(con, kve.get_value_tsymbol(), &val).await?,
Err(_) => compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?,
Ok(_) => conwrite!(con, groups::NIL)?,
Ok(Some(val)) => {
con.write_mono_length_prefixed_with_tsymbol(&val, kve.get_value_tsymbol())
.await?
}
Err(_) => compiler::cold_err(con._write_raw(P::RCODE_ENCODING_ERROR)).await?,
Ok(_) => con._write_raw(P::RCODE_NIL).await?,
}
}
Ok(())

@ -31,9 +31,9 @@ action!(
///
/// At this moment, `keylen` only supports a single key
fn keylen(handle: &crate::corestore::Corestore, con: &mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len == 1)?;
ensure_length::<P>(act.len(), |len| len == 1)?;
let res: Option<usize> = {
let reader = handle.get_table_with::<KVEBlob>()?;
let reader = handle.get_table_with::<P, KVEBlob>()?;
unsafe {
// UNSAFE(@ohsayan): this is completely safe as we've already checked
// the number of arguments is one
@ -45,10 +45,10 @@ action!(
};
if let Some(value) = res {
// Good, we got the key's length, write it off to the stream
con.write_response(value).await?;
con.write_usize(value).await?;
} else {
// Ah, couldn't find that key
con.write_response(responses::groups::NIL).await?;
con._write_raw(P::RCODE_NIL).await?;
}
Ok(())
}

@ -26,9 +26,6 @@
use crate::corestore::Data;
use crate::dbnet::connection::prelude::*;
use crate::resp::writer;
use crate::resp::writer::TypedArrayWriter;
const LEN: &[u8] = "LEN".as_bytes();
const LIMIT: &[u8] = "LIMIT".as_bytes();
const VALUEAT: &[u8] = "VALUEAT".as_bytes();
@ -66,8 +63,8 @@ action! {
/// - `LGET <mylist> LAST` will return the last item
/// if it exists
fn lget(handle: &Corestore, con: &mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len != 0)?;
let listmap = handle.get_table_with::<KVEList>()?;
ensure_length::<P>(act.len(), |len| len != 0)?;
let listmap = handle.get_table_with::<P, KVEList>()?;
// get the list name
let listname = unsafe { act.next_unchecked() };
// now let us see what we need to do
@ -75,7 +72,7 @@ action! {
() => {
match unsafe { String::from_utf8_lossy(act.next_unchecked()) }.parse::<usize>() {
Ok(int) => int,
Err(_) => return util::err(groups::WRONGTYPE_ERR),
Err(_) => return util::err(P::RCODE_WRONGTYPE_ERR),
}
};
}
@ -84,32 +81,32 @@ action! {
// just return everything in the list
let items = match listmap.list_cloned_full(listname) {
Ok(Some(list)) => list,
Ok(None) => return conwrite!(con, groups::NIL),
Err(()) => return conwrite!(con, groups::ENCODING_ERROR),
Ok(None) => return Err(P::RCODE_NIL.into()),
Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()),
};
writelist!(con, listmap, items);
}
Some(subaction) => {
match subaction.as_ref() {
LEN => {
ensure_length(act.len(), |len| len == 0)?;
ensure_length::<P>(act.len(), |len| len == 0)?;
match listmap.list_len(listname) {
Ok(Some(len)) => conwrite!(con, len)?,
Ok(None) => conwrite!(con, groups::NIL)?,
Err(()) => conwrite!(con, groups::ENCODING_ERROR)?,
Ok(Some(len)) => con.write_usize(len).await?,
Ok(None) => return Err(P::RCODE_NIL.into()),
Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()),
}
}
LIMIT => {
ensure_length(act.len(), |len| len == 1)?;
ensure_length::<P>(act.len(), |len| len == 1)?;
let count = get_numeric_count!();
match listmap.list_cloned(listname, count) {
Ok(Some(items)) => writelist!(con, listmap, items),
Ok(None) => conwrite!(con, groups::NIL)?,
Err(()) => conwrite!(con, groups::ENCODING_ERROR)?,
Ok(None) => return Err(P::RCODE_NIL.into()),
Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()),
}
}
VALUEAT => {
ensure_length(act.len(), |len| len == 1)?;
ensure_length::<P>(act.len(), |len| len == 1)?;
let idx = get_numeric_count!();
let maybe_value = listmap.get(listname).map(|list| {
list.map(|lst| lst.read().get(idx).cloned())
@ -117,58 +114,56 @@ action! {
match maybe_value {
Ok(v) => match v {
Some(Some(value)) => {
unsafe {
// tsymbol is verified
writer::write_raw_mono(con, listmap.get_value_tsymbol(), &value)
.await?;
}
con.write_mono_length_prefixed_with_tsymbol(
&value, listmap.get_value_tsymbol()
).await?;
}
Some(None) => {
// bad index
conwrite!(con, groups::LISTMAP_BAD_INDEX)?;
return Err(P::RSTRING_LISTMAP_BAD_INDEX.into());
}
None => {
// not found
conwrite!(con, groups::NIL)?;
return Err(P::RCODE_NIL.into());
}
}
Err(()) => conwrite!(con, groups::ENCODING_ERROR)?,
Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()),
}
}
LAST => {
ensure_length(act.len(), |len| len == 0)?;
ensure_length::<P>(act.len(), |len| len == 0)?;
let maybe_value = listmap.get(listname).map(|list| {
list.map(|lst| lst.read().last().cloned())
});
match maybe_value {
Ok(v) => match v {
Some(Some(value)) => {
unsafe {
writer::write_raw_mono(con, listmap.get_value_tsymbol(), &value).await?;
}
con.write_mono_length_prefixed_with_tsymbol(
&value, listmap.get_value_tsymbol()
).await?;
},
Some(None) => conwrite!(con, groups::LISTMAP_LIST_IS_EMPTY)?,
None => conwrite!(con, groups::NIL)?,
Some(None) => return Err(P::RSTRING_LISTMAP_LIST_IS_EMPTY.into()),
None => return Err(P::RCODE_NIL.into()),
}
Err(()) => conwrite!(con, groups::ENCODING_ERROR)?,
Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()),
}
}
FIRST => {
ensure_length(act.len(), |len| len == 0)?;
ensure_length::<P>(act.len(), |len| len == 0)?;
let maybe_value = listmap.get(listname).map(|list| {
list.map(|lst| lst.read().first().cloned())
});
match maybe_value {
Ok(v) => match v {
Some(Some(value)) => {
unsafe {
writer::write_raw_mono(con, listmap.get_value_tsymbol(), &value).await?;
}
con.write_mono_length_prefixed_with_tsymbol(
&value, listmap.get_value_tsymbol()
).await?;
},
Some(None) => conwrite!(con, groups::LISTMAP_LIST_IS_EMPTY)?,
None => conwrite!(con, groups::NIL)?,
Some(None) => return Err(P::RSTRING_LISTMAP_LIST_IS_EMPTY.into()),
None => return Err(P::RCODE_NIL.into()),
}
Err(()) => conwrite!(con, groups::ENCODING_ERROR)?,
Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()),
}
}
RANGE => {
@ -176,13 +171,13 @@ action! {
Some(start) => {
let start: usize = match start.parse() {
Ok(v) => v,
Err(_) => return util::err(groups::WRONGTYPE_ERR),
Err(_) => return util::err(P::RCODE_WRONGTYPE_ERR),
};
let mut range = Range::new(start);
if let Some(stop) = act.next_string_owned() {
let stop: usize = match stop.parse() {
Ok(v) => v,
Err(_) => return util::err(groups::WRONGTYPE_ERR),
Err(_) => return util::err(P::RCODE_WRONGTYPE_ERR),
};
range.set_stop(stop);
};
@ -193,17 +188,17 @@ action! {
Some(ret) => {
writelist!(con, listmap, ret);
},
None => conwrite!(con, groups::LISTMAP_BAD_INDEX)?,
None => return Err(P::RSTRING_LISTMAP_BAD_INDEX.into()),
}
}
Ok(None) => conwrite!(con, groups::NIL)?,
Err(()) => conwrite!(con, groups::ENCODING_ERROR)?,
Ok(None) => return Err(P::RCODE_NIL.into()),
Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()),
}
}
None => aerr!(con),
None => return Err(P::RCODE_ACTION_ERR.into()),
}
}
_ => conwrite!(con, groups::UNKNOWN_ACTION)?,
_ => return Err(P::RCODE_UNKNOWN_ACTION.into()),
}
}
}

@ -24,7 +24,6 @@
*
*/
use super::{writer, OKAY_BADIDX_NIL_NLUT};
use crate::corestore::Data;
use crate::dbnet::connection::prelude::*;
use crate::util::compiler;
@ -44,55 +43,55 @@ action! {
/// - `LMOD <mylist> remove <index>`
/// - `LMOD <mylist> clear`
fn lmod(handle: &Corestore, con: &mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len > 1)?;
let listmap = handle.get_table_with::<KVEList>()?;
ensure_length::<P>(act.len(), |len| len > 1)?;
let listmap = handle.get_table_with::<P, KVEList>()?;
// get the list name
let listname = unsafe { act.next_unchecked() };
macro_rules! get_numeric_count {
() => {
match unsafe { String::from_utf8_lossy(act.next_unchecked()) }.parse::<usize>() {
Ok(int) => int,
Err(_) => return conwrite!(con, groups::WRONGTYPE_ERR),
Err(_) => return Err(P::RCODE_WRONGTYPE_ERR.into()),
}
};
}
// now let us see what we need to do
match unsafe { act.next_uppercase_unchecked() }.as_ref() {
CLEAR => {
ensure_length(act.len(), |len| len == 0)?;
ensure_length::<P>(act.len(), |len| len == 0)?;
let list = match listmap.get_inner_ref().get(listname) {
Some(l) => l,
_ => return conwrite!(con, groups::NIL),
_ => return Err(P::RCODE_NIL.into()),
};
let okay = if registry::state_okay() {
list.write().clear();
groups::OKAY
P::RCODE_OKAY
} else {
groups::SERVER_ERR
P::RCODE_SERVER_ERR
};
conwrite!(con, okay)?;
con._write_raw(okay).await?
}
PUSH => {
ensure_boolean_or_aerr(!act.is_empty())?;
ensure_boolean_or_aerr::<P>(!act.is_empty())?;
let list = match listmap.get_inner_ref().get(listname) {
Some(l) => l,
_ => return conwrite!(con, groups::NIL),
_ => return Err(P::RCODE_NIL.into()),
};
let venc_ok = listmap.get_val_encoder();
let ret = if compiler::likely(act.as_ref().all(venc_ok)) {
if registry::state_okay() {
list.write().extend(act.map(Data::copy_from_slice));
groups::OKAY
P::RCODE_OKAY
} else {
groups::SERVER_ERR
P::RCODE_SERVER_ERR
}
} else {
groups::ENCODING_ERROR
P::RCODE_ENCODING_ERROR
};
conwrite!(con, ret)?;
con._write_raw(ret).await?
}
REMOVE => {
ensure_length(act.len(), |len| len == 1)?;
ensure_length::<P>(act.len(), |len| len == 1)?;
let idx_to_remove = get_numeric_count!();
if registry::state_okay() {
let maybe_value = listmap.get_inner_ref().get(listname).map(|list| {
@ -104,13 +103,13 @@ action! {
false
}
});
conwrite!(con, OKAY_BADIDX_NIL_NLUT[maybe_value])?;
con._write_raw(P::OKAY_BADIDX_NIL_NLUT[maybe_value]).await?
} else {
conwrite!(con, groups::SERVER_ERR)?;
return Err(P::RCODE_SERVER_ERR.into());
}
}
INSERT => {
ensure_length(act.len(), |len| len == 2)?;
ensure_length::<P>(act.len(), |len| len == 2)?;
let idx_to_insert_at = get_numeric_count!();
let bts = unsafe { act.next_unchecked() };
let ret = if compiler::likely(listmap.is_val_ok(bts)) {
@ -128,21 +127,21 @@ action! {
false
}
}),
Err(()) => return conwrite!(con, groups::ENCODING_ERROR),
Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()),
};
OKAY_BADIDX_NIL_NLUT[maybe_insert]
P::OKAY_BADIDX_NIL_NLUT[maybe_insert]
} else {
// flush broken; server err
groups::SERVER_ERR
P::RCODE_SERVER_ERR
}
} else {
// encoding failed, uh
groups::ENCODING_ERROR
P::RCODE_ENCODING_ERROR
};
conwrite!(con, ret)?;
con._write_raw(ret).await?
}
POP => {
ensure_length(act.len(), |len| len < 2)?;
ensure_length::<P>(act.len(), |len| len < 2)?;
let idx = if act.len() == 1 {
// we have an idx
Some(get_numeric_count!())
@ -165,24 +164,24 @@ action! {
wlock.pop()
}
}),
Err(()) => return conwrite!(con, groups::ENCODING_ERROR),
Err(()) => return Err(P::RCODE_ENCODING_ERROR.into()),
};
match maybe_pop {
Some(Some(val)) => {
unsafe {
writer::write_raw_mono(con, listmap.get_value_tsymbol(), &val).await?;
}
con.write_mono_length_prefixed_with_tsymbol(
&val, listmap.get_value_tsymbol()
).await?;
}
Some(None) => {
conwrite!(con, groups::LISTMAP_BAD_INDEX)?;
con._write_raw(P::RSTRING_LISTMAP_BAD_INDEX).await?;
}
None => conwrite!(con, groups::NIL)?,
None => con._write_raw(P::RCODE_NIL).await?,
}
} else {
conwrite!(con, groups::SERVER_ERR)?;
con._write_raw(P::RCODE_SERVER_ERR).await?
}
}
_ => conwrite!(con, groups::UNKNOWN_ACTION)?,
_ => con._write_raw(P::RCODE_UNKNOWN_ACTION).await?,
}
Ok(())
}

@ -26,11 +26,10 @@
macro_rules! writelist {
($con:expr, $listmap:expr, $items:expr) => {{
let mut typed_array_writer =
unsafe { TypedArrayWriter::new($con, $listmap.get_value_tsymbol(), $items.len()) }
.await?;
$con.write_typed_non_null_array_header($items.len(), $listmap.get_value_tsymbol())
.await?;
for item in $items {
typed_array_writer.write_element(item).await?;
$con.write_typed_non_null_array_element(&item).await?;
}
}};
}

@ -30,23 +30,16 @@ mod macros;
pub mod lget;
pub mod lmod;
use crate::corestore::booltable::BytesBoolTable;
use crate::corestore::booltable::BytesNicheLUT;
use crate::corestore::Data;
use crate::dbnet::connection::prelude::*;
use crate::kvengine::LockedVec;
use crate::resp::writer;
const OKAY_OVW_BLUT: BytesBoolTable = BytesBoolTable::new(groups::OKAY, groups::OVERWRITE_ERR);
const OKAY_BADIDX_NIL_NLUT: BytesNicheLUT =
BytesNicheLUT::new(groups::NIL, groups::OKAY, groups::LISTMAP_BAD_INDEX);
action! {
/// Handle an `LSET` query for the list model
/// Syntax: `LSET <listname> <values ...>`
fn lset(handle: &Corestore, con: &mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len > 0)?;
let listmap = handle.get_table_with::<KVEList>()?;
ensure_length::<P>(act.len(), |len| len > 0)?;
let listmap = handle.get_table_with::<P, KVEList>()?;
let listname = unsafe { act.next_unchecked_bytes() };
let list = listmap.get_inner_ref();
if registry::state_okay() {
@ -57,9 +50,9 @@ action! {
} else {
false
};
conwrite!(con, OKAY_OVW_BLUT[did])?;
con._write_raw(P::OKAY_OVW_BLUT[did]).await?
} else {
conwrite!(con, groups::SERVER_ERR)?;
con._write_raw(P::RCODE_SERVER_ERR).await?
}
Ok(())
}

@ -27,14 +27,13 @@
use crate::corestore::table::DataModel;
use crate::corestore::Data;
use crate::dbnet::connection::prelude::*;
use crate::resp::writer::TypedArrayWriter;
const DEFAULT_COUNT: usize = 10;
action!(
/// Run an `LSKEYS` query
fn lskeys(handle: &crate::corestore::Corestore, con: &mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |size| size < 4)?;
ensure_length::<P>(act.len(), |size| size < 4)?;
let (table, count) = if act.is_empty() {
(get_tbl!(handle, con), DEFAULT_COUNT)
} else if act.len() == 1 {
@ -45,7 +44,7 @@ action!(
let count = if let Ok(cnt) = String::from_utf8_lossy(nextret).parse::<usize>() {
cnt
} else {
return util::err(groups::WRONGTYPE_ERR);
return util::err(P::RCODE_WRONGTYPE_ERR);
};
(get_tbl!(handle, con), count)
} else {
@ -61,7 +60,7 @@ action!(
let count = if let Ok(cnt) = String::from_utf8_lossy(count_ret).parse::<usize>() {
cnt
} else {
return util::err(groups::WRONGTYPE_ERR);
return util::err(P::RCODE_WRONGTYPE_ERR);
};
(get_tbl!(entity, handle, con), count)
};
@ -73,13 +72,9 @@ action!(
DataModel::KV(kv) => kv.get_inner_ref().get_keys(count),
DataModel::KVExtListmap(kv) => kv.get_inner_ref().get_keys(count),
};
let mut writer = unsafe {
// SAFETY: We have checked kty ourselves
TypedArrayWriter::new(con, tsymbol, items.len())
}
.await?;
con.write_typed_non_null_array_header(items.len(), tsymbol).await?;
for key in items {
writer.write_element(key).await?;
con.write_typed_non_null_array_element(&key).await?;
}
Ok(())
}

@ -46,31 +46,17 @@ macro_rules! is_lowbit_unset {
};
}
#[macro_export]
macro_rules! conwrite {
($con:expr, $what:expr) => {
$con.write_response($what)
.await
.map_err(|e| $crate::actions::ActionError::IoError(e))
};
}
#[macro_export]
macro_rules! aerr {
($con:expr) => {
return conwrite!($con, $crate::protocol::responses::groups::ACTION_ERR)
};
}
#[macro_export]
macro_rules! get_tbl {
($entity:expr, $store:expr, $con:expr) => {{
$store.get_table($entity)?
$crate::actions::translate_ddl_error::<P, ::std::sync::Arc<$crate::corestore::table::Table>>(
$store.get_table($entity),
)?
}};
($store:expr, $con:expr) => {{
match $store.get_ctable() {
Some(tbl) => tbl,
None => return $crate::util::err($crate::protocol::responses::groups::DEFAULT_UNSET),
None => return $crate::util::err(P::RSTRING_DEFAULT_UNSET),
}
}};
}
@ -80,7 +66,7 @@ macro_rules! get_tbl_ref {
($store:expr, $con:expr) => {{
match $store.get_ctable_ref() {
Some(tbl) => tbl,
None => return $crate::util::err($crate::protocol::responses::groups::DEFAULT_UNSET),
None => return $crate::util::err(P::RSTRING_DEFAULT_UNSET),
}
}};
}
@ -88,9 +74,9 @@ macro_rules! get_tbl_ref {
#[macro_export]
macro_rules! handle_entity {
($con:expr, $ident:expr) => {{
match $crate::queryengine::parser::Entity::from_slice(&$ident) {
match $crate::queryengine::parser::Entity::from_slice::<P>(&$ident) {
Ok(e) => e,
Err(e) => return conwrite!($con, e),
Err(e) => return Err(e.into()),
}
}};
}

@ -27,30 +27,26 @@
use crate::dbnet::connection::prelude::*;
use crate::kvengine::encoding::ENCODING_LUT_ITER;
use crate::queryengine::ActionIter;
use crate::resp::writer::TypedArrayWriter;
use crate::util::compiler;
action!(
/// Run an `MGET` query
///
fn mget(handle: &crate::corestore::Corestore, con: &mut T, act: ActionIter<'a>) {
ensure_length(act.len(), |size| size != 0)?;
let kve = handle.get_table_with::<KVEBlob>()?;
ensure_length::<P>(act.len(), |size| size != 0)?;
let kve = handle.get_table_with::<P, KVEBlob>()?;
let encoding_is_okay = ENCODING_LUT_ITER[kve.is_key_encoded()](act.as_ref());
if compiler::likely(encoding_is_okay) {
let mut writer = unsafe {
// SAFETY: We are getting the value type ourselves
TypedArrayWriter::new(con, kve.get_value_tsymbol(), act.len())
}
.await?;
con.write_typed_array_header(act.len(), kve.get_value_tsymbol())
.await?;
for key in act {
match kve.get_cloned_unchecked(key) {
Some(v) => writer.write_element(&v).await?,
None => writer.write_null().await?,
Some(v) => con.write_typed_array_element(&v).await?,
None => con.write_typed_array_element_null().await?,
}
}
} else {
compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?;
return util::err(P::RCODE_ENCODING_ERROR);
}
Ok(())
}

@ -51,7 +51,7 @@ pub mod update;
pub mod uset;
pub mod whereami;
use crate::corestore::memstore::DdlError;
use crate::protocol::responses::groups;
use crate::protocol::interface::ProtocolSpec;
use crate::util;
use std::io::Error as IoError;
@ -65,6 +65,16 @@ pub enum ActionError {
IoError(std::io::Error),
}
impl PartialEq for ActionError {
fn eq(&self, other: &Self) -> bool {
match (self, other) {
(Self::ActionError(a1), Self::ActionError(a2)) => a1 == a2,
(Self::IoError(ioe1), Self::IoError(ioe2)) => ioe1.to_string() == ioe2.to_string(),
_ => false,
}
}
}
impl From<&'static [u8]> for ActionError {
fn from(e: &'static [u8]) -> Self {
Self::ActionError(e)
@ -77,36 +87,44 @@ impl From<IoError> for ActionError {
}
}
impl From<DdlError> for ActionError {
fn from(e: DdlError) -> Self {
let ret = match e {
DdlError::AlreadyExists => groups::ALREADY_EXISTS,
DdlError::DdlTransactionFailure => groups::DDL_TRANSACTIONAL_FAILURE,
DdlError::DefaultNotFound => groups::DEFAULT_UNSET,
DdlError::NotEmpty => groups::KEYSPACE_NOT_EMPTY,
DdlError::NotReady => groups::NOT_READY,
DdlError::ObjectNotFound => groups::CONTAINER_NOT_FOUND,
DdlError::ProtectedObject => groups::PROTECTED_OBJECT,
DdlError::StillInUse => groups::STILL_IN_USE,
DdlError::WrongModel => groups::WRONG_MODEL,
};
Self::ActionError(ret)
#[cold]
#[inline(never)]
fn map_ddl_error_to_status<P: ProtocolSpec>(e: DdlError) -> ActionError {
let r = match e {
DdlError::AlreadyExists => P::RSTRING_ALREADY_EXISTS,
DdlError::DdlTransactionFailure => P::RSTRING_DDL_TRANSACTIONAL_FAILURE,
DdlError::DefaultNotFound => P::RSTRING_DEFAULT_UNSET,
DdlError::NotEmpty => P::RSTRING_KEYSPACE_NOT_EMPTY,
DdlError::NotReady => P::RSTRING_NOT_READY,
DdlError::ObjectNotFound => P::RSTRING_CONTAINER_NOT_FOUND,
DdlError::ProtectedObject => P::RSTRING_PROTECTED_OBJECT,
DdlError::StillInUse => P::RSTRING_STILL_IN_USE,
DdlError::WrongModel => P::RSTRING_WRONG_MODEL,
};
ActionError::ActionError(r)
}
#[inline(always)]
pub fn translate_ddl_error<P: ProtocolSpec, T>(r: Result<T, DdlError>) -> Result<T, ActionError> {
match r {
Ok(r) => Ok(r),
Err(e) => Err(map_ddl_error_to_status::<P>(e)),
}
}
pub fn ensure_length(len: usize, is_valid: fn(usize) -> bool) -> ActionResult<()> {
pub fn ensure_length<P: ProtocolSpec>(len: usize, is_valid: fn(usize) -> bool) -> ActionResult<()> {
if util::compiler::likely(is_valid(len)) {
Ok(())
} else {
util::err(groups::ACTION_ERR)
util::err(P::RCODE_ACTION_ERR)
}
}
pub fn ensure_boolean_or_aerr(boolean: bool) -> ActionResult<()> {
pub fn ensure_boolean_or_aerr<P: ProtocolSpec>(boolean: bool) -> ActionResult<()> {
if util::compiler::likely(boolean) {
Ok(())
} else {
util::err(groups::ACTION_ERR)
util::err(P::RCODE_ACTION_ERR)
}
}
@ -121,16 +139,16 @@ pub fn ensure_cond_or_err(cond: bool, err: &'static [u8]) -> ActionResult<()> {
pub mod heya {
//! Respond to `HEYA` queries
use crate::dbnet::connection::prelude::*;
use crate::resp::BytesWrapper;
action!(
/// Returns a `HEY!` `Response`
fn heya(_handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len == 0 || len == 1)?;
ensure_length::<P>(act.len(), |len| len == 0 || len == 1)?;
if act.len() == 1 {
let raw_byte = unsafe { act.next_unchecked_bytes() };
con.write_response(BytesWrapper(raw_byte)).await?;
let raw_byte = unsafe { act.next_unchecked() };
con.write_mono_length_prefixed_with_tsymbol(raw_byte, b'+')
.await?;
} else {
con.write_response(responses::groups::HEYA).await?;
con._write_raw(P::ELEMRESP_HEYA).await?;
}
Ok(())
}

@ -27,36 +27,31 @@
use crate::corestore;
use crate::dbnet::connection::prelude::*;
use crate::kvengine::encoding::ENCODING_LUT_ITER;
use crate::protocol::responses;
use crate::queryengine::ActionIter;
use crate::resp::writer::TypedArrayWriter;
use crate::util::compiler;
action!(
/// Run an MPOP action
fn mpop(handle: &corestore::Corestore, con: &mut T, act: ActionIter<'a>) {
ensure_length(act.len(), |len| len != 0)?;
ensure_length::<P>(act.len(), |len| len != 0)?;
if registry::state_okay() {
let kve = handle.get_table_with::<KVEBlob>()?;
let kve = handle.get_table_with::<P, KVEBlob>()?;
let encoding_is_okay = ENCODING_LUT_ITER[kve.is_key_encoded()](act.as_ref());
if compiler::likely(encoding_is_okay) {
let mut writer = unsafe {
// SAFETY: We have verified the tsymbol ourselves
TypedArrayWriter::new(con, kve.get_value_tsymbol(), act.len())
}
.await?;
con.write_typed_array_header(act.len(), kve.get_value_tsymbol())
.await?;
for key in act {
match kve.pop_unchecked(key) {
Some(val) => writer.write_element(val).await?,
None => writer.write_null().await?,
Some(val) => con.write_typed_array_element(&val).await?,
None => con.write_typed_array_element_null().await?,
}
}
} else {
compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?;
return util::err(P::RCODE_ENCODING_ERROR);
}
} else {
// don't begin the operation at all if the database is poisoned
con.write_response(responses::groups::SERVER_ERR).await?;
return util::err(P::RCODE_SERVER_ERR);
}
Ok(())
}

@ -33,8 +33,8 @@ action!(
/// Run an `MSET` query
fn mset(handle: &crate::corestore::Corestore, con: &mut T, mut act: ActionIter<'a>) {
let howmany = act.len();
ensure_length(howmany, |size| size & 1 == 0 && size != 0)?;
let kve = handle.get_table_with::<KVEBlob>()?;
ensure_length::<P>(howmany, |size| size & 1 == 0 && size != 0)?;
let kve = handle.get_table_with::<P, KVEBlob>()?;
let encoding_is_okay = ENCODING_LUT_ITER_PAIR[kve.get_encoding_tuple()](&act);
if compiler::likely(encoding_is_okay) {
let done_howmany: Option<usize> = if registry::state_okay() {
@ -49,12 +49,12 @@ action!(
None
};
if let Some(done_howmany) = done_howmany {
con.write_response(done_howmany as usize).await?;
con.write_usize(done_howmany).await?;
} else {
con.write_response(responses::groups::SERVER_ERR).await?;
return util::err(P::RCODE_SERVER_ERR);
}
} else {
compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?;
return util::err(P::RCODE_ENCODING_ERROR);
}
Ok(())
}

@ -33,8 +33,8 @@ action!(
/// Run an `MUPDATE` query
fn mupdate(handle: &crate::corestore::Corestore, con: &mut T, mut act: ActionIter<'a>) {
let howmany = act.len();
ensure_length(howmany, |size| size & 1 == 0 && size != 0)?;
let kve = handle.get_table_with::<KVEBlob>()?;
ensure_length::<P>(howmany, |size| size & 1 == 0 && size != 0)?;
let kve = handle.get_table_with::<P, KVEBlob>()?;
let encoding_is_okay = ENCODING_LUT_ITER_PAIR[kve.get_encoding_tuple()](&act);
let done_howmany: Option<usize>;
if compiler::likely(encoding_is_okay) {
@ -50,12 +50,12 @@ action!(
done_howmany = None;
}
if let Some(done_howmany) = done_howmany {
con.write_response(done_howmany as usize).await?;
con.write_usize(done_howmany).await?;
} else {
con.write_response(responses::groups::SERVER_ERR).await?;
return util::err(P::RCODE_SERVER_ERR);
}
} else {
compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?;
return util::err(P::RCODE_ENCODING_ERROR);
}
Ok(())
}

@ -25,29 +25,25 @@
*/
use crate::dbnet::connection::prelude::*;
use crate::resp::writer;
use crate::util::compiler;
action! {
fn pop(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len == 1)?;
ensure_length::<P>(act.len(), |len| len == 1)?;
let key = unsafe {
// SAFETY: We have checked for there to be one arg
act.next_unchecked()
};
if registry::state_okay() {
let kve = handle.get_table_with::<KVEBlob>()?;
let tsymbol = kve.get_value_tsymbol();
let kve = handle.get_table_with::<P, KVEBlob>()?;
match kve.pop(key) {
Ok(Some(val)) => unsafe {
// SAFETY: We have verified the tsymbol ourselves
writer::write_raw_mono(con, tsymbol, &val).await?
},
Ok(None) => conwrite!(con, groups::NIL)?,
Err(()) => compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?,
Ok(Some(val)) => con.write_mono_length_prefixed_with_tsymbol(
&val, kve.get_value_tsymbol()
).await?,
Ok(None) => return util::err(P::RCODE_NIL),
Err(()) => return util::err(P::RCODE_ENCODING_ERROR),
}
} else {
conwrite!(con, groups::SERVER_ERR)?;
return util::err(P::RCODE_SERVER_ERR);
}
Ok(())
}

@ -28,21 +28,17 @@
//! This module provides functions to work with `SET` queries
use crate::corestore;
use crate::corestore::booltable::BytesNicheLUT;
use crate::dbnet::connection::prelude::*;
use crate::queryengine::ActionIter;
use corestore::Data;
const SET_NLUT: BytesNicheLUT =
BytesNicheLUT::new(groups::ENCODING_ERROR, groups::OKAY, groups::OVERWRITE_ERR);
action!(
/// Run a `SET` query
fn set(handle: &crate::corestore::Corestore, con: &mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len == 2)?;
ensure_length::<P>(act.len(), |len| len == 2)?;
if registry::state_okay() {
let did_we = {
let writer = handle.get_table_with::<KVEBlob>()?;
let writer = handle.get_table_with::<P, KVEBlob>()?;
match unsafe {
// UNSAFE(@ohsayan): This is completely safe as we've already checked
// that there are exactly 2 arguments
@ -56,9 +52,9 @@ action!(
Err(()) => None,
}
};
conwrite!(con, SET_NLUT[did_we])?;
con._write_raw(P::SET_NLUT[did_we]).await?;
} else {
conwrite!(con, groups::SERVER_ERR)?;
con._write_raw(P::RCODE_SERVER_ERR).await?;
}
Ok(())
}

@ -37,8 +37,8 @@ action! {
/// This either returns `Okay` if all the keys were `del`eted, or it returns a
/// `Nil`, which is code `1`
fn sdel(handle: &crate::corestore::Corestore, con: &mut T, act: ActionIter<'a>) {
ensure_length(act.len(), |len| len != 0)?;
let kve = handle.get_table_with::<KVEBlob>()?;
ensure_length::<P>(act.len(), |len| len != 0)?;
let kve = handle.get_table_with::<P, KVEBlob>()?;
if registry::state_okay() {
// guarantee one check: consistency
let key_encoder = kve.get_key_encoder();
@ -48,15 +48,15 @@ action! {
self::snapshot_and_del(kve, key_encoder, act.into_inner())
};
match outcome {
StrongActionResult::Okay => conwrite!(con, groups::OKAY)?,
StrongActionResult::Okay => con._write_raw(P::RCODE_OKAY).await?,
StrongActionResult::Nil => {
// good, it failed because some key didn't exist
conwrite!(con, groups::NIL)?;
return util::err(P::RCODE_NIL);
},
StrongActionResult::ServerError => conwrite!(con, groups::SERVER_ERR)?,
StrongActionResult::ServerError => return util::err(P::RCODE_SERVER_ERR),
StrongActionResult::EncodingError => {
// error we love to hate: encoding error, ugh
compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?
return util::err(P::RCODE_ENCODING_ERROR);
},
StrongActionResult::OverwriteError => unsafe {
// SAFETY check: never the case
@ -64,7 +64,7 @@ action! {
}
}
} else {
conwrite!(con, groups::SERVER_ERR)?;
return util::err(P::RCODE_SERVER_ERR);
}
Ok(())
}

@ -40,8 +40,8 @@ action! {
/// `Overwrite Error` or code `2`
fn sset(handle: &crate::corestore::Corestore, con: &mut T, act: ActionIter<'a>) {
let howmany = act.len();
ensure_length(howmany, |size| size & 1 == 0 && size != 0)?;
let kve = handle.get_table_with::<KVEBlob>()?;
ensure_length::<P>(howmany, |size| size & 1 == 0 && size != 0)?;
let kve = handle.get_table_with::<P, KVEBlob>()?;
if registry::state_okay() {
let encoder = kve.get_double_encoder();
let outcome = unsafe {
@ -50,12 +50,12 @@ action! {
self::snapshot_and_insert(kve, encoder, act.into_inner())
};
match outcome {
StrongActionResult::Okay => conwrite!(con, groups::OKAY)?,
StrongActionResult::OverwriteError => conwrite!(con, groups::OVERWRITE_ERR)?,
StrongActionResult::ServerError => conwrite!(con, groups::SERVER_ERR)?,
StrongActionResult::Okay => con._write_raw(P::RCODE_OKAY).await?,
StrongActionResult::OverwriteError => return util::err(P::RCODE_OVERWRITE_ERR),
StrongActionResult::ServerError => return util::err(P::RCODE_SERVER_ERR),
StrongActionResult::EncodingError => {
// error we love to hate: encoding error, ugh
compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?
return util::err(P::RCODE_ENCODING_ERROR);
},
StrongActionResult::Nil => unsafe {
// SAFETY check: never the case
@ -63,7 +63,7 @@ action! {
}
}
} else {
conwrite!(con, groups::SERVER_ERR)?;
return util::err(P::RCODE_SERVER_ERR);
}
Ok(())
}

@ -40,8 +40,8 @@ action! {
/// or code `1`
fn supdate(handle: &crate::corestore::Corestore, con: &mut T, act: ActionIter<'a>) {
let howmany = act.len();
ensure_length(howmany, |size| size & 1 == 0 && size != 0)?;
let kve = handle.get_table_with::<KVEBlob>()?;
ensure_length::<P>(howmany, |size| size & 1 == 0 && size != 0)?;
let kve = handle.get_table_with::<P, KVEBlob>()?;
if registry::state_okay() {
let encoder = kve.get_double_encoder();
let outcome = unsafe {
@ -49,15 +49,15 @@ action! {
self::snapshot_and_update(kve, encoder, act.into_inner())
};
match outcome {
StrongActionResult::Okay => conwrite!(con, groups::OKAY)?,
StrongActionResult::Okay => con._write_raw(P::RCODE_OKAY).await?,
StrongActionResult::Nil => {
// good, it failed because some key didn't exist
conwrite!(con, groups::NIL)?;
return util::err(P::RCODE_NIL);
},
StrongActionResult::ServerError => conwrite!(con, groups::SERVER_ERR)?,
StrongActionResult::ServerError => return util::err(P::RCODE_SERVER_ERR),
StrongActionResult::EncodingError => {
// error we love to hate: encoding error, ugh
compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?
return util::err(P::RCODE_ENCODING_ERROR);
},
StrongActionResult::OverwriteError => unsafe {
// SAFETY check: never the case
@ -65,7 +65,7 @@ action! {
}
}
} else {
conwrite!(con, groups::SERVER_ERR)?;
return util::err(P::RCODE_SERVER_ERR);
}
Ok(())
}

@ -28,20 +28,16 @@
//! This module provides functions to work with `UPDATE` queries
//!
use crate::corestore::booltable::BytesNicheLUT;
use crate::corestore::Data;
use crate::dbnet::connection::prelude::*;
const UPDATE_NLUT: BytesNicheLUT =
BytesNicheLUT::new(groups::ENCODING_ERROR, groups::OKAY, groups::NIL);
action!(
/// Run an `UPDATE` query
fn update(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len == 2)?;
ensure_length::<P>(act.len(), |len| len == 2)?;
if registry::state_okay() {
let did_we = {
let writer = handle.get_table_with::<KVEBlob>()?;
let writer = handle.get_table_with::<P, KVEBlob>()?;
match unsafe {
// UNSAFE(@ohsayan): This is completely safe as we've already checked
// that there are exactly 2 arguments
@ -55,9 +51,9 @@ action!(
Err(()) => None,
}
};
conwrite!(con, UPDATE_NLUT[did_we])?;
con._write_raw(P::UPDATE_NLUT[did_we]).await?;
} else {
conwrite!(con, groups::SERVER_ERR)?;
return util::err(P::RCODE_SERVER_ERR);
}
Ok(())
}

@ -36,20 +36,20 @@ action!(
/// This is like "INSERT or UPDATE"
fn uset(handle: &crate::corestore::Corestore, con: &mut T, mut act: ActionIter<'a>) {
let howmany = act.len();
ensure_length(howmany, |size| size & 1 == 0 && size != 0)?;
let kve = handle.get_table_with::<KVEBlob>()?;
ensure_length::<P>(howmany, |size| size & 1 == 0 && size != 0)?;
let kve = handle.get_table_with::<P, KVEBlob>()?;
let encoding_is_okay = ENCODING_LUT_ITER_PAIR[kve.get_encoding_tuple()](&act);
if compiler::likely(encoding_is_okay) {
if registry::state_okay() {
while let (Some(key), Some(val)) = (act.next(), act.next()) {
kve.upsert_unchecked(Data::copy_from_slice(key), Data::copy_from_slice(val));
}
conwrite!(con, howmany / 2)?;
con.write_usize(howmany / 2).await?;
} else {
conwrite!(con, groups::SERVER_ERR)?;
return util::err(P::RCODE_SERVER_ERR);
}
} else {
compiler::cold_err(conwrite!(con, groups::ENCODING_ERROR))?;
return util::err(P::RCODE_ENCODING_ERROR);
}
Ok(())
}

@ -25,20 +25,19 @@
*/
use crate::dbnet::connection::prelude::*;
use crate::resp::writer::NonNullArrayWriter;
action! {
fn whereami(store: &Corestore, con: &mut T, act: ActionIter<'a>) {
ensure_length(act.len(), |len| len == 0)?;
ensure_length::<P>(act.len(), |len| len == 0)?;
match store.get_ids() {
(Some(ks), Some(tbl)) => {
let mut writer = unsafe { NonNullArrayWriter::new(con, b'+', 2).await? };
writer.write_element(ks).await?;
writer.write_element(tbl).await?;
con.write_typed_non_null_array_header(2, b'+').await?;
con.write_typed_non_null_array_element(ks).await?;
con.write_typed_non_null_array_element(tbl).await?;
},
(Some(ks), None) => {
let mut writer = unsafe { NonNullArrayWriter::new(con, b'+', 1).await? };
writer.write_element(ks).await?;
con.write_typed_non_null_array_header(1, b'+').await?;
con.write_typed_non_null_array_element(ks).await?;
},
_ => unsafe { impossible!() }
}

@ -38,10 +38,10 @@ action!(
if act.is_empty() {
// traditional mksnap
match engine.mksnap(handle.clone_store()).await {
SnapshotActionResult::Ok => conwrite!(con, groups::OKAY)?,
SnapshotActionResult::Failure => conwrite!(con, groups::SERVER_ERR)?,
SnapshotActionResult::Disabled => conwrite!(con, groups::SNAPSHOT_DISABLED)?,
SnapshotActionResult::Busy => conwrite!(con, groups::SNAPSHOT_BUSY)?,
SnapshotActionResult::Ok => con._write_raw(P::RCODE_OKAY).await?,
SnapshotActionResult::Failure => return util::err(P::RCODE_SERVER_ERR),
SnapshotActionResult::Disabled => return util::err(P::RSTRING_SNAPSHOT_DISABLED),
SnapshotActionResult::Busy => return util::err(P::RSTRING_SNAPSHOT_BUSY),
_ => unsafe { impossible!() },
}
} else if act.len() == 1 {
@ -51,7 +51,7 @@ action!(
act.next_unchecked_bytes()
};
if !encoding::is_utf8(&name) {
return conwrite!(con, groups::ENCODING_ERROR);
return util::err(P::RCODE_ENCODING_ERROR);
}
// SECURITY: Check for directory traversal syntax
@ -72,19 +72,21 @@ action!(
.count()
!= 0;
if illegal_snapshot {
return conwrite!(con, groups::SNAPSHOT_ILLEGAL_NAME);
return util::err(P::RSTRING_SNAPSHOT_ILLEGAL_NAME);
}
// now make the snapshot
match engine.mkrsnap(name, handle.clone_store()).await {
SnapshotActionResult::Ok => conwrite!(con, groups::OKAY)?,
SnapshotActionResult::Failure => conwrite!(con, groups::SERVER_ERR)?,
SnapshotActionResult::Busy => conwrite!(con, groups::SNAPSHOT_BUSY)?,
SnapshotActionResult::AlreadyExists => conwrite!(con, groups::SNAPSHOT_DUPLICATE)?,
SnapshotActionResult::Ok => con._write_raw(P::RCODE_OKAY).await?,
SnapshotActionResult::Failure => return util::err(P::RCODE_SERVER_ERR),
SnapshotActionResult::Busy => return util::err(P::RSTRING_SNAPSHOT_BUSY),
SnapshotActionResult::AlreadyExists => {
return util::err(P::RSTRING_SNAPSHOT_DUPLICATE)
}
_ => unsafe { impossible!() },
}
} else {
conwrite!(con, groups::ACTION_ERR)?;
return util::err(P::RCODE_ACTION_ERR);
}
Ok(())
}

@ -25,9 +25,7 @@
*/
use crate::{
corestore::booltable::BoolTable,
dbnet::connection::prelude::*,
protocol::{PROTOCOL_VERSION, PROTOCOL_VERSIONSTRING},
corestore::booltable::BoolTable, dbnet::connection::prelude::*,
storage::v1::interface::DIR_ROOT,
};
use ::libsky::VERSION;
@ -47,18 +45,18 @@ const HEALTH_TABLE: BoolTable<&str> = BoolTable::new("good", "critical");
action! {
fn sys(_handle: &Corestore, con: &mut T, iter: ActionIter<'_>) {
let mut iter = iter;
ensure_boolean_or_aerr(iter.len() == 2)?;
ensure_boolean_or_aerr::<P>(iter.len() == 2)?;
match unsafe { iter.next_lowercase_unchecked() }.as_ref() {
INFO => sys_info(con, &mut iter).await,
METRIC => sys_metric(con, &mut iter).await,
_ => util::err(groups::UNKNOWN_ACTION),
_ => util::err(P::RCODE_UNKNOWN_ACTION),
}
}
fn sys_info(con: &mut T, iter: &mut ActionIter<'_>) {
match unsafe { iter.next_lowercase_unchecked() }.as_ref() {
INFO_PROTOCOL => con.write_response(PROTOCOL_VERSIONSTRING).await?,
INFO_PROTOVER => con.write_response(PROTOCOL_VERSION).await?,
INFO_VERSION => con.write_response(VERSION).await?,
INFO_PROTOCOL => con.write_string(P::PROTOCOL_VERSIONSTRING).await?,
INFO_PROTOVER => con.write_float(P::PROTOCOL_VERSION).await?,
INFO_VERSION => con.write_string(VERSION).await?,
_ => return util::err(ERR_UNKNOWN_PROPERTY),
}
Ok(())
@ -66,14 +64,14 @@ action! {
fn sys_metric(con: &mut T, iter: &mut ActionIter<'_>) {
match unsafe { iter.next_lowercase_unchecked() }.as_ref() {
METRIC_HEALTH => {
con.write_response(HEALTH_TABLE[registry::state_okay()]).await?
con.write_string(HEALTH_TABLE[registry::state_okay()]).await?
}
METRIC_STORAGE_USAGE => {
match util::os::dirsize(DIR_ROOT) {
Ok(size) => con.write_response(size).await?,
Ok(size) => con.write_int64(size).await?,
Err(e) => {
log::error!("Failed to get storage usage with: {e}");
con.write_response(groups::SERVER_ERR).await?
return util::err(P::RCODE_SERVER_ERR);
},
}
}

@ -57,6 +57,7 @@ pub async fn run(
snapshot,
maxcon,
auth,
protocol,
..
}: ConfigurationSet,
restore_filepath: Option<String>,
@ -100,8 +101,15 @@ pub async fn run(
let termsig =
TerminationSignal::init().map_err(|e| Error::ioerror_extra(e, "binding to signals"))?;
// start the server (single or multiple listeners)
let mut server =
dbnet::connect(ports, maxcon, db.clone(), auth_provider, signal.clone()).await?;
let mut server = dbnet::connect(
ports,
protocol,
maxcon,
db.clone(),
auth_provider,
signal.clone(),
)
.await?;
tokio::select! {
_ = server.run_server() => {},

@ -1,70 +0,0 @@
/*
* Created on Sun Mar 06 2022
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2022, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
use crate::actions::ActionError;
/// Skyhash respstring: already claimed (user was already claimed)
pub const AUTH_ERROR_ALREADYCLAIMED: &[u8] = b"!err-auth-already-claimed\n";
/// Skyhash respcode(10): bad credentials (either bad creds or invalid user)
pub const AUTH_CODE_BAD_CREDENTIALS: &[u8] = b"!10\n";
/// Skyhash respstring: auth is disabled
pub const AUTH_ERROR_DISABLED: &[u8] = b"!err-auth-disabled\n";
/// Skyhash respcode(11): Insufficient permissions (same for anonymous user)
pub const AUTH_CODE_PERMS: &[u8] = b"!11\n";
/// Skyhash respstring: ID is too long
pub const AUTH_ERROR_ILLEGAL_USERNAME: &[u8] = b"!err-auth-illegal-username\n";
/// Skyhash respstring: ID is protected/in use
pub const AUTH_ERROR_FAILED_TO_DELETE_USER: &[u8] = b"!err-auth-deluser-fail\n";
/// Auth erros
#[derive(PartialEq, Debug)]
pub enum AuthError {
/// The auth slot was already claimed
AlreadyClaimed,
/// Bad userid/tokens/keys
BadCredentials,
/// Auth is disabled
Disabled,
/// The action is not available to the current account
PermissionDenied,
/// The user is anonymous and doesn't have the right to execute this
Anonymous,
/// Some other error
Other(&'static [u8]),
}
impl From<AuthError> for ActionError {
fn from(e: AuthError) -> Self {
let r = match e {
AuthError::AlreadyClaimed => AUTH_ERROR_ALREADYCLAIMED,
AuthError::Anonymous | AuthError::PermissionDenied => AUTH_CODE_PERMS,
AuthError::BadCredentials => AUTH_CODE_BAD_CREDENTIALS,
AuthError::Disabled => AUTH_ERROR_DISABLED,
AuthError::Other(e) => e,
};
ActionError::ActionError(r)
}
}

@ -38,10 +38,7 @@
mod keys;
pub mod provider;
use crate::resp::{writer::NonNullArrayWriter, TSYMBOL_UNICODE_STRING};
pub use provider::{AuthProvider, AuthResult, Authmap};
pub mod errors;
pub use errors::AuthError;
pub use provider::{AuthProvider, Authmap};
#[cfg(test)]
mod tests;
@ -61,103 +58,102 @@ action! {
/// Handle auth. Should have passed the `auth` token
fn auth(
con: &mut T,
auth: &mut AuthProviderHandle<'_, T, Strm>,
auth: &mut AuthProviderHandle<'_, P, T, Strm>,
iter: ActionIter<'_>
) {
let mut iter = iter;
match iter.next_lowercase().unwrap_or_aerr()?.as_ref() {
match iter.next_lowercase().unwrap_or_aerr::<P>()?.as_ref() {
AUTH_LOGIN => self::_auth_login(con, auth, &mut iter).await,
AUTH_CLAIM => self::_auth_claim(con, auth, &mut iter).await,
AUTH_ADDUSER => {
ensure_boolean_or_aerr(iter.len() == 1)?; // just the username
ensure_boolean_or_aerr::<P>(iter.len() == 1)?; // just the username
let username = unsafe { iter.next_unchecked() };
let key = auth.provider_mut().claim_user(username)?;
con.write_response(StringWrapper(key)).await?;
let key = auth.provider_mut().claim_user::<P>(username)?;
con.write_string(&key).await?;
Ok(())
}
AUTH_LOGOUT => {
ensure_boolean_or_aerr(iter.is_empty())?; // nothing else
auth.provider_mut().logout()?;
ensure_boolean_or_aerr::<P>(iter.is_empty())?; // nothing else
auth.provider_mut().logout::<P>()?;
auth.swap_executor_to_anonymous();
con.write_response(groups::OKAY).await?;
con._write_raw(P::RCODE_OKAY).await?;
Ok(())
}
AUTH_DELUSER => {
ensure_boolean_or_aerr(iter.len() == 1)?; // just the username
auth.provider_mut().delete_user(unsafe { iter.next_unchecked() })?;
con.write_response(groups::OKAY).await?;
ensure_boolean_or_aerr::<P>(iter.len() == 1)?; // just the username
auth.provider_mut().delete_user::<P>(unsafe { iter.next_unchecked() })?;
con._write_raw(P::RCODE_OKAY).await?;
Ok(())
}
AUTH_RESTORE => self::auth_restore(con, auth, &mut iter).await,
AUTH_LISTUSER => self::auth_listuser(con, auth, &mut iter).await,
AUTH_WHOAMI => self::auth_whoami(con, auth, &mut iter).await,
_ => util::err(groups::UNKNOWN_ACTION),
_ => util::err(P::RCODE_UNKNOWN_ACTION),
}
}
fn auth_whoami(con: &mut T, auth: &mut AuthProviderHandle<'_, T, Strm>, iter: &mut ActionIter<'_>) {
ensure_boolean_or_aerr(ActionIter::is_empty(iter))?;
con.write_response(StringWrapper(auth.provider().whoami()?)).await?;
fn auth_whoami(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) {
ensure_boolean_or_aerr::<P>(ActionIter::is_empty(iter))?;
con.write_string(&auth.provider().whoami::<P>()?).await?;
Ok(())
}
fn auth_listuser(con: &mut T, auth: &mut AuthProviderHandle<'_, T, Strm>, iter: &mut ActionIter<'_>) {
ensure_boolean_or_aerr(ActionIter::is_empty(iter))?;
let usernames = auth.provider().collect_usernames()?;
let mut array_writer = unsafe {
// The symbol is definitely correct, obvious from this context
NonNullArrayWriter::new(con, TSYMBOL_UNICODE_STRING, usernames.len())
}.await?;
fn auth_listuser(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) {
ensure_boolean_or_aerr::<P>(ActionIter::is_empty(iter))?;
let usernames = auth.provider().collect_usernames::<P>()?;
con.write_typed_non_null_array_header(usernames.len(), b'+').await?;
for username in usernames {
array_writer.write_element(username).await?;
con.write_typed_non_null_array_element(username.as_bytes()).await?;
}
Ok(())
}
fn auth_restore(con: &mut T, auth: &mut AuthProviderHandle<'_, T, Strm>, iter: &mut ActionIter<'_>) {
fn auth_restore(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) {
let newkey = match iter.len() {
1 => {
// so this fella thinks they're root
auth.provider().regenerate(unsafe {iter.next_unchecked()})?
auth.provider().regenerate::<P>(
unsafe { iter.next_unchecked() }
)?
}
2 => {
// so this fella is giving us the origin key
let origin = unsafe { iter.next_unchecked() };
let id = unsafe { iter.next_unchecked() };
auth.provider().regenerate_using_origin(origin, id)?
auth.provider().regenerate_using_origin::<P>(origin, id)?
}
_ => return util::err(groups::ACTION_ERR),
_ => return util::err(P::RCODE_ACTION_ERR),
};
con.write_response(StringWrapper(newkey)).await?;
con.write_string(&newkey).await?;
Ok(())
}
fn _auth_claim(con: &mut T, auth: &mut AuthProviderHandle<'_, T, Strm>, iter: &mut ActionIter<'_>) {
ensure_boolean_or_aerr(iter.len() == 1)?; // just the origin key
fn _auth_claim(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) {
ensure_boolean_or_aerr::<P>(iter.len() == 1)?; // just the origin key
let origin_key = unsafe { iter.next_unchecked() };
let key = auth.provider_mut().claim_root(origin_key)?;
let key = auth.provider_mut().claim_root::<P>(origin_key)?;
auth.swap_executor_to_authenticated();
con.write_response(StringWrapper(key)).await?;
con.write_string(&key).await?;
Ok(())
}
/// Handle a login operation only. The **`login` token is expected to be present**
fn auth_login_only(
con: &mut T,
auth: &mut AuthProviderHandle<'_, T, Strm>,
auth: &mut AuthProviderHandle<'_, P, T, Strm>,
iter: ActionIter<'_>
) {
let mut iter = iter;
match iter.next_lowercase().unwrap_or_aerr()?.as_ref() {
match iter.next_lowercase().unwrap_or_aerr::<P>()?.as_ref() {
AUTH_LOGIN => self::_auth_login(con, auth, &mut iter).await,
AUTH_CLAIM => self::_auth_claim(con, auth, &mut iter).await,
AUTH_RESTORE => self::auth_restore(con, auth, &mut iter).await,
AUTH_WHOAMI => self::auth_whoami(con, auth, &mut iter).await,
_ => util::err(errors::AUTH_CODE_PERMS),
_ => util::err(P::AUTH_CODE_PERMS),
}
}
fn _auth_login(con: &mut T, auth: &mut AuthProviderHandle<'_, T, Strm>, iter: &mut ActionIter<'_>) {
fn _auth_login(con: &mut T, auth: &mut AuthProviderHandle<'_, P, T, Strm>, iter: &mut ActionIter<'_>) {
// sweet, where's our username and password
ensure_boolean_or_aerr(iter.len() == 2)?; // just the uname and pass
ensure_boolean_or_aerr::<P>(iter.len() == 2)?; // just the uname and pass
let (username, password) = unsafe { (iter.next_unchecked(), iter.next_unchecked()) };
auth.provider_mut().login(username, password)?;
auth.provider_mut().login::<P>(username, password)?;
auth.swap_executor_to_authenticated();
con.write_response(groups::OKAY).await?;
con._write_raw(P::RCODE_OKAY).await?;
Ok(())
}
}

@ -24,9 +24,12 @@
*
*/
use super::{errors, keys, AuthError};
use super::keys;
use crate::actions::{ActionError, ActionResult};
use crate::corestore::array::Array;
use crate::corestore::htable::Coremap;
use crate::protocol::interface::ProtocolSpec;
use crate::util::err;
use std::sync::Arc;
// constants
@ -35,14 +38,12 @@ pub const AUTHKEY_SIZE: usize = 40;
/// Size of an authn ID in bytes
pub const AUTHID_SIZE: usize = 40;
#[cfg(debug_assertions)]
pub mod testsuite_data {
#![allow(unused)]
//! Temporary users created by the testsuite in debug mode
pub const TESTSUITE_ROOT_USER: &str = "root";
pub const TESTSUITE_TEST_USER: &str = "testuser";
#[cfg(test)]
pub const TESTSUITE_ROOT_TOKEN: &str = "XUOdVKhEONnnGwNwT7WeLqbspDgVtKex0/nwFwBSW7XJxioHwpg6H.";
#[cfg(all(not(feature = "persist-suite"), test))]
pub const TESTSUITE_TEST_TOKEN: &str = "mpobAB7EY8vnBs70d/..h1VvfinKIeEJgt1rg4wUkwF6aWCvGGR9le";
}
@ -56,8 +57,6 @@ const USER_ROOT: AuthID = unsafe { AuthID::from_const(USER_ROOT_ARRAY, 4) };
type AuthID = Array<u8, AUTHID_SIZE>;
/// An authn key
pub type Authkey = [u8; AUTHKEY_SIZE];
/// Result of an auth operation
pub type AuthResult<T> = Result<T, AuthError>;
/// Authmap
pub type Authmap = Arc<Coremap<AuthID, Authkey>>;
@ -121,8 +120,8 @@ impl AuthProvider {
pub const fn is_enabled(&self) -> bool {
matches!(self.origin, Some(_))
}
pub fn claim_root(&mut self, origin_key: &[u8]) -> AuthResult<String> {
self.verify_origin(origin_key)?;
pub fn claim_root<P: ProtocolSpec>(&mut self, origin_key: &[u8]) -> ActionResult<String> {
self.verify_origin::<P>(origin_key)?;
// the origin key was good, let's try claiming root
let (key, store) = keys::generate_full();
if self.authmap.true_if_insert(USER_ROOT, store) {
@ -130,33 +129,33 @@ impl AuthProvider {
self.whoami = Some(USER_ROOT);
Ok(key)
} else {
Err(AuthError::AlreadyClaimed)
err(P::AUTH_ERROR_ALREADYCLAIMED)
}
}
fn are_you_root(&self) -> AuthResult<bool> {
self.ensure_enabled()?;
fn are_you_root<P: ProtocolSpec>(&self) -> ActionResult<bool> {
self.ensure_enabled::<P>()?;
match self.whoami.as_ref().map(|v| v.eq(&USER_ROOT)) {
Some(v) => Ok(v),
None => Err(AuthError::Anonymous),
None => err(P::AUTH_CODE_PERMS),
}
}
pub fn claim_user(&self, claimant: &[u8]) -> AuthResult<String> {
self.ensure_root()?;
self._claim_user(claimant)
pub fn claim_user<P: ProtocolSpec>(&self, claimant: &[u8]) -> ActionResult<String> {
self.ensure_root::<P>()?;
self._claim_user::<P>(claimant)
}
pub fn _claim_user(&self, claimant: &[u8]) -> AuthResult<String> {
pub fn _claim_user<P: ProtocolSpec>(&self, claimant: &[u8]) -> ActionResult<String> {
let (key, store) = keys::generate_full();
if self
.authmap
.true_if_insert(Self::try_auth_id(claimant)?, store)
.true_if_insert(Self::try_auth_id::<P>(claimant)?, store)
{
Ok(key)
} else {
Err(AuthError::AlreadyClaimed)
err(P::AUTH_ERROR_ALREADYCLAIMED)
}
}
pub fn login(&mut self, account: &[u8], token: &[u8]) -> AuthResult<()> {
self.ensure_enabled()?;
pub fn login<P: ProtocolSpec>(&mut self, account: &[u8], token: &[u8]) -> ActionResult<()> {
self.ensure_enabled::<P>()?;
match self
.authmap
.get(account)
@ -164,84 +163,94 @@ impl AuthProvider {
{
Some(Some(true)) => {
// great, authenticated
self.whoami = Some(Self::try_auth_id(account)?);
self.whoami = Some(Self::try_auth_id::<P>(account)?);
Ok(())
}
_ => {
// either the password was wrong, or the username was wrong
Err(AuthError::BadCredentials)
err(P::AUTH_CODE_BAD_CREDENTIALS)
}
}
}
pub fn regenerate_using_origin(&self, origin: &[u8], account: &[u8]) -> AuthResult<String> {
self.verify_origin(origin)?;
self._regenerate(account)
pub fn regenerate_using_origin<P: ProtocolSpec>(
&self,
origin: &[u8],
account: &[u8],
) -> ActionResult<String> {
self.verify_origin::<P>(origin)?;
self._regenerate::<P>(account)
}
pub fn regenerate(&self, account: &[u8]) -> AuthResult<String> {
self.ensure_root()?;
self._regenerate(account)
pub fn regenerate<P: ProtocolSpec>(&self, account: &[u8]) -> ActionResult<String> {
self.ensure_root::<P>()?;
self._regenerate::<P>(account)
}
/// Regenerate the token for the given user. This returns a new token
fn _regenerate(&self, account: &[u8]) -> AuthResult<String> {
let id = Self::try_auth_id(account)?;
fn _regenerate<P: ProtocolSpec>(&self, account: &[u8]) -> ActionResult<String> {
let id = Self::try_auth_id::<P>(account)?;
let (key, store) = keys::generate_full();
if self.authmap.true_if_update(id, store) {
Ok(key)
} else {
Err(AuthError::BadCredentials)
err(P::AUTH_CODE_BAD_CREDENTIALS)
}
}
fn try_auth_id(authid: &[u8]) -> AuthResult<AuthID> {
fn try_auth_id<P: ProtocolSpec>(authid: &[u8]) -> ActionResult<AuthID> {
if authid.is_ascii() && authid.len() <= AUTHID_SIZE {
Ok(unsafe {
// We just verified the length
AuthID::from_slice(authid)
})
} else {
Err(AuthError::Other(errors::AUTH_ERROR_ILLEGAL_USERNAME))
err(P::AUTH_ERROR_ILLEGAL_USERNAME)
}
}
pub fn logout(&mut self) -> AuthResult<()> {
self.ensure_enabled()?;
self.whoami.take().map(|_| ()).ok_or(AuthError::Anonymous)
pub fn logout<P: ProtocolSpec>(&mut self) -> ActionResult<()> {
self.ensure_enabled::<P>()?;
self.whoami
.take()
.map(|_| ())
.ok_or(ActionError::ActionError(P::AUTH_CODE_PERMS))
}
fn ensure_enabled(&self) -> AuthResult<()> {
self.origin.as_ref().map(|_| ()).ok_or(AuthError::Disabled)
fn ensure_enabled<P: ProtocolSpec>(&self) -> ActionResult<()> {
self.origin
.as_ref()
.map(|_| ())
.ok_or(ActionError::ActionError(P::AUTH_ERROR_DISABLED))
}
pub fn verify_origin(&self, origin: &[u8]) -> AuthResult<()> {
if self.get_origin()?.eq(origin) {
pub fn verify_origin<P: ProtocolSpec>(&self, origin: &[u8]) -> ActionResult<()> {
if self.get_origin::<P>()?.eq(origin) {
Ok(())
} else {
Err(AuthError::BadCredentials)
err(P::AUTH_CODE_BAD_CREDENTIALS)
}
}
fn get_origin(&self) -> AuthResult<&Authkey> {
fn get_origin<P: ProtocolSpec>(&self) -> ActionResult<&Authkey> {
match self.origin.as_ref() {
Some(key) => Ok(key),
None => Err(AuthError::Disabled),
None => err(P::AUTH_ERROR_DISABLED),
}
}
fn ensure_root(&self) -> AuthResult<()> {
if self.are_you_root()? {
fn ensure_root<P: ProtocolSpec>(&self) -> ActionResult<()> {
if self.are_you_root::<P>()? {
Ok(())
} else {
Err(AuthError::PermissionDenied)
err(P::AUTH_CODE_PERMS)
}
}
pub fn delete_user(&self, user: &[u8]) -> AuthResult<()> {
self.ensure_root()?;
pub fn delete_user<P: ProtocolSpec>(&self, user: &[u8]) -> ActionResult<()> {
self.ensure_root::<P>()?;
if user.eq(&USER_ROOT) {
// can't delete root!
Err(AuthError::Other(errors::AUTH_ERROR_FAILED_TO_DELETE_USER))
err(P::AUTH_ERROR_FAILED_TO_DELETE_USER)
} else if self.authmap.true_if_removed(user) {
Ok(())
} else {
Err(AuthError::BadCredentials)
err(P::AUTH_CODE_BAD_CREDENTIALS)
}
}
/// List all the users
pub fn collect_usernames(&self) -> AuthResult<Vec<String>> {
self.ensure_root()?;
pub fn collect_usernames<P: ProtocolSpec>(&self) -> ActionResult<Vec<String>> {
self.ensure_root::<P>()?;
Ok(self
.authmap
.iter()
@ -249,12 +258,12 @@ impl AuthProvider {
.collect())
}
/// Return the AuthID of the current user
pub fn whoami(&self) -> AuthResult<String> {
self.ensure_enabled()?;
pub fn whoami<P: ProtocolSpec>(&self) -> ActionResult<String> {
self.ensure_enabled::<P>()?;
self.whoami
.as_ref()
.map(|v| String::from_utf8_lossy(v).to_string())
.ok_or(AuthError::Anonymous)
.ok_or(ActionError::ActionError(P::AUTH_CODE_PERMS))
}
}

@ -35,77 +35,88 @@ mod keys {
}
mod authn {
use crate::auth::{AuthError, AuthProvider};
use crate::actions::ActionError;
use crate::auth::AuthProvider;
use crate::protocol::{interface::ProtocolSpec, Skyhash2};
const ORIG: &[u8; 40] = b"c4299d190fb9a00626797fcc138c56eae9971664";
#[test]
fn claim_root_okay() {
let mut provider = AuthProvider::new_blank(Some(*ORIG));
let _ = provider.claim_root(ORIG).unwrap();
let _ = provider.claim_root::<Skyhash2>(ORIG).unwrap();
}
#[test]
fn claim_root_wrongkey() {
let mut provider = AuthProvider::new_blank(Some(*ORIG));
let claim_err = provider.claim_root(&ORIG[1..]).unwrap_err();
assert_eq!(claim_err, AuthError::BadCredentials);
let claim_err = provider.claim_root::<Skyhash2>(&ORIG[1..]).unwrap_err();
assert_eq!(
claim_err,
ActionError::ActionError(Skyhash2::AUTH_CODE_BAD_CREDENTIALS)
);
}
#[test]
fn claim_root_disabled() {
let mut provider = AuthProvider::new_disabled();
assert_eq!(
provider.claim_root(b"abcd").unwrap_err(),
AuthError::Disabled
provider.claim_root::<Skyhash2>(b"abcd").unwrap_err(),
ActionError::ActionError(Skyhash2::AUTH_ERROR_DISABLED)
);
}
#[test]
fn claim_root_already_claimed() {
let mut provider = AuthProvider::new_blank(Some(*ORIG));
let _ = provider.claim_root(ORIG).unwrap();
let _ = provider.claim_root::<Skyhash2>(ORIG).unwrap();
assert_eq!(
provider.claim_root(ORIG).unwrap_err(),
AuthError::AlreadyClaimed
provider.claim_root::<Skyhash2>(ORIG).unwrap_err(),
ActionError::ActionError(Skyhash2::AUTH_ERROR_ALREADYCLAIMED)
);
}
#[test]
fn claim_user_okay_with_login() {
let mut provider = AuthProvider::new_blank(Some(*ORIG));
// claim root
let rootkey = provider.claim_root(ORIG).unwrap();
let rootkey = provider.claim_root::<Skyhash2>(ORIG).unwrap();
// login as root
provider.login(b"root", rootkey.as_bytes()).unwrap();
provider
.login::<Skyhash2>(b"root", rootkey.as_bytes())
.unwrap();
// claim user
let _ = provider.claim_user(b"sayan").unwrap();
let _ = provider.claim_user::<Skyhash2>(b"sayan").unwrap();
}
#[test]
fn claim_user_fail_not_root_with_login() {
let mut provider = AuthProvider::new_blank(Some(*ORIG));
// claim root
let rootkey = provider.claim_root(ORIG).unwrap();
let rootkey = provider.claim_root::<Skyhash2>(ORIG).unwrap();
// login as root
provider.login(b"root", rootkey.as_bytes()).unwrap();
provider
.login::<Skyhash2>(b"root", rootkey.as_bytes())
.unwrap();
// claim user
let userkey = provider.claim_user(b"user").unwrap();
let userkey = provider.claim_user::<Skyhash2>(b"user").unwrap();
// login as user
provider.login(b"user", userkey.as_bytes()).unwrap();
provider
.login::<Skyhash2>(b"user", userkey.as_bytes())
.unwrap();
// now try to claim an user being a non-root account
assert_eq!(
provider.claim_user(b"otheruser").unwrap_err(),
AuthError::PermissionDenied
provider.claim_user::<Skyhash2>(b"otheruser").unwrap_err(),
ActionError::ActionError(Skyhash2::AUTH_CODE_PERMS)
);
}
#[test]
fn claim_user_fail_anonymous() {
let mut provider = AuthProvider::new_blank(Some(*ORIG));
// claim root
let _ = provider.claim_root(ORIG).unwrap();
let _ = provider.claim_root::<Skyhash2>(ORIG).unwrap();
// logout
provider.logout().unwrap();
provider.logout::<Skyhash2>().unwrap();
// try to claim as an anonymous user
assert_eq!(
provider.claim_user(b"newuser").unwrap_err(),
AuthError::Anonymous
provider.claim_user::<Skyhash2>(b"newuser").unwrap_err(),
ActionError::ActionError(Skyhash2::AUTH_CODE_PERMS)
);
}
}

@ -115,3 +115,9 @@ args:
takes_value: true
help: Set the authentication origin key
value_name: origin_key
- protover:
required: false
long: protover
takes_value: true
help: Set the protocol version
value_name: protover

@ -72,6 +72,12 @@ pub(super) fn parse_cli_args(matches: ArgMatches) -> Configset {
)
};
}
// protocol settings
fcli! {
protocol_settings,
matches.value_of("protover"),
"--protover"
};
// server settings
fcli!(
server_tcp,

@ -44,6 +44,8 @@ pub(super) fn parse_env_config() -> Configset {
);
};
}
// protocol settings
fenv!(protocol_settings, SKY_PROTOCOL_VERSION);
// server settings
fenv!(server_tcp, SKY_SYSTEM_HOST, SKY_SYSTEM_PORT);
fenv!(server_noart, SKY_SYSTEM_NOART);

@ -25,7 +25,8 @@
*/
use super::{
AuthSettings, ConfigSourceParseResult, Configset, Modeset, OptString, TryFromConfigSource,
AuthSettings, ConfigSourceParseResult, Configset, Modeset, OptString, ProtocolVersion,
TryFromConfigSource,
};
use serde::Deserialize;
use std::net::IpAddr;
@ -59,6 +60,7 @@ pub struct ConfigKeyServer {
pub(super) maxclient: Option<usize>,
/// The deployment mode
pub(super) mode: Option<Modeset>,
pub(super) protocol: Option<ProtocolVersion>,
}
/// The BGSAVE section in the config file
@ -175,6 +177,7 @@ pub fn from_file(file: ConfigFile) -> Configset {
Optional::some(server.port),
"server.port",
);
set.protocol_settings(server.protocol, "server.protocol");
set.server_maxcon(Optional::from(server.maxclient), "server.maxcon");
set.server_noart(Optional::from(server.noart), "server.noart");
set.server_mode(Optional::from(server.mode), "server.mode");

@ -68,6 +68,54 @@ impl BGSave {
}
}
#[repr(u8)]
#[derive(Debug, PartialEq)]
pub enum ProtocolVersion {
V1,
V2,
}
impl Default for ProtocolVersion {
fn default() -> Self {
Self::V2
}
}
impl ToString for ProtocolVersion {
fn to_string(&self) -> String {
match self {
Self::V1 => "Skyhash 1.0".to_owned(),
Self::V2 => "Skyhash 2.0".to_owned(),
}
}
}
struct ProtocolVersionVisitor;
impl<'de> Visitor<'de> for ProtocolVersionVisitor {
type Value = ProtocolVersion;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "a 40 character ASCII string")
}
fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
value.parse().map_err(|_| {
E::custom("Invalid value for protocol version. Valid inputs: 1.0, 1.1, 1.2, 2.0")
})
}
}
impl<'de> Deserialize<'de> for ProtocolVersion {
fn deserialize<D>(deserializer: D) -> Result<ProtocolVersion, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(ProtocolVersionVisitor)
}
}
/// A `ConfigurationSet` which can be used by main::check_args_or_connect() to bind
/// to a `TcpListener` and show the corresponding terminal output for the given
/// configuration
@ -87,9 +135,12 @@ pub struct ConfigurationSet {
pub mode: Modeset,
/// The auth settings
pub auth: AuthSettings,
/// The protocol version
pub protocol: ProtocolVersion,
}
impl ConfigurationSet {
#[allow(clippy::too_many_arguments)]
pub const fn new(
noart: bool,
bgsave: BGSave,
@ -98,6 +149,7 @@ impl ConfigurationSet {
maxcon: usize,
mode: Modeset,
auth: AuthSettings,
protocol: ProtocolVersion,
) -> Self {
Self {
noart,
@ -107,6 +159,7 @@ impl ConfigurationSet {
maxcon,
mode,
auth,
protocol,
}
}
/// Create a default `ConfigurationSet` with the following setup defaults:
@ -125,6 +178,7 @@ impl ConfigurationSet {
MAXIMUM_CONNECTION_LIMIT,
Modeset::Dev,
AuthSettings::default(),
ProtocolVersion::V2,
)
}
/// Returns `false` if `noart` is enabled. Otherwise it returns `true`
@ -207,14 +261,14 @@ impl PortConfig {
Self::Multi { host, port, ssl } => {
format!(
"skyhash://{host}:{port} and skyhash-secure://{host}:{tlsport}",
tlsport = ssl.get_port()
tlsport = ssl.get_port(),
)
}
Self::SecureOnly {
host,
ssl: SslOpts { port, .. },
} => format!("skyhash-secure://{host}:{port}"),
Self::InsecureOnly { host, port } => format!("skyhash://{host}:{port}"),
Self::InsecureOnly { host, port } => format!("skyhash://{host}:{port}",),
}
}
}

@ -66,9 +66,17 @@ impl FeedbackStack {
impl fmt::Display for FeedbackStack {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if !self.is_empty() {
write!(f, "{} {}:", self.feedback_source, self.feedback_type)?;
for err in self.stack.iter() {
write!(f, "\n{}- {}", TAB, err)?;
if self.len() == 1 {
write!(
f,
"{} {}: {}",
self.feedback_source, self.feedback_type, self.stack[0]
)?;
} else {
write!(f, "{} {}:", self.feedback_source, self.feedback_type)?;
for err in self.stack.iter() {
write!(f, "\n{}- {}", TAB, err)?;
}
}
}
Ok(())
@ -265,8 +273,7 @@ mod test {
#[test]
fn errorstack_fmt() {
const EXPECTED: &str = "\
Environment errors:
- Invalid value for `SKY_SYSTEM_PORT`. Expected a 16-bit integer\
Environment errors: Invalid value for `SKY_SYSTEM_PORT`. Expected a 16-bit integer\
";
let mut estk = ErrorStack::new(EMSG_ENV);
estk.push("Invalid value for `SKY_SYSTEM_PORT`. Expected a 16-bit integer");

@ -211,6 +211,34 @@ impl FromStr for OptString {
}
}
impl FromStr for ProtocolVersion {
type Err = ();
fn from_str(st: &str) -> Result<Self, Self::Err> {
match st {
"1" | "1.0" | "1.1" | "1.2" => Ok(Self::V1),
"2" | "2.0" => Ok(Self::V2),
_ => Err(()),
}
}
}
impl TryFromConfigSource<ProtocolVersion> for Option<ProtocolVersion> {
fn is_present(&self) -> bool {
self.is_some()
}
fn mutate_failed(self, target: &mut ProtocolVersion, trip: &mut bool) -> bool {
if let Some(v) = self {
*target = v;
*trip = true;
}
false
}
fn try_parse(self) -> ConfigSourceParseResult<ProtocolVersion> {
self.map(ConfigSourceParseResult::Okay)
.unwrap_or(ConfigSourceParseResult::Absent)
}
}
impl TryFromConfigSource<OptString> for OptString {
fn is_present(&self) -> bool {
self.base.is_some()
@ -225,7 +253,7 @@ impl TryFromConfigSource<OptString> for OptString {
fn try_parse(self) -> ConfigSourceParseResult<OptString> {
self.base
.map(|v| ConfigSourceParseResult::Okay(OptString { base: Some(v) }))
.unwrap_or(ConfigSourceParseResult::Okay(OptString::new_null()))
.unwrap_or(ConfigSourceParseResult::Absent)
}
}
@ -365,6 +393,13 @@ impl Configset {
} else {
return Err(ConfigError::CfgError(self.estack));
};
if target.config.protocol != ProtocolVersion::default() {
target.wpush(format!(
"{} is deprecated. Switch to {}",
target.config.protocol.to_string(),
ProtocolVersion::default().to_string()
));
}
if target.is_prod_mode() {
self::feedback::evaluate_prod_settings(&target.config).map(|_| target)
} else {
@ -374,6 +409,24 @@ impl Configset {
}
}
// protocol settings
impl Configset {
pub fn protocol_settings(
&mut self,
nproto: impl TryFromConfigSource<ProtocolVersion>,
nproto_key: StaticStr,
) {
let mut proto = ProtocolVersion::default();
self.try_mutate(
nproto,
&mut proto,
nproto_key,
"a protocol version like 2.0 or 1.0",
);
self.cfg.protocol = proto;
}
}
// server settings
impl Configset {
pub fn server_tcp(

@ -345,7 +345,7 @@ mod cfg_file_tests {
use crate::config::AuthkeyWrapper;
use crate::config::{
cfgfile, AuthSettings, BGSave, Configset, ConfigurationSet, Modeset, PortConfig,
SnapshotConfig, SnapshotPref, SslOpts, DEFAULT_IPV4, DEFAULT_PORT,
ProtocolVersion, SnapshotConfig, SnapshotPref, SslOpts, DEFAULT_IPV4, DEFAULT_PORT,
};
use crate::dbnet::MAXIMUM_CONNECTION_LIMIT;
use std::net::{IpAddr, Ipv6Addr};
@ -401,6 +401,7 @@ mod cfg_file_tests {
maxcon: MAXIMUM_CONNECTION_LIMIT,
mode: Modeset::Dev,
auth: AuthSettings::default(),
protocol: ProtocolVersion::default(),
}
);
}
@ -422,6 +423,7 @@ mod cfg_file_tests {
maxcon: MAXIMUM_CONNECTION_LIMIT,
mode: Modeset::Dev,
auth: AuthSettings::default(),
protocol: ProtocolVersion::default(),
}
);
}
@ -447,7 +449,8 @@ mod cfg_file_tests {
),
MAXIMUM_CONNECTION_LIMIT,
Modeset::Dev,
AuthSettings::new(AuthkeyWrapper::try_new(crate::TEST_AUTH_ORIGIN_KEY).unwrap())
AuthSettings::new(AuthkeyWrapper::try_new(crate::TEST_AUTH_ORIGIN_KEY).unwrap()),
ProtocolVersion::default()
)
);
}
@ -473,6 +476,7 @@ mod cfg_file_tests {
maxcon: MAXIMUM_CONNECTION_LIMIT,
mode: Modeset::Dev,
auth: AuthSettings::default(),
protocol: ProtocolVersion::default(),
}
);
}
@ -495,6 +499,7 @@ mod cfg_file_tests {
maxcon: MAXIMUM_CONNECTION_LIMIT,
mode: Modeset::Dev,
auth: AuthSettings::default(),
protocol: ProtocolVersion::default(),
}
)
}
@ -517,6 +522,7 @@ mod cfg_file_tests {
maxcon: MAXIMUM_CONNECTION_LIMIT,
mode: Modeset::Dev,
auth: AuthSettings::default(),
protocol: ProtocolVersion::default(),
}
)
}
@ -535,6 +541,7 @@ mod cfg_file_tests {
maxcon: MAXIMUM_CONNECTION_LIMIT,
mode: Modeset::Dev,
auth: AuthSettings::default(),
protocol: ProtocolVersion::default(),
}
);
}

@ -132,6 +132,7 @@ mod cluster {
#[derive(Debug, PartialEq)]
/// Errors arising from trying to modify/access containers
#[allow(dead_code)]
pub enum DdlError {
/// The object is still in use
StillInUse,

@ -29,6 +29,7 @@ use crate::corestore::{
memstore::{DdlError, Keyspace, Memstore, ObjectID, DEFAULT},
table::{DescribeTable, Table},
};
use crate::protocol::interface::ProtocolSpec;
use crate::queryengine::parser::{Entity, OwnedEntity};
use crate::registry;
use crate::storage;
@ -210,8 +211,8 @@ impl Corestore {
self.estate.table.as_ref().map(|(_, tbl)| tbl.as_ref())
}
/// Returns a table with the provided specification
pub fn get_table_with<T: DescribeTable>(&self) -> ActionResult<&T::Table> {
T::get(self)
pub fn get_table_with<P: ProtocolSpec, T: DescribeTable>(&self) -> ActionResult<&T::Table> {
T::get::<P>(self)
}
/// Create a table: in-memory; **no transactional guarantees**. Two tables can be created
/// simultaneously, but are never flushed unless we are very lucky. If the global flush

@ -32,22 +32,22 @@ use crate::corestore::Data;
use crate::corestore::{memstore::DdlError, KeyspaceResult};
use crate::dbnet::connection::prelude::Corestore;
use crate::kvengine::{KVEListmap, KVEStandard, LockedVec};
use crate::protocol::responses::groups;
use crate::protocol::interface::ProtocolSpec;
use crate::util;
pub trait DescribeTable {
type Table;
fn try_get(table: &Table) -> Option<&Self::Table>;
fn get(store: &Corestore) -> ActionResult<&Self::Table> {
fn get<P: ProtocolSpec>(store: &Corestore) -> ActionResult<&Self::Table> {
match store.estate.table {
Some((_, ref table)) => {
// so we do have a table
match Self::try_get(table) {
Some(tbl) => Ok(tbl),
None => util::err(groups::WRONG_MODEL),
None => util::err(P::RSTRING_WRONG_MODEL),
}
}
None => util::err(groups::DEFAULT_UNSET),
None => util::err(P::RSTRING_DEFAULT_UNSET),
}
}
}

@ -25,45 +25,41 @@
*/
//! # Generic connection traits
//! The `con` module defines the generic connection traits `ProtocolConnection` and `ProtocolConnectionExt`.
//! The `con` module defines the generic connection traits `RawConnection` and `ProtocolRead`.
//! These two traits can be used to interface with sockets that are used for communication through the Skyhash
//! protocol.
//!
//! The `ProtocolConnection` trait provides a basic set of methods that are required by prospective connection
//! The `RawConnection` trait provides a basic set of methods that are required by prospective connection
//! objects to be eligible for higher level protocol interactions (such as interactions with high-level query objects).
//! Once a type implements this trait, it automatically gets a free `ProtocolConnectionExt` implementation. This immediately
//! Once a type implements this trait, it automatically gets a free `ProtocolRead` implementation. This immediately
//! enables this connection object/type to use methods like read_query enabling it to read and interact with queries and write
//! respones in compliance with the Skyhash protocol.
use crate::{
actions::{ActionError, ActionResult},
auth::{self, AuthProvider},
corestore::{buffers::Integer64, Corestore},
auth::AuthProvider,
corestore::Corestore,
dbnet::{
connection::prelude::FutureResult,
tcp::{BufferedSocketStream, Connection},
Terminator,
},
protocol::{self, responses, ParseError, Query},
queryengine,
resp::Writable,
IoResult,
protocol::{
interface::{ProtocolRead, ProtocolSpec, ProtocolWrite},
Query,
},
queryengine, IoResult,
};
use bytes::{Buf, BytesMut};
use std::{
future::Future,
io::{Error as IoError, ErrorKind},
marker::PhantomData,
pin::Pin,
sync::Arc,
};
#[cfg(windows)]
use std::io::ErrorKind;
use std::{marker::PhantomData, sync::Arc};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt, BufWriter},
sync::{mpsc, Semaphore},
};
pub const SIMPLE_QUERY_HEADER: [u8; 1] = [b'*'];
type QueryWithAdvance = (Query, usize);
pub type QueryWithAdvance = (Query, usize);
pub enum QueryResult {
Q(QueryWithAdvance),
@ -72,18 +68,19 @@ pub enum QueryResult {
Disconnected,
}
pub struct AuthProviderHandle<'a, T, Strm> {
pub struct AuthProviderHandle<'a, P, T, Strm> {
provider: &'a mut AuthProvider,
executor: &'a mut ExecutorFn<T, Strm>,
executor: &'a mut ExecutorFn<P, T, Strm>,
_phantom: PhantomData<(T, Strm)>,
}
impl<'a, T, Strm> AuthProviderHandle<'a, T, Strm>
impl<'a, P, T, Strm> AuthProviderHandle<'a, P, T, Strm>
where
T: ClientConnection<Strm>,
T: ClientConnection<P, Strm>,
Strm: Stream,
P: ProtocolSpec,
{
pub fn new(provider: &'a mut AuthProvider, executor: &'a mut ExecutorFn<T, Strm>) -> Self {
pub fn new(provider: &'a mut AuthProvider, executor: &'a mut ExecutorFn<P, T, Strm>) -> Self {
Self {
provider,
executor,
@ -105,217 +102,33 @@ where
}
pub mod prelude {
//! A 'prelude' for callers that would like to use the `ProtocolConnection` and `ProtocolConnectionExt` traits
//! A 'prelude' for callers that would like to use the `RawConnection` and `ProtocolRead` traits
//!
//! This module is hollow itself, it only re-exports from `dbnet::con` and `tokio::io`
pub use super::{AuthProviderHandle, ClientConnection, ProtocolConnectionExt, Stream};
pub use crate::actions::{ensure_boolean_or_aerr, ensure_cond_or_err, ensure_length};
pub use crate::corestore::{
table::{KVEBlob, KVEList},
Corestore,
pub use super::{AuthProviderHandle, ClientConnection, Stream};
pub use crate::{
actions::{ensure_boolean_or_aerr, ensure_cond_or_err, ensure_length, translate_ddl_error},
corestore::{
table::{KVEBlob, KVEList},
Corestore,
},
get_tbl, handle_entity, is_lowbit_set,
protocol::interface::ProtocolSpec,
queryengine::ActionIter,
registry,
util::{self, FutureResult, UnwrapActionError, Unwrappable},
};
pub use crate::protocol::responses::{self, groups};
pub use crate::queryengine::ActionIter;
pub use crate::resp::StringWrapper;
pub use crate::util::{self, FutureResult, UnwrapActionError, Unwrappable};
pub use crate::{aerr, conwrite, get_tbl, handle_entity, is_lowbit_set, registry};
pub use tokio::io::{AsyncReadExt, AsyncWriteExt};
}
/// # The `ProtocolConnectionExt` trait
/// # The `RawConnection` trait
///
/// The `ProtocolConnectionExt` trait has default implementations and doesn't ever require explicit definitions, unless
/// there's some black magic that you want to do. All [`ProtocolConnection`] objects will get a free implementation for this trait.
/// Hence implementing [`ProtocolConnection`] alone is enough for you to get high-level methods to interface with the protocol.
/// The `RawConnection` trait has low-level methods that can be used to interface with raw sockets. Any type
/// that successfully implements this trait will get an implementation for `ProtocolRead` and `ProtocolWrite`
/// provided that it uses a protocol that implements the `ProtocolSpec` trait.
///
/// ## DO NOT
/// The fact that this is a trait enables great flexibility in terms of visibility, but **DO NOT EVER CALL any function other than
/// `read_query`, `close_conn_with_error` or `write_response`**. If you mess with functions like `read_again`, you're likely to pull yourself into some
/// good trouble.
pub trait ProtocolConnectionExt<Strm>: ProtocolConnection<Strm> + Send
where
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
/// Try to parse a query from the buffered data
fn try_query(&self) -> Result<QueryWithAdvance, ParseError> {
protocol::Parser::parse(self.get_buffer())
}
/// Read a query from the remote end
///
/// This function asynchronously waits until all the data required
/// for parsing the query is available
fn read_query<'r, 's>(
&'r mut self,
) -> Pin<Box<dyn Future<Output = Result<QueryResult, IoError>> + Send + 's>>
where
'r: 's,
Self: Sync + Send + 's,
{
Box::pin(async move {
let mv_self = self;
loop {
let (buffer, stream) = mv_self.get_mut_both();
match stream.read_buf(buffer).await {
Ok(0) => {
if buffer.is_empty() {
return Ok(QueryResult::Disconnected);
} else {
return Err(IoError::from(ErrorKind::ConnectionReset));
}
}
Ok(_) => {}
Err(e) => return Err(e),
}
match mv_self.try_query() {
Ok(query_with_advance) => {
return Ok(QueryResult::Q(query_with_advance));
}
Err(ParseError::NotEnough) => (),
Err(ParseError::DatatypeParseFailure) => return Ok(QueryResult::Wrongtype),
Err(ParseError::UnexpectedByte) | Err(ParseError::BadPacket) => {
return Ok(QueryResult::E(responses::full_responses::R_PACKET_ERR));
}
}
}
})
}
/// Write a response to the stream
fn write_response<'r, 's>(
&'r mut self,
streamer: impl Writable + 's + Send + Sync,
) -> Pin<Box<dyn Future<Output = IoResult<()>> + Sync + Send + 's>>
where
'r: 's,
Self: Send + 's,
Self: Sync,
{
Box::pin(async move {
let mv_self = self;
let streamer = streamer;
let ret: IoResult<()> = {
streamer.write(&mut mv_self.get_mut_stream()).await?;
Ok(())
};
ret
})
}
/// Write the simple query header `*1\n` to the stream
fn write_simple_query_header<'r, 's>(
&'r mut self,
) -> Pin<Box<dyn Future<Output = IoResult<()>> + Send + Sync + 's>>
where
'r: 's,
Self: Send + Sync + 's,
{
Box::pin(async move {
let mv_self = self;
let ret: IoResult<()> = {
mv_self.write_response(SIMPLE_QUERY_HEADER).await?;
Ok(())
};
ret
})
}
/// Write the length of the pipeline query (*)
fn write_pipeline_query_header<'r, 's>(
&'r mut self,
len: usize,
) -> FutureResult<'s, IoResult<()>>
where
'r: 's,
Self: Send + Sync + 's,
{
Box::pin(async move {
let slf = self;
slf.write_response([b'$']).await?;
slf.get_mut_stream()
.write_all(&Integer64::init(len as u64))
.await?;
slf.write_response([b'\n']).await?;
Ok(())
})
}
/// Write the flat array length (`_<size>\n`)
fn write_flat_array_length<'r, 's>(&'r mut self, len: usize) -> FutureResult<'s, IoResult<()>>
where
'r: 's,
Self: Send + Sync + 's,
{
Box::pin(async move {
let mv_self = self;
let ret: IoResult<()> = {
mv_self.write_response([b'_']).await?;
mv_self.write_response(len.to_string().into_bytes()).await?;
mv_self.write_response([b'\n']).await?;
Ok(())
};
ret
})
}
/// Write the array length (`&<size>\n`)
fn write_array_length<'r, 's>(&'r mut self, len: usize) -> FutureResult<'s, IoResult<()>>
where
'r: 's,
Self: Send + Sync + 's,
{
Box::pin(async move {
let mv_self = self;
let ret: IoResult<()> = {
mv_self.write_response([b'&']).await?;
mv_self.write_response(len.to_string().into_bytes()).await?;
mv_self.write_response([b'\n']).await?;
Ok(())
};
ret
})
}
/// Wraps around the `write_response` used to differentiate between a
/// success response and an error response
fn close_conn_with_error<'r, 's>(
&'r mut self,
resp: impl Writable + 's + Send + Sync,
) -> FutureResult<'s, IoResult<()>>
where
'r: 's,
Self: Send + Sync + 's,
{
Box::pin(async move {
let mv_self = self;
let ret: IoResult<()> = {
mv_self.write_response(resp).await?;
mv_self.flush_stream().await?;
Ok(())
};
ret
})
}
fn flush_stream<'r, 's>(&'r mut self) -> FutureResult<'s, IoResult<()>>
where
'r: 's,
Self: Sync + Send + 's,
{
Box::pin(async move {
let mv_self = self;
let ret: IoResult<()> = {
mv_self.get_mut_stream().flush().await?;
Ok(())
};
ret
})
}
unsafe fn raw_stream(&mut self) -> &mut BufWriter<Strm> {
self.get_mut_stream()
}
}
/// # The `ProtocolConnection` trait
///
/// The `ProtocolConnection` trait has low-level methods that can be used to interface with raw sockets. Any type
/// that successfully implements this trait will get an implementation for `ProtocolConnectionExt` which augments and
/// builds on these fundamental methods to provide high-level interfacing with queries.
///
/// ## Example of a `ProtocolConnection` object
/// Ideally a `ProtocolConnection` object should look like (the generic parameter just exists for doc-tests, just think that
/// ## Example of a `RawConnection` object
/// Ideally a `RawConnection` object should look like (the generic parameter just exists for doc-tests, just think that
/// there is a type `Strm`):
/// ```no_run
/// struct Connection<Strm> {
@ -325,7 +138,7 @@ where
/// ```
///
/// `Strm` should be a stream, i.e something like an SSL connection/TCP connection.
pub trait ProtocolConnection<Strm> {
pub trait RawConnection<P: ProtocolSpec, Strm>: Send + Sync {
/// Returns an **immutable** reference to the underlying read buffer
fn get_buffer(&self) -> &BytesMut;
/// Returns an **immutable** reference to the underlying stream
@ -348,18 +161,10 @@ pub trait ProtocolConnection<Strm> {
}
}
// Give ProtocolConnection implementors a free ProtocolConnectionExt impl
impl<Strm, T> ProtocolConnectionExt<Strm> for T
where
T: ProtocolConnection<Strm> + Send,
Strm: Sync + Send + Unpin + AsyncWriteExt + AsyncReadExt,
{
}
impl<T> ProtocolConnection<T> for Connection<T>
impl<T, P> RawConnection<P, T> for Connection<T>
where
T: BufferedSocketStream,
T: BufferedSocketStream + Sync + Send,
P: ProtocolSpec,
{
fn get_buffer(&self) -> &BytesMut {
&self.buffer
@ -378,35 +183,36 @@ where
}
}
pub(super) type ExecutorFn<T, Strm> =
for<'s> fn(&'s mut ConnectionHandler<T, Strm>, Query) -> FutureResult<'s, ActionResult<()>>;
pub(super) type ExecutorFn<P, T, Strm> =
for<'s> fn(&'s mut ConnectionHandler<P, T, Strm>, Query) -> FutureResult<'s, ActionResult<()>>;
/// # A generic connection handler
///
/// A [`ConnectionHandler`] object is a generic connection handler for any object that implements the [`ProtocolConnection`] trait (or
/// the [`ProtocolConnectionExt`] trait). This function will accept such a type `T`, possibly a listener object and then use it to read
/// A [`ConnectionHandler`] object is a generic connection handler for any object that implements the [`RawConnection`] trait (or
/// the [`ProtocolRead`] trait). This function will accept such a type `T`, possibly a listener object and then use it to read
/// a query, parse it and return an appropriate response through [`corestore::Corestore::execute_query`]
pub struct ConnectionHandler<T, Strm> {
pub struct ConnectionHandler<P, T, Strm> {
db: Corestore,
con: T,
climit: Arc<Semaphore>,
auth: AuthProvider,
executor: ExecutorFn<T, Strm>,
executor: ExecutorFn<P, T, Strm>,
terminator: Terminator,
_term_sig_tx: mpsc::Sender<()>,
_marker: PhantomData<Strm>,
}
impl<T, Strm> ConnectionHandler<T, Strm>
impl<P, T, Strm> ConnectionHandler<P, T, Strm>
where
T: ProtocolConnectionExt<Strm> + Send + Sync,
Strm: Sync + Send + Unpin + AsyncWriteExt + AsyncReadExt,
T: ProtocolRead<P, Strm> + ProtocolWrite<P, Strm> + Send + Sync,
Strm: Stream,
P: ProtocolSpec,
{
pub fn new(
db: Corestore,
con: T,
auth: AuthProvider,
executor: ExecutorFn<T, Strm>,
executor: ExecutorFn<P, T, Strm>,
climit: Arc<Semaphore>,
terminator: Terminator,
_term_sig_tx: mpsc::Sender<()>,
@ -434,23 +240,50 @@ where
Ok(QueryResult::Q((query, advance_by))) => {
// the mutable reference to self ensures that the buffer is not modified
// hence ensuring that the pointers will remain valid
match self.execute_query(query).await {
Ok(()) => {}
Err(ActionError::ActionError(e)) => {
self.con.close_conn_with_error(e).await?;
}
Err(ActionError::IoError(e)) => {
return Err(e);
#[cfg(debug_assertions)]
let len_at_start = self.con.get_buffer().len();
#[cfg(debug_assertions)]
let sptr_at_start = self.con.get_buffer().as_ptr() as usize;
#[cfg(debug_assertions)]
let eptr_at_start = sptr_at_start + len_at_start;
{
match self.execute_query(query).await {
Ok(()) => {}
Err(ActionError::ActionError(e)) => {
self.con.close_conn_with_error(e).await?;
}
Err(ActionError::IoError(e)) => {
return Err(e);
}
}
}
// this is only when we clear the buffer. since execute_query is not called
// at this point, it's totally fine (so invalidating ptrs is totally cool)
self.con.advance_buffer(advance_by);
{
// do these assertions to ensure memory safety (this is just for sanity sake)
#[cfg(debug_assertions)]
// len should be unchanged. no functions should **ever** touch the buffer
debug_assert_eq!(self.con.get_buffer().len(), len_at_start);
#[cfg(debug_assertions)]
// start of allocation should be unchanged
debug_assert_eq!(self.con.get_buffer().as_ptr() as usize, sptr_at_start);
#[cfg(debug_assertions)]
// end of allocation should be unchanged. else we're entirely violating
// memory safety guarantees
debug_assert_eq!(
unsafe {
// UNSAFE(@ohsayan): THis is always okay
self.con.get_buffer().as_ptr().add(len_at_start)
} as usize,
eptr_at_start
);
// this is only when we clear the buffer. since execute_query is not called
// at this point, it's totally fine (so invalidating ptrs is totally cool)
self.con.advance_buffer(advance_by);
}
}
Ok(QueryResult::E(r)) => self.con.close_conn_with_error(r).await?,
Ok(QueryResult::Wrongtype) => {
self.con
.close_conn_with_error(responses::groups::WRONGTYPE_ERR.to_owned())
.close_conn_with_error(P::RCODE_WRONGTYPE_ERR)
.await?
}
Ok(QueryResult::Disconnected) => return Ok(()),
@ -479,8 +312,7 @@ where
}
Query::Pipelined(_) => {
con.write_simple_query_header().await?;
con.write_response(auth::errors::AUTH_CODE_BAD_CREDENTIALS)
.await?;
con._write_raw(P::AUTH_CODE_BAD_CREDENTIALS).await?;
}
}
Ok(())
@ -499,7 +331,7 @@ where
queryengine::execute_simple(db, con, &mut auth_provider, q).await?;
}
Query::Pipelined(pipeline) => {
con.write_pipeline_query_header(pipeline.len()).await?;
con.write_pipelined_query_header(pipeline.len()).await?;
queryengine::execute_pipeline(db, con, &mut auth_provider, pipeline).await?;
}
}
@ -510,12 +342,12 @@ where
/// Execute a query that has already been validated by `Connection::read_query`
async fn execute_query(&mut self, query: Query) -> ActionResult<()> {
(self.executor)(self, query).await?;
self.con.flush_stream().await?;
self.con._flush_stream().await?;
Ok(())
}
}
impl<T, Strm> Drop for ConnectionHandler<T, Strm> {
impl<P, T, Strm> Drop for ConnectionHandler<P, T, Strm> {
fn drop(&mut self) {
// Make sure that the permit is returned to the semaphore
// in the case that there is a panic inside
@ -528,10 +360,14 @@ pub trait Stream: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync {}
impl<T> Stream for T where T: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync {}
/// A simple _shorthand trait_ for the insanely long definition of the connection generic type
pub trait ClientConnection<Strm: Stream>: ProtocolConnectionExt<Strm> + Send + Sync {}
impl<T, Strm> ClientConnection<Strm> for T
pub trait ClientConnection<P: ProtocolSpec, Strm: Stream>:
ProtocolWrite<P, Strm> + ProtocolRead<P, Strm> + Send + Sync
{
}
impl<P, T, Strm> ClientConnection<P, Strm> for T
where
T: ProtocolConnectionExt<Strm> + Send + Sync,
T: ProtocolWrite<P, Strm> + ProtocolRead<P, Strm> + Send + Sync,
Strm: Stream,
P: ProtocolSpec,
{
}

@ -39,19 +39,24 @@
//! 5. Now errors are handled if they occur. Otherwise, the query is executed by `Corestore::execute_query()`
//!
use self::tcp::Listener;
use crate::{
auth::AuthProvider,
config::{PortConfig, SslOpts},
corestore::Corestore,
util::error::{Error, SkyResult},
IoResult,
};
use std::{net::IpAddr, sync::Arc};
use tls::SslListener;
use tokio::{
net::TcpListener,
sync::{broadcast, mpsc, Semaphore},
use {
self::{
tcp::{Listener, ListenerV1},
tls::{SslListener, SslListenerV1},
},
crate::{
auth::AuthProvider,
config::{PortConfig, ProtocolVersion, SslOpts},
corestore::Corestore,
util::error::{Error, SkyResult},
IoResult,
},
core::future::Future,
std::{net::IpAddr, sync::Arc},
tokio::{
net::TcpListener,
sync::{broadcast, mpsc, Semaphore},
},
};
pub mod connection;
#[macro_use]
@ -160,35 +165,93 @@ impl BaseListener {
#[allow(clippy::large_enum_variant)]
pub enum MultiListener {
SecureOnly(SslListener),
SecureOnlyV1(SslListenerV1),
InsecureOnly(Listener),
InsecureOnlyV1(ListenerV1),
Multi(Listener, SslListener),
MultiV1(ListenerV1, SslListenerV1),
}
async fn wait_on_port_futures(
a: impl Future<Output = IoResult<()>>,
b: impl Future<Output = IoResult<()>>,
) -> IoResult<()> {
let (e1, e2) = tokio::join!(a, b);
if let Err(e) = e1 {
log::error!("Insecure listener failed with: {}", e);
}
if let Err(e) = e2 {
log::error!("Secure listener failed with: {}", e);
}
Ok(())
}
impl MultiListener {
/// Create a new `InsecureOnly` listener
pub fn new_insecure_only(base: BaseListener) -> Self {
MultiListener::InsecureOnly(Listener::new(base))
pub fn new_insecure_only(base: BaseListener, protocol: ProtocolVersion) -> Self {
match protocol {
ProtocolVersion::V2 => MultiListener::InsecureOnly(Listener::new(base)),
ProtocolVersion::V1 => MultiListener::InsecureOnlyV1(ListenerV1::new(base)),
}
}
/// Create a new `SecureOnly` listener
pub fn new_secure_only(base: BaseListener, ssl: SslOpts) -> SkyResult<Self> {
let listener =
SslListener::new_pem_based_ssl_connection(ssl.key, ssl.chain, base, ssl.passfile)?;
Ok(MultiListener::SecureOnly(listener))
pub fn new_secure_only(
base: BaseListener,
ssl: SslOpts,
protocol: ProtocolVersion,
) -> SkyResult<Self> {
let listener = match protocol {
ProtocolVersion::V2 => {
let listener = SslListener::new_pem_based_ssl_connection(
ssl.key,
ssl.chain,
base,
ssl.passfile,
)?;
MultiListener::SecureOnly(listener)
}
ProtocolVersion::V1 => {
let listener = SslListenerV1::new_pem_based_ssl_connection(
ssl.key,
ssl.chain,
base,
ssl.passfile,
)?;
MultiListener::SecureOnlyV1(listener)
}
};
Ok(listener)
}
/// Create a new `Multi` listener that has both a secure and an insecure listener
pub async fn new_multi(
ssl_base_listener: BaseListener,
tcp_base_listener: BaseListener,
ssl: SslOpts,
protocol: ProtocolVersion,
) -> SkyResult<Self> {
let secure_listener = SslListener::new_pem_based_ssl_connection(
ssl.key,
ssl.chain,
ssl_base_listener,
ssl.passfile,
)?;
let insecure_listener = Listener::new(tcp_base_listener);
Ok(MultiListener::Multi(insecure_listener, secure_listener))
let mls = match protocol {
ProtocolVersion::V2 => {
let secure_listener = SslListener::new_pem_based_ssl_connection(
ssl.key,
ssl.chain,
ssl_base_listener,
ssl.passfile,
)?;
let insecure_listener = Listener::new(tcp_base_listener);
MultiListener::Multi(insecure_listener, secure_listener)
}
ProtocolVersion::V1 => {
let secure_listener = SslListenerV1::new_pem_based_ssl_connection(
ssl.key,
ssl.chain,
ssl_base_listener,
ssl.passfile,
)?;
let insecure_listener = ListenerV1::new(tcp_base_listener);
MultiListener::MultiV1(insecure_listener, secure_listener)
}
};
Ok(mls)
}
/// Start the server
///
@ -197,18 +260,14 @@ impl MultiListener {
pub async fn run_server(&mut self) -> IoResult<()> {
match self {
MultiListener::SecureOnly(secure_listener) => secure_listener.run().await,
MultiListener::SecureOnlyV1(secure_listener) => secure_listener.run().await,
MultiListener::InsecureOnly(insecure_listener) => insecure_listener.run().await,
MultiListener::InsecureOnlyV1(insecure_listener) => insecure_listener.run().await,
MultiListener::Multi(insecure_listener, secure_listener) => {
let insec = insecure_listener.run();
let sec = secure_listener.run();
let (e1, e2) = tokio::join!(insec, sec);
if let Err(e) = e1 {
log::error!("Insecure listener failed with: {}", e);
}
if let Err(e) = e2 {
log::error!("Secure listener failed with: {}", e);
}
Ok(())
wait_on_port_futures(insecure_listener.run(), secure_listener.run()).await
}
MultiListener::MultiV1(insecure_listener, secure_listener) => {
wait_on_port_futures(insecure_listener.run(), secure_listener.run()).await
}
}
}
@ -218,12 +277,18 @@ impl MultiListener {
/// make sure that the data is saved!**
pub async fn finish_with_termsig(self) {
match self {
MultiListener::InsecureOnly(server) => server.base.release_self().await,
MultiListener::SecureOnly(server) => server.base.release_self().await,
MultiListener::InsecureOnly(Listener { base, .. })
| MultiListener::SecureOnly(SslListener { base, .. })
| MultiListener::InsecureOnlyV1(ListenerV1 { base, .. })
| MultiListener::SecureOnlyV1(SslListenerV1 { base, .. }) => base.release_self().await,
MultiListener::Multi(insecure, secure) => {
insecure.base.release_self().await;
secure.base.release_self().await;
}
MultiListener::MultiV1(insecure, secure) => {
insecure.base.release_self().await;
secure.base.release_self().await;
}
}
}
}
@ -231,6 +296,7 @@ impl MultiListener {
/// Initialize the database networking
pub async fn connect(
ports: PortConfig,
protocol: ProtocolVersion,
maxcon: usize,
db: Corestore,
auth: AuthProvider,
@ -250,17 +316,19 @@ pub async fn connect(
let description = ports.get_description();
let server = match ports {
PortConfig::InsecureOnly { host, port } => {
MultiListener::new_insecure_only(base_listener_init(host, port).await?)
}
PortConfig::SecureOnly { host, ssl } => {
MultiListener::new_secure_only(base_listener_init(host, ssl.port).await?, ssl)?
MultiListener::new_insecure_only(base_listener_init(host, port).await?, protocol)
}
PortConfig::SecureOnly { host, ssl } => MultiListener::new_secure_only(
base_listener_init(host, ssl.port).await?,
ssl,
protocol,
)?,
PortConfig::Multi { host, port, ssl } => {
let secure_listener = base_listener_init(host, ssl.port).await?;
let insecure_listener = base_listener_init(host, port).await?;
MultiListener::new_multi(secure_listener, insecure_listener, ssl).await?
MultiListener::new_multi(secure_listener, insecure_listener, ssl, protocol).await?
}
};
log::info!("Server started on {}", description);
log::info!("Server started on {description}");
Ok(server)
}

@ -24,28 +24,35 @@
*
*/
use crate::{
dbnet::{
connection::{ConnectionHandler, ExecutorFn},
BaseListener, Terminator,
},
protocol, IoResult,
};
use bytes::BytesMut;
use libsky::BUF_CAP;
pub use protocol::{ParseResult, Query};
use std::{cell::Cell, time::Duration};
use tokio::{
io::{AsyncWrite, BufWriter},
net::TcpStream,
time,
use {
crate::{
dbnet::{
connection::{ConnectionHandler, ExecutorFn},
BaseListener, Terminator,
},
protocol::{
self,
interface::{ProtocolRead, ProtocolSpec, ProtocolWrite},
Skyhash1, Skyhash2,
},
IoResult,
},
bytes::BytesMut,
libsky::BUF_CAP,
std::{cell::Cell, time::Duration},
tokio::{
io::{AsyncWrite, BufWriter},
net::TcpStream,
time,
},
};
pub trait BufferedSocketStream: AsyncWrite {}
impl BufferedSocketStream for TcpStream {}
type TcpExecutorFn = ExecutorFn<Connection<TcpStream>, TcpStream>;
type TcpExecutorFn<P> = ExecutorFn<P, Connection<TcpStream>, TcpStream>;
/// A TCP/SSL connection wrapper
pub struct Connection<T>
@ -94,13 +101,19 @@ impl TcpBackoff {
}
}
pub type Listener = RawListener<Skyhash2>;
pub type ListenerV1 = RawListener<Skyhash1>;
/// A listener
pub struct Listener {
pub struct RawListener<P> {
pub base: BaseListener,
executor_fn: TcpExecutorFn,
executor_fn: TcpExecutorFn<P>,
}
impl Listener {
impl<P: ProtocolSpec + 'static> RawListener<P>
where
Connection<TcpStream>: ProtocolRead<P, TcpStream> + ProtocolWrite<P, TcpStream>,
{
pub fn new(base: BaseListener) -> Self {
Self {
executor_fn: if base.auth.is_enabled() {

@ -24,40 +24,53 @@
*
*/
use crate::{
dbnet::{
connection::{ConnectionHandler, ExecutorFn},
tcp::{BufferedSocketStream, Connection, TcpBackoff},
BaseListener, Terminator,
use {
crate::{
dbnet::{
connection::{ConnectionHandler, ExecutorFn},
tcp::{BufferedSocketStream, Connection, TcpBackoff},
BaseListener, Terminator,
},
protocol::{
interface::{ProtocolRead, ProtocolSpec, ProtocolWrite},
Skyhash1, Skyhash2,
},
util::error::{Error, SkyResult},
IoResult,
},
util::error::{Error, SkyResult},
IoResult,
};
use openssl::{
pkey::PKey,
rsa::Rsa,
ssl::{Ssl, SslAcceptor, SslFiletype, SslMethod},
openssl::{
pkey::PKey,
rsa::Rsa,
ssl::{Ssl, SslAcceptor, SslFiletype, SslMethod},
},
std::{fs, pin::Pin},
tokio::net::TcpStream,
tokio_openssl::SslStream,
};
use std::{fs, pin::Pin};
use tokio::net::TcpStream;
use tokio_openssl::SslStream;
impl BufferedSocketStream for SslStream<TcpStream> {}
type SslExecutorFn = ExecutorFn<Connection<SslStream<TcpStream>>, SslStream<TcpStream>>;
type SslExecutorFn<P> = ExecutorFn<P, Connection<SslStream<TcpStream>>, SslStream<TcpStream>>;
pub type SslListener = SslListenerRaw<Skyhash2>;
pub type SslListenerV1 = SslListenerRaw<Skyhash1>;
pub struct SslListener {
pub struct SslListenerRaw<P> {
pub base: BaseListener,
acceptor: SslAcceptor,
executor_fn: SslExecutorFn,
executor_fn: SslExecutorFn<P>,
}
impl SslListener {
impl<P: ProtocolSpec + 'static> SslListenerRaw<P>
where
Connection<SslStream<TcpStream>>:
ProtocolRead<P, SslStream<TcpStream>> + ProtocolWrite<P, SslStream<TcpStream>>,
{
pub fn new_pem_based_ssl_connection(
key_file: String,
chain_file: String,
base: BaseListener,
tls_passfile: Option<String>,
) -> SkyResult<Self> {
) -> SkyResult<SslListenerRaw<P>> {
let mut acceptor_builder = SslAcceptor::mozilla_intermediate(SslMethod::tls())?;
// cert is the same for both
acceptor_builder.set_certificate_chain_file(chain_file)?;
@ -77,7 +90,7 @@ impl SslListener {
// no passphrase, needs interactive
acceptor_builder.set_private_key_file(key_file, SslFiletype::PEM)?;
}
Ok(SslListener {
Ok(Self {
acceptor: acceptor_builder.build(),
executor_fn: if base.auth.is_enabled() {
ConnectionHandler::execute_unauth

@ -56,7 +56,6 @@ mod kvengine;
mod protocol;
mod queryengine;
pub mod registry;
mod resp;
mod services;
mod storage;
#[cfg(test)]

@ -1,79 +0,0 @@
/*
* Created on Tue Nov 02 2021
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2021, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
/*
Do note that the result of the benches might actually be slower, than faster! The reason it is so, is simply because of
the fact that we generate owned queries, by copying bytes which adds an overhead, but offers simplicity in writing tests
and/or benches
*/
extern crate test;
use super::{element::OwnedElement, OwnedQuery, Parser};
use bytes::Bytes;
use test::Bencher;
#[bench]
fn bench_simple_query_string(b: &mut Bencher) {
const PAYLOAD: &[u8] = b"*1\n+5\nsayan\n";
unsafe {
b.iter(|| {
assert_eq!(
Parser::new(PAYLOAD).parse().unwrap().0.into_owned_query(),
OwnedQuery::SimpleQuery(OwnedElement::String(Bytes::from("sayan")))
);
})
}
}
#[bench]
fn bench_simple_query_uint(b: &mut Bencher) {
const PAYLOAD: &[u8] = b"*1\n:5\n12345\n";
unsafe {
b.iter(|| {
assert_eq!(
Parser::new(PAYLOAD).parse().unwrap().0.into_owned_query(),
OwnedQuery::SimpleQuery(OwnedElement::UnsignedInt(12345))
);
})
}
}
#[bench]
fn bench_simple_query_any_array(b: &mut Bencher) {
const PAYLOAD: &[u8] = b"*1\n~3\n3\nthe\n3\ncat\n6\nmeowed\n";
unsafe {
b.iter(|| {
assert_eq!(
Parser::new(PAYLOAD).parse().unwrap().0.into_owned_query(),
OwnedQuery::SimpleQuery(OwnedElement::AnyArray(vec![
"the".into(),
"cat".into(),
"meowed".into()
]))
)
})
}
}

@ -0,0 +1,492 @@
/*
* Created on Tue Apr 26 2022
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2022, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
use super::ParseError;
use crate::{
corestore::{
booltable::{BytesBoolTable, BytesNicheLUT},
buffers::Integer64,
},
dbnet::connection::{QueryResult, QueryWithAdvance, RawConnection, Stream},
util::FutureResult,
IoResult,
};
use std::io::{Error as IoError, ErrorKind};
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};
/*
NOTE TO SELF (@ohsayan): Why do we split everything into separate traits? To avoid mistakes
in the future. We don't want any action to randomly call `read_query`, which was possible
with the earlier `ProtcolConnectionExt` trait, since it was imported by every action from
the prelude.
- `ProtocolSpec`: this is like a charset definition of the protocol along with some other
good stuff
- `ProtocolRead`: should only read from the stream and never write
- `ProtocolWrite`: should only write data and never read
These distinctions reduce the likelihood of making mistakes while implementing the traits
-- Sayan (May, 2022)
*/
/// The `ProtocolSpec` trait is used to define the character set and pre-generated elements
/// and responses for a protocol version. To make any actual use of it, you need to implement
/// both the `ProtocolRead` and `ProtocolWrite` for the protocol
pub trait ProtocolSpec: Send + Sync {
// spec information
/// The Skyhash protocol version
const PROTOCOL_VERSION: f32;
/// The Skyhash protocol version string (Skyhash-x.y)
const PROTOCOL_VERSIONSTRING: &'static str;
// type symbols
/// Type symbol for unicode strings
const TSYMBOL_STRING: u8;
/// Type symbol for blobs
const TSYMBOL_BINARY: u8;
/// Type symbol for float
const TSYMBOL_FLOAT: u8;
/// Type symbok for int64
const TSYMBOL_INT64: u8;
/// Type symbol for typed array
const TSYMBOL_TYPED_ARRAY: u8;
/// Type symbol for typed non-null array
const TSYMBOL_TYPED_NON_NULL_ARRAY: u8;
/// Type symbol for an array
const TSYMBOL_ARRAY: u8;
/// Type symbol for a flat array
const TSYMBOL_FLAT_ARRAY: u8;
// charset
/// The line-feed character or separator
const LF: u8 = b'\n';
// metaframe
/// The header for simple queries
const SIMPLE_QUERY_HEADER: &'static [u8];
/// The header for pipelined queries (excluding length, obviously)
const PIPELINED_QUERY_FIRST_BYTE: u8;
// typed array
/// Null element represenation for a typed array
const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8];
// respcodes
/// Respcode 0: Okay
const RCODE_OKAY: &'static [u8];
/// Respcode 1: Nil
const RCODE_NIL: &'static [u8];
/// Respcode 2: Overwrite error
const RCODE_OVERWRITE_ERR: &'static [u8];
/// Respcode 3: Action error
const RCODE_ACTION_ERR: &'static [u8];
/// Respcode 4: Packet error
const RCODE_PACKET_ERR: &'static [u8];
/// Respcode 5: Server error
const RCODE_SERVER_ERR: &'static [u8];
/// Respcode 6: Other error
const RCODE_OTHER_ERR_EMPTY: &'static [u8];
/// Respcode 7: Unknown action
const RCODE_UNKNOWN_ACTION: &'static [u8];
/// Respcode 8: Wrongtype error
const RCODE_WRONGTYPE_ERR: &'static [u8];
/// Respcode 9: Unknown data type error
const RCODE_UNKNOWN_DATA_TYPE: &'static [u8];
/// Respcode 10: Encoding error
const RCODE_ENCODING_ERROR: &'static [u8];
// respstrings
/// Respstring when snapshot engine is busy
const RSTRING_SNAPSHOT_BUSY: &'static [u8];
/// Respstring when snapshots are disabled
const RSTRING_SNAPSHOT_DISABLED: &'static [u8];
/// Respstring when duplicate snapshot creation is attempted
const RSTRING_SNAPSHOT_DUPLICATE: &'static [u8];
/// Respstring when snapshot has illegal chars
const RSTRING_SNAPSHOT_ILLEGAL_NAME: &'static [u8];
/// Respstring when a **very bad error** happens (use after termsig)
const RSTRING_ERR_ACCESS_AFTER_TERMSIG: &'static [u8];
/// Respstring when the default container is unset
const RSTRING_DEFAULT_UNSET: &'static [u8];
/// Respstring when the container is not found
const RSTRING_CONTAINER_NOT_FOUND: &'static [u8];
/// Respstring when the container is still in use, but a _free_ op is attempted
const RSTRING_STILL_IN_USE: &'static [u8];
/// Respstring when a protected container is attempted to be accessed/modified
const RSTRING_PROTECTED_OBJECT: &'static [u8];
/// Respstring when an action is not suitable for the current table model
const RSTRING_WRONG_MODEL: &'static [u8];
/// Respstring when the container already exists
const RSTRING_ALREADY_EXISTS: &'static [u8];
/// Respstring when the container is not ready
const RSTRING_NOT_READY: &'static [u8];
/// Respstring when a DDL transaction fails
const RSTRING_DDL_TRANSACTIONAL_FAILURE: &'static [u8];
/// Respstring when an unknow DDL query is run (`CREATE BLAH`, for example)
const RSTRING_UNKNOWN_DDL_QUERY: &'static [u8];
/// Respstring when a bad DDL expression is run
const RSTRING_BAD_EXPRESSION: &'static [u8];
/// Respstring when an unsupported model is attempted to be used during table creation
const RSTRING_UNKNOWN_MODEL: &'static [u8];
/// Respstring when too many arguments are passed to a DDL query
const RSTRING_TOO_MANY_ARGUMENTS: &'static [u8];
/// Respstring when the container name is too long
const RSTRING_CONTAINER_NAME_TOO_LONG: &'static [u8];
/// Respstring when the container name
const RSTRING_BAD_CONTAINER_NAME: &'static [u8];
/// Respstring when an unknown inspect query is run (`INSPECT blah`, for example)
const RSTRING_UNKNOWN_INSPECT_QUERY: &'static [u8];
/// Respstring when an unknown table property is passed during table creation
const RSTRING_UNKNOWN_PROPERTY: &'static [u8];
/// Respstring when a non-empty keyspace is attempted to be dropped
const RSTRING_KEYSPACE_NOT_EMPTY: &'static [u8];
/// Respstring when a bad type is provided for a key in the K/V engine (like using a `list`
/// for the key)
const RSTRING_BAD_TYPE_FOR_KEY: &'static [u8];
/// Respstring when a non-existent index is attempted to be accessed in a list
const RSTRING_LISTMAP_BAD_INDEX: &'static [u8];
/// Respstring when a list is empty and we attempt to access/modify it
const RSTRING_LISTMAP_LIST_IS_EMPTY: &'static [u8];
// element responses
/// A string element containing the text "HEY!"
const ELEMRESP_HEYA: &'static [u8];
// full responses
/// A **full response** for a packet error
const FULLRESP_RCODE_PACKET_ERR: &'static [u8];
/// A **full response** for a wrongtype error
const FULLRESP_RCODE_WRONG_TYPE: &'static [u8];
// LUTs
/// A LUT for SET operations
const SET_NLUT: BytesNicheLUT = BytesNicheLUT::new(
Self::RCODE_ENCODING_ERROR,
Self::RCODE_OKAY,
Self::RCODE_OVERWRITE_ERR,
);
/// A LUT for lists
const OKAY_BADIDX_NIL_NLUT: BytesNicheLUT = BytesNicheLUT::new(
Self::RCODE_NIL,
Self::RCODE_OKAY,
Self::RSTRING_LISTMAP_BAD_INDEX,
);
/// A LUT for SET operations
const OKAY_OVW_BLUT: BytesBoolTable =
BytesBoolTable::new(Self::RCODE_OKAY, Self::RCODE_OVERWRITE_ERR);
/// A LUT for UPDATE operations
const UPDATE_NLUT: BytesNicheLUT = BytesNicheLUT::new(
Self::RCODE_ENCODING_ERROR,
Self::RCODE_OKAY,
Self::RCODE_NIL,
);
// auth error respstrings
/// respstring: already claimed (user was already claimed)
const AUTH_ERROR_ALREADYCLAIMED: &'static [u8];
/// respcode(10): bad credentials (either bad creds or invalid user)
const AUTH_CODE_BAD_CREDENTIALS: &'static [u8];
/// respstring: auth is disabled
const AUTH_ERROR_DISABLED: &'static [u8];
/// respcode(11): Insufficient permissions (same for anonymous user)
const AUTH_CODE_PERMS: &'static [u8];
/// respstring: ID is too long
const AUTH_ERROR_ILLEGAL_USERNAME: &'static [u8];
/// respstring: ID is protected/in use
const AUTH_ERROR_FAILED_TO_DELETE_USER: &'static [u8];
}
/// # The `ProtocolRead` trait
///
/// The `ProtocolRead` trait enables read operations using the protocol for a given stream `Strm` and protocol
/// `P`. Both the stream and protocol must implement the appropriate traits for you to be able to use these
/// traits
///
/// ## DO NOT
/// The fact that this is a trait enables great flexibility in terms of visibility, but **DO NOT EVER CALL any
/// function other than `read_query`, `close_conn_with_error` or `write_response`**. If you mess with functions
/// like `read_again`, you're likely to pull yourself into some good trouble.
pub trait ProtocolRead<P, Strm>: RawConnection<P, Strm>
where
Strm: Stream,
P: ProtocolSpec,
{
/// Try to parse a query from the buffered data
fn try_query(&self) -> Result<QueryWithAdvance, ParseError>;
/// Read a query from the remote end
///
/// This function asynchronously waits until all the data required
/// for parsing the query is available
fn read_query<'s, 'r: 's>(&'r mut self) -> FutureResult<'s, Result<QueryResult, IoError>> {
Box::pin(async move {
let mv_self = self;
loop {
let (buffer, stream) = mv_self.get_mut_both();
match stream.read_buf(buffer).await {
Ok(0) => {
if buffer.is_empty() {
return Ok(QueryResult::Disconnected);
} else {
return Err(IoError::from(ErrorKind::ConnectionReset));
}
}
Ok(_) => {}
Err(e) => return Err(e),
}
match mv_self.try_query() {
Ok(query_with_advance) => {
return Ok(QueryResult::Q(query_with_advance));
}
Err(ParseError::NotEnough) => (),
Err(ParseError::DatatypeParseFailure) => return Ok(QueryResult::Wrongtype),
Err(ParseError::UnexpectedByte | ParseError::BadPacket) => {
return Ok(QueryResult::E(P::FULLRESP_RCODE_PACKET_ERR));
}
Err(ParseError::WrongType) => {
return Ok(QueryResult::E(P::FULLRESP_RCODE_WRONG_TYPE));
}
}
}
})
}
}
pub trait ProtocolWrite<P, Strm>: RawConnection<P, Strm>
where
Strm: Stream,
P: ProtocolSpec,
{
// utility
fn _get_raw_stream(&mut self) -> &mut BufWriter<Strm> {
self.get_mut_stream()
}
fn _flush_stream<'life0, 'ret_life>(&'life0 mut self) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: Send + 'ret_life,
{
Box::pin(async move { self.get_mut_stream().flush().await })
}
fn _write_raw<'life0, 'life1, 'ret_life>(
&'life0 mut self,
data: &'life1 [u8],
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: Send + 'ret_life,
{
Box::pin(async move { self.get_mut_stream().write_all(data).await })
}
fn _write_raw_flushed<'life0, 'life1, 'ret_life>(
&'life0 mut self,
data: &'life1 [u8],
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: Send + 'ret_life,
{
Box::pin(async move {
self._write_raw(data).await?;
self._flush_stream().await
})
}
fn close_conn_with_error<'life0, 'life1, 'ret_life>(
&'life0 mut self,
resp: &'life1 [u8],
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: Send + 'ret_life,
{
Box::pin(async move { self._write_raw_flushed(resp).await })
}
// metaframe
fn write_simple_query_header<'life0, 'ret_life>(
&'life0 mut self,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: Send + 'ret_life,
{
Box::pin(async move {
self.get_mut_stream()
.write_all(P::SIMPLE_QUERY_HEADER)
.await
})
}
fn write_pipelined_query_header<'life0, 'ret_life>(
&'life0 mut self,
qcount: usize,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: Send + 'ret_life,
{
Box::pin(async move {
self.get_mut_stream()
.write_all(&[P::PIPELINED_QUERY_FIRST_BYTE])
.await?;
self.get_mut_stream()
.write_all(&Integer64::from(qcount))
.await?;
self.get_mut_stream().write_all(&[P::LF]).await
})
}
// monoelement
fn write_mono_length_prefixed_with_tsymbol<'life0, 'life1, 'ret_life>(
&'life0 mut self,
data: &'life1 [u8],
tsymbol: u8,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: Send + 'ret_life;
/// serialize and write an `&str` to the stream
fn write_string<'life0, 'life1, 'ret_life>(
&'life0 mut self,
string: &'life1 str,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: 'ret_life;
/// serialize and write an `&[u8]` to the stream
fn write_binary<'life0, 'life1, 'ret_life>(
&'life0 mut self,
binary: &'life1 [u8],
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: 'ret_life;
/// serialize and write an `usize` to the stream
fn write_usize<'life0, 'ret_life>(
&'life0 mut self,
size: usize,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: 'ret_life;
/// serialize and write an `u64` to the stream
fn write_int64<'life0, 'ret_life>(
&'life0 mut self,
int: u64,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: 'ret_life;
/// serialize and write an `f32` to the stream
fn write_float<'life0, 'ret_life>(
&'life0 mut self,
float: f32,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: 'ret_life;
// typed array
fn write_typed_array_header<'life0, 'ret_life>(
&'life0 mut self,
len: usize,
tsymbol: u8,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: Send + 'ret_life,
{
Box::pin(async move {
self.get_mut_stream()
.write_all(&[P::TSYMBOL_TYPED_ARRAY, tsymbol])
.await?;
self.get_mut_stream()
.write_all(&Integer64::from(len))
.await?;
self.get_mut_stream().write_all(&[P::LF]).await?;
Ok(())
})
}
fn write_typed_array_element_null<'life0, 'ret_life>(
&'life0 mut self,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: Send + 'ret_life,
{
Box::pin(async move {
self.get_mut_stream()
.write_all(P::TYPE_TYPED_ARRAY_ELEMENT_NULL)
.await
})
}
fn write_typed_array_element<'life0, 'life1, 'ret_life>(
&'life0 mut self,
element: &'life1 [u8],
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: 'ret_life;
// typed non-null array
fn write_typed_non_null_array_header<'life0, 'ret_life>(
&'life0 mut self,
len: usize,
tsymbol: u8,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: Send + 'ret_life,
{
Box::pin(async move {
self.get_mut_stream()
.write_all(&[P::TSYMBOL_TYPED_NON_NULL_ARRAY, tsymbol])
.await?;
self.get_mut_stream()
.write_all(&Integer64::from(len))
.await?;
self.get_mut_stream().write_all(&[P::LF]).await?;
Ok(())
})
}
fn write_typed_non_null_array_element<'life0, 'life1, 'ret_life>(
&'life0 mut self,
element: &'life1 [u8],
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: Send + 'ret_life,
{
Box::pin(async move { self.write_typed_array_element(element).await })
}
}

@ -75,16 +75,14 @@ impl<'a> AnyArrayIter<'a> {
}
/// Returns the next value in uppercase
pub fn next_uppercase(&mut self) -> Option<Box<[u8]>> {
self.iter.next().map(|v| unsafe {
// SAFETY: Only construction is unsafe, forwarding is not
v.as_slice().to_ascii_uppercase().into_boxed_slice()
})
self.iter
.next()
.map(|v| v.as_slice().to_ascii_uppercase().into_boxed_slice())
}
pub fn next_lowercase(&mut self) -> Option<Box<[u8]>> {
self.iter.next().map(|v| unsafe {
// SAFETY: Only construction is unsafe, forwarding is not
v.as_slice().to_ascii_lowercase().into_boxed_slice()
})
self.iter
.next()
.map(|v| v.as_slice().to_ascii_lowercase().into_boxed_slice())
}
pub unsafe fn next_lowercase_unchecked(&mut self) -> Box<[u8]> {
self.next_lowercase().unwrap_or_else(|| impossible!())
@ -143,7 +141,7 @@ unsafe impl DerefUnsafeSlice for Bytes {
impl<'a> Iterator for AnyArrayIter<'a> {
type Item = &'a [u8];
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|v| unsafe { v.as_slice() })
self.iter.next().map(|v| v.as_slice())
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
@ -152,7 +150,7 @@ impl<'a> Iterator for AnyArrayIter<'a> {
impl<'a> DoubleEndedIterator for AnyArrayIter<'a> {
fn next_back(&mut self) -> Option<<Self as Iterator>::Item> {
self.iter.next_back().map(|v| unsafe { v.as_slice() })
self.iter.next_back().map(|v| v.as_slice())
}
}
@ -162,30 +160,15 @@ impl<'a> FusedIterator for AnyArrayIter<'a> {}
impl<'a> Iterator for BorrowedAnyArrayIter<'a> {
type Item = &'a [u8];
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|v| unsafe { v.as_slice() })
self.iter.next().map(|v| v.as_slice())
}
}
impl<'a> DoubleEndedIterator for BorrowedAnyArrayIter<'a> {
fn next_back(&mut self) -> Option<<Self as Iterator>::Item> {
self.iter.next_back().map(|v| unsafe { v.as_slice() })
self.iter.next_back().map(|v| v.as_slice())
}
}
impl<'a> ExactSizeIterator for BorrowedAnyArrayIter<'a> {}
impl<'a> FusedIterator for BorrowedAnyArrayIter<'a> {}
#[test]
fn test_iter() {
use super::{Parser, Query};
let (q, _fwby) = Parser::parse(b"*3\n3\nset1\nx3\n100").unwrap();
let r = match q {
Query::Simple(q) => q,
_ => panic!("Wrong query"),
};
let it = r.as_slice().iter();
let mut iter = unsafe { AnyArrayIter::new(it) };
assert_eq!(iter.next_uppercase().unwrap().as_ref(), "SET".as_bytes());
assert_eq!(iter.next().unwrap(), "x".as_bytes());
assert_eq!(iter.next().unwrap(), "100".as_bytes());
}

@ -24,21 +24,30 @@
*
*/
use crate::corestore::heap_array::HeapArray;
use core::{fmt, marker::PhantomData, mem::transmute, slice};
#[cfg(feature = "nightly")]
mod benches;
#[cfg(test)]
mod tests;
use self::interface::ProtocolSpec;
use {
crate::corestore::heap_array::HeapArray,
core::{fmt, slice},
};
// pub mods
pub mod interface;
pub mod iter;
pub mod responses;
// internal mods
mod raw_parser;
// versions
mod v1;
mod v2;
// endof pub mods
/// The Skyhash protocol version
pub const PROTOCOL_VERSION: f32 = 2.0;
/// The Skyhash protocol version string (Skyhash-x.y)
pub const PROTOCOL_VERSIONSTRING: &str = "Skyhash-2.0";
pub type Skyhash2 = v2::Parser;
pub type Skyhash1 = v1::Parser;
#[cfg(test)]
/// The latest protocol version supported by this version
pub const LATEST_PROTOCOL_VERSION: f32 = Skyhash2::PROTOCOL_VERSION;
#[cfg(test)]
/// The latest protocol version supported by this version (`Skyhash-x.y`)
pub const LATEST_PROTOCOL_VERSIONSTRING: &str = Skyhash2::PROTOCOL_VERSIONSTRING;
#[derive(PartialEq)]
/// As its name says, an [`UnsafeSlice`] is a terribly unsafe slice. It's guarantess are
@ -72,8 +81,13 @@ impl UnsafeSlice {
Self { start_ptr, len }
}
/// Return self as a slice
pub unsafe fn as_slice(&self) -> &[u8] {
slice::from_raw_parts(self.start_ptr, self.len)
pub fn as_slice(&self) -> &[u8] {
unsafe {
// UNSAFE(@ohsayan): Just like core::slice, we resemble the same idea:
// we assume that the unsafe construction was correct and hence *assume*
// that calling this is safe
slice::from_raw_parts(self.start_ptr, self.len)
}
}
}
@ -86,7 +100,6 @@ pub enum ParseError {
/// Didn't get the number of expected bytes
NotEnough = 0u8,
/// The packet simply contains invalid data
#[allow(dead_code)] // HACK(@ohsayan): rustc can't "guess" the transmutation
BadPacket = 1u8,
/// The query contains an unexpected byte
UnexpectedByte = 2u8,
@ -94,6 +107,8 @@ pub enum ParseError {
///
/// This can happen not just for elements but can also happen for their sizes ([`Self::parse_into_u64`])
DatatypeParseFailure = 3u8,
/// The client supplied the wrong query data type for the given query
WrongType = 4u8,
}
/// A generic result to indicate parsing errors thorugh the [`ParseError`] enum
@ -114,13 +129,12 @@ impl SimpleQuery {
#[cfg(test)]
fn into_owned(self) -> OwnedSimpleQuery {
OwnedSimpleQuery {
data: self
.data
.iter()
.map(|v| unsafe { v.as_slice().to_owned() })
.collect(),
data: self.data.iter().map(|v| v.as_slice().to_owned()).collect(),
}
}
pub const fn new(data: HeapArray<UnsafeSlice>) -> Self {
Self { data }
}
pub fn as_slice(&self) -> &[UnsafeSlice] {
&self.data
}
@ -128,7 +142,7 @@ impl SimpleQuery {
#[cfg(test)]
struct OwnedSimpleQuery {
data: Vec<Vec<u8>>,
pub data: Vec<Vec<u8>>,
}
#[derive(Debug)]
@ -137,6 +151,9 @@ pub struct PipelinedQuery {
}
impl PipelinedQuery {
pub const fn new(data: HeapArray<HeapArray<UnsafeSlice>>) -> Self {
Self { data }
}
pub fn len(&self) -> usize {
self.data.len()
}
@ -149,11 +166,7 @@ impl PipelinedQuery {
data: self
.data
.iter()
.map(|v| {
v.iter()
.map(|v| unsafe { v.as_slice().to_owned() })
.collect()
})
.map(|v| v.iter().map(|v| v.as_slice().to_owned()).collect())
.collect(),
}
}
@ -161,261 +174,5 @@ impl PipelinedQuery {
#[cfg(test)]
struct OwnedPipelinedQuery {
data: Vec<Vec<Vec<u8>>>,
}
/// A parser for Skyhash 2.0
pub struct Parser<'a> {
end: *const u8,
cursor: *const u8,
_lt: PhantomData<&'a ()>,
}
impl<'a> Parser<'a> {
/// Initialize a new parser
pub fn new(slice: &[u8]) -> Self {
unsafe {
Self {
end: slice.as_ptr().add(slice.len()),
cursor: slice.as_ptr(),
_lt: PhantomData,
}
}
}
}
// basic methods
impl<'a> Parser<'a> {
/// Returns a ptr one byte past the allocation of the buffer
const fn data_end_ptr(&self) -> *const u8 {
self.end
}
/// Returns the position of the cursor
/// WARNING: Deref might led to a segfault
const fn cursor_ptr(&self) -> *const u8 {
self.cursor
}
/// Check how many bytes we have left
fn remaining(&self) -> usize {
self.data_end_ptr() as usize - self.cursor_ptr() as usize
}
/// Check if we have `size` bytes remaining
fn has_remaining(&self, size: usize) -> bool {
self.remaining() >= size
}
#[cfg(test)]
/// Check if we have exhausted the buffer
fn exhausted(&self) -> bool {
self.cursor_ptr() >= self.data_end_ptr()
}
/// Check if the buffer is not exhausted
fn not_exhausted(&self) -> bool {
self.cursor_ptr() < self.data_end_ptr()
}
/// Attempts to return the byte pointed at by the cursor.
/// WARNING: The same segfault warning
const unsafe fn get_byte_at_cursor(&self) -> u8 {
*self.cursor_ptr()
}
}
// mutable refs
impl<'a> Parser<'a> {
/// Increment the cursor by `by` positions
unsafe fn incr_cursor_by(&mut self, by: usize) {
self.cursor = self.cursor.add(by);
}
/// Increment the position of the cursor by one position
unsafe fn incr_cursor(&mut self) {
self.incr_cursor_by(1);
}
}
// higher level abstractions
impl<'a> Parser<'a> {
/// Attempt to read `len` bytes
fn read_until(&mut self, len: usize) -> ParseResult<UnsafeSlice> {
if self.has_remaining(len) {
unsafe {
// UNSAFE(@ohsayan): Already verified lengths
let slice = UnsafeSlice::new(self.cursor_ptr(), len);
self.incr_cursor_by(len);
Ok(slice)
}
} else {
Err(ParseError::NotEnough)
}
}
#[cfg(test)]
/// Attempt to read a byte slice terminated by an LF
fn read_line(&mut self) -> ParseResult<UnsafeSlice> {
let start_ptr = self.cursor_ptr();
unsafe {
while self.not_exhausted() && self.get_byte_at_cursor() != b'\n' {
self.incr_cursor();
}
if self.not_exhausted() && self.get_byte_at_cursor() == b'\n' {
let len = self.cursor_ptr() as usize - start_ptr as usize;
self.incr_cursor(); // skip LF
Ok(UnsafeSlice::new(start_ptr, len))
} else {
Err(ParseError::NotEnough)
}
}
}
/// Attempt to read a line, **rejecting an empty payload**
fn read_line_pedantic(&mut self) -> ParseResult<UnsafeSlice> {
let start_ptr = self.cursor_ptr();
unsafe {
while self.not_exhausted() && self.get_byte_at_cursor() != b'\n' {
self.incr_cursor();
}
let len = self.cursor_ptr() as usize - start_ptr as usize;
let has_lf = self.not_exhausted() && self.get_byte_at_cursor() == b'\n';
if has_lf && len != 0 {
self.incr_cursor(); // skip LF
Ok(UnsafeSlice::new(start_ptr, len))
} else {
// just some silly hackery
Err(transmute(has_lf))
}
}
}
/// Attempt to read an `usize` from the buffer
fn read_usize(&mut self) -> ParseResult<usize> {
let line = self.read_line_pedantic()?;
let bytes = unsafe {
// UNSAFE(@ohsayan): We just extracted the slice
line.as_slice()
};
let mut ret = 0usize;
for byte in bytes {
if byte.is_ascii_digit() {
ret = match ret.checked_mul(10) {
Some(r) => r,
None => return Err(ParseError::DatatypeParseFailure),
};
ret = match ret.checked_add((byte & 0x0F) as _) {
Some(r) => r,
None => return Err(ParseError::DatatypeParseFailure),
};
} else {
return Err(ParseError::DatatypeParseFailure);
}
}
Ok(ret)
}
}
// query impls
impl<'a> Parser<'a> {
/// Parse the next simple query. This should have passed the `*` tsymbol
///
/// Simple query structure (tokenized line-by-line):
/// ```text
/// * -> Simple Query Header
/// <n>\n -> Count of elements in the simple query
/// <l0>\n -> Length of element 1
/// <e0> -> element 1 itself
/// <l1>\n -> Length of element 2
/// <e1> -> element 2 itself
/// ...
/// ```
fn _next_simple_query(&mut self) -> ParseResult<HeapArray<UnsafeSlice>> {
let element_count = self.read_usize()?;
unsafe {
let mut data = HeapArray::new_writer(element_count);
for i in 0..element_count {
let element_size = self.read_usize()?;
let element = self.read_until(element_size)?;
data.write_to_index(i, element);
}
Ok(data.finish())
}
}
/// Parse a simple query
fn next_simple_query(&mut self) -> ParseResult<SimpleQuery> {
Ok(SimpleQuery {
data: self._next_simple_query()?,
})
}
/// Parse a pipelined query. This should have passed the `$` tsymbol
///
/// Pipelined query structure (tokenized line-by-line):
/// ```text
/// $ -> Pipeline
/// <n>\n -> Pipeline has n queries
/// <lq0>\n -> Query 1 has 3 elements
/// <lq0e0>\n -> Q1E1 has 3 bytes
/// <q0e0> -> Q1E1 itself
/// <lq0e1>\n -> Q1E2 has 1 byte
/// <q0e1> -> Q1E2 itself
/// <lq0e2>\n -> Q1E3 has 3 bytes
/// <q0e2> -> Q1E3 itself
/// <lq1>\n -> Query 2 has 2 elements
/// <lq1e0>\n -> Q2E1 has 3 bytes
/// <q1e0> -> Q2E1 itself
/// <lq1e1>\n -> Q2E2 has 1 byte
/// <q1e1> -> Q2E2 itself
/// ...
/// ```
///
/// Example:
/// ```text
/// $ -> Pipeline
/// 2\n -> Pipeline has 2 queries
/// 3\n -> Query 1 has 3 elements
/// 3\n -> Q1E1 has 3 bytes
/// SET -> Q1E1 itself
/// 1\n -> Q1E2 has 1 byte
/// x -> Q1E2 itself
/// 3\n -> Q1E3 has 3 bytes
/// 100 -> Q1E3 itself
/// 2\n -> Query 2 has 2 elements
/// 3\n -> Q2E1 has 3 bytes
/// GET -> Q2E1 itself
/// 1\n -> Q2E2 has 1 byte
/// x -> Q2E2 itself
/// ```
fn next_pipeline(&mut self) -> ParseResult<PipelinedQuery> {
let query_count = self.read_usize()?;
unsafe {
let mut queries = HeapArray::new_writer(query_count);
for i in 0..query_count {
let sq = self._next_simple_query()?;
queries.write_to_index(i, sq);
}
Ok(PipelinedQuery {
data: queries.finish(),
})
}
}
fn _parse(&mut self) -> ParseResult<Query> {
if self.not_exhausted() {
unsafe {
let first_byte = self.get_byte_at_cursor();
self.incr_cursor();
let data = match first_byte {
b'*' => {
// a simple query
Query::Simple(self.next_simple_query()?)
}
b'$' => {
// a pipelined query
Query::Pipelined(self.next_pipeline()?)
}
_ => return Err(ParseError::UnexpectedByte),
};
Ok(data)
}
} else {
Err(ParseError::NotEnough)
}
}
pub fn parse(buf: &[u8]) -> ParseResult<(Query, usize)> {
let mut slf = Self::new(buf);
let body = slf._parse()?;
let consumed = slf.cursor_ptr() as usize - buf.as_ptr() as usize;
Ok((body, consumed))
}
pub data: Vec<Vec<Vec<u8>>>,
}

@ -0,0 +1,174 @@
/*
* Created on Tue May 03 2022
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2022, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
use {
super::{ParseError, ParseResult, UnsafeSlice},
core::mem::transmute,
};
/*
NOTE TO SELF (@ohsayan): The reason we split this into three traits is because:
- `RawParser` is the only one that is to be implemented. Just provide information about the cursor
- `RawParserMeta` provides information about the buffer based on cursor and end ptr information
- `RawParserExt` provides high-level abstractions over `RawParserMeta`. It is like the "super trait"
These distinctions reduce the likelihood of "accidentally incorrect impls" (we could've easily included
`RawParserMeta` inside `RawParser`).
-- Sayan (May, 2022)
*/
/// The `RawParser` trait has three methods that implementors must define:
///
/// - `cursor_ptr` -> Should point to the current position in the buffer for the parser
/// - `cursor_ptr_mut` -> a mutable reference to the cursor
/// - `data_end_ptr` -> a ptr to one byte past the allocated area of the buffer
///
/// All implementors of `RawParser` get a free implementation for `RawParserMeta` and `RawParserExt`
///
/// # Safety
/// - `cursor_ptr` must point to a valid location in memory
/// - `data_end_ptr` must point to a valid location in memory, in the **same allocated area**
pub(super) unsafe trait RawParser {
fn cursor_ptr(&self) -> *const u8;
fn cursor_ptr_mut(&mut self) -> &mut *const u8;
fn data_end_ptr(&self) -> *const u8;
}
/// The `RawParserMeta` trait builds on top of the `RawParser` trait to provide low-level interactions
/// and information with the parser's buffer. It is implemented for any type that implements the `RawParser`
/// trait. Manual implementation is discouraged
pub(super) trait RawParserMeta: RawParser {
/// Check how many bytes we have left
fn remaining(&self) -> usize {
self.data_end_ptr() as usize - self.cursor_ptr() as usize
}
/// Check if we have `size` bytes remaining
fn has_remaining(&self, size: usize) -> bool {
self.remaining() >= size
}
/// Check if we have exhausted the buffer
fn exhausted(&self) -> bool {
self.cursor_ptr() >= self.data_end_ptr()
}
/// Check if the buffer is not exhausted
fn not_exhausted(&self) -> bool {
self.cursor_ptr() < self.data_end_ptr()
}
/// Attempts to return the byte pointed at by the cursor.
/// WARNING: The same segfault warning
unsafe fn get_byte_at_cursor(&self) -> u8 {
*self.cursor_ptr()
}
/// Increment the cursor by `by` positions
unsafe fn incr_cursor_by(&mut self, by: usize) {
let current = *self.cursor_ptr_mut();
*self.cursor_ptr_mut() = current.add(by);
}
/// Increment the position of the cursor by one position
unsafe fn incr_cursor(&mut self) {
self.incr_cursor_by(1);
}
}
impl<T> RawParserMeta for T where T: RawParser {}
/// `RawParserExt` builds on the `RawParser` and `RawParserMeta` traits to provide high level abstractions
/// like reading lines, or a slice of a given length. It is implemented for any type that
/// implements the `RawParser` trait. Manual implementation is discouraged
pub(super) trait RawParserExt: RawParser + RawParserMeta {
/// Attempt to read `len` bytes
fn read_until(&mut self, len: usize) -> ParseResult<UnsafeSlice> {
if self.has_remaining(len) {
unsafe {
// UNSAFE(@ohsayan): Already verified lengths
let slice = UnsafeSlice::new(self.cursor_ptr(), len);
self.incr_cursor_by(len);
Ok(slice)
}
} else {
Err(ParseError::NotEnough)
}
}
#[cfg(test)]
/// Attempt to read a byte slice terminated by an LF
fn read_line(&mut self) -> ParseResult<UnsafeSlice> {
let start_ptr = self.cursor_ptr();
unsafe {
while self.not_exhausted() && self.get_byte_at_cursor() != b'\n' {
self.incr_cursor();
}
if self.not_exhausted() && self.get_byte_at_cursor() == b'\n' {
let len = self.cursor_ptr() as usize - start_ptr as usize;
self.incr_cursor(); // skip LF
Ok(UnsafeSlice::new(start_ptr, len))
} else {
Err(ParseError::NotEnough)
}
}
}
/// Attempt to read a line, **rejecting an empty payload**
fn read_line_pedantic(&mut self) -> ParseResult<UnsafeSlice> {
let start_ptr = self.cursor_ptr();
unsafe {
while self.not_exhausted() && self.get_byte_at_cursor() != b'\n' {
self.incr_cursor();
}
let len = self.cursor_ptr() as usize - start_ptr as usize;
let has_lf = self.not_exhausted() && self.get_byte_at_cursor() == b'\n';
if has_lf && len != 0 {
self.incr_cursor(); // skip LF
Ok(UnsafeSlice::new(start_ptr, len))
} else {
// just some silly hackery
Err(transmute(has_lf))
}
}
}
/// Attempt to read an `usize` from the buffer
fn read_usize(&mut self) -> ParseResult<usize> {
let line = self.read_line_pedantic()?;
let bytes = line.as_slice();
let mut ret = 0usize;
for byte in bytes {
if byte.is_ascii_digit() {
ret = match ret.checked_mul(10) {
Some(r) => r,
None => return Err(ParseError::DatatypeParseFailure),
};
ret = match ret.checked_add((byte & 0x0F) as _) {
Some(r) => r,
None => return Err(ParseError::DatatypeParseFailure),
};
} else {
return Err(ParseError::DatatypeParseFailure);
}
}
Ok(ret)
}
}
impl<T> RawParserExt for T where T: RawParser + RawParserMeta {}

@ -1,153 +0,0 @@
/*
* Created on Sat Aug 22 2020
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2020, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
//! Primitives for generating Skyhash compatible responses
pub mod groups {
#![allow(unused)]
//! # Pre-compiled response **elements**
//! These are pre-compiled response groups and **not** complete responses. If complete
//! responses are required, user protocol::responses::fresp
use ::sky_macros::compiled_eresp_bytes as eresp;
/// Response code 0 as a array element
pub const OKAY: &[u8] = eresp!("0");
/// Response code 1 as a array element
pub const NIL: &[u8] = eresp!("1");
/// Response code 2 as a array element
pub const OVERWRITE_ERR: &[u8] = eresp!("2");
/// Response code 3 as a array element
pub const ACTION_ERR: &[u8] = eresp!("3");
/// Response code 4 as a array element
pub const PACKET_ERR: &[u8] = eresp!("4");
/// Response code 5 as a array element
pub const SERVER_ERR: &[u8] = eresp!("5");
/// Response code 6 as a array element
pub const OTHER_ERR_EMPTY: &[u8] = eresp!("6");
/// Response group element with string "HEYA"
pub const HEYA: &[u8] = "+4\nHEY!".as_bytes();
/// "Unknown action" error response
pub const UNKNOWN_ACTION: &[u8] = eresp!("Unknown action");
/// Response code 7
pub const WRONGTYPE_ERR: &[u8] = eresp!("7");
/// Response code 8
pub const UNKNOWN_DATA_TYPE: &[u8] = eresp!("8");
/// Response code 9 as an array element
pub const ENCODING_ERROR: &[u8] = eresp!("9");
/// Snapshot busy error
pub const SNAPSHOT_BUSY: &[u8] = eresp!("err-snapshot-busy");
/// Snapshot disabled (other error)
pub const SNAPSHOT_DISABLED: &[u8] = eresp!("err-snapshot-disabled");
/// Duplicate snapshot
pub const SNAPSHOT_DUPLICATE: &[u8] = eresp!("duplicate-snapshot");
/// Snapshot has illegal name (other error)
pub const SNAPSHOT_ILLEGAL_NAME: &[u8] = eresp!("err-invalid-snapshot-name");
/// Access after termination signal (other error)
pub const ERR_ACCESS_AFTER_TERMSIG: &[u8] = eresp!("err-access-after-termsig");
// keyspace related resps
/// The default container was not set
pub const DEFAULT_UNSET: &[u8] = eresp!("default-container-unset");
/// The container was not found
pub const CONTAINER_NOT_FOUND: &[u8] = eresp!("container-not-found");
/// The container is still in use and so cannot be removed
pub const STILL_IN_USE: &[u8] = eresp!("still-in-use");
/// This is a protected object and hence cannot be accessed
pub const PROTECTED_OBJECT: &[u8] = eresp!("err-protected-object");
/// The action was applied against the wrong model
pub const WRONG_MODEL: &[u8] = eresp!("wrong-model");
/// The container already exists
pub const ALREADY_EXISTS: &[u8] = eresp!("err-already-exists");
/// The container is not ready
pub const NOT_READY: &[u8] = eresp!("not-ready");
/// A transactional failure occurred
pub const DDL_TRANSACTIONAL_FAILURE: &[u8] = eresp!("transactional-failure");
/// An unknown DDL query was run
pub const UNKNOWN_DDL_QUERY: &[u8] = eresp!("unknown-ddl-query");
/// The expression for a DDL query was malformed
pub const BAD_EXPRESSION: &[u8] = eresp!("malformed-expression");
/// An unknown model was passed in a DDL query
pub const UNKNOWN_MODEL: &[u8] = eresp!("unknown-model");
/// Too many arguments were passed to model constructor
pub const TOO_MANY_ARGUMENTS: &[u8] = eresp!("too-many-args");
/// The container name is too long
pub const CONTAINER_NAME_TOO_LONG: &[u8] = eresp!("container-name-too-long");
/// The container name contains invalid characters
pub const BAD_CONTAINER_NAME: &[u8] = eresp!("bad-container-name");
/// An unknown inspect query
pub const UNKNOWN_INSPECT_QUERY: &[u8] = eresp!("unknown-inspect-query");
/// An unknown table property was passed
pub const UNKNOWN_PROPERTY: &[u8] = eresp!("unknown-property");
/// The keyspace is not empty and hence cannot be removed
pub const KEYSPACE_NOT_EMPTY: &[u8] = eresp!("keyspace-not-empty");
/// Bad type supplied in a DDL query for the key
pub const BAD_TYPE_FOR_KEY: &[u8] = eresp!("bad-type-for-key");
/// The index for the provided list was non-existent
pub const LISTMAP_BAD_INDEX: &[u8] = eresp!("bad-list-index");
/// The list is empty
pub const LISTMAP_LIST_IS_EMPTY: &[u8] = eresp!("list-is-empty");
}
pub mod full_responses {
#![allow(unused)]
//! # Pre-compiled **responses**
//! These are pre-compiled **complete** responses. This means that they should
//! be written off directly to the stream and should **not be preceded by any response metaframe**
/// Response code: 0 (Okay)
pub const R_OKAY: &[u8] = "*!1\n0\n".as_bytes();
/// Response code: 1 (Nil)
pub const R_NIL: &[u8] = "*!1\n1\n".as_bytes();
/// Response code: 2 (Overwrite Error)
pub const R_OVERWRITE_ERR: &[u8] = "*!1\n2\n".as_bytes();
/// Response code: 3 (Action Error)
pub const R_ACTION_ERR: &[u8] = "*!1\n3\n".as_bytes();
/// Response code: 4 (Packet Error)
pub const R_PACKET_ERR: &[u8] = "*!1\n4\n".as_bytes();
/// Response code: 5 (Server Error)
pub const R_SERVER_ERR: &[u8] = "*!1\n5\n".as_bytes();
/// Response code: 6 (Other Error _without description_)
pub const R_OTHER_ERR_EMPTY: &[u8] = "*!1\n6\n".as_bytes();
/// Response code: 7; wrongtype
pub const R_WRONGTYPE_ERR: &[u8] = "*!1\n7".as_bytes();
/// Response code: 8; unknown data type
pub const R_UNKNOWN_DATA_TYPE: &[u8] = "*!1\n8\n".as_bytes();
/// A heya response
pub const R_HEYA: &[u8] = "*+4\nHEY!\n".as_bytes();
/// An other response with description: "Unknown action"
pub const R_UNKNOWN_ACTION: &[u8] = "*!14\nUnknown action\n".as_bytes();
/// A 0 uint64 reply
pub const R_ONE_INT_REPLY: &[u8] = "*:1\n1\n".as_bytes();
/// A 1 uint64 reply
pub const R_ZERO_INT_REPLY: &[u8] = "*:1\n0\n".as_bytes();
/// Snapshot busy (other error)
pub const R_SNAPSHOT_BUSY: &[u8] = "*!17\nerr-snapshot-busy\n".as_bytes();
/// Snapshot disabled (other error)
pub const R_SNAPSHOT_DISABLED: &[u8] = "*!21\nerr-snapshot-disabled\n".as_bytes();
/// Snapshot has illegal name (other error)
pub const R_SNAPSHOT_ILLEGAL_NAME: &[u8] = "*!25\nerr-invalid-snapshot-name\n".as_bytes();
/// Access after termination signal (other error)
pub const R_ERR_ACCESS_AFTER_TERMSIG: &[u8] = "*!24\nerr-access-after-termsig\n".as_bytes();
}

@ -0,0 +1,80 @@
/*
* Created on Mon May 02 2022
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2022, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
extern crate test;
use super::{super::Query, Parser};
use test::Bencher;
#[bench]
fn simple_query(b: &mut Bencher) {
const PAYLOAD: &[u8] = b"*1\n~3\n3\nSET\n1\nx\n3\n100\n";
let expected = vec!["SET".to_owned(), "x".to_owned(), "100".to_owned()];
b.iter(|| {
let (query, forward) = Parser::parse(PAYLOAD).unwrap();
assert_eq!(forward, PAYLOAD.len());
let query = if let Query::Simple(sq) = query {
sq
} else {
panic!("Got pipeline instead of simple query");
};
let ret: Vec<String> = query
.as_slice()
.iter()
.map(|s| String::from_utf8_lossy(s.as_slice()).to_string())
.collect();
assert_eq!(ret, expected)
});
}
#[bench]
fn pipelined_query(b: &mut Bencher) {
const PAYLOAD: &[u8] = b"*2\n~3\n3\nSET\n1\nx\n3\n100\n~2\n3\nGET\n1\nx\n";
let expected = vec![
vec!["SET".to_owned(), "x".to_owned(), "100".to_owned()],
vec!["GET".to_owned(), "x".to_owned()],
];
b.iter(|| {
let (query, forward) = Parser::parse(PAYLOAD).unwrap();
assert_eq!(forward, PAYLOAD.len());
let query = if let Query::Pipelined(sq) = query {
sq
} else {
panic!("Got simple instead of pipeline query");
};
let ret: Vec<Vec<String>> = query
.into_inner()
.iter()
.map(|query| {
query
.as_slice()
.iter()
.map(|v| String::from_utf8_lossy(v.as_slice()).to_string())
.collect()
})
.collect();
assert_eq!(ret, expected)
});
}

@ -0,0 +1,288 @@
/*
* Created on Mon May 02 2022
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2022, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
use {
crate::{
corestore::buffers::Integer64,
dbnet::connection::{QueryWithAdvance, RawConnection, Stream},
protocol::{
interface::{ProtocolRead, ProtocolSpec, ProtocolWrite},
ParseError, Skyhash1,
},
util::FutureResult,
IoResult,
},
::sky_macros::compiled_eresp_bytes_v1 as eresp,
tokio::io::AsyncWriteExt,
};
impl ProtocolSpec for Skyhash1 {
// spec information
const PROTOCOL_VERSION: f32 = 1.0;
const PROTOCOL_VERSIONSTRING: &'static str = "Skyhash-1.0";
// type symbols
const TSYMBOL_STRING: u8 = b'+';
const TSYMBOL_BINARY: u8 = b'?';
const TSYMBOL_FLOAT: u8 = b'%';
const TSYMBOL_INT64: u8 = b':';
const TSYMBOL_TYPED_ARRAY: u8 = b'@';
const TSYMBOL_TYPED_NON_NULL_ARRAY: u8 = b'^';
const TSYMBOL_ARRAY: u8 = b'&';
const TSYMBOL_FLAT_ARRAY: u8 = b'_';
// typed array
const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8] = b"\0";
// metaframe
const SIMPLE_QUERY_HEADER: &'static [u8] = b"*";
const PIPELINED_QUERY_FIRST_BYTE: u8 = b'$';
// respcodes
const RCODE_OKAY: &'static [u8] = eresp!("0");
const RCODE_NIL: &'static [u8] = eresp!("1");
const RCODE_OVERWRITE_ERR: &'static [u8] = eresp!("2");
const RCODE_ACTION_ERR: &'static [u8] = eresp!("3");
const RCODE_PACKET_ERR: &'static [u8] = eresp!("4");
const RCODE_SERVER_ERR: &'static [u8] = eresp!("5");
const RCODE_OTHER_ERR_EMPTY: &'static [u8] = eresp!("6");
const RCODE_UNKNOWN_ACTION: &'static [u8] = eresp!("Unknown action");
const RCODE_WRONGTYPE_ERR: &'static [u8] = eresp!("7");
const RCODE_UNKNOWN_DATA_TYPE: &'static [u8] = eresp!("8");
const RCODE_ENCODING_ERROR: &'static [u8] = eresp!("9");
// respstrings
const RSTRING_SNAPSHOT_BUSY: &'static [u8] = eresp!("err-snapshot-busy");
const RSTRING_SNAPSHOT_DISABLED: &'static [u8] = eresp!("err-snapshot-disabled");
const RSTRING_SNAPSHOT_DUPLICATE: &'static [u8] = eresp!("duplicate-snapshot");
const RSTRING_SNAPSHOT_ILLEGAL_NAME: &'static [u8] = eresp!("err-invalid-snapshot-name");
const RSTRING_ERR_ACCESS_AFTER_TERMSIG: &'static [u8] = eresp!("err-access-after-termsig");
// keyspace related resps
const RSTRING_DEFAULT_UNSET: &'static [u8] = eresp!("default-container-unset");
const RSTRING_CONTAINER_NOT_FOUND: &'static [u8] = eresp!("container-not-found");
const RSTRING_STILL_IN_USE: &'static [u8] = eresp!("still-in-use");
const RSTRING_PROTECTED_OBJECT: &'static [u8] = eresp!("err-protected-object");
const RSTRING_WRONG_MODEL: &'static [u8] = eresp!("wrong-model");
const RSTRING_ALREADY_EXISTS: &'static [u8] = eresp!("err-already-exists");
const RSTRING_NOT_READY: &'static [u8] = eresp!("not-ready");
const RSTRING_DDL_TRANSACTIONAL_FAILURE: &'static [u8] = eresp!("transactional-failure");
const RSTRING_UNKNOWN_DDL_QUERY: &'static [u8] = eresp!("unknown-ddl-query");
const RSTRING_BAD_EXPRESSION: &'static [u8] = eresp!("malformed-expression");
const RSTRING_UNKNOWN_MODEL: &'static [u8] = eresp!("unknown-model");
const RSTRING_TOO_MANY_ARGUMENTS: &'static [u8] = eresp!("too-many-args");
const RSTRING_CONTAINER_NAME_TOO_LONG: &'static [u8] = eresp!("container-name-too-long");
const RSTRING_BAD_CONTAINER_NAME: &'static [u8] = eresp!("bad-container-name");
const RSTRING_UNKNOWN_INSPECT_QUERY: &'static [u8] = eresp!("unknown-inspect-query");
const RSTRING_UNKNOWN_PROPERTY: &'static [u8] = eresp!("unknown-property");
const RSTRING_KEYSPACE_NOT_EMPTY: &'static [u8] = eresp!("keyspace-not-empty");
const RSTRING_BAD_TYPE_FOR_KEY: &'static [u8] = eresp!("bad-type-for-key");
const RSTRING_LISTMAP_BAD_INDEX: &'static [u8] = eresp!("bad-list-index");
const RSTRING_LISTMAP_LIST_IS_EMPTY: &'static [u8] = eresp!("list-is-empty");
// elements
const ELEMRESP_HEYA: &'static [u8] = b"+4\nHEY!\n";
// full responses
const FULLRESP_RCODE_PACKET_ERR: &'static [u8] = b"*1\n!1\n4\n";
const FULLRESP_RCODE_WRONG_TYPE: &'static [u8] = b"*1\n!1\n7\n";
// auth rcodes/strings
const AUTH_ERROR_ALREADYCLAIMED: &'static [u8] = eresp!("err-auth-already-claimed");
const AUTH_CODE_BAD_CREDENTIALS: &'static [u8] = eresp!("10");
const AUTH_ERROR_DISABLED: &'static [u8] = eresp!("err-auth-disabled");
const AUTH_CODE_PERMS: &'static [u8] = eresp!("11");
const AUTH_ERROR_ILLEGAL_USERNAME: &'static [u8] = eresp!("err-auth-illegal-username");
const AUTH_ERROR_FAILED_TO_DELETE_USER: &'static [u8] = eresp!("err-auth-deluser-fail");
}
impl<Strm, T> ProtocolRead<Skyhash1, Strm> for T
where
T: RawConnection<Skyhash1, Strm> + Send + Sync,
Strm: Stream,
{
fn try_query(&self) -> Result<QueryWithAdvance, ParseError> {
Skyhash1::parse(self.get_buffer())
}
}
impl<Strm, T> ProtocolWrite<Skyhash1, Strm> for T
where
T: RawConnection<Skyhash1, Strm> + Send + Sync,
Strm: Stream,
{
fn write_mono_length_prefixed_with_tsymbol<'life0, 'life1, 'ret_life>(
&'life0 mut self,
data: &'life1 [u8],
tsymbol: u8,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: Send + 'ret_life,
{
Box::pin(async move {
let stream = self.get_mut_stream();
// <tsymbol><length><lf>
stream.write_all(&[tsymbol]).await?;
stream.write_all(&Integer64::from(data.len())).await?;
stream.write_all(&[Skyhash1::LF]).await?;
// <data><lf>
stream.write_all(data).await?;
stream.write_all(&[Skyhash1::LF]).await
})
}
fn write_string<'life0, 'life1, 'ret_life>(
&'life0 mut self,
string: &'life1 str,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: 'ret_life,
{
Box::pin(async move {
let stream = self.get_mut_stream();
// tsymbol
stream.write_all(&[Skyhash1::TSYMBOL_STRING]).await?;
// length
let len_bytes = Integer64::from(string.len());
stream.write_all(&len_bytes).await?;
// LF
stream.write_all(&[Skyhash1::LF]).await?;
// payload
stream.write_all(string.as_bytes()).await?;
// final LF
stream.write_all(&[Skyhash1::LF]).await
})
}
fn write_binary<'life0, 'life1, 'ret_life>(
&'life0 mut self,
binary: &'life1 [u8],
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: 'ret_life,
{
Box::pin(async move {
let stream = self.get_mut_stream();
// tsymbol
stream.write_all(&[Skyhash1::TSYMBOL_BINARY]).await?;
// length
let len_bytes = Integer64::from(binary.len());
stream.write_all(&len_bytes).await?;
// LF
stream.write_all(&[Skyhash1::LF]).await?;
// payload
stream.write_all(binary).await?;
// final LF
stream.write_all(&[Skyhash1::LF]).await
})
}
fn write_usize<'life0, 'ret_life>(
&'life0 mut self,
size: usize,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: 'ret_life,
{
Box::pin(async move { self.write_int64(size as _).await })
}
fn write_int64<'life0, 'ret_life>(
&'life0 mut self,
int: u64,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: 'ret_life,
{
Box::pin(async move {
let stream = self.get_mut_stream();
// tsymbol
stream.write_all(&[Skyhash1::TSYMBOL_INT64]).await?;
// get body and sizeline
let body = Integer64::from(int);
let body_len = Integer64::from(body.len());
// len of body
stream.write_all(&body_len).await?;
// sizeline LF
stream.write_all(&[Skyhash1::LF]).await?;
// body
stream.write_all(&body).await?;
// LF
stream.write_all(&[Skyhash1::LF]).await
})
}
fn write_float<'life0, 'ret_life>(
&'life0 mut self,
float: f32,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: 'ret_life,
{
Box::pin(async move {
let stream = self.get_mut_stream();
// tsymbol
stream.write_all(&[Skyhash1::TSYMBOL_FLOAT]).await?;
// get body and sizeline
let body = float.to_string();
let body = body.as_bytes();
let sizeline = Integer64::from(body.len());
// sizeline
stream.write_all(&sizeline).await?;
// sizeline LF
stream.write_all(&[Skyhash1::LF]).await?;
// body
stream.write_all(body).await?;
// LF
stream.write_all(&[Skyhash1::LF]).await
})
}
fn write_typed_array_element<'life0, 'life1, 'ret_life>(
&'life0 mut self,
element: &'life1 [u8],
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: 'ret_life,
{
Box::pin(async move {
let stream = self.get_mut_stream();
// len
stream.write_all(&Integer64::from(element.len())).await?;
// LF
stream.write_all(&[Skyhash1::LF]).await?;
// body
stream.write_all(element).await?;
// LF
stream.write_all(&[Skyhash1::LF]).await
})
}
}

@ -0,0 +1,241 @@
/*
* Created on Sat Apr 30 2022
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2022, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
use {
super::{
raw_parser::{RawParser, RawParserExt, RawParserMeta},
ParseError, ParseResult, PipelinedQuery, Query, SimpleQuery, UnsafeSlice,
},
crate::{
corestore::heap_array::{HeapArray, HeapArrayWriter},
dbnet::connection::QueryWithAdvance,
},
};
mod interface_impls;
// test and bench modules
#[cfg(feature = "nightly")]
mod benches;
#[cfg(test)]
mod tests;
/// A parser for Skyhash 1.0
///
/// Packet structure example (simple query):
/// ```text
/// *1\n
/// ~3\n
/// 3\n
/// SET\n
/// 1\n
/// x\n
/// 3\n
/// 100\n
/// ```
pub struct Parser {
end: *const u8,
cursor: *const u8,
}
unsafe impl RawParser for Parser {
fn cursor_ptr(&self) -> *const u8 {
self.cursor
}
fn cursor_ptr_mut(&mut self) -> &mut *const u8 {
&mut self.cursor
}
fn data_end_ptr(&self) -> *const u8 {
self.end
}
}
unsafe impl Send for Parser {}
unsafe impl Sync for Parser {}
impl Parser {
/// Initialize a new parser
fn new(slice: &[u8]) -> Self {
unsafe {
Self {
end: slice.as_ptr().add(slice.len()),
cursor: slice.as_ptr(),
}
}
}
}
// utility methods
impl Parser {
/// Returns true if the cursor will give a char, but if `this_if_nothing_ahead` is set
/// to true, then if no byte is ahead, it will still return true
fn will_cursor_give_char(&self, ch: u8, true_if_nothing_ahead: bool) -> ParseResult<bool> {
if self.exhausted() {
// nothing left
if true_if_nothing_ahead {
Ok(true)
} else {
Err(ParseError::NotEnough)
}
} else if unsafe { self.get_byte_at_cursor().eq(&ch) } {
Ok(true)
} else {
Ok(false)
}
}
/// Check if the current cursor will give an LF
fn will_cursor_give_linefeed(&self) -> ParseResult<bool> {
self.will_cursor_give_char(b'\n', false)
}
/// Gets the _next element. **The cursor should be at the tsymbol (passed)**
fn _next(&mut self) -> ParseResult<UnsafeSlice> {
let element_size = self.read_usize()?;
self.read_until(element_size)
}
}
// higher level abstractions
impl Parser {
/// Parse the next blob. **The cursor should be at the tsymbol (passed)**
fn parse_next_blob(&mut self) -> ParseResult<UnsafeSlice> {
{
let chunk = self._next()?;
if self.will_cursor_give_linefeed()? {
unsafe {
// UNSAFE(@ohsayan): We know that the buffer is not exhausted
// due to the above condition
self.incr_cursor();
}
Ok(chunk)
} else {
Err(ParseError::UnexpectedByte)
}
}
}
}
// query abstractions
impl Parser {
/// The buffer should resemble the below structure:
/// ```
/// ~<count>\n
/// <e0l0>\n
/// <e0>\n
/// <e1l1>\n
/// <e1>\n
/// ...
/// ```
fn _parse_simple_query(&mut self) -> ParseResult<HeapArray<UnsafeSlice>> {
if self.not_exhausted() {
if unsafe { self.get_byte_at_cursor() } != b'~' {
// we need an any array
return Err(ParseError::WrongType);
}
unsafe {
// UNSAFE(@ohsayan): Just checked length
self.incr_cursor();
}
let query_count = self.read_usize()?;
let mut writer = HeapArrayWriter::with_capacity(query_count);
for i in 0..query_count {
unsafe {
// UNSAFE(@ohsayan): The index of the for loop ensures that
// we never attempt to write to a bad memory location
writer.write_to_index(i, self.parse_next_blob()?);
}
}
Ok(unsafe {
// UNSAFE(@ohsayan): If we've reached here, then we have initialized
// all the queries
writer.finish()
})
} else {
Err(ParseError::NotEnough)
}
}
fn parse_simple_query(&mut self) -> ParseResult<SimpleQuery> {
Ok(SimpleQuery::new(self._parse_simple_query()?))
}
/// The buffer should resemble the following structure:
/// ```text
/// # query 1
/// ~<count>\n
/// <e0l0>\n
/// <e0>\n
/// <e1l1>\n
/// <e1>\n
/// # query 2
/// ~<count>\n
/// <e0l0>\n
/// <e0>\n
/// <e1l1>\n
/// <e1>\n
/// ...
/// ```
fn parse_pipelined_query(&mut self, length: usize) -> ParseResult<PipelinedQuery> {
let mut writer = HeapArrayWriter::with_capacity(length);
for i in 0..length {
unsafe {
// UNSAFE(@ohsayan): The above condition guarantees that the index
// never causes an overflow
writer.write_to_index(i, self._parse_simple_query()?);
}
}
unsafe {
// UNSAFE(@ohsayan): if we reached here, then we have inited everything
Ok(PipelinedQuery::new(writer.finish()))
}
}
fn _parse(&mut self) -> ParseResult<Query> {
if self.not_exhausted() {
let first_byte = unsafe {
// UNSAFE(@ohsayan): Just checked if buffer is exhausted or not
self.get_byte_at_cursor()
};
if first_byte != b'*' {
// unknown query scheme, so it's a bad packet
return Err(ParseError::BadPacket);
}
unsafe {
// UNSAFE(@ohsayan): Checked buffer len and incremented, so we're good
self.incr_cursor()
};
let query_count = self.read_usize()?; // get the length
if query_count == 1 {
Ok(Query::Simple(self.parse_simple_query()?))
} else {
Ok(Query::Pipelined(self.parse_pipelined_query(query_count)?))
}
} else {
Err(ParseError::NotEnough)
}
}
pub fn parse(buf: &[u8]) -> ParseResult<QueryWithAdvance> {
let mut slf = Self::new(buf);
let body = slf._parse()?;
let consumed = slf.cursor_ptr() as usize - buf.as_ptr() as usize;
Ok((body, consumed))
}
}

@ -0,0 +1,93 @@
/*
* Created on Mon May 02 2022
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2022, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
use {
super::Parser,
crate::protocol::{ParseError, Query},
};
#[cfg(test)]
const SQPAYLOAD: &[u8] = b"*1\n~3\n3\nSET\n1\nx\n3\n100\n";
#[cfg(test)]
const PQPAYLOAD: &[u8] = b"*2\n~3\n3\nSET\n1\nx\n3\n100\n~2\n3\nGET\n1\nx\n";
#[test]
fn parse_simple_query() {
let payload = SQPAYLOAD.to_vec();
let (q, f) = Parser::parse(&payload).unwrap();
let q: Vec<String> = if let Query::Simple(q) = q {
q.as_slice()
.iter()
.map(|v| String::from_utf8_lossy(v.as_slice()).to_string())
.collect()
} else {
panic!("Expected simple query")
};
assert_eq!(f, payload.len());
assert_eq!(q, vec!["SET".to_owned(), "x".into(), "100".into()]);
}
#[test]
fn parse_simple_query_incomplete() {
for i in 0..SQPAYLOAD.len() - 1 {
let slice = &SQPAYLOAD[..i];
assert_eq!(Parser::parse(slice).unwrap_err(), ParseError::NotEnough);
}
}
#[test]
fn parse_pipelined_query() {
let payload = PQPAYLOAD.to_vec();
let (q, f) = Parser::parse(&payload).unwrap();
let q: Vec<Vec<String>> = if let Query::Pipelined(q) = q {
q.into_inner()
.iter()
.map(|sq| {
sq.iter()
.map(|v| String::from_utf8_lossy(v.as_slice()).to_string())
.collect()
})
.collect()
} else {
panic!("Expected pipelined query query")
};
assert_eq!(f, payload.len());
assert_eq!(
q,
vec![
vec!["SET".to_owned(), "x".into(), "100".into()],
vec!["GET".into(), "x".into()]
]
);
}
#[test]
fn parse_pipelined_query_incomplete() {
for i in 0..PQPAYLOAD.len() - 1 {
let slice = &PQPAYLOAD[..i];
assert_eq!(Parser::parse(slice).unwrap_err(), ParseError::NotEnough);
}
}

@ -0,0 +1,80 @@
/*
* Created on Sat Apr 30 2022
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2022, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
extern crate test;
use super::{super::Query, Parser};
use test::Bencher;
#[bench]
fn simple_query(b: &mut Bencher) {
const PAYLOAD: &[u8] = b"*3\n3\nSET1\nx3\n100";
let expected = vec!["SET".to_owned(), "x".to_owned(), "100".to_owned()];
b.iter(|| {
let (query, forward) = Parser::parse(PAYLOAD).unwrap();
assert_eq!(forward, PAYLOAD.len());
let query = if let Query::Simple(sq) = query {
sq
} else {
panic!("Got pipeline instead of simple query");
};
let ret: Vec<String> = query
.as_slice()
.iter()
.map(|s| String::from_utf8_lossy(s.as_slice()).to_string())
.collect();
assert_eq!(ret, expected)
});
}
#[bench]
fn pipelined_query(b: &mut Bencher) {
const PAYLOAD: &[u8] = b"$2\n3\n3\nSET1\nx3\n1002\n3\nGET1\nx";
let expected = vec![
vec!["SET".to_owned(), "x".to_owned(), "100".to_owned()],
vec!["GET".to_owned(), "x".to_owned()],
];
b.iter(|| {
let (query, forward) = Parser::parse(PAYLOAD).unwrap();
assert_eq!(forward, PAYLOAD.len());
let query = if let Query::Pipelined(sq) = query {
sq
} else {
panic!("Got simple instead of pipeline query");
};
let ret: Vec<Vec<String>> = query
.into_inner()
.iter()
.map(|query| {
query
.as_slice()
.iter()
.map(|v| String::from_utf8_lossy(v.as_slice()).to_string())
.collect()
})
.collect();
assert_eq!(ret, expected)
});
}

@ -0,0 +1,263 @@
/*
* Created on Sat Apr 30 2022
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2022, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
use crate::{
corestore::buffers::Integer64,
dbnet::connection::{QueryWithAdvance, RawConnection, Stream},
protocol::{
interface::{ProtocolRead, ProtocolSpec, ProtocolWrite},
ParseError, Skyhash2,
},
util::FutureResult,
IoResult,
};
use ::sky_macros::compiled_eresp_bytes as eresp;
use tokio::io::AsyncWriteExt;
impl ProtocolSpec for Skyhash2 {
// spec information
const PROTOCOL_VERSION: f32 = 2.0;
const PROTOCOL_VERSIONSTRING: &'static str = "Skyhash-2.0";
// type symbols
const TSYMBOL_STRING: u8 = b'+';
const TSYMBOL_BINARY: u8 = b'?';
const TSYMBOL_FLOAT: u8 = b'%';
const TSYMBOL_INT64: u8 = b':';
const TSYMBOL_TYPED_ARRAY: u8 = b'@';
const TSYMBOL_TYPED_NON_NULL_ARRAY: u8 = b'^';
const TSYMBOL_ARRAY: u8 = b'&';
const TSYMBOL_FLAT_ARRAY: u8 = b'_';
// typed array
const TYPE_TYPED_ARRAY_ELEMENT_NULL: &'static [u8] = b"\0";
// metaframe
const SIMPLE_QUERY_HEADER: &'static [u8] = b"*";
const PIPELINED_QUERY_FIRST_BYTE: u8 = b'$';
// respcodes
const RCODE_OKAY: &'static [u8] = eresp!("0");
const RCODE_NIL: &'static [u8] = eresp!("1");
const RCODE_OVERWRITE_ERR: &'static [u8] = eresp!("2");
const RCODE_ACTION_ERR: &'static [u8] = eresp!("3");
const RCODE_PACKET_ERR: &'static [u8] = eresp!("4");
const RCODE_SERVER_ERR: &'static [u8] = eresp!("5");
const RCODE_OTHER_ERR_EMPTY: &'static [u8] = eresp!("6");
const RCODE_UNKNOWN_ACTION: &'static [u8] = eresp!("Unknown action");
const RCODE_WRONGTYPE_ERR: &'static [u8] = eresp!("7");
const RCODE_UNKNOWN_DATA_TYPE: &'static [u8] = eresp!("8");
const RCODE_ENCODING_ERROR: &'static [u8] = eresp!("9");
// respstrings
const RSTRING_SNAPSHOT_BUSY: &'static [u8] = eresp!("err-snapshot-busy");
const RSTRING_SNAPSHOT_DISABLED: &'static [u8] = eresp!("err-snapshot-disabled");
const RSTRING_SNAPSHOT_DUPLICATE: &'static [u8] = eresp!("duplicate-snapshot");
const RSTRING_SNAPSHOT_ILLEGAL_NAME: &'static [u8] = eresp!("err-invalid-snapshot-name");
const RSTRING_ERR_ACCESS_AFTER_TERMSIG: &'static [u8] = eresp!("err-access-after-termsig");
// keyspace related resps
const RSTRING_DEFAULT_UNSET: &'static [u8] = eresp!("default-container-unset");
const RSTRING_CONTAINER_NOT_FOUND: &'static [u8] = eresp!("container-not-found");
const RSTRING_STILL_IN_USE: &'static [u8] = eresp!("still-in-use");
const RSTRING_PROTECTED_OBJECT: &'static [u8] = eresp!("err-protected-object");
const RSTRING_WRONG_MODEL: &'static [u8] = eresp!("wrong-model");
const RSTRING_ALREADY_EXISTS: &'static [u8] = eresp!("err-already-exists");
const RSTRING_NOT_READY: &'static [u8] = eresp!("not-ready");
const RSTRING_DDL_TRANSACTIONAL_FAILURE: &'static [u8] = eresp!("transactional-failure");
const RSTRING_UNKNOWN_DDL_QUERY: &'static [u8] = eresp!("unknown-ddl-query");
const RSTRING_BAD_EXPRESSION: &'static [u8] = eresp!("malformed-expression");
const RSTRING_UNKNOWN_MODEL: &'static [u8] = eresp!("unknown-model");
const RSTRING_TOO_MANY_ARGUMENTS: &'static [u8] = eresp!("too-many-args");
const RSTRING_CONTAINER_NAME_TOO_LONG: &'static [u8] = eresp!("container-name-too-long");
const RSTRING_BAD_CONTAINER_NAME: &'static [u8] = eresp!("bad-container-name");
const RSTRING_UNKNOWN_INSPECT_QUERY: &'static [u8] = eresp!("unknown-inspect-query");
const RSTRING_UNKNOWN_PROPERTY: &'static [u8] = eresp!("unknown-property");
const RSTRING_KEYSPACE_NOT_EMPTY: &'static [u8] = eresp!("keyspace-not-empty");
const RSTRING_BAD_TYPE_FOR_KEY: &'static [u8] = eresp!("bad-type-for-key");
const RSTRING_LISTMAP_BAD_INDEX: &'static [u8] = eresp!("bad-list-index");
const RSTRING_LISTMAP_LIST_IS_EMPTY: &'static [u8] = eresp!("list-is-empty");
// elements
const ELEMRESP_HEYA: &'static [u8] = b"+4\nHEY!";
// full responses
const FULLRESP_RCODE_PACKET_ERR: &'static [u8] = b"*!4\n";
const FULLRESP_RCODE_WRONG_TYPE: &'static [u8] = b"*!7\n";
// auth respcodes/strings
const AUTH_ERROR_ALREADYCLAIMED: &'static [u8] = eresp!("err-auth-already-claimed");
const AUTH_CODE_BAD_CREDENTIALS: &'static [u8] = eresp!("10");
const AUTH_ERROR_DISABLED: &'static [u8] = eresp!("err-auth-disabled");
const AUTH_CODE_PERMS: &'static [u8] = eresp!("11");
const AUTH_ERROR_ILLEGAL_USERNAME: &'static [u8] = eresp!("err-auth-illegal-username");
const AUTH_ERROR_FAILED_TO_DELETE_USER: &'static [u8] = eresp!("err-auth-deluser-fail");
}
impl<Strm, T> ProtocolRead<Skyhash2, Strm> for T
where
T: RawConnection<Skyhash2, Strm> + Send + Sync,
Strm: Stream,
{
fn try_query(&self) -> Result<QueryWithAdvance, ParseError> {
Skyhash2::parse(self.get_buffer())
}
}
impl<Strm, T> ProtocolWrite<Skyhash2, Strm> for T
where
T: RawConnection<Skyhash2, Strm> + Send + Sync,
Strm: Stream,
{
fn write_mono_length_prefixed_with_tsymbol<'life0, 'life1, 'ret_life>(
&'life0 mut self,
data: &'life1 [u8],
tsymbol: u8,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: Send + 'ret_life,
{
Box::pin(async move {
let stream = self.get_mut_stream();
// <tsymbol><length><lf>
stream.write_all(&[tsymbol]).await?;
stream.write_all(&Integer64::from(data.len())).await?;
stream.write_all(&[Skyhash2::LF]).await?;
stream.write_all(data).await
})
}
fn write_string<'life0, 'life1, 'ret_life>(
&'life0 mut self,
string: &'life1 str,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: 'ret_life,
{
Box::pin(async move {
let stream = self.get_mut_stream();
// tsymbol
stream.write_all(&[Skyhash2::TSYMBOL_STRING]).await?;
// length
let len_bytes = Integer64::from(string.len());
stream.write_all(&len_bytes).await?;
// LF
stream.write_all(&[Skyhash2::LF]).await?;
// payload
stream.write_all(string.as_bytes()).await
})
}
fn write_binary<'life0, 'life1, 'ret_life>(
&'life0 mut self,
binary: &'life1 [u8],
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: 'ret_life,
{
Box::pin(async move {
let stream = self.get_mut_stream();
// tsymbol
stream.write_all(&[Skyhash2::TSYMBOL_BINARY]).await?;
// length
let len_bytes = Integer64::from(binary.len());
stream.write_all(&len_bytes).await?;
// LF
stream.write_all(&[Skyhash2::LF]).await?;
// payload
stream.write_all(binary).await
})
}
fn write_usize<'life0, 'ret_life>(
&'life0 mut self,
size: usize,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: 'ret_life,
{
Box::pin(async move { self.write_int64(size as _).await })
}
fn write_int64<'life0, 'ret_life>(
&'life0 mut self,
int: u64,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: 'ret_life,
{
Box::pin(async move {
let stream = self.get_mut_stream();
// tsymbol
stream.write_all(&[Skyhash2::TSYMBOL_INT64]).await?;
// body
stream.write_all(&Integer64::from(int)).await?;
// LF
stream.write_all(&[Skyhash2::LF]).await
})
}
fn write_float<'life0, 'ret_life>(
&'life0 mut self,
float: f32,
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
Self: 'ret_life,
{
Box::pin(async move {
let stream = self.get_mut_stream();
// tsymbol
stream.write_all(&[Skyhash2::TSYMBOL_FLOAT]).await?;
// body
stream.write_all(float.to_string().as_bytes()).await?;
// LF
stream.write_all(&[Skyhash2::LF]).await
})
}
fn write_typed_array_element<'life0, 'life1, 'ret_life>(
&'life0 mut self,
element: &'life1 [u8],
) -> FutureResult<'ret_life, IoResult<()>>
where
'life0: 'ret_life,
'life1: 'ret_life,
Self: 'ret_life,
{
Box::pin(async move {
let stream = self.get_mut_stream();
// len
stream.write_all(&Integer64::from(element.len())).await?;
// LF
stream.write_all(&[Skyhash2::LF]).await?;
// body
stream.write_all(element).await
})
}
}

@ -0,0 +1,186 @@
/*
* Created on Fri Apr 29 2022
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2022, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
mod interface_impls;
use {
super::{
raw_parser::{RawParser, RawParserExt, RawParserMeta},
ParseError, ParseResult, PipelinedQuery, Query, SimpleQuery, UnsafeSlice,
},
crate::{corestore::heap_array::HeapArray, dbnet::connection::QueryWithAdvance},
};
#[cfg(feature = "nightly")]
mod benches;
#[cfg(test)]
mod tests;
/// A parser for Skyhash 2.0
pub struct Parser {
end: *const u8,
cursor: *const u8,
}
unsafe impl RawParser for Parser {
fn cursor_ptr(&self) -> *const u8 {
self.cursor
}
fn cursor_ptr_mut(&mut self) -> &mut *const u8 {
&mut self.cursor
}
fn data_end_ptr(&self) -> *const u8 {
self.end
}
}
unsafe impl Sync for Parser {}
unsafe impl Send for Parser {}
impl Parser {
/// Initialize a new parser
fn new(slice: &[u8]) -> Self {
unsafe {
Self {
end: slice.as_ptr().add(slice.len()),
cursor: slice.as_ptr(),
}
}
}
}
// query impls
impl Parser {
/// Parse the next simple query. This should have passed the `*` tsymbol
///
/// Simple query structure (tokenized line-by-line):
/// ```text
/// * -> Simple Query Header
/// <n>\n -> Count of elements in the simple query
/// <l0>\n -> Length of element 1
/// <e0> -> element 1 itself
/// <l1>\n -> Length of element 2
/// <e1> -> element 2 itself
/// ...
/// ```
fn _next_simple_query(&mut self) -> ParseResult<HeapArray<UnsafeSlice>> {
let element_count = self.read_usize()?;
unsafe {
let mut data = HeapArray::new_writer(element_count);
for i in 0..element_count {
let element_size = self.read_usize()?;
let element = self.read_until(element_size)?;
data.write_to_index(i, element);
}
Ok(data.finish())
}
}
/// Parse a simple query
fn next_simple_query(&mut self) -> ParseResult<SimpleQuery> {
Ok(SimpleQuery::new(self._next_simple_query()?))
}
/// Parse a pipelined query. This should have passed the `$` tsymbol
///
/// Pipelined query structure (tokenized line-by-line):
/// ```text
/// $ -> Pipeline
/// <n>\n -> Pipeline has n queries
/// <lq0>\n -> Query 1 has 3 elements
/// <lq0e0>\n -> Q1E1 has 3 bytes
/// <q0e0> -> Q1E1 itself
/// <lq0e1>\n -> Q1E2 has 1 byte
/// <q0e1> -> Q1E2 itself
/// <lq0e2>\n -> Q1E3 has 3 bytes
/// <q0e2> -> Q1E3 itself
/// <lq1>\n -> Query 2 has 2 elements
/// <lq1e0>\n -> Q2E1 has 3 bytes
/// <q1e0> -> Q2E1 itself
/// <lq1e1>\n -> Q2E2 has 1 byte
/// <q1e1> -> Q2E2 itself
/// ...
/// ```
///
/// Example:
/// ```text
/// $ -> Pipeline
/// 2\n -> Pipeline has 2 queries
/// 3\n -> Query 1 has 3 elements
/// 3\n -> Q1E1 has 3 bytes
/// SET -> Q1E1 itself
/// 1\n -> Q1E2 has 1 byte
/// x -> Q1E2 itself
/// 3\n -> Q1E3 has 3 bytes
/// 100 -> Q1E3 itself
/// 2\n -> Query 2 has 2 elements
/// 3\n -> Q2E1 has 3 bytes
/// GET -> Q2E1 itself
/// 1\n -> Q2E2 has 1 byte
/// x -> Q2E2 itself
/// ```
fn next_pipeline(&mut self) -> ParseResult<PipelinedQuery> {
let query_count = self.read_usize()?;
unsafe {
let mut queries = HeapArray::new_writer(query_count);
for i in 0..query_count {
let sq = self._next_simple_query()?;
queries.write_to_index(i, sq);
}
Ok(PipelinedQuery {
data: queries.finish(),
})
}
}
fn _parse(&mut self) -> ParseResult<Query> {
if self.not_exhausted() {
unsafe {
let first_byte = self.get_byte_at_cursor();
self.incr_cursor();
let data = match first_byte {
b'*' => {
// a simple query
Query::Simple(self.next_simple_query()?)
}
b'$' => {
// a pipelined query
Query::Pipelined(self.next_pipeline()?)
}
_ => return Err(ParseError::UnexpectedByte),
};
Ok(data)
}
} else {
Err(ParseError::NotEnough)
}
}
// only expose this. don't expose Self::new since that'll be _relatively easier_ to
// invalidate invariants for
pub fn parse(buf: &[u8]) -> ParseResult<QueryWithAdvance> {
let mut slf = Self::new(buf);
let body = slf._parse()?;
let consumed = slf.cursor_ptr() as usize - buf.as_ptr() as usize;
Ok((body, consumed))
}
}

@ -24,8 +24,11 @@
*
*/
use super::{Parser, PipelinedQuery, Query, SimpleQuery};
use crate::protocol::ParseError;
use super::{
super::raw_parser::{RawParser, RawParserExt, RawParserMeta},
Parser, PipelinedQuery, Query, SimpleQuery,
};
use crate::protocol::{iter::AnyArrayIter, ParseError};
use std::iter::Map;
use std::vec::IntoIter as VecIntoIter;
@ -67,11 +70,9 @@ fn get_slices(slices: &[&[u8]]) -> Packets {
fn ensure_zero_reads(parser: &mut Parser) {
let r = parser.read_until(0).unwrap();
unsafe {
let slice = r.as_slice();
assert_eq!(slice, b"");
assert!(slice.is_empty());
}
let slice = r.as_slice();
assert_eq!(slice, b"");
assert!(slice.is_empty());
}
// We do this intentionally for "heap simulation"
@ -317,11 +318,9 @@ fn read_until_nonempty() {
ensure_zero_reads(&mut parser);
// now read the entire length; should always work
let r = parser.read_until(len).unwrap();
unsafe {
let slice = r.as_slice();
assert_eq!(slice, src.as_slice());
assert_eq!(slice.len(), len);
}
let slice = r.as_slice();
assert_eq!(slice, src.as_slice());
assert_eq!(slice.len(), len);
// even after the buffer is exhausted, `0` should always work
ensure_zero_reads(&mut parser);
}
@ -346,23 +345,19 @@ fn read_until_not_enough() {
fn read_until_more_bytes() {
let sample1 = v!(b"abcd1");
let mut p1 = Parser::new(&sample1);
unsafe {
assert_eq!(
p1.read_until(&sample1.len() - 1).unwrap().as_slice(),
&sample1[..&sample1.len() - 1]
);
// ensure we have not exhasuted
ensure_not_exhausted(&p1);
ensure_remaining(&p1, 1);
}
assert_eq!(
p1.read_until(&sample1.len() - 1).unwrap().as_slice(),
&sample1[..&sample1.len() - 1]
);
// ensure we have not exhasuted
ensure_not_exhausted(&p1);
ensure_remaining(&p1, 1);
let sample2 = v!(b"abcd1234567890!@#$");
let mut p2 = Parser::new(&sample2);
unsafe {
assert_eq!(p2.read_until(4).unwrap().as_slice(), &sample2[..4]);
// ensure we have not exhasuted
ensure_not_exhausted(&p2);
ensure_remaining(&p2, sample2.len() - 4);
}
assert_eq!(p2.read_until(4).unwrap().as_slice(), &sample2[..4]);
// ensure we have not exhasuted
ensure_not_exhausted(&p2);
ensure_remaining(&p2, sample2.len() - 4);
}
// read_line
@ -370,12 +365,10 @@ fn read_until_more_bytes() {
fn read_line_special_case_only_lf() {
let b = v!(b"\n");
let mut parser = Parser::new(&b);
unsafe {
let r = parser.read_line().unwrap();
let slice = r.as_slice();
assert_eq!(slice, b"");
assert!(slice.is_empty());
};
let r = parser.read_line().unwrap();
let slice = r.as_slice();
assert_eq!(slice, b"");
assert!(slice.is_empty());
// ensure it is exhausted
ensure_exhausted(&parser);
}
@ -389,12 +382,10 @@ fn read_line() {
assert_eq!(parser.read_line().unwrap_err(), ParseError::NotEnough);
} else {
// should work
unsafe {
assert_eq!(
parser.read_line().unwrap().as_slice(),
&src.as_slice()[..len - 1]
);
}
assert_eq!(
parser.read_line().unwrap().as_slice(),
&src.as_slice()[..len - 1]
);
// now, we attempt to read which should work
ensure_zero_reads(&mut parser);
}
@ -414,9 +405,7 @@ fn read_line_more_bytes() {
let sample1 = v!(b"abcd\n1");
let mut p1 = Parser::new(&sample1);
let line = p1.read_line().unwrap();
unsafe {
assert_eq!(line.as_slice(), b"abcd");
}
assert_eq!(line.as_slice(), b"abcd");
// we should still have one remaining
ensure_not_exhausted(&p1);
ensure_remaining(&p1, 1);
@ -427,17 +416,13 @@ fn read_line_subsequent_lf() {
let sample1 = v!(b"abcd\n1\n");
let mut p1 = Parser::new(&sample1);
let line = p1.read_line().unwrap();
unsafe {
assert_eq!(line.as_slice(), b"abcd");
}
assert_eq!(line.as_slice(), b"abcd");
// we should still have two octets remaining
ensure_not_exhausted(&p1);
ensure_remaining(&p1, 2);
// and we should be able to read in another line
let line = p1.read_line().unwrap();
unsafe {
assert_eq!(line.as_slice(), b"1");
}
assert_eq!(line.as_slice(), b"1");
ensure_exhausted(&p1);
}
@ -453,12 +438,10 @@ fn read_line_pedantic_okay() {
);
} else {
// should work
unsafe {
assert_eq!(
parser.read_line_pedantic().unwrap().as_slice(),
&src.as_slice()[..len - 1]
);
}
assert_eq!(
parser.read_line_pedantic().unwrap().as_slice(),
&src.as_slice()[..len - 1]
);
// now, we attempt to read which should work
ensure_zero_reads(&mut parser);
}
@ -641,3 +624,18 @@ fn pipelined_query_fail_because_not_enough() {
assert_eq!(ret, ParseError::NotEnough)
}
}
#[test]
fn test_iter() {
use super::{Parser, Query};
let (q, _fwby) = Parser::parse(b"*3\n3\nset1\nx3\n100").unwrap();
let r = match q {
Query::Simple(q) => q,
_ => panic!("Wrong query"),
};
let it = r.as_slice().iter();
let mut iter = unsafe { AnyArrayIter::new(it) };
assert_eq!(iter.next_uppercase().unwrap().as_ref(), "SET".as_bytes());
assert_eq!(iter.next().unwrap(), "x".as_bytes());
assert_eq!(iter.next().unwrap(), "100".as_bytes());
}

@ -42,15 +42,14 @@ action! {
/// like queries
fn create(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) {
// minlength is 2 (create has already been checked)
ensure_length(act.len(), |size| size > 1)?;
ensure_length::<P>(act.len(), |size| size > 1)?;
let mut create_what = unsafe { act.next().unsafe_unwrap() }.to_vec();
create_what.make_ascii_uppercase();
match create_what.as_ref() {
TABLE => create_table(handle, con, act).await?,
KEYSPACE => create_keyspace(handle, con, act).await?,
_ => {
con.write_response(responses::groups::UNKNOWN_DDL_QUERY)
.await?;
con._write_raw(P::RSTRING_UNKNOWN_DDL_QUERY).await?;
}
}
Ok(())
@ -60,15 +59,14 @@ action! {
/// like queries
fn ddl_drop(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) {
// minlength is 2 (create has already been checked)
ensure_length(act.len(), |size| size > 1)?;
ensure_length::<P>(act.len(), |size| size > 1)?;
let mut create_what = unsafe { act.next().unsafe_unwrap() }.to_vec();
create_what.make_ascii_uppercase();
match create_what.as_ref() {
TABLE => drop_table(handle, con, act).await?,
KEYSPACE => drop_keyspace(handle, con, act).await?,
_ => {
con.write_response(responses::groups::UNKNOWN_DDL_QUERY)
.await?;
con._write_raw(P::RSTRING_UNKNOWN_DDL_QUERY).await?;
}
}
Ok(())
@ -76,77 +74,77 @@ action! {
/// We should have `<tableid> <model>(args) properties`
fn create_table(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |size| size > 1 && size < 4)?;
ensure_length::<P>(act.len(), |size| size > 1 && size < 4)?;
let table_name = unsafe { act.next().unsafe_unwrap() };
let model_name = unsafe { act.next().unsafe_unwrap() };
let (table_entity, model_code) = parser::parse_table_args(table_name, model_name)?;
let (table_entity, model_code) = parser::parse_table_args::<P>(table_name, model_name)?;
let is_volatile = match act.next() {
Some(maybe_volatile) => {
ensure_cond_or_err(maybe_volatile.eq(VOLATILE), responses::groups::UNKNOWN_PROPERTY)?;
ensure_cond_or_err(maybe_volatile.eq(VOLATILE), P::RSTRING_UNKNOWN_PROPERTY)?;
true
}
None => false,
};
if registry::state_okay() {
handle.create_table(table_entity, model_code, is_volatile)?;
con.write_response(responses::groups::OKAY).await?;
translate_ddl_error::<P, ()>(handle.create_table(table_entity, model_code, is_volatile))?;
con._write_raw(P::RCODE_OKAY).await?;
} else {
conwrite!(con, responses::groups::SERVER_ERR)?;
return util::err(P::RCODE_SERVER_ERR);
}
Ok(())
}
/// We should have `<ksid>`
fn create_keyspace(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len == 1)?;
ensure_length::<P>(act.len(), |len| len == 1)?;
match act.next() {
Some(ksid) => {
ensure_cond_or_err(encoding::is_utf8(&ksid), responses::groups::ENCODING_ERROR)?;
ensure_cond_or_err(encoding::is_utf8(&ksid), P::RCODE_ENCODING_ERROR)?;
let ksid_str = unsafe { str::from_utf8_unchecked(ksid) };
ensure_cond_or_err(VALID_CONTAINER_NAME.is_match(ksid_str), responses::groups::BAD_EXPRESSION)?;
ensure_cond_or_err(ksid.len() < 64, responses::groups::CONTAINER_NAME_TOO_LONG)?;
ensure_cond_or_err(VALID_CONTAINER_NAME.is_match(ksid_str), P::RSTRING_BAD_EXPRESSION)?;
ensure_cond_or_err(ksid.len() < 64, P::RSTRING_CONTAINER_NAME_TOO_LONG)?;
let ksid = unsafe { ObjectID::from_slice(ksid_str) };
if registry::state_okay() {
handle.create_keyspace(ksid)?;
con.write_response(responses::groups::OKAY).await?
translate_ddl_error::<P, ()>(handle.create_keyspace(ksid))?;
con._write_raw(P::RCODE_OKAY).await?
} else {
conwrite!(con, responses::groups::SERVER_ERR)?;
return util::err(P::RCODE_SERVER_ERR);
}
}
None => con.write_response(responses::groups::ACTION_ERR).await?,
None => return util::err(P::RCODE_ACTION_ERR),
}
Ok(())
}
/// Drop a table (`<tblid>` only)
fn drop_table(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |size| size == 1)?;
ensure_length::<P>(act.len(), |size| size == 1)?;
match act.next() {
Some(eg) => {
let entity_group = parser::Entity::from_slice(eg)?;
let entity_group = parser::Entity::from_slice::<P>(eg)?;
if registry::state_okay() {
handle.drop_table(entity_group)?;
con.write_response(responses::groups::OKAY).await?;
translate_ddl_error::<P, ()>(handle.drop_table(entity_group))?;
con._write_raw(P::RCODE_OKAY).await?;
} else {
conwrite!(con, responses::groups::SERVER_ERR)?;
return util::err(P::RCODE_SERVER_ERR);
}
},
None => con.write_response(responses::groups::ACTION_ERR).await?,
None => return util::err(P::RCODE_ACTION_ERR),
}
Ok(())
}
/// Drop a keyspace (`<ksid>` only)
fn drop_keyspace(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |size| size == 1)?;
ensure_length::<P>(act.len(), |size| size == 1)?;
match act.next() {
Some(ksid) => {
ensure_cond_or_err(ksid.len() < 64, responses::groups::CONTAINER_NAME_TOO_LONG)?;
ensure_cond_or_err(ksid.len() < 64, P::RSTRING_CONTAINER_NAME_TOO_LONG)?;
let force_remove = match act.next() {
Some(bts) if bts.eq(FORCE_REMOVE) => true,
None => false,
_ => {
return util::err(responses::groups::UNKNOWN_ACTION);
return util::err(P::RCODE_UNKNOWN_ACTION);
}
};
if registry::state_okay() {
@ -156,13 +154,13 @@ action! {
} else {
handle.drop_keyspace(objid)
};
result?;
con.write_response(responses::groups::OKAY).await?;
translate_ddl_error::<P, ()>(result)?;
con._write_raw(P::RCODE_OKAY).await?;
} else {
conwrite!(con, responses::groups::SERVER_ERR)?;
return util::err(P::RCODE_SERVER_ERR);
}
},
None => con.write_response(responses::groups::ACTION_ERR).await?,
None => return util::err(P::RCODE_ACTION_ERR),
}
Ok(())
}

@ -25,11 +25,14 @@
*/
use super::ddl::{KEYSPACE, TABLE};
use crate::corestore::memstore::ObjectID;
use crate::corestore::{
memstore::{Keyspace, ObjectID},
table::Table,
};
use crate::dbnet::connection::prelude::*;
use crate::resp::writer::TypedArrayWriter;
const KEYSPACES: &[u8] = "KEYSPACES".as_bytes();
action! {
/// Runs an inspect query:
/// - `INSPECT KEYSPACES` is run by this function itself
@ -44,7 +47,7 @@ action! {
KEYSPACE => inspect_keyspace(handle, con, act).await?,
TABLE => inspect_table(handle, con, act).await?,
KEYSPACES => {
ensure_length(act.len(), |len| len == 0)?;
ensure_length::<P>(act.len(), |len| len == 0)?;
// let's return what all keyspaces exist
let ks_list: Vec<ObjectID> = handle
.get_store()
@ -52,66 +55,62 @@ action! {
.iter()
.map(|kv| kv.key().clone())
.collect();
let mut writer = unsafe {
TypedArrayWriter::new(con, b'+', ks_list.len())
}.await?;
for tbl in ks_list {
writer.write_element(tbl).await?;
con.write_typed_non_null_array_header(ks_list.len(), b'+').await?;
for ks in ks_list {
con.write_typed_non_null_array_element(&ks).await?;
}
}
_ => conwrite!(con, responses::groups::UNKNOWN_INSPECT_QUERY)?,
_ => return util::err(P::RSTRING_UNKNOWN_INSPECT_QUERY),
}
}
None => aerr!(con),
None => return util::err(P::RCODE_ACTION_ERR),
}
Ok(())
}
/// INSPECT a keyspace. This should only have the keyspace ID
fn inspect_keyspace(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len < 2)?;
let tbl_list: Vec<ObjectID>;
ensure_length::<P>(act.len(), |len| len < 2)?;
let tbl_list: Vec<ObjectID> =
match act.next() {
Some(keyspace_name) => {
// inspect the provided keyspace
let ksid = if keyspace_name.len() > 64 {
return conwrite!(con, responses::groups::BAD_CONTAINER_NAME);
return util::err(P::RSTRING_BAD_CONTAINER_NAME);
} else {
keyspace_name
};
let ks = match handle.get_keyspace(ksid) {
Some(kspace) => kspace,
None => return conwrite!(con, responses::groups::CONTAINER_NOT_FOUND),
None => return util::err(P::RSTRING_CONTAINER_NOT_FOUND),
};
tbl_list = ks.tables.iter().map(|kv| kv.key().clone()).collect();
ks.tables.iter().map(|kv| kv.key().clone()).collect()
},
None => {
// inspect the current keyspace
let cks = handle.get_cks()?;
tbl_list = cks.tables.iter().map(|kv| kv.key().clone()).collect();
let cks = translate_ddl_error::<P, &Keyspace>(handle.get_cks())?;
cks.tables.iter().map(|kv| kv.key().clone()).collect()
},
}
let mut writer = unsafe {
TypedArrayWriter::new(con, b'+', tbl_list.len())
}.await?;
};
con.write_typed_non_null_array_header(tbl_list.len(), b'+').await?;
for tbl in tbl_list {
writer.write_element(tbl).await?;
con.write_typed_non_null_array_element(&tbl).await?;
}
Ok(())
}
/// INSPECT a table. This should only have the table ID
fn inspect_table(handle: &Corestore, con: &'a mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len < 2)?;
ensure_length::<P>(act.len(), |len| len < 2)?;
match act.next() {
Some(entity) => {
let entity = handle_entity!(con, entity);
conwrite!(con, get_tbl!(entity, handle, con).describe_self())?;
con.write_string(get_tbl!(entity, handle, con).describe_self()).await?;
},
None => {
// inspect the current table
let tbl = handle.get_table_result()?;
con.write_response(tbl.describe_self()).await?;
let tbl = translate_ddl_error::<P, &Table>(handle.get_table_result())?;
con.write_string(tbl.describe_self()).await?;
},
}
Ok(())

@ -30,7 +30,7 @@ use crate::actions::{ActionError, ActionResult};
use crate::auth;
use crate::corestore::Corestore;
use crate::dbnet::connection::prelude::*;
use crate::protocol::{iter::AnyArrayIter, responses, PipelinedQuery, SimpleQuery, UnsafeSlice};
use crate::protocol::{iter::AnyArrayIter, PipelinedQuery, SimpleQuery, UnsafeSlice};
use crate::queryengine::parser::Entity;
use crate::{actions, admin};
mod ddl;
@ -58,7 +58,7 @@ macro_rules! gen_constants_and_matches {
pub const $action2: &[u8] = stringify!($action2).as_bytes();
)*
}
let first = $buf.next_uppercase().unwrap_or_custom_aerr(groups::PACKET_ERR)?;
let first = $buf.next_uppercase().unwrap_or_custom_aerr(P::RCODE_PACKET_ERR)?;
match first.as_ref() {
$(
tags::$action => $fns($db, $con, $buf).await?,
@ -67,7 +67,7 @@ macro_rules! gen_constants_and_matches {
tags::$action2 => $fns2.await?,
)*
_ => {
$con.write_response(responses::groups::UNKNOWN_ACTION).await?;
$con._write_raw(P::RCODE_UNKNOWN_ACTION).await?;
}
}
};
@ -78,7 +78,7 @@ action! {
fn execute_simple_noauth(
_db: &mut Corestore,
con: &mut T,
auth: &mut AuthProviderHandle<'_, T, Strm>,
auth: &mut AuthProviderHandle<'_, P, T, Strm>,
buf: SimpleQuery
) {
let bufref = buf.as_slice();
@ -87,26 +87,26 @@ action! {
// won't suddenly become invalid
AnyArrayIter::new(bufref.iter())
};
match iter.next_lowercase().unwrap_or_custom_aerr(groups::PACKET_ERR)?.as_ref() {
match iter.next_lowercase().unwrap_or_custom_aerr(P::RCODE_PACKET_ERR)?.as_ref() {
ACTION_AUTH => auth::auth_login_only(con, auth, iter).await,
_ => util::err(auth::errors::AUTH_CODE_BAD_CREDENTIALS),
_ => util::err(P::AUTH_CODE_BAD_CREDENTIALS),
}
}
//// Execute a simple query
fn execute_simple(
db: &mut Corestore,
con: &mut T,
auth: &mut AuthProviderHandle<'_, T, Strm>,
auth: &mut AuthProviderHandle<'_, P, T, Strm>,
buf: SimpleQuery
) {
self::execute_stage(db, con, auth, buf.as_slice()).await
}
}
async fn execute_stage<'a, T: 'a + ClientConnection<Strm>, Strm: Stream>(
async fn execute_stage<'a, P: ProtocolSpec, T: 'a + ClientConnection<P, Strm>, Strm: Stream>(
db: &mut Corestore,
con: &'a mut T,
auth: &mut AuthProviderHandle<'_, T, Strm>,
auth: &mut AuthProviderHandle<'_, P, T, Strm>,
buf: &[UnsafeSlice],
) -> ActionResult<()> {
let mut iter = unsafe {
@ -158,23 +158,28 @@ async fn execute_stage<'a, T: 'a + ClientConnection<Strm>, Strm: Stream>(
action! {
/// Handle `use <entity>` like queries
fn entity_swap(handle: &mut Corestore, con: &mut T, mut act: ActionIter<'a>) {
ensure_length(act.len(), |len| len == 1)?;
ensure_length::<P>(act.len(), |len| len == 1)?;
let entity = unsafe {
// SAFETY: Already checked len
act.next_unchecked()
};
handle.swap_entity(Entity::from_slice(entity)?)?;
con.write_response(groups::OKAY).await?;
translate_ddl_error::<P, ()>(handle.swap_entity(Entity::from_slice::<P>(entity)?))?;
con._write_raw(P::RCODE_OKAY).await?;
Ok(())
}
}
/// Execute a stage **completely**. This means that action errors are never propagated
/// over the try operator
async fn execute_stage_pedantic<'a, T: ClientConnection<Strm> + 'a, Strm: Stream + 'a>(
async fn execute_stage_pedantic<
'a,
P: ProtocolSpec,
T: ClientConnection<P, Strm> + 'a,
Strm: Stream + 'a,
>(
handle: &mut Corestore,
con: &mut T,
auth: &mut AuthProviderHandle<'_, T, Strm>,
auth: &mut AuthProviderHandle<'_, P, T, Strm>,
stage: &[UnsafeSlice],
) -> crate::IoResult<()> {
let ret = async {
@ -183,7 +188,7 @@ async fn execute_stage_pedantic<'a, T: ClientConnection<Strm> + 'a, Strm: Stream
};
match ret.await {
Ok(()) => Ok(()),
Err(ActionError::ActionError(e)) => con.write_response(e).await,
Err(ActionError::ActionError(e)) => con._write_raw(e).await,
Err(ActionError::IoError(ioe)) => Err(ioe),
}
}
@ -193,7 +198,7 @@ action! {
fn execute_pipeline(
handle: &mut Corestore,
con: &mut T,
auth: &mut AuthProviderHandle<'_, T, Strm>,
auth: &mut AuthProviderHandle<'_, P, T, Strm>,
pipeline: PipelinedQuery
) {
for stage in pipeline.into_inner().iter() {

@ -26,7 +26,7 @@
use crate::corestore::{lazy::Lazy, memstore::ObjectID};
use crate::kvengine::encoding;
use crate::protocol::responses;
use crate::queryengine::ProtocolSpec;
use crate::util::{
self,
compiler::{self, cold_err},
@ -47,20 +47,20 @@ pub(super) static VALID_CONTAINER_NAME: LazyRegexFn =
pub(super) static VALID_TYPENAME: LazyRegexFn =
LazyRegexFn::new(|| Regex::new("^<[a-zA-Z][a-zA-Z0-9]+[^>\\s]?>{1}$").unwrap());
pub(super) fn parse_table_args<'a>(
pub(super) fn parse_table_args<'a, P: ProtocolSpec>(
table_name: &'a [u8],
model_name: &'a [u8],
) -> Result<(Entity<'a>, u8), &'static [u8]> {
if compiler::unlikely(!encoding::is_utf8(&table_name) || !encoding::is_utf8(&model_name)) {
return Err(responses::groups::ENCODING_ERROR);
return Err(P::RCODE_ENCODING_ERROR);
}
let model_name_str = unsafe { str::from_utf8_unchecked(model_name) };
// get the entity group
let entity_group = Entity::from_slice(table_name)?;
let entity_group = Entity::from_slice::<P>(table_name)?;
let splits: Vec<&str> = model_name_str.split('(').collect();
if compiler::unlikely(splits.len() != 2) {
return Err(responses::groups::BAD_EXPRESSION);
return Err(P::RSTRING_BAD_EXPRESSION);
}
let model_name_split = unsafe { ucidx!(splits, 0) };
@ -69,19 +69,19 @@ pub(super) fn parse_table_args<'a>(
// model name has to have at least one char while model args should have
// atleast `)` 1 chars (for example if the model takes no arguments: `smh()`)
if compiler::unlikely(model_name_split.is_empty() || model_args_split.is_empty()) {
return Err(responses::groups::BAD_EXPRESSION);
return Err(P::RSTRING_BAD_EXPRESSION);
}
// THIS IS WHERE WE HANDLE THE NEWER MODELS
if model_name_split.as_bytes() != KEYMAP {
return Err(responses::groups::UNKNOWN_MODEL);
return Err(P::RSTRING_UNKNOWN_MODEL);
}
let non_bracketed_end =
unsafe { ucidx!(*model_args_split.as_bytes(), model_args_split.len() - 1) != b')' };
if compiler::unlikely(non_bracketed_end) {
return Err(responses::groups::BAD_EXPRESSION);
return Err(P::RSTRING_BAD_EXPRESSION);
}
// should be (ty1, ty2)
@ -96,10 +96,10 @@ pub(super) fn parse_table_args<'a>(
let all_nonzero = model_args.into_iter().all(|v| !v.is_empty());
if all_nonzero {
// arg fun
Err(responses::groups::TOO_MANY_ARGUMENTS)
Err(P::RSTRING_TOO_MANY_ARGUMENTS)
} else {
// comma fun
Err(responses::groups::BAD_EXPRESSION)
Err(P::RSTRING_BAD_EXPRESSION)
}
});
}
@ -116,7 +116,7 @@ pub(super) fn parse_table_args<'a>(
VALID_CONTAINER_NAME.is_match(val_ty)
};
if compiler::unlikely(!(valid_key_ty || valid_val_ty)) {
return Err(responses::groups::BAD_EXPRESSION);
return Err(P::RSTRING_BAD_EXPRESSION);
}
let key_ty = key_ty.as_bytes();
let val_ty = val_ty.as_bytes();
@ -132,8 +132,8 @@ pub(super) fn parse_table_args<'a>(
(STR, LIST_BINSTR) => 6,
(STR, LIST_STR) => 7,
// KVExt bad keytypes (we can't use lists as keys for obvious reasons)
(LIST_STR, _) | (LIST_BINSTR, _) => return Err(responses::groups::BAD_TYPE_FOR_KEY),
_ => return Err(responses::groups::UNKNOWN_DATA_TYPE),
(LIST_STR, _) | (LIST_BINSTR, _) => return Err(P::RSTRING_BAD_TYPE_FOR_KEY),
_ => return Err(P::RCODE_UNKNOWN_DATA_TYPE),
};
Ok((entity_group, model_code))
}
@ -179,29 +179,29 @@ impl<'a> fmt::Debug for Entity<'a> {
}
impl<'a> Entity<'a> {
pub fn from_slice(input: ByteSlice<'a>) -> Result<Entity<'a>, &'static [u8]> {
pub fn from_slice<P: ProtocolSpec>(input: ByteSlice<'a>) -> Result<Entity<'a>, &'static [u8]> {
let parts: Vec<&[u8]> = input.split(|b| *b == b':').collect();
if compiler::unlikely(parts.is_empty() || parts.len() > 2) {
return util::err(responses::groups::BAD_EXPRESSION);
return util::err(P::RSTRING_BAD_EXPRESSION);
}
// just the table
let first_entity = unsafe { ucidx!(parts, 0) };
if parts.len() == 1 {
Ok(Entity::Single(Self::verify_entity_name(first_entity)?))
Ok(Entity::Single(Self::verify_entity_name::<P>(first_entity)?))
} else {
let second_entity = Self::verify_entity_name(unsafe { ucidx!(parts, 1) })?;
let second_entity = Self::verify_entity_name::<P>(unsafe { ucidx!(parts, 1) })?;
if first_entity.is_empty() {
// partial syntax; so the table is in the second position
Ok(Entity::Partial(second_entity))
} else {
let keyspace = Self::verify_entity_name(first_entity)?;
let table = Self::verify_entity_name(second_entity)?;
let keyspace = Self::verify_entity_name::<P>(first_entity)?;
let table = Self::verify_entity_name::<P>(second_entity)?;
Ok(Entity::Full(keyspace, table))
}
}
}
#[inline(always)]
fn verify_entity_name(input: &[u8]) -> Result<&[u8], &'static [u8]> {
fn verify_entity_name<P: ProtocolSpec>(input: &[u8]) -> Result<&[u8], &'static [u8]> {
let mut valid_name = input.len() < 65
&& encoding::is_utf8(input)
&& unsafe { VALID_CONTAINER_NAME.is_match(str::from_utf8_unchecked(input)) };
@ -220,13 +220,13 @@ impl<'a> Entity<'a> {
Ok(input)
} else if compiler::unlikely(input.is_empty()) {
// bad expression (something like `:`)
util::err(responses::groups::BAD_EXPRESSION)
util::err(P::RSTRING_BAD_EXPRESSION)
} else if compiler::unlikely(input.eq(b"system")) {
// system cannot be switched to
util::err(responses::groups::PROTECTED_OBJECT)
util::err(P::RSTRING_PROTECTED_OBJECT)
} else {
// the container has a bad name
util::err(responses::groups::BAD_CONTAINER_NAME)
util::err(P::RSTRING_BAD_CONTAINER_NAME)
}
}
pub fn as_owned(&self) -> OwnedEntity {

@ -28,6 +28,8 @@ use super::parser;
mod parser_ddl_tests {
use super::parser::Entity;
use crate::protocol::interface::ProtocolSpec;
use crate::protocol::Skyhash2;
macro_rules! byvec {
($($element:expr),*) => {
vec![
@ -38,9 +40,8 @@ mod parser_ddl_tests {
};
}
fn parse_table_args_test(input: Vec<&'static [u8]>) -> Result<(Entity<'_>, u8), &'static [u8]> {
super::parser::parse_table_args(input[0], input[1])
super::parser::parse_table_args::<Skyhash2>(input[0], input[1])
}
use crate::protocol::responses;
#[test]
fn test_table_args_valid() {
// binstr, binstr
@ -97,12 +98,12 @@ mod parser_ddl_tests {
let it = byvec!("1one", "keymap(binstr,binstr)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_CONTAINER_NAME
Skyhash2::RSTRING_BAD_CONTAINER_NAME
);
let it = byvec!("%whywouldsomeone", "keymap(binstr,binstr)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_CONTAINER_NAME
Skyhash2::RSTRING_BAD_CONTAINER_NAME
);
}
#[test]
@ -133,22 +134,22 @@ mod parser_ddl_tests {
let it = byvec!("mycooltbl", "keymap(wth, str)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::UNKNOWN_DATA_TYPE
Skyhash2::RCODE_UNKNOWN_DATA_TYPE
);
let it = byvec!("mycooltbl", "keymap(wth, wth)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::UNKNOWN_DATA_TYPE
Skyhash2::RCODE_UNKNOWN_DATA_TYPE
);
let it = byvec!("mycooltbl", "keymap(str, wth)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::UNKNOWN_DATA_TYPE
Skyhash2::RCODE_UNKNOWN_DATA_TYPE
);
let it = byvec!("mycooltbl", "keymap(wth1, wth2)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::UNKNOWN_DATA_TYPE
Skyhash2::RCODE_UNKNOWN_DATA_TYPE
);
}
#[test]
@ -156,17 +157,17 @@ mod parser_ddl_tests {
let it = byvec!("mycooltbl", "wthmap(wth, wth)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::UNKNOWN_MODEL
Skyhash2::RSTRING_UNKNOWN_MODEL
);
let it = byvec!("mycooltbl", "wthmap(str, str)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::UNKNOWN_MODEL
Skyhash2::RSTRING_UNKNOWN_MODEL
);
let it = byvec!("mycooltbl", "wthmap()");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::UNKNOWN_MODEL
Skyhash2::RSTRING_UNKNOWN_MODEL
);
}
#[test]
@ -174,82 +175,82 @@ mod parser_ddl_tests {
let it = byvec!("mycooltbl", "keymap(");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap(,");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap(,,");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap),");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap),,");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap),,)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap(,)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap(,,)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap,");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap,,");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap,,)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap(str,");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap(str,str");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap(str,str,");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap(str,str,)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
let it = byvec!("mycooltbl", "keymap(str,str,),");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_EXPRESSION
Skyhash2::RSTRING_BAD_EXPRESSION
);
}
@ -258,14 +259,14 @@ mod parser_ddl_tests {
let it = byvec!("mycooltbl", "keymap(str, str, str)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::TOO_MANY_ARGUMENTS
Skyhash2::RSTRING_TOO_MANY_ARGUMENTS
);
// this should be valid for not-yet-known data types too
let it = byvec!("mycooltbl", "keymap(wth, wth, wth)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::TOO_MANY_ARGUMENTS
Skyhash2::RSTRING_TOO_MANY_ARGUMENTS
);
}
@ -274,86 +275,96 @@ mod parser_ddl_tests {
let it = byvec!("myverycooltbl", "keymap(list<str>, str)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_TYPE_FOR_KEY
Skyhash2::RSTRING_BAD_TYPE_FOR_KEY
);
let it = byvec!("myverycooltbl", "keymap(list<binstr>, binstr)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_TYPE_FOR_KEY
Skyhash2::RSTRING_BAD_TYPE_FOR_KEY
);
// for consistency checks
let it = byvec!("myverycooltbl", "keymap(list<str>, binstr)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_TYPE_FOR_KEY
Skyhash2::RSTRING_BAD_TYPE_FOR_KEY
);
let it = byvec!("myverycooltbl", "keymap(list<binstr>, str)");
assert_eq!(
parse_table_args_test(it).unwrap_err(),
responses::groups::BAD_TYPE_FOR_KEY
Skyhash2::RSTRING_BAD_TYPE_FOR_KEY
);
}
}
mod entity_parser_tests {
use super::parser::Entity;
use crate::protocol::responses;
use crate::protocol::interface::ProtocolSpec;
use crate::protocol::Skyhash2;
#[test]
fn test_query_full_entity_okay() {
let x = byt!("ks:tbl");
assert_eq!(Entity::from_slice(&x).unwrap(), Entity::Full(b"ks", b"tbl"));
assert_eq!(
Entity::from_slice::<Skyhash2>(&x).unwrap(),
Entity::Full(b"ks", b"tbl")
);
}
#[test]
fn test_query_half_entity() {
let x = byt!("tbl");
assert_eq!(Entity::from_slice(&x).unwrap(), Entity::Single(b"tbl"))
assert_eq!(
Entity::from_slice::<Skyhash2>(&x).unwrap(),
Entity::Single(b"tbl")
)
}
#[test]
fn test_query_partial_entity() {
let x = byt!(":tbl");
assert_eq!(Entity::from_slice(&x).unwrap(), Entity::Partial(b"tbl"))
assert_eq!(
Entity::from_slice::<Skyhash2>(&x).unwrap(),
Entity::Partial(b"tbl")
)
}
#[test]
fn test_query_entity_badexpr() {
let x = byt!("ks:");
assert_eq!(
Entity::from_slice(&x).unwrap_err(),
responses::groups::BAD_EXPRESSION
Entity::from_slice::<Skyhash2>(&x).unwrap_err(),
Skyhash2::RSTRING_BAD_EXPRESSION
);
let x = byt!(":");
assert_eq!(
Entity::from_slice(&x).unwrap_err(),
responses::groups::BAD_EXPRESSION
Entity::from_slice::<Skyhash2>(&x).unwrap_err(),
Skyhash2::RSTRING_BAD_EXPRESSION
);
let x = byt!("::");
assert_eq!(
Entity::from_slice(&x).unwrap_err(),
responses::groups::BAD_EXPRESSION
Entity::from_slice::<Skyhash2>(&x).unwrap_err(),
Skyhash2::RSTRING_BAD_EXPRESSION
);
let x = byt!("::ks");
assert_eq!(
Entity::from_slice(&x).unwrap_err(),
responses::groups::BAD_EXPRESSION
Entity::from_slice::<Skyhash2>(&x).unwrap_err(),
Skyhash2::RSTRING_BAD_EXPRESSION
);
let x = byt!("ks::tbl");
assert_eq!(
Entity::from_slice(&x).unwrap_err(),
responses::groups::BAD_EXPRESSION
Entity::from_slice::<Skyhash2>(&x).unwrap_err(),
Skyhash2::RSTRING_BAD_EXPRESSION
);
let x = byt!("ks::");
assert_eq!(
Entity::from_slice(&x).unwrap_err(),
responses::groups::BAD_EXPRESSION
Entity::from_slice::<Skyhash2>(&x).unwrap_err(),
Skyhash2::RSTRING_BAD_EXPRESSION
);
let x = byt!("ks::tbl::");
assert_eq!(
Entity::from_slice(&x).unwrap_err(),
responses::groups::BAD_EXPRESSION
Entity::from_slice::<Skyhash2>(&x).unwrap_err(),
Skyhash2::RSTRING_BAD_EXPRESSION
);
let x = byt!("::ks::tbl::");
assert_eq!(
Entity::from_slice(&x).unwrap_err(),
responses::groups::BAD_EXPRESSION
Entity::from_slice::<Skyhash2>(&x).unwrap_err(),
Skyhash2::RSTRING_BAD_EXPRESSION
);
}
@ -361,21 +372,21 @@ mod entity_parser_tests {
fn test_bad_entity_name() {
let ename = byt!("$var");
assert_eq!(
Entity::from_slice(&ename).unwrap_err(),
responses::groups::BAD_CONTAINER_NAME
Entity::from_slice::<Skyhash2>(&ename).unwrap_err(),
Skyhash2::RSTRING_BAD_CONTAINER_NAME
);
}
#[test]
fn ks_or_table_with_preload_or_partmap() {
let badname = byt!("PARTMAP");
assert_eq!(
Entity::from_slice(&badname).unwrap_err(),
responses::groups::BAD_CONTAINER_NAME
Entity::from_slice::<Skyhash2>(&badname).unwrap_err(),
Skyhash2::RSTRING_BAD_CONTAINER_NAME
);
let badname = byt!("PRELOAD");
assert_eq!(
Entity::from_slice(&badname).unwrap_err(),
responses::groups::BAD_CONTAINER_NAME
Entity::from_slice::<Skyhash2>(&badname).unwrap_err(),
Skyhash2::RSTRING_BAD_CONTAINER_NAME
);
}
}

@ -1,226 +0,0 @@
/*
* Created on Mon Aug 17 2020
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2020, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
#![allow(clippy::needless_lifetimes)]
//! Utilities for generating responses, which are only used by the `server`
//!
use crate::corestore::buffers::Integer64;
use crate::corestore::memstore::ObjectID;
use crate::util::FutureResult;
use bytes::Bytes;
use std::io::Error as IoError;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
pub mod writer;
pub const TSYMBOL_UNICODE_STRING: u8 = b'+';
pub const TSYMBOL_FLOAT: u8 = b'%';
type FutureIoResult<'s> = FutureResult<'s, Result<(), IoError>>;
/// # The `Writable` trait
/// All trait implementors are given access to an asynchronous stream to which
/// they must write a response.
///
/// Every `write()` call makes a call to the [`IsConnection`](./IsConnection)'s
/// `write_lowlevel` function, which in turn writes something to the underlying stream.
///
/// Do note that this write **doesn't gurantee immediate completion** as the underlying
/// stream might use buffering. So, the best idea would be to use to use the `flush()`
/// call on the stream.
pub trait Writable {
/*
HACK(@ohsayan): Since `async` is not supported in traits just yet, we will have to
use explicit declarations for asynchoronous functions
*/
fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s>;
}
pub trait IsConnection: std::marker::Sync + std::marker::Send {
fn write_lowlevel<'s>(&'s mut self, bytes: &'s [u8]) -> FutureIoResult<'s>;
}
impl<T> IsConnection for T
where
T: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
fn write_lowlevel<'s>(&'s mut self, bytes: &'s [u8]) -> FutureIoResult<'s> {
Box::pin(self.write_all(bytes))
}
}
/// A `BytesWrapper` object wraps around a `Bytes` object that might have been pulled
/// from `Corestore`.
///
/// This wrapper exists to prevent trait implementation conflicts when
/// an impl for `fmt::Display` may be implemented upstream
#[derive(Debug, PartialEq)]
pub struct BytesWrapper(pub Bytes);
impl BytesWrapper {
pub fn finish_into_bytes(self) -> Bytes {
self.0
}
}
#[derive(Debug, PartialEq)]
pub struct StringWrapper(pub String);
impl Writable for StringWrapper {
fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> {
Box::pin(async move {
con.write_lowlevel(&[TSYMBOL_UNICODE_STRING]).await?;
// Now get the size of the Bytes object as bytes
let size = Integer64::from(self.0.len());
// Write this to the stream
con.write_lowlevel(&size).await?;
// Now write a LF character
con.write_lowlevel(&[b'\n']).await?;
// Now write the REAL bytes (of the object)
con.write_lowlevel(self.0.as_bytes()).await?;
Ok(())
})
}
}
impl Writable for Vec<u8> {
fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> {
Box::pin(async move { con.write_lowlevel(&self).await })
}
}
impl<const N: usize> Writable for [u8; N] {
fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> {
Box::pin(async move { con.write_lowlevel(&self).await })
}
}
impl Writable for &'static [u8] {
fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> {
Box::pin(async move { con.write_lowlevel(self).await })
}
}
impl Writable for &'static str {
fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> {
Box::pin(async move {
// First write a `+` character to the stream since this is a
// string (we represent `String`s as `Byte` objects internally)
// and since `Bytes` are effectively `String`s we will append the
// type operator `+` to the stream
con.write_lowlevel(&[TSYMBOL_UNICODE_STRING]).await?;
// Now get the size of the Bytes object as bytes
let size = Integer64::from(self.len());
// Write this to the stream
con.write_lowlevel(&size).await?;
// Now write a LF character
con.write_lowlevel(&[b'\n']).await?;
// Now write the REAL bytes (of the object)
con.write_lowlevel(self.as_bytes()).await?;
Ok(())
})
}
}
impl Writable for BytesWrapper {
fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> {
Box::pin(async move {
// First write a `+` character to the stream since this is a
// string (we represent `String`s as `Byte` objects internally)
// and since `Bytes` are effectively `String`s we will append the
// type operator `+` to the stream
let bytes = self.finish_into_bytes();
con.write_lowlevel(&[TSYMBOL_UNICODE_STRING]).await?;
// Now get the size of the Bytes object as bytes
let size = Integer64::from(bytes.len());
// Write this to the stream
con.write_lowlevel(&size).await?;
// Now write a LF character
con.write_lowlevel(&[b'\n']).await?;
// Now write the REAL bytes (of the object)
con.write_lowlevel(&bytes).await?;
Ok(())
})
}
}
impl Writable for usize {
fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> {
Box::pin(async move {
con.write_lowlevel(b":").await?;
let usize_bytes = Integer64::from(self);
con.write_lowlevel(&usize_bytes).await?;
con.write_lowlevel(b"\n").await?;
Ok(())
})
}
}
impl Writable for u64 {
fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> {
Box::pin(async move {
con.write_lowlevel(b":").await?;
let usize_bytes = Integer64::from(self);
con.write_lowlevel(&usize_bytes).await?;
con.write_lowlevel(b"\n").await?;
Ok(())
})
}
}
impl Writable for ObjectID {
fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> {
Box::pin(async move {
// First write a `+` character to the stream since this is a
// string (we represent `String`s as `Byte` objects internally)
// and since `Bytes` are effectively `String`s we will append the
// type operator `+` to the stream
con.write_lowlevel(&[TSYMBOL_UNICODE_STRING]).await?;
// Now get the size of the Bytes object as bytes
let size = Integer64::from(self.len());
// Write this to the stream
con.write_lowlevel(&size).await?;
// Now write a LF character
con.write_lowlevel(&[b'\n']).await?;
// Now write the REAL bytes (of the object)
con.write_lowlevel(&self).await?;
Ok(())
})
}
}
impl Writable for f32 {
fn write<'s>(self, con: &'s mut impl IsConnection) -> FutureIoResult<'s> {
Box::pin(async move {
let payload = self.to_string();
con.write_lowlevel(&[TSYMBOL_FLOAT]).await?;
con.write_lowlevel(payload.as_bytes()).await?;
con.write_lowlevel(&[b'\n']).await?;
Ok(())
})
}
}

@ -1,225 +0,0 @@
/*
* Created on Thu Aug 12 2021
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2021, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
use crate::corestore::buffers::Integer64;
use crate::corestore::Data;
use crate::dbnet::connection::ProtocolConnectionExt;
use crate::protocol::responses::groups;
use crate::IoResult;
use core::marker::PhantomData;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
/// Write a raw mono group with a custom tsymbol
pub async unsafe fn write_raw_mono<T, Strm>(
con: &mut T,
tsymbol: u8,
payload: &Data,
) -> IoResult<()>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
let raw_stream = con.raw_stream();
raw_stream.write_all(&[tsymbol; 1]).await?; // first write tsymbol
let bytes = Integer64::from(payload.len());
raw_stream.write_all(&bytes).await?; // then len
raw_stream.write_all(&[b'\n']).await?; // LF
raw_stream.write_all(payload).await?; // payload
Ok(())
}
#[derive(Debug)]
/// A writer for a flat array, which is a multi-typed non-recursive array
pub struct FlatArrayWriter<'a, T, Strm> {
tsymbol: u8,
con: &'a mut T,
_owned: PhantomData<Strm>,
}
#[allow(dead_code)] // TODO(@ohsayan): Remove this once we start using the flat array writer
impl<'a, T, Strm> FlatArrayWriter<'a, T, Strm>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
/// Intialize a new flat array writer. This will write out the tsymbol
/// and length for the flat array
pub async unsafe fn new(
con: &'a mut T,
tsymbol: u8,
len: usize,
) -> IoResult<FlatArrayWriter<'a, T, Strm>> {
{
let stream = con.raw_stream();
// first write _
stream.write_all(&[b'_']).await?;
let bytes = Integer64::from(len);
// now write len
stream.write_all(&bytes).await?;
// first LF
stream.write_all(&[b'\n']).await?;
}
Ok(Self {
con,
tsymbol,
_owned: PhantomData,
})
}
/// Write an element
pub async fn write_element(&mut self, bytes: impl AsRef<[u8]>) -> IoResult<()> {
let stream = unsafe { self.con.raw_stream() };
let bytes = bytes.as_ref();
// first write <tsymbol>
stream.write_all(&[self.tsymbol]).await?;
// now len
let len = Integer64::from(bytes.len());
stream.write_all(&len).await?;
// now LF
stream.write_all(&[b'\n']).await?;
// now element
stream.write_all(bytes).await?;
Ok(())
}
/// Write the NIL response code
pub async fn write_nil(&mut self) -> IoResult<()> {
let stream = unsafe { self.con.raw_stream() };
stream.write_all(groups::NIL).await?;
Ok(())
}
/// Write the SERVER_ERR (5) response code
pub async fn write_server_error(&mut self) -> IoResult<()> {
let stream = unsafe { self.con.raw_stream() };
stream.write_all(groups::NIL).await?;
Ok(())
}
}
#[derive(Debug)]
/// A writer for a typed array, which is a singly-typed array which either
/// has a typed element or a `NULL`
pub struct TypedArrayWriter<'a, T, Strm> {
con: &'a mut T,
_owned: PhantomData<Strm>,
}
impl<'a, T, Strm> TypedArrayWriter<'a, T, Strm>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
/// Create a new `typedarraywriter`. This will write the tsymbol and
/// the array length
pub async unsafe fn new(
con: &'a mut T,
tsymbol: u8,
len: usize,
) -> IoResult<TypedArrayWriter<'a, T, Strm>> {
{
let stream = con.raw_stream();
// first write @<tsymbol>
stream.write_all(&[b'@', tsymbol]).await?;
let bytes = Integer64::from(len);
// now write len
stream.write_all(&bytes).await?;
// first LF
stream.write_all(&[b'\n']).await?;
}
Ok(Self {
con,
_owned: PhantomData,
})
}
/// Write an element
pub async fn write_element(&mut self, bytes: impl AsRef<[u8]>) -> IoResult<()> {
let stream = unsafe { self.con.raw_stream() };
let bytes = bytes.as_ref();
// write len
let len = Integer64::from(bytes.len());
stream.write_all(&len).await?;
// now LF
stream.write_all(&[b'\n']).await?;
// now element
stream.write_all(bytes).await?;
Ok(())
}
/// Write a null
pub async fn write_null(&mut self) -> IoResult<()> {
let stream = unsafe { self.con.raw_stream() };
stream.write_all(&[b'\0']).await?;
Ok(())
}
}
#[derive(Debug)]
/// A writer for a non-null typed array, which is a singly-typed array which either
/// has a typed element or a `NULL`
pub struct NonNullArrayWriter<'a, T, Strm> {
con: &'a mut T,
_owned: PhantomData<Strm>,
}
impl<'a, T, Strm> NonNullArrayWriter<'a, T, Strm>
where
T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
/// Create a new `typedarraywriter`. This will write the tsymbol and
/// the array length
pub async unsafe fn new(
con: &'a mut T,
tsymbol: u8,
len: usize,
) -> IoResult<NonNullArrayWriter<'a, T, Strm>> {
{
let stream = con.raw_stream();
// first write @<tsymbol>
stream.write_all(&[b'^', tsymbol]).await?;
let bytes = Integer64::from(len);
// now write len
stream.write_all(&bytes).await?;
// first LF
stream.write_all(&[b'\n']).await?;
}
Ok(Self {
con,
_owned: PhantomData,
})
}
/// Write an element
pub async fn write_element(&mut self, bytes: impl AsRef<[u8]>) -> IoResult<()> {
let stream = unsafe { self.con.raw_stream() };
let bytes = bytes.as_ref();
// write len
let len = Integer64::from(bytes.len());
stream.write_all(&len).await?;
// now LF
stream.write_all(&[b'\n']).await?;
// now element
stream.write_all(bytes).await?;
Ok(())
}
}

@ -34,7 +34,7 @@ mod __private {
query.push("KEYSPACES");
assert!(matches!(
con.run_query_raw(&query).await.unwrap(),
Element::Array(Array::Str(_))
Element::Array(Array::NonNullStr(_))
))
}
async fn test_inspect_keyspace() {
@ -43,7 +43,7 @@ mod __private {
query.push(&__MYKS__);
assert!(matches!(
con.run_query_raw(&query).await.unwrap(),
Element::Array(Array::Str(_))
Element::Array(Array::NonNullStr(_))
))
}
async fn test_inspect_current_keyspace() {

@ -1035,8 +1035,7 @@ mod __private {
.into_iter()
.map(|element| element.to_owned())
.collect();
if let Element::Array(Array::Str(arr)) = ret {
let arr: Vec<String> = arr.into_iter().map(|v| v.unwrap()).collect();
if let Element::Array(Array::NonNullStr(arr)) = ret {
assert_eq!(ret_should_have.len(), arr.len());
assert!(ret_should_have.into_iter().all(|key| arr.contains(&key)));
} else {
@ -1070,8 +1069,7 @@ mod __private {
.into_iter()
.map(|element| element.to_owned())
.collect();
if let Element::Array(Array::Str(arr)) = ret {
let arr: Vec<String> = arr.into_iter().map(|v| v.unwrap()).collect();
if let Element::Array(Array::NonNullStr(arr)) = ret {
assert_eq!(ret_should_have.len(), arr.len());
assert!(ret_should_have.into_iter().all(|key| arr.contains(&key)));
} else {
@ -1092,8 +1090,7 @@ mod __private {
.into_iter()
.map(|element| element.to_owned())
.collect();
if let Element::Array(Array::Str(arr)) = ret {
let arr: Vec<String> = arr.into_iter().map(|v| v.unwrap()).collect();
if let Element::Array(Array::NonNullStr(arr)) = ret {
assert_eq!(ret_should_have.len(), arr.len());
assert!(ret_should_have.into_iter().all(|key| arr.contains(&key)));
} else {
@ -1115,8 +1112,7 @@ mod __private {
.into_iter()
.map(|element| element.to_owned())
.collect();
if let Element::Array(Array::Str(arr)) = ret {
let arr: Vec<String> = arr.into_iter().map(|v| v.unwrap()).collect();
if let Element::Array(Array::NonNullStr(arr)) = ret {
assert_eq!(ret_should_have.len(), arr.len());
assert!(ret_should_have.into_iter().all(|key| arr.contains(&key)));
} else {

@ -61,7 +61,7 @@ mod __private {
async fn test_lget_emptylist_okay() {
lset!(con, "mysuperlist");
let q = query!("lget", "mysuperlist");
runeq!(con, q, Element::Array(Array::Str(vec![])));
runeq!(con, q, Element::Array(Array::NonNullStr(vec![])));
}
async fn test_lget_list_with_elements_okay() {
lset!(con, "mysuperlist", "elementa", "elementb", "elementc");

@ -104,7 +104,7 @@ macro_rules! assert_okay {
}
macro_rules! assert_skyhash_arrayeq {
(str, $con:expr, $query:expr, $($val:expr),*) => {
(!str, $con:expr, $query:expr, $($val:expr),*) => {
runeq!(
$con,
$query,
@ -115,13 +115,13 @@ macro_rules! assert_skyhash_arrayeq {
))
)
};
(bin, $con:expr, $query:expr, $($val:expr),*) => {
(str, $con:expr, $query:expr, $($val:expr),*) => {
runeq!(
$con,
$query,
skytable::Element::Array(skytable::types::Array::Bin(
skytable::Element::Array(skytable::types::Array::NonNullStr(
vec![
$(Some($val.into()),)*
$($val.into(),)*
]
))
)

@ -52,7 +52,7 @@ mod tls {
}
mod sys {
use crate::protocol::{PROTOCOL_VERSION, PROTOCOL_VERSIONSTRING};
use crate::protocol::{LATEST_PROTOCOL_VERSION, LATEST_PROTOCOL_VERSIONSTRING};
use libsky::VERSION;
use sky_macros::dbtest_func as dbtest;
use skytable::{query, Element, RespCode};
@ -79,7 +79,7 @@ mod sys {
runeq!(
con,
query!("sys", "info", "protocol"),
Element::String(PROTOCOL_VERSIONSTRING.to_owned())
Element::String(LATEST_PROTOCOL_VERSIONSTRING.to_owned())
)
}
#[dbtest]
@ -87,7 +87,7 @@ mod sys {
runeq!(
con,
query!("sys", "info", "protover"),
Element::Float(PROTOCOL_VERSION)
Element::Float(LATEST_PROTOCOL_VERSION)
)
}
#[dbtest]

@ -161,9 +161,9 @@ impl<const N: usize> PersistValue for [&[u8]; N] {
fn response_load(&self) -> Element {
let mut flat = Vec::with_capacity(N);
for item in self {
flat.push(Some(item.to_vec()));
flat.push(item.to_vec());
}
Element::Array(Array::Bin(flat))
Element::Array(Array::NonNullBin(flat))
}
}
@ -174,9 +174,9 @@ impl<const N: usize> PersistValue for [&str; N] {
fn response_load(&self) -> Element {
let mut flat = Vec::with_capacity(N);
for item in self {
flat.push(Some(item.to_string()));
flat.push(item.to_string());
}
Element::Array(Array::Str(flat))
Element::Array(Array::NonNullStr(flat))
}
}

@ -120,26 +120,58 @@ macro_rules! action {
$block:block)*
) => {
$($(#[$attr])*
pub async fn $fname<'a, T: 'a + ClientConnection<Strm>, Strm:Stream>(
pub async fn $fname<
'a,
T: 'a + $crate::dbnet::connection::ClientConnection<P, Strm>,
Strm: $crate::dbnet::connection::Stream,
P: $crate::protocol::interface::ProtocolSpec
> (
$($argname: $argty,)*
) -> $crate::actions::ActionResult<()>
$block)*
};
(
$($(#[$attr:meta])*
fn $fname:ident($argone:ident: $argonety:ty,
fn $fname:ident(
$argone:ident: $argonety:ty,
$argtwo:ident: $argtwoty:ty,
mut $argthree:ident: $argthreety:ty)
$block:block)*
mut $argthree:ident: $argthreety:ty
) $block:block)*
) => {
$($(#[$attr])*
pub async fn $fname<'a, T: 'a + ClientConnection<Strm>, Strm:Stream>(
pub async fn $fname<
'a,
T: 'a + $crate::dbnet::connection::ClientConnection<P, Strm>,
Strm: $crate::dbnet::connection::Stream,
P: $crate::protocol::interface::ProtocolSpec
>(
$argone: $argonety,
$argtwo: $argtwoty,
mut $argthree: $argthreety
) -> $crate::actions::ActionResult<()>
$block)*
};
(
$($(#[$attr:meta])*
fn $fname:ident(
$argone:ident: $argonety:ty,
$argtwo:ident: $argtwoty:ty,
$argthree:ident: $argthreety:ty
) $block:block)*
) => {
$($(#[$attr])*
pub async fn $fname<
'a,
T: 'a + $crate::dbnet::connection::ClientConnection<P, Strm>,
Strm: $crate::dbnet::connection::Stream,
P: $crate::protocol::interface::ProtocolSpec
>(
$argone: $argonety,
$argtwo: $argtwoty,
$argthree: $argthreety
) -> $crate::actions::ActionResult<()>
$block)*
};
}
#[macro_export]

@ -27,10 +27,10 @@
#[macro_use]
mod macros;
pub mod compiler;
pub mod os;
pub mod error;
pub mod os;
use crate::actions::{ActionError, ActionResult};
use crate::protocol::responses::groups;
use crate::protocol::interface::ProtocolSpec;
use core::fmt::Debug;
use core::future::Future;
use core::ops::Deref;
@ -79,15 +79,15 @@ unsafe impl<T> Unwrappable<T> for Option<T> {
pub trait UnwrapActionError<T> {
fn unwrap_or_custom_aerr(self, e: impl Into<ActionError>) -> ActionResult<T>;
fn unwrap_or_aerr(self) -> ActionResult<T>;
fn unwrap_or_aerr<P: ProtocolSpec>(self) -> ActionResult<T>;
}
impl<T> UnwrapActionError<T> for Option<T> {
fn unwrap_or_custom_aerr(self, e: impl Into<ActionError>) -> ActionResult<T> {
self.ok_or_else(|| e.into())
}
fn unwrap_or_aerr(self) -> ActionResult<T> {
self.ok_or_else(|| groups::ACTION_ERR.into())
fn unwrap_or_aerr<P: ProtocolSpec>(self) -> ActionResult<T> {
self.ok_or_else(|| P::RCODE_ACTION_ERR.into())
}
}

@ -106,18 +106,40 @@ pub fn dbtest_func(args: TokenStream, item: TokenStream) -> TokenStream {
/// Get a compile time respcode/respstring array. For example, if you pass: "Unknown action",
/// it will return: `!14\nUnknown Action\n`
pub fn compiled_eresp_array(tokens: TokenStream) -> TokenStream {
_get_eresp_array(tokens)
_get_eresp_array(tokens, false)
}
fn _get_eresp_array(tokens: TokenStream) -> TokenStream {
#[proc_macro]
/// Get a compile time respcode/respstring array. For example, if you pass: "Unknown action",
/// it will return: `!14\n14\nUnknown Action\n`
pub fn compiled_eresp_array_v1(tokens: TokenStream) -> TokenStream {
_get_eresp_array(tokens, true)
}
fn _get_eresp_array(tokens: TokenStream, sizeline: bool) -> TokenStream {
let payload_str = match syn::parse_macro_input!(tokens as Lit) {
Lit::Str(st) => st.value(),
_ => panic!("Expected a string literal"),
};
let payload_bytes = payload_str.as_bytes();
let mut processed = quote! {
b'!',
};
if sizeline {
let payload_len = payload_str.as_bytes().len();
let payload_len_str = payload_len.to_string();
let payload_len_bytes = payload_len_str.as_bytes();
for byte in payload_len_bytes {
processed = quote! {
#processed
#byte,
};
}
processed = quote! {
#processed
b'\n',
};
}
let payload_bytes = payload_str.as_bytes();
for byte in payload_bytes {
processed = quote! {
#processed
@ -145,3 +167,15 @@ pub fn compiled_eresp_bytes(tokens: TokenStream) -> TokenStream {
}
.into()
}
#[proc_macro]
/// Get a compile time respcode/respstring slice. For example, if you pass: "Unknown action",
/// it will return: `!14\nUnknown Action\n`
pub fn compiled_eresp_bytes_v1(tokens: TokenStream) -> TokenStream {
let ret = compiled_eresp_array_v1(tokens);
let ret = syn::parse_macro_input!(ret as syn::Expr);
quote! {
&#ret
}
.into()
}

Loading…
Cancel
Save