diff --git a/server/src/engine/net/protocol/exchange.rs b/server/src/engine/net/protocol/exchange.rs index 388ac87b..e6a81d27 100644 --- a/server/src/engine/net/protocol/exchange.rs +++ b/server/src/engine/net/protocol/exchange.rs @@ -26,14 +26,35 @@ use crate::engine::mem::BufferedScanner; +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] +pub struct Resume(usize); +impl Resume { + #[cfg(test)] + pub(super) const fn test_new(v: usize) -> Self { + Self(v) + } + #[cfg(test)] + pub(super) const fn inner(&self) -> usize { + self.0 + } +} +impl Default for Resume { + fn default() -> Self { + Self(0) + } +} + pub(super) unsafe fn resume<'a>( buf: &'a [u8], - last_cursor: usize, + last_cursor: Resume, last_state: QExchangeState, -) -> (usize, QExchangeResult<'a>) { - let mut scanner = BufferedScanner::new_with_cursor(buf, last_cursor); +) -> (Resume, QExchangeResult<'a>) { + let mut scanner = unsafe { + // UNSAFE(@ohsayan): we are the ones who generate the cursor and restore it + BufferedScanner::new_with_cursor(buf, last_cursor.0) + }; let ret = last_state.resume(&mut scanner); - (scanner.cursor(), ret) + (Resume(scanner.cursor()), ret) } /* diff --git a/server/src/engine/net/protocol/mod.rs b/server/src/engine/net/protocol/mod.rs index 1dd6d62c..20175b05 100644 --- a/server/src/engine/net/protocol/mod.rs +++ b/server/src/engine/net/protocol/mod.rs @@ -144,7 +144,7 @@ pub(super) async fn query_loop( con.write_all(b"H\x00\x00\x00").await?; con.flush().await?; let mut state = QExchangeState::default(); - let mut cursor = 0; + let mut cursor = Default::default(); loop { if con.read_buf(buf).await? == 0 { if buf.is_empty() { @@ -158,7 +158,7 @@ pub(super) async fn query_loop( continue; } let sq = match unsafe { - // UNSAFE(@ohsayan): we store the cursor from the last run + // UNSAFE(@ohsayan): as the resume cursor is private, we can't access this anyways exchange::resume(buf, cursor, state) } { (_, QExchangeResult::SQCompleted(sq)) => sq, @@ -176,7 +176,7 @@ pub(super) async fn query_loop( con.flush().await?; // reset buffer, cursor and state buf.clear(); - cursor = 0; + cursor = Default::default(); state = QExchangeState::default(); continue; } @@ -207,7 +207,7 @@ pub(super) async fn query_loop( con.flush().await?; // reset buffer, cursor and state buf.clear(); - cursor = 0; + cursor = Default::default(); state = QExchangeState::default(); } } diff --git a/server/src/engine/net/protocol/tests.rs b/server/src/engine/net/protocol/tests.rs index 9eab9549..c15a2779 100644 --- a/server/src/engine/net/protocol/tests.rs +++ b/server/src/engine/net/protocol/tests.rs @@ -24,9 +24,12 @@ * */ +use crate::engine::net::protocol::exchange::Resume; + use { super::{ exchange::{self, scanint, LFTIntParseResult, QExchangeResult, QExchangeState}, + handshake::ProtocolError, SQuery, }, crate::{ @@ -68,7 +71,7 @@ pub(super) fn create_simple_query(query: &str, params: [&str; N] client handshake */ -const FULL_HANDSHAKE_WITH_AUTH: [u8; 23] = *b"H\0\0\0\0\x005\n8\nsayanpass1234"; +const FULL_HANDSHAKE_WITH_AUTH: [u8; 23] = *b"H\0\0\0\0\05\n8\nsayanpass1234"; const STATIC_HANDSHAKE_WITH_AUTH: CHandshakeStatic = CHandshakeStatic::new( HandshakeVersion::Original, @@ -178,6 +181,76 @@ fn parse_auth_with_state_updates() { assert_eq!(rounds, 3); // r1 = initial read, r2 = lengths, r3 = items } +const HS_BAD_PACKET: [u8; 6] = *b"I\x00\0\0\0\0"; +const HS_BAD_VERSION_HS: [u8; 6] = *b"H\x01\0\0\0\0"; +const HS_BAD_VERSION_PROTO: [u8; 6] = *b"H\0\x01\0\0\0"; +const HS_BAD_MODE_XCHG: [u8; 6] = *b"H\0\0\x01\0\0"; +const HS_BAD_MODE_QUERY: [u8; 6] = *b"H\0\0\0\x01\0"; +const HS_BAD_MODE_AUTH: [u8; 6] = *b"H\0\0\0\0\x01"; + +fn scan_hs(hs: impl AsRef<[u8]>, f: impl Fn(HandshakeResult)) { + let mut scanner = BufferedScanner::new(hs.as_ref()); + let hs = CHandshake::resume_with(&mut scanner, Default::default()); + f(hs) +} + +#[test] +fn hs_bad_packet() { + scan_hs(HS_BAD_PACKET, |hs_result| { + assert_eq!( + hs_result, + HandshakeResult::Error(ProtocolError::CorruptedHSPacket) + ) + }) +} + +#[test] +fn hs_bad_version_hs() { + scan_hs(HS_BAD_VERSION_HS, |hs_result| { + assert_eq!( + hs_result, + HandshakeResult::Error(ProtocolError::RejectHSVersion) + ) + }) +} + +#[test] +fn hs_bad_version_proto() { + scan_hs(HS_BAD_VERSION_PROTO, |hs_result| { + assert_eq!( + hs_result, + HandshakeResult::Error(ProtocolError::RejectProtocol) + ) + }) +} + +#[test] +fn hs_bad_exchange_mode() { + scan_hs(HS_BAD_MODE_XCHG, |hs_result| { + assert_eq!( + hs_result, + HandshakeResult::Error(ProtocolError::RejectExchangeMode) + ) + }) +} + +#[test] +fn hs_bad_query_mode() { + scan_hs(HS_BAD_MODE_QUERY, |hs_result| { + assert_eq!( + hs_result, + HandshakeResult::Error(ProtocolError::RejectQueryMode) + ) + }) +} + +#[test] +fn hs_bad_auth_mode() { + scan_hs(HS_BAD_MODE_AUTH, |hs_result| { + assert_eq!(hs_result, HandshakeResult::Error(ProtocolError::RejectAuth)) + }) +} + /* QT-DEX/SQ */ @@ -193,7 +266,7 @@ fn parse_staged( let __query_buffer = create_simple_query(query, params); for _ in 0..__query_buffer.len() { let mut __read_total = 0; - let mut cursor = 0; + let mut cursor = Default::default(); let mut state = QExchangeState::default(); loop { let remaining = __query_buffer.len() - __read_total; @@ -256,9 +329,15 @@ fn staged_randomized() { fn stages_manual() { let query = create_simple_query("select * from mymodel where username = ?", ["sayan"]); assert_eq!( - unsafe { exchange::resume(&query[..QExchangeState::MIN_READ], 0, Default::default()) }, + unsafe { + exchange::resume( + &query[..QExchangeState::MIN_READ], + Default::default(), + Default::default(), + ) + }, ( - 5, + Resume::test_new(5), QExchangeResult::ChangeState(QExchangeState::new_test( exchange::QExchangeStateInternal::PendingMeta2, 52, @@ -271,12 +350,12 @@ fn stages_manual() { unsafe { exchange::resume( &query[..QExchangeState::MIN_READ + 1], - 0, + Default::default(), Default::default(), ) }, ( - 6, + Resume::test_new(6), QExchangeResult::ChangeState(QExchangeState::new_test( exchange::QExchangeStateInternal::PendingMeta2, 52, @@ -289,12 +368,12 @@ fn stages_manual() { unsafe { exchange::resume( &query[..QExchangeState::MIN_READ + 2], - 0, + Default::default(), Default::default(), ) }, ( - 7, + Resume::test_new(7), QExchangeResult::ChangeState(QExchangeState::new_test( exchange::QExchangeStateInternal::PendingData, 52, @@ -306,9 +385,15 @@ fn stages_manual() { // the cursor should never change for upper_bound in QExchangeState::MIN_READ + 2..query.len() { assert_eq!( - unsafe { exchange::resume(&query[..upper_bound], 0, Default::default()) }, + unsafe { + exchange::resume( + &query[..upper_bound], + Default::default(), + Default::default(), + ) + }, ( - 7, + Resume::test_new(7), QExchangeResult::ChangeState(QExchangeState::new_test( exchange::QExchangeStateInternal::PendingData, 52, @@ -318,8 +403,8 @@ fn stages_manual() { ) ); } - match unsafe { exchange::resume(&query, 0, Default::default()) } { - (l, QExchangeResult::SQCompleted(q)) if l == query.len() => { + match unsafe { exchange::resume(&query, Default::default(), Default::default()) } { + (l, QExchangeResult::SQCompleted(q)) if l.inner() == query.len() => { assert_eq!(q.query_str(), "select * from mymodel where username = ?"); assert_eq!(q.params_str(), "sayan"); } diff --git a/server/src/engine/storage/header.rs b/server/src/engine/storage/header.rs deleted file mode 100644 index e69de29b..00000000