From d91e696e24661cb0ee520a2c07d01c17e8a8e46d Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Mon, 10 May 2021 13:18:23 +0530 Subject: [PATCH] Add more tests --- server/src/protocol/parserv2.rs | 85 ++++++++++++++++++++++++++++----- 1 file changed, 74 insertions(+), 11 deletions(-) diff --git a/server/src/protocol/parserv2.rs b/server/src/protocol/parserv2.rs index 9d3c8e4a..49a00ad7 100644 --- a/server/src/protocol/parserv2.rs +++ b/server/src/protocol/parserv2.rs @@ -33,10 +33,11 @@ pub(super) struct Parser<'a> { buffer: &'a [u8], } -#[derive(Debug)] +#[derive(Debug, PartialEq)] pub enum ParseError { NotEnough, UnexpectedByte, + BadPacket, } type ActionGroup = Vec>; @@ -89,17 +90,17 @@ impl<'a> Parser<'a> { /// the following line) /// This **will forward the cursor itself** fn read_sizeline(&mut self, opt_char: Option) -> ParseResult { - let opt_char: u8 = opt_char.unwrap_or(b'#'); - if let Some(opt_char) = self.buffer.get(self.cursor) { - // Good, we found a #; time to move ahead - self.incr_cursor(); - // Now read the remaining line - let (started_at, stopped_at) = self.read_line(); - Self::parse_into_usize(&self.buffer[started_at..stopped_at]) - } else { - // A sizeline should begin with a '#'; this one doesn't so it's a bad packet; ugh - Err(ParseError::UnexpectedByte) + if let Some(b) = self.buffer.get(self.cursor) { + if *b == opt_char.unwrap_or(b'#') { + // Good, we found a opt_char; time to move ahead + self.incr_cursor(); + // Now read the remaining line + let (started_at, stopped_at) = self.read_line(); + return Self::parse_into_usize(&self.buffer[started_at..stopped_at]); + } } + // A sizeline should begin with a opt_char; this one doesn't so it's a bad packet; ugh + Err(ParseError::UnexpectedByte) } fn incr_cursor(&mut self) { self.cursor += 1; @@ -108,6 +109,9 @@ impl<'a> Parser<'a> { let mut byte_iter = bytes.into_iter(); let mut item_usize = 0usize; while let Some(dig) = byte_iter.next() { + // 48 is the ASCII code for 0, and 57 is the ascii code for 9 + // so if 0 is given, the subtraction should give 0; similarly + // if 9 is given, the subtraction should give us 9! let curdig: usize = match dig.checked_sub(48) { Some(dig) => { if dig > 9 { @@ -185,6 +189,10 @@ impl<'a> Parser<'a> { } pub fn parse(mut self) -> Result<(Query, usize), ParseError> { let number_of_queries = self.parse_metaframe_get_datagroup_count()?; + if number_of_queries == 0 { + // how on earth do you expect us to execute 0 queries? waste of bandwidth + return Err(ParseError::BadPacket); + } if number_of_queries == 1 { // This is a simple query let single_group = self.parse_next_actiongroup()?; @@ -218,6 +226,14 @@ fn test_sizeline_parse() { assert_eq!(parser.cursor, sizeline.len()); } +#[test] +#[should_panic] +fn test_fail_sizeline_parse_wrong_firstbyte() { + let sizeline = "125\n".as_bytes(); + let mut parser = Parser::new(&sizeline); + parser.read_sizeline(None).unwrap(); +} + #[test] fn test_metaframe_parse() { let metaframe = "\r2\n*2\n".as_bytes(); @@ -226,6 +242,20 @@ fn test_metaframe_parse() { assert_eq!(parser.cursor, metaframe.len()); } +#[test] +#[should_panic] +fn test_metaframe_parse_fail() { + // First byte should be CR and not $ + let metaframe = "$2\n*2\n".as_bytes(); + let mut parser = Parser::new(&metaframe); + parser.parse_metaframe_get_datagroup_count().unwrap(); + // Give a wrong length approximation + let metaframe = "\r1\n*2\n".as_bytes(); + Parser::new(&metaframe) + .parse_metaframe_get_datagroup_count() + .unwrap(); +} + #[test] fn test_actiongroup_size_parse() { let dataframe_layout = "#6\n&12345\n".as_bytes(); @@ -272,3 +302,36 @@ fn test_complete_query_packet_parse() { ); assert_eq!(forward_by, query_packet.len()); } + +#[test] +#[should_panic] +fn test_query_parse_fail() { + // this packet has an extra \n, where it should have been nothing or a \r + let query_packet = "\r2\n*1\n#2\n&2\n#3\nGET\n#3\nfoo\n\n".as_bytes(); + Parser::new(&query_packet).parse().unwrap(); +} + +#[test] +fn test_query_parse_pass_part_of_next_query() { + // we read a part of the next query, we should happily ignore it (`\r2\n*1\n`) + let query_packet = "\r2\n*1\n#2\n&2\n#3\nGET\n#3\nfoo\n\r2\n*1\n".as_bytes(); + let (ret, forward_by) = Parser::new(&query_packet).parse().unwrap(); + assert_eq!( + ret, + Query::SimpleQuery(vec![ + "GET".to_owned().into_bytes(), + "foo".to_owned().into_bytes() + ]) + ); + // the cursor should be at the '\n' byte + assert!(forward_by == query_packet.len() - "\r2\n*1\n".len()); +} + +#[test] +fn test_query_fail_not_enough() { + let query_packet = "\r2".as_bytes(); + assert_eq!( + Parser::new(&query_packet).parse().err().unwrap(), + ParseError::NotEnough + ); +}