|
|
|
@ -30,15 +30,15 @@ use crate::resp::Writable;
|
|
|
|
|
use crate::CoreDB;
|
|
|
|
|
use bytes::Buf;
|
|
|
|
|
use bytes::BytesMut;
|
|
|
|
|
use futures::future;
|
|
|
|
|
use libtdb::TResult;
|
|
|
|
|
use libtdb::BUF_CAP;
|
|
|
|
|
use openssl::ssl::{SslAcceptor, SslFiletype, SslMethod};
|
|
|
|
|
use openssl::ssl::{Ssl, SslAcceptor, SslFiletype, SslMethod};
|
|
|
|
|
use std::net::SocketAddr;
|
|
|
|
|
use std::pin::Pin;
|
|
|
|
|
use std::sync::Arc;
|
|
|
|
|
use tokio::io::BufWriter;
|
|
|
|
|
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
|
|
|
use tokio::net::TcpListener;
|
|
|
|
|
use tokio::net::TcpStream;
|
|
|
|
|
use tokio::io::{AsyncReadExt, AsyncWrite, AsyncWriteExt};
|
|
|
|
|
use tokio::net::{TcpListener, TcpStream};
|
|
|
|
|
use tokio::sync::Semaphore;
|
|
|
|
|
use tokio::sync::{broadcast, mpsc};
|
|
|
|
|
use tokio::time::{self, Duration};
|
|
|
|
@ -57,7 +57,7 @@ pub struct SslListener {
|
|
|
|
|
// We send a clone of `terminate_tx` to each `CHandler`
|
|
|
|
|
pub terminate_tx: mpsc::Sender<()>,
|
|
|
|
|
pub terminate_rx: mpsc::Receiver<()>,
|
|
|
|
|
acceptor: Arc<SslAcceptor>,
|
|
|
|
|
acceptor: SslAcceptor,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl SslListener {
|
|
|
|
@ -71,10 +71,11 @@ impl SslListener {
|
|
|
|
|
terminate_tx: mpsc::Sender<()>,
|
|
|
|
|
terminate_rx: mpsc::Receiver<()>,
|
|
|
|
|
) -> TResult<Self> {
|
|
|
|
|
log::debug!("New SSL/TLS connection registered");
|
|
|
|
|
let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls())?;
|
|
|
|
|
acceptor.set_private_key_file(key_file, SslFiletype::PEM)?;
|
|
|
|
|
acceptor.set_certificate_chain_file(chain_file)?;
|
|
|
|
|
let acceptor = Arc::new(acceptor.build());
|
|
|
|
|
let acceptor = acceptor.build();
|
|
|
|
|
Ok(SslListener {
|
|
|
|
|
db,
|
|
|
|
|
listener,
|
|
|
|
@ -86,19 +87,23 @@ impl SslListener {
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
async fn accept(&mut self) -> TResult<SslStream<TcpStream>> {
|
|
|
|
|
println!("Received connection");
|
|
|
|
|
log::debug!("Trying to accept a SSL connection");
|
|
|
|
|
let mut backoff = 1;
|
|
|
|
|
loop {
|
|
|
|
|
match self.listener.accept().await {
|
|
|
|
|
// We don't need the bindaddr
|
|
|
|
|
// We get the encrypted stream which we need to decrypt
|
|
|
|
|
// by using the acceptor
|
|
|
|
|
Ok((encrypted_stream, _)) => {
|
|
|
|
|
let decrypted_stream =
|
|
|
|
|
tokio_openssl::accept(&self.acceptor, encrypted_stream).await?;
|
|
|
|
|
return Ok(decrypted_stream);
|
|
|
|
|
Ok((stream, _)) => {
|
|
|
|
|
log::debug!("Accepted an SSL/TLS connection");
|
|
|
|
|
let ssl = Ssl::new(self.acceptor.context())?;
|
|
|
|
|
let mut stream = SslStream::new(ssl, stream)?;
|
|
|
|
|
Pin::new(&mut stream).accept().await?;
|
|
|
|
|
log::debug!("Connected to secure socket over TCP");
|
|
|
|
|
return Ok(stream);
|
|
|
|
|
}
|
|
|
|
|
Err(e) => {
|
|
|
|
|
log::debug!("Failed to establish a secure connection");
|
|
|
|
|
if backoff > 64 {
|
|
|
|
|
// Too many retries, goodbye user
|
|
|
|
|
return Err(e.into());
|
|
|
|
@ -112,10 +117,11 @@ impl SslListener {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
pub async fn run(&mut self) -> TResult<()> {
|
|
|
|
|
log::debug!("Started secure server");
|
|
|
|
|
loop {
|
|
|
|
|
// Take the permit first, but we won't use it right now
|
|
|
|
|
// that's why we will forget it
|
|
|
|
|
self.climit.acquire().await.forget();
|
|
|
|
|
self.climit.acquire().await.unwrap().forget();
|
|
|
|
|
let stream = self.accept().await?;
|
|
|
|
|
let mut sslhandle = SslConnectionHandler {
|
|
|
|
|
db: self.db.clone(),
|
|
|
|
@ -125,6 +131,7 @@ impl SslListener {
|
|
|
|
|
_term_sig_tx: self.terminate_tx.clone(),
|
|
|
|
|
};
|
|
|
|
|
tokio::spawn(async move {
|
|
|
|
|
log::debug!("Spawned listener task");
|
|
|
|
|
if let Err(e) = sslhandle.run().await {
|
|
|
|
|
eprintln!("Error: {}", e);
|
|
|
|
|
}
|
|
|
|
@ -143,6 +150,7 @@ pub struct SslConnectionHandler {
|
|
|
|
|
|
|
|
|
|
impl SslConnectionHandler {
|
|
|
|
|
pub async fn run(&mut self) -> TResult<()> {
|
|
|
|
|
log::debug!("SslConnectionHanler initialized to handle a remote client");
|
|
|
|
|
while !self.terminator.is_termination_signal() {
|
|
|
|
|
let try_df = tokio::select! {
|
|
|
|
|
tdf = self.con.read_query() => tdf,
|
|
|
|
@ -174,19 +182,19 @@ impl Drop for SslConnectionHandler {
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
pub struct SslConnection {
|
|
|
|
|
stream: BufWriter<SslStream<TcpStream>>,
|
|
|
|
|
stream: SslStream<TcpStream>,
|
|
|
|
|
buffer: BytesMut,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
impl SslConnection {
|
|
|
|
|
pub fn new(stream: SslStream<TcpStream>) -> Self {
|
|
|
|
|
SslConnection {
|
|
|
|
|
stream: BufWriter::new(stream),
|
|
|
|
|
stream: stream,
|
|
|
|
|
buffer: BytesMut::with_capacity(BUF_CAP),
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
async fn read_again(&mut self) -> Result<(), String> {
|
|
|
|
|
match self.stream.read_buf(&mut self.buffer).await {
|
|
|
|
|
match self.stream.get_mut().read_buf(&mut self.buffer).await {
|
|
|
|
|
Ok(0) => {
|
|
|
|
|
// If 0 bytes were received, then the remote end closed
|
|
|
|
|
// the connection
|
|
|
|
@ -209,7 +217,7 @@ impl SslConnection {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
fn get_peer(&self) -> IoResult<SocketAddr> {
|
|
|
|
|
self.stream.get_ref().get_ref().peer_addr()
|
|
|
|
|
self.stream.get_ref().peer_addr()
|
|
|
|
|
}
|
|
|
|
|
/// Try to parse a query from the buffered data
|
|
|
|
|
fn try_query(&mut self) -> Result<ParseResult, ()> {
|
|
|
|
@ -244,14 +252,14 @@ impl SslConnection {
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
pub async fn flush_stream(&mut self) -> TResult<()> {
|
|
|
|
|
self.stream.flush().await?;
|
|
|
|
|
self.stream.get_mut().flush().await?;
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
/// Wraps around the `write_response` used to differentiate between a
|
|
|
|
|
/// success response and an error response
|
|
|
|
|
pub async fn close_conn_with_error(&mut self, resp: Vec<u8>) -> TResult<()> {
|
|
|
|
|
self.write_response(resp).await?;
|
|
|
|
|
self.stream.flush().await?;
|
|
|
|
|
self.stream.get_mut().flush().await?;
|
|
|
|
|
Ok(())
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|