net: Add more protocol tests

next
Sayan Nandan 6 months ago
parent d3b5fd8060
commit 2d905d07d0
No known key found for this signature in database
GPG Key ID: 0EBD769024B24F0A

@ -54,7 +54,7 @@ use {
sq definition sq definition
*/ */
#[derive(Debug)] #[derive(Debug, PartialEq)]
pub struct SQuery<'a> { pub struct SQuery<'a> {
buf: &'a [u8], buf: &'a [u8],
q_window: usize, q_window: usize,
@ -76,7 +76,9 @@ impl<'a> SQuery<'a> {
scanint scanint
*/ */
fn scan_usize_guaranteed_termination(scanner: &mut BufferedScanner) -> Result<usize, ()> { fn scan_usize_guaranteed_termination(
scanner: &mut BufferedScanner,
) -> Result<usize, ExchangeError> {
let mut ret = 0usize; let mut ret = 0usize;
let mut stop = scanner.rounded_eq(b'\n'); let mut stop = scanner.rounded_eq(b'\n');
while !scanner.eof() & !stop { while !scanner.eof() & !stop {
@ -89,7 +91,7 @@ fn scan_usize_guaranteed_termination(scanner: &mut BufferedScanner) -> Result<us
.map(|int| int.checked_add((next_byte & 0x0f) as usize)) .map(|int| int.checked_add((next_byte & 0x0f) as usize))
{ {
Some(Some(int)) if next_byte.is_ascii_digit() => ret = int, Some(Some(int)) if next_byte.is_ascii_digit() => ret = int,
_ => return Err(()), _ => return Err(ExchangeError::NotAsciiByteOrOverflow),
} }
stop = scanner.rounded_eq(b'\n'); stop = scanner.rounded_eq(b'\n');
} }
@ -100,11 +102,11 @@ fn scan_usize_guaranteed_termination(scanner: &mut BufferedScanner) -> Result<us
if stop { if stop {
Ok(ret) Ok(ret)
} else { } else {
Err(()) Err(ExchangeError::UnterminatedInteger)
} }
} }
#[derive(Clone, Copy)] #[derive(Clone, Copy, PartialEq)]
struct Usize { struct Usize {
v: isize, v: isize,
} }
@ -181,14 +183,14 @@ impl Usize {
states states
*/ */
#[derive(Debug)] #[derive(Debug, PartialEq)]
pub enum ExchangeState { pub enum ExchangeState {
Initial, Initial,
Simple(SQState), Simple(SQState),
Pipeline(PipeState), Pipeline(PipeState),
} }
#[derive(Debug)] #[derive(Debug, PartialEq)]
pub struct SQState { pub struct SQState {
packet_s: Usize, packet_s: Usize,
} }
@ -199,7 +201,7 @@ impl SQState {
} }
} }
#[derive(Debug)] #[derive(Debug, PartialEq)]
pub struct PipeState { pub struct PipeState {
packet_s: Usize, packet_s: Usize,
} }
@ -216,12 +218,22 @@ impl Default for ExchangeState {
} }
} }
#[derive(Debug, PartialEq)]
pub enum ExchangeResult<'a> { pub enum ExchangeResult<'a> {
NewState(ExchangeState), NewState(ExchangeState),
Simple(SQuery<'a>), Simple(SQuery<'a>),
Pipeline(Pipeline<'a>), Pipeline(Pipeline<'a>),
} }
#[derive(Debug, PartialEq, Clone, Copy)]
#[repr(u8)]
pub enum ExchangeError {
UnknownFirstByte,
NotAsciiByteOrOverflow,
UnterminatedInteger,
IncorrectQuerySizeOrMoreBytes,
}
pub struct Exchange<'a> { pub struct Exchange<'a> {
scanner: BufferedScanner<'a>, scanner: BufferedScanner<'a>,
} }
@ -234,13 +246,16 @@ impl<'a> Exchange<'a> {
pub fn try_complete( pub fn try_complete(
scanner: BufferedScanner<'a>, scanner: BufferedScanner<'a>,
state: ExchangeState, state: ExchangeState,
) -> Result<(ExchangeResult, usize), ()> { ) -> Result<(ExchangeResult, usize), ExchangeError> {
Self::new(scanner).complete(state) Self::new(scanner).complete(state)
} }
} }
impl<'a> Exchange<'a> { impl<'a> Exchange<'a> {
fn complete(mut self, state: ExchangeState) -> Result<(ExchangeResult<'a>, usize), ()> { fn complete(
mut self,
state: ExchangeState,
) -> Result<(ExchangeResult<'a>, usize), ExchangeError> {
match state { match state {
ExchangeState::Initial => { ExchangeState::Initial => {
if compiler::likely(self.scanner.has_left(Self::MIN_Q_SIZE)) { if compiler::likely(self.scanner.has_left(Self::MIN_Q_SIZE)) {
@ -251,7 +266,7 @@ impl<'a> Exchange<'a> {
match first_byte { match first_byte {
b'S' => self.process_simple(SQState::new(Usize::new_unflagged(0))), b'S' => self.process_simple(SQState::new(Usize::new_unflagged(0))),
b'P' => self.process_pipe(PipeState::new(Usize::new_unflagged(0))), b'P' => self.process_pipe(PipeState::new(Usize::new_unflagged(0))),
_ => return Err(()), _ => return Err(ExchangeError::UnknownFirstByte),
} }
} else { } else {
Ok(ExchangeResult::NewState(state)) Ok(ExchangeResult::NewState(state))
@ -262,29 +277,38 @@ impl<'a> Exchange<'a> {
} }
.map(|ret| (ret, self.scanner.cursor())) .map(|ret| (ret, self.scanner.cursor()))
} }
fn process_simple(&mut self, mut sq_state: SQState) -> Result<ExchangeResult<'a>, ()> { fn process_simple(
&mut self,
mut sq_state: SQState,
) -> Result<ExchangeResult<'a>, ExchangeError> {
// try to complete the packet size if needed // try to complete the packet size if needed
sq_state.packet_s.update_scanned(&mut self.scanner)?; sq_state
if sq_state.packet_s.flag() & self.scanner.remaining_size_is(sq_state.packet_s.int()) { .packet_s
.update_scanned(&mut self.scanner)
.map_err(|_| ExchangeError::NotAsciiByteOrOverflow)?;
if sq_state.packet_s.flag() & self.scanner.has_left(sq_state.packet_s.int()) {
// we have the full packet size and the required data // we have the full packet size and the required data
let q_window = scan_usize_guaranteed_termination(&mut self.scanner)?; let q_window = scan_usize_guaranteed_termination(&mut self.scanner)?;
let nonzero = (q_window != 0) & (sq_state.packet_s.int() != 0); let nonzero = (q_window != 0) & (sq_state.packet_s.int() != 0);
if compiler::likely(self.scanner.remaining_size_is(q_window) & nonzero) { if compiler::likely(self.scanner.remaining_size_is(sq_state.packet_s.int()) & nonzero) {
// this check is important because the client might have given us an incorrect q size // this check is important because the client might have given us an incorrect q size
Ok(ExchangeResult::Simple(SQuery::new( Ok(ExchangeResult::Simple(SQuery::new(
self.scanner.current_buffer(), self.scanner.current_buffer(),
q_window, q_window,
))) )))
} else { } else {
Err(()) Err(ExchangeError::IncorrectQuerySizeOrMoreBytes)
} }
} else { } else {
Ok(ExchangeResult::NewState(ExchangeState::Simple(sq_state))) Ok(ExchangeResult::NewState(ExchangeState::Simple(sq_state)))
} }
} }
fn process_pipe(&mut self, mut pipe_s: PipeState) -> Result<ExchangeResult<'a>, ()> { fn process_pipe(&mut self, mut pipe_s: PipeState) -> Result<ExchangeResult<'a>, ExchangeError> {
// try to complete the packet size if needed // try to complete the packet size if needed
pipe_s.packet_s.update_scanned(&mut self.scanner)?; pipe_s
.packet_s
.update_scanned(&mut self.scanner)
.map_err(|_| ExchangeError::NotAsciiByteOrOverflow)?;
if pipe_s.packet_s.flag() & self.scanner.remaining_size_is(pipe_s.packet_s.int()) { if pipe_s.packet_s.flag() & self.scanner.remaining_size_is(pipe_s.packet_s.int()) {
// great, we have the entire packet // great, we have the entire packet
Ok(ExchangeResult::Pipeline(Pipeline::new( Ok(ExchangeResult::Pipeline(Pipeline::new(
@ -300,6 +324,7 @@ impl<'a> Exchange<'a> {
pipeline pipeline
*/ */
#[derive(Debug, PartialEq)]
pub struct Pipeline<'a> { pub struct Pipeline<'a> {
scanner: BufferedScanner<'a>, scanner: BufferedScanner<'a>,
} }
@ -310,7 +335,7 @@ impl<'a> Pipeline<'a> {
scanner: BufferedScanner::new(buf), scanner: BufferedScanner::new(buf),
} }
} }
pub fn next_query(&mut self) -> Result<Option<SQuery<'a>>, ()> { pub fn next_query(&mut self) -> Result<Option<SQuery<'a>>, ExchangeError> {
let nonzero = self.scanner.buffer_len() != 0; let nonzero = self.scanner.buffer_len() != 0;
if self.scanner.eof() & nonzero { if self.scanner.eof() & nonzero {
Ok(None) Ok(None)
@ -318,13 +343,13 @@ impl<'a> Pipeline<'a> {
let query_size = scan_usize_guaranteed_termination(&mut self.scanner)?; let query_size = scan_usize_guaranteed_termination(&mut self.scanner)?;
let param_size = scan_usize_guaranteed_termination(&mut self.scanner)?; let param_size = scan_usize_guaranteed_termination(&mut self.scanner)?;
let (full_size, overflow) = param_size.overflowing_add(query_size); let (full_size, overflow) = param_size.overflowing_add(query_size);
if compiler::likely(self.scanner.remaining_size_is(full_size) & !overflow) { if compiler::likely(self.scanner.has_left(full_size) & !overflow) {
Ok(Some(SQuery { Ok(Some(SQuery {
buf: self.scanner.current_buffer(), buf: self.scanner.current_buffer(),
q_window: query_size, q_window: query_size,
})) }))
} else { } else {
Err(()) Err(ExchangeError::IncorrectQuerySizeOrMoreBytes)
} }
} }
} }

@ -259,7 +259,7 @@ pub(super) async fn query_loop<S: Socket>(
(state, cursor) = cleanup_for_next_query(con, buf).await?; (state, cursor) = cleanup_for_next_query(con, buf).await?;
} }
}, },
Err(()) => { Err(_) => {
// respond with error // respond with error
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16) let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16)
.to_le_bytes(); .to_le_bytes();
@ -375,7 +375,7 @@ async fn exec_pipe<'a, S: Socket>(
) )
.await? .await?
} }
Err(()) => { Err(_) => {
// respond with error // respond with error
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16) let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16)
.to_le_bytes(); .to_le_bytes();

@ -25,7 +25,10 @@
*/ */
use { use {
super::handshake::ProtocolError, super::{
exchange::{Exchange, ExchangeError, ExchangeResult, ExchangeState},
handshake::ProtocolError,
},
crate::engine::{ crate::engine::{
mem::BufferedScanner, mem::BufferedScanner,
net::protocol::handshake::{ net::protocol::handshake::{
@ -270,3 +273,86 @@ fn hs_bad_auth_mode() {
assert_eq!(hs_result, HandshakeResult::Error(ProtocolError::RejectAuth)) assert_eq!(hs_result, HandshakeResult::Error(ProtocolError::RejectAuth))
}) })
} }
/*
QT-DEX
*/
fn iterate_payload(payload: &str, start: usize, f: impl Fn(usize, &[u8])) {
for i in start..payload.len() {
f(i, &payload.as_bytes()[..i])
}
}
fn iterate_exchange_payload(
payload: &str,
start: usize,
f: impl Fn(usize, Result<(ExchangeResult, usize), ExchangeError>),
) {
iterate_payload(payload, start, |i, bytes| {
let scanner = BufferedScanner::new(bytes);
f(i, Exchange::try_complete(scanner, ExchangeState::default()))
})
}
fn iterate_exchange_payload_from_zero(
payload: &str,
f: impl Fn(usize, Result<(ExchangeResult, usize), ExchangeError>),
) {
iterate_exchange_payload(payload, 0, f)
}
/*
corner cases
*/
#[test]
fn zero_sized_packet() {
for payload in [
"S\n", // zero packet
"S0\n", // zero packet
"S2\n0\n", // zero query
"S1\n\n", // zero query
] {
iterate_exchange_payload_from_zero(payload, |size, result| {
if size == payload.len() {
// we got the full payload
assert_eq!(result, Err(ExchangeError::IncorrectQuerySizeOrMoreBytes))
} else {
// we don't have the full payload
if size < 3 {
assert_eq!(
result,
Ok((ExchangeResult::NewState(ExchangeState::Initial), 0))
)
} else {
assert!(
matches!(
result,
Ok((ExchangeResult::NewState(ExchangeState::Simple(_)), _))
),
"failed for {:?}, result is {:?}",
&payload[..size],
result,
);
}
}
});
}
}
#[test]
fn invalid_first_byte() {
for payload in ["A1\n\n", "B7\n5\nsayan"] {
iterate_exchange_payload(payload, 1, |size, result| {
if size >= 3 {
assert_eq!(result, Err(ExchangeError::UnknownFirstByte))
} else {
assert_eq!(
result,
Ok((ExchangeResult::NewState(ExchangeState::Initial), 0))
)
}
})
}
}

Loading…
Cancel
Save