net: Add more protocol tests

next
Sayan Nandan 6 months ago
parent d3b5fd8060
commit 2d905d07d0
No known key found for this signature in database
GPG Key ID: 0EBD769024B24F0A

@ -54,7 +54,7 @@ use {
sq definition
*/
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct SQuery<'a> {
buf: &'a [u8],
q_window: usize,
@ -76,7 +76,9 @@ impl<'a> SQuery<'a> {
scanint
*/
fn scan_usize_guaranteed_termination(scanner: &mut BufferedScanner) -> Result<usize, ()> {
fn scan_usize_guaranteed_termination(
scanner: &mut BufferedScanner,
) -> Result<usize, ExchangeError> {
let mut ret = 0usize;
let mut stop = scanner.rounded_eq(b'\n');
while !scanner.eof() & !stop {
@ -89,7 +91,7 @@ fn scan_usize_guaranteed_termination(scanner: &mut BufferedScanner) -> Result<us
.map(|int| int.checked_add((next_byte & 0x0f) as usize))
{
Some(Some(int)) if next_byte.is_ascii_digit() => ret = int,
_ => return Err(()),
_ => return Err(ExchangeError::NotAsciiByteOrOverflow),
}
stop = scanner.rounded_eq(b'\n');
}
@ -100,11 +102,11 @@ fn scan_usize_guaranteed_termination(scanner: &mut BufferedScanner) -> Result<us
if stop {
Ok(ret)
} else {
Err(())
Err(ExchangeError::UnterminatedInteger)
}
}
#[derive(Clone, Copy)]
#[derive(Clone, Copy, PartialEq)]
struct Usize {
v: isize,
}
@ -181,14 +183,14 @@ impl Usize {
states
*/
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub enum ExchangeState {
Initial,
Simple(SQState),
Pipeline(PipeState),
}
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct SQState {
packet_s: Usize,
}
@ -199,7 +201,7 @@ impl SQState {
}
}
#[derive(Debug)]
#[derive(Debug, PartialEq)]
pub struct PipeState {
packet_s: Usize,
}
@ -216,12 +218,22 @@ impl Default for ExchangeState {
}
}
#[derive(Debug, PartialEq)]
pub enum ExchangeResult<'a> {
NewState(ExchangeState),
Simple(SQuery<'a>),
Pipeline(Pipeline<'a>),
}
#[derive(Debug, PartialEq, Clone, Copy)]
#[repr(u8)]
pub enum ExchangeError {
UnknownFirstByte,
NotAsciiByteOrOverflow,
UnterminatedInteger,
IncorrectQuerySizeOrMoreBytes,
}
pub struct Exchange<'a> {
scanner: BufferedScanner<'a>,
}
@ -234,13 +246,16 @@ impl<'a> Exchange<'a> {
pub fn try_complete(
scanner: BufferedScanner<'a>,
state: ExchangeState,
) -> Result<(ExchangeResult, usize), ()> {
) -> Result<(ExchangeResult, usize), ExchangeError> {
Self::new(scanner).complete(state)
}
}
impl<'a> Exchange<'a> {
fn complete(mut self, state: ExchangeState) -> Result<(ExchangeResult<'a>, usize), ()> {
fn complete(
mut self,
state: ExchangeState,
) -> Result<(ExchangeResult<'a>, usize), ExchangeError> {
match state {
ExchangeState::Initial => {
if compiler::likely(self.scanner.has_left(Self::MIN_Q_SIZE)) {
@ -251,7 +266,7 @@ impl<'a> Exchange<'a> {
match first_byte {
b'S' => self.process_simple(SQState::new(Usize::new_unflagged(0))),
b'P' => self.process_pipe(PipeState::new(Usize::new_unflagged(0))),
_ => return Err(()),
_ => return Err(ExchangeError::UnknownFirstByte),
}
} else {
Ok(ExchangeResult::NewState(state))
@ -262,29 +277,38 @@ impl<'a> Exchange<'a> {
}
.map(|ret| (ret, self.scanner.cursor()))
}
fn process_simple(&mut self, mut sq_state: SQState) -> Result<ExchangeResult<'a>, ()> {
fn process_simple(
&mut self,
mut sq_state: SQState,
) -> Result<ExchangeResult<'a>, ExchangeError> {
// try to complete the packet size if needed
sq_state.packet_s.update_scanned(&mut self.scanner)?;
if sq_state.packet_s.flag() & self.scanner.remaining_size_is(sq_state.packet_s.int()) {
sq_state
.packet_s
.update_scanned(&mut self.scanner)
.map_err(|_| ExchangeError::NotAsciiByteOrOverflow)?;
if sq_state.packet_s.flag() & self.scanner.has_left(sq_state.packet_s.int()) {
// we have the full packet size and the required data
let q_window = scan_usize_guaranteed_termination(&mut self.scanner)?;
let nonzero = (q_window != 0) & (sq_state.packet_s.int() != 0);
if compiler::likely(self.scanner.remaining_size_is(q_window) & nonzero) {
if compiler::likely(self.scanner.remaining_size_is(sq_state.packet_s.int()) & nonzero) {
// this check is important because the client might have given us an incorrect q size
Ok(ExchangeResult::Simple(SQuery::new(
self.scanner.current_buffer(),
q_window,
)))
} else {
Err(())
Err(ExchangeError::IncorrectQuerySizeOrMoreBytes)
}
} else {
Ok(ExchangeResult::NewState(ExchangeState::Simple(sq_state)))
}
}
fn process_pipe(&mut self, mut pipe_s: PipeState) -> Result<ExchangeResult<'a>, ()> {
fn process_pipe(&mut self, mut pipe_s: PipeState) -> Result<ExchangeResult<'a>, ExchangeError> {
// try to complete the packet size if needed
pipe_s.packet_s.update_scanned(&mut self.scanner)?;
pipe_s
.packet_s
.update_scanned(&mut self.scanner)
.map_err(|_| ExchangeError::NotAsciiByteOrOverflow)?;
if pipe_s.packet_s.flag() & self.scanner.remaining_size_is(pipe_s.packet_s.int()) {
// great, we have the entire packet
Ok(ExchangeResult::Pipeline(Pipeline::new(
@ -300,6 +324,7 @@ impl<'a> Exchange<'a> {
pipeline
*/
#[derive(Debug, PartialEq)]
pub struct Pipeline<'a> {
scanner: BufferedScanner<'a>,
}
@ -310,7 +335,7 @@ impl<'a> Pipeline<'a> {
scanner: BufferedScanner::new(buf),
}
}
pub fn next_query(&mut self) -> Result<Option<SQuery<'a>>, ()> {
pub fn next_query(&mut self) -> Result<Option<SQuery<'a>>, ExchangeError> {
let nonzero = self.scanner.buffer_len() != 0;
if self.scanner.eof() & nonzero {
Ok(None)
@ -318,13 +343,13 @@ impl<'a> Pipeline<'a> {
let query_size = scan_usize_guaranteed_termination(&mut self.scanner)?;
let param_size = scan_usize_guaranteed_termination(&mut self.scanner)?;
let (full_size, overflow) = param_size.overflowing_add(query_size);
if compiler::likely(self.scanner.remaining_size_is(full_size) & !overflow) {
if compiler::likely(self.scanner.has_left(full_size) & !overflow) {
Ok(Some(SQuery {
buf: self.scanner.current_buffer(),
q_window: query_size,
}))
} else {
Err(())
Err(ExchangeError::IncorrectQuerySizeOrMoreBytes)
}
}
}

@ -259,7 +259,7 @@ pub(super) async fn query_loop<S: Socket>(
(state, cursor) = cleanup_for_next_query(con, buf).await?;
}
},
Err(()) => {
Err(_) => {
// respond with error
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16)
.to_le_bytes();
@ -375,7 +375,7 @@ async fn exec_pipe<'a, S: Socket>(
)
.await?
}
Err(()) => {
Err(_) => {
// respond with error
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16)
.to_le_bytes();

@ -25,7 +25,10 @@
*/
use {
super::handshake::ProtocolError,
super::{
exchange::{Exchange, ExchangeError, ExchangeResult, ExchangeState},
handshake::ProtocolError,
},
crate::engine::{
mem::BufferedScanner,
net::protocol::handshake::{
@ -270,3 +273,86 @@ fn hs_bad_auth_mode() {
assert_eq!(hs_result, HandshakeResult::Error(ProtocolError::RejectAuth))
})
}
/*
QT-DEX
*/
fn iterate_payload(payload: &str, start: usize, f: impl Fn(usize, &[u8])) {
for i in start..payload.len() {
f(i, &payload.as_bytes()[..i])
}
}
fn iterate_exchange_payload(
payload: &str,
start: usize,
f: impl Fn(usize, Result<(ExchangeResult, usize), ExchangeError>),
) {
iterate_payload(payload, start, |i, bytes| {
let scanner = BufferedScanner::new(bytes);
f(i, Exchange::try_complete(scanner, ExchangeState::default()))
})
}
fn iterate_exchange_payload_from_zero(
payload: &str,
f: impl Fn(usize, Result<(ExchangeResult, usize), ExchangeError>),
) {
iterate_exchange_payload(payload, 0, f)
}
/*
corner cases
*/
#[test]
fn zero_sized_packet() {
for payload in [
"S\n", // zero packet
"S0\n", // zero packet
"S2\n0\n", // zero query
"S1\n\n", // zero query
] {
iterate_exchange_payload_from_zero(payload, |size, result| {
if size == payload.len() {
// we got the full payload
assert_eq!(result, Err(ExchangeError::IncorrectQuerySizeOrMoreBytes))
} else {
// we don't have the full payload
if size < 3 {
assert_eq!(
result,
Ok((ExchangeResult::NewState(ExchangeState::Initial), 0))
)
} else {
assert!(
matches!(
result,
Ok((ExchangeResult::NewState(ExchangeState::Simple(_)), _))
),
"failed for {:?}, result is {:?}",
&payload[..size],
result,
);
}
}
});
}
}
#[test]
fn invalid_first_byte() {
for payload in ["A1\n\n", "B7\n5\nsayan"] {
iterate_exchange_payload(payload, 1, |size, result| {
if size >= 3 {
assert_eq!(result, Err(ExchangeError::UnknownFirstByte))
} else {
assert_eq!(
result,
Ok((ExchangeResult::NewState(ExchangeState::Initial), 0))
)
}
})
}
}

Loading…
Cancel
Save