diff --git a/server/src/engine/net/protocol/exchange.rs b/server/src/engine/net/protocol/exchange.rs index 6a148a5a..ad1e9a39 100644 --- a/server/src/engine/net/protocol/exchange.rs +++ b/server/src/engine/net/protocol/exchange.rs @@ -222,6 +222,10 @@ impl PipeState { const fn new(packet_s: Usize) -> Self { Self { packet_s } } + #[cfg(test)] + pub const fn _new(packet_s: Usize) -> Self { + Self::new(packet_s) + } } impl Default for ExchangeState { diff --git a/server/src/engine/net/protocol/tests.rs b/server/src/engine/net/protocol/tests.rs index 0c7fe4c6..7d7986bd 100644 --- a/server/src/engine/net/protocol/tests.rs +++ b/server/src/engine/net/protocol/tests.rs @@ -32,7 +32,7 @@ use { crate::engine::{ mem::BufferedScanner, net::protocol::{ - exchange::{SQState, Usize}, + exchange::{PipeState, SQState, Usize}, handshake::{ AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode, HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode, @@ -535,23 +535,37 @@ fn simple_query() { pipeline */ -fn pipe_query(q: &str, p: [&str; N]) -> String { - let mut buffer = String::new(); - buffer.extend(q.len().to_string().chars()); - buffer.push('\n'); - buffer.extend( - p.iter() - .map(|_p| _p.len()) - .sum::() - .to_string() - .chars(), - ); - buffer.push('\n'); - buffer.extend(q.chars()); - for p_ in p { - buffer.push_str(p_); +struct EPipeQuery { + payload: String, + q: String, + p: &'static [&'static str], + q_range: Range, + p_range: Range, +} + +impl EPipeQuery { + fn new_string(q: &str, p: &'static [&'static str]) -> String { + Self::new(q.to_owned(), p).payload + } + fn new(q: String, p: &'static [&'static str]) -> Self { + let mut buffer = String::new(); + buffer.push_str(&q.len().to_string()); + buffer.push('\n'); + buffer.push_str(&p.iter().map(|s| s.len()).sum::().to_string()); + buffer.push('\n'); + let q_start = buffer.len(); + buffer.push_str(q.as_str()); + let q_stop = buffer.len(); + p.iter().for_each(|p_| buffer.push_str(p_)); + let p_stop = buffer.len(); + Self { + payload: buffer, + q, + p, + q_range: q_start..q_stop, + p_range: q_stop..p_stop, + } } - buffer } fn pipe(queries: [String; N]) -> String { @@ -559,15 +573,23 @@ fn pipe(queries: [String; N]) -> String { format!("P{}\n{packed_queries}", packed_queries.len()) } +fn get_pipeline_from_result(pipeline: ExchangeResult) -> Vec { + let pipeline: Vec> = match pipeline { + ExchangeResult::Pipeline(p) => p.into_iter().map(Result::unwrap).collect(), + _ => panic!("expected pipeline got: {:?}", pipeline), + }; + pipeline +} + #[test] fn full_pipe_scan() { let pipeline_buffer = pipe([ - pipe_query("create space myspace", []), - pipe_query( + EPipeQuery::new_string("create space myspace", &[]), + EPipeQuery::new_string( "create model myspace.mymodel(username: string, password: string)", - [], + &[], ), - pipe_query("insert into myspace.mymodel(?, ?)", ["sayan", "cake"]), + EPipeQuery::new_string("insert into myspace.mymodel(?, ?)", &["sayan", "cake"]), ]); let (pipeline, cursor) = Exchange::try_complete( BufferedScanner::new(pipeline_buffer.as_bytes()), @@ -575,10 +597,7 @@ fn full_pipe_scan() { ) .unwrap(); assert_eq!(cursor, pipeline_buffer.len()); - let pipeline: Vec> = match pipeline { - ExchangeResult::Pipeline(p) => p.into_iter().map(Result::unwrap).collect(), - _ => panic!("expected pipeline got: {:?}", pipeline), - }; + let pipeline = get_pipeline_from_result(pipeline); assert_eq!( pipeline, vec![ @@ -594,3 +613,138 @@ fn full_pipe_scan() { ] ); } + +struct EPipe { + payload: String, + variable_range: Range, + queries: Vec, +} + +impl EPipe { + fn new(queries: [EPipeQuery; N]) -> Self { + let packed_queries_len = queries.iter().map(|epq| epq.payload.len()).sum::(); + let mut buffer = format!("P{packed_queries_len}\n"); + let variable_start = buffer.len(); + for query in queries.iter() { + buffer.push_str(query.payload.as_str()); + } + let variable_end = buffer.len(); + Self { + payload: buffer, + variable_range: variable_start..variable_end, + queries: queries.into_iter().collect(), + } + } +} + +#[test] +fn pipeline() { + for pipe in [ + EPipe::new([ + // small query with no params + EPipeQuery::new("create space myspace".to_owned(), &[]), + // small query with params + EPipeQuery::new( + "insert into myspace.mymodel(?, ?)".to_owned(), + &["sayan", "elx"], + ), + // giant query with no params + EPipeQuery::new( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz123456789".repeat(1000), + &[], + ), + // giant query with params + EPipeQuery::new( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz123456789".repeat(1000), + &["hello", "world"], + ), + ]), + EPipe::new([ + // giant query with no params + EPipeQuery::new( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz123456789".repeat(1000), + &[], + ), + // giant query with params + EPipeQuery::new( + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz123456789".repeat(1000), + &["hello", "world"], + ), + // small query with no params + EPipeQuery::new("create space myspace".to_owned(), &[]), + // small query with params + EPipeQuery::new( + "insert into myspace.mymodel(?, ?)".to_owned(), + &["sayan", "elx"], + ), + ]), + ] { + iterate_exchange_payload_from_zero(pipe.payload.as_bytes(), |buffer_position, result| { + if buffer_position < 3 { + // we didn't read enough to pass the initial bounds check + assert_eq!( + result, + Ok((ExchangeResult::NewState(ExchangeState::Initial), 0)) + ) + } else { + if buffer_position <= pipe.variable_range.start - 1 { + // in this case we're before completion of the packet size. we succesively read [1], [12], [123] + assert_eq!( + result, + Ok(( + ExchangeResult::NewState(ExchangeState::Pipeline(PipeState::_new( + Usize::new_unflagged(nth_position_value( + pipe.variable_range.len(), + buffer_position - 1 + )) + ))), + buffer_position + )) + ) + } else if buffer_position == pipe.variable_range.start { + // in this case we passed the newline for the packet size so we have everything in order + assert_eq!( + result, + Ok(( + ExchangeResult::NewState(ExchangeState::Pipeline(PipeState::_new( + Usize::new_flagged(pipe.variable_range.len()) + ))), + pipe.variable_range.start + )) + ) + } else if buffer_position >= pipe.variable_range.start { + // in this case we have more bytes than the variable range start + if buffer_position == pipe.payload.len() { + // if we're here, we've read the full packet + let (status, cursor) = result.unwrap(); + assert_eq!(cursor, pipe.variable_range.end); // the whole chunk is obtained + let pipeline = get_pipeline_from_result(status); + pipeline + .into_iter() + .zip(pipe.queries.iter()) + .for_each(|(spq, epq)| { + let decoded_query = core::str::from_utf8(spq.query()).unwrap(); + let decoded_params = core::str::from_utf8(spq.params()).unwrap(); + assert_eq!(decoded_query, epq.q); + assert_eq!(decoded_params, epq.p.concat()); + assert_eq!(decoded_query, &epq.payload[epq.q_range.clone()]); + assert_eq!(decoded_params, &epq.payload[epq.p_range.clone()]); + }) + } else { + // if we're here we haven't read the full packet because we're waiting for it to complete + let (status, cursor) = result.unwrap(); + assert_eq!(cursor, pipe.variable_range.start); // the cursor won't move until + assert_eq!( + status, + ExchangeResult::NewState(ExchangeState::Pipeline(PipeState::_new( + Usize::new_flagged(pipe.variable_range.len()) + ))) + ) + } + } else { + unreachable!() + } + } + }) + } +}