Only discard part of buffer that was parsed

next
Sayan Nandan 3 years ago
parent 74893c275e
commit abb1b9bf33

@ -112,7 +112,7 @@ where
if self.get_buffer().is_empty() {
return Err(());
}
Ok(protocol::parse(&self.get_buffer()))
Ok(protocol::parse(&self.get_buffer()[..]))
}
/// Read a query from the remote end
///
@ -128,15 +128,15 @@ where
Box::pin(async move {
let mv_self = self;
let _: Result<QueryResult, IoError> = {
mv_self.read_again().await?;
loop {
mv_self.read_again().await?;
match mv_self.try_query() {
Ok(ParseResult::Query(query, forward)) => {
mv_self.advance_buffer(forward);
return Ok(QueryResult::Q(query));
}
Ok(ParseResult::BadPacket) => {
mv_self.clear_buffer();
Ok(ParseResult::BadPacket(discard_len)) => {
mv_self.advance_buffer(discard_len);
return Ok(QueryResult::E(responses::fresp::R_PACKET_ERR.to_owned()));
}
Err(_) => {
@ -144,7 +144,6 @@ where
}
_ => (),
}
mv_self.read_again().await?;
}
};
})
@ -315,7 +314,7 @@ where
}
}
pub async fn run(&mut self) -> TResult<()> {
log::debug!("SslConnectionHanler initialized to handle a remote client");
log::debug!("ConnectionHandler initialized to handle a remote client");
while !self.terminator.is_termination_signal() {
let try_df = tokio::select! {
tdf = self.con.read_query() => tdf,

@ -92,7 +92,8 @@ pub enum ParseResult {
/// The packet is incomplete, i.e more data needs to be read
Incomplete,
/// The packet is corrupted, in the sense that it contains invalid data
BadPacket,
/// This variant wraps the number of bytes that should be discarded as it is invalid
BadPacket(usize),
/// A successfully parsed query
///
/// The second field is the number of bytes that should be discarded from the buffer as it has already
@ -134,7 +135,7 @@ pub fn parse(buf: &[u8]) -> ParseResult {
*/
let mut pos = 0;
if buf[pos] != b'#' {
return ParseResult::BadPacket;
return ParseResult::BadPacket(buf.len());
} else {
pos += 1;
}
@ -156,18 +157,18 @@ pub fn parse(buf: &[u8]) -> ParseResult {
let curdig: usize = match dig.checked_sub(48) {
Some(dig) => {
if dig > 9 {
return ParseResult::BadPacket;
return ParseResult::BadPacket(buf.len());
} else {
dig.into()
}
}
None => return ParseResult::BadPacket,
None => return ParseResult::BadPacket(buf.len()),
};
action_size = (action_size * 10) + curdig;
}
// This line gives us the number of actions
} else {
return ParseResult::BadPacket;
return ParseResult::BadPacket(buf.len());
}
let mut items: Vec<ActionGroup> = Vec::with_capacity(action_size);
while pos < buf.len() && items.len() <= action_size {
@ -196,13 +197,13 @@ pub fn parse(buf: &[u8]) -> ParseResult {
if dig > 9 {
// If `dig` is greater than 9, then the current
// UTF-8 char isn't a number
return ParseResult::BadPacket;
return ParseResult::BadPacket(buf.len());
} else {
dig.into()
}
}
None => {
return ParseResult::BadPacket;
return ParseResult::BadPacket(buf.len());
}
};
current_array_size = (current_array_size * 10) + curdg; // Increment the size
@ -216,7 +217,7 @@ pub fn parse(buf: &[u8]) -> ParseResult {
if buf[pos] == b'#' {
pos += 1; // skip the '#' character
} else {
return ParseResult::BadPacket;
return ParseResult::BadPacket(buf.len());
}
while pos < buf.len() && buf[pos] != b'\n' {
let curdig: usize = match buf[pos].checked_sub(48) {
@ -224,13 +225,13 @@ pub fn parse(buf: &[u8]) -> ParseResult {
if dig > 9 {
// If `dig` is greater than 9, then the current
// UTF-8 char isn't a number
return ParseResult::BadPacket;
return ParseResult::BadPacket(buf.len());
} else {
dig.into()
}
}
None => {
return ParseResult::BadPacket;
return ParseResult::BadPacket(buf.len());
}
};
element_size = (element_size * 10) + curdig; // Increment the size
@ -253,7 +254,7 @@ pub fn parse(buf: &[u8]) -> ParseResult {
items.push(ActionGroup(actiongroup));
}
_ => {
return ParseResult::BadPacket;
return ParseResult::BadPacket(buf.len());
}
}
continue;
@ -263,7 +264,7 @@ pub fn parse(buf: &[u8]) -> ParseResult {
// parsing business, we should never reach here unless
// the packet is invalid
return ParseResult::BadPacket;
return ParseResult::BadPacket(buf.len());
}
}
}
@ -279,7 +280,7 @@ pub fn parse(buf: &[u8]) -> ParseResult {
ParseResult::Incomplete
}
} else {
ParseResult::BadPacket
ParseResult::BadPacket(buf.len())
}
}
/// Read a size line and return the following line

Loading…
Cancel
Save