|
|
|
@ -48,8 +48,8 @@ use {
|
|
|
|
|
self::{
|
|
|
|
|
exchange::{Exchange, ExchangeResult, ExchangeState, Pipeline},
|
|
|
|
|
handshake::{
|
|
|
|
|
AuthMode, CHandshake, DataExchangeMode, HandshakeResult, HandshakeState,
|
|
|
|
|
HandshakeVersion, ProtocolError, ProtocolVersion, QueryMode,
|
|
|
|
|
AuthMode, DataExchangeMode, HandshakeResult, HandshakeState, HandshakeVersion,
|
|
|
|
|
ProtocolError, ProtocolVersion, QueryMode,
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
super::{IoResult, QueryLoopResult, Socket},
|
|
|
|
@ -107,14 +107,43 @@ impl ClientLocalState {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
read loop
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
macro_rules! read_loop {
|
|
|
|
|
($con:expr, $buf:expr, $conn_closed:expr, $conn_reset:expr, $body:block) => {
|
|
|
|
|
loop {
|
|
|
|
|
let read_many = $con.read_buf($buf).await?;
|
|
|
|
|
if read_many == 0 {
|
|
|
|
|
if $buf.is_empty() {
|
|
|
|
|
return $conn_closed;
|
|
|
|
|
} else {
|
|
|
|
|
return $conn_reset;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
$body
|
|
|
|
|
}
|
|
|
|
|
};
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|
handshake
|
|
|
|
|
*/
|
|
|
|
|
|
|
|
|
|
#[inline(always)]
|
|
|
|
|
async fn write_handshake_error<S: Socket>(
|
|
|
|
|
con: &mut BufWriter<S>,
|
|
|
|
|
e: ProtocolError,
|
|
|
|
|
) -> IoResult<()> {
|
|
|
|
|
let hs_err_packet = [b'H', 0, 1, e.value_u8()];
|
|
|
|
|
con.write_all(&hs_err_packet).await
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
#[derive(Debug, PartialEq)]
|
|
|
|
|
enum PostHandshake {
|
|
|
|
|
enum HandshakeCompleteResult {
|
|
|
|
|
Okay(ClientLocalState),
|
|
|
|
|
Error(ProtocolError),
|
|
|
|
|
Error,
|
|
|
|
|
ConnectionClosedFin,
|
|
|
|
|
ConnectionClosedRst,
|
|
|
|
|
}
|
|
|
|
@ -123,40 +152,34 @@ async fn do_handshake<S: Socket>(
|
|
|
|
|
con: &mut BufWriter<S>,
|
|
|
|
|
buf: &mut BytesMut,
|
|
|
|
|
global: &Global,
|
|
|
|
|
) -> IoResult<PostHandshake> {
|
|
|
|
|
let mut expected = CHandshake::INITIAL_READ;
|
|
|
|
|
) -> IoResult<HandshakeCompleteResult> {
|
|
|
|
|
let mut state = HandshakeState::default();
|
|
|
|
|
let mut cursor = 0;
|
|
|
|
|
let handshake;
|
|
|
|
|
loop {
|
|
|
|
|
let read_many = con.read_buf(buf).await?;
|
|
|
|
|
if read_many == 0 {
|
|
|
|
|
if buf.is_empty() {
|
|
|
|
|
return Ok(PostHandshake::ConnectionClosedFin);
|
|
|
|
|
} else {
|
|
|
|
|
return Ok(PostHandshake::ConnectionClosedRst);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
if buf.len() < expected {
|
|
|
|
|
continue;
|
|
|
|
|
}
|
|
|
|
|
let mut scanner = unsafe { BufferedScanner::new_with_cursor(buf, cursor) };
|
|
|
|
|
match handshake::CHandshake::resume_with(&mut scanner, state) {
|
|
|
|
|
HandshakeResult::Completed(hs) => {
|
|
|
|
|
handshake = hs;
|
|
|
|
|
cursor = scanner.cursor();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
HandshakeResult::ChangeState { new_state, expect } => {
|
|
|
|
|
expected = expect;
|
|
|
|
|
state = new_state;
|
|
|
|
|
cursor = scanner.cursor();
|
|
|
|
|
}
|
|
|
|
|
HandshakeResult::Error(e) => {
|
|
|
|
|
return Ok(PostHandshake::Error(e));
|
|
|
|
|
read_loop!(
|
|
|
|
|
con,
|
|
|
|
|
buf,
|
|
|
|
|
Ok(HandshakeCompleteResult::ConnectionClosedFin),
|
|
|
|
|
Ok(HandshakeCompleteResult::ConnectionClosedRst),
|
|
|
|
|
{
|
|
|
|
|
let mut scanner = unsafe { BufferedScanner::new_with_cursor(buf, cursor) };
|
|
|
|
|
match handshake::CHandshake::resume_with(&mut scanner, state) {
|
|
|
|
|
HandshakeResult::Completed(hs) => {
|
|
|
|
|
handshake = hs;
|
|
|
|
|
cursor = scanner.cursor();
|
|
|
|
|
break;
|
|
|
|
|
}
|
|
|
|
|
HandshakeResult::ChangeState { new_state } => {
|
|
|
|
|
state = new_state;
|
|
|
|
|
cursor = scanner.cursor();
|
|
|
|
|
}
|
|
|
|
|
HandshakeResult::Error(e) => {
|
|
|
|
|
write_handshake_error(con, e).await?;
|
|
|
|
|
return Ok(HandshakeCompleteResult::Error);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
);
|
|
|
|
|
// check handshake
|
|
|
|
|
if cfg!(debug_assertions) {
|
|
|
|
|
assert_eq!(
|
|
|
|
@ -181,7 +204,7 @@ async fn do_handshake<S: Socket>(
|
|
|
|
|
{
|
|
|
|
|
okay @ (VerifyUser::Okay | VerifyUser::OkayRoot) => {
|
|
|
|
|
let hs = handshake.hs_static();
|
|
|
|
|
let ret = Ok(PostHandshake::Okay(ClientLocalState::new(
|
|
|
|
|
let ret = Ok(HandshakeCompleteResult::Okay(ClientLocalState::new(
|
|
|
|
|
uname.into(),
|
|
|
|
|
okay.is_root(),
|
|
|
|
|
hs,
|
|
|
|
@ -194,7 +217,8 @@ async fn do_handshake<S: Socket>(
|
|
|
|
|
}
|
|
|
|
|
Err(_) => {}
|
|
|
|
|
};
|
|
|
|
|
Ok(PostHandshake::Error(ProtocolError::RejectAuth))
|
|
|
|
|
write_handshake_error(con, ProtocolError::RejectAuth).await?;
|
|
|
|
|
Ok(HandshakeCompleteResult::Error)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
@ -217,60 +241,55 @@ pub(super) async fn query_loop<S: Socket>(
|
|
|
|
|
) -> IoResult<QueryLoopResult> {
|
|
|
|
|
// handshake
|
|
|
|
|
let mut client_state = match do_handshake(con, buf, global).await? {
|
|
|
|
|
PostHandshake::Okay(hs) => hs,
|
|
|
|
|
PostHandshake::ConnectionClosedFin => return Ok(QueryLoopResult::Fin),
|
|
|
|
|
PostHandshake::ConnectionClosedRst => return Ok(QueryLoopResult::Rst),
|
|
|
|
|
PostHandshake::Error(e) => {
|
|
|
|
|
// failed to handshake; we'll close the connection
|
|
|
|
|
let hs_err_packet = [b'H', 0, 1, e.value_u8()];
|
|
|
|
|
con.write_all(&hs_err_packet).await?;
|
|
|
|
|
return Ok(QueryLoopResult::HSFailed);
|
|
|
|
|
}
|
|
|
|
|
HandshakeCompleteResult::Okay(hs) => hs,
|
|
|
|
|
HandshakeCompleteResult::ConnectionClosedFin => return Ok(QueryLoopResult::Fin),
|
|
|
|
|
HandshakeCompleteResult::ConnectionClosedRst => return Ok(QueryLoopResult::Rst),
|
|
|
|
|
HandshakeCompleteResult::Error => return Ok(QueryLoopResult::HSFailed),
|
|
|
|
|
};
|
|
|
|
|
// done handshaking
|
|
|
|
|
con.write_all(b"H\x00\x00\x00").await?;
|
|
|
|
|
con.flush().await?;
|
|
|
|
|
let mut state = ExchangeState::default();
|
|
|
|
|
let mut cursor = 0;
|
|
|
|
|
loop {
|
|
|
|
|
if con.read_buf(buf).await? == 0 {
|
|
|
|
|
if buf.is_empty() {
|
|
|
|
|
return Ok(QueryLoopResult::Fin);
|
|
|
|
|
} else {
|
|
|
|
|
return Ok(QueryLoopResult::Rst);
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
match Exchange::try_complete(
|
|
|
|
|
unsafe {
|
|
|
|
|
// UNSAFE(@ohsayan): the cursor is either 0 or returned by the exchange impl
|
|
|
|
|
BufferedScanner::new_with_cursor(&buf, cursor)
|
|
|
|
|
},
|
|
|
|
|
state,
|
|
|
|
|
) {
|
|
|
|
|
Ok((result, new_cursor)) => match result {
|
|
|
|
|
ExchangeResult::NewState(new_state) => {
|
|
|
|
|
state = new_state;
|
|
|
|
|
cursor = new_cursor;
|
|
|
|
|
}
|
|
|
|
|
ExchangeResult::Simple(query) => {
|
|
|
|
|
exec_simple(con, &mut client_state, global, query).await?;
|
|
|
|
|
read_loop!(
|
|
|
|
|
con,
|
|
|
|
|
buf,
|
|
|
|
|
Ok(QueryLoopResult::Fin),
|
|
|
|
|
Ok(QueryLoopResult::Rst),
|
|
|
|
|
{
|
|
|
|
|
match Exchange::try_complete(
|
|
|
|
|
unsafe {
|
|
|
|
|
// UNSAFE(@ohsayan): the cursor is either 0 or returned by the exchange impl
|
|
|
|
|
BufferedScanner::new_with_cursor(&buf, cursor)
|
|
|
|
|
},
|
|
|
|
|
state,
|
|
|
|
|
) {
|
|
|
|
|
Ok((result, new_cursor)) => match result {
|
|
|
|
|
ExchangeResult::NewState(new_state) => {
|
|
|
|
|
state = new_state;
|
|
|
|
|
cursor = new_cursor;
|
|
|
|
|
}
|
|
|
|
|
ExchangeResult::Simple(query) => {
|
|
|
|
|
exec_simple(con, &mut client_state, global, query).await?;
|
|
|
|
|
(state, cursor) = cleanup_for_next_query(con, buf).await?;
|
|
|
|
|
}
|
|
|
|
|
ExchangeResult::Pipeline(pipe) => {
|
|
|
|
|
exec_pipe(con, &mut client_state, global, pipe).await?;
|
|
|
|
|
(state, cursor) = cleanup_for_next_query(con, buf).await?;
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
Err(_) => {
|
|
|
|
|
// respond with error
|
|
|
|
|
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8()
|
|
|
|
|
as u16)
|
|
|
|
|
.to_le_bytes();
|
|
|
|
|
con.write_all(&[ResponseType::Error.value_u8(), a, b])
|
|
|
|
|
.await?;
|
|
|
|
|
(state, cursor) = cleanup_for_next_query(con, buf).await?;
|
|
|
|
|
}
|
|
|
|
|
ExchangeResult::Pipeline(pipe) => {
|
|
|
|
|
exec_pipe(con, &mut client_state, global, pipe).await?;
|
|
|
|
|
(state, cursor) = cleanup_for_next_query(con, buf).await?;
|
|
|
|
|
}
|
|
|
|
|
},
|
|
|
|
|
Err(_) => {
|
|
|
|
|
// respond with error
|
|
|
|
|
let [a, b] = (QueryError::SysNetworkSystemIllegalClientPacket.value_u8() as u16)
|
|
|
|
|
.to_le_bytes();
|
|
|
|
|
con.write_all(&[ResponseType::Error.value_u8(), a, b])
|
|
|
|
|
.await?;
|
|
|
|
|
(state, cursor) = cleanup_for_next_query(con, buf).await?;
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
);
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
/*
|
|
|
|
|