Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 220 additions & 7 deletions src/client/conn.rs
Original file line number Diff line number Diff line change
@@ -1,23 +1,33 @@
#[allow(clippy::module_inception)]
mod conn;
mod connector;
mod http;
mod proxy;
mod tcp;
mod tls_info;
#[cfg(unix)]
mod uds;
mod verbose;

use std::{
fmt::{self, Debug, Formatter},
io,
io::IoSlice,
pin::Pin,
sync::{
Arc,
atomic::{AtomicBool, Ordering},
},
task::{Context, Poll},
};

use ::http::{Extensions, HeaderMap, HeaderValue};
use tokio::io::{AsyncRead, AsyncWrite};
use pin_project_lite::pin_project;
#[cfg(unix)]
use tokio::net::UnixStream;
use tokio::{
io::{AsyncRead, AsyncWrite, ReadBuf},
net::TcpStream,
};
use tokio_btls::SslStream;
use tower::{
BoxError,
util::{BoxCloneSyncService, BoxCloneSyncServiceLayer},
Expand All @@ -26,16 +36,21 @@ use tower::{
#[cfg(feature = "socks")]
pub(super) use self::proxy::socks;
pub(super) use self::{
conn::Conn,
connector::Connector,
http::{HttpInfo, TcpConnectOptions},
http::{HttpInfo, HttpTransport},
proxy::tunnel,
tcp::{SocketBindOptions, tokio::TokioTcpConnector},
tls_info::TlsInfoFactory,
};
use crate::{client::ConnectRequest, dns::DynResolver, proxy::matcher::Intercept};
use crate::{
client::ConnectRequest,
dns::DynResolver,
proxy::matcher::Intercept,
tls::{MaybeHttpsStream, TlsInfo},
};

/// HTTP connector with dynamic DNS resolver.
pub type HttpConnector = self::http::HttpConnector<DynResolver>;
pub type HttpConnector = self::http::HttpConnector<DynResolver, TokioTcpConnector>;

/// Boxed connector service for establishing connections.
pub type BoxedConnectorService = BoxCloneSyncService<Unnameable, Conn, BoxError>;
Expand Down Expand Up @@ -69,6 +84,31 @@ impl<T> AsyncConn for T where T: AsyncRead + AsyncWrite + Connection + Send + Sy

impl<T> AsyncConnWithInfo for T where T: AsyncConn + TlsInfoFactory {}

pin_project! {
/// Note: the `is_proxy` member means *is plain text HTTP proxy*.
/// This tells core whether the URI should be written in
/// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or
/// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise.
pub struct Conn {
tls_info: bool,
proxy: Option<Intercept>,
#[pin]
stream: Box<dyn AsyncConnWithInfo>,
}
}

pin_project! {
/// A wrapper around `SslStream` that adapts it for use as a generic async connection.
///
/// This type enables unified handling of plain TCP and TLS-encrypted streams by providing
/// implementations of `Connection`, `Read`, `Write`, and `TlsInfoFactory`.
/// It is mainly used internally to abstract over different connection types.
pub struct TlsConn<T> {
#[pin]
stream: SslStream<T>,
}
}

/// Describes a type returned by a connector.
pub trait Connection {
/// Return metadata describing the connection.
Expand Down Expand Up @@ -129,6 +169,179 @@ pub struct Connected {
poisoned: PoisonPill,
}

// ==== impl Conn ====

impl Connection for Conn {
fn connected(&self) -> Connected {
let mut connected = self.stream.connected();

if let Some(proxy) = &self.proxy {
connected = connected.proxy(proxy.clone());
}

if self.tls_info {
if let Some(tls_info) = self.stream.tls_info() {
connected.extra(tls_info)
} else {
connected
}
} else {
connected
}
}
}

impl AsyncRead for Conn {
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
AsyncRead::poll_read(self.project().stream, cx, buf)
}
}

impl AsyncWrite for Conn {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, io::Error>> {
AsyncWrite::poll_write(self.project().stream, cx, buf)
}

#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
AsyncWrite::poll_write_vectored(self.project().stream, cx, bufs)
}

#[inline]
fn is_write_vectored(&self) -> bool {
self.stream.is_write_vectored()
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
AsyncWrite::poll_flush(self.project().stream, cx)
}

#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), io::Error>> {
AsyncWrite::poll_shutdown(self.project().stream, cx)
}
}

// ===== impl TcpStream =====

impl Connection for TlsConn<TcpStream> {
fn connected(&self) -> Connected {
let connected = self.stream.get_ref().connected();
if self.stream.ssl().selected_alpn_protocol() == Some(b"h2") {
connected.negotiated_h2()
} else {
connected
}
}
}

impl Connection for TlsConn<MaybeHttpsStream<TcpStream>> {
fn connected(&self) -> Connected {
let connected = self.stream.get_ref().connected();
if self.stream.ssl().selected_alpn_protocol() == Some(b"h2") {
connected.negotiated_h2()
} else {
connected
}
}
}

// ===== impl UnixStream =====

#[cfg(unix)]
impl Connection for TlsConn<UnixStream> {
fn connected(&self) -> Connected {
let connected = self.stream.get_ref().connected();
if self.stream.ssl().selected_alpn_protocol() == Some(b"h2") {
connected.negotiated_h2()
} else {
connected
}
}
}

#[cfg(unix)]
impl Connection for TlsConn<MaybeHttpsStream<UnixStream>> {
fn connected(&self) -> Connected {
let connected = self.stream.get_ref().connected();
if self.stream.ssl().selected_alpn_protocol() == Some(b"h2") {
connected.negotiated_h2()
} else {
connected
}
}
}

impl<T: AsyncRead + AsyncWrite + Unpin> AsyncRead for TlsConn<T> {
#[inline]
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &mut ReadBuf<'_>,
) -> Poll<tokio::io::Result<()>> {
AsyncRead::poll_read(self.project().stream, cx, buf)
}
}

impl<T: AsyncRead + AsyncWrite + Unpin> AsyncWrite for TlsConn<T> {
#[inline]
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context,
buf: &[u8],
) -> Poll<Result<usize, tokio::io::Error>> {
AsyncWrite::poll_write(self.project().stream, cx, buf)
}

#[inline]
fn poll_write_vectored(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
AsyncWrite::poll_write_vectored(self.project().stream, cx, bufs)
}

#[inline]
fn is_write_vectored(&self) -> bool {
self.stream.is_write_vectored()
}

#[inline]
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
AsyncWrite::poll_flush(self.project().stream, cx)
}

#[inline]
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), tokio::io::Error>> {
AsyncWrite::poll_shutdown(self.project().stream, cx)
}
}

impl<T> TlsInfoFactory for TlsConn<T>
where
SslStream<T>: TlsInfoFactory,
{
fn tls_info(&self) -> Option<TlsInfo> {
self.stream.tls_info()
}
}

// ===== impl PoisonPill =====

impl fmt::Debug for PoisonPill {
Expand Down
Loading