Use iterators for actions

next
Sayan Nandan 3 years ago
parent a839137643
commit 58830edc80

@ -25,18 +25,19 @@
*/ */
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::queryengine::ActionIter;
/// Get the number of keys in the database /// Get the number of keys in the database
pub async fn dbsize<T, Strm>( pub async fn dbsize<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
crate::err_if_len_is!(act, con, != 0); crate::err_if_len_is!(act, con, not 0);
let len; let len;
{ {
len = handle.get_ref().len(); len = handle.get_ref().len();

@ -29,6 +29,7 @@
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::protocol::responses; use crate::protocol::responses;
use crate::queryengine::ActionIter;
/// Run a `DEL` query /// Run a `DEL` query
/// ///
@ -37,13 +38,13 @@ use crate::protocol::responses;
pub async fn del<T, Strm>( pub async fn del<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
crate::err_if_len_is!(act, con, == 0); crate::err_if_len_is!(act, con, eq 0);
let done_howmany: Option<usize>; let done_howmany: Option<usize>;
{ {
if handle.is_poisoned() { if handle.is_poisoned() {
@ -51,7 +52,7 @@ where
} else { } else {
let mut many = 0; let mut many = 0;
let cmap = handle.get_ref(); let cmap = handle.get_ref();
act.into_iter().skip(1).for_each(|key| { act.for_each(|key| {
if cmap.true_if_removed(key.as_bytes()) { if cmap.true_if_removed(key.as_bytes()) {
many += 1 many += 1
} }

@ -28,22 +28,23 @@
//! This module provides functions to work with `EXISTS` queries //! This module provides functions to work with `EXISTS` queries
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::queryengine::ActionIter;
/// Run an `EXISTS` query /// Run an `EXISTS` query
pub async fn exists<T, Strm>( pub async fn exists<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
crate::err_if_len_is!(act, con, == 0); crate::err_if_len_is!(act, con, eq 0);
let mut how_many_of_them_exist = 0usize; let mut how_many_of_them_exist = 0usize;
{ {
let cmap = handle.get_ref(); let cmap = handle.get_ref();
act.into_iter().skip(1).for_each(|key| { act.for_each(|key| {
if cmap.contains_key(key.as_bytes()) { if cmap.contains_key(key.as_bytes()) {
how_many_of_them_exist += 1; how_many_of_them_exist += 1;
} }

@ -26,18 +26,19 @@
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::protocol::responses; use crate::protocol::responses;
use crate::queryengine::ActionIter;
/// Delete all the keys in the database /// Delete all the keys in the database
pub async fn flushdb<T, Strm>( pub async fn flushdb<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
crate::err_if_len_is!(act, con, != 0); crate::err_if_len_is!(act, con, not 0);
let failed; let failed;
{ {
if handle.is_poisoned() { if handle.is_poisoned() {

@ -29,27 +29,33 @@
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::protocol::responses; use crate::protocol::responses;
use crate::queryengine::ActionIter;
use crate::resp::BytesWrapper; use crate::resp::BytesWrapper;
use bytes::Bytes; use bytes::Bytes;
use core::hint::unreachable_unchecked;
/// Run a `GET` query /// Run a `GET` query
pub async fn get<T, Strm>( pub async fn get<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, mut act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
crate::err_if_len_is!(act, con, != 1); crate::err_if_len_is!(act, con, not 1);
let res: Option<Bytes> = { let res: Option<Bytes> = {
let reader = handle.get_ref(); let reader = handle.get_ref();
unsafe { unsafe {
// UNSAFE(@ohsayan): act.get_ref().get_unchecked() is safe because we've already if the action // UNSAFE(@ohsayan): unreachable_unchecked is safe because we've already checked if the action
// group contains one argument (excluding the action itself) // group contains one argument (excluding the action itself)
reader reader
.get(act.get_unchecked(1).as_bytes()) .get(
act.next()
.unwrap_or_else(|| unreachable_unchecked())
.as_bytes(),
)
.map(|b| b.get_blob().clone()) .map(|b| b.get_blob().clone())
} }
}; };

@ -29,6 +29,7 @@
//! Functions for handling `JGET` queries //! Functions for handling `JGET` queries
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::queryengine::ActionIter;
/// Run a `JGET` query /// Run a `JGET` query
/// This returns a JSON key/value pair of keys and values /// This returns a JSON key/value pair of keys and values
@ -42,13 +43,13 @@ use crate::dbnet::connection::prelude::*;
pub async fn jget<T, Strm>( pub async fn jget<T, Strm>(
_handle: &crate::coredb::CoreDB, _handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
crate::err_if_len_is!(act, con, != 1); crate::err_if_len_is!(act, con, not 1);
todo!() todo!()
} }

@ -26,6 +26,8 @@
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::protocol::responses; use crate::protocol::responses;
use crate::queryengine::ActionIter;
use core::hint::unreachable_unchecked;
/// Run a `KEYLEN` query /// Run a `KEYLEN` query
/// ///
@ -33,20 +35,24 @@ use crate::protocol::responses;
pub async fn keylen<T, Strm>( pub async fn keylen<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, mut act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
crate::err_if_len_is!(act, con, != 1); crate::err_if_len_is!(act, con, not 1);
let res: Option<usize> = { let res: Option<usize> = {
let reader = handle.get_ref(); let reader = handle.get_ref();
unsafe { unsafe {
// UNSAFE(@ohsayan): get_unchecked() is completely safe as we've already checked // UNSAFE(@ohsayan): unreachable_unchecked() is completely safe as we've already checked
// the number of arguments is one // the number of arguments is one
reader reader
.get(act.get_unchecked(1).as_bytes()) .get(
act.next()
.unwrap_or_else(|| unreachable_unchecked())
.as_bytes(),
)
.map(|b| b.get_blob().len()) .map(|b| b.get_blob().len())
} }
}; };

@ -26,6 +26,7 @@
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::protocol::responses; use crate::protocol::responses;
use crate::queryengine::ActionIter;
use crate::resp::BytesWrapper; use crate::resp::BytesWrapper;
use bytes::Bytes; use bytes::Bytes;
@ -33,14 +34,14 @@ use bytes::Bytes;
pub async fn lskeys<T, Strm>( pub async fn lskeys<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, mut act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
crate::err_if_len_is!(act, con, > 1); crate::err_if_len_is!(act, con, gt 1);
let item_count = if let Some(cnt) = act.get(1) { let item_count = if let Some(cnt) = act.next() {
if let Ok(cnt) = cnt.parse::<usize>() { if let Ok(cnt) = cnt.parse::<usize>() {
cnt cnt
} else { } else {

@ -25,6 +25,7 @@
*/ */
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::queryengine::ActionIter;
use crate::resp::BytesWrapper; use crate::resp::BytesWrapper;
use bytes::Bytes; use bytes::Bytes;
use skytable::RespCode; use skytable::RespCode;
@ -34,19 +35,20 @@ use skytable::RespCode;
pub async fn mget<T, Strm>( pub async fn mget<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, mut act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
crate::err_if_len_is!(act, con, == 0); crate::err_if_len_is!(act, con, eq 0);
con.write_array_length(act.len() - 1).await?; con.write_array_length(act.len()).await?;
let mut keys = act.into_iter().skip(1); while let Some(key) = act.next() {
while let Some(key) = keys.next() {
let res: Option<Bytes> = { let res: Option<Bytes> = {
let reader = handle.get_ref(); handle
reader.get(key.as_bytes()).map(|b| b.get_blob().clone()) .get_ref()
.get(key.as_bytes())
.map(|b| b.get_blob().clone())
}; };
if let Some(value) = res { if let Some(value) = res {
// Good, we got the value, write it off to the stream // Good, we got the value, write it off to the stream

@ -24,9 +24,11 @@
* *
*/ */
//! # The Key/Value Engine //! # Actions
//! This is Skytable's K/V engine. It contains utilities to interface with //!
//! Skytable's K/V store //! Actions are like shell commands, you provide arguments -- they return output. This module contains a collection
//! of the actions supported by Skytable
//!
pub mod dbsize; pub mod dbsize;
pub mod del; pub mod del;
@ -47,12 +49,13 @@ pub mod heya {
//! Respond to `HEYA` queries //! Respond to `HEYA` queries
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::protocol; use crate::protocol;
use crate::queryengine::ActionIter;
use protocol::responses; use protocol::responses;
/// Returns a `HEY!` `Response` /// Returns a `HEY!` `Response`
pub async fn heya<T, Strm>( pub async fn heya<T, Strm>(
_handle: &crate::coredb::CoreDB, _handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
_act: Vec<String>, _act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
@ -64,50 +67,50 @@ pub mod heya {
#[macro_export] #[macro_export]
macro_rules! err_if_len_is { macro_rules! err_if_len_is {
($buf:ident, $con:ident, == $len:literal) => { ($buf:ident, $con:ident, eq $len:literal) => {
if $buf.len() - 1 == $len { if $buf.len() == $len {
return $con return $con
.write_response(&**crate::protocol::responses::groups::ACTION_ERR) .write_response(&**crate::protocol::responses::groups::ACTION_ERR)
.await; .await;
} }
}; };
($buf:ident, $con:ident, != $len:literal) => { ($buf:ident, $con:ident, not $len:literal) => {
if $buf.len() - 1 != $len { if $buf.len() != $len {
return $con return $con
.write_response(&**crate::protocol::responses::groups::ACTION_ERR) .write_response(&**crate::protocol::responses::groups::ACTION_ERR)
.await; .await;
} }
}; };
($buf:ident, $con:ident, > $len:literal) => { ($buf:ident, $con:ident, gt $len:literal) => {
if $buf.len() - 1 > $len { if $buf.len() > $len {
return $con return $con
.write_response(&**crate::protocol::responses::groups::ACTION_ERR) .write_response(&**crate::protocol::responses::groups::ACTION_ERR)
.await; .await;
} }
}; };
($buf:ident, $con:ident, < $len:literal) => { ($buf:ident, $con:ident, lt $len:literal) => {
if $buf.len() - 1 < $len { if $buf.len() < $len {
return $con return $con
.write_response(&**crate::protocol::responses::groups::ACTION_ERR) .write_response(&**crate::protocol::responses::groups::ACTION_ERR)
.await; .await;
} }
}; };
($buf:ident, $con:ident, >= $len:literal) => { ($buf:ident, $con:ident, gt_or_eq $len:literal) => {
if $buf.len() - 1 >= $len { if $buf.len() >= $len {
return $con return $con
.write_response(&**crate::protocol::responses::groups::ACTION_ERR) .write_response(&**crate::protocol::responses::groups::ACTION_ERR)
.await; .await;
} }
}; };
($buf:ident, $con:ident, <= $len:literal) => { ($buf:ident, $con:ident, lt_or_eq $len:literal) => {
if $buf.len() - 1 <= $len { if $buf.len() <= $len {
return $con return $con
.write_response(&**crate::protocol::responses::groups::ACTION_ERR) .write_response(&**crate::protocol::responses::groups::ACTION_ERR)
.await; .await;
} }
}; };
($buf:ident, $con:ident, & $len:literal) => { ($buf:ident, $con:ident, & $len:literal) => {
if $buf.len() - 1 & $len { if $buf.len() & $len {
return $con return $con
.write_response(&**crate::protocol::responses::groups::ACTION_ERR) .write_response(&**crate::protocol::responses::groups::ACTION_ERR)
.await; .await;

@ -27,25 +27,25 @@
use crate::coredb::Data; use crate::coredb::Data;
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::protocol::responses; use crate::protocol::responses;
use crate::queryengine::ActionIter;
/// Run an `MSET` query /// Run an `MSET` query
pub async fn mset<T, Strm>( pub async fn mset<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, mut act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
let howmany = act.len() - 1; let howmany = act.len();
if howmany & 1 == 1 || howmany == 0 { if howmany & 1 == 1 || howmany == 0 {
// An odd number of arguments means that the number of keys // An odd number of arguments means that the number of keys
// is not the same as the number of values, we won't run this // is not the same as the number of values, we won't run this
// action at all // action at all
return con.write_response(&**responses::groups::ACTION_ERR).await; return con.write_response(&**responses::groups::ACTION_ERR).await;
} }
let mut kviter = act.into_iter().skip(1);
let done_howmany: Option<usize>; let done_howmany: Option<usize>;
{ {
if handle.is_poisoned() { if handle.is_poisoned() {
@ -53,7 +53,7 @@ where
} else { } else {
let writer = handle.get_ref(); let writer = handle.get_ref();
let mut didmany = 0; let mut didmany = 0;
while let (Some(key), Some(val)) = (kviter.next(), kviter.next()) { while let (Some(key), Some(val)) = (act.next(), act.next()) {
if writer.true_if_insert(Data::from(key), Data::from(val)) { if writer.true_if_insert(Data::from(key), Data::from(val)) {
didmany += 1; didmany += 1;
} }

@ -27,25 +27,25 @@
use crate::coredb::Data; use crate::coredb::Data;
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::protocol::responses; use crate::protocol::responses;
use crate::queryengine::ActionIter;
/// Run an `MUPDATE` query /// Run an `MUPDATE` query
pub async fn mupdate<T, Strm>( pub async fn mupdate<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, mut act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
let howmany = act.len() - 1; let howmany = act.len();
if howmany & 1 == 1 || howmany == 0 { if howmany & 1 == 1 || howmany == 0 {
// An odd number of arguments means that the number of keys // An odd number of arguments means that the number of keys
// is not the same as the number of values, we won't run this // is not the same as the number of values, we won't run this
// action at all // action at all
return con.write_response(&**responses::groups::ACTION_ERR).await; return con.write_response(&**responses::groups::ACTION_ERR).await;
} }
let mut kviter = act.into_iter().skip(1);
let done_howmany: Option<usize>; let done_howmany: Option<usize>;
{ {
if handle.is_poisoned() { if handle.is_poisoned() {
@ -53,7 +53,7 @@ where
} else { } else {
let writer = handle.get_ref(); let writer = handle.get_ref();
let mut didmany = 0; let mut didmany = 0;
while let (Some(key), Some(val)) = (kviter.next(), kviter.next()) { while let (Some(key), Some(val)) = (act.next(), act.next()) {
if writer.true_if_update(Data::from(key), Data::from(val)) { if writer.true_if_update(Data::from(key), Data::from(val)) {
didmany += 1; didmany += 1;
} }

@ -30,6 +30,7 @@
use crate::coredb; use crate::coredb;
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::protocol::responses; use crate::protocol::responses;
use crate::queryengine::ActionIter;
use coredb::Data; use coredb::Data;
use std::hint::unreachable_unchecked; use std::hint::unreachable_unchecked;
@ -37,30 +38,25 @@ use std::hint::unreachable_unchecked;
pub async fn set<T, Strm>( pub async fn set<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, mut act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
let howmany = act.len() - 1; crate::err_if_len_is!(act, con, not 2);
if howmany != 2 {
// There should be exactly 2 arguments
return con.write_response(&**responses::groups::ACTION_ERR).await;
}
let mut it = act.into_iter().skip(1);
let did_we = { let did_we = {
if handle.is_poisoned() { if handle.is_poisoned() {
None None
} else { } else {
let writer = handle.get_ref(); let writer = handle.get_ref();
if writer.true_if_insert( if writer.true_if_insert(
Data::from_string(it.next().unwrap_or_else(|| unsafe { Data::from_string(act.next().unwrap_or_else(|| unsafe {
// UNSAFE(@ohsayan): This is completely safe as we've already checked // UNSAFE(@ohsayan): This is completely safe as we've already checked
// that there are exactly 2 arguments // that there are exactly 2 arguments
unreachable_unchecked() unreachable_unchecked()
})), })),
Data::from(it.next().unwrap_or_else(|| unsafe { Data::from(act.next().unwrap_or_else(|| unsafe {
// UNSAFE(@ohsayan): This is completely safe as we've already checked // UNSAFE(@ohsayan): This is completely safe as we've already checked
// that there are exactly 2 arguments // that there are exactly 2 arguments
unreachable_unchecked() unreachable_unchecked()

@ -38,8 +38,8 @@
use crate::coredb::Data; use crate::coredb::Data;
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::protocol::responses; use crate::protocol::responses;
use crate::queryengine::ActionIter;
use std::hint::unreachable_unchecked; use core::hint::unreachable_unchecked;
/// Run an `SSET` query /// Run an `SSET` query
/// ///
@ -48,13 +48,13 @@ use std::hint::unreachable_unchecked;
pub async fn sset<T, Strm>( pub async fn sset<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, mut act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
let howmany = act.len() - 1; let howmany = act.len();
if howmany & 1 == 1 || howmany == 0 { if howmany & 1 == 1 || howmany == 0 {
return con.write_response(&**responses::groups::ACTION_ERR).await; return con.write_response(&**responses::groups::ACTION_ERR).await;
} }
@ -66,13 +66,7 @@ where
// This iterator gives us the keys and values, skipping the first argument which // This iterator gives us the keys and values, skipping the first argument which
// is the action name // is the action name
let mut key_iter = act let mut key_iter = act.as_ref().iter();
.get(1..)
.unwrap_or_else(|| unsafe {
// UNSAFE(@ohsayan): We've already checked if the action group contains more than one arugment
unreachable_unchecked()
})
.iter();
if handle.is_poisoned() { if handle.is_poisoned() {
failed = None; failed = None;
} else { } else {
@ -92,8 +86,7 @@ where
}) { }) {
// Since the failed flag is false, none of the keys existed // Since the failed flag is false, none of the keys existed
// So we can safely set the keys // So we can safely set the keys
let mut iter = act.into_iter().skip(1); while let (Some(key), Some(value)) = (act.next(), act.next()) {
while let (Some(key), Some(value)) = (iter.next(), iter.next()) {
if !mut_table.true_if_insert(Data::from(key), Data::from_string(value)) { if !mut_table.true_if_insert(Data::from(key), Data::from_string(value)) {
// Tell the compiler that this will never be the case // Tell the compiler that this will never be the case
unsafe { unsafe {
@ -126,13 +119,13 @@ where
pub async fn sdel<T, Strm>( pub async fn sdel<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
let howmany = act.len() - 1; let howmany = act.len();
if howmany == 0 { if howmany == 0 {
return con.write_response(&**responses::groups::ACTION_ERR).await; return con.write_response(&**responses::groups::ACTION_ERR).await;
} }
@ -141,13 +134,7 @@ where
// We use this additional scope to tell the compiler that the write lock // We use this additional scope to tell the compiler that the write lock
// doesn't go beyond the scope of this function - and is never used across // doesn't go beyond the scope of this function - and is never used across
// an await: cause, the compiler ain't as smart as we are ;) // an await: cause, the compiler ain't as smart as we are ;)
let mut key_iter = act let mut key_iter = act.as_ref().iter();
.get(1..)
.unwrap_or_else(|| unsafe {
// UNSAFE(@ohsayan): We've already checked if the action group contains more than one arugment
unreachable_unchecked()
})
.iter();
if handle.is_poisoned() { if handle.is_poisoned() {
failed = None; failed = None;
} else { } else {
@ -168,7 +155,7 @@ where
}) { }) {
// Since the failed flag is false, all of the keys exist // Since the failed flag is false, all of the keys exist
// So we can safely delete the keys // So we can safely delete the keys
act.into_iter().skip(1).for_each(|key| { act.into_iter().for_each(|key| {
// Since we've already checked that the keys don't exist // Since we've already checked that the keys don't exist
// We'll tell the compiler to optimize this // We'll tell the compiler to optimize this
let _ = mut_table.remove(key.as_bytes()).unwrap_or_else(|| unsafe { let _ = mut_table.remove(key.as_bytes()).unwrap_or_else(|| unsafe {
@ -198,13 +185,13 @@ where
pub async fn supdate<T, Strm>( pub async fn supdate<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, mut act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
let howmany = act.len() - 1; let howmany = act.len();
if howmany & 1 == 1 || howmany == 0 { if howmany & 1 == 1 || howmany == 0 {
return con.write_response(&**responses::groups::ACTION_ERR).await; return con.write_response(&**responses::groups::ACTION_ERR).await;
} }
@ -213,13 +200,7 @@ where
// We use this additional scope to tell the compiler that the write lock // We use this additional scope to tell the compiler that the write lock
// doesn't go beyond the scope of this function - and is never used across // doesn't go beyond the scope of this function - and is never used across
// an await: cause, the compiler ain't as smart as we are ;) // an await: cause, the compiler ain't as smart as we are ;)
let mut key_iter = act let mut key_iter = act.as_ref().iter();
.get(1..)
.unwrap_or_else(|| unsafe {
// UNSAFE(@ohsayan): We've already checked if the action group contains more than one arugment
unreachable_unchecked()
})
.iter();
if handle.is_poisoned() { if handle.is_poisoned() {
failed = None; failed = None;
} else { } else {
@ -245,8 +226,7 @@ where
}) { }) {
// Since the failed flag is false, none of the keys existed // Since the failed flag is false, none of the keys existed
// So we can safely update the keys // So we can safely update the keys
let mut iter = act.into_iter().skip(1); while let (Some(key), Some(value)) = (act.next(), act.next()) {
while let (Some(key), Some(value)) = (iter.next(), iter.next()) {
if !mut_table.true_if_update(Data::from(key), Data::from_string(value)) { if !mut_table.true_if_update(Data::from(key), Data::from_string(value)) {
// Tell the compiler that this will never be the case // Tell the compiler that this will never be the case
unsafe { unreachable_unchecked() } unsafe { unreachable_unchecked() }

@ -31,6 +31,7 @@
use crate::coredb::{self}; use crate::coredb::{self};
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::protocol::responses; use crate::protocol::responses;
use crate::queryengine::ActionIter;
use coredb::Data; use coredb::Data;
use std::hint::unreachable_unchecked; use std::hint::unreachable_unchecked;
@ -38,30 +39,25 @@ use std::hint::unreachable_unchecked;
pub async fn update<T, Strm>( pub async fn update<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, mut act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
let howmany = act.len() - 1; crate::err_if_len_is!(act, con, not 2);
if howmany != 2 {
// There should be exactly 2 arguments
return con.write_response(&**responses::groups::ACTION_ERR).await;
}
let mut it = act.into_iter().skip(1);
let did_we = { let did_we = {
if handle.is_poisoned() { if handle.is_poisoned() {
None None
} else { } else {
let writer = handle.get_ref(); let writer = handle.get_ref();
if writer.true_if_update( if writer.true_if_update(
Data::from(it.next().unwrap_or_else(|| unsafe { Data::from(act.next().unwrap_or_else(|| unsafe {
// UNSAFE(@ohsayan): We've already checked that the action contains exactly // UNSAFE(@ohsayan): We've already checked that the action contains exactly
// two arguments (excluding the action itself). So, this branch won't ever be reached // two arguments (excluding the action itself). So, this branch won't ever be reached
unreachable_unchecked() unreachable_unchecked()
})), })),
Data::from_string(it.next().unwrap_or_else(|| unsafe { Data::from_string(act.next().unwrap_or_else(|| unsafe {
// UNSAFE(@ohsayan): We've already checked that the action contains exactly // UNSAFE(@ohsayan): We've already checked that the action contains exactly
// two arguments (excluding the action itself). So, this branch won't ever be reached // two arguments (excluding the action itself). So, this branch won't ever be reached
unreachable_unchecked() unreachable_unchecked()

@ -27,6 +27,7 @@
use crate::coredb::Data; use crate::coredb::Data;
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::protocol::responses; use crate::protocol::responses;
use crate::queryengine::ActionIter;
/// Run an `USET` query /// Run an `USET` query
/// ///
@ -34,26 +35,25 @@ use crate::protocol::responses;
pub async fn uset<T, Strm>( pub async fn uset<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, mut act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
let howmany = act.len() - 1; let howmany = act.len();
if howmany & 1 == 1 || howmany == 0 { if howmany & 1 == 1 || howmany == 0 {
// An odd number of arguments means that the number of keys // An odd number of arguments means that the number of keys
// is not the same as the number of values, we won't run this // is not the same as the number of values, we won't run this
// action at all // action at all
return con.write_response(&**responses::groups::ACTION_ERR).await; return con.write_response(&**responses::groups::ACTION_ERR).await;
} }
let mut kviter = act.into_iter().skip(1);
let failed = { let failed = {
if handle.is_poisoned() { if handle.is_poisoned() {
true true
} else { } else {
let writer = handle.get_ref(); let writer = handle.get_ref();
while let (Some(key), Some(val)) = (kviter.next(), kviter.next()) { while let (Some(key), Some(val)) = (act.next(), act.next()) {
let _ = writer.upsert(Data::from(key), Data::from(val)); let _ = writer.upsert(Data::from(key), Data::from(val));
} }
drop(writer); drop(writer);

@ -29,6 +29,7 @@ use crate::diskstore;
use crate::diskstore::snapshot::SnapshotEngine; use crate::diskstore::snapshot::SnapshotEngine;
use crate::diskstore::snapshot::DIR_SNAPSHOT; use crate::diskstore::snapshot::DIR_SNAPSHOT;
use crate::protocol::responses; use crate::protocol::responses;
use crate::queryengine::ActionIter;
use std::hint::unreachable_unchecked; use std::hint::unreachable_unchecked;
use std::path::{Component, PathBuf}; use std::path::{Component, PathBuf};
@ -37,14 +38,13 @@ use std::path::{Component, PathBuf};
pub async fn mksnap<T, Strm>( pub async fn mksnap<T, Strm>(
handle: &crate::coredb::CoreDB, handle: &crate::coredb::CoreDB,
con: &mut T, con: &mut T,
act: Vec<String>, mut act: ActionIter,
) -> std::io::Result<()> ) -> std::io::Result<()>
where where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
let howmany = act.len() - 1; if act.len() == 0 {
if howmany == 0 {
if !handle.is_snapshot_enabled() { if !handle.is_snapshot_enabled() {
// Since snapshotting is disabled, we can't create a snapshot! // Since snapshotting is disabled, we can't create a snapshot!
// We'll just return an error returning the same // We'll just return an error returning the same
@ -100,9 +100,9 @@ where
.await; .await;
} }
} else { } else {
if howmany == 1 { if act.len() == 1 {
// This means that the user wants to create a 'named' snapshot // This means that the user wants to create a 'named' snapshot
let snapname = act.get(1).unwrap_or_else(|| unsafe { let snapname = act.next().unwrap_or_else(|| unsafe {
// UNSAFE(@ohsayan): We've already checked that the action // UNSAFE(@ohsayan): We've already checked that the action
// contains a second argument, so this can't be reached // contains a second argument, so this can't be reached
unreachable_unchecked() unreachable_unchecked()

@ -24,8 +24,6 @@
* *
*/ */
use std::borrow::Cow;
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
#[non_exhaustive] #[non_exhaustive]
/// # Data Types /// # Data Types
@ -41,27 +39,3 @@ pub enum Element {
/// A non-recursive String array; tsymbol: `_` /// A non-recursive String array; tsymbol: `_`
FlatArray(Vec<String>), FlatArray(Vec<String>),
} }
impl Element {
/// This will return a reference to the first element in the element
///
/// If this element is a compound type, it will return a reference to the first element in the compound
/// type
pub fn get_first(&self) -> Option<Cow<String>> {
match self {
Self::Array(elem) => match elem.first() {
Some(el) => match el {
Element::String(st) => Some(Cow::Borrowed(&st)),
_ => None,
},
None => None,
},
Self::FlatArray(elem) => match elem.first() {
Some(el) => Some(Cow::Borrowed(&el)),
None => None,
},
Self::String(ref st) => Some(Cow::Borrowed(&st)),
_ => None,
}
}
}

@ -28,50 +28,36 @@
use crate::coredb::CoreDB; use crate::coredb::CoreDB;
use crate::dbnet::connection::prelude::*; use crate::dbnet::connection::prelude::*;
use crate::gen_match;
use crate::protocol::responses; use crate::protocol::responses;
use crate::protocol::Element; use crate::protocol::Element;
use crate::{actions, admin}; use crate::{actions, admin};
mod tags { use std::vec::IntoIter;
//! This module is a collection of tags/strings used for evaluating queries pub type ActionIter = IntoIter<String>;
//! and responses
/// `GET` action tag macro_rules! gen_constants_and_matches {
pub const TAG_GET: &'static str = "GET"; ($con:ident, $buf:ident, $db:ident, $($action:ident),*; $($fns:expr),*) => {
/// `SET` action tag mod tags {
pub const TAG_SET: &'static str = "SET"; //! This module is a collection of tags/strings used for evaluating queries
/// `UPDATE` action tag //! and responses
pub const TAG_UPDATE: &'static str = "UPDATE"; $(
/// `DEL` action tag pub const $action: &'static str = stringify!($action);
pub const TAG_DEL: &'static str = "DEL"; )*
/// `HEYA` action tag }
pub const TAG_HEYA: &'static str = "HEYA"; let mut first = match $buf.next() {
/// `EXISTS` action tag Some(first) => first,
pub const TAG_EXISTS: &'static str = "EXISTS"; None => return $con.write_response(&**responses::groups::PACKET_ERR).await,
/// `MSET` action tag };
pub const TAG_MSET: &'static str = "MSET"; first.make_ascii_uppercase();
/// `MGET` action tag match first.as_str() {
pub const TAG_MGET: &'static str = "MGET"; $(
/// `MUPDATE` action tag tags::$action => $fns($db, $con, $buf).await?,
pub const TAG_MUPDATE: &'static str = "MUPDATE"; )*
/// `SSET` action tag _ => {
pub const TAG_SSET: &'static str = "SSET"; return $con.write_response(&**responses::groups::UNKNOWN_ACTION).await;
/// `SDEL` action tag }
pub const TAG_SDEL: &'static str = "SDEL"; }
/// `SUPDATE` action tag };
pub const TAG_SUPDATE: &'static str = "SUPDATE";
/// `DBSIZE` action tag
pub const TAG_DBSIZE: &'static str = "DBSIZE";
/// `FLUSHDB` action tag
pub const TAG_FLUSHDB: &'static str = "FLUSHDB";
/// `USET` action tag
pub const TAG_USET: &'static str = "USET";
/// `KEYLEN` action tag
pub const TAG_KEYLEN: &'static str = "KEYLEN";
/// `MKSNAP` action tag
pub const TAG_MKSNAP: &'static str = "MKSNAP";
/// `LSKEYS` action tag
pub const TAG_LSKEYS: &str = "LSKEYS";
} }
/// Execute a simple(*) query /// Execute a simple(*) query
@ -80,59 +66,35 @@ where
T: ProtocolConnectionExt<Strm>, T: ProtocolConnectionExt<Strm>,
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync, Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{ {
let first = match buf.get_first() { let buf = if let Element::FlatArray(arr) = buf {
Some(element) => element.to_ascii_uppercase(), arr
None => return con.write_response(&**responses::groups::PACKET_ERR).await, } else {
return con
.write_response(&**responses::full_responses::R_WRONGTYPE_ERR)
.await;
}; };
gen_match!( let mut buf = buf.into_iter();
first, gen_constants_and_matches!(
db, con, buf, db, GET, SET, UPDATE, DEL, HEYA, EXISTS, MSET, MGET, MUPDATE, SSET, SDEL,
con, SUPDATE, DBSIZE, FLUSHDB, USET, KEYLEN, MKSNAP, LSKEYS;
buf, actions::get::get,
tags::TAG_DEL => actions::del::del, actions::set::set,
tags::TAG_GET => actions::get::get, actions::update::update,
tags::TAG_HEYA => actions::heya::heya, actions::del::del,
tags::TAG_EXISTS => actions::exists::exists, actions::heya::heya,
tags::TAG_SET => actions::set::set, actions::exists::exists,
tags::TAG_MGET => actions::mget::mget, actions::mset::mset,
tags::TAG_MSET => actions::mset::mset, actions::mget::mget,
tags::TAG_UPDATE => actions::update::update, actions::mupdate::mupdate,
tags::TAG_MUPDATE => actions::mupdate::mupdate, actions::strong::sset,
tags::TAG_SSET => actions::strong::sset, actions::strong::sdel,
tags::TAG_SDEL => actions::strong::sdel, actions::strong::supdate,
tags::TAG_SUPDATE => actions::strong::supdate, actions::dbsize::dbsize,
tags::TAG_DBSIZE => actions::dbsize::dbsize, actions::flushdb::flushdb,
tags::TAG_FLUSHDB => actions::flushdb::flushdb, actions::uset::uset,
tags::TAG_USET => actions::uset::uset, actions::keylen::keylen,
tags::TAG_KEYLEN => actions::keylen::keylen, admin::mksnap::mksnap,
tags::TAG_MKSNAP => admin::mksnap::mksnap, actions::lskeys::lskeys
tags::TAG_LSKEYS => actions::lskeys::lskeys
); );
Ok(()) Ok(())
} }
#[macro_export]
/// A match generator macro built specifically for the `crate::queryengine::execute_simple` function
///
/// **NOTE:** This macro needs _paths_ for both sides of the $x => $y, to produce something sensible
macro_rules! gen_match {
($pre:ident, $db:ident, $con:ident, $buf:ident, $($x:pat => $y:expr),*) => {
let flat_array = if let crate::protocol::Element::FlatArray(array) = $buf {
array
} else {
return $con.write_response(&**responses::groups::WRONGTYPE_ERR).await;
};
match $pre.as_str() {
// First repeat over all the $x => $y patterns, passing in the variables
// and adding .await calls and adding the `?`
$(
$x => $y($db, $con, flat_array).await?,
)*
// Now add the final case where no action is matched
_ => {
return $con.write_response(&**responses::groups::UNKNOWN_ACTION)
.await;
},
}
};
}

Loading…
Cancel
Save