Add a generic implementation for a connection

This commit defines two traits: `Con` and `ConOps`. Implementors of
`ConOps` get a free implementation for `Con`. `Con` is the ultimate
object that can be used in place of the current SSL/non-SSL connection
objects. If you look at the implementations of the current connection
objects, they have a lot of repetition as they do almost the same thing
except for the fact that they have a different underlying stream.
This is exactly what we're trying to eliminate. We will also define a
generic connection handler object to reduce redundancy.
next
Sayan Nandan 3 years ago
parent 7dadf4411f
commit ba478b9f5a

BIN
.DS_Store vendored

Binary file not shown.

3
.gitignore vendored

@ -5,4 +5,5 @@ data.bin
snapstore.bin
snapstore.partmap
/snapshots
/.idea
/.idea
.DS_Store

@ -24,7 +24,6 @@ regex = "1.4.5"
sky_macros = {path="../sky-macros"}
tokio-openssl = "0.6.1"
openssl = { version = "0.10.33", features = ["vendored"] }
[target.'cfg(not(target_env = "msvc"))'.dependencies]
jemallocator = "0.3.2"

@ -0,0 +1,258 @@
/*
* Created on Sun Apr 25 2021
*
* This file is a part of Skytable
* Skytable (formerly known as TerrabaseDB or Skybase) is a free and open-source
* NoSQL database written by Sayan Nandan ("the Author") with the
* vision to provide flexibility in data modelling without compromising
* on performance, queryability or scalability.
*
* Copyright (c) 2020, Sayan Nandan <ohsayan@outlook.com>
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*
*/
use super::deserializer;
use super::responses;
use crate::dbnet::Terminator;
use crate::protocol::tls::SslConnection;
use crate::protocol::Connection;
use crate::protocol::ParseResult;
use crate::protocol::QueryResult;
use crate::resp::Writable;
use crate::CoreDB;
use bytes::Buf;
use bytes::BytesMut;
use std::future::Future;
use std::io::Error as IoError;
use std::io::ErrorKind;
use std::io::Result as IoResult;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufWriter;
use tokio::net::TcpStream;
use tokio::sync::mpsc;
use tokio::sync::Semaphore;
use tokio_openssl::SslStream;
pub trait Con<Strm>: ConOps<Strm>
where
Strm: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
/// Try to fill the buffer again
fn read_again<'r, 's>(&'r mut self) -> Pin<Box<dyn Future<Output = IoResult<()>> + Send + 's>>
where
'r: 's,
Self: Send + 's,
{
Box::pin(async move {
let mv_self = self;
let ret: IoResult<()> = {
let (buffer, stream) = mv_self.get_mut_both();
match stream.read_buf(buffer).await {
Ok(0) => {
if buffer.is_empty() {
return Ok(());
} else {
return Err(IoError::from(ErrorKind::ConnectionReset));
}
}
Ok(_) => Ok(()),
Err(e) => return Err(e),
}
};
ret
})
}
/// Try to parse a query from the buffered data
fn try_query(&self) -> Result<ParseResult, ()> {
if self.get_buffer().is_empty() {
return Err(());
}
Ok(deserializer::parse(&self.get_buffer()))
}
/// Read a query from the remote end
///
/// This function asynchronously waits until all the data required
/// for parsing the query is available
fn read_query<'r, 's>(
&'r mut self,
) -> Pin<Box<dyn Future<Output = Result<QueryResult, IoError>> + Send + 's>>
where
'r: 's,
Self: Send + 's,
{
Box::pin(async move {
let mv_self = self;
let _: Result<QueryResult, IoError> = {
mv_self.read_again().await?;
loop {
match mv_self.try_query() {
Ok(ParseResult::Query(query, forward)) => {
mv_self.advance_buffer(forward);
return Ok(QueryResult::Q(query));
}
Ok(ParseResult::BadPacket) => {
mv_self.clear_buffer();
return Ok(QueryResult::E(responses::fresp::R_PACKET_ERR.to_owned()));
}
Err(_) => {
return Ok(QueryResult::Empty);
}
_ => (),
}
mv_self.read_again().await?;
}
};
})
}
/// Write a response to the stream
fn write_response<'r, 's>(
&'r mut self,
streamer: impl Writable + 's + Send,
) -> Pin<Box<dyn Future<Output = IoResult<()>> + Send + 's>>
where
'r: 's,
Self: Send + 's,
{
Box::pin(async move {
let mv_self = self;
let streamer = streamer;
let ret: IoResult<()> = {
streamer.write(&mut mv_self.get_mut_stream()).await?;
Ok(())
};
ret
})
}
/// Wraps around the `write_response` used to differentiate between a
/// success response and an error response
fn close_conn_with_error<'r, 's>(
&'r mut self,
resp: Vec<u8>,
) -> Pin<Box<dyn Future<Output = IoResult<()>> + Send + 's>>
where
'r: 's,
Self: Send + 's,
{
Box::pin(async move {
let mv_self = self;
let ret: IoResult<()> = {
mv_self.write_response(resp).await?;
mv_self.flush_stream().await?;
Ok(())
};
ret
})
}
fn flush_stream<'r, 's>(&'r mut self) -> Pin<Box<dyn Future<Output = IoResult<()>> + Send + 's>>
where
'r: 's,
Self: Send + 's,
{
Box::pin(async move {
let mv_self = self;
let ret: IoResult<()> = {
mv_self.get_mut_stream().flush().await?;
Ok(())
};
ret
})
}
}
pub trait ConOps<Strm> {
/// Returns an **immutable** reference to the underlying read buffer
fn get_buffer(&self) -> &BytesMut;
/// Returns an **immutable** reference to the underlying stream
fn get_stream(&self) -> &BufWriter<Strm>;
/// Returns a **mutable** reference to the underlying read buffer
fn get_mut_buffer(&mut self) -> &mut BytesMut;
/// Returns a **mutable** reference to the underlying stream
fn get_mut_stream(&mut self) -> &mut BufWriter<Strm>;
/// Returns a **mutable** reference to (buffer, stream)
///
/// This is to avoid double mutable reference errors
fn get_mut_both(&mut self) -> (&mut BytesMut, &mut BufWriter<Strm>);
/// Advance the read buffer by `forward_by` positions
fn advance_buffer(&mut self, forward_by: usize) {
self.get_mut_buffer().advance(forward_by)
}
/// Clear the internal buffer completely
fn clear_buffer(&mut self) {
self.get_mut_buffer().clear()
}
}
// Give ConOps implementors a free Con impl
impl<Strm, T> Con<Strm> for T
where
T: ConOps<Strm>,
Strm: Sync + Send + Unpin + AsyncWriteExt + AsyncReadExt,
{
}
impl ConOps<SslStream<TcpStream>> for SslConnection {
fn get_buffer(&self) -> &BytesMut {
&self.buffer
}
fn get_stream(&self) -> &BufWriter<SslStream<TcpStream>> {
&self.stream
}
fn get_mut_buffer(&mut self) -> &mut BytesMut {
&mut self.buffer
}
fn get_mut_stream(&mut self) -> &mut BufWriter<SslStream<TcpStream>> {
&mut self.stream
}
fn get_mut_both(&mut self) -> (&mut BytesMut, &mut BufWriter<SslStream<TcpStream>>) {
(&mut self.buffer, &mut self.stream)
}
}
impl ConOps<TcpStream> for Connection {
fn get_buffer(&self) -> &BytesMut {
&self.buffer
}
fn get_stream(&self) -> &BufWriter<TcpStream> {
&self.stream
}
fn get_mut_buffer(&mut self) -> &mut BytesMut {
&mut self.buffer
}
fn get_mut_stream(&mut self) -> &mut BufWriter<TcpStream> {
&mut self.stream
}
fn get_mut_both(&mut self) -> (&mut BytesMut, &mut BufWriter<TcpStream>) {
(&mut self.buffer, &mut self.stream)
}
}
pub struct ConnectionHandler<T, Strm>
where
T: Con<Strm>,
Strm: Sync + Send + Unpin + AsyncWriteExt + AsyncReadExt,
{
db: CoreDB,
con: T,
climit: Arc<Semaphore>,
terminator: Terminator,
_term_sig_tx: mpsc::Sender<()>,
_marker: PhantomData<Strm>,
}

@ -44,6 +44,7 @@ use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt, BufWriter};
use tokio::net::TcpStream;
pub mod tls;
mod con;
/// A TCP connection wrapper
pub struct Connection {

@ -189,8 +189,8 @@ impl Drop for SslConnectionHandler {
}
pub struct SslConnection {
stream: BufWriter<SslStream<TcpStream>>,
buffer: BytesMut,
pub stream: BufWriter<SslStream<TcpStream>>,
pub buffer: BytesMut,
}
impl SslConnection {
@ -227,7 +227,7 @@ impl SslConnection {
self.stream.get_ref().get_ref().peer_addr()
}
/// Try to parse a query from the buffered data
fn try_query(&mut self) -> Result<ParseResult, ()> {
fn try_query(&self) -> Result<ParseResult, ()> {
if self.buffer.is_empty() {
return Err(());
}

@ -28,14 +28,11 @@
//!
use bytes::Bytes;
use libsky::terrapipe::RespCodes;
use std::error::Error;
use std::future::Future;
use std::io::Error as IoError;
use std::pin::Pin;
use tokio::io::AsyncReadExt;
use tokio::io::AsyncWriteExt;
use tokio::io::BufWriter;
use tokio::net::TcpStream;
use tokio_openssl::SslStream;
/// # The `Writable` trait
/// All trait implementors are given access to an asynchronous stream to which
@ -55,7 +52,7 @@ pub trait Writable {
fn write<'s>(
self,
con: &'s mut impl IsConnection,
) -> Pin<Box<dyn Future<Output = Result<(), Box<dyn Error>>> + Send + Sync + 's>>;
) -> Pin<Box<dyn Future<Output = Result<(), IoError>> + Send + Sync + 's>>;
}
pub trait IsConnection: std::marker::Sync + std::marker::Send {
@ -65,25 +62,10 @@ pub trait IsConnection: std::marker::Sync + std::marker::Send {
) -> Pin<Box<dyn Future<Output = Result<usize, IoError>> + Send + Sync + 's>>;
}
impl IsConnection for BufWriter<TcpStream> {
fn write_lowlevel<'s>(
&'s mut self,
bytes: &'s [u8],
) -> Pin<Box<dyn Future<Output = Result<usize, IoError>> + Send + Sync + 's>> {
Box::pin(self.write(bytes))
}
}
impl IsConnection for SslStream<TcpStream> {
fn write_lowlevel<'s>(
&'s mut self,
bytes: &'s [u8],
) -> Pin<Box<dyn Future<Output = Result<usize, IoError>> + Send + Sync + 's>> {
Box::pin(self.write(bytes))
}
}
impl IsConnection for BufWriter<SslStream<TcpStream>> {
impl<T> IsConnection for T
where
T: AsyncReadExt + AsyncWriteExt + Unpin + Send + Sync,
{
fn write_lowlevel<'s>(
&'s mut self,
bytes: &'s [u8],
@ -120,12 +102,8 @@ impl Writable for Vec<u8> {
fn write<'s>(
self,
con: &'s mut impl IsConnection,
) -> Pin<Box<(dyn Future<Output = Result<(), Box<(dyn Error + 'static)>>> + Send + Sync + 's)>>
{
async fn write_bytes(
con: &mut impl IsConnection,
resp: Vec<u8>,
) -> Result<(), Box<dyn Error>> {
) -> Pin<Box<(dyn Future<Output = Result<(), IoError>> + Send + Sync + 's)>> {
async fn write_bytes(con: &mut impl IsConnection, resp: Vec<u8>) -> Result<(), IoError> {
con.write_lowlevel(&resp).await?;
Ok(())
}
@ -137,12 +115,8 @@ impl Writable for &'static [u8] {
fn write<'s>(
self,
con: &'s mut impl IsConnection,
) -> Pin<Box<(dyn Future<Output = Result<(), Box<(dyn Error + 'static)>>> + Send + Sync + 's)>>
{
async fn write_bytes(
con: &mut impl IsConnection,
resp: &[u8],
) -> Result<(), Box<dyn Error>> {
) -> Pin<Box<(dyn Future<Output = Result<(), IoError>> + Send + Sync + 's)>> {
async fn write_bytes(con: &mut impl IsConnection, resp: &[u8]) -> Result<(), IoError> {
con.write_lowlevel(&resp).await?;
Ok(())
}
@ -154,12 +128,8 @@ impl Writable for BytesWrapper {
fn write<'s>(
self,
con: &'s mut impl IsConnection,
) -> Pin<Box<(dyn Future<Output = Result<(), Box<(dyn Error + 'static)>>> + Send + Sync + 's)>>
{
async fn write_bytes(
con: &mut impl IsConnection,
bytes: Bytes,
) -> Result<(), Box<dyn Error>> {
) -> Pin<Box<(dyn Future<Output = Result<(), IoError>> + Send + Sync + 's)>> {
async fn write_bytes(con: &mut impl IsConnection, bytes: Bytes) -> Result<(), IoError> {
// First write a `+` character to the stream since this is a
// string (we represent `String`s as `Byte` objects internally)
// and since `Bytes` are effectively `String`s we will append the
@ -185,12 +155,8 @@ impl Writable for RespCodes {
fn write<'s>(
self,
con: &'s mut impl IsConnection,
) -> Pin<Box<(dyn Future<Output = Result<(), Box<(dyn Error + 'static)>>> + Send + Sync + 's)>>
{
async fn write_bytes(
con: &mut impl IsConnection,
code: RespCodes,
) -> Result<(), Box<dyn Error>> {
) -> Pin<Box<(dyn Future<Output = Result<(), IoError>> + Send + Sync + 's)>> {
async fn write_bytes(con: &mut impl IsConnection, code: RespCodes) -> Result<(), IoError> {
if let RespCodes::OtherError(Some(e)) = code {
// Since this is an other error which contains a description
// we'll write !<no_of_bytes> followed by the string
@ -232,12 +198,8 @@ impl Writable for GroupBegin {
fn write<'s>(
self,
con: &'s mut impl IsConnection,
) -> Pin<Box<(dyn Future<Output = Result<(), Box<(dyn Error + 'static)>>> + Send + Sync + 's)>>
{
async fn write_bytes(
con: &mut impl IsConnection,
size: usize,
) -> Result<(), Box<dyn Error>> {
) -> Pin<Box<(dyn Future<Output = Result<(), IoError>> + Send + Sync + 's)>> {
async fn write_bytes(con: &mut impl IsConnection, size: usize) -> Result<(), IoError> {
con.write_lowlevel(b"#2\n*1\n").await?;
// First write a `#` which indicates that the next bytes give the
// prefix length
@ -262,12 +224,8 @@ impl Writable for usize {
fn write<'s>(
self,
con: &'s mut impl IsConnection,
) -> Pin<Box<(dyn Future<Output = Result<(), Box<(dyn Error + 'static)>>> + Send + Sync + 's)>>
{
async fn write_bytes(
con: &mut impl IsConnection,
val: usize,
) -> Result<(), Box<dyn Error>> {
) -> Pin<Box<(dyn Future<Output = Result<(), IoError>> + Send + Sync + 's)>> {
async fn write_bytes(con: &mut impl IsConnection, val: usize) -> Result<(), IoError> {
con.write_lowlevel(b":").await?;
let usize_bytes = val.to_string().into_bytes();
let usize_bytes_len = usize_bytes.len().to_string().into_bytes();
@ -285,9 +243,8 @@ impl Writable for u64 {
fn write<'s>(
self,
con: &'s mut impl IsConnection,
) -> Pin<Box<(dyn Future<Output = Result<(), Box<(dyn Error + 'static)>>> + Send + Sync + 's)>>
{
async fn write_bytes(con: &mut impl IsConnection, val: u64) -> Result<(), Box<dyn Error>> {
) -> Pin<Box<(dyn Future<Output = Result<(), IoError>> + Send + Sync + 's)>> {
async fn write_bytes(con: &mut impl IsConnection, val: u64) -> Result<(), IoError> {
con.write_lowlevel(b":").await?;
let usize_bytes = val.to_string().into_bytes();
let usize_bytes_len = usize_bytes.len().to_string().into_bytes();

Loading…
Cancel
Save