From 6414b05fa66f57d6bee5be3715bb57724cfe8f27 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Thu, 4 Apr 2024 10:35:22 +0530 Subject: [PATCH] net: Fix protocol impl and improve protocol testing coverage --- server/src/engine/net/protocol/exchange.rs | 42 ++- server/src/engine/net/protocol/tests.rs | 239 +++++++++++++++--- .../src/engine/tests/client_misc/sec/mod.rs | 17 +- server/src/util/macros.rs | 75 ------ 4 files changed, 251 insertions(+), 122 deletions(-) diff --git a/server/src/engine/net/protocol/exchange.rs b/server/src/engine/net/protocol/exchange.rs index 9f510949..98cd9cb9 100644 --- a/server/src/engine/net/protocol/exchange.rs +++ b/server/src/engine/net/protocol/exchange.rs @@ -64,6 +64,10 @@ impl<'a> SQuery<'a> { fn new(buf: &'a [u8], q_window: usize) -> Self { Self { buf, q_window } } + #[cfg(test)] + pub(super) fn _new(buf: &'a [u8], q_window: usize) -> Self { + Self::new(buf, q_window) + } pub fn query(&self) -> &[u8] { &self.buf[..self.q_window] } @@ -107,7 +111,7 @@ fn scan_usize_guaranteed_termination( } #[derive(Clone, Copy, PartialEq)] -struct Usize { +pub(super) struct Usize { v: isize, } @@ -119,9 +123,13 @@ impl Usize { Self { v } } #[inline(always)] - const fn new_unflagged(int: usize) -> Self { + pub(super) const fn new_unflagged(int: usize) -> Self { Self::new(int as isize) } + #[cfg(test)] + pub(super) const fn new_flagged(int: usize) -> Self { + Self::new(int as isize | Self::MASK) + } #[inline(always)] fn int(&self) -> usize { (self.v & !Self::MASK) as usize @@ -199,6 +207,10 @@ impl SQState { const fn new(packet_s: Usize) -> Self { Self { packet_s } } + #[cfg(test)] + pub(super) const fn _new(s: Usize) -> Self { + Self::new(s) + } } #[derive(Debug, PartialEq)] @@ -288,14 +300,22 @@ impl<'a> Exchange<'a> { .map_err(|_| ExchangeError::NotAsciiByteOrOverflow)?; if sq_state.packet_s.flag() & self.scanner.has_left(sq_state.packet_s.int()) { // we have the full packet size and the required data + // scan the query window + let start = self.scanner.cursor(); let q_window = scan_usize_guaranteed_termination(&mut self.scanner)?; + let stop = self.scanner.cursor(); + // now compute remaining buffer length and nonzero condition + let expected_remaining_buffer = sq_state.packet_s.int() - (stop - start); let nonzero = (q_window != 0) & (sq_state.packet_s.int() != 0); - if compiler::likely(self.scanner.remaining_size_is(sq_state.packet_s.int()) & nonzero) { + // validate and return + if compiler::likely(self.scanner.remaining_size_is(expected_remaining_buffer) & nonzero) + { // this check is important because the client might have given us an incorrect q size - Ok(ExchangeResult::Simple(SQuery::new( - self.scanner.current_buffer(), - q_window, - ))) + let block = unsafe { + // UNSAFE(@ohsayan): just verified earlier + self.scanner.next_chunk_variable(expected_remaining_buffer) + }; + Ok(ExchangeResult::Simple(SQuery::new(block, q_window))) } else { Err(ExchangeError::IncorrectQuerySizeOrMoreBytes) } @@ -311,9 +331,11 @@ impl<'a> Exchange<'a> { .map_err(|_| ExchangeError::NotAsciiByteOrOverflow)?; if pipe_s.packet_s.flag() & self.scanner.remaining_size_is(pipe_s.packet_s.int()) { // great, we have the entire packet - Ok(ExchangeResult::Pipeline(Pipeline::new( - self.scanner.current_buffer(), - ))) + let block = unsafe { + // UNSAFE(@ohsayan): just verified earlier + self.scanner.next_chunk_variable(pipe_s.packet_s.int()) + }; + Ok(ExchangeResult::Pipeline(Pipeline::new(block))) } else { Ok(ExchangeResult::NewState(ExchangeState::Pipeline(pipe_s))) } diff --git a/server/src/engine/net/protocol/tests.rs b/server/src/engine/net/protocol/tests.rs index bd9d2cf9..6c04bc4b 100644 --- a/server/src/engine/net/protocol/tests.rs +++ b/server/src/engine/net/protocol/tests.rs @@ -31,35 +31,18 @@ use { }, crate::engine::{ mem::BufferedScanner, - net::protocol::handshake::{ - AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode, - HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode, + net::protocol::{ + exchange::{SQState, Usize}, + handshake::{ + AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode, + HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode, + }, + SQuery, }, }, + std::ops::Range, }; -pub(super) fn create_simple_query(query: &str, params: [&str; N]) -> Vec { - let mut buf = vec![]; - let query_size_as_string = query.len().to_string(); - let total_packet_size = query.len() - + params.iter().map(|l| l.len()).sum::() - + query_size_as_string.len() - + 1; - // segment 1 - buf.push(b'S'); - buf.extend(total_packet_size.to_string().as_bytes()); - buf.push(b'\n'); - // segment - buf.extend(query_size_as_string.as_bytes()); - buf.push(b'\n'); - // dataframe - buf.extend(query.as_bytes()); - params - .into_iter() - .for_each(|param| buf.extend(param.as_bytes())); - buf -} - /* client handshake */ @@ -278,14 +261,15 @@ fn hs_bad_auth_mode() { QT-DEX */ -fn iterate_payload(payload: &str, start: usize, f: impl Fn(usize, &[u8])) { - for i in start..payload.len() { - f(i, &payload.as_bytes()[..i]) +fn iterate_payload(payload: impl AsRef<[u8]>, start: usize, f: impl Fn(usize, &[u8])) { + let payload = payload.as_ref(); + for i in start..=payload.len() { + f(i, &payload[..i]) } } fn iterate_exchange_payload( - payload: &str, + payload: impl AsRef<[u8]>, start: usize, f: impl Fn(usize, Result<(ExchangeResult, usize), ExchangeError>), ) { @@ -296,7 +280,7 @@ fn iterate_exchange_payload( } fn iterate_exchange_payload_from_zero( - payload: &str, + payload: impl AsRef<[u8]>, f: impl Fn(usize, Result<(ExchangeResult, usize), ExchangeError>), ) { iterate_exchange_payload(payload, 0, f) @@ -309,7 +293,6 @@ fn iterate_exchange_payload_from_zero( #[test] fn zero_sized_packet() { for payload in [ - "S\n", // zero packet "S0\n", // zero packet "S2\n0\n", // zero query "S1\n\n", // zero query @@ -317,7 +300,11 @@ fn zero_sized_packet() { iterate_exchange_payload_from_zero(payload, |size, result| { if size == payload.len() { // we got the full payload - assert_eq!(result, Err(ExchangeError::IncorrectQuerySizeOrMoreBytes)) + if payload.len() == 3 { + assert_eq!(result, Err(ExchangeError::UnterminatedInteger)) + } else { + assert_eq!(result, Err(ExchangeError::IncorrectQuerySizeOrMoreBytes)) + } } else { // we don't have the full payload if size < 3 { @@ -356,3 +343,191 @@ fn invalid_first_byte() { }) } } + +pub struct EQuery { + // payload + payload: String, + variable_range: Range, + // query + query: String, + query_range: Range, + // params + params: &'static [&'static str], + param_range: Range, + param_indices: Vec>, +} + +impl EQuery { + fn new(query: String, params: &'static [&'static str]) -> Self { + var!(let variable_start, variable_end, query_start, query_end, param_start); + /* + prepare the "back" of the payload + */ + let encoded_params: String = params.iter().flat_map(|param| param.chars()).collect(); + let total_size = query.len() + encoded_params.len(); + let total_size_string = format!("{total_size}\n"); + + /* + compute offsets + */ + + let packet_size = total_size_string.len() + total_size; + let mut buffer = String::new(); + buffer.push('S'); + buffer.push_str(&format!("{packet_size}\n")); + + // record start of variable block + variable_start = buffer.len(); + + buffer.push_str(&query.len().to_string()); + buffer.push('\n'); + + // record start of query + query_start = buffer.len(); + buffer.push_str(&query); + query_end = buffer.len(); + + // record start of params + param_start = buffer.len(); + let mut param_indices = Vec::new(); + for param in params { + let start = buffer.len(); + buffer.push_str(param); + param_indices.push(start..buffer.len()); + } + + variable_end = buffer.len(); + Self { + payload: buffer, + variable_range: variable_start..variable_end, + query, + query_range: query_start..query_end, + params, + param_range: param_start..variable_end, + param_indices, + } + } +} + +#[test] +fn ext_query() { + let ext_query = EQuery::new("create space myspace".to_owned(), &["sayan", "pass", ""]); + let query_starts_at = ext_query.payload[ext_query.variable_range.clone()] + .find('\n') + .unwrap() + + 1; + assert_eq!( + &ext_query.payload[ext_query.variable_range.clone()] + [query_starts_at..query_starts_at + ext_query.query.len()], + ext_query.query + ); + assert_eq!(ext_query.query, &ext_query.payload[ext_query.query_range]); + assert_eq!("sayanpass", &ext_query.payload[ext_query.param_range]); + for (param_indices, real_param) in ext_query.param_indices.iter().zip(ext_query.params) { + assert_eq!(*real_param, &ext_query.payload[param_indices.clone()]); + } +} + +/* + simple queries +*/ + +const fn dig_count(real: usize) -> usize { + // count the number of digits + let mut dig_count = 0; + let mut real_ = real; + while real_ != 0 { + dig_count += 1; + real_ /= 10; + } + // account for a `0` + dig_count += (real == 0) as usize; + dig_count +} + +const fn nth_position_value(mut real: usize, mut pos: usize) -> usize { + let digits = dig_count(real); + while digits != pos { + real /= 10; + pos += 1; + } + real +} + +#[test] +fn simple_query() { + for query in [ + // small query without params + EQuery::new("small query".to_owned(), &[]), + // small query with params + EQuery::new("small query".to_owned(), &["hello", "world"]), + // giant query without params + EQuery::new( + "abcdefghijklmnopqrstuvwxyz 123456789 ABCDEFGHIJKLMNOPQRSTUVWXYZ".repeat(1000), + &[], + ), + // giant query with params + EQuery::new( + "abcdefghijklmnopqrstuvwxyz 123456789 ABCDEFGHIJKLMNOPQRSTUVWXYZ".repeat(1000), + &["hello", "world"], + ), + ] { + iterate_exchange_payload_from_zero(query.payload.as_bytes(), |read_position, result| { + /* + S\n\n + ^ variable ^query ^param + range start start start + + - if before (variable range start - 1) then depending on the position from the first byte we will have, say the query size is 123 + then we will have wrt distance from first byte (i.e position - 1) [1], [12], [123] + - if at (variable range start - 1) then we will have the exact size at [123] and in completed state + - if >= query start, then we will continue to issue changes of state until we have the full size which will be caught in a different branch + */ + if read_position < 3 { + // didn't reach minimum threshold + assert_eq!( + result, + Ok((ExchangeResult::NewState(ExchangeState::Initial), 0)) + ) + } else if read_position <= query.variable_range.start - 1 { + let index = read_position - 1; + assert_eq!( + result, + Ok(( + ExchangeResult::NewState(ExchangeState::Simple(SQState::_new( + Usize::new_unflagged(nth_position_value( + query.variable_range.len(), + index + )) + ))), + read_position + )) + ) + } else if read_position >= query.variable_range.start { + if read_position == query.payload.len() { + let (result, cursor) = result.unwrap(); + assert_eq!(cursor, query.payload.len()); + assert_eq!( + result, + ExchangeResult::Simple(SQuery::_new( + query.payload[query.query_range.start..].as_bytes(), + query.query_range.len() + )) + ); + } else { + assert_eq!( + result, + Ok(( + ExchangeResult::NewState(ExchangeState::Simple(SQState::_new( + Usize::new_flagged(query.variable_range.len()) + ))), + query.variable_range.start // the cursor will not go ahead until the full query is read + )) + ) + } + } else { + unreachable!() + } + }) + } +} diff --git a/server/src/engine/tests/client_misc/sec/mod.rs b/server/src/engine/tests/client_misc/sec/mod.rs index ff3df078..9c57c5c9 100644 --- a/server/src/engine/tests/client_misc/sec/mod.rs +++ b/server/src/engine/tests/client_misc/sec/mod.rs @@ -40,6 +40,7 @@ use { const INVALID_SYNTAX_ERR: u16 = QueryError::QLInvalidSyntax.value_u8() as u16; const EXPECTED_STATEMENT_ERR: u16 = QueryError::QLExpectedStatement.value_u8() as u16; const UNKNOWN_STMT_ERR: u16 = QueryError::QLUnknownStatement.value_u8() as u16; +const ILLEGAL_PACKET: u16 = QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16; #[dbtest] fn deny_unknown_tokens() { @@ -48,11 +49,17 @@ fn deny_unknown_tokens() { "model", "space", "where", "force", "into", "from", "with", "set", "add", "remove", "*", ",", "", ] { - assert_err_eq!( - db.query_parse::<()>(&query!(token)), - Error::ServerError(EXPECTED_STATEMENT_ERR), - "{token}", - ); + let result = db.query_parse::<()>(&query!(token)); + if token.is_empty() { + // the server will reject empty queries + assert_err_eq!(result, Error::ServerError(ILLEGAL_PACKET), "{token}") + } else { + assert_err_eq!( + result, + Error::ServerError(EXPECTED_STATEMENT_ERR), + "{token}", + ); + } } } diff --git a/server/src/util/macros.rs b/server/src/util/macros.rs index cc8b014a..aa72e22b 100644 --- a/server/src/util/macros.rs +++ b/server/src/util/macros.rs @@ -132,81 +132,6 @@ macro_rules! assert_hmeq { }; } -#[macro_export] -/// ## The action macro -/// -/// A macro for adding all the _fuss_ to an action. Implementing actions should be simple -/// and should not require us to repeatedly specify generic paramters and/or trait bounds. -/// This is exactly what this macro does: does all the _magic_ behind the scenes for you, -/// including adding generic parameters, handling docs (if any), adding the correct -/// trait bounds and finally making your function async. Rest knowing that all your -/// action requirements have been happily addressed with this macro and that you don't have -/// to write a lot of code to do the exact same thing -/// -/// -/// ## Limitations -/// -/// This macro can only handle mutable parameters for a fixed number of arguments (three) -/// -macro_rules! action { - ( - $($(#[$attr:meta])* - fn $fname:ident($($argname:ident: $argty:ty),* $(,)?) - $block:block)* - ) => { - $($(#[$attr])* - pub async fn $fname< - 'a, - C: 'a + $crate::dbnet::BufferedSocketStream, - P: $crate::protocol::interface::ProtocolSpec, - > ( - $($argname: $argty,)* - ) -> $crate::actions::ActionResult<()> - $block)* - }; - ( - $($(#[$attr:meta])* - fn $fname:ident( - $argone:ident: $argonety:ty, - $argtwo:ident: $argtwoty:ty, - mut $argthree:ident: $argthreety:ty $(,)? - ) $block:block)* - ) => { - $($(#[$attr])* - pub async fn $fname< - 'a, - C: 'a + $crate::dbnet::BufferedSocketStream, - P: $crate::protocol::interface::ProtocolSpec, - >( - $argone: $argonety, - $argtwo: $argtwoty, - mut $argthree: $argthreety - ) -> $crate::actions::ActionResult<()> - $block)* - }; - ( - $($(#[$attr:meta])* - fn $fname:ident( - $argone:ident: $argonety:ty, - $argtwo:ident: $argtwoty:ty, - $argthree:ident: $argthreety:ty - ) $block:block)* - ) => { - $($(#[$attr])* - pub async fn $fname< - 'a, - T: 'a + $crate::dbnet::connection::ClientConnection, - Strm: $crate::dbnet::connection::Stream, - P: $crate::protocol::interface::ProtocolSpec - >( - $argone: $argonety, - $argtwo: $argtwoty, - $argthree: $argthreety - ) -> $crate::actions::ActionResult<()> - $block)* - }; -} - #[macro_export] macro_rules! byt { ($f:expr) => {