net: Retries based on `expect` values are pointless

next
Sayan Nandan 6 months ago
parent 337ea9efe2
commit b961e840f5
No known key found for this signature in database
GPG Key ID: 0EBD769024B24F0A

24
Cargo.lock generated

@ -1300,7 +1300,7 @@ dependencies = [
"libsky", "libsky",
"log", "log",
"num_cpus", "num_cpus",
"skytable 0.8.6 (git+https://github.com/skytable/client-rust.git)", "skytable",
"tokio", "tokio",
] ]
@ -1345,7 +1345,7 @@ dependencies = [
"serde", "serde",
"serde_yaml", "serde_yaml",
"sky_macros", "sky_macros",
"skytable 0.8.6 (git+https://github.com/skytable/client-rust.git?branch=feature/pipeline-batch)", "skytable",
"tokio", "tokio",
"tokio-openssl", "tokio-openssl",
"uuid", "uuid",
@ -1359,29 +1359,13 @@ dependencies = [
"crossterm", "crossterm",
"libsky", "libsky",
"rustyline", "rustyline",
"skytable 0.8.6 (git+https://github.com/skytable/client-rust.git)", "skytable",
] ]
[[package]] [[package]]
name = "skytable" name = "skytable"
version = "0.8.6" version = "0.8.6"
source = "git+https://github.com/skytable/client-rust.git?branch=feature/pipeline-batch#a279456c548781921cb4aed8cc1bab68c74fb37b" source = "git+https://github.com/skytable/client-rust.git#21263b5b3875ac185df58475b057c56d35ca5c30"
dependencies = [
"async-trait",
"bb8",
"itoa",
"native-tls",
"r2d2",
"rand",
"sky-derive",
"tokio",
"tokio-native-tls",
]
[[package]]
name = "skytable"
version = "0.8.6"
source = "git+https://github.com/skytable/client-rust.git#0080df9bddb277eff5f0af704ca0f09aa222b775"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"bb8", "bb8",

@ -50,7 +50,7 @@ libc = "0.2.153"
# external deps # external deps
rand = "0.8.5" rand = "0.8.5"
tokio = { version = "1.37.0", features = ["test-util"] } tokio = { version = "1.37.0", features = ["test-util"] }
skytable = { git = "https://github.com/skytable/client-rust.git", branch = "feature/pipeline-batch" } skytable = { git = "https://github.com/skytable/client-rust.git" }
[features] [features]
nightly = [] nightly = []

@ -210,12 +210,7 @@ pub enum HandshakeResult<'a> {
/// Finished handshake /// Finished handshake
Completed(CHandshake<'a>), Completed(CHandshake<'a>),
/// Update handshake state /// Update handshake state
/// ChangeState { new_state: HandshakeState },
/// **NOTE:** expect does not take into account the current amount of buffered data (hence the unbuffered part must be computed!)
ChangeState {
new_state: HandshakeState,
expect: usize,
},
/// An error occurred /// An error occurred
Error(ProtocolError), Error(ProtocolError),
} }
@ -271,17 +266,15 @@ impl<'a> CHandshake<'a> {
/// Resume from the initial state (nothing buffered yet) /// Resume from the initial state (nothing buffered yet)
fn resume_initial(scanner: &mut BufferedScanner<'a>) -> HandshakeResult<'a> { fn resume_initial(scanner: &mut BufferedScanner<'a>) -> HandshakeResult<'a> {
// get our block // get our block
if cfg!(debug_assertions) { if scanner.remaining() < Self::INITIAL_READ {
if scanner.remaining() < Self::INITIAL_READ { return HandshakeResult::ChangeState {
return HandshakeResult::ChangeState { new_state: HandshakeState::Initial,
new_state: HandshakeState::Initial, };
expect: Self::INITIAL_READ,
};
}
} else {
assert!(scanner.remaining() >= Self::INITIAL_READ);
} }
let buf: [u8; CHandshake::INITIAL_READ] = unsafe { scanner.next_chunk() }; let buf: [u8; CHandshake::INITIAL_READ] = unsafe {
// UNSAFE(@ohsayan): validated in earlier branch
scanner.next_chunk()
};
let invalid_first_byte = buf[0] != Self::CLIENT_HELLO; let invalid_first_byte = buf[0] != Self::CLIENT_HELLO;
let invalid_hs_version = buf[1] > HandshakeVersion::MAX_DSCR; let invalid_hs_version = buf[1] > HandshakeVersion::MAX_DSCR;
let invalid_proto_version = buf[2] > ProtocolVersion::MAX_DSCR; let invalid_proto_version = buf[2] > ProtocolVersion::MAX_DSCR;
@ -350,7 +343,6 @@ impl<'a> CHandshake<'a> {
uname_l, uname_l,
pwd_l, pwd_l,
}, },
expect: (uname_l + pwd_l),
} }
} }
} }
@ -366,7 +358,6 @@ impl<'a> CHandshake<'a> {
// we need more data // we need more data
return HandshakeResult::ChangeState { return HandshakeResult::ChangeState {
new_state: HandshakeState::StaticBlock(static_header), new_state: HandshakeState::StaticBlock(static_header),
expect: static_header.auth_mode.min_payload_bytes(),
}; };
} }
// we seem to have enough data for this auth mode // we seem to have enough data for this auth mode
@ -379,7 +370,6 @@ impl<'a> CHandshake<'a> {
ScannerDecodeResult::NeedMore => { ScannerDecodeResult::NeedMore => {
return HandshakeResult::ChangeState { return HandshakeResult::ChangeState {
new_state: HandshakeState::StaticBlock(static_header), new_state: HandshakeState::StaticBlock(static_header),
expect: AuthMode::Password.min_payload_bytes(), // 2 for uname_l and 2 for pwd_l
}; };
} }
ScannerDecodeResult::Value(v) => v as usize, ScannerDecodeResult::Value(v) => v as usize,
@ -402,7 +392,6 @@ impl<'a> CHandshake<'a> {
// newline missing (or maybe there's more?) // newline missing (or maybe there's more?)
return HandshakeResult::ChangeState { return HandshakeResult::ChangeState {
new_state: HandshakeState::ExpectingMetaForVariableBlock { static_hs, uname_l }, new_state: HandshakeState::ExpectingMetaForVariableBlock { static_hs, uname_l },
expect: uname_l + 2, // space for username + password len
}; };
} }
ScannerDecodeResult::Error => { ScannerDecodeResult::Error => {

@ -48,8 +48,8 @@ use {
self::{ self::{
exchange::{Exchange, ExchangeResult, ExchangeState, Pipeline}, exchange::{Exchange, ExchangeResult, ExchangeState, Pipeline},
handshake::{ handshake::{
AuthMode, CHandshake, DataExchangeMode, HandshakeResult, HandshakeState, AuthMode, DataExchangeMode, HandshakeResult, HandshakeState, HandshakeVersion,
HandshakeVersion, ProtocolError, ProtocolVersion, QueryMode, ProtocolError, ProtocolVersion, QueryMode,
}, },
}, },
super::{IoResult, QueryLoopResult, Socket}, super::{IoResult, QueryLoopResult, Socket},
@ -107,14 +107,43 @@ impl ClientLocalState {
} }
} }
/*
read loop
*/
macro_rules! read_loop {
($con:expr, $buf:expr, $conn_closed:expr, $conn_reset:expr, $body:block) => {
loop {
let read_many = $con.read_buf($buf).await?;
if read_many == 0 {
if $buf.is_empty() {
return $conn_closed;
} else {
return $conn_reset;
}
}
$body
}
};
}
/* /*
handshake handshake
*/ */
#[inline(always)]
async fn write_handshake_error<S: Socket>(
con: &mut BufWriter<S>,
e: ProtocolError,
) -> IoResult<()> {
let hs_err_packet = [b'H', 0, 1, e.value_u8()];
con.write_all(&hs_err_packet).await
}
#[derive(Debug, PartialEq)] #[derive(Debug, PartialEq)]
enum PostHandshake { enum HandshakeCompleteResult {
Okay(ClientLocalState), Okay(ClientLocalState),
Error(ProtocolError), Error,
ConnectionClosedFin, ConnectionClosedFin,
ConnectionClosedRst, ConnectionClosedRst,
} }
@ -123,40 +152,34 @@ async fn do_handshake<S: Socket>(
con: &mut BufWriter<S>, con: &mut BufWriter<S>,
buf: &mut BytesMut, buf: &mut BytesMut,
global: &Global, global: &Global,
) -> IoResult<PostHandshake> { ) -> IoResult<HandshakeCompleteResult> {
let mut expected = CHandshake::INITIAL_READ;
let mut state = HandshakeState::default(); let mut state = HandshakeState::default();
let mut cursor = 0; let mut cursor = 0;
let handshake; let handshake;
loop { read_loop!(
let read_many = con.read_buf(buf).await?; con,
if read_many == 0 { buf,
if buf.is_empty() { Ok(HandshakeCompleteResult::ConnectionClosedFin),
return Ok(PostHandshake::ConnectionClosedFin); Ok(HandshakeCompleteResult::ConnectionClosedRst),
} else { {
return Ok(PostHandshake::ConnectionClosedRst); let mut scanner = unsafe { BufferedScanner::new_with_cursor(buf, cursor) };
} match handshake::CHandshake::resume_with(&mut scanner, state) {
} HandshakeResult::Completed(hs) => {
if buf.len() < expected { handshake = hs;
continue; cursor = scanner.cursor();
} break;
let mut scanner = unsafe { BufferedScanner::new_with_cursor(buf, cursor) }; }
match handshake::CHandshake::resume_with(&mut scanner, state) { HandshakeResult::ChangeState { new_state } => {
HandshakeResult::Completed(hs) => { state = new_state;
handshake = hs; cursor = scanner.cursor();
cursor = scanner.cursor(); }
break; HandshakeResult::Error(e) => {
} write_handshake_error(con, e).await?;
HandshakeResult::ChangeState { new_state, expect } => { return Ok(HandshakeCompleteResult::Error);
expected = expect; }
state = new_state;
cursor = scanner.cursor();
}
HandshakeResult::Error(e) => {
return Ok(PostHandshake::Error(e));
} }
} }
} );
// check handshake // check handshake
if cfg!(debug_assertions) { if cfg!(debug_assertions) {
assert_eq!( assert_eq!(
@ -181,7 +204,7 @@ async fn do_handshake<S: Socket>(
{ {
okay @ (VerifyUser::Okay | VerifyUser::OkayRoot) => { okay @ (VerifyUser::Okay | VerifyUser::OkayRoot) => {
let hs = handshake.hs_static(); let hs = handshake.hs_static();
let ret = Ok(PostHandshake::Okay(ClientLocalState::new( let ret = Ok(HandshakeCompleteResult::Okay(ClientLocalState::new(
uname.into(), uname.into(),
okay.is_root(), okay.is_root(),
hs, hs,
@ -194,7 +217,8 @@ async fn do_handshake<S: Socket>(
} }
Err(_) => {} Err(_) => {}
}; };
Ok(PostHandshake::Error(ProtocolError::RejectAuth)) write_handshake_error(con, ProtocolError::RejectAuth).await?;
Ok(HandshakeCompleteResult::Error)
} }
/* /*
@ -217,60 +241,55 @@ pub(super) async fn query_loop<S: Socket>(
) -> IoResult<QueryLoopResult> { ) -> IoResult<QueryLoopResult> {
// handshake // handshake
let mut client_state = match do_handshake(con, buf, global).await? { let mut client_state = match do_handshake(con, buf, global).await? {
PostHandshake::Okay(hs) => hs, HandshakeCompleteResult::Okay(hs) => hs,
PostHandshake::ConnectionClosedFin => return Ok(QueryLoopResult::Fin), HandshakeCompleteResult::ConnectionClosedFin => return Ok(QueryLoopResult::Fin),
PostHandshake::ConnectionClosedRst => return Ok(QueryLoopResult::Rst), HandshakeCompleteResult::ConnectionClosedRst => return Ok(QueryLoopResult::Rst),
PostHandshake::Error(e) => { HandshakeCompleteResult::Error => return Ok(QueryLoopResult::HSFailed),
// failed to handshake; we'll close the connection
let hs_err_packet = [b'H', 0, 1, e.value_u8()];
con.write_all(&hs_err_packet).await?;
return Ok(QueryLoopResult::HSFailed);
}
}; };
// done handshaking // done handshaking
con.write_all(b"H\x00\x00\x00").await?; con.write_all(b"H\x00\x00\x00").await?;
con.flush().await?; con.flush().await?;
let mut state = ExchangeState::default(); let mut state = ExchangeState::default();
let mut cursor = 0; let mut cursor = 0;
loop { read_loop!(
if con.read_buf(buf).await? == 0 { con,
if buf.is_empty() { buf,
return Ok(QueryLoopResult::Fin); Ok(QueryLoopResult::Fin),
} else { Ok(QueryLoopResult::Rst),
return Ok(QueryLoopResult::Rst); {
} match Exchange::try_complete(
} unsafe {
match Exchange::try_complete( // UNSAFE(@ohsayan): the cursor is either 0 or returned by the exchange impl
unsafe { BufferedScanner::new_with_cursor(&buf, cursor)
// UNSAFE(@ohsayan): the cursor is either 0 or returned by the exchange impl },
BufferedScanner::new_with_cursor(&buf, cursor) state,
}, ) {
state, Ok((result, new_cursor)) => match result {
) { ExchangeResult::NewState(new_state) => {
Ok((result, new_cursor)) => match result { state = new_state;
ExchangeResult::NewState(new_state) => { cursor = new_cursor;
state = new_state; }
cursor = new_cursor; ExchangeResult::Simple(query) => {
} exec_simple(con, &mut client_state, global, query).await?;
ExchangeResult::Simple(query) => { (state, cursor) = cleanup_for_next_query(con, buf).await?;
exec_simple(con, &mut client_state, global, query).await?; }
ExchangeResult::Pipeline(pipe) => {
exec_pipe(con, &mut client_state, global, pipe).await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?;
}
},
Err(_) => {
// respond with error
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8()
as u16)
.to_le_bytes();
con.write_all(&[ResponseType::Error.value_u8(), a, b])
.await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?; (state, cursor) = cleanup_for_next_query(con, buf).await?;
} }
ExchangeResult::Pipeline(pipe) => {
exec_pipe(con, &mut client_state, global, pipe).await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?;
}
},
Err(_) => {
// respond with error
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16)
.to_le_bytes();
con.write_all(&[ResponseType::Error.value_u8(), a, b])
.await?;
(state, cursor) = cleanup_for_next_query(con, buf).await?;
} }
} }
} );
} }
/* /*

@ -74,7 +74,6 @@ fn parse_staged_with_auth() {
result, result,
HandshakeResult::ChangeState { HandshakeResult::ChangeState {
new_state: HandshakeState::Initial, new_state: HandshakeState::Initial,
expect: CHandshake::INITIAL_READ
} }
); );
} }
@ -85,7 +84,6 @@ fn parse_staged_with_auth() {
result, result,
HandshakeResult::ChangeState { HandshakeResult::ChangeState {
new_state: HandshakeState::StaticBlock(STATIC_HANDSHAKE_WITH_AUTH), new_state: HandshakeState::StaticBlock(STATIC_HANDSHAKE_WITH_AUTH),
expect: 4
} }
); );
} }
@ -98,7 +96,6 @@ fn parse_staged_with_auth() {
uname_l: 5, uname_l: 5,
pwd_l: 8 pwd_l: 8
}, },
expect: 13,
} }
); );
} }
@ -124,15 +121,13 @@ fn run_state_changes_return_rounds(src: &[u8], expected_final_handshake: CHandsh
let mut rounds = 0; let mut rounds = 0;
let mut state = HandshakeState::default(); let mut state = HandshakeState::default();
let mut cursor = 0; let mut cursor = 0;
let mut expect_many = CHandshake::INITIAL_READ;
loop { loop {
rounds += 1; rounds += 1;
let buf = &src[..cursor + expect_many]; let buf = &src[..rounds];
let mut scanner = unsafe { BufferedScanner::new_with_cursor(buf, cursor) }; let mut scanner = unsafe { BufferedScanner::new_with_cursor(buf, cursor) };
match CHandshake::resume_with(&mut scanner, state) { match CHandshake::resume_with(&mut scanner, state) {
HandshakeResult::ChangeState { new_state, expect } => { HandshakeResult::ChangeState { new_state } => {
state = new_state; state = new_state;
expect_many = expect;
cursor = scanner.cursor(); cursor = scanner.cursor();
} }
HandshakeResult::Completed(hs) => { HandshakeResult::Completed(hs) => {
@ -154,7 +149,7 @@ fn parse_auth_with_state_updates() {
CHandshakeAuth::new(b"sayan", b"pass1234"), CHandshakeAuth::new(b"sayan", b"pass1234"),
), ),
); );
assert_eq!(rounds, 3); // r1 = initial read, r2 = lengths, r3 = items assert_eq!(rounds, FULL_HANDSHAKE_WITH_AUTH.len())
} }
const HS_BAD_PACKET: [u8; 6] = *b"I\x00\0\0\0\0"; const HS_BAD_PACKET: [u8; 6] = *b"I\x00\0\0\0\0";

@ -30,16 +30,18 @@ use skytable::{
response::{Response, Value}, response::{Response, Value},
}; };
const PIPE_RUNS: usize = 20;
#[sky_macros::dbtest] #[sky_macros::dbtest]
fn pipe() { fn pipe() {
let mut db = db!(); let mut db = db!();
let mut pipe = Pipeline::new(); let mut pipe = Pipeline::new();
for _ in 0..100 { for _ in 0..PIPE_RUNS {
pipe.push(&query!("sysctl report status")); pipe.push(&query!("sysctl report status"));
} }
assert_eq!( assert_eq!(
db.execute_pipeline(&pipe).unwrap(), db.execute_pipeline(&pipe).unwrap(),
vec![Response::Empty; 100] vec![Response::Empty; PIPE_RUNS]
); );
} }

Loading…
Cancel
Save