net: Fix protocol impl and improve protocol testing coverage

next
Sayan Nandan 6 months ago
parent 2d905d07d0
commit 6414b05fa6
No known key found for this signature in database
GPG Key ID: 0EBD769024B24F0A

@ -64,6 +64,10 @@ impl<'a> SQuery<'a> {
fn new(buf: &'a [u8], q_window: usize) -> Self {
Self { buf, q_window }
}
#[cfg(test)]
pub(super) fn _new(buf: &'a [u8], q_window: usize) -> Self {
Self::new(buf, q_window)
}
pub fn query(&self) -> &[u8] {
&self.buf[..self.q_window]
}
@ -107,7 +111,7 @@ fn scan_usize_guaranteed_termination(
}
#[derive(Clone, Copy, PartialEq)]
struct Usize {
pub(super) struct Usize {
v: isize,
}
@ -119,9 +123,13 @@ impl Usize {
Self { v }
}
#[inline(always)]
const fn new_unflagged(int: usize) -> Self {
pub(super) const fn new_unflagged(int: usize) -> Self {
Self::new(int as isize)
}
#[cfg(test)]
pub(super) const fn new_flagged(int: usize) -> Self {
Self::new(int as isize | Self::MASK)
}
#[inline(always)]
fn int(&self) -> usize {
(self.v & !Self::MASK) as usize
@ -199,6 +207,10 @@ impl SQState {
const fn new(packet_s: Usize) -> Self {
Self { packet_s }
}
#[cfg(test)]
pub(super) const fn _new(s: Usize) -> Self {
Self::new(s)
}
}
#[derive(Debug, PartialEq)]
@ -288,14 +300,22 @@ impl<'a> Exchange<'a> {
.map_err(|_| ExchangeError::NotAsciiByteOrOverflow)?;
if sq_state.packet_s.flag() & self.scanner.has_left(sq_state.packet_s.int()) {
// we have the full packet size and the required data
// scan the query window
let start = self.scanner.cursor();
let q_window = scan_usize_guaranteed_termination(&mut self.scanner)?;
let stop = self.scanner.cursor();
// now compute remaining buffer length and nonzero condition
let expected_remaining_buffer = sq_state.packet_s.int() - (stop - start);
let nonzero = (q_window != 0) & (sq_state.packet_s.int() != 0);
if compiler::likely(self.scanner.remaining_size_is(sq_state.packet_s.int()) & nonzero) {
// validate and return
if compiler::likely(self.scanner.remaining_size_is(expected_remaining_buffer) & nonzero)
{
// this check is important because the client might have given us an incorrect q size
Ok(ExchangeResult::Simple(SQuery::new(
self.scanner.current_buffer(),
q_window,
)))
let block = unsafe {
// UNSAFE(@ohsayan): just verified earlier
self.scanner.next_chunk_variable(expected_remaining_buffer)
};
Ok(ExchangeResult::Simple(SQuery::new(block, q_window)))
} else {
Err(ExchangeError::IncorrectQuerySizeOrMoreBytes)
}
@ -311,9 +331,11 @@ impl<'a> Exchange<'a> {
.map_err(|_| ExchangeError::NotAsciiByteOrOverflow)?;
if pipe_s.packet_s.flag() & self.scanner.remaining_size_is(pipe_s.packet_s.int()) {
// great, we have the entire packet
Ok(ExchangeResult::Pipeline(Pipeline::new(
self.scanner.current_buffer(),
)))
let block = unsafe {
// UNSAFE(@ohsayan): just verified earlier
self.scanner.next_chunk_variable(pipe_s.packet_s.int())
};
Ok(ExchangeResult::Pipeline(Pipeline::new(block)))
} else {
Ok(ExchangeResult::NewState(ExchangeState::Pipeline(pipe_s)))
}

@ -31,35 +31,18 @@ use {
},
crate::engine::{
mem::BufferedScanner,
net::protocol::handshake::{
net::protocol::{
exchange::{SQState, Usize},
handshake::{
AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode,
HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode,
},
SQuery,
},
},
std::ops::Range,
};
pub(super) fn create_simple_query<const N: usize>(query: &str, params: [&str; N]) -> Vec<u8> {
let mut buf = vec![];
let query_size_as_string = query.len().to_string();
let total_packet_size = query.len()
+ params.iter().map(|l| l.len()).sum::<usize>()
+ query_size_as_string.len()
+ 1;
// segment 1
buf.push(b'S');
buf.extend(total_packet_size.to_string().as_bytes());
buf.push(b'\n');
// segment
buf.extend(query_size_as_string.as_bytes());
buf.push(b'\n');
// dataframe
buf.extend(query.as_bytes());
params
.into_iter()
.for_each(|param| buf.extend(param.as_bytes()));
buf
}
/*
client handshake
*/
@ -278,14 +261,15 @@ fn hs_bad_auth_mode() {
QT-DEX
*/
fn iterate_payload(payload: &str, start: usize, f: impl Fn(usize, &[u8])) {
for i in start..payload.len() {
f(i, &payload.as_bytes()[..i])
fn iterate_payload(payload: impl AsRef<[u8]>, start: usize, f: impl Fn(usize, &[u8])) {
let payload = payload.as_ref();
for i in start..=payload.len() {
f(i, &payload[..i])
}
}
fn iterate_exchange_payload(
payload: &str,
payload: impl AsRef<[u8]>,
start: usize,
f: impl Fn(usize, Result<(ExchangeResult, usize), ExchangeError>),
) {
@ -296,7 +280,7 @@ fn iterate_exchange_payload(
}
fn iterate_exchange_payload_from_zero(
payload: &str,
payload: impl AsRef<[u8]>,
f: impl Fn(usize, Result<(ExchangeResult, usize), ExchangeError>),
) {
iterate_exchange_payload(payload, 0, f)
@ -309,7 +293,6 @@ fn iterate_exchange_payload_from_zero(
#[test]
fn zero_sized_packet() {
for payload in [
"S\n", // zero packet
"S0\n", // zero packet
"S2\n0\n", // zero query
"S1\n\n", // zero query
@ -317,7 +300,11 @@ fn zero_sized_packet() {
iterate_exchange_payload_from_zero(payload, |size, result| {
if size == payload.len() {
// we got the full payload
if payload.len() == 3 {
assert_eq!(result, Err(ExchangeError::UnterminatedInteger))
} else {
assert_eq!(result, Err(ExchangeError::IncorrectQuerySizeOrMoreBytes))
}
} else {
// we don't have the full payload
if size < 3 {
@ -356,3 +343,191 @@ fn invalid_first_byte() {
})
}
}
pub struct EQuery {
// payload
payload: String,
variable_range: Range<usize>,
// query
query: String,
query_range: Range<usize>,
// params
params: &'static [&'static str],
param_range: Range<usize>,
param_indices: Vec<Range<usize>>,
}
impl EQuery {
fn new(query: String, params: &'static [&'static str]) -> Self {
var!(let variable_start, variable_end, query_start, query_end, param_start);
/*
prepare the "back" of the payload
*/
let encoded_params: String = params.iter().flat_map(|param| param.chars()).collect();
let total_size = query.len() + encoded_params.len();
let total_size_string = format!("{total_size}\n");
/*
compute offsets
*/
let packet_size = total_size_string.len() + total_size;
let mut buffer = String::new();
buffer.push('S');
buffer.push_str(&format!("{packet_size}\n"));
// record start of variable block
variable_start = buffer.len();
buffer.push_str(&query.len().to_string());
buffer.push('\n');
// record start of query
query_start = buffer.len();
buffer.push_str(&query);
query_end = buffer.len();
// record start of params
param_start = buffer.len();
let mut param_indices = Vec::new();
for param in params {
let start = buffer.len();
buffer.push_str(param);
param_indices.push(start..buffer.len());
}
variable_end = buffer.len();
Self {
payload: buffer,
variable_range: variable_start..variable_end,
query,
query_range: query_start..query_end,
params,
param_range: param_start..variable_end,
param_indices,
}
}
}
#[test]
fn ext_query() {
let ext_query = EQuery::new("create space myspace".to_owned(), &["sayan", "pass", ""]);
let query_starts_at = ext_query.payload[ext_query.variable_range.clone()]
.find('\n')
.unwrap()
+ 1;
assert_eq!(
&ext_query.payload[ext_query.variable_range.clone()]
[query_starts_at..query_starts_at + ext_query.query.len()],
ext_query.query
);
assert_eq!(ext_query.query, &ext_query.payload[ext_query.query_range]);
assert_eq!("sayanpass", &ext_query.payload[ext_query.param_range]);
for (param_indices, real_param) in ext_query.param_indices.iter().zip(ext_query.params) {
assert_eq!(*real_param, &ext_query.payload[param_indices.clone()]);
}
}
/*
simple queries
*/
const fn dig_count(real: usize) -> usize {
// count the number of digits
let mut dig_count = 0;
let mut real_ = real;
while real_ != 0 {
dig_count += 1;
real_ /= 10;
}
// account for a `0`
dig_count += (real == 0) as usize;
dig_count
}
const fn nth_position_value(mut real: usize, mut pos: usize) -> usize {
let digits = dig_count(real);
while digits != pos {
real /= 10;
pos += 1;
}
real
}
#[test]
fn simple_query() {
for query in [
// small query without params
EQuery::new("small query".to_owned(), &[]),
// small query with params
EQuery::new("small query".to_owned(), &["hello", "world"]),
// giant query without params
EQuery::new(
"abcdefghijklmnopqrstuvwxyz 123456789 ABCDEFGHIJKLMNOPQRSTUVWXYZ".repeat(1000),
&[],
),
// giant query with params
EQuery::new(
"abcdefghijklmnopqrstuvwxyz 123456789 ABCDEFGHIJKLMNOPQRSTUVWXYZ".repeat(1000),
&["hello", "world"],
),
] {
iterate_exchange_payload_from_zero(query.payload.as_bytes(), |read_position, result| {
/*
S<packet size>\n<query window>\n<query><param>
^ variable ^query ^param
range start start start
- if before (variable range start - 1) then depending on the position from the first byte we will have, say the query size is 123
then we will have wrt distance from first byte (i.e position - 1) [1], [12], [123]
- if at (variable range start - 1) then we will have the exact size at [123] and in completed state
- if >= query start, then we will continue to issue changes of state until we have the full size which will be caught in a different branch
*/
if read_position < 3 {
// didn't reach minimum threshold
assert_eq!(
result,
Ok((ExchangeResult::NewState(ExchangeState::Initial), 0))
)
} else if read_position <= query.variable_range.start - 1 {
let index = read_position - 1;
assert_eq!(
result,
Ok((
ExchangeResult::NewState(ExchangeState::Simple(SQState::_new(
Usize::new_unflagged(nth_position_value(
query.variable_range.len(),
index
))
))),
read_position
))
)
} else if read_position >= query.variable_range.start {
if read_position == query.payload.len() {
let (result, cursor) = result.unwrap();
assert_eq!(cursor, query.payload.len());
assert_eq!(
result,
ExchangeResult::Simple(SQuery::_new(
query.payload[query.query_range.start..].as_bytes(),
query.query_range.len()
))
);
} else {
assert_eq!(
result,
Ok((
ExchangeResult::NewState(ExchangeState::Simple(SQState::_new(
Usize::new_flagged(query.variable_range.len())
))),
query.variable_range.start // the cursor will not go ahead until the full query is read
))
)
}
} else {
unreachable!()
}
})
}
}

@ -40,6 +40,7 @@ use {
const INVALID_SYNTAX_ERR: u16 = QueryError::QLInvalidSyntax.value_u8() as u16;
const EXPECTED_STATEMENT_ERR: u16 = QueryError::QLExpectedStatement.value_u8() as u16;
const UNKNOWN_STMT_ERR: u16 = QueryError::QLUnknownStatement.value_u8() as u16;
const ILLEGAL_PACKET: u16 = QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16;
#[dbtest]
fn deny_unknown_tokens() {
@ -48,13 +49,19 @@ fn deny_unknown_tokens() {
"model", "space", "where", "force", "into", "from", "with", "set", "add", "remove", "*",
",", "",
] {
let result = db.query_parse::<()>(&query!(token));
if token.is_empty() {
// the server will reject empty queries
assert_err_eq!(result, Error::ServerError(ILLEGAL_PACKET), "{token}")
} else {
assert_err_eq!(
db.query_parse::<()>(&query!(token)),
result,
Error::ServerError(EXPECTED_STATEMENT_ERR),
"{token}",
);
}
}
}
#[dbtest(username = "root", password = "")]
fn ensure_empty_password_returns_hs_error_5() {

@ -132,81 +132,6 @@ macro_rules! assert_hmeq {
};
}
#[macro_export]
/// ## The action macro
///
/// A macro for adding all the _fuss_ to an action. Implementing actions should be simple
/// and should not require us to repeatedly specify generic paramters and/or trait bounds.
/// This is exactly what this macro does: does all the _magic_ behind the scenes for you,
/// including adding generic parameters, handling docs (if any), adding the correct
/// trait bounds and finally making your function async. Rest knowing that all your
/// action requirements have been happily addressed with this macro and that you don't have
/// to write a lot of code to do the exact same thing
///
///
/// ## Limitations
///
/// This macro can only handle mutable parameters for a fixed number of arguments (three)
///
macro_rules! action {
(
$($(#[$attr:meta])*
fn $fname:ident($($argname:ident: $argty:ty),* $(,)?)
$block:block)*
) => {
$($(#[$attr])*
pub async fn $fname<
'a,
C: 'a + $crate::dbnet::BufferedSocketStream,
P: $crate::protocol::interface::ProtocolSpec,
> (
$($argname: $argty,)*
) -> $crate::actions::ActionResult<()>
$block)*
};
(
$($(#[$attr:meta])*
fn $fname:ident(
$argone:ident: $argonety:ty,
$argtwo:ident: $argtwoty:ty,
mut $argthree:ident: $argthreety:ty $(,)?
) $block:block)*
) => {
$($(#[$attr])*
pub async fn $fname<
'a,
C: 'a + $crate::dbnet::BufferedSocketStream,
P: $crate::protocol::interface::ProtocolSpec,
>(
$argone: $argonety,
$argtwo: $argtwoty,
mut $argthree: $argthreety
) -> $crate::actions::ActionResult<()>
$block)*
};
(
$($(#[$attr:meta])*
fn $fname:ident(
$argone:ident: $argonety:ty,
$argtwo:ident: $argtwoty:ty,
$argthree:ident: $argthreety:ty
) $block:block)*
) => {
$($(#[$attr])*
pub async fn $fname<
'a,
T: 'a + $crate::dbnet::connection::ClientConnection<P, Strm>,
Strm: $crate::dbnet::connection::Stream,
P: $crate::protocol::interface::ProtocolSpec
>(
$argone: $argonety,
$argtwo: $argtwoty,
$argthree: $argthreety
) -> $crate::actions::ActionResult<()>
$block)*
};
}
#[macro_export]
macro_rules! byt {
($f:expr) => {

Loading…
Cancel
Save