Add more proto tests

next
Sayan Nandan 9 months ago
parent 8e0a94d4f9
commit 6c50a4042a

@ -26,14 +26,35 @@
use crate::engine::mem::BufferedScanner;
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct Resume(usize);
impl Resume {
#[cfg(test)]
pub(super) const fn test_new(v: usize) -> Self {
Self(v)
}
#[cfg(test)]
pub(super) const fn inner(&self) -> usize {
self.0
}
}
impl Default for Resume {
fn default() -> Self {
Self(0)
}
}
pub(super) unsafe fn resume<'a>(
buf: &'a [u8],
last_cursor: usize,
last_cursor: Resume,
last_state: QExchangeState,
) -> (usize, QExchangeResult<'a>) {
let mut scanner = BufferedScanner::new_with_cursor(buf, last_cursor);
) -> (Resume, QExchangeResult<'a>) {
let mut scanner = unsafe {
// UNSAFE(@ohsayan): we are the ones who generate the cursor and restore it
BufferedScanner::new_with_cursor(buf, last_cursor.0)
};
let ret = last_state.resume(&mut scanner);
(scanner.cursor(), ret)
(Resume(scanner.cursor()), ret)
}
/*

@ -144,7 +144,7 @@ pub(super) async fn query_loop<S: Socket>(
con.write_all(b"H\x00\x00\x00").await?;
con.flush().await?;
let mut state = QExchangeState::default();
let mut cursor = 0;
let mut cursor = Default::default();
loop {
if con.read_buf(buf).await? == 0 {
if buf.is_empty() {
@ -158,7 +158,7 @@ pub(super) async fn query_loop<S: Socket>(
continue;
}
let sq = match unsafe {
// UNSAFE(@ohsayan): we store the cursor from the last run
// UNSAFE(@ohsayan): as the resume cursor is private, we can't access this anyways
exchange::resume(buf, cursor, state)
} {
(_, QExchangeResult::SQCompleted(sq)) => sq,
@ -176,7 +176,7 @@ pub(super) async fn query_loop<S: Socket>(
con.flush().await?;
// reset buffer, cursor and state
buf.clear();
cursor = 0;
cursor = Default::default();
state = QExchangeState::default();
continue;
}
@ -207,7 +207,7 @@ pub(super) async fn query_loop<S: Socket>(
con.flush().await?;
// reset buffer, cursor and state
buf.clear();
cursor = 0;
cursor = Default::default();
state = QExchangeState::default();
}
}

@ -24,9 +24,12 @@
*
*/
use crate::engine::net::protocol::exchange::Resume;
use {
super::{
exchange::{self, scanint, LFTIntParseResult, QExchangeResult, QExchangeState},
handshake::ProtocolError,
SQuery,
},
crate::{
@ -68,7 +71,7 @@ pub(super) fn create_simple_query<const N: usize>(query: &str, params: [&str; N]
client handshake
*/
const FULL_HANDSHAKE_WITH_AUTH: [u8; 23] = *b"H\0\0\0\0\x005\n8\nsayanpass1234";
const FULL_HANDSHAKE_WITH_AUTH: [u8; 23] = *b"H\0\0\0\0\05\n8\nsayanpass1234";
const STATIC_HANDSHAKE_WITH_AUTH: CHandshakeStatic = CHandshakeStatic::new(
HandshakeVersion::Original,
@ -178,6 +181,76 @@ fn parse_auth_with_state_updates() {
assert_eq!(rounds, 3); // r1 = initial read, r2 = lengths, r3 = items
}
const HS_BAD_PACKET: [u8; 6] = *b"I\x00\0\0\0\0";
const HS_BAD_VERSION_HS: [u8; 6] = *b"H\x01\0\0\0\0";
const HS_BAD_VERSION_PROTO: [u8; 6] = *b"H\0\x01\0\0\0";
const HS_BAD_MODE_XCHG: [u8; 6] = *b"H\0\0\x01\0\0";
const HS_BAD_MODE_QUERY: [u8; 6] = *b"H\0\0\0\x01\0";
const HS_BAD_MODE_AUTH: [u8; 6] = *b"H\0\0\0\0\x01";
fn scan_hs(hs: impl AsRef<[u8]>, f: impl Fn(HandshakeResult)) {
let mut scanner = BufferedScanner::new(hs.as_ref());
let hs = CHandshake::resume_with(&mut scanner, Default::default());
f(hs)
}
#[test]
fn hs_bad_packet() {
scan_hs(HS_BAD_PACKET, |hs_result| {
assert_eq!(
hs_result,
HandshakeResult::Error(ProtocolError::CorruptedHSPacket)
)
})
}
#[test]
fn hs_bad_version_hs() {
scan_hs(HS_BAD_VERSION_HS, |hs_result| {
assert_eq!(
hs_result,
HandshakeResult::Error(ProtocolError::RejectHSVersion)
)
})
}
#[test]
fn hs_bad_version_proto() {
scan_hs(HS_BAD_VERSION_PROTO, |hs_result| {
assert_eq!(
hs_result,
HandshakeResult::Error(ProtocolError::RejectProtocol)
)
})
}
#[test]
fn hs_bad_exchange_mode() {
scan_hs(HS_BAD_MODE_XCHG, |hs_result| {
assert_eq!(
hs_result,
HandshakeResult::Error(ProtocolError::RejectExchangeMode)
)
})
}
#[test]
fn hs_bad_query_mode() {
scan_hs(HS_BAD_MODE_QUERY, |hs_result| {
assert_eq!(
hs_result,
HandshakeResult::Error(ProtocolError::RejectQueryMode)
)
})
}
#[test]
fn hs_bad_auth_mode() {
scan_hs(HS_BAD_MODE_AUTH, |hs_result| {
assert_eq!(hs_result, HandshakeResult::Error(ProtocolError::RejectAuth))
})
}
/*
QT-DEX/SQ
*/
@ -193,7 +266,7 @@ fn parse_staged<const N: usize>(
let __query_buffer = create_simple_query(query, params);
for _ in 0..__query_buffer.len() {
let mut __read_total = 0;
let mut cursor = 0;
let mut cursor = Default::default();
let mut state = QExchangeState::default();
loop {
let remaining = __query_buffer.len() - __read_total;
@ -256,9 +329,15 @@ fn staged_randomized() {
fn stages_manual() {
let query = create_simple_query("select * from mymodel where username = ?", ["sayan"]);
assert_eq!(
unsafe { exchange::resume(&query[..QExchangeState::MIN_READ], 0, Default::default()) },
unsafe {
exchange::resume(
&query[..QExchangeState::MIN_READ],
Default::default(),
Default::default(),
)
},
(
5,
Resume::test_new(5),
QExchangeResult::ChangeState(QExchangeState::new_test(
exchange::QExchangeStateInternal::PendingMeta2,
52,
@ -271,12 +350,12 @@ fn stages_manual() {
unsafe {
exchange::resume(
&query[..QExchangeState::MIN_READ + 1],
0,
Default::default(),
Default::default(),
)
},
(
6,
Resume::test_new(6),
QExchangeResult::ChangeState(QExchangeState::new_test(
exchange::QExchangeStateInternal::PendingMeta2,
52,
@ -289,12 +368,12 @@ fn stages_manual() {
unsafe {
exchange::resume(
&query[..QExchangeState::MIN_READ + 2],
0,
Default::default(),
Default::default(),
)
},
(
7,
Resume::test_new(7),
QExchangeResult::ChangeState(QExchangeState::new_test(
exchange::QExchangeStateInternal::PendingData,
52,
@ -306,9 +385,15 @@ fn stages_manual() {
// the cursor should never change
for upper_bound in QExchangeState::MIN_READ + 2..query.len() {
assert_eq!(
unsafe { exchange::resume(&query[..upper_bound], 0, Default::default()) },
unsafe {
exchange::resume(
&query[..upper_bound],
Default::default(),
Default::default(),
)
},
(
7,
Resume::test_new(7),
QExchangeResult::ChangeState(QExchangeState::new_test(
exchange::QExchangeStateInternal::PendingData,
52,
@ -318,8 +403,8 @@ fn stages_manual() {
)
);
}
match unsafe { exchange::resume(&query, 0, Default::default()) } {
(l, QExchangeResult::SQCompleted(q)) if l == query.len() => {
match unsafe { exchange::resume(&query, Default::default(), Default::default()) } {
(l, QExchangeResult::SQCompleted(q)) if l.inner() == query.len() => {
assert_eq!(q.query_str(), "select * from mymodel where username = ?");
assert_eq!(q.params_str(), "sayan");
}

Loading…
Cancel
Save