diff --git a/server/src/engine/net/protocol/exchange.rs b/server/src/engine/net/protocol/exchange.rs index 98cd9cb9..6a148a5a 100644 --- a/server/src/engine/net/protocol/exchange.rs +++ b/server/src/engine/net/protocol/exchange.rs @@ -86,15 +86,15 @@ fn scan_usize_guaranteed_termination( let mut ret = 0usize; let mut stop = scanner.rounded_eq(b'\n'); while !scanner.eof() & !stop { - let next_byte = unsafe { + let this_byte = unsafe { // UNSAFE(@ohsayan): loop invariant scanner.next_byte() }; match ret .checked_mul(10) - .map(|int| int.checked_add((next_byte & 0x0f) as usize)) + .map(|int| int.checked_add((this_byte & 0x0f) as usize)) { - Some(Some(int)) if next_byte.is_ascii_digit() => ret = int, + Some(Some(int)) if this_byte.is_ascii_digit() => ret = int, _ => return Err(ExchangeError::NotAsciiByteOrOverflow), } stop = scanner.rounded_eq(b'\n'); @@ -357,22 +357,33 @@ impl<'a> Pipeline<'a> { scanner: BufferedScanner::new(buf), } } - pub fn next_query(&mut self) -> Result>, ExchangeError> { +} + +impl<'a> Iterator for Pipeline<'a> { + type Item = Result, ExchangeError>; + fn next(&mut self) -> Option { let nonzero = self.scanner.buffer_len() != 0; if self.scanner.eof() & nonzero { - Ok(None) + None } else { - let query_size = scan_usize_guaranteed_termination(&mut self.scanner)?; - let param_size = scan_usize_guaranteed_termination(&mut self.scanner)?; - let (full_size, overflow) = param_size.overflowing_add(query_size); - if compiler::likely(self.scanner.has_left(full_size) & !overflow) { - Ok(Some(SQuery { - buf: self.scanner.current_buffer(), - q_window: query_size, - })) - } else { - Err(ExchangeError::IncorrectQuerySizeOrMoreBytes) - } + let mut ret = || { + let query_size = scan_usize_guaranteed_termination(&mut self.scanner)?; + let param_size = scan_usize_guaranteed_termination(&mut self.scanner)?; + let (full_size, overflow) = param_size.overflowing_add(query_size); + if compiler::likely(self.scanner.has_left(full_size) & !overflow) { + let block = unsafe { + // UNSAFE(@ohsayan): checked in above branch + self.scanner.next_chunk_variable(full_size) + }; + Ok(SQuery { + buf: block, + q_window: query_size, + }) + } else { + Err(ExchangeError::IncorrectQuerySizeOrMoreBytes) + } + }; + Some(ret()) } } } diff --git a/server/src/engine/net/protocol/mod.rs b/server/src/engine/net/protocol/mod.rs index 25bc4cf3..ace91eb1 100644 --- a/server/src/engine/net/protocol/mod.rs +++ b/server/src/engine/net/protocol/mod.rs @@ -363,12 +363,12 @@ async fn exec_pipe<'a, S: Socket>( con: &mut BufWriter, cs: &mut ClientLocalState, global: &Global, - mut pipe: Pipeline<'a>, + pipe: Pipeline<'a>, ) -> IoResult<()> { - loop { - match pipe.next_query() { - Ok(None) => break Ok(()), - Ok(Some(q)) => { + let mut pipe = pipe.into_iter(); + while let Some(query) = pipe.next() { + match query { + Ok(q) => { write_response( engine::core::exec::dispatch_to_executor(global, cs, q).await, con, @@ -381,8 +381,9 @@ async fn exec_pipe<'a, S: Socket>( .to_le_bytes(); con.write_all(&[ResponseType::Error.value_u8(), a, b]) .await?; - break Ok(()); + return Ok(()); } } } + Ok(()) } diff --git a/server/src/engine/net/protocol/tests.rs b/server/src/engine/net/protocol/tests.rs index 6c04bc4b..0c7fe4c6 100644 --- a/server/src/engine/net/protocol/tests.rs +++ b/server/src/engine/net/protocol/tests.rs @@ -363,8 +363,7 @@ impl EQuery { /* prepare the "back" of the payload */ - let encoded_params: String = params.iter().flat_map(|param| param.chars()).collect(); - let total_size = query.len() + encoded_params.len(); + let total_size = query.len() + params.iter().map(|p| p.len()).sum::(); let total_size_string = format!("{total_size}\n"); /* @@ -531,3 +530,67 @@ fn simple_query() { }) } } + +/* + pipeline +*/ + +fn pipe_query(q: &str, p: [&str; N]) -> String { + let mut buffer = String::new(); + buffer.extend(q.len().to_string().chars()); + buffer.push('\n'); + buffer.extend( + p.iter() + .map(|_p| _p.len()) + .sum::() + .to_string() + .chars(), + ); + buffer.push('\n'); + buffer.extend(q.chars()); + for p_ in p { + buffer.push_str(p_); + } + buffer +} + +fn pipe(queries: [String; N]) -> String { + let packed_queries = queries.concat(); + format!("P{}\n{packed_queries}", packed_queries.len()) +} + +#[test] +fn full_pipe_scan() { + let pipeline_buffer = pipe([ + pipe_query("create space myspace", []), + pipe_query( + "create model myspace.mymodel(username: string, password: string)", + [], + ), + pipe_query("insert into myspace.mymodel(?, ?)", ["sayan", "cake"]), + ]); + let (pipeline, cursor) = Exchange::try_complete( + BufferedScanner::new(pipeline_buffer.as_bytes()), + ExchangeState::default(), + ) + .unwrap(); + assert_eq!(cursor, pipeline_buffer.len()); + let pipeline: Vec> = match pipeline { + ExchangeResult::Pipeline(p) => p.into_iter().map(Result::unwrap).collect(), + _ => panic!("expected pipeline got: {:?}", pipeline), + }; + assert_eq!( + pipeline, + vec![ + SQuery::_new(b"create space myspace", "create space myspace".len()), + SQuery::_new( + b"create model myspace.mymodel(username: string, password: string)", + "create model myspace.mymodel(username: string, password: string)".len() + ), + SQuery::_new( + b"insert into myspace.mymodel(?, ?)sayancake", + "insert into myspace.mymodel(?, ?)".len() + ) + ] + ); +}