From b961e840f525ddf81e9152300c79ea76f308c710 Mon Sep 17 00:00:00 2001 From: Sayan Nandan Date: Fri, 5 Apr 2024 18:30:31 +0530 Subject: [PATCH] net: Retries based on `expect` values are pointless --- Cargo.lock | 24 +-- server/Cargo.toml | 2 +- server/src/engine/net/protocol/handshake.rs | 29 +--- server/src/engine/net/protocol/mod.rs | 177 +++++++++++--------- server/src/engine/net/protocol/tests.rs | 11 +- server/src/engine/tests/client/mod.rs | 6 +- 6 files changed, 119 insertions(+), 130 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 18f43f91..98275317 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1300,7 +1300,7 @@ dependencies = [ "libsky", "log", "num_cpus", - "skytable 0.8.6 (git+https://github.com/skytable/client-rust.git)", + "skytable", "tokio", ] @@ -1345,7 +1345,7 @@ dependencies = [ "serde", "serde_yaml", "sky_macros", - "skytable 0.8.6 (git+https://github.com/skytable/client-rust.git?branch=feature/pipeline-batch)", + "skytable", "tokio", "tokio-openssl", "uuid", @@ -1359,29 +1359,13 @@ dependencies = [ "crossterm", "libsky", "rustyline", - "skytable 0.8.6 (git+https://github.com/skytable/client-rust.git)", + "skytable", ] [[package]] name = "skytable" version = "0.8.6" -source = "git+https://github.com/skytable/client-rust.git?branch=feature/pipeline-batch#a279456c548781921cb4aed8cc1bab68c74fb37b" -dependencies = [ - "async-trait", - "bb8", - "itoa", - "native-tls", - "r2d2", - "rand", - "sky-derive", - "tokio", - "tokio-native-tls", -] - -[[package]] -name = "skytable" -version = "0.8.6" -source = "git+https://github.com/skytable/client-rust.git#0080df9bddb277eff5f0af704ca0f09aa222b775" +source = "git+https://github.com/skytable/client-rust.git#21263b5b3875ac185df58475b057c56d35ca5c30" dependencies = [ "async-trait", "bb8", diff --git a/server/Cargo.toml b/server/Cargo.toml index 3842c9be..f349aebf 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -50,7 +50,7 @@ libc = "0.2.153" # external deps rand = "0.8.5" tokio = { version = "1.37.0", features = ["test-util"] } -skytable = { git = "https://github.com/skytable/client-rust.git", branch = "feature/pipeline-batch" } +skytable = { git = "https://github.com/skytable/client-rust.git" } [features] nightly = [] diff --git a/server/src/engine/net/protocol/handshake.rs b/server/src/engine/net/protocol/handshake.rs index 29a798de..7dd6ec22 100644 --- a/server/src/engine/net/protocol/handshake.rs +++ b/server/src/engine/net/protocol/handshake.rs @@ -210,12 +210,7 @@ pub enum HandshakeResult<'a> { /// Finished handshake Completed(CHandshake<'a>), /// Update handshake state - /// - /// **NOTE:** expect does not take into account the current amount of buffered data (hence the unbuffered part must be computed!) - ChangeState { - new_state: HandshakeState, - expect: usize, - }, + ChangeState { new_state: HandshakeState }, /// An error occurred Error(ProtocolError), } @@ -271,17 +266,15 @@ impl<'a> CHandshake<'a> { /// Resume from the initial state (nothing buffered yet) fn resume_initial(scanner: &mut BufferedScanner<'a>) -> HandshakeResult<'a> { // get our block - if cfg!(debug_assertions) { - if scanner.remaining() < Self::INITIAL_READ { - return HandshakeResult::ChangeState { - new_state: HandshakeState::Initial, - expect: Self::INITIAL_READ, - }; - } - } else { - assert!(scanner.remaining() >= Self::INITIAL_READ); + if scanner.remaining() < Self::INITIAL_READ { + return HandshakeResult::ChangeState { + new_state: HandshakeState::Initial, + }; } - let buf: [u8; CHandshake::INITIAL_READ] = unsafe { scanner.next_chunk() }; + let buf: [u8; CHandshake::INITIAL_READ] = unsafe { + // UNSAFE(@ohsayan): validated in earlier branch + scanner.next_chunk() + }; let invalid_first_byte = buf[0] != Self::CLIENT_HELLO; let invalid_hs_version = buf[1] > HandshakeVersion::MAX_DSCR; let invalid_proto_version = buf[2] > ProtocolVersion::MAX_DSCR; @@ -350,7 +343,6 @@ impl<'a> CHandshake<'a> { uname_l, pwd_l, }, - expect: (uname_l + pwd_l), } } } @@ -366,7 +358,6 @@ impl<'a> CHandshake<'a> { // we need more data return HandshakeResult::ChangeState { new_state: HandshakeState::StaticBlock(static_header), - expect: static_header.auth_mode.min_payload_bytes(), }; } // we seem to have enough data for this auth mode @@ -379,7 +370,6 @@ impl<'a> CHandshake<'a> { ScannerDecodeResult::NeedMore => { return HandshakeResult::ChangeState { new_state: HandshakeState::StaticBlock(static_header), - expect: AuthMode::Password.min_payload_bytes(), // 2 for uname_l and 2 for pwd_l }; } ScannerDecodeResult::Value(v) => v as usize, @@ -402,7 +392,6 @@ impl<'a> CHandshake<'a> { // newline missing (or maybe there's more?) return HandshakeResult::ChangeState { new_state: HandshakeState::ExpectingMetaForVariableBlock { static_hs, uname_l }, - expect: uname_l + 2, // space for username + password len }; } ScannerDecodeResult::Error => { diff --git a/server/src/engine/net/protocol/mod.rs b/server/src/engine/net/protocol/mod.rs index e83c5523..6017d8ce 100644 --- a/server/src/engine/net/protocol/mod.rs +++ b/server/src/engine/net/protocol/mod.rs @@ -48,8 +48,8 @@ use { self::{ exchange::{Exchange, ExchangeResult, ExchangeState, Pipeline}, handshake::{ - AuthMode, CHandshake, DataExchangeMode, HandshakeResult, HandshakeState, - HandshakeVersion, ProtocolError, ProtocolVersion, QueryMode, + AuthMode, DataExchangeMode, HandshakeResult, HandshakeState, HandshakeVersion, + ProtocolError, ProtocolVersion, QueryMode, }, }, super::{IoResult, QueryLoopResult, Socket}, @@ -107,14 +107,43 @@ impl ClientLocalState { } } +/* + read loop +*/ + +macro_rules! read_loop { + ($con:expr, $buf:expr, $conn_closed:expr, $conn_reset:expr, $body:block) => { + loop { + let read_many = $con.read_buf($buf).await?; + if read_many == 0 { + if $buf.is_empty() { + return $conn_closed; + } else { + return $conn_reset; + } + } + $body + } + }; +} + /* handshake */ +#[inline(always)] +async fn write_handshake_error( + con: &mut BufWriter, + e: ProtocolError, +) -> IoResult<()> { + let hs_err_packet = [b'H', 0, 1, e.value_u8()]; + con.write_all(&hs_err_packet).await +} + #[derive(Debug, PartialEq)] -enum PostHandshake { +enum HandshakeCompleteResult { Okay(ClientLocalState), - Error(ProtocolError), + Error, ConnectionClosedFin, ConnectionClosedRst, } @@ -123,40 +152,34 @@ async fn do_handshake( con: &mut BufWriter, buf: &mut BytesMut, global: &Global, -) -> IoResult { - let mut expected = CHandshake::INITIAL_READ; +) -> IoResult { let mut state = HandshakeState::default(); let mut cursor = 0; let handshake; - loop { - let read_many = con.read_buf(buf).await?; - if read_many == 0 { - if buf.is_empty() { - return Ok(PostHandshake::ConnectionClosedFin); - } else { - return Ok(PostHandshake::ConnectionClosedRst); - } - } - if buf.len() < expected { - continue; - } - let mut scanner = unsafe { BufferedScanner::new_with_cursor(buf, cursor) }; - match handshake::CHandshake::resume_with(&mut scanner, state) { - HandshakeResult::Completed(hs) => { - handshake = hs; - cursor = scanner.cursor(); - break; - } - HandshakeResult::ChangeState { new_state, expect } => { - expected = expect; - state = new_state; - cursor = scanner.cursor(); - } - HandshakeResult::Error(e) => { - return Ok(PostHandshake::Error(e)); + read_loop!( + con, + buf, + Ok(HandshakeCompleteResult::ConnectionClosedFin), + Ok(HandshakeCompleteResult::ConnectionClosedRst), + { + let mut scanner = unsafe { BufferedScanner::new_with_cursor(buf, cursor) }; + match handshake::CHandshake::resume_with(&mut scanner, state) { + HandshakeResult::Completed(hs) => { + handshake = hs; + cursor = scanner.cursor(); + break; + } + HandshakeResult::ChangeState { new_state } => { + state = new_state; + cursor = scanner.cursor(); + } + HandshakeResult::Error(e) => { + write_handshake_error(con, e).await?; + return Ok(HandshakeCompleteResult::Error); + } } } - } + ); // check handshake if cfg!(debug_assertions) { assert_eq!( @@ -181,7 +204,7 @@ async fn do_handshake( { okay @ (VerifyUser::Okay | VerifyUser::OkayRoot) => { let hs = handshake.hs_static(); - let ret = Ok(PostHandshake::Okay(ClientLocalState::new( + let ret = Ok(HandshakeCompleteResult::Okay(ClientLocalState::new( uname.into(), okay.is_root(), hs, @@ -194,7 +217,8 @@ async fn do_handshake( } Err(_) => {} }; - Ok(PostHandshake::Error(ProtocolError::RejectAuth)) + write_handshake_error(con, ProtocolError::RejectAuth).await?; + Ok(HandshakeCompleteResult::Error) } /* @@ -217,60 +241,55 @@ pub(super) async fn query_loop( ) -> IoResult { // handshake let mut client_state = match do_handshake(con, buf, global).await? { - PostHandshake::Okay(hs) => hs, - PostHandshake::ConnectionClosedFin => return Ok(QueryLoopResult::Fin), - PostHandshake::ConnectionClosedRst => return Ok(QueryLoopResult::Rst), - PostHandshake::Error(e) => { - // failed to handshake; we'll close the connection - let hs_err_packet = [b'H', 0, 1, e.value_u8()]; - con.write_all(&hs_err_packet).await?; - return Ok(QueryLoopResult::HSFailed); - } + HandshakeCompleteResult::Okay(hs) => hs, + HandshakeCompleteResult::ConnectionClosedFin => return Ok(QueryLoopResult::Fin), + HandshakeCompleteResult::ConnectionClosedRst => return Ok(QueryLoopResult::Rst), + HandshakeCompleteResult::Error => return Ok(QueryLoopResult::HSFailed), }; // done handshaking con.write_all(b"H\x00\x00\x00").await?; con.flush().await?; let mut state = ExchangeState::default(); let mut cursor = 0; - loop { - if con.read_buf(buf).await? == 0 { - if buf.is_empty() { - return Ok(QueryLoopResult::Fin); - } else { - return Ok(QueryLoopResult::Rst); - } - } - match Exchange::try_complete( - unsafe { - // UNSAFE(@ohsayan): the cursor is either 0 or returned by the exchange impl - BufferedScanner::new_with_cursor(&buf, cursor) - }, - state, - ) { - Ok((result, new_cursor)) => match result { - ExchangeResult::NewState(new_state) => { - state = new_state; - cursor = new_cursor; - } - ExchangeResult::Simple(query) => { - exec_simple(con, &mut client_state, global, query).await?; + read_loop!( + con, + buf, + Ok(QueryLoopResult::Fin), + Ok(QueryLoopResult::Rst), + { + match Exchange::try_complete( + unsafe { + // UNSAFE(@ohsayan): the cursor is either 0 or returned by the exchange impl + BufferedScanner::new_with_cursor(&buf, cursor) + }, + state, + ) { + Ok((result, new_cursor)) => match result { + ExchangeResult::NewState(new_state) => { + state = new_state; + cursor = new_cursor; + } + ExchangeResult::Simple(query) => { + exec_simple(con, &mut client_state, global, query).await?; + (state, cursor) = cleanup_for_next_query(con, buf).await?; + } + ExchangeResult::Pipeline(pipe) => { + exec_pipe(con, &mut client_state, global, pipe).await?; + (state, cursor) = cleanup_for_next_query(con, buf).await?; + } + }, + Err(_) => { + // respond with error + let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() + as u16) + .to_le_bytes(); + con.write_all(&[ResponseType::Error.value_u8(), a, b]) + .await?; (state, cursor) = cleanup_for_next_query(con, buf).await?; } - ExchangeResult::Pipeline(pipe) => { - exec_pipe(con, &mut client_state, global, pipe).await?; - (state, cursor) = cleanup_for_next_query(con, buf).await?; - } - }, - Err(_) => { - // respond with error - let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16) - .to_le_bytes(); - con.write_all(&[ResponseType::Error.value_u8(), a, b]) - .await?; - (state, cursor) = cleanup_for_next_query(con, buf).await?; } } - } + ); } /* diff --git a/server/src/engine/net/protocol/tests.rs b/server/src/engine/net/protocol/tests.rs index c0ad111a..ed3b5448 100644 --- a/server/src/engine/net/protocol/tests.rs +++ b/server/src/engine/net/protocol/tests.rs @@ -74,7 +74,6 @@ fn parse_staged_with_auth() { result, HandshakeResult::ChangeState { new_state: HandshakeState::Initial, - expect: CHandshake::INITIAL_READ } ); } @@ -85,7 +84,6 @@ fn parse_staged_with_auth() { result, HandshakeResult::ChangeState { new_state: HandshakeState::StaticBlock(STATIC_HANDSHAKE_WITH_AUTH), - expect: 4 } ); } @@ -98,7 +96,6 @@ fn parse_staged_with_auth() { uname_l: 5, pwd_l: 8 }, - expect: 13, } ); } @@ -124,15 +121,13 @@ fn run_state_changes_return_rounds(src: &[u8], expected_final_handshake: CHandsh let mut rounds = 0; let mut state = HandshakeState::default(); let mut cursor = 0; - let mut expect_many = CHandshake::INITIAL_READ; loop { rounds += 1; - let buf = &src[..cursor + expect_many]; + let buf = &src[..rounds]; let mut scanner = unsafe { BufferedScanner::new_with_cursor(buf, cursor) }; match CHandshake::resume_with(&mut scanner, state) { - HandshakeResult::ChangeState { new_state, expect } => { + HandshakeResult::ChangeState { new_state } => { state = new_state; - expect_many = expect; cursor = scanner.cursor(); } HandshakeResult::Completed(hs) => { @@ -154,7 +149,7 @@ fn parse_auth_with_state_updates() { CHandshakeAuth::new(b"sayan", b"pass1234"), ), ); - assert_eq!(rounds, 3); // r1 = initial read, r2 = lengths, r3 = items + assert_eq!(rounds, FULL_HANDSHAKE_WITH_AUTH.len()) } const HS_BAD_PACKET: [u8; 6] = *b"I\x00\0\0\0\0"; diff --git a/server/src/engine/tests/client/mod.rs b/server/src/engine/tests/client/mod.rs index 26501c2e..b1ff7d7d 100644 --- a/server/src/engine/tests/client/mod.rs +++ b/server/src/engine/tests/client/mod.rs @@ -30,16 +30,18 @@ use skytable::{ response::{Response, Value}, }; +const PIPE_RUNS: usize = 20; + #[sky_macros::dbtest] fn pipe() { let mut db = db!(); let mut pipe = Pipeline::new(); - for _ in 0..100 { + for _ in 0..PIPE_RUNS { pipe.push(&query!("sysctl report status")); } assert_eq!( db.execute_pipeline(&pipe).unwrap(), - vec![Response::Empty; 100] + vec![Response::Empty; PIPE_RUNS] ); }