Simplify protocol messages

It's easier to formally specify the protocol, when we exempt a
disallowed single byte LF sequence as it reduces the number of
states.
next
Sayan Nandan 9 months ago
parent bc7d02d8cf
commit c0c4ad4248
No known key found for this signature in database
GPG Key ID: 0EBD769024B24F0A

@ -88,7 +88,7 @@ impl<'a, T> Scanner<'a, T> {
}
/// Returns true if the rounded cursor matches the predicate
pub fn rounded_cursor_matches(&self, f: impl Fn(&T) -> bool) -> bool {
f(&self.d[self.rounded_cursor()])
f(&self.d[self.rounded_cursor()]) & !self.eof()
}
/// Same as `rounded_cursor_matches`, but with the added guarantee that no rounding was done
pub fn rounded_cursor_not_eof_matches(&self, f: impl Fn(&T) -> bool) -> bool {
@ -101,6 +101,12 @@ impl<'a, T> Scanner<'a, T> {
{
self.rounded_cursor_matches(|v| v_t.eq(v)) & !self.eof()
}
pub fn rounded_eq(&self, v: T) -> bool
where
T: PartialEq,
{
v.eq(&self.d[self.rounded_cursor()]) & !self.eof()
}
}
impl<'a, T> Scanner<'a, T> {
@ -160,7 +166,7 @@ impl<'a, T> Scanner<'a, T> {
*self.cursor_ptr()
}
/// Returns the rounded cursor
pub fn rounded_cursor(&self) -> usize {
fn rounded_cursor(&self) -> usize {
(self.buffer_len() - 1).min(self.__cursor)
}
/// Returns the current cursor value with rounding

@ -24,7 +24,7 @@
*
*/
use crate::engine::mem::BufferedScanner;
use {super::AccumlatorStatus, crate::engine::mem::BufferedScanner};
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct Resume(usize);
@ -61,13 +61,6 @@ pub(super) unsafe fn resume<'a>(
SQ
*/
#[derive(Debug, PartialEq)]
pub(super) enum LFTIntParseResult {
Value(u64),
Partial(u64),
Error,
}
#[derive(Debug, PartialEq)]
pub struct SQuery<'a> {
q: &'a [u8],
@ -104,51 +97,6 @@ impl<'a> SQuery<'a> {
utils
*/
/// scan an integer:
/// - if just an LF:
/// - if disallowed single byte: return an error
/// - else, return value
/// - if no LF: return upto limit
/// - if LF: return value
pub(super) fn scanint(
scanner: &mut BufferedScanner,
first_run: bool,
prev: u64,
) -> LFTIntParseResult {
let mut current = prev;
// guard a case where the buffer might be empty and can potentially have invalid chars
let mut okay = !((scanner.rounded_cursor_value() == b'\n') & first_run);
while scanner.rounded_cursor_not_eof_matches(|b| b'\n'.ne(b)) & okay {
let byte = unsafe { scanner.next_byte() };
okay &= byte.is_ascii_digit();
match current
.checked_mul(10)
.map(|new| new.checked_add((byte & 0x0f) as u64))
{
Some(Some(int)) => {
current = int;
}
_ => {
okay = false;
}
}
}
let lf = scanner.rounded_cursor_not_eof_equals(b'\n');
unsafe {
// UNSAFE(@ohsayan): within buffer range
scanner.incr_cursor_if(lf);
}
if lf & okay {
LFTIntParseResult::Value(current)
} else {
if okay {
LFTIntParseResult::Partial(current)
} else {
LFTIntParseResult::Error
}
}
}
#[derive(Debug, PartialEq, Clone, Copy)]
pub(super) enum QExchangeStateInternal {
Initial,
@ -225,8 +173,8 @@ impl QExchangeState {
debug_assert!(scanner.has_left(Self::MIN_READ));
match self.state {
QExchangeStateInternal::Initial => self.start_initial(scanner),
QExchangeStateInternal::PendingMeta1 => self.resume_at_md1(scanner, false),
QExchangeStateInternal::PendingMeta2 => self.resume_at_md2(scanner, false),
QExchangeStateInternal::PendingMeta1 => self.resume_at_md1(scanner),
QExchangeStateInternal::PendingMeta2 => self.resume_at_md2(scanner),
QExchangeStateInternal::PendingData => self.resume_data(scanner),
}
}
@ -235,16 +183,12 @@ impl QExchangeState {
// has to be a simple query!
return QExchangeResult::Error;
}
self.resume_at_md1(scanner, true)
self.resume_at_md1(scanner)
}
fn resume_at_md1<'a>(
mut self,
scanner: &mut BufferedScanner<'a>,
first_run: bool,
) -> QExchangeResult<'a> {
let packet_size = match scanint(scanner, first_run, self.md_packet_size) {
LFTIntParseResult::Value(v) => v,
LFTIntParseResult::Partial(p) => {
fn resume_at_md1<'a>(mut self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> {
let packet_size = match super::scan_int(scanner, self.md_packet_size) {
Ok(AccumlatorStatus::Completed(v)) => v,
Ok(AccumlatorStatus::Pending(p)) => {
// if this is the first run, we read 5 bytes and need atleast one more; if this is a resume we read one or more bytes and
// need atleast one more
self.target += 1;
@ -252,26 +196,22 @@ impl QExchangeState {
self.state = QExchangeStateInternal::PendingMeta1;
return QExchangeResult::ChangeState(self);
}
LFTIntParseResult::Error => return QExchangeResult::Error,
Err(()) => return QExchangeResult::Error,
};
self.md_packet_size = packet_size;
self.target = scanner.cursor() + packet_size as usize;
// hand over control to md2
self.resume_at_md2(scanner, true)
self.resume_at_md2(scanner)
}
fn resume_at_md2<'a>(
mut self,
scanner: &mut BufferedScanner<'a>,
first_run: bool,
) -> QExchangeResult<'a> {
let q_window = match scanint(scanner, first_run, self.md_q_window) {
LFTIntParseResult::Value(v) => v,
LFTIntParseResult::Partial(p) => {
fn resume_at_md2<'a>(mut self, scanner: &mut BufferedScanner<'a>) -> QExchangeResult<'a> {
let q_window = match super::scan_int(scanner, self.md_q_window) {
Ok(AccumlatorStatus::Completed(v)) => v,
Ok(AccumlatorStatus::Pending(p)) => {
self.md_q_window = p;
self.state = QExchangeStateInternal::PendingMeta2;
return QExchangeResult::ChangeState(self);
}
LFTIntParseResult::Error => return QExchangeResult::Error,
Err(()) => return QExchangeResult::Error,
};
self.md_q_window = q_window;
// hand over control to data

@ -24,6 +24,21 @@
*
*/
/*
* Implementation of the Skyhash/2.0 Protocol
* ---
* This module implements handshake and exchange mode extensions for the Skyhash protocol.
*
* Notable points:
* - [Deprecated] Newline exception: while all integers are to be encoded and postfixed with an LF, a single LF
* without any integer payload is equivalent to a zero value. we allow this because it's easier to specify formally
* as states
* - Handshake parameter versions: We currently only evaluate values for the version "original" (shipped with
* Skytable 0.8.0)
* - FIXME(@ohsayan) Optimistic retry without timeout: Our current algorithm does not apply a timeout to receive data
* and optimistically retries infinitely until the target block size is received
*/
mod exchange;
mod handshake;
#[cfg(test)]
@ -294,3 +309,42 @@ async fn do_handshake<S: Socket>(
};
Ok(PostHandshake::Error(ProtocolError::RejectAuth))
}
#[derive(Debug, PartialEq, Clone, Copy)]
enum AccumlatorStatus {
Pending(u64),
Completed(u64),
}
/// Scan an integer
///
/// Allowed sequences:
/// - < int >\n
/// - \n ; FIXME(@ohsayan): a LF only sequence is allowed. should it be removed?
fn scan_int(s: &mut BufferedScanner, acc: u64) -> Result<AccumlatorStatus, ()> {
let mut acc = acc;
let mut okay = true;
let mut end = s.rounded_eq(b'\n');
while okay & !end & !s.eof() {
let d = unsafe { s.next_byte() };
okay &= d.is_ascii_digit();
match acc.checked_mul(10).map(|v| v.checked_add((d & 0x0f) as _)) {
Some(Some(v)) => acc = v,
_ => okay = false,
}
end = s.rounded_eq(b'\n');
}
unsafe {
// UNSAFE(@ohsayan): if we hit an LF, then we still have space until EOA
s.incr_cursor_if(end)
}
if okay & end {
Ok(AccumlatorStatus::Completed(acc))
} else {
if okay {
Ok(AccumlatorStatus::Pending(acc))
} else {
Err(())
}
}
}

@ -28,16 +28,19 @@ use crate::engine::net::protocol::exchange::Resume;
use {
super::{
exchange::{self, scanint, LFTIntParseResult, QExchangeResult, QExchangeState},
exchange::{self, QExchangeResult, QExchangeState},
handshake::ProtocolError,
SQuery,
},
crate::{
engine::{
mem::BufferedScanner,
net::protocol::handshake::{
AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode,
HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode,
net::protocol::{
handshake::{
AuthMode, CHandshake, CHandshakeAuth, CHandshakeStatic, DataExchangeMode,
HandshakeResult, HandshakeState, HandshakeVersion, ProtocolVersion, QueryMode,
},
scan_int, AccumlatorStatus,
},
},
util::test_utils,
@ -413,11 +416,21 @@ fn stages_manual() {
}
#[test]
fn scanint_impl() {
let mut s = BufferedScanner::new(b"\n");
assert_eq!(scanint(&mut s, true, 0), LFTIntParseResult::Error);
let mut s = BufferedScanner::new(b"12");
assert_eq!(scanint(&mut s, true, 0), LFTIntParseResult::Partial(12));
let mut s = BufferedScanner::new(b"12\n");
assert_eq!(scanint(&mut s, true, 0), LFTIntParseResult::Value(12));
fn num_accumulate() {
let x = [
(b"1".as_slice(), Ok(AccumlatorStatus::Pending(1)), 1usize),
(b"12", Ok(AccumlatorStatus::Pending(12)), 2),
(b"123", Ok(AccumlatorStatus::Pending(123)), 3),
(b"123\n", Ok(AccumlatorStatus::Completed(123)), 4),
(b"\n", Ok(AccumlatorStatus::Completed(0)), 1),
(b"A", Err(()), 1),
(b"A\n", Err(()), 2),
(b"1A", Err(()), 2),
(b"1A\n", Err(()), 3),
];
for (buf, res, cursor) in x {
let mut bs = BufferedScanner::new(buf);
assert_eq!(scan_int(&mut bs, 0), res);
assert_eq!(bs.cursor(), cursor);
}
}

Loading…
Cancel
Save