From abb1b9bf338731953908172a338bc40bdcfd6d22 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Mon, 10 May 2021 08:13:35 +0530 Subject: [PATCH] Only discard part of buffer that was parsed --- server/src/dbnet/connection.rs | 11 +++++------ server/src/protocol/mod.rs | 27 ++++++++++++++------------- 2 files changed, 19 insertions(+), 19 deletions(-) diff --git a/server/src/dbnet/connection.rs b/server/src/dbnet/connection.rs index 677af371..da52fd8d 100644 --- a/server/src/dbnet/connection.rs +++ b/server/src/dbnet/connection.rs @@ -112,7 +112,7 @@ where if self.get_buffer().is_empty() { return Err(()); } - Ok(protocol::parse(&self.get_buffer())) + Ok(protocol::parse(&self.get_buffer()[..])) } /// Read a query from the remote end /// @@ -128,15 +128,15 @@ where Box::pin(async move { let mv_self = self; let _: Result = { - mv_self.read_again().await?; loop { + mv_self.read_again().await?; match mv_self.try_query() { Ok(ParseResult::Query(query, forward)) => { mv_self.advance_buffer(forward); return Ok(QueryResult::Q(query)); } - Ok(ParseResult::BadPacket) => { - mv_self.clear_buffer(); + Ok(ParseResult::BadPacket(discard_len)) => { + mv_self.advance_buffer(discard_len); return Ok(QueryResult::E(responses::fresp::R_PACKET_ERR.to_owned())); } Err(_) => { @@ -144,7 +144,6 @@ where } _ => (), } - mv_self.read_again().await?; } }; }) @@ -315,7 +314,7 @@ where } } pub async fn run(&mut self) -> TResult<()> { - log::debug!("SslConnectionHanler initialized to handle a remote client"); + log::debug!("ConnectionHandler initialized to handle a remote client"); while !self.terminator.is_termination_signal() { let try_df = tokio::select! { tdf = self.con.read_query() => tdf, diff --git a/server/src/protocol/mod.rs b/server/src/protocol/mod.rs index 19a2acb9..e3cb79cc 100644 --- a/server/src/protocol/mod.rs +++ b/server/src/protocol/mod.rs @@ -92,7 +92,8 @@ pub enum ParseResult { /// The packet is incomplete, i.e more data needs to be read Incomplete, /// The packet is corrupted, in the sense that it contains invalid data - BadPacket, + /// This variant wraps the number of bytes that should be discarded as it is invalid + BadPacket(usize), /// A successfully parsed query /// /// The second field is the number of bytes that should be discarded from the buffer as it has already @@ -134,7 +135,7 @@ pub fn parse(buf: &[u8]) -> ParseResult { */ let mut pos = 0; if buf[pos] != b'#' { - return ParseResult::BadPacket; + return ParseResult::BadPacket(buf.len()); } else { pos += 1; } @@ -156,18 +157,18 @@ pub fn parse(buf: &[u8]) -> ParseResult { let curdig: usize = match dig.checked_sub(48) { Some(dig) => { if dig > 9 { - return ParseResult::BadPacket; + return ParseResult::BadPacket(buf.len()); } else { dig.into() } } - None => return ParseResult::BadPacket, + None => return ParseResult::BadPacket(buf.len()), }; action_size = (action_size * 10) + curdig; } // This line gives us the number of actions } else { - return ParseResult::BadPacket; + return ParseResult::BadPacket(buf.len()); } let mut items: Vec = Vec::with_capacity(action_size); while pos < buf.len() && items.len() <= action_size { @@ -196,13 +197,13 @@ pub fn parse(buf: &[u8]) -> ParseResult { if dig > 9 { // If `dig` is greater than 9, then the current // UTF-8 char isn't a number - return ParseResult::BadPacket; + return ParseResult::BadPacket(buf.len()); } else { dig.into() } } None => { - return ParseResult::BadPacket; + return ParseResult::BadPacket(buf.len()); } }; current_array_size = (current_array_size * 10) + curdg; // Increment the size @@ -216,7 +217,7 @@ pub fn parse(buf: &[u8]) -> ParseResult { if buf[pos] == b'#' { pos += 1; // skip the '#' character } else { - return ParseResult::BadPacket; + return ParseResult::BadPacket(buf.len()); } while pos < buf.len() && buf[pos] != b'\n' { let curdig: usize = match buf[pos].checked_sub(48) { @@ -224,13 +225,13 @@ pub fn parse(buf: &[u8]) -> ParseResult { if dig > 9 { // If `dig` is greater than 9, then the current // UTF-8 char isn't a number - return ParseResult::BadPacket; + return ParseResult::BadPacket(buf.len()); } else { dig.into() } } None => { - return ParseResult::BadPacket; + return ParseResult::BadPacket(buf.len()); } }; element_size = (element_size * 10) + curdig; // Increment the size @@ -253,7 +254,7 @@ pub fn parse(buf: &[u8]) -> ParseResult { items.push(ActionGroup(actiongroup)); } _ => { - return ParseResult::BadPacket; + return ParseResult::BadPacket(buf.len()); } } continue; @@ -263,7 +264,7 @@ pub fn parse(buf: &[u8]) -> ParseResult { // parsing business, we should never reach here unless // the packet is invalid - return ParseResult::BadPacket; + return ParseResult::BadPacket(buf.len()); } } } @@ -279,7 +280,7 @@ pub fn parse(buf: &[u8]) -> ParseResult { ParseResult::Incomplete } } else { - ParseResult::BadPacket + ParseResult::BadPacket(buf.len()) } } /// Read a size line and return the following line