net: Rewrite protocol impl to support pipelines

next
Sayan Nandan 6 months ago
parent 0b0acd2038
commit d3b5fd8060
No known key found for this signature in database
GPG Key ID: 0EBD769024B24F0A

@ -430,13 +430,13 @@ impl FractalMgr {
_ = sigterm.recv() => {
info!("flp: finishing any pending maintenance tasks");
let global = global.clone();
tokio::task::spawn_blocking(|| self.general_executor(global)).await.unwrap();
let _ = tokio::task::spawn_blocking(|| self.general_executor(global)).await;
info!("flp: exited executor service");
break;
},
_ = tokio::time::sleep(dur) => {
let global = global.clone();
tokio::task::spawn_blocking(|| self.general_executor(global)).await.unwrap()
let _ = tokio::task::spawn_blocking(|| self.general_executor(global)).await;
}
task = lpq.recv() => {
let Task { threshold, task } = match task {

@ -66,7 +66,7 @@ impl<'a, T> Scanner<'a, T> {
self.__cursor
}
/// Returns the buffer from the current position
pub fn current_buffer(&self) -> &[T] {
pub fn current_buffer(&self) -> &'a [T] {
&self.d[self.__cursor..]
}
/// Returns the ptr to the cursor
@ -86,6 +86,9 @@ impl<'a, T> Scanner<'a, T> {
pub fn has_left(&self, sizeof: usize) -> bool {
self.remaining() >= sizeof
}
pub fn remaining_size_is(&self, size: usize) -> bool {
self.remaining() == size
}
/// Returns true if the rounded cursor matches the predicate
pub fn rounded_cursor_matches(&self, f: impl Fn(&T) -> bool) -> bool {
f(&self.d[self.rounded_cursor()]) & !self.eof()

@ -1,5 +1,5 @@
/*
* Created on Wed Sep 20 2023
* Created on Tue Apr 02 2024
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
@ -7,7 +7,7 @@
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2023, Sayan Nandan <ohsayan@outlook.com>
* Copyright (c) 2024, Sayan Nandan <nandansayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
@ -24,211 +24,308 @@
*
*/
use {super::AccumlatorStatus, crate::engine::mem::BufferedScanner};
use {
crate::{engine::mem::BufferedScanner, util::compiler},
core::fmt,
};
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct Resume(usize);
impl Resume {
#[cfg(test)]
pub(super) const fn test_new(v: usize) -> Self {
Self(v)
}
#[cfg(test)]
pub(super) const fn inner(&self) -> usize {
self.0
}
}
impl Default for Resume {
fn default() -> Self {
Self(0)
}
}
/*
Skyhash/2.1 Implementation
---
This is an implementation of Skyhash/2.1, Skytable's data exchange protocol.
pub(super) unsafe fn resume<'a>(
buf: &'a [u8],
last_cursor: Resume,
last_state: QExchangeState,
) -> (Resume, QExchangeResult<'a>) {
let mut scanner = unsafe {
// UNSAFE(@ohsayan): we are the ones who generate the cursor and restore it
BufferedScanner::new_with_cursor(buf, last_cursor.0)
};
let ret = last_state.resume(&mut scanner);
(Resume(scanner.cursor()), ret)
}
0. Notes
++++++++++++++++++
- 2.1 is fully backwards compatible with 2.0 clients. As such we don't even designate it as a separate version.
- The "LF exception" essentially allows `0\n` to be equal to `\n`. It's unimportant to enforce this
1.1 Query Types
++++++++++++++++++
The protocol makes two distinctions, at the protocol-level about the type of queries:
a. Simple query
b. Pipeline
1.1.1 Simple Query
++++++++++++++++++
A simple query
*/
/*
SQ
sq definition
*/
#[derive(Debug, PartialEq)]
#[derive(Debug)]
pub struct SQuery<'a> {
q: &'a [u8],
buf: &'a [u8],
q_window: usize,
}
impl<'a> SQuery<'a> {
pub(super) fn new(q: &'a [u8], q_window: usize) -> Self {
Self { q, q_window }
fn new(buf: &'a [u8], q_window: usize) -> Self {
Self { buf, q_window }
}
pub fn query(&self) -> &[u8] {
&self.buf[..self.q_window]
}
pub fn params(&self) -> &[u8] {
&self.buf[self.q_window..]
}
}
/*
scanint
*/
fn scan_usize_guaranteed_termination(scanner: &mut BufferedScanner) -> Result<usize, ()> {
let mut ret = 0usize;
let mut stop = scanner.rounded_eq(b'\n');
while !scanner.eof() & !stop {
let next_byte = unsafe {
// UNSAFE(@ohsayan): loop invariant
scanner.next_byte()
};
match ret
.checked_mul(10)
.map(|int| int.checked_add((next_byte & 0x0f) as usize))
{
Some(Some(int)) if next_byte.is_ascii_digit() => ret = int,
_ => return Err(()),
}
stop = scanner.rounded_eq(b'\n');
}
unsafe {
// UNSAFE(@ohsayan): scanned stop, but not accounted for yet
scanner.incr_cursor_if(stop)
}
if stop {
Ok(ret)
} else {
Err(())
}
}
#[derive(Clone, Copy)]
struct Usize {
v: isize,
}
impl Usize {
const SHIFT: u32 = isize::BITS - 1;
const MASK: isize = 1 << Self::SHIFT;
#[inline(always)]
const fn new(v: isize) -> Self {
Self { v }
}
pub fn payload(&self) -> &'a [u8] {
self.q
#[inline(always)]
const fn new_unflagged(int: usize) -> Self {
Self::new(int as isize)
}
pub fn q_window(&self) -> usize {
self.q_window
#[inline(always)]
fn int(&self) -> usize {
(self.v & !Self::MASK) as usize
}
pub fn query(&self) -> &'a [u8] {
&self.payload()[..self.q_window()]
#[inline(always)]
fn update(&mut self, new: usize) {
self.v = (new as isize) | (self.v & Self::MASK);
}
pub fn params(&self) -> &'a [u8] {
&self.payload()[self.q_window()..]
#[inline(always)]
fn flag(&self) -> bool {
(self.v & Self::MASK) != 0
}
#[cfg(test)]
pub fn query_str(&self) -> &str {
core::str::from_utf8(self.query()).unwrap()
#[inline(always)]
fn set_flag_if(&mut self, iff: bool) {
self.v |= (iff as isize) << Self::SHIFT;
}
#[cfg(test)]
pub fn params_str(&self) -> &str {
core::str::from_utf8(self.params()).unwrap()
}
impl fmt::Debug for Usize {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Usize")
.field("int", &self.int())
.field("flag", &self.flag())
.finish()
}
}
impl Usize {
/// Attempt to "complete" a scan of the integer. Idempotency guarantee: it is guaranteed that calls would not change the state
/// of the [`Usize`] or the buffer if the final state has been reached
fn update_scanned(&mut self, scanner: &mut BufferedScanner) -> Result<(), ()> {
let mut stop = scanner.rounded_eq(b'\n');
while !stop & !scanner.eof() & !self.flag() {
let byte = unsafe {
// UNSAFE(@ohsayan): verified by loop invariant
scanner.next_byte()
};
match (self.int() as isize) // this cast allows us to guarantee that we don't trip the flag
.checked_mul(10)
.map(|int| int.checked_add((byte & 0x0f) as isize))
{
Some(Some(int)) if byte.is_ascii_digit() => self.update(int as usize),
_ => return Err(()),
}
stop = scanner.rounded_eq(b'\n');
}
unsafe {
// UNSAFE(@ohsayan): scanned stop byte but did not account for it; the flag check is for cases where the input buffer
// has something like [LF][LF] in which case we stopped at the first LF but we would accidentally read the second one
// on the second derogatory call
scanner.incr_cursor_if(stop & !self.flag())
}
self.set_flag_if(stop | self.flag()); // if second call we must check the flag
Ok(())
}
}
/*
utils
states
*/
#[derive(Debug, PartialEq, Clone, Copy)]
pub(super) enum QExchangeStateInternal {
#[derive(Debug)]
pub enum ExchangeState {
Initial,
PendingMeta1,
PendingMeta2,
PendingData,
Simple(SQState),
Pipeline(PipeState),
}
impl Default for QExchangeStateInternal {
fn default() -> Self {
Self::Initial
#[derive(Debug)]
pub struct SQState {
packet_s: Usize,
}
impl SQState {
const fn new(packet_s: Usize) -> Self {
Self { packet_s }
}
}
#[derive(Debug, PartialEq)]
pub(super) struct QExchangeState {
state: QExchangeStateInternal,
target: usize,
md_packet_size: u64,
md_q_window: u64,
#[derive(Debug)]
pub struct PipeState {
packet_s: Usize,
}
impl PipeState {
const fn new(packet_s: Usize) -> Self {
Self { packet_s }
}
}
impl Default for QExchangeState {
impl Default for ExchangeState {
fn default() -> Self {
Self::new()
Self::Initial
}
}
#[derive(Debug, PartialEq)]
/// Result after attempting to complete (or terminate) a query time exchange
pub(super) enum QExchangeResult<'a> {
/// We completed the exchange and yielded a [`SQuery`]
SQCompleted(SQuery<'a>),
/// We're changing states
ChangeState(QExchangeState),
/// We hit an error and need to terminate this exchange
Error,
pub enum ExchangeResult<'a> {
NewState(ExchangeState),
Simple(SQuery<'a>),
Pipeline(Pipeline<'a>),
}
impl QExchangeState {
fn _new(
state: QExchangeStateInternal,
target: usize,
md_packet_size: u64,
md_q_window: u64,
) -> Self {
Self {
state,
target,
md_packet_size,
md_q_window,
}
pub struct Exchange<'a> {
scanner: BufferedScanner<'a>,
}
impl<'a> Exchange<'a> {
const MIN_Q_SIZE: usize = "P0\n".len();
fn new(scanner: BufferedScanner<'a>) -> Self {
Self { scanner }
}
#[cfg(test)]
pub(super) fn new_test(
state: QExchangeStateInternal,
target: usize,
md_packet_size: u64,
md_q_window: u64,
) -> Self {
Self::_new(state, target, md_packet_size, md_q_window)
pub fn try_complete(
scanner: BufferedScanner<'a>,
state: ExchangeState,
) -> Result<(ExchangeResult, usize), ()> {
Self::new(scanner).complete(state)
}
}
impl QExchangeState {
pub const MIN_READ: usize = b"S\x00\n\x00\n".len();
pub fn new() -> Self {
Self::_new(QExchangeStateInternal::Initial, Self::MIN_READ, 0, 0)
impl<'a> Exchange<'a> {
fn complete(mut self, state: ExchangeState) -> Result<(ExchangeResult<'a>, usize), ()> {
match state {
ExchangeState::Initial => {
if compiler::likely(self.scanner.has_left(Self::MIN_Q_SIZE)) {
let first_byte = unsafe {
// UNSAFE(@ohsayan): already verified in above branch
self.scanner.next_byte()
};
match first_byte {
b'S' => self.process_simple(SQState::new(Usize::new_unflagged(0))),
b'P' => self.process_pipe(PipeState::new(Usize::new_unflagged(0))),
_ => return Err(()),
}
} else {
Ok(ExchangeResult::NewState(state))
}
}
ExchangeState::Simple(sq_s) => self.process_simple(sq_s),
ExchangeState::Pipeline(pipe_s) => self.process_pipe(pipe_s),
}
.map(|ret| (ret, self.scanner.cursor()))
}
pub fn has_reached_target(&self, new_buffer: &[u8]) -> bool {
new_buffer.len() >= self.target
fn process_simple(&mut self, mut sq_state: SQState) -> Result<ExchangeResult<'a>, ()> {
// try to complete the packet size if needed
sq_state.packet_s.update_scanned(&mut self.scanner)?;
if sq_state.packet_s.flag() & self.scanner.remaining_size_is(sq_state.packet_s.int()) {
// we have the full packet size and the required data
let q_window = scan_usize_guaranteed_termination(&mut self.scanner)?;
let nonzero = (q_window != 0) & (sq_state.packet_s.int() != 0);
if compiler::likely(self.scanner.remaining_size_is(q_window) & nonzero) {
// this check is important because the client might have given us an incorrect q size
Ok(ExchangeResult::Simple(SQuery::new(
self.scanner.current_buffer(),
q_window,
)))
} else {
Err(())
}
} else {
Ok(ExchangeResult::NewState(ExchangeState::Simple(sq_state)))
}
}
fn resume<'a>(self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> {
debug_assert!(scanner.has_left(Self::MIN_READ));
match self.state {
QExchangeStateInternal::Initial => self.start_initial(scanner),
QExchangeStateInternal::PendingMeta1 => self.resume_at_md1(scanner),
QExchangeStateInternal::PendingMeta2 => self.resume_at_md2(scanner),
QExchangeStateInternal::PendingData => self.resume_data(scanner),
fn process_pipe(&mut self, mut pipe_s: PipeState) -> Result<ExchangeResult<'a>, ()> {
// try to complete the packet size if needed
pipe_s.packet_s.update_scanned(&mut self.scanner)?;
if pipe_s.packet_s.flag() & self.scanner.remaining_size_is(pipe_s.packet_s.int()) {
// great, we have the entire packet
Ok(ExchangeResult::Pipeline(Pipeline::new(
self.scanner.current_buffer(),
)))
} else {
Ok(ExchangeResult::NewState(ExchangeState::Pipeline(pipe_s)))
}
}
fn start_initial<'a>(self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> {
if unsafe { scanner.next_byte() } != b'S' {
// has to be a simple query!
return QExchangeResult::Error;
}
/*
pipeline
*/
pub struct Pipeline<'a> {
scanner: BufferedScanner<'a>,
}
impl<'a> Pipeline<'a> {
fn new(buf: &'a [u8]) -> Self {
Self {
scanner: BufferedScanner::new(buf),
}
self.resume_at_md1(scanner)
}
fn resume_at_md1<'a>(mut self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> {
let packet_size = match super::scan_int(scanner, self.md_packet_size) {
Ok(AccumlatorStatus::Completed(v)) => v,
Ok(AccumlatorStatus::Pending(p)) => {
// if this is the first run, we read 5 bytes and need atleast one more; if this is a resume we read one or more bytes and
// need atleast one more
self.target += 1;
self.md_packet_size = p;
self.state = QExchangeStateInternal::PendingMeta1;
return QExchangeResult::ChangeState(self);
}
Err(()) => return QExchangeResult::Error,
};
self.md_packet_size = packet_size;
self.target = scanner.cursor() + packet_size as usize;
// hand over control to md2
self.resume_at_md2(scanner)
}
fn resume_at_md2<'a>(mut self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> {
let q_window = match super::scan_int(scanner, self.md_q_window) {
Ok(AccumlatorStatus::Completed(v)) => v,
Ok(AccumlatorStatus::Pending(p)) => {
self.md_q_window = p;
self.state = QExchangeStateInternal::PendingMeta2;
return QExchangeResult::ChangeState(self);
}
Err(()) => return QExchangeResult::Error,
};
self.md_q_window = q_window;
// hand over control to data
self.resume_data(scanner)
}
fn resume_data<'a>(mut self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> {
let df_size = self.target - scanner.cursor();
if scanner.remaining() == df_size {
unsafe {
QExchangeResult::SQCompleted(SQuery::new(
scanner.next_chunk_variable(df_size),
self.md_q_window as usize,
))
}
}
pub fn next_query(&mut self) -> Result<Option<SQuery<'a>>, ()> {
let nonzero = self.scanner.buffer_len() != 0;
if self.scanner.eof() & nonzero {
Ok(None)
} else {
self.state = QExchangeStateInternal::PendingData;
QExchangeResult::ChangeState(self)
let query_size = scan_usize_guaranteed_termination(&mut self.scanner)?;
let param_size = scan_usize_guaranteed_termination(&mut self.scanner)?;
let (full_size, overflow) = param_size.overflowing_add(query_size);
if compiler::likely(self.scanner.remaining_size_is(full_size) & !overflow) {
Ok(Some(SQuery {
buf: self.scanner.current_buffer(),
q_window: query_size,
}))
} else {
Err(())
}
}
}
}

@ -44,14 +44,9 @@ mod handshake;
#[cfg(test)]
mod tests;
// re-export
pub use exchange::SQuery;
use crate::engine::core::system_db::VerifyUser;
use {
self::{
exchange::{QExchangeResult, QExchangeState},
exchange::{Exchange, ExchangeResult, ExchangeState, Pipeline},
handshake::{
AuthMode, CHandshake, DataExchangeMode, HandshakeResult, HandshakeState,
HandshakeVersion, ProtocolError, ProtocolVersion, QueryMode,
@ -60,7 +55,8 @@ use {
super::{IoResult, QueryLoopResult, Socket},
crate::engine::{
self,
error::QueryError,
core::system_db::VerifyUser,
error::{QueryError, QueryResult},
fractal::{Global, GlobalInstanceLike},
mem::{BufferedScanner, IntegerRepr},
},
@ -68,31 +64,12 @@ use {
tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter},
};
#[repr(u8)]
#[derive(sky_macros::EnumMethods, Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[allow(unused)]
pub enum ResponseType {
Null = 0x00,
Bool = 0x01,
UInt8 = 0x02,
UInt16 = 0x03,
UInt32 = 0x04,
UInt64 = 0x05,
SInt8 = 0x06,
SInt16 = 0x07,
SInt32 = 0x08,
SInt64 = 0x09,
Float32 = 0x0A,
Float64 = 0x0B,
Binary = 0x0C,
String = 0x0D,
List = 0x0E,
Dict = 0x0F,
Error = 0x10,
Row = 0x11,
Empty = 0x12,
MultiRow = 0x13,
}
// re-export
pub use self::exchange::SQuery;
/*
connection state
*/
#[derive(Debug, PartialEq)]
pub struct ClientLocalState {
@ -128,106 +105,9 @@ impl ClientLocalState {
}
}
#[derive(Debug, PartialEq)]
pub enum Response {
Empty,
Null,
Serialized {
ty: ResponseType,
size: usize,
data: Vec<u8>,
},
Bool(bool),
}
pub(super) async fn query_loop<S: Socket>(
con: &mut BufWriter<S>,
buf: &mut BytesMut,
global: &Global,
) -> IoResult<QueryLoopResult> {
// handshake
let mut client_state = match do_handshake(con, buf, global).await? {
PostHandshake::Okay(hs) => hs,
PostHandshake::ConnectionClosedFin => return Ok(QueryLoopResult::Fin),
PostHandshake::ConnectionClosedRst => return Ok(QueryLoopResult::Rst),
PostHandshake::Error(e) => {
// failed to handshake; we'll close the connection
let hs_err_packet = [b'H', 0, 1, e.value_u8()];
con.write_all(&hs_err_packet).await?;
return Ok(QueryLoopResult::HSFailed);
}
};
// done handshaking
con.write_all(b"H\x00\x00\x00").await?;
con.flush().await?;
let mut state = QExchangeState::default();
let mut cursor = Default::default();
loop {
if con.read_buf(buf).await? == 0 {
if buf.is_empty() {
return Ok(QueryLoopResult::Fin);
} else {
return Ok(QueryLoopResult::Rst);
}
}
if !state.has_reached_target(buf) {
// we haven't buffered sufficient bytes; keep working
continue;
}
let sq = match unsafe {
// UNSAFE(@ohsayan): as the resume cursor is private, we can't access this anyways
exchange::resume(buf, cursor, state)
} {
(_, QExchangeResult::SQCompleted(sq)) => sq,
(new_cursor, QExchangeResult::ChangeState(new_state)) => {
cursor = new_cursor;
state = new_state;
continue;
}
(_, QExchangeResult::Error) => {
// respond with error
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16)
.to_le_bytes();
con.write_all(&[ResponseType::Error.value_u8(), a, b])
.await?;
con.flush().await?;
// reset buffer, cursor and state
buf.clear();
cursor = Default::default();
state = QExchangeState::default();
continue;
}
};
// now execute query
match engine::core::exec::dispatch_to_executor(global, &mut client_state, sq).await {
Ok(Response::Empty) => {
con.write_all(&[ResponseType::Empty.value_u8()]).await?;
}
Ok(Response::Serialized { ty, size, data }) => {
con.write_u8(ty.value_u8()).await?;
let mut irep = IntegerRepr::new();
con.write_all(irep.as_bytes(size as u64)).await?;
con.write_u8(b'\n').await?;
con.write_all(&data).await?;
}
Ok(Response::Bool(b)) => {
con.write_all(&[ResponseType::Bool.value_u8(), b as u8])
.await?
}
Ok(Response::Null) => con.write_u8(ResponseType::Null.value_u8()).await?,
Err(e) => {
let [a, b] = (e.value_u8() as u16).to_le_bytes();
con.write_all(&[ResponseType::Error.value_u8(), a, b])
.await?;
}
}
con.flush().await?;
// reset buffer, cursor and state
buf.clear();
cursor = Default::default();
state = QExchangeState::default();
}
}
/*
handshake
*/
#[derive(Debug, PartialEq)]
enum PostHandshake {
@ -315,41 +195,194 @@ async fn do_handshake<S: Socket>(
Ok(PostHandshake::Error(ProtocolError::RejectAuth))
}
#[derive(Debug, PartialEq, Clone, Copy)]
enum AccumlatorStatus {
Pending(u64),
Completed(u64),
/*
exec event loop
*/
async fn cleanup_for_next_query<S: Socket>(
con: &mut BufWriter<S>,
buf: &mut BytesMut,
) -> IoResult<(ExchangeState, usize)> {
con.flush().await?; // flush write buffer
buf.clear(); // clear read buffer
Ok((ExchangeState::default(), 0))
}
/// Scan an integer
///
/// Allowed sequences:
/// - < int >\n
/// - \n ; FIXME(@ohsayan): a LF only sequence is allowed. should it be removed?
fn scan_int(s: &mut BufferedScanner, acc: u64) -> Result<AccumlatorStatus, ()> {
let mut acc = acc;
let mut okay = true;
let mut end = s.rounded_eq(b'\n');
while okay & !end & !s.eof() {
let d = unsafe { s.next_byte() };
okay &= d.is_ascii_digit();
match acc.checked_mul(10).map(|v| v.checked_add((d & 0x0f) as _)) {
Some(Some(v)) => acc = v,
_ => okay = false,
pub(super) async fn query_loop<S: Socket>(
con: &mut BufWriter<S>,
buf: &mut BytesMut,
global: &Global,
) -> IoResult<QueryLoopResult> {
// handshake
let mut client_state = match do_handshake(con, buf, global).await? {
PostHandshake::Okay(hs) => hs,
PostHandshake::ConnectionClosedFin => return Ok(QueryLoopResult::Fin),
PostHandshake::ConnectionClosedRst => return Ok(QueryLoopResult::Rst),
PostHandshake::Error(e) => {
// failed to handshake; we'll close the connection
let hs_err_packet = [b'H', 0, 1, e.value_u8()];
con.write_all(&hs_err_packet).await?;
return Ok(QueryLoopResult::HSFailed);
}
};
// done handshaking
con.write_all(b"H\x00\x00\x00").await?;
con.flush().await?;
let mut state = ExchangeState::default();
let mut cursor = 0;
loop {
if con.read_buf(buf).await? == 0 {
if buf.is_empty() {
return Ok(QueryLoopResult::Fin);
} else {
return Ok(QueryLoopResult::Rst);
}
}
match Exchange::try_complete(
unsafe {
// UNSAFE(@ohsayan): the cursor is either 0 or returned by the exchange impl
BufferedScanner::new_with_cursor(&buf, cursor)
},
state,
) {
Ok((result, new_cursor)) => match result {
ExchangeResult::NewState(new_state) => {
state = new_state;
cursor = new_cursor;
}
ExchangeResult::Simple(query) => {
exec_simple(con, &mut client_state, global, query).await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?;
}
ExchangeResult::Pipeline(pipe) => {
exec_pipe(con, &mut client_state, global, pipe).await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?;
}
},
Err(()) => {
// respond with error
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16)
.to_le_bytes();
con.write_all(&[ResponseType::Error.value_u8(), a, b])
.await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?;
}
}
end = s.rounded_eq(b'\n');
}
unsafe {
// UNSAFE(@ohsayan): if we hit an LF, then we still have space until EOA
s.incr_cursor_if(end)
}
/*
responses
*/
#[repr(u8)]
#[derive(sky_macros::EnumMethods, Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)]
#[allow(unused)]
pub enum ResponseType {
Null = 0x00,
Bool = 0x01,
UInt8 = 0x02,
UInt16 = 0x03,
UInt32 = 0x04,
UInt64 = 0x05,
SInt8 = 0x06,
SInt16 = 0x07,
SInt32 = 0x08,
SInt64 = 0x09,
Float32 = 0x0A,
Float64 = 0x0B,
Binary = 0x0C,
String = 0x0D,
List = 0x0E,
Dict = 0x0F,
Error = 0x10,
Row = 0x11,
Empty = 0x12,
MultiRow = 0x13,
}
#[derive(Debug, PartialEq)]
pub enum Response {
Empty,
Null,
Serialized {
ty: ResponseType,
size: usize,
data: Vec<u8>,
},
Bool(bool),
}
async fn write_response<S: Socket>(
resp: QueryResult<Response>,
con: &mut BufWriter<S>,
) -> IoResult<()> {
match resp {
Ok(Response::Empty) => con.write_all(&[ResponseType::Empty.value_u8()]).await,
Ok(Response::Serialized { ty, size, data }) => {
con.write_u8(ty.value_u8()).await?;
let mut irep = IntegerRepr::new();
con.write_all(irep.as_bytes(size as u64)).await?;
con.write_u8(b'\n').await?;
con.write_all(&data).await
}
Ok(Response::Bool(b)) => {
con.write_all(&[ResponseType::Bool.value_u8(), b as u8])
.await
}
Ok(Response::Null) => con.write_u8(ResponseType::Null.value_u8()).await,
Err(e) => {
let [a, b] = (e.value_u8() as u16).to_le_bytes();
con.write_all(&[ResponseType::Error.value_u8(), a, b]).await
}
}
if okay & end {
Ok(AccumlatorStatus::Completed(acc))
} else {
if okay {
Ok(AccumlatorStatus::Pending(acc))
} else {
Err(())
}
/*
simple query
*/
async fn exec_simple<S: Socket>(
con: &mut BufWriter<S>,
cs: &mut ClientLocalState,
global: &Global,
query: SQuery<'_>,
) -> IoResult<()> {
write_response(
engine::core::exec::dispatch_to_executor(global, cs, query).await,
con,
)
.await
}
/*
pipeline
*/
async fn exec_pipe<'a, S: Socket>(
con: &mut BufWriter<S>,
cs: &mut ClientLocalState,
global: &Global,
mut pipe: Pipeline<'a>,
) -> IoResult<()> {
loop {
match pipe.next_query() {
Ok(None) => break Ok(()),
Ok(Some(q)) => {
write_response(
engine::core::exec::dispatch_to_executor(global, cs, q).await,
con,
)
.await?
}
Err(()) => {
// respond with error
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16)
.to_le_bytes();
con.write_all(&[ResponseType::Error.value_u8(), a, b])
.await?;
break Ok(());
}
}
}
}

@ -24,28 +24,15 @@
*
*/
use crate::engine::net::protocol::exchange::Resume;
use {
super::{
exchange::{self, QExchangeResult, QExchangeState},
handshake::ProtocolError,
SQuery,
},
crate::{
engine::{
mem::BufferedScanner,
net::protocol::{
handshake::{
AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode,
HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode,
},
scan_int, AccumlatorStatus,
},
super::handshake::ProtocolError,
crate::engine::{
mem::BufferedScanner,
net::protocol::handshake::{
AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode,
HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode,
},
util::test_utils,
},
rand::Rng,
};
pub(super) fn create_simple_query<const N: usize>(query: &str, params: [&str; N]) -> Vec<u8> {
@ -283,212 +270,3 @@ fn hs_bad_auth_mode() {
assert_eq!(hs_result, HandshakeResult::Error(ProtocolError::RejectAuth))
})
}
/*
QT-DEX/SQ
*/
const SQ: &str = "select * from myspace.mymodel where username = ?";
fn parse_staged<const N: usize>(
query: &str,
params: [&str; N],
eq: impl Fn(SQuery),
rng: &mut impl Rng,
) {
let __query_buffer = create_simple_query(query, params);
for _ in 0..__query_buffer.len() {
let mut __read_total = 0;
let mut cursor = Default::default();
let mut state = QExchangeState::default();
loop {
let remaining = __query_buffer.len() - __read_total;
let read_this_time = {
let mut cnt = 0;
if remaining == 1 {
1
} else {
let mut last = test_utils::random_number(1, remaining, rng);
loop {
if cnt >= 10 {
break last;
}
// if we're reading exact, try to keep it low
if last == remaining {
cnt += 1;
last = test_utils::random_number(1, remaining, rng);
} else {
break last;
}
}
}
};
__read_total += read_this_time;
let buffered = &__query_buffer[..__read_total];
if !state.has_reached_target(buffered) {
continue;
}
match unsafe { exchange::resume(buffered, cursor, state) } {
(new_cursor, QExchangeResult::ChangeState(new_state)) => {
cursor = new_cursor;
state = new_state;
continue;
}
(_, QExchangeResult::SQCompleted(q)) => {
eq(q);
break;
}
_ => panic!(),
}
}
}
}
#[test]
fn staged_randomized() {
let mut rng = test_utils::rng();
parse_staged(
SQ,
["sayan"],
|q| {
assert_eq!(q.query_str(), SQ);
assert_eq!(q.params_str(), "sayan");
},
&mut rng,
);
}
#[test]
fn stages_manual() {
let query = create_simple_query("select * from mymodel where username = ?", ["sayan"]);
assert_eq!(
unsafe {
exchange::resume(
&query[..QExchangeState::MIN_READ],
Default::default(),
Default::default(),
)
},
(
Resume::test_new(5),
QExchangeResult::ChangeState(QExchangeState::new_test(
exchange::QExchangeStateInternal::PendingMeta2,
52,
48,
4
))
)
);
assert_eq!(
unsafe {
exchange::resume(
&query[..QExchangeState::MIN_READ + 1],
Default::default(),
Default::default(),
)
},
(
Resume::test_new(6),
QExchangeResult::ChangeState(QExchangeState::new_test(
exchange::QExchangeStateInternal::PendingMeta2,
52,
48,
40
))
)
);
assert_eq!(
unsafe {
exchange::resume(
&query[..QExchangeState::MIN_READ + 2],
Default::default(),
Default::default(),
)
},
(
Resume::test_new(7),
QExchangeResult::ChangeState(QExchangeState::new_test(
exchange::QExchangeStateInternal::PendingData,
52,
48,
40
))
)
);
// the cursor should never change
for upper_bound in QExchangeState::MIN_READ + 2..query.len() {
assert_eq!(
unsafe {
exchange::resume(
&query[..upper_bound],
Default::default(),
Default::default(),
)
},
(
Resume::test_new(7),
QExchangeResult::ChangeState(QExchangeState::new_test(
exchange::QExchangeStateInternal::PendingData,
52,
48,
40
))
)
);
}
match unsafe { exchange::resume(&query, Default::default(), Default::default()) } {
(l, QExchangeResult::SQCompleted(q)) if l.inner() == query.len() => {
assert_eq!(q.query_str(), "select * from mymodel where username = ?");
assert_eq!(q.params_str(), "sayan");
}
e => panic!("expected end, got {e:?}"),
}
}
#[test]
fn exchange_bad_segment_metadata() {
let exchange_packets = [
(
b"S4A\n32\ndelete from dbs where dbname = ?sillydb".as_slice(),
4,
"incorrect packet size",
),
(
b"S42\n3A\ndelete from dbs where dbname = ?sillydb",
7,
"incorrect q window",
),
(
b"S4A\n3A\ndelete from dbs where dbname = ?sillydb",
4,
"incorrect packet size and q window",
),
];
for (packet, cursor, description) in exchange_packets {
assert_eq!(
unsafe { exchange::resume(packet, Resume::test_new(0), Default::default()) },
(Resume::test_new(cursor), QExchangeResult::Error),
"failed for `{description}`"
)
}
}
#[test]
fn num_accumulate() {
let x = [
(b"1".as_slice(), Ok(AccumlatorStatus::Pending(1)), 1usize),
(b"12", Ok(AccumlatorStatus::Pending(12)), 2),
(b"123", Ok(AccumlatorStatus::Pending(123)), 3),
(b"123\n", Ok(AccumlatorStatus::Completed(123)), 4),
(b"\n", Ok(AccumlatorStatus::Completed(0)), 1),
(b"A", Err(()), 1),
(b"A\n", Err(()), 2),
(b"1A", Err(()), 2),
(b"1A\n", Err(()), 3),
];
for (buf, res, cursor) in x {
let mut bs = BufferedScanner::new(buf);
assert_eq!(scan_int(&mut bs, 0), res);
assert_eq!(bs.cursor(), cursor);
}
}

Loading…
Cancel
Save