|
|
@ -31,35 +31,18 @@ use {
|
|
|
|
},
|
|
|
|
},
|
|
|
|
crate::engine::{
|
|
|
|
crate::engine::{
|
|
|
|
mem::BufferedScanner,
|
|
|
|
mem::BufferedScanner,
|
|
|
|
net::protocol::handshake::{
|
|
|
|
net::protocol::{
|
|
|
|
AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode,
|
|
|
|
exchange::{SQState, Usize},
|
|
|
|
HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode,
|
|
|
|
handshake::{
|
|
|
|
|
|
|
|
AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode,
|
|
|
|
|
|
|
|
HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode,
|
|
|
|
|
|
|
|
},
|
|
|
|
|
|
|
|
SQuery,
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
},
|
|
|
|
|
|
|
|
std::ops::Range,
|
|
|
|
};
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
pub(super) fn create_simple_query<const N: usize>(query: &str, params: [&str; N]) -> Vec<u8> {
|
|
|
|
|
|
|
|
let mut buf = vec![];
|
|
|
|
|
|
|
|
let query_size_as_string = query.len().to_string();
|
|
|
|
|
|
|
|
let total_packet_size = query.len()
|
|
|
|
|
|
|
|
+ params.iter().map(|l| l.len()).sum::<usize>()
|
|
|
|
|
|
|
|
+ query_size_as_string.len()
|
|
|
|
|
|
|
|
+ 1;
|
|
|
|
|
|
|
|
// segment 1
|
|
|
|
|
|
|
|
buf.push(b'S');
|
|
|
|
|
|
|
|
buf.extend(total_packet_size.to_string().as_bytes());
|
|
|
|
|
|
|
|
buf.push(b'\n');
|
|
|
|
|
|
|
|
// segment
|
|
|
|
|
|
|
|
buf.extend(query_size_as_string.as_bytes());
|
|
|
|
|
|
|
|
buf.push(b'\n');
|
|
|
|
|
|
|
|
// dataframe
|
|
|
|
|
|
|
|
buf.extend(query.as_bytes());
|
|
|
|
|
|
|
|
params
|
|
|
|
|
|
|
|
.into_iter()
|
|
|
|
|
|
|
|
.for_each(|param| buf.extend(param.as_bytes()));
|
|
|
|
|
|
|
|
buf
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
/*
|
|
|
|
client handshake
|
|
|
|
client handshake
|
|
|
|
*/
|
|
|
|
*/
|
|
|
@ -278,14 +261,15 @@ fn hs_bad_auth_mode() {
|
|
|
|
QT-DEX
|
|
|
|
QT-DEX
|
|
|
|
*/
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
fn iterate_payload(payload: &str, start: usize, f: impl Fn(usize, &[u8])) {
|
|
|
|
fn iterate_payload(payload: impl AsRef<[u8]>, start: usize, f: impl Fn(usize, &[u8])) {
|
|
|
|
for i in start..payload.len() {
|
|
|
|
let payload = payload.as_ref();
|
|
|
|
f(i, &payload.as_bytes()[..i])
|
|
|
|
for i in start..=payload.len() {
|
|
|
|
|
|
|
|
f(i, &payload[..i])
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fn iterate_exchange_payload(
|
|
|
|
fn iterate_exchange_payload(
|
|
|
|
payload: &str,
|
|
|
|
payload: impl AsRef<[u8]>,
|
|
|
|
start: usize,
|
|
|
|
start: usize,
|
|
|
|
f: impl Fn(usize, Result<(ExchangeResult, usize), ExchangeError>),
|
|
|
|
f: impl Fn(usize, Result<(ExchangeResult, usize), ExchangeError>),
|
|
|
|
) {
|
|
|
|
) {
|
|
|
@ -296,7 +280,7 @@ fn iterate_exchange_payload(
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
fn iterate_exchange_payload_from_zero(
|
|
|
|
fn iterate_exchange_payload_from_zero(
|
|
|
|
payload: &str,
|
|
|
|
payload: impl AsRef<[u8]>,
|
|
|
|
f: impl Fn(usize, Result<(ExchangeResult, usize), ExchangeError>),
|
|
|
|
f: impl Fn(usize, Result<(ExchangeResult, usize), ExchangeError>),
|
|
|
|
) {
|
|
|
|
) {
|
|
|
|
iterate_exchange_payload(payload, 0, f)
|
|
|
|
iterate_exchange_payload(payload, 0, f)
|
|
|
@ -309,7 +293,6 @@ fn iterate_exchange_payload_from_zero(
|
|
|
|
#[test]
|
|
|
|
#[test]
|
|
|
|
fn zero_sized_packet() {
|
|
|
|
fn zero_sized_packet() {
|
|
|
|
for payload in [
|
|
|
|
for payload in [
|
|
|
|
"S\n", // zero packet
|
|
|
|
|
|
|
|
"S0\n", // zero packet
|
|
|
|
"S0\n", // zero packet
|
|
|
|
"S2\n0\n", // zero query
|
|
|
|
"S2\n0\n", // zero query
|
|
|
|
"S1\n\n", // zero query
|
|
|
|
"S1\n\n", // zero query
|
|
|
@ -317,7 +300,11 @@ fn zero_sized_packet() {
|
|
|
|
iterate_exchange_payload_from_zero(payload, |size, result| {
|
|
|
|
iterate_exchange_payload_from_zero(payload, |size, result| {
|
|
|
|
if size == payload.len() {
|
|
|
|
if size == payload.len() {
|
|
|
|
// we got the full payload
|
|
|
|
// we got the full payload
|
|
|
|
assert_eq!(result, Err(ExchangeError::IncorrectQuerySizeOrMoreBytes))
|
|
|
|
if payload.len() == 3 {
|
|
|
|
|
|
|
|
assert_eq!(result, Err(ExchangeError::UnterminatedInteger))
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
assert_eq!(result, Err(ExchangeError::IncorrectQuerySizeOrMoreBytes))
|
|
|
|
|
|
|
|
}
|
|
|
|
} else {
|
|
|
|
} else {
|
|
|
|
// we don't have the full payload
|
|
|
|
// we don't have the full payload
|
|
|
|
if size < 3 {
|
|
|
|
if size < 3 {
|
|
|
@ -356,3 +343,191 @@ fn invalid_first_byte() {
|
|
|
|
})
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pub struct EQuery {
|
|
|
|
|
|
|
|
// payload
|
|
|
|
|
|
|
|
payload: String,
|
|
|
|
|
|
|
|
variable_range: Range<usize>,
|
|
|
|
|
|
|
|
// query
|
|
|
|
|
|
|
|
query: String,
|
|
|
|
|
|
|
|
query_range: Range<usize>,
|
|
|
|
|
|
|
|
// params
|
|
|
|
|
|
|
|
params: &'static [&'static str],
|
|
|
|
|
|
|
|
param_range: Range<usize>,
|
|
|
|
|
|
|
|
param_indices: Vec<Range<usize>>,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
impl EQuery {
|
|
|
|
|
|
|
|
fn new(query: String, params: &'static [&'static str]) -> Self {
|
|
|
|
|
|
|
|
var!(let variable_start, variable_end, query_start, query_end, param_start);
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
|
|
|
prepare the "back" of the payload
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
let encoded_params: String = params.iter().flat_map(|param| param.chars()).collect();
|
|
|
|
|
|
|
|
let total_size = query.len() + encoded_params.len();
|
|
|
|
|
|
|
|
let total_size_string = format!("{total_size}\n");
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
|
|
|
compute offsets
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
let packet_size = total_size_string.len() + total_size;
|
|
|
|
|
|
|
|
let mut buffer = String::new();
|
|
|
|
|
|
|
|
buffer.push('S');
|
|
|
|
|
|
|
|
buffer.push_str(&format!("{packet_size}\n"));
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// record start of variable block
|
|
|
|
|
|
|
|
variable_start = buffer.len();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
buffer.push_str(&query.len().to_string());
|
|
|
|
|
|
|
|
buffer.push('\n');
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// record start of query
|
|
|
|
|
|
|
|
query_start = buffer.len();
|
|
|
|
|
|
|
|
buffer.push_str(&query);
|
|
|
|
|
|
|
|
query_end = buffer.len();
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
// record start of params
|
|
|
|
|
|
|
|
param_start = buffer.len();
|
|
|
|
|
|
|
|
let mut param_indices = Vec::new();
|
|
|
|
|
|
|
|
for param in params {
|
|
|
|
|
|
|
|
let start = buffer.len();
|
|
|
|
|
|
|
|
buffer.push_str(param);
|
|
|
|
|
|
|
|
param_indices.push(start..buffer.len());
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
variable_end = buffer.len();
|
|
|
|
|
|
|
|
Self {
|
|
|
|
|
|
|
|
payload: buffer,
|
|
|
|
|
|
|
|
variable_range: variable_start..variable_end,
|
|
|
|
|
|
|
|
query,
|
|
|
|
|
|
|
|
query_range: query_start..query_end,
|
|
|
|
|
|
|
|
params,
|
|
|
|
|
|
|
|
param_range: param_start..variable_end,
|
|
|
|
|
|
|
|
param_indices,
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
|
|
|
fn ext_query() {
|
|
|
|
|
|
|
|
let ext_query = EQuery::new("create space myspace".to_owned(), &["sayan", "pass", ""]);
|
|
|
|
|
|
|
|
let query_starts_at = ext_query.payload[ext_query.variable_range.clone()]
|
|
|
|
|
|
|
|
.find('\n')
|
|
|
|
|
|
|
|
.unwrap()
|
|
|
|
|
|
|
|
+ 1;
|
|
|
|
|
|
|
|
assert_eq!(
|
|
|
|
|
|
|
|
&ext_query.payload[ext_query.variable_range.clone()]
|
|
|
|
|
|
|
|
[query_starts_at..query_starts_at + ext_query.query.len()],
|
|
|
|
|
|
|
|
ext_query.query
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
assert_eq!(ext_query.query, &ext_query.payload[ext_query.query_range]);
|
|
|
|
|
|
|
|
assert_eq!("sayanpass", &ext_query.payload[ext_query.param_range]);
|
|
|
|
|
|
|
|
for (param_indices, real_param) in ext_query.param_indices.iter().zip(ext_query.params) {
|
|
|
|
|
|
|
|
assert_eq!(*real_param, &ext_query.payload[param_indices.clone()]);
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
|
|
|
simple queries
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const fn dig_count(real: usize) -> usize {
|
|
|
|
|
|
|
|
// count the number of digits
|
|
|
|
|
|
|
|
let mut dig_count = 0;
|
|
|
|
|
|
|
|
let mut real_ = real;
|
|
|
|
|
|
|
|
while real_ != 0 {
|
|
|
|
|
|
|
|
dig_count += 1;
|
|
|
|
|
|
|
|
real_ /= 10;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
// account for a `0`
|
|
|
|
|
|
|
|
dig_count += (real == 0) as usize;
|
|
|
|
|
|
|
|
dig_count
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
const fn nth_position_value(mut real: usize, mut pos: usize) -> usize {
|
|
|
|
|
|
|
|
let digits = dig_count(real);
|
|
|
|
|
|
|
|
while digits != pos {
|
|
|
|
|
|
|
|
real /= 10;
|
|
|
|
|
|
|
|
pos += 1;
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
real
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
#[test]
|
|
|
|
|
|
|
|
fn simple_query() {
|
|
|
|
|
|
|
|
for query in [
|
|
|
|
|
|
|
|
// small query without params
|
|
|
|
|
|
|
|
EQuery::new("small query".to_owned(), &[]),
|
|
|
|
|
|
|
|
// small query with params
|
|
|
|
|
|
|
|
EQuery::new("small query".to_owned(), &["hello", "world"]),
|
|
|
|
|
|
|
|
// giant query without params
|
|
|
|
|
|
|
|
EQuery::new(
|
|
|
|
|
|
|
|
"abcdefghijklmnopqrstuvwxyz 123456789 ABCDEFGHIJKLMNOPQRSTUVWXYZ".repeat(1000),
|
|
|
|
|
|
|
|
&[],
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
// giant query with params
|
|
|
|
|
|
|
|
EQuery::new(
|
|
|
|
|
|
|
|
"abcdefghijklmnopqrstuvwxyz 123456789 ABCDEFGHIJKLMNOPQRSTUVWXYZ".repeat(1000),
|
|
|
|
|
|
|
|
&["hello", "world"],
|
|
|
|
|
|
|
|
),
|
|
|
|
|
|
|
|
] {
|
|
|
|
|
|
|
|
iterate_exchange_payload_from_zero(query.payload.as_bytes(), |read_position, result| {
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
|
|
|
S<packet size>\n<query window>\n<query><param>
|
|
|
|
|
|
|
|
^ variable ^query ^param
|
|
|
|
|
|
|
|
range start start start
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
- if before (variable range start - 1) then depending on the position from the first byte we will have, say the query size is 123
|
|
|
|
|
|
|
|
then we will have wrt distance from first byte (i.e position - 1) [1], [12], [123]
|
|
|
|
|
|
|
|
- if at (variable range start - 1) then we will have the exact size at [123] and in completed state
|
|
|
|
|
|
|
|
- if >= query start, then we will continue to issue changes of state until we have the full size which will be caught in a different branch
|
|
|
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
if read_position < 3 {
|
|
|
|
|
|
|
|
// didn't reach minimum threshold
|
|
|
|
|
|
|
|
assert_eq!(
|
|
|
|
|
|
|
|
result,
|
|
|
|
|
|
|
|
Ok((ExchangeResult::NewState(ExchangeState::Initial), 0))
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
} else if read_position <= query.variable_range.start - 1 {
|
|
|
|
|
|
|
|
let index = read_position - 1;
|
|
|
|
|
|
|
|
assert_eq!(
|
|
|
|
|
|
|
|
result,
|
|
|
|
|
|
|
|
Ok((
|
|
|
|
|
|
|
|
ExchangeResult::NewState(ExchangeState::Simple(SQState::_new(
|
|
|
|
|
|
|
|
Usize::new_unflagged(nth_position_value(
|
|
|
|
|
|
|
|
query.variable_range.len(),
|
|
|
|
|
|
|
|
index
|
|
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
))),
|
|
|
|
|
|
|
|
read_position
|
|
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
} else if read_position >= query.variable_range.start {
|
|
|
|
|
|
|
|
if read_position == query.payload.len() {
|
|
|
|
|
|
|
|
let (result, cursor) = result.unwrap();
|
|
|
|
|
|
|
|
assert_eq!(cursor, query.payload.len());
|
|
|
|
|
|
|
|
assert_eq!(
|
|
|
|
|
|
|
|
result,
|
|
|
|
|
|
|
|
ExchangeResult::Simple(SQuery::_new(
|
|
|
|
|
|
|
|
query.payload[query.query_range.start..].as_bytes(),
|
|
|
|
|
|
|
|
query.query_range.len()
|
|
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
assert_eq!(
|
|
|
|
|
|
|
|
result,
|
|
|
|
|
|
|
|
Ok((
|
|
|
|
|
|
|
|
ExchangeResult::NewState(ExchangeState::Simple(SQState::_new(
|
|
|
|
|
|
|
|
Usize::new_flagged(query.variable_range.len())
|
|
|
|
|
|
|
|
))),
|
|
|
|
|
|
|
|
query.variable_range.start // the cursor will not go ahead until the full query is read
|
|
|
|
|
|
|
|
))
|
|
|
|
|
|
|
|
)
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
} else {
|
|
|
|
|
|
|
|
unreachable!()
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
})
|
|
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
}
|
|
|
|