net: Rewrite protocol impl to support pipelines

next
Sayan Nandan 6 months ago
parent 0b0acd2038
commit d3b5fd8060
No known key found for this signature in database
GPG Key ID: 0EBD769024B24F0A

@ -430,13 +430,13 @@ impl FractalMgr {
_ = sigterm.recv() => { _ = sigterm.recv() => {
info!("flp: finishing any pending maintenance tasks"); info!("flp: finishing any pending maintenance tasks");
let global = global.clone(); let global = global.clone();
tokio::task::spawn_blocking(|| self.general_executor(global)).await.unwrap(); let _ = tokio::task::spawn_blocking(|| self.general_executor(global)).await;
info!("flp: exited executor service"); info!("flp: exited executor service");
break; break;
}, },
_ = tokio::time::sleep(dur) => { _ = tokio::time::sleep(dur) => {
let global = global.clone(); let global = global.clone();
tokio::task::spawn_blocking(|| self.general_executor(global)).await.unwrap() let _ = tokio::task::spawn_blocking(|| self.general_executor(global)).await;
} }
task = lpq.recv() => { task = lpq.recv() => {
let Task { threshold, task } = match task { let Task { threshold, task } = match task {

@ -66,7 +66,7 @@ impl<'a, T> Scanner<'a, T> {
self.__cursor self.__cursor
} }
/// Returns the buffer from the current position /// Returns the buffer from the current position
pub fn current_buffer(&self) -> &[T] { pub fn current_buffer(&self) -> &'a [T] {
&self.d[self.__cursor..] &self.d[self.__cursor..]
} }
/// Returns the ptr to the cursor /// Returns the ptr to the cursor
@ -86,6 +86,9 @@ impl<'a, T> Scanner<'a, T> {
pub fn has_left(&self, sizeof: usize) -> bool { pub fn has_left(&self, sizeof: usize) -> bool {
self.remaining() >= sizeof self.remaining() >= sizeof
} }
pub fn remaining_size_is(&self, size: usize) -> bool {
self.remaining() == size
}
/// Returns true if the rounded cursor matches the predicate /// Returns true if the rounded cursor matches the predicate
pub fn rounded_cursor_matches(&self, f: impl Fn(&T) -> bool) -> bool { pub fn rounded_cursor_matches(&self, f: impl Fn(&T) -> bool) -> bool {
f(&self.d[self.rounded_cursor()]) & !self.eof() f(&self.d[self.rounded_cursor()]) & !self.eof()

@ -1,5 +1,5 @@
/* /*
* Created on Wed Sep 20 2023 * Created on Tue Apr 02 2024
* *
* This file is a part of Skytable * This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
@ -7,7 +7,7 @@
* vision to provide flexibility in data modelling without compromising * vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability. * on performance, queryability or scalability.
* *
* Copyright (c) 2023, Sayan Nandan <ohsayan@outlook.com> * Copyright (c) 2024, Sayan Nandan <nandansayan@outlook.com>
* *
* This program is free software: you can redistribute it and/or modify * This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by * it under the terms of the GNU Affero General Public License as published by
@ -24,211 +24,308 @@
* *
*/ */
use {super::AccumlatorStatus, crate::engine::mem::BufferedScanner}; use {
crate::{engine::mem::BufferedScanner, util::compiler},
core::fmt,
};
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] /*
pub struct Resume(usize); Skyhash/2.1 Implementation
impl Resume { ---
#[cfg(test)] This is an implementation of Skyhash/2.1, Skytable's data exchange protocol.
pub(super) const fn test_new(v: usize) -> Self {
Self(v)
}
#[cfg(test)]
pub(super) const fn inner(&self) -> usize {
self.0
}
}
impl Default for Resume {
fn default() -> Self {
Self(0)
}
}
pub(super) unsafe fn resume<'a>( 0. Notes
buf: &'a [u8], ++++++++++++++++++
last_cursor: Resume, - 2.1 is fully backwards compatible with 2.0 clients. As such we don't even designate it as a separate version.
last_state: QExchangeState, - The "LF exception" essentially allows `0\n` to be equal to `\n`. It's unimportant to enforce this
) -> (Resume, QExchangeResult<'a>) {
let mut scanner = unsafe { 1.1 Query Types
// UNSAFE(@ohsayan): we are the ones who generate the cursor and restore it ++++++++++++++++++
BufferedScanner::new_with_cursor(buf, last_cursor.0) The protocol makes two distinctions, at the protocol-level about the type of queries:
}; a. Simple query
let ret = last_state.resume(&mut scanner); b. Pipeline
(Resume(scanner.cursor()), ret)
} 1.1.1 Simple Query
++++++++++++++++++
A simple query
*/
/* /*
SQ sq definition
*/ */
#[derive(Debug, PartialEq)] #[derive(Debug)]
pub struct SQuery<'a> { pub struct SQuery<'a> {
q: &'a [u8], buf: &'a [u8],
q_window: usize, q_window: usize,
} }
impl<'a> SQuery<'a> { impl<'a> SQuery<'a> {
pub(super) fn new(q: &'a [u8], q_window: usize) -> Self { fn new(buf: &'a [u8], q_window: usize) -> Self {
Self { q, q_window } Self { buf, q_window }
}
pub fn query(&self) -> &[u8] {
&self.buf[..self.q_window]
}
pub fn params(&self) -> &[u8] {
&self.buf[self.q_window..]
}
}
/*
scanint
*/
fn scan_usize_guaranteed_termination(scanner: &mut BufferedScanner) -> Result<usize, ()> {
let mut ret = 0usize;
let mut stop = scanner.rounded_eq(b'\n');
while !scanner.eof() & !stop {
let next_byte = unsafe {
// UNSAFE(@ohsayan): loop invariant
scanner.next_byte()
};
match ret
.checked_mul(10)
.map(|int| int.checked_add((next_byte & 0x0f) as usize))
{
Some(Some(int)) if next_byte.is_ascii_digit() => ret = int,
_ => return Err(()),
}
stop = scanner.rounded_eq(b'\n');
}
unsafe {
// UNSAFE(@ohsayan): scanned stop, but not accounted for yet
scanner.incr_cursor_if(stop)
}
if stop {
Ok(ret)
} else {
Err(())
}
}
#[derive(Clone, Copy)]
struct Usize {
v: isize,
}
impl Usize {
const SHIFT: u32 = isize::BITS - 1;
const MASK: isize = 1 << Self::SHIFT;
#[inline(always)]
const fn new(v: isize) -> Self {
Self { v }
} }
pub fn payload(&self) -> &'a [u8] { #[inline(always)]
self.q const fn new_unflagged(int: usize) -> Self {
Self::new(int as isize)
} }
pub fn q_window(&self) -> usize { #[inline(always)]
self.q_window fn int(&self) -> usize {
(self.v & !Self::MASK) as usize
} }
pub fn query(&self) -> &'a [u8] { #[inline(always)]
&self.payload()[..self.q_window()] fn update(&mut self, new: usize) {
self.v = (new as isize) | (self.v & Self::MASK);
} }
pub fn params(&self) -> &'a [u8] { #[inline(always)]
&self.payload()[self.q_window()..] fn flag(&self) -> bool {
(self.v & Self::MASK) != 0
} }
#[cfg(test)] #[inline(always)]
pub fn query_str(&self) -> &str { fn set_flag_if(&mut self, iff: bool) {
core::str::from_utf8(self.query()).unwrap() self.v |= (iff as isize) << Self::SHIFT;
} }
#[cfg(test)] }
pub fn params_str(&self) -> &str {
core::str::from_utf8(self.params()).unwrap() impl fmt::Debug for Usize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Usize")
.field("int", &self.int())
.field("flag", &self.flag())
.finish()
}
}
impl Usize {
/// Attempt to "complete" a scan of the integer. Idempotency guarantee: it is guaranteed that calls would not change the state
/// of the [`Usize`] or the buffer if the final state has been reached
fn update_scanned(&mut self, scanner: &mut BufferedScanner) -> Result<(), ()> {
let mut stop = scanner.rounded_eq(b'\n');
while !stop & !scanner.eof() & !self.flag() {
let byte = unsafe {
// UNSAFE(@ohsayan): verified by loop invariant
scanner.next_byte()
};
match (self.int() as isize) // this cast allows us to guarantee that we don't trip the flag
.checked_mul(10)
.map(|int| int.checked_add((byte & 0x0f) as isize))
{
Some(Some(int)) if byte.is_ascii_digit() => self.update(int as usize),
_ => return Err(()),
}
stop = scanner.rounded_eq(b'\n');
}
unsafe {
// UNSAFE(@ohsayan): scanned stop byte but did not account for it; the flag check is for cases where the input buffer
// has something like [LF][LF] in which case we stopped at the first LF but we would accidentally read the second one
// on the second derogatory call
scanner.incr_cursor_if(stop & !self.flag())
}
self.set_flag_if(stop | self.flag()); // if second call we must check the flag
Ok(())
} }
} }
/* /*
utils states
*/ */
#[derive(Debug, PartialEq, Clone, Copy)] #[derive(Debug)]
pub(super) enum QExchangeStateInternal { pub enum ExchangeState {
Initial, Initial,
PendingMeta1, Simple(SQState),
PendingMeta2, Pipeline(PipeState),
PendingData,
} }
impl Default for QExchangeStateInternal { #[derive(Debug)]
fn default() -> Self { pub struct SQState {
Self::Initial packet_s: Usize,
}
impl SQState {
const fn new(packet_s: Usize) -> Self {
Self { packet_s }
} }
} }
#[derive(Debug, PartialEq)] #[derive(Debug)]
pub(super) struct QExchangeState { pub struct PipeState {
state: QExchangeStateInternal, packet_s: Usize,
target: usize, }
md_packet_size: u64,
md_q_window: u64, impl PipeState {
const fn new(packet_s: Usize) -> Self {
Self { packet_s }
}
} }
impl Default for QExchangeState { impl Default for ExchangeState {
fn default() -> Self { fn default() -> Self {
Self::new() Self::Initial
} }
} }
#[derive(Debug, PartialEq)] pub enum ExchangeResult<'a> {
/// Result after attempting to complete (or terminate) a query time exchange NewState(ExchangeState),
pub(super) enum QExchangeResult<'a> { Simple(SQuery<'a>),
/// We completed the exchange and yielded a [`SQuery`] Pipeline(Pipeline<'a>),
SQCompleted(SQuery<'a>),
/// We're changing states
ChangeState(QExchangeState),
/// We hit an error and need to terminate this exchange
Error,
} }
impl QExchangeState { pub struct Exchange<'a> {
fn _new( scanner: BufferedScanner<'a>,
state: QExchangeStateInternal, }
target: usize,
md_packet_size: u64, impl<'a> Exchange<'a> {
md_q_window: u64, const MIN_Q_SIZE: usize = "P0\n".len();
) -> Self { fn new(scanner: BufferedScanner<'a>) -> Self {
Self { Self { scanner }
state,
target,
md_packet_size,
md_q_window,
}
} }
#[cfg(test)] pub fn try_complete(
pub(super) fn new_test( scanner: BufferedScanner<'a>,
state: QExchangeStateInternal, state: ExchangeState,
target: usize, ) -> Result<(ExchangeResult, usize), ()> {
md_packet_size: u64, Self::new(scanner).complete(state)
md_q_window: u64,
) -> Self {
Self::_new(state, target, md_packet_size, md_q_window)
} }
} }
impl QExchangeState { impl<'a> Exchange<'a> {
pub const MIN_READ: usize = b"S\x00\n\x00\n".len(); fn complete(mut self, state: ExchangeState) -> Result<(ExchangeResult<'a>, usize), ()> {
pub fn new() -> Self { match state {
Self::_new(QExchangeStateInternal::Initial, Self::MIN_READ, 0, 0) ExchangeState::Initial => {
if compiler::likely(self.scanner.has_left(Self::MIN_Q_SIZE)) {
let first_byte = unsafe {
// UNSAFE(@ohsayan): already verified in above branch
self.scanner.next_byte()
};
match first_byte {
b'S' => self.process_simple(SQState::new(Usize::new_unflagged(0))),
b'P' => self.process_pipe(PipeState::new(Usize::new_unflagged(0))),
_ => return Err(()),
}
} else {
Ok(ExchangeResult::NewState(state))
}
}
ExchangeState::Simple(sq_s) => self.process_simple(sq_s),
ExchangeState::Pipeline(pipe_s) => self.process_pipe(pipe_s),
}
.map(|ret| (ret, self.scanner.cursor()))
} }
pub fn has_reached_target(&self, new_buffer: &[u8]) -> bool { fn process_simple(&mut self, mut sq_state: SQState) -> Result<ExchangeResult<'a>, ()> {
new_buffer.len() >= self.target // try to complete the packet size if needed
sq_state.packet_s.update_scanned(&mut self.scanner)?;
if sq_state.packet_s.flag() & self.scanner.remaining_size_is(sq_state.packet_s.int()) {
// we have the full packet size and the required data
let q_window = scan_usize_guaranteed_termination(&mut self.scanner)?;
let nonzero = (q_window != 0) & (sq_state.packet_s.int() != 0);
if compiler::likely(self.scanner.remaining_size_is(q_window) & nonzero) {
// this check is important because the client might have given us an incorrect q size
Ok(ExchangeResult::Simple(SQuery::new(
self.scanner.current_buffer(),
q_window,
)))
} else {
Err(())
}
} else {
Ok(ExchangeResult::NewState(ExchangeState::Simple(sq_state)))
}
} }
fn resume<'a>(self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> { fn process_pipe(&mut self, mut pipe_s: PipeState) -> Result<ExchangeResult<'a>, ()> {
debug_assert!(scanner.has_left(Self::MIN_READ)); // try to complete the packet size if needed
match self.state { pipe_s.packet_s.update_scanned(&mut self.scanner)?;
QExchangeStateInternal::Initial => self.start_initial(scanner), if pipe_s.packet_s.flag() & self.scanner.remaining_size_is(pipe_s.packet_s.int()) {
QExchangeStateInternal::PendingMeta1 => self.resume_at_md1(scanner), // great, we have the entire packet
QExchangeStateInternal::PendingMeta2 => self.resume_at_md2(scanner), Ok(ExchangeResult::Pipeline(Pipeline::new(
QExchangeStateInternal::PendingData => self.resume_data(scanner), self.scanner.current_buffer(),
)))
} else {
Ok(ExchangeResult::NewState(ExchangeState::Pipeline(pipe_s)))
} }
} }
fn start_initial<'a>(self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> { }
if unsafe { scanner.next_byte() } != b'S' {
// has to be a simple query! /*
return QExchangeResult::Error; pipeline
*/
pub struct Pipeline<'a> {
scanner: BufferedScanner<'a>,
}
impl<'a> Pipeline<'a> {
fn new(buf: &'a [u8]) -> Self {
Self {
scanner: BufferedScanner::new(buf),
} }
self.resume_at_md1(scanner) }
} pub fn next_query(&mut self) -> Result<Option<SQuery<'a>>, ()> {
fn resume_at_md1<'a>(mut self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> { let nonzero = self.scanner.buffer_len() != 0;
let packet_size = match super::scan_int(scanner, self.md_packet_size) { if self.scanner.eof() & nonzero {
Ok(AccumlatorStatus::Completed(v)) => v, Ok(None)
Ok(AccumlatorStatus::Pending(p)) => {
// if this is the first run, we read 5 bytes and need atleast one more; if this is a resume we read one or more bytes and
// need atleast one more
self.target += 1;
self.md_packet_size = p;
self.state = QExchangeStateInternal::PendingMeta1;
return QExchangeResult::ChangeState(self);
}
Err(()) => return QExchangeResult::Error,
};
self.md_packet_size = packet_size;
self.target = scanner.cursor() + packet_size as usize;
// hand over control to md2
self.resume_at_md2(scanner)
}
fn resume_at_md2<'a>(mut self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> {
let q_window = match super::scan_int(scanner, self.md_q_window) {
Ok(AccumlatorStatus::Completed(v)) => v,
Ok(AccumlatorStatus::Pending(p)) => {
self.md_q_window = p;
self.state = QExchangeStateInternal::PendingMeta2;
return QExchangeResult::ChangeState(self);
}
Err(()) => return QExchangeResult::Error,
};
self.md_q_window = q_window;
// hand over control to data
self.resume_data(scanner)
}
fn resume_data<'a>(mut self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> {
let df_size = self.target - scanner.cursor();
if scanner.remaining() == df_size {
unsafe {
QExchangeResult::SQCompleted(SQuery::new(
scanner.next_chunk_variable(df_size),
self.md_q_window as usize,
))
}
} else { } else {
self.state = QExchangeStateInternal::PendingData; let query_size = scan_usize_guaranteed_termination(&mut self.scanner)?;
QExchangeResult::ChangeState(self) let param_size = scan_usize_guaranteed_termination(&mut self.scanner)?;
let (full_size, overflow) = param_size.overflowing_add(query_size);
if compiler::likely(self.scanner.remaining_size_is(full_size) & !overflow) {
Ok(Some(SQuery {
buf: self.scanner.current_buffer(),
q_window: query_size,
}))
} else {
Err(())
}
} }
} }
} }

@ -44,14 +44,9 @@ mod handshake;
#[cfg(test)] #[cfg(test)]
mod tests; mod tests;
// re-export
pub use exchange::SQuery;
use crate::engine::core::system_db::VerifyUser;
use { use {
self::{ self::{
exchange::{QExchangeResult, QExchangeState}, exchange::{Exchange, ExchangeResult, ExchangeState, Pipeline},
handshake::{ handshake::{
AuthMode, CHandshake, DataExchangeMode, HandshakeResult, HandshakeState, AuthMode, CHandshake, DataExchangeMode, HandshakeResult, HandshakeState,
HandshakeVersion, ProtocolError, ProtocolVersion, QueryMode, HandshakeVersion, ProtocolError, ProtocolVersion, QueryMode,
@ -60,7 +55,8 @@ use {
super::{IoResult, QueryLoopResult, Socket}, super::{IoResult, QueryLoopResult, Socket},
crate::engine::{ crate::engine::{
self, self,
error::QueryError, core::system_db::VerifyUser,
error::{QueryError, QueryResult},
fractal::{Global, GlobalInstanceLike}, fractal::{Global, GlobalInstanceLike},
mem::{BufferedScanner, IntegerRepr}, mem::{BufferedScanner, IntegerRepr},
}, },
@ -68,31 +64,12 @@ use {
tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}, tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter},
}; };
#[repr(u8)] // re-export
#[derive(sky_macros::EnumMethods, Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] pub use self::exchange::SQuery;
#[allow(unused)]
pub enum ResponseType { /*
Null = 0x00, connection state
Bool = 0x01, */
UInt8 = 0x02,
UInt16 = 0x03,
UInt32 = 0x04,
UInt64 = 0x05,
SInt8 = 0x06,
SInt16 = 0x07,
SInt32 = 0x08,
SInt64 = 0x09,
Float32 = 0x0A,
Float64 = 0x0B,
Binary = 0x0C,
String = 0x0D,
List = 0x0E,
Dict = 0x0F,
Error = 0x10,
Row = 0x11,
Empty = 0x12,
MultiRow = 0x13,
}
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
pub struct ClientLocalState { pub struct ClientLocalState {
@ -128,106 +105,9 @@ impl ClientLocalState {
} }
} }
#[derive(Debug, PartialEq)] /*
pub enum Response { handshake
Empty, */
Null,
Serialized {
ty: ResponseType,
size: usize,
data: Vec<u8>,
},
Bool(bool),
}
pub(super) async fn query_loop<S: Socket>(
con: &mut BufWriter<S>,
buf: &mut BytesMut,
global: &Global,
) -> IoResult<QueryLoopResult> {
// handshake
let mut client_state = match do_handshake(con, buf, global).await? {
PostHandshake::Okay(hs) => hs,
PostHandshake::ConnectionClosedFin => return Ok(QueryLoopResult::Fin),
PostHandshake::ConnectionClosedRst => return Ok(QueryLoopResult::Rst),
PostHandshake::Error(e) => {
// failed to handshake; we'll close the connection
let hs_err_packet = [b'H', 0, 1, e.value_u8()];
con.write_all(&hs_err_packet).await?;
return Ok(QueryLoopResult::HSFailed);
}
};
// done handshaking
con.write_all(b"H\x00\x00\x00").await?;
con.flush().await?;
let mut state = QExchangeState::default();
let mut cursor = Default::default();
loop {
if con.read_buf(buf).await? == 0 {
if buf.is_empty() {
return Ok(QueryLoopResult::Fin);
} else {
return Ok(QueryLoopResult::Rst);
}
}
if !state.has_reached_target(buf) {
// we haven't buffered sufficient bytes; keep working
continue;
}
let sq = match unsafe {
// UNSAFE(@ohsayan): as the resume cursor is private, we can't access this anyways
exchange::resume(buf, cursor, state)
} {
(_, QExchangeResult::SQCompleted(sq)) => sq,
(new_cursor, QExchangeResult::ChangeState(new_state)) => {
cursor = new_cursor;
state = new_state;
continue;
}
(_, QExchangeResult::Error) => {
// respond with error
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16)
.to_le_bytes();
con.write_all(&[ResponseType::Error.value_u8(), a, b])
.await?;
con.flush().await?;
// reset buffer, cursor and state
buf.clear();
cursor = Default::default();
state = QExchangeState::default();
continue;
}
};
// now execute query
match engine::core::exec::dispatch_to_executor(global, &mut client_state, sq).await {
Ok(Response::Empty) => {
con.write_all(&[ResponseType::Empty.value_u8()]).await?;
}
Ok(Response::Serialized { ty, size, data }) => {
con.write_u8(ty.value_u8()).await?;
let mut irep = IntegerRepr::new();
con.write_all(irep.as_bytes(size as u64)).await?;
con.write_u8(b'\n').await?;
con.write_all(&data).await?;
}
Ok(Response::Bool(b)) => {
con.write_all(&[ResponseType::Bool.value_u8(), b as u8])
.await?
}
Ok(Response::Null) => con.write_u8(ResponseType::Null.value_u8()).await?,
Err(e) => {
let [a, b] = (e.value_u8() as u16).to_le_bytes();
con.write_all(&[ResponseType::Error.value_u8(), a, b])
.await?;
}
}
con.flush().await?;
// reset buffer, cursor and state
buf.clear();
cursor = Default::default();
state = QExchangeState::default();
}
}
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
enum PostHandshake { enum PostHandshake {
@ -315,41 +195,194 @@ async fn do_handshake<S: Socket>(
Ok(PostHandshake::Error(ProtocolError::RejectAuth)) Ok(PostHandshake::Error(ProtocolError::RejectAuth))
} }
#[derive(Debug, PartialEq, Clone, Copy)] /*
enum AccumlatorStatus { exec event loop
Pending(u64), */
Completed(u64),
async fn cleanup_for_next_query<S: Socket>(
con: &mut BufWriter<S>,
buf: &mut BytesMut,
) -> IoResult<(ExchangeState, usize)> {
con.flush().await?; // flush write buffer
buf.clear(); // clear read buffer
Ok((ExchangeState::default(), 0))
} }
/// Scan an integer pub(super) async fn query_loop<S: Socket>(
/// con: &mut BufWriter<S>,
/// Allowed sequences: buf: &mut BytesMut,
/// - < int >\n global: &Global,
/// - \n ; FIXME(@ohsayan): a LF only sequence is allowed. should it be removed? ) -> IoResult<QueryLoopResult> {
fn scan_int(s: &mut BufferedScanner, acc: u64) -> Result<AccumlatorStatus, ()> { // handshake
let mut acc = acc; let mut client_state = match do_handshake(con, buf, global).await? {
let mut okay = true; PostHandshake::Okay(hs) => hs,
let mut end = s.rounded_eq(b'\n'); PostHandshake::ConnectionClosedFin => return Ok(QueryLoopResult::Fin),
while okay & !end & !s.eof() { PostHandshake::ConnectionClosedRst => return Ok(QueryLoopResult::Rst),
let d = unsafe { s.next_byte() }; PostHandshake::Error(e) => {
okay &= d.is_ascii_digit(); // failed to handshake; we'll close the connection
match acc.checked_mul(10).map(|v| v.checked_add((d & 0x0f) as _)) { let hs_err_packet = [b'H', 0, 1, e.value_u8()];
Some(Some(v)) => acc = v, con.write_all(&hs_err_packet).await?;
_ => okay = false, return Ok(QueryLoopResult::HSFailed);
}
};
// done handshaking
con.write_all(b"H\x00\x00\x00").await?;
con.flush().await?;
let mut state = ExchangeState::default();
let mut cursor = 0;
loop {
if con.read_buf(buf).await? == 0 {
if buf.is_empty() {
return Ok(QueryLoopResult::Fin);
} else {
return Ok(QueryLoopResult::Rst);
}
}
match Exchange::try_complete(
unsafe {
// UNSAFE(@ohsayan): the cursor is either 0 or returned by the exchange impl
BufferedScanner::new_with_cursor(&buf, cursor)
},
state,
) {
Ok((result, new_cursor)) => match result {
ExchangeResult::NewState(new_state) => {
state = new_state;
cursor = new_cursor;
}
ExchangeResult::Simple(query) => {
exec_simple(con, &mut client_state, global, query).await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?;
}
ExchangeResult::Pipeline(pipe) => {
exec_pipe(con, &mut client_state, global, pipe).await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?;
}
},
Err(()) => {
// respond with error
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16)
.to_le_bytes();
con.write_all(&[ResponseType::Error.value_u8(), a, b])
.await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?;
}
} }
end = s.rounded_eq(b'\n');
} }
unsafe { }
// UNSAFE(@ohsayan): if we hit an LF, then we still have space until EOA
s.incr_cursor_if(end) /*
responses
*/
#[repr(u8)]
#[derive(sky_macros::EnumMethods, Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[allow(unused)]
pub enum ResponseType {
Null = 0x00,
Bool = 0x01,
UInt8 = 0x02,
UInt16 = 0x03,
UInt32 = 0x04,
UInt64 = 0x05,
SInt8 = 0x06,
SInt16 = 0x07,
SInt32 = 0x08,
SInt64 = 0x09,
Float32 = 0x0A,
Float64 = 0x0B,
Binary = 0x0C,
String = 0x0D,
List = 0x0E,
Dict = 0x0F,
Error = 0x10,
Row = 0x11,
Empty = 0x12,
MultiRow = 0x13,
}
#[derive(Debug, PartialEq)]
pub enum Response {
Empty,
Null,
Serialized {
ty: ResponseType,
size: usize,
data: Vec<u8>,
},
Bool(bool),
}
async fn write_response<S: Socket>(
resp: QueryResult<Response>,
con: &mut BufWriter<S>,
) -> IoResult<()> {
match resp {
Ok(Response::Empty) => con.write_all(&[ResponseType::Empty.value_u8()]).await,
Ok(Response::Serialized { ty, size, data }) => {
con.write_u8(ty.value_u8()).await?;
let mut irep = IntegerRepr::new();
con.write_all(irep.as_bytes(size as u64)).await?;
con.write_u8(b'\n').await?;
con.write_all(&data).await
}
Ok(Response::Bool(b)) => {
con.write_all(&[ResponseType::Bool.value_u8(), b as u8])
.await
}
Ok(Response::Null) => con.write_u8(ResponseType::Null.value_u8()).await,
Err(e) => {
let [a, b] = (e.value_u8() as u16).to_le_bytes();
con.write_all(&[ResponseType::Error.value_u8(), a, b]).await
}
} }
if okay & end { }
Ok(AccumlatorStatus::Completed(acc))
} else { /*
if okay { simple query
Ok(AccumlatorStatus::Pending(acc)) */
} else {
Err(()) async fn exec_simple<S: Socket>(
con: &mut BufWriter<S>,
cs: &mut ClientLocalState,
global: &Global,
query: SQuery<'_>,
) -> IoResult<()> {
write_response(
engine::core::exec::dispatch_to_executor(global, cs, query).await,
con,
)
.await
}
/*
pipeline
*/
async fn exec_pipe<'a, S: Socket>(
con: &mut BufWriter<S>,
cs: &mut ClientLocalState,
global: &Global,
mut pipe: Pipeline<'a>,
) -> IoResult<()> {
loop {
match pipe.next_query() {
Ok(None) => break Ok(()),
Ok(Some(q)) => {
write_response(
engine::core::exec::dispatch_to_executor(global, cs, q).await,
con,
)
.await?
}
Err(()) => {
// respond with error
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16)
.to_le_bytes();
con.write_all(&[ResponseType::Error.value_u8(), a, b])
.await?;
break Ok(());
}
} }
} }
} }

@ -24,28 +24,15 @@
* *
*/ */
use crate::engine::net::protocol::exchange::Resume;
use { use {
super::{ super::handshake::ProtocolError,
exchange::{self, QExchangeResult, QExchangeState}, crate::engine::{
handshake::ProtocolError, mem::BufferedScanner,
SQuery, net::protocol::handshake::{
}, AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode,
crate::{ HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode,
engine::{
mem::BufferedScanner,
net::protocol::{
handshake::{
AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode,
HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode,
},
scan_int, AccumlatorStatus,
},
}, },
util::test_utils,
}, },
rand::Rng,
}; };
pub(super) fn create_simple_query<const N: usize>(query: &str, params: [&str; N]) -> Vec<u8> { pub(super) fn create_simple_query<const N: usize>(query: &str, params: [&str; N]) -> Vec<u8> {
@ -283,212 +270,3 @@ fn hs_bad_auth_mode() {
assert_eq!(hs_result, HandshakeResult::Error(ProtocolError::RejectAuth)) assert_eq!(hs_result, HandshakeResult::Error(ProtocolError::RejectAuth))
}) })
} }
/*
QT-DEX/SQ
*/
const SQ: &str = "select * from myspace.mymodel where username = ?";
fn parse_staged<const N: usize>(
query: &str,
params: [&str; N],
eq: impl Fn(SQuery),
rng: &mut impl Rng,
) {
let __query_buffer = create_simple_query(query, params);
for _ in 0..__query_buffer.len() {
let mut __read_total = 0;
let mut cursor = Default::default();
let mut state = QExchangeState::default();
loop {
let remaining = __query_buffer.len() - __read_total;
let read_this_time = {
let mut cnt = 0;
if remaining == 1 {
1
} else {
let mut last = test_utils::random_number(1, remaining, rng);
loop {
if cnt >= 10 {
break last;
}
// if we're reading exact, try to keep it low
if last == remaining {
cnt += 1;
last = test_utils::random_number(1, remaining, rng);
} else {
break last;
}
}
}
};
__read_total += read_this_time;
let buffered = &__query_buffer[..__read_total];
if !state.has_reached_target(buffered) {
continue;
}
match unsafe { exchange::resume(buffered, cursor, state) } {
(new_cursor, QExchangeResult::ChangeState(new_state)) => {
cursor = new_cursor;
state = new_state;
continue;
}
(_, QExchangeResult::SQCompleted(q)) => {
eq(q);
break;
}
_ => panic!(),
}
}
}
}
#[test]
fn staged_randomized() {
let mut rng = test_utils::rng();
parse_staged(
SQ,
["sayan"],
|q| {
assert_eq!(q.query_str(), SQ);
assert_eq!(q.params_str(), "sayan");
},
&mut rng,
);
}
#[test]
fn stages_manual() {
let query = create_simple_query("select * from mymodel where username = ?", ["sayan"]);
assert_eq!(
unsafe {
exchange::resume(
&query[..QExchangeState::MIN_READ],
Default::default(),
Default::default(),
)
},
(
Resume::test_new(5),
QExchangeResult::ChangeState(QExchangeState::new_test(
exchange::QExchangeStateInternal::PendingMeta2,
52,
48,
4
))
)
);
assert_eq!(
unsafe {
exchange::resume(
&query[..QExchangeState::MIN_READ + 1],
Default::default(),
Default::default(),
)
},
(
Resume::test_new(6),
QExchangeResult::ChangeState(QExchangeState::new_test(
exchange::QExchangeStateInternal::PendingMeta2,
52,
48,
40
))
)
);
assert_eq!(
unsafe {
exchange::resume(
&query[..QExchangeState::MIN_READ + 2],
Default::default(),
Default::default(),
)
},
(
Resume::test_new(7),
QExchangeResult::ChangeState(QExchangeState::new_test(
exchange::QExchangeStateInternal::PendingData,
52,
48,
40
))
)
);
// the cursor should never change
for upper_bound in QExchangeState::MIN_READ + 2..query.len() {
assert_eq!(
unsafe {
exchange::resume(
&query[..upper_bound],
Default::default(),
Default::default(),
)
},
(
Resume::test_new(7),
QExchangeResult::ChangeState(QExchangeState::new_test(
exchange::QExchangeStateInternal::PendingData,
52,
48,
40
))
)
);
}
match unsafe { exchange::resume(&query, Default::default(), Default::default()) } {
(l, QExchangeResult::SQCompleted(q)) if l.inner() == query.len() => {
assert_eq!(q.query_str(), "select * from mymodel where username = ?");
assert_eq!(q.params_str(), "sayan");
}
e => panic!("expected end, got {e:?}"),
}
}
#[test]
fn exchange_bad_segment_metadata() {
let exchange_packets = [
(
b"S4A\n32\ndelete from dbs where dbname = ?sillydb".as_slice(),
4,
"incorrect packet size",
),
(
b"S42\n3A\ndelete from dbs where dbname = ?sillydb",
7,
"incorrect q window",
),
(
b"S4A\n3A\ndelete from dbs where dbname = ?sillydb",
4,
"incorrect packet size and q window",
),
];
for (packet, cursor, description) in exchange_packets {
assert_eq!(
unsafe { exchange::resume(packet, Resume::test_new(0), Default::default()) },
(Resume::test_new(cursor), QExchangeResult::Error),
"failed for `{description}`"
)
}
}
#[test]
fn num_accumulate() {
let x = [
(b"1".as_slice(), Ok(AccumlatorStatus::Pending(1)), 1usize),
(b"12", Ok(AccumlatorStatus::Pending(12)), 2),
(b"123", Ok(AccumlatorStatus::Pending(123)), 3),
(b"123\n", Ok(AccumlatorStatus::Completed(123)), 4),
(b"\n", Ok(AccumlatorStatus::Completed(0)), 1),
(b"A", Err(()), 1),
(b"A\n", Err(()), 2),
(b"1A", Err(()), 2),
(b"1A\n", Err(()), 3),
];
for (buf, res, cursor) in x {
let mut bs = BufferedScanner::new(buf);
assert_eq!(scan_int(&mut bs, 0), res);
assert_eq!(bs.cursor(), cursor);
}
}

Loading…
Cancel
Save