From d3b5fd8060ff4d3bdf366b8d25eaba7b8d1a9573 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Wed, 3 Apr 2024 17:24:20 +0530 Subject: [PATCH] net: Rewrite protocol impl to support pipelines --- server/src/engine/fractal/mgr.rs | 4 +- server/src/engine/mem/scanner.rs | 5 +- server/src/engine/net/protocol/exchange.rs | 417 +++++++++++++-------- server/src/engine/net/protocol/mod.rs | 357 ++++++++++-------- server/src/engine/net/protocol/tests.rs | 234 +----------- 5 files changed, 464 insertions(+), 553 deletions(-) diff --git a/server/src/engine/fractal/mgr.rs b/server/src/engine/fractal/mgr.rs index 09924415..0b8f7190 100644 --- a/server/src/engine/fractal/mgr.rs +++ b/server/src/engine/fractal/mgr.rs @@ -430,13 +430,13 @@ impl FractalMgr { _ = sigterm.recv() => { info!("flp: finishing any pending maintenance tasks"); let global = global.clone(); - tokio::task::spawn_blocking(|| self.general_executor(global)).await.unwrap(); + let _ = tokio::task::spawn_blocking(|| self.general_executor(global)).await; info!("flp: exited executor service"); break; }, _ = tokio::time::sleep(dur) => { let global = global.clone(); - tokio::task::spawn_blocking(|| self.general_executor(global)).await.unwrap() + let _ = tokio::task::spawn_blocking(|| self.general_executor(global)).await; } task = lpq.recv() => { let Task { threshold, task } = match task { diff --git a/server/src/engine/mem/scanner.rs b/server/src/engine/mem/scanner.rs index 2f2a2f03..c65b38fc 100644 --- a/server/src/engine/mem/scanner.rs +++ b/server/src/engine/mem/scanner.rs @@ -66,7 +66,7 @@ impl<'a, T> Scanner<'a, T> { self.__cursor } /// Returns the buffer from the current position - pub fn current_buffer(&self) -> &[T] { + pub fn current_buffer(&self) -> &'a [T] { &self.d[self.__cursor..] } /// Returns the ptr to the cursor @@ -86,6 +86,9 @@ impl<'a, T> Scanner<'a, T> { pub fn has_left(&self, sizeof: usize) -> bool { self.remaining() >= sizeof } + pub fn remaining_size_is(&self, size: usize) -> bool { + self.remaining() == size + } /// Returns true if the rounded cursor matches the predicate pub fn rounded_cursor_matches(&self, f: impl Fn(&T) -> bool) -> bool { f(&self.d[self.rounded_cursor()]) & !self.eof() diff --git a/server/src/engine/net/protocol/exchange.rs b/server/src/engine/net/protocol/exchange.rs index 7522b97e..419054ac 100644 --- a/server/src/engine/net/protocol/exchange.rs +++ b/server/src/engine/net/protocol/exchange.rs @@ -1,5 +1,5 @@ /* - * Created on Wed Sep 20 2023 + * Created on Tue Apr 02 2024 * * This file is a part of Skytable * Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source @@ -7,7 +7,7 @@ * vision to provide flexibility in data modelling without compromising * on performance, queryability or scalability. * - * Copyright (c) 2023, Sayan Nandan + * Copyright (c) 2024, Sayan Nandan * * This program is free software: you can redistribute it and/or modify * it under the terms of the GNU Affero General Public License as published by @@ -24,211 +24,308 @@ * */ -use {super::AccumlatorStatus, crate::engine::mem::BufferedScanner}; +use { + crate::{engine::mem::BufferedScanner, util::compiler}, + core::fmt, +}; -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)] -pub struct Resume(usize); -impl Resume { - #[cfg(test)] - pub(super) const fn test_new(v: usize) -> Self { - Self(v) - } - #[cfg(test)] - pub(super) const fn inner(&self) -> usize { - self.0 - } -} -impl Default for Resume { - fn default() -> Self { - Self(0) - } -} +/* + Skyhash/2.1 Implementation + --- + This is an implementation of Skyhash/2.1, Skytable's data exchange protocol. -pub(super) unsafe fn resume<'a>( - buf: &'a [u8], - last_cursor: Resume, - last_state: QExchangeState, -) -> (Resume, QExchangeResult<'a>) { - let mut scanner = unsafe { - // UNSAFE(@ohsayan): we are the ones who generate the cursor and restore it - BufferedScanner::new_with_cursor(buf, last_cursor.0) - }; - let ret = last_state.resume(&mut scanner); - (Resume(scanner.cursor()), ret) -} + 0. Notes + ++++++++++++++++++ + - 2.1 is fully backwards compatible with 2.0 clients. As such we don't even designate it as a separate version. + - The "LF exception" essentially allows `0\n` to be equal to `\n`. It's unimportant to enforce this + + 1.1 Query Types + ++++++++++++++++++ + The protocol makes two distinctions, at the protocol-level about the type of queries: + a. Simple query + b. Pipeline + + 1.1.1 Simple Query + ++++++++++++++++++ + A simple query +*/ /* - SQ + sq definition */ -#[derive(Debug, PartialEq)] +#[derive(Debug)] pub struct SQuery<'a> { - q: &'a [u8], + buf: &'a [u8], q_window: usize, } impl<'a> SQuery<'a> { - pub(super) fn new(q: &'a [u8], q_window: usize) -> Self { - Self { q, q_window } + fn new(buf: &'a [u8], q_window: usize) -> Self { + Self { buf, q_window } + } + pub fn query(&self) -> &[u8] { + &self.buf[..self.q_window] + } + pub fn params(&self) -> &[u8] { + &self.buf[self.q_window..] + } +} + +/* + scanint +*/ + +fn scan_usize_guaranteed_termination(scanner: &mut BufferedScanner) -> Result { + let mut ret = 0usize; + let mut stop = scanner.rounded_eq(b'\n'); + while !scanner.eof() & !stop { + let next_byte = unsafe { + // UNSAFE(@ohsayan): loop invariant + scanner.next_byte() + }; + match ret + .checked_mul(10) + .map(|int| int.checked_add((next_byte & 0x0f) as usize)) + { + Some(Some(int)) if next_byte.is_ascii_digit() => ret = int, + _ => return Err(()), + } + stop = scanner.rounded_eq(b'\n'); + } + unsafe { + // UNSAFE(@ohsayan): scanned stop, but not accounted for yet + scanner.incr_cursor_if(stop) + } + if stop { + Ok(ret) + } else { + Err(()) + } +} + +#[derive(Clone, Copy)] +struct Usize { + v: isize, +} + +impl Usize { + const SHIFT: u32 = isize::BITS - 1; + const MASK: isize = 1 << Self::SHIFT; + #[inline(always)] + const fn new(v: isize) -> Self { + Self { v } } - pub fn payload(&self) -> &'a [u8] { - self.q + #[inline(always)] + const fn new_unflagged(int: usize) -> Self { + Self::new(int as isize) } - pub fn q_window(&self) -> usize { - self.q_window + #[inline(always)] + fn int(&self) -> usize { + (self.v & !Self::MASK) as usize } - pub fn query(&self) -> &'a [u8] { - &self.payload()[..self.q_window()] + #[inline(always)] + fn update(&mut self, new: usize) { + self.v = (new as isize) | (self.v & Self::MASK); } - pub fn params(&self) -> &'a [u8] { - &self.payload()[self.q_window()..] + #[inline(always)] + fn flag(&self) -> bool { + (self.v & Self::MASK) != 0 } - #[cfg(test)] - pub fn query_str(&self) -> &str { - core::str::from_utf8(self.query()).unwrap() + #[inline(always)] + fn set_flag_if(&mut self, iff: bool) { + self.v |= (iff as isize) << Self::SHIFT; } - #[cfg(test)] - pub fn params_str(&self) -> &str { - core::str::from_utf8(self.params()).unwrap() +} + +impl fmt::Debug for Usize { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Usize") + .field("int", &self.int()) + .field("flag", &self.flag()) + .finish() + } +} + +impl Usize { + /// Attempt to "complete" a scan of the integer. Idempotency guarantee: it is guaranteed that calls would not change the state + /// of the [`Usize`] or the buffer if the final state has been reached + fn update_scanned(&mut self, scanner: &mut BufferedScanner) -> Result<(), ()> { + let mut stop = scanner.rounded_eq(b'\n'); + while !stop & !scanner.eof() & !self.flag() { + let byte = unsafe { + // UNSAFE(@ohsayan): verified by loop invariant + scanner.next_byte() + }; + match (self.int() as isize) // this cast allows us to guarantee that we don't trip the flag + .checked_mul(10) + .map(|int| int.checked_add((byte & 0x0f) as isize)) + { + Some(Some(int)) if byte.is_ascii_digit() => self.update(int as usize), + _ => return Err(()), + } + stop = scanner.rounded_eq(b'\n'); + } + unsafe { + // UNSAFE(@ohsayan): scanned stop byte but did not account for it; the flag check is for cases where the input buffer + // has something like [LF][LF] in which case we stopped at the first LF but we would accidentally read the second one + // on the second derogatory call + scanner.incr_cursor_if(stop & !self.flag()) + } + self.set_flag_if(stop | self.flag()); // if second call we must check the flag + Ok(()) } } /* - utils + states */ -#[derive(Debug, PartialEq, Clone, Copy)] -pub(super) enum QExchangeStateInternal { +#[derive(Debug)] +pub enum ExchangeState { Initial, - PendingMeta1, - PendingMeta2, - PendingData, + Simple(SQState), + Pipeline(PipeState), } -impl Default for QExchangeStateInternal { - fn default() -> Self { - Self::Initial +#[derive(Debug)] +pub struct SQState { + packet_s: Usize, +} + +impl SQState { + const fn new(packet_s: Usize) -> Self { + Self { packet_s } } } -#[derive(Debug, PartialEq)] -pub(super) struct QExchangeState { - state: QExchangeStateInternal, - target: usize, - md_packet_size: u64, - md_q_window: u64, +#[derive(Debug)] +pub struct PipeState { + packet_s: Usize, +} + +impl PipeState { + const fn new(packet_s: Usize) -> Self { + Self { packet_s } + } } -impl Default for QExchangeState { +impl Default for ExchangeState { fn default() -> Self { - Self::new() + Self::Initial } } -#[derive(Debug, PartialEq)] -/// Result after attempting to complete (or terminate) a query time exchange -pub(super) enum QExchangeResult<'a> { - /// We completed the exchange and yielded a [`SQuery`] - SQCompleted(SQuery<'a>), - /// We're changing states - ChangeState(QExchangeState), - /// We hit an error and need to terminate this exchange - Error, +pub enum ExchangeResult<'a> { + NewState(ExchangeState), + Simple(SQuery<'a>), + Pipeline(Pipeline<'a>), } -impl QExchangeState { - fn _new( - state: QExchangeStateInternal, - target: usize, - md_packet_size: u64, - md_q_window: u64, - ) -> Self { - Self { - state, - target, - md_packet_size, - md_q_window, - } +pub struct Exchange<'a> { + scanner: BufferedScanner<'a>, +} + +impl<'a> Exchange<'a> { + const MIN_Q_SIZE: usize = "P0\n".len(); + fn new(scanner: BufferedScanner<'a>) -> Self { + Self { scanner } } - #[cfg(test)] - pub(super) fn new_test( - state: QExchangeStateInternal, - target: usize, - md_packet_size: u64, - md_q_window: u64, - ) -> Self { - Self::_new(state, target, md_packet_size, md_q_window) + pub fn try_complete( + scanner: BufferedScanner<'a>, + state: ExchangeState, + ) -> Result<(ExchangeResult, usize), ()> { + Self::new(scanner).complete(state) } } -impl QExchangeState { - pub const MIN_READ: usize = b"S\x00\n\x00\n".len(); - pub fn new() -> Self { - Self::_new(QExchangeStateInternal::Initial, Self::MIN_READ, 0, 0) +impl<'a> Exchange<'a> { + fn complete(mut self, state: ExchangeState) -> Result<(ExchangeResult<'a>, usize), ()> { + match state { + ExchangeState::Initial => { + if compiler::likely(self.scanner.has_left(Self::MIN_Q_SIZE)) { + let first_byte = unsafe { + // UNSAFE(@ohsayan): already verified in above branch + self.scanner.next_byte() + }; + match first_byte { + b'S' => self.process_simple(SQState::new(Usize::new_unflagged(0))), + b'P' => self.process_pipe(PipeState::new(Usize::new_unflagged(0))), + _ => return Err(()), + } + } else { + Ok(ExchangeResult::NewState(state)) + } + } + ExchangeState::Simple(sq_s) => self.process_simple(sq_s), + ExchangeState::Pipeline(pipe_s) => self.process_pipe(pipe_s), + } + .map(|ret| (ret, self.scanner.cursor())) } - pub fn has_reached_target(&self, new_buffer: &[u8]) -> bool { - new_buffer.len() >= self.target + fn process_simple(&mut self, mut sq_state: SQState) -> Result, ()> { + // try to complete the packet size if needed + sq_state.packet_s.update_scanned(&mut self.scanner)?; + if sq_state.packet_s.flag() & self.scanner.remaining_size_is(sq_state.packet_s.int()) { + // we have the full packet size and the required data + let q_window = scan_usize_guaranteed_termination(&mut self.scanner)?; + let nonzero = (q_window != 0) & (sq_state.packet_s.int() != 0); + if compiler::likely(self.scanner.remaining_size_is(q_window) & nonzero) { + // this check is important because the client might have given us an incorrect q size + Ok(ExchangeResult::Simple(SQuery::new( + self.scanner.current_buffer(), + q_window, + ))) + } else { + Err(()) + } + } else { + Ok(ExchangeResult::NewState(ExchangeState::Simple(sq_state))) + } } - fn resume<'a>(self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> { - debug_assert!(scanner.has_left(Self::MIN_READ)); - match self.state { - QExchangeStateInternal::Initial => self.start_initial(scanner), - QExchangeStateInternal::PendingMeta1 => self.resume_at_md1(scanner), - QExchangeStateInternal::PendingMeta2 => self.resume_at_md2(scanner), - QExchangeStateInternal::PendingData => self.resume_data(scanner), + fn process_pipe(&mut self, mut pipe_s: PipeState) -> Result, ()> { + // try to complete the packet size if needed + pipe_s.packet_s.update_scanned(&mut self.scanner)?; + if pipe_s.packet_s.flag() & self.scanner.remaining_size_is(pipe_s.packet_s.int()) { + // great, we have the entire packet + Ok(ExchangeResult::Pipeline(Pipeline::new( + self.scanner.current_buffer(), + ))) + } else { + Ok(ExchangeResult::NewState(ExchangeState::Pipeline(pipe_s))) } } - fn start_initial<'a>(self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> { - if unsafe { scanner.next_byte() } != b'S' { - // has to be a simple query! - return QExchangeResult::Error; +} + +/* + pipeline +*/ + +pub struct Pipeline<'a> { + scanner: BufferedScanner<'a>, +} + +impl<'a> Pipeline<'a> { + fn new(buf: &'a [u8]) -> Self { + Self { + scanner: BufferedScanner::new(buf), } - self.resume_at_md1(scanner) - } - fn resume_at_md1<'a>(mut self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> { - let packet_size = match super::scan_int(scanner, self.md_packet_size) { - Ok(AccumlatorStatus::Completed(v)) => v, - Ok(AccumlatorStatus::Pending(p)) => { - // if this is the first run, we read 5 bytes and need atleast one more; if this is a resume we read one or more bytes and - // need atleast one more - self.target += 1; - self.md_packet_size = p; - self.state = QExchangeStateInternal::PendingMeta1; - return QExchangeResult::ChangeState(self); - } - Err(()) => return QExchangeResult::Error, - }; - self.md_packet_size = packet_size; - self.target = scanner.cursor() + packet_size as usize; - // hand over control to md2 - self.resume_at_md2(scanner) - } - fn resume_at_md2<'a>(mut self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> { - let q_window = match super::scan_int(scanner, self.md_q_window) { - Ok(AccumlatorStatus::Completed(v)) => v, - Ok(AccumlatorStatus::Pending(p)) => { - self.md_q_window = p; - self.state = QExchangeStateInternal::PendingMeta2; - return QExchangeResult::ChangeState(self); - } - Err(()) => return QExchangeResult::Error, - }; - self.md_q_window = q_window; - // hand over control to data - self.resume_data(scanner) - } - fn resume_data<'a>(mut self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> { - let df_size = self.target - scanner.cursor(); - if scanner.remaining() == df_size { - unsafe { - QExchangeResult::SQCompleted(SQuery::new( - scanner.next_chunk_variable(df_size), - self.md_q_window as usize, - )) - } + } + pub fn next_query(&mut self) -> Result>, ()> { + let nonzero = self.scanner.buffer_len() != 0; + if self.scanner.eof() & nonzero { + Ok(None) } else { - self.state = QExchangeStateInternal::PendingData; - QExchangeResult::ChangeState(self) + let query_size = scan_usize_guaranteed_termination(&mut self.scanner)?; + let param_size = scan_usize_guaranteed_termination(&mut self.scanner)?; + let (full_size, overflow) = param_size.overflowing_add(query_size); + if compiler::likely(self.scanner.remaining_size_is(full_size) & !overflow) { + Ok(Some(SQuery { + buf: self.scanner.current_buffer(), + q_window: query_size, + })) + } else { + Err(()) + } } } } diff --git a/server/src/engine/net/protocol/mod.rs b/server/src/engine/net/protocol/mod.rs index c1745eae..93a57764 100644 --- a/server/src/engine/net/protocol/mod.rs +++ b/server/src/engine/net/protocol/mod.rs @@ -44,14 +44,9 @@ mod handshake; #[cfg(test)] mod tests; -// re-export -pub use exchange::SQuery; - -use crate::engine::core::system_db::VerifyUser; - use { self::{ - exchange::{QExchangeResult, QExchangeState}, + exchange::{Exchange, ExchangeResult, ExchangeState, Pipeline}, handshake::{ AuthMode, CHandshake, DataExchangeMode, HandshakeResult, HandshakeState, HandshakeVersion, ProtocolError, ProtocolVersion, QueryMode, @@ -60,7 +55,8 @@ use { super::{IoResult, QueryLoopResult, Socket}, crate::engine::{ self, - error::QueryError, + core::system_db::VerifyUser, + error::{QueryError, QueryResult}, fractal::{Global, GlobalInstanceLike}, mem::{BufferedScanner, IntegerRepr}, }, @@ -68,31 +64,12 @@ use { tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter}, }; -#[repr(u8)] -#[derive(sky_macros::EnumMethods, Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] -#[allow(unused)] -pub enum ResponseType { - Null = 0x00, - Bool = 0x01, - UInt8 = 0x02, - UInt16 = 0x03, - UInt32 = 0x04, - UInt64 = 0x05, - SInt8 = 0x06, - SInt16 = 0x07, - SInt32 = 0x08, - SInt64 = 0x09, - Float32 = 0x0A, - Float64 = 0x0B, - Binary = 0x0C, - String = 0x0D, - List = 0x0E, - Dict = 0x0F, - Error = 0x10, - Row = 0x11, - Empty = 0x12, - MultiRow = 0x13, -} +// re-export +pub use self::exchange::SQuery; + +/* + connection state +*/ #[derive(Debug, PartialEq)] pub struct ClientLocalState { @@ -128,106 +105,9 @@ impl ClientLocalState { } } -#[derive(Debug, PartialEq)] -pub enum Response { - Empty, - Null, - Serialized { - ty: ResponseType, - size: usize, - data: Vec, - }, - Bool(bool), -} - -pub(super) async fn query_loop( - con: &mut BufWriter, - buf: &mut BytesMut, - global: &Global, -) -> IoResult { - // handshake - let mut client_state = match do_handshake(con, buf, global).await? { - PostHandshake::Okay(hs) => hs, - PostHandshake::ConnectionClosedFin => return Ok(QueryLoopResult::Fin), - PostHandshake::ConnectionClosedRst => return Ok(QueryLoopResult::Rst), - PostHandshake::Error(e) => { - // failed to handshake; we'll close the connection - let hs_err_packet = [b'H', 0, 1, e.value_u8()]; - con.write_all(&hs_err_packet).await?; - return Ok(QueryLoopResult::HSFailed); - } - }; - // done handshaking - con.write_all(b"H\x00\x00\x00").await?; - con.flush().await?; - let mut state = QExchangeState::default(); - let mut cursor = Default::default(); - loop { - if con.read_buf(buf).await? == 0 { - if buf.is_empty() { - return Ok(QueryLoopResult::Fin); - } else { - return Ok(QueryLoopResult::Rst); - } - } - if !state.has_reached_target(buf) { - // we haven't buffered sufficient bytes; keep working - continue; - } - let sq = match unsafe { - // UNSAFE(@ohsayan): as the resume cursor is private, we can't access this anyways - exchange::resume(buf, cursor, state) - } { - (_, QExchangeResult::SQCompleted(sq)) => sq, - (new_cursor, QExchangeResult::ChangeState(new_state)) => { - cursor = new_cursor; - state = new_state; - continue; - } - (_, QExchangeResult::Error) => { - // respond with error - let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16) - .to_le_bytes(); - con.write_all(&[ResponseType::Error.value_u8(), a, b]) - .await?; - con.flush().await?; - // reset buffer, cursor and state - buf.clear(); - cursor = Default::default(); - state = QExchangeState::default(); - continue; - } - }; - // now execute query - match engine::core::exec::dispatch_to_executor(global, &mut client_state, sq).await { - Ok(Response::Empty) => { - con.write_all(&[ResponseType::Empty.value_u8()]).await?; - } - Ok(Response::Serialized { ty, size, data }) => { - con.write_u8(ty.value_u8()).await?; - let mut irep = IntegerRepr::new(); - con.write_all(irep.as_bytes(size as u64)).await?; - con.write_u8(b'\n').await?; - con.write_all(&data).await?; - } - Ok(Response::Bool(b)) => { - con.write_all(&[ResponseType::Bool.value_u8(), b as u8]) - .await? - } - Ok(Response::Null) => con.write_u8(ResponseType::Null.value_u8()).await?, - Err(e) => { - let [a, b] = (e.value_u8() as u16).to_le_bytes(); - con.write_all(&[ResponseType::Error.value_u8(), a, b]) - .await?; - } - } - con.flush().await?; - // reset buffer, cursor and state - buf.clear(); - cursor = Default::default(); - state = QExchangeState::default(); - } -} +/* + handshake +*/ #[derive(Debug, PartialEq)] enum PostHandshake { @@ -315,41 +195,194 @@ async fn do_handshake( Ok(PostHandshake::Error(ProtocolError::RejectAuth)) } -#[derive(Debug, PartialEq, Clone, Copy)] -enum AccumlatorStatus { - Pending(u64), - Completed(u64), +/* + exec event loop +*/ + +async fn cleanup_for_next_query( + con: &mut BufWriter, + buf: &mut BytesMut, +) -> IoResult<(ExchangeState, usize)> { + con.flush().await?; // flush write buffer + buf.clear(); // clear read buffer + Ok((ExchangeState::default(), 0)) } -/// Scan an integer -/// -/// Allowed sequences: -/// - < int >\n -/// - \n ; FIXME(@ohsayan): a LF only sequence is allowed. should it be removed? -fn scan_int(s: &mut BufferedScanner, acc: u64) -> Result { - let mut acc = acc; - let mut okay = true; - let mut end = s.rounded_eq(b'\n'); - while okay & !end & !s.eof() { - let d = unsafe { s.next_byte() }; - okay &= d.is_ascii_digit(); - match acc.checked_mul(10).map(|v| v.checked_add((d & 0x0f) as _)) { - Some(Some(v)) => acc = v, - _ => okay = false, +pub(super) async fn query_loop( + con: &mut BufWriter, + buf: &mut BytesMut, + global: &Global, +) -> IoResult { + // handshake + let mut client_state = match do_handshake(con, buf, global).await? { + PostHandshake::Okay(hs) => hs, + PostHandshake::ConnectionClosedFin => return Ok(QueryLoopResult::Fin), + PostHandshake::ConnectionClosedRst => return Ok(QueryLoopResult::Rst), + PostHandshake::Error(e) => { + // failed to handshake; we'll close the connection + let hs_err_packet = [b'H', 0, 1, e.value_u8()]; + con.write_all(&hs_err_packet).await?; + return Ok(QueryLoopResult::HSFailed); + } + }; + // done handshaking + con.write_all(b"H\x00\x00\x00").await?; + con.flush().await?; + let mut state = ExchangeState::default(); + let mut cursor = 0; + loop { + if con.read_buf(buf).await? == 0 { + if buf.is_empty() { + return Ok(QueryLoopResult::Fin); + } else { + return Ok(QueryLoopResult::Rst); + } + } + match Exchange::try_complete( + unsafe { + // UNSAFE(@ohsayan): the cursor is either 0 or returned by the exchange impl + BufferedScanner::new_with_cursor(&buf, cursor) + }, + state, + ) { + Ok((result, new_cursor)) => match result { + ExchangeResult::NewState(new_state) => { + state = new_state; + cursor = new_cursor; + } + ExchangeResult::Simple(query) => { + exec_simple(con, &mut client_state, global, query).await?; + (state, cursor) = cleanup_for_next_query(con, buf).await?; + } + ExchangeResult::Pipeline(pipe) => { + exec_pipe(con, &mut client_state, global, pipe).await?; + (state, cursor) = cleanup_for_next_query(con, buf).await?; + } + }, + Err(()) => { + // respond with error + let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16) + .to_le_bytes(); + con.write_all(&[ResponseType::Error.value_u8(), a, b]) + .await?; + (state, cursor) = cleanup_for_next_query(con, buf).await?; + } } - end = s.rounded_eq(b'\n'); } - unsafe { - // UNSAFE(@ohsayan): if we hit an LF, then we still have space until EOA - s.incr_cursor_if(end) +} + +/* + responses +*/ + +#[repr(u8)] +#[derive(sky_macros::EnumMethods, Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] +#[allow(unused)] +pub enum ResponseType { + Null = 0x00, + Bool = 0x01, + UInt8 = 0x02, + UInt16 = 0x03, + UInt32 = 0x04, + UInt64 = 0x05, + SInt8 = 0x06, + SInt16 = 0x07, + SInt32 = 0x08, + SInt64 = 0x09, + Float32 = 0x0A, + Float64 = 0x0B, + Binary = 0x0C, + String = 0x0D, + List = 0x0E, + Dict = 0x0F, + Error = 0x10, + Row = 0x11, + Empty = 0x12, + MultiRow = 0x13, +} + +#[derive(Debug, PartialEq)] +pub enum Response { + Empty, + Null, + Serialized { + ty: ResponseType, + size: usize, + data: Vec, + }, + Bool(bool), +} + +async fn write_response( + resp: QueryResult, + con: &mut BufWriter, +) -> IoResult<()> { + match resp { + Ok(Response::Empty) => con.write_all(&[ResponseType::Empty.value_u8()]).await, + Ok(Response::Serialized { ty, size, data }) => { + con.write_u8(ty.value_u8()).await?; + let mut irep = IntegerRepr::new(); + con.write_all(irep.as_bytes(size as u64)).await?; + con.write_u8(b'\n').await?; + con.write_all(&data).await + } + Ok(Response::Bool(b)) => { + con.write_all(&[ResponseType::Bool.value_u8(), b as u8]) + .await + } + Ok(Response::Null) => con.write_u8(ResponseType::Null.value_u8()).await, + Err(e) => { + let [a, b] = (e.value_u8() as u16).to_le_bytes(); + con.write_all(&[ResponseType::Error.value_u8(), a, b]).await + } } - if okay & end { - Ok(AccumlatorStatus::Completed(acc)) - } else { - if okay { - Ok(AccumlatorStatus::Pending(acc)) - } else { - Err(()) +} + +/* + simple query +*/ + +async fn exec_simple( + con: &mut BufWriter, + cs: &mut ClientLocalState, + global: &Global, + query: SQuery<'_>, +) -> IoResult<()> { + write_response( + engine::core::exec::dispatch_to_executor(global, cs, query).await, + con, + ) + .await +} + +/* + pipeline +*/ + +async fn exec_pipe<'a, S: Socket>( + con: &mut BufWriter, + cs: &mut ClientLocalState, + global: &Global, + mut pipe: Pipeline<'a>, +) -> IoResult<()> { + loop { + match pipe.next_query() { + Ok(None) => break Ok(()), + Ok(Some(q)) => { + write_response( + engine::core::exec::dispatch_to_executor(global, cs, q).await, + con, + ) + .await? + } + Err(()) => { + // respond with error + let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16) + .to_le_bytes(); + con.write_all(&[ResponseType::Error.value_u8(), a, b]) + .await?; + break Ok(()); + } } } } diff --git a/server/src/engine/net/protocol/tests.rs b/server/src/engine/net/protocol/tests.rs index 811a0569..85355527 100644 --- a/server/src/engine/net/protocol/tests.rs +++ b/server/src/engine/net/protocol/tests.rs @@ -24,28 +24,15 @@ * */ -use crate::engine::net::protocol::exchange::Resume; - use { - super::{ - exchange::{self, QExchangeResult, QExchangeState}, - handshake::ProtocolError, - SQuery, - }, - crate::{ - engine::{ - mem::BufferedScanner, - net::protocol::{ - handshake::{ - AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode, - HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode, - }, - scan_int, AccumlatorStatus, - }, + super::handshake::ProtocolError, + crate::engine::{ + mem::BufferedScanner, + net::protocol::handshake::{ + AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode, + HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode, }, - util::test_utils, }, - rand::Rng, }; pub(super) fn create_simple_query(query: &str, params: [&str; N]) -> Vec { @@ -283,212 +270,3 @@ fn hs_bad_auth_mode() { assert_eq!(hs_result, HandshakeResult::Error(ProtocolError::RejectAuth)) }) } - -/* - QT-DEX/SQ -*/ - -const SQ: &str = "select * from myspace.mymodel where username = ?"; - -fn parse_staged( - query: &str, - params: [&str; N], - eq: impl Fn(SQuery), - rng: &mut impl Rng, -) { - let __query_buffer = create_simple_query(query, params); - for _ in 0..__query_buffer.len() { - let mut __read_total = 0; - let mut cursor = Default::default(); - let mut state = QExchangeState::default(); - loop { - let remaining = __query_buffer.len() - __read_total; - let read_this_time = { - let mut cnt = 0; - if remaining == 1 { - 1 - } else { - let mut last = test_utils::random_number(1, remaining, rng); - loop { - if cnt >= 10 { - break last; - } - // if we're reading exact, try to keep it low - if last == remaining { - cnt += 1; - last = test_utils::random_number(1, remaining, rng); - } else { - break last; - } - } - } - }; - __read_total += read_this_time; - let buffered = &__query_buffer[..__read_total]; - if !state.has_reached_target(buffered) { - continue; - } - match unsafe { exchange::resume(buffered, cursor, state) } { - (new_cursor, QExchangeResult::ChangeState(new_state)) => { - cursor = new_cursor; - state = new_state; - continue; - } - (_, QExchangeResult::SQCompleted(q)) => { - eq(q); - break; - } - _ => panic!(), - } - } - } -} - -#[test] -fn staged_randomized() { - let mut rng = test_utils::rng(); - parse_staged( - SQ, - ["sayan"], - |q| { - assert_eq!(q.query_str(), SQ); - assert_eq!(q.params_str(), "sayan"); - }, - &mut rng, - ); -} - -#[test] -fn stages_manual() { - let query = create_simple_query("select * from mymodel where username = ?", ["sayan"]); - assert_eq!( - unsafe { - exchange::resume( - &query[..QExchangeState::MIN_READ], - Default::default(), - Default::default(), - ) - }, - ( - Resume::test_new(5), - QExchangeResult::ChangeState(QExchangeState::new_test( - exchange::QExchangeStateInternal::PendingMeta2, - 52, - 48, - 4 - )) - ) - ); - assert_eq!( - unsafe { - exchange::resume( - &query[..QExchangeState::MIN_READ + 1], - Default::default(), - Default::default(), - ) - }, - ( - Resume::test_new(6), - QExchangeResult::ChangeState(QExchangeState::new_test( - exchange::QExchangeStateInternal::PendingMeta2, - 52, - 48, - 40 - )) - ) - ); - assert_eq!( - unsafe { - exchange::resume( - &query[..QExchangeState::MIN_READ + 2], - Default::default(), - Default::default(), - ) - }, - ( - Resume::test_new(7), - QExchangeResult::ChangeState(QExchangeState::new_test( - exchange::QExchangeStateInternal::PendingData, - 52, - 48, - 40 - )) - ) - ); - // the cursor should never change - for upper_bound in QExchangeState::MIN_READ + 2..query.len() { - assert_eq!( - unsafe { - exchange::resume( - &query[..upper_bound], - Default::default(), - Default::default(), - ) - }, - ( - Resume::test_new(7), - QExchangeResult::ChangeState(QExchangeState::new_test( - exchange::QExchangeStateInternal::PendingData, - 52, - 48, - 40 - )) - ) - ); - } - match unsafe { exchange::resume(&query, Default::default(), Default::default()) } { - (l, QExchangeResult::SQCompleted(q)) if l.inner() == query.len() => { - assert_eq!(q.query_str(), "select * from mymodel where username = ?"); - assert_eq!(q.params_str(), "sayan"); - } - e => panic!("expected end, got {e:?}"), - } -} - -#[test] -fn exchange_bad_segment_metadata() { - let exchange_packets = [ - ( - b"S4A\n32\ndelete from dbs where dbname = ?sillydb".as_slice(), - 4, - "incorrect packet size", - ), - ( - b"S42\n3A\ndelete from dbs where dbname = ?sillydb", - 7, - "incorrect q window", - ), - ( - b"S4A\n3A\ndelete from dbs where dbname = ?sillydb", - 4, - "incorrect packet size and q window", - ), - ]; - for (packet, cursor, description) in exchange_packets { - assert_eq!( - unsafe { exchange::resume(packet, Resume::test_new(0), Default::default()) }, - (Resume::test_new(cursor), QExchangeResult::Error), - "failed for `{description}`" - ) - } -} - -#[test] -fn num_accumulate() { - let x = [ - (b"1".as_slice(), Ok(AccumlatorStatus::Pending(1)), 1usize), - (b"12", Ok(AccumlatorStatus::Pending(12)), 2), - (b"123", Ok(AccumlatorStatus::Pending(123)), 3), - (b"123\n", Ok(AccumlatorStatus::Completed(123)), 4), - (b"\n", Ok(AccumlatorStatus::Completed(0)), 1), - (b"A", Err(()), 1), - (b"A\n", Err(()), 2), - (b"1A", Err(()), 2), - (b"1A\n", Err(()), 3), - ]; - for (buf, res, cursor) in x { - let mut bs = BufferedScanner::new(buf); - assert_eq!(scan_int(&mut bs, 0), res); - assert_eq!(bs.cursor(), cursor); - } -}