diff --git a/server/src/engine/net/protocol/exchange.rs b/server/src/engine/net/protocol/exchange.rs index 419054ac..9f510949 100644 --- a/server/src/engine/net/protocol/exchange.rs +++ b/server/src/engine/net/protocol/exchange.rs @@ -54,7 +54,7 @@ use { sq definition */ -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub struct SQuery<'a> { buf: &'a [u8], q_window: usize, @@ -76,7 +76,9 @@ impl<'a> SQuery<'a> { scanint */ -fn scan_usize_guaranteed_termination(scanner: &mut BufferedScanner) -> Result { +fn scan_usize_guaranteed_termination( + scanner: &mut BufferedScanner, +) -> Result { let mut ret = 0usize; let mut stop = scanner.rounded_eq(b'\n'); while !scanner.eof() & !stop { @@ -89,7 +91,7 @@ fn scan_usize_guaranteed_termination(scanner: &mut BufferedScanner) -> Result ret = int, - _ => return Err(()), + _ => return Err(ExchangeError::NotAsciiByteOrOverflow), } stop = scanner.rounded_eq(b'\n'); } @@ -100,11 +102,11 @@ fn scan_usize_guaranteed_termination(scanner: &mut BufferedScanner) -> Result { NewState(ExchangeState), Simple(SQuery<'a>), Pipeline(Pipeline<'a>), } +#[derive(Debug, PartialEq, Clone, Copy)] +#[repr(u8)] +pub enum ExchangeError { + UnknownFirstByte, + NotAsciiByteOrOverflow, + UnterminatedInteger, + IncorrectQuerySizeOrMoreBytes, +} + pub struct Exchange<'a> { scanner: BufferedScanner<'a>, } @@ -234,13 +246,16 @@ impl<'a> Exchange<'a> { pub fn try_complete( scanner: BufferedScanner<'a>, state: ExchangeState, - ) -> Result<(ExchangeResult, usize), ()> { + ) -> Result<(ExchangeResult, usize), ExchangeError> { Self::new(scanner).complete(state) } } impl<'a> Exchange<'a> { - fn complete(mut self, state: ExchangeState) -> Result<(ExchangeResult<'a>, usize), ()> { + fn complete( + mut self, + state: ExchangeState, + ) -> Result<(ExchangeResult<'a>, usize), ExchangeError> { match state { ExchangeState::Initial => { if compiler::likely(self.scanner.has_left(Self::MIN_Q_SIZE)) { @@ -251,7 +266,7 @@ impl<'a> Exchange<'a> { match first_byte { b'S' => self.process_simple(SQState::new(Usize::new_unflagged(0))), b'P' => self.process_pipe(PipeState::new(Usize::new_unflagged(0))), - _ => return Err(()), + _ => return Err(ExchangeError::UnknownFirstByte), } } else { Ok(ExchangeResult::NewState(state)) @@ -262,29 +277,38 @@ impl<'a> Exchange<'a> { } .map(|ret| (ret, self.scanner.cursor())) } - fn process_simple(&mut self, mut sq_state: SQState) -> Result, ()> { + fn process_simple( + &mut self, + mut sq_state: SQState, + ) -> Result, ExchangeError> { // try to complete the packet size if needed - sq_state.packet_s.update_scanned(&mut self.scanner)?; - if sq_state.packet_s.flag() & self.scanner.remaining_size_is(sq_state.packet_s.int()) { + sq_state + .packet_s + .update_scanned(&mut self.scanner) + .map_err(|_| ExchangeError::NotAsciiByteOrOverflow)?; + if sq_state.packet_s.flag() & self.scanner.has_left(sq_state.packet_s.int()) { // we have the full packet size and the required data let q_window = scan_usize_guaranteed_termination(&mut self.scanner)?; let nonzero = (q_window != 0) & (sq_state.packet_s.int() != 0); - if compiler::likely(self.scanner.remaining_size_is(q_window) & nonzero) { + if compiler::likely(self.scanner.remaining_size_is(sq_state.packet_s.int()) & nonzero) { // this check is important because the client might have given us an incorrect q size Ok(ExchangeResult::Simple(SQuery::new( self.scanner.current_buffer(), q_window, ))) } else { - Err(()) + Err(ExchangeError::IncorrectQuerySizeOrMoreBytes) } } else { Ok(ExchangeResult::NewState(ExchangeState::Simple(sq_state))) } } - fn process_pipe(&mut self, mut pipe_s: PipeState) -> Result, ()> { + fn process_pipe(&mut self, mut pipe_s: PipeState) -> Result, ExchangeError> { // try to complete the packet size if needed - pipe_s.packet_s.update_scanned(&mut self.scanner)?; + pipe_s + .packet_s + .update_scanned(&mut self.scanner) + .map_err(|_| ExchangeError::NotAsciiByteOrOverflow)?; if pipe_s.packet_s.flag() & self.scanner.remaining_size_is(pipe_s.packet_s.int()) { // great, we have the entire packet Ok(ExchangeResult::Pipeline(Pipeline::new( @@ -300,6 +324,7 @@ impl<'a> Exchange<'a> { pipeline */ +#[derive(Debug, PartialEq)] pub struct Pipeline<'a> { scanner: BufferedScanner<'a>, } @@ -310,7 +335,7 @@ impl<'a> Pipeline<'a> { scanner: BufferedScanner::new(buf), } } - pub fn next_query(&mut self) -> Result>, ()> { + pub fn next_query(&mut self) -> Result>, ExchangeError> { let nonzero = self.scanner.buffer_len() != 0; if self.scanner.eof() & nonzero { Ok(None) @@ -318,13 +343,13 @@ impl<'a> Pipeline<'a> { let query_size = scan_usize_guaranteed_termination(&mut self.scanner)?; let param_size = scan_usize_guaranteed_termination(&mut self.scanner)?; let (full_size, overflow) = param_size.overflowing_add(query_size); - if compiler::likely(self.scanner.remaining_size_is(full_size) & !overflow) { + if compiler::likely(self.scanner.has_left(full_size) & !overflow) { Ok(Some(SQuery { buf: self.scanner.current_buffer(), q_window: query_size, })) } else { - Err(()) + Err(ExchangeError::IncorrectQuerySizeOrMoreBytes) } } } diff --git a/server/src/engine/net/protocol/mod.rs b/server/src/engine/net/protocol/mod.rs index 93a57764..25bc4cf3 100644 --- a/server/src/engine/net/protocol/mod.rs +++ b/server/src/engine/net/protocol/mod.rs @@ -259,7 +259,7 @@ pub(super) async fn query_loop( (state, cursor) = cleanup_for_next_query(con, buf).await?; } }, - Err(()) => { + Err(_) => { // respond with error let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16) .to_le_bytes(); @@ -375,7 +375,7 @@ async fn exec_pipe<'a, S: Socket>( ) .await? } - Err(()) => { + Err(_) => { // respond with error let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16) .to_le_bytes(); diff --git a/server/src/engine/net/protocol/tests.rs b/server/src/engine/net/protocol/tests.rs index 85355527..bd9d2cf9 100644 --- a/server/src/engine/net/protocol/tests.rs +++ b/server/src/engine/net/protocol/tests.rs @@ -25,7 +25,10 @@ */ use { - super::handshake::ProtocolError, + super::{ + exchange::{Exchange, ExchangeError, ExchangeResult, ExchangeState}, + handshake::ProtocolError, + }, crate::engine::{ mem::BufferedScanner, net::protocol::handshake::{ @@ -270,3 +273,86 @@ fn hs_bad_auth_mode() { assert_eq!(hs_result, HandshakeResult::Error(ProtocolError::RejectAuth)) }) } + +/* + QT-DEX +*/ + +fn iterate_payload(payload: &str, start: usize, f: impl Fn(usize, &[u8])) { + for i in start..payload.len() { + f(i, &payload.as_bytes()[..i]) + } +} + +fn iterate_exchange_payload( + payload: &str, + start: usize, + f: impl Fn(usize, Result<(ExchangeResult, usize), ExchangeError>), +) { + iterate_payload(payload, start, |i, bytes| { + let scanner = BufferedScanner::new(bytes); + f(i, Exchange::try_complete(scanner, ExchangeState::default())) + }) +} + +fn iterate_exchange_payload_from_zero( + payload: &str, + f: impl Fn(usize, Result<(ExchangeResult, usize), ExchangeError>), +) { + iterate_exchange_payload(payload, 0, f) +} + +/* + corner cases +*/ + +#[test] +fn zero_sized_packet() { + for payload in [ + "S\n", // zero packet + "S0\n", // zero packet + "S2\n0\n", // zero query + "S1\n\n", // zero query + ] { + iterate_exchange_payload_from_zero(payload, |size, result| { + if size == payload.len() { + // we got the full payload + assert_eq!(result, Err(ExchangeError::IncorrectQuerySizeOrMoreBytes)) + } else { + // we don't have the full payload + if size < 3 { + assert_eq!( + result, + Ok((ExchangeResult::NewState(ExchangeState::Initial), 0)) + ) + } else { + assert!( + matches!( + result, + Ok((ExchangeResult::NewState(ExchangeState::Simple(_)), _)) + ), + "failed for {:?}, result is {:?}", + &payload[..size], + result, + ); + } + } + }); + } +} + +#[test] +fn invalid_first_byte() { + for payload in ["A1\n\n", "B7\n5\nsayan"] { + iterate_exchange_payload(payload, 1, |size, result| { + if size >= 3 { + assert_eq!(result, Err(ExchangeError::UnknownFirstByte)) + } else { + assert_eq!( + result, + Ok((ExchangeResult::NewState(ExchangeState::Initial), 0)) + ) + } + }) + } +}