diff --git a/.github/workflows/test-pr.yml b/.github/workflows/test-pr.yml index 6aebd0b6..1dc798c7 100644 --- a/.github/workflows/test-pr.yml +++ b/.github/workflows/test-pr.yml @@ -39,10 +39,6 @@ jobs: rustup update rustup default ${{ matrix.rust }} if: env.BUILD == 'true' - - name: Install perl modules - uses: perl-actions/install-with-cpanm@v1 - with: - install: "HTML::Entities" - name: Run Tests run: make test env: diff --git a/.github/workflows/test-push.yml b/.github/workflows/test-push.yml index e794039d..2684d037 100644 --- a/.github/workflows/test-push.yml +++ b/.github/workflows/test-push.yml @@ -43,10 +43,6 @@ jobs: brew install gnu-tar echo "/usr/local/opt/gnu-tar/libexec/gnubin" >> $GITHUB_PATH if: runner.os == 'macOS' - - name: Install perl modules - uses: perl-actions/install-with-cpanm@v1 - with: - install: "HTML::Entities" - name: Setup environment run: | chmod +x ci/setvars.sh @@ -93,10 +89,6 @@ jobs: run: | chmod +x ci/setvars.sh ci/setvars.sh - - name: Install perl modules - uses: perl-actions/install-with-cpanm@v1 - with: - install: "HTML::Entities" - name: Restore cache uses: actions/cache@v3 @@ -133,10 +125,6 @@ jobs: uses: actions/checkout@v2 with: fetch-depth: 2 - - name: Install perl modules - uses: perl-actions/install-with-cpanm@v1 - with: - install: "HTML::Entities" - name: Setup environment run: | diff --git a/server/src/engine/mem/scanner.rs b/server/src/engine/mem/scanner.rs index 48332b64..5c3b925c 100644 --- a/server/src/engine/mem/scanner.rs +++ b/server/src/engine/mem/scanner.rs @@ -173,6 +173,7 @@ impl<'a, T> Scanner<'a, T> { } impl<'a> Scanner<'a, u8> { + #[cfg(test)] /// Attempt to parse the next byte pub fn try_next_byte(&mut self) -> Option { if self.eof() { diff --git a/server/src/engine/net/protocol/exchange.rs b/server/src/engine/net/protocol/exchange.rs index 1a1a1547..388ac87b 100644 --- a/server/src/engine/net/protocol/exchange.rs +++ b/server/src/engine/net/protocol/exchange.rs @@ -26,123 +26,27 @@ use crate::engine::mem::BufferedScanner; -pub const EXCHANGE_MIN_SIZE: usize = b"S1\nh".len(); -pub(super) const STATE_READ_INITIAL: QueryTimeExchangeResult<'static> = - QueryTimeExchangeResult::ChangeState { - new_state: QueryTimeExchangeState::Initial, - expect_more: EXCHANGE_MIN_SIZE, - }; -pub(super) const STATE_ERROR: QueryTimeExchangeResult<'static> = QueryTimeExchangeResult::Error; - -#[derive(Debug, PartialEq)] -/// State of a query time exchange -pub enum QueryTimeExchangeState { - /// beginning of exchange - Initial, - /// SQ (part of packet size) - SQ1Meta1Partial { packet_size_part: u64 }, - /// SQ (part of Q window) - SQ2Meta2Partial { - size_of_static_frame: usize, - packet_size: usize, - q_window_part: u64, - }, - /// SQ waiting for block - SQ3FinalizeWaitingForBlock { - dataframe_size: usize, - q_window: usize, - }, -} - -impl Default for QueryTimeExchangeState { - fn default() -> Self { - Self::Initial - } -} - -#[derive(Debug, PartialEq)] -/// Result after attempting to complete (or terminate) a query time exchange -pub enum QueryTimeExchangeResult<'a> { - /// We completed the exchange and yielded a [`SQuery`] - SQCompleted(SQuery<'a>), - /// We're changing states - ChangeState { - new_state: QueryTimeExchangeState, - expect_more: usize, - }, - /// We hit an error and need to terminate this exchange - Error, -} - -/// Resume a query time exchange -pub fn resume<'a>( - scanner: &mut BufferedScanner<'a>, - state: QueryTimeExchangeState, -) -> QueryTimeExchangeResult<'a> { - match state { - QueryTimeExchangeState::Initial => SQuery::resume_initial(scanner), - QueryTimeExchangeState::SQ1Meta1Partial { packet_size_part } => { - SQuery::resume_at_sq1_meta1_partial(scanner, packet_size_part) - } - QueryTimeExchangeState::SQ2Meta2Partial { - packet_size, - q_window_part, - size_of_static_frame, - } => SQuery::resume_at_sq2_meta2_partial( - scanner, - size_of_static_frame, - packet_size, - q_window_part, - ), - QueryTimeExchangeState::SQ3FinalizeWaitingForBlock { - dataframe_size, - q_window, - } => SQuery::resume_at_final(scanner, q_window, dataframe_size), - } +pub(super) unsafe fn resume<'a>( + buf: &'a [u8], + last_cursor: usize, + last_state: QExchangeState, +) -> (usize, QExchangeResult<'a>) { + let mut scanner = BufferedScanner::new_with_cursor(buf, last_cursor); + let ret = last_state.resume(&mut scanner); + (scanner.cursor(), ret) } /* SQ */ -enum LFTIntParseResult { +#[derive(Debug, PartialEq)] +pub(super) enum LFTIntParseResult { Value(u64), Partial(u64), Error, } -fn parse_lf_separated( - scanner: &mut BufferedScanner, - previously_buffered: u64, -) -> LFTIntParseResult { - let mut ret = previously_buffered; - let mut okay = true; - while scanner.rounded_cursor_not_eof_matches(|b| *b != b'\n') & okay { - let b = unsafe { scanner.next_byte() }; - okay &= b.is_ascii_digit(); - ret = match ret.checked_mul(10) { - Some(r) => r, - None => return LFTIntParseResult::Error, - }; - ret = match ret.checked_add((b & 0x0F) as u64) { - Some(r) => r, - None => return LFTIntParseResult::Error, - }; - } - let payload_ok = okay; - let lf_ok = scanner.rounded_cursor_not_eof_matches(|b| *b == b'\n'); - unsafe { scanner.incr_cursor_if(lf_ok) } - if payload_ok & lf_ok { - LFTIntParseResult::Value(ret) - } else { - if payload_ok { - LFTIntParseResult::Partial(ret) - } else { - LFTIntParseResult::Error - } - } -} - #[derive(Debug, PartialEq)] pub struct SQuery<'a> { q: &'a [u8], @@ -166,155 +70,204 @@ impl<'a> SQuery<'a> { &self.payload()[self.q_window()..] } #[cfg(test)] - pub fn query_str(&self) -> Option<&'a str> { - core::str::from_utf8(self.query()).ok() + pub fn query_str(&self) -> &str { + core::str::from_utf8(self.query()).unwrap() } #[cfg(test)] - pub fn params_str(&self) -> Option<&'a str> { - core::str::from_utf8(self.params()).ok() + pub fn params_str(&self) -> &str { + core::str::from_utf8(self.params()).unwrap() } } -impl<'a> SQuery<'a> { - /// We're touching this packet for the first time - fn resume_initial(scanner: &mut BufferedScanner<'a>) -> QueryTimeExchangeResult<'a> { - if cfg!(debug_assertions) { - if !scanner.has_left(EXCHANGE_MIN_SIZE) { - return STATE_READ_INITIAL; +/* + utils +*/ + +/// scan an integer: +/// - if just an LF: +/// - if disallowed single byte: return an error +/// - else, return value +/// - if no LF: return upto limit +/// - if LF: return value +pub(super) fn scanint( + scanner: &mut BufferedScanner, + first_run: bool, + prev: u64, +) -> LFTIntParseResult { + let mut current = prev; + // guard a case where the buffer might be empty and can potentially have invalid chars + let mut okay = !((scanner.rounded_cursor_value() == b'\n') & first_run); + while scanner.rounded_cursor_not_eof_matches(|b| b'\n'.ne(b)) & okay { + let byte = unsafe { scanner.next_byte() }; + okay &= byte.is_ascii_digit(); + match current + .checked_mul(10) + .map(|new| new.checked_add((byte & 0x0f) as u64)) + { + Some(Some(int)) => { + current = int; } - } else { - assert!(scanner.has_left(EXCHANGE_MIN_SIZE)); - } - // attempt to read atleast one byte - if cfg!(debug_assertions) { - match scanner.try_next_byte() { - Some(b'S') => {} - Some(_) => return STATE_ERROR, - None => return STATE_READ_INITIAL, + _ => { + okay = false; } + } + } + let lf = scanner.rounded_cursor_not_eof_equals(b'\n'); + unsafe { + // UNSAFE(@ohsayan): within buffer range + scanner.incr_cursor_if(lf); + } + if lf & okay { + LFTIntParseResult::Value(current) + } else { + if okay { + LFTIntParseResult::Partial(current) } else { - match unsafe { scanner.next_byte() } { - b'S' => {} - _ => return STATE_ERROR, - } + LFTIntParseResult::Error } - Self::resume_at_sq1_meta1_partial(scanner, 0) } - /// We found some part of meta1, and need to resume - fn resume_at_sq1_meta1_partial( - scanner: &mut BufferedScanner<'a>, - prev: u64, - ) -> QueryTimeExchangeResult<'a> { - match parse_lf_separated(scanner, prev) { - LFTIntParseResult::Value(packet_size) => { - // we got the packet size; can we get the q window? - Self::resume_at_sq2_meta2_partial( - scanner, - scanner.cursor(), - packet_size as usize, - 0, - ) - } - LFTIntParseResult::Partial(partial_packet_size) => { - // we couldn't get the packet size - QueryTimeExchangeResult::ChangeState { - new_state: QueryTimeExchangeState::SQ1Meta1Partial { - packet_size_part: partial_packet_size, - }, - expect_more: 3, // 1LF + 1ASCII + 1LF - } - } - LFTIntParseResult::Error => STATE_ERROR, +} + +#[derive(Debug, PartialEq, Clone, Copy)] +pub(super) enum QExchangeStateInternal { + Initial, + PendingMeta1, + PendingMeta2, + PendingData, +} + +impl Default for QExchangeStateInternal { + fn default() -> Self { + Self::Initial + } +} + +#[derive(Debug, PartialEq)] +pub(super) struct QExchangeState { + state: QExchangeStateInternal, + target: usize, + md_packet_size: u64, + md_q_window: u64, +} + +impl Default for QExchangeState { + fn default() -> Self { + Self::new() + } +} + +#[derive(Debug, PartialEq)] +/// Result after attempting to complete (or terminate) a query time exchange +pub(super) enum QExchangeResult<'a> { + /// We completed the exchange and yielded a [`SQuery`] + SQCompleted(SQuery<'a>), + /// We're changing states + ChangeState(QExchangeState), + /// We hit an error and need to terminate this exchange + Error, +} + +impl QExchangeState { + fn _new( + state: QExchangeStateInternal, + target: usize, + md_packet_size: u64, + md_q_window: u64, + ) -> Self { + Self { + state, + target, + md_packet_size, + md_q_window, } } - /// We found some part of meta2, and need to resume - fn resume_at_sq2_meta2_partial( + #[cfg(test)] + pub(super) fn new_test( + state: QExchangeStateInternal, + target: usize, + md_packet_size: u64, + md_q_window: u64, + ) -> Self { + Self::_new(state, target, md_packet_size, md_q_window) + } +} + +impl QExchangeState { + pub const MIN_READ: usize = b"S\x00\n\x00\n".len(); + pub fn new() -> Self { + Self::_new(QExchangeStateInternal::Initial, Self::MIN_READ, 0, 0) + } + pub fn has_reached_target(&self, new_buffer: &[u8]) -> bool { + new_buffer.len() >= self.target + } + fn resume<'a>(self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> { + debug_assert!(scanner.has_left(Self::MIN_READ)); + match self.state { + QExchangeStateInternal::Initial => self.start_initial(scanner), + QExchangeStateInternal::PendingMeta1 => self.resume_at_md1(scanner, false), + QExchangeStateInternal::PendingMeta2 => self.resume_at_md2(scanner, false), + QExchangeStateInternal::PendingData => self.resume_data(scanner), + } + } + fn start_initial<'a>(self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> { + if unsafe { scanner.next_byte() } != b'S' { + // has to be a simple query! + return QExchangeResult::Error; + } + self.resume_at_md1(scanner, true) + } + fn resume_at_md1<'a>( + mut self, scanner: &mut BufferedScanner<'a>, - static_size: usize, - packet_size: usize, - prev_qw_buffered: u64, - ) -> QueryTimeExchangeResult<'a> { - let start = scanner.cursor(); - match parse_lf_separated(scanner, prev_qw_buffered) { - LFTIntParseResult::Value(q_window) => { - // we got the q window; can we complete the exchange? - let df_size = Self::compute_df_size(scanner, static_size, packet_size); - if df_size == 0 { - return QueryTimeExchangeResult::Error; - } - Self::resume_at_final(scanner, q_window as usize, df_size) - } - LFTIntParseResult::Partial(q_window_partial) => { - // not enough bytes for getting Q window - QueryTimeExchangeResult::ChangeState { - new_state: QueryTimeExchangeState::SQ2Meta2Partial { - packet_size, - q_window_part: q_window_partial, - size_of_static_frame: static_size, - }, - // we passed cursor - start bytes out of the packet, so expect this more - expect_more: packet_size - (scanner.cursor() - start), - } + first_run: bool, + ) -> QExchangeResult<'a> { + let packet_size = match scanint(scanner, first_run, self.md_packet_size) { + LFTIntParseResult::Value(v) => v, + LFTIntParseResult::Partial(p) => { + // if this is the first run, we read 5 bytes and need atleast one more; if this is a resume we read one or more bytes and + // need atleast one more + self.target += 1; + self.md_packet_size = p; + self.state = QExchangeStateInternal::PendingMeta1; + return QExchangeResult::ChangeState(self); } - LFTIntParseResult::Error => STATE_ERROR, - } + LFTIntParseResult::Error => return QExchangeResult::Error, + }; + self.md_packet_size = packet_size; + self.target = scanner.cursor() + packet_size as usize; + // hand over control to md2 + self.resume_at_md2(scanner, true) } - /// We got all our meta and need the dataframe - fn resume_at_final( + fn resume_at_md2<'a>( + mut self, scanner: &mut BufferedScanner<'a>, - q_window: usize, - dataframe_size: usize, - ) -> QueryTimeExchangeResult<'a> { - if scanner.has_left(dataframe_size) { - // we have sufficient bytes for the dataframe + first_run: bool, + ) -> QExchangeResult<'a> { + let q_window = match scanint(scanner, first_run, self.md_q_window) { + LFTIntParseResult::Value(v) => v, + LFTIntParseResult::Partial(p) => { + self.md_q_window = p; + self.state = QExchangeStateInternal::PendingMeta2; + return QExchangeResult::ChangeState(self); + } + LFTIntParseResult::Error => return QExchangeResult::Error, + }; + self.md_q_window = q_window; + // hand over control to data + self.resume_data(scanner) + } + fn resume_data<'a>(mut self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> { + let df_size = self.target - scanner.cursor(); + if scanner.remaining() == df_size { unsafe { - // UNSAFE(@ohsayan): +lenck - QueryTimeExchangeResult::SQCompleted(SQuery::new( - scanner.next_chunk_variable(dataframe_size), - q_window, + QExchangeResult::SQCompleted(SQuery::new( + scanner.next_chunk_variable(df_size), + self.md_q_window as usize, )) } } else { - // not enough bytes for the dataframe - QueryTimeExchangeResult::ChangeState { - new_state: QueryTimeExchangeState::SQ3FinalizeWaitingForBlock { - dataframe_size, - q_window, - }, - expect_more: Self::compute_df_remaining(scanner, dataframe_size), // dataframe - } + self.state = QExchangeStateInternal::PendingData; + QExchangeResult::ChangeState(self) } } } - -impl<'a> SQuery<'a> { - fn compute_df_size(scanner: &BufferedScanner, static_size: usize, packet_size: usize) -> usize { - (packet_size + static_size) - scanner.cursor() - } - fn compute_df_remaining(scanner: &BufferedScanner<'_>, df_size: usize) -> usize { - (scanner.cursor() + df_size) - scanner.buffer_len() - } -} - -#[cfg(test)] -pub(super) fn create_simple_query(query: &str, params: [&str; N]) -> Vec { - let mut buf = vec![]; - let query_size_as_string = query.len().to_string(); - let size_of_variable_section = query.len() - + params.iter().map(|l| l.len()).sum::() - + query_size_as_string.len() - + 1; - // segment 1 - buf.push(b'S'); - buf.extend(size_of_variable_section.to_string().as_bytes()); - buf.push(b'\n'); - // segment - buf.extend(query_size_as_string.as_bytes()); - buf.push(b'\n'); - // dataframe - buf.extend(query.as_bytes()); - params - .into_iter() - .for_each(|param| buf.extend(param.as_bytes())); - buf -} diff --git a/server/src/engine/net/protocol/mod.rs b/server/src/engine/net/protocol/mod.rs index c745086c..a8b58074 100644 --- a/server/src/engine/net/protocol/mod.rs +++ b/server/src/engine/net/protocol/mod.rs @@ -34,7 +34,7 @@ pub use exchange::SQuery; use { self::{ - exchange::{QueryTimeExchangeResult, QueryTimeExchangeState}, + exchange::{QExchangeResult, QExchangeState}, handshake::{ AuthMode, CHandshake, DataExchangeMode, HandshakeResult, HandshakeState, HandshakeVersion, ProtocolError, ProtocolVersion, QueryMode, @@ -140,42 +140,41 @@ pub(super) async fn query_loop( // done handshaking con.write_all(b"H\x00\x00\x00").await?; con.flush().await?; - let mut exchg_state = QueryTimeExchangeState::default(); - let mut expect_more = exchange::EXCHANGE_MIN_SIZE; + let mut state = QExchangeState::default(); let mut cursor = 0; loop { - let read_many = con.read_buf(buf).await?; - if read_many == 0 { + if con.read_buf(buf).await? == 0 { if buf.is_empty() { return Ok(QueryLoopResult::Fin); } else { return Ok(QueryLoopResult::Rst); } } - if read_many < expect_more { + if !state.has_reached_target(buf) { // we haven't buffered sufficient bytes; keep working continue; } - let mut buffer = unsafe { BufferedScanner::new_with_cursor(&buf, cursor) }; - let sq = match exchange::resume(&mut buffer, exchg_state) { - QueryTimeExchangeResult::ChangeState { - new_state, - expect_more: _more, - } => { - exchg_state = new_state; - expect_more = _more; - cursor = buffer.cursor(); + let sq = match unsafe { + // UNSAFE(@ohsayan): we store the cursor from the last run + exchange::resume(buf, cursor, state) + } { + (_, QExchangeResult::SQCompleted(sq)) => sq, + (new_cursor, QExchangeResult::ChangeState(new_state)) => { + cursor = new_cursor; + state = new_state; continue; } - QueryTimeExchangeResult::SQCompleted(sq) => sq, - QueryTimeExchangeResult::Error => { + (_, QExchangeResult::Error) => { + // respond with error let [a, b] = (QueryError::NetworkSubsystemCorruptedPacket.value_u8() as u16).to_le_bytes(); con.write_all(&[ResponseType::Error.value_u8(), a, b]) .await?; con.flush().await?; + // reset buffer, cursor and state buf.clear(); - exchg_state = QueryTimeExchangeState::default(); + cursor = 0; + state = QExchangeState::default(); continue; } }; @@ -198,8 +197,10 @@ pub(super) async fn query_loop( } } con.flush().await?; + // reset buffer, cursor and state buf.clear(); - exchg_state = QueryTimeExchangeState::default(); + cursor = 0; + state = QExchangeState::default(); } } diff --git a/server/src/engine/net/protocol/tests.rs b/server/src/engine/net/protocol/tests.rs index c32ca03e..9eab9549 100644 --- a/server/src/engine/net/protocol/tests.rs +++ b/server/src/engine/net/protocol/tests.rs @@ -24,17 +24,46 @@ * */ -use crate::engine::{ - mem::BufferedScanner, - net::protocol::{ - exchange::{self, create_simple_query, QueryTimeExchangeResult, QueryTimeExchangeState}, - handshake::{ - AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode, - HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode, +use { + super::{ + exchange::{self, scanint, LFTIntParseResult, QExchangeResult, QExchangeState}, + SQuery, + }, + crate::{ + engine::{ + mem::BufferedScanner, + net::protocol::handshake::{ + AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode, + HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode, + }, }, + util::test_utils, }, + rand::Rng, }; +pub(super) fn create_simple_query(query: &str, params: [&str; N]) -> Vec { + let mut buf = vec![]; + let query_size_as_string = query.len().to_string(); + let total_packet_size = query.len() + + params.iter().map(|l| l.len()).sum::() + + query_size_as_string.len() + + 1; + // segment 1 + buf.push(b'S'); + buf.extend(total_packet_size.to_string().as_bytes()); + buf.push(b'\n'); + // segment + buf.extend(query_size_as_string.as_bytes()); + buf.push(b'\n'); + // dataframe + buf.extend(query.as_bytes()); + params + .into_iter() + .for_each(|param| buf.extend(param.as_bytes())); + buf +} + /* client handshake */ @@ -155,110 +184,155 @@ fn parse_auth_with_state_updates() { const SQ: &str = "select * from myspace.mymodel where username = ?"; -#[test] -fn qtdex_simple_query() { - let query = create_simple_query(SQ, ["sayan"]); - let mut fin = 52; - for i in 0..query.len() { - let mut scanner = BufferedScanner::new(&query[..i + 1]); - let result = exchange::resume(&mut scanner, Default::default()); - match scanner.buffer_len() { - 1..=3 => assert_eq!(result, exchange::STATE_READ_INITIAL), - 4 => assert_eq!( - result, - QueryTimeExchangeResult::ChangeState { - new_state: QueryTimeExchangeState::SQ2Meta2Partial { - size_of_static_frame: 4, - packet_size: 56, - q_window_part: 0, - }, - expect_more: 56, - } - ), - 5 => assert_eq!( - result, - QueryTimeExchangeResult::ChangeState { - new_state: QueryTimeExchangeState::SQ2Meta2Partial { - size_of_static_frame: 4, - packet_size: 56, - q_window_part: 4, - }, - expect_more: 55, +fn parse_staged( + query: &str, + params: [&str; N], + eq: impl Fn(SQuery), + rng: &mut impl Rng, +) { + let __query_buffer = create_simple_query(query, params); + for _ in 0..__query_buffer.len() { + let mut __read_total = 0; + let mut cursor = 0; + let mut state = QExchangeState::default(); + loop { + let remaining = __query_buffer.len() - __read_total; + let read_this_time = { + let mut cnt = 0; + if remaining == 1 { + 1 + } else { + let mut last = test_utils::random_number(1, remaining, rng); + loop { + if cnt >= 10 { + break last; + } + // if we're reading exact, try to keep it low + if last == remaining { + cnt += 1; + last = test_utils::random_number(1, remaining, rng); + } else { + break last; + } + } } - ), - 6 => assert_eq!( - result, - QueryTimeExchangeResult::ChangeState { - new_state: QueryTimeExchangeState::SQ2Meta2Partial { - size_of_static_frame: 4, - packet_size: 56, - q_window_part: 48, - }, - expect_more: 54, + }; + __read_total += read_this_time; + let buffered = &__query_buffer[..__read_total]; + if !state.has_reached_target(buffered) { + continue; + } + match unsafe { exchange::resume(buffered, cursor, state) } { + (new_cursor, QExchangeResult::ChangeState(new_state)) => { + cursor = new_cursor; + state = new_state; + continue; } - ), - 7 => assert_eq!( - result, - QueryTimeExchangeResult::ChangeState { - new_state: QueryTimeExchangeState::SQ3FinalizeWaitingForBlock { - dataframe_size: 53, - q_window: 48, - }, - expect_more: 53, + (_, QExchangeResult::SQCompleted(q)) => { + eq(q); + break; } - ), - 8..=59 => { - assert_eq!( - result, - QueryTimeExchangeResult::ChangeState { - new_state: QueryTimeExchangeState::SQ3FinalizeWaitingForBlock { - dataframe_size: 53, - q_window: 48 - }, - expect_more: fin, - } - ); - fin -= 1; + _ => panic!(), } - 60 => match result { - QueryTimeExchangeResult::SQCompleted(sq) => { - assert_eq!(sq.query_str().unwrap(), SQ); - assert_eq!(sq.params_str().unwrap(), "sayan"); - } - _ => unreachable!(), - }, - _ => unreachable!(), } } } #[test] -fn qtdex_simple_query_update_state() { - let query = create_simple_query(SQ, ["sayan"]); - let mut state = QueryTimeExchangeState::default(); - let mut cursor = 0; - let mut expected = 0; - let mut rounds = 0; - loop { - rounds += 1; - let buf = &query[..expected + cursor]; - let mut scanner = unsafe { BufferedScanner::new_with_cursor(buf, cursor) }; - match exchange::resume(&mut scanner, state) { - QueryTimeExchangeResult::SQCompleted(sq) => { - assert_eq!(sq.query_str().unwrap(), SQ); - assert_eq!(sq.params_str().unwrap(), "sayan"); - break; - } - QueryTimeExchangeResult::ChangeState { - new_state, - expect_more, - } => { - expected = expect_more; - state = new_state; - } - QueryTimeExchangeResult::Error => panic!("hit error!"), +fn staged_randomized() { + let mut rng = test_utils::rng(); + parse_staged( + SQ, + ["sayan"], + |q| { + assert_eq!(q.query_str(), SQ); + assert_eq!(q.params_str(), "sayan"); + }, + &mut rng, + ); +} + +#[test] +fn stages_manual() { + let query = create_simple_query("select * from mymodel where username = ?", ["sayan"]); + assert_eq!( + unsafe { exchange::resume(&query[..QExchangeState::MIN_READ], 0, Default::default()) }, + ( + 5, + QExchangeResult::ChangeState(QExchangeState::new_test( + exchange::QExchangeStateInternal::PendingMeta2, + 52, + 48, + 4 + )) + ) + ); + assert_eq!( + unsafe { + exchange::resume( + &query[..QExchangeState::MIN_READ + 1], + 0, + Default::default(), + ) + }, + ( + 6, + QExchangeResult::ChangeState(QExchangeState::new_test( + exchange::QExchangeStateInternal::PendingMeta2, + 52, + 48, + 40 + )) + ) + ); + assert_eq!( + unsafe { + exchange::resume( + &query[..QExchangeState::MIN_READ + 2], + 0, + Default::default(), + ) + }, + ( + 7, + QExchangeResult::ChangeState(QExchangeState::new_test( + exchange::QExchangeStateInternal::PendingData, + 52, + 48, + 40 + )) + ) + ); + // the cursor should never change + for upper_bound in QExchangeState::MIN_READ + 2..query.len() { + assert_eq!( + unsafe { exchange::resume(&query[..upper_bound], 0, Default::default()) }, + ( + 7, + QExchangeResult::ChangeState(QExchangeState::new_test( + exchange::QExchangeStateInternal::PendingData, + 52, + 48, + 40 + )) + ) + ); + } + match unsafe { exchange::resume(&query, 0, Default::default()) } { + (l, QExchangeResult::SQCompleted(q)) if l == query.len() => { + assert_eq!(q.query_str(), "select * from mymodel where username = ?"); + assert_eq!(q.params_str(), "sayan"); } - cursor = scanner.cursor(); + e => panic!("expected end, got {e:?}"), } - assert_eq!(rounds, 3); +} + +#[test] +fn scanint_impl() { + let mut s = BufferedScanner::new(b"\n"); + assert_eq!(scanint(&mut s, true, 0), LFTIntParseResult::Error); + let mut s = BufferedScanner::new(b"12"); + assert_eq!(scanint(&mut s, true, 0), LFTIntParseResult::Partial(12)); + let mut s = BufferedScanner::new(b"12\n"); + assert_eq!(scanint(&mut s, true, 0), LFTIntParseResult::Value(12)); } diff --git a/server/src/util/test_utils.rs b/server/src/util/test_utils.rs index e3cf35ab..a12110a1 100644 --- a/server/src/util/test_utils.rs +++ b/server/src/util/test_utils.rs @@ -38,6 +38,10 @@ use { }, }; +pub fn rng() -> ThreadRng { + rand::thread_rng() +} + pub fn shuffle_slice(slice: &mut [T], rng: &mut impl Rng) { slice.shuffle(rng) } @@ -73,8 +77,8 @@ pub fn random_bool(rng: &mut impl Rng) -> bool { rng.gen_bool(0.5) } /// Generate a random number within the given range -pub fn random_number(max: T, min: T, rng: &mut impl Rng) -> T { - rng.gen_range(max..min) +pub fn random_number(min: T, max: T, rng: &mut impl Rng) -> T { + rng.gen_range(min..max) } pub fn random_string(rng: &mut impl Rng, l: usize) -> String {