diff --git a/src/client/conn.rs b/src/client/conn.rs index 3524720e..e55927c0 100644 --- a/src/client/conn.rs +++ b/src/client/conn.rs @@ -1,8 +1,7 @@ -#[allow(clippy::module_inception)] -mod conn; mod connector; mod http; mod proxy; +mod tcp; mod tls_info; #[cfg(unix)] mod uds; @@ -10,14 +9,20 @@ mod verbose; use std::{ fmt::{self, Debug, Formatter}, + io, + io::IoSlice, + pin::Pin, sync::{ Arc, atomic::{AtomicBool, Ordering}, }, + task::{Context, Poll}, }; use ::http::{Extensions, HeaderMap, HeaderValue}; -use tokio::io::{AsyncRead, AsyncWrite}; +use pin_project_lite::pin_project; +use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; +use tokio_btls::SslStream; use tower::{ BoxError, util::{BoxCloneSyncService, BoxCloneSyncServiceLayer}, @@ -26,16 +31,16 @@ use tower::{ #[cfg(feature = "socks")] pub(super) use self::proxy::socks; pub(super) use self::{ - conn::Conn, connector::Connector, - http::{HttpInfo, TcpConnectOptions}, + http::{HttpInfo, HttpTransport}, proxy::tunnel, + tcp::{SocketBindOptions, tokio::TokioTcpConnector}, tls_info::TlsInfoFactory, }; -use crate::{client::ConnectRequest, dns::DynResolver, proxy::matcher::Intercept}; +use crate::{client::ConnectRequest, dns::DynResolver, proxy::matcher::Intercept, tls::TlsInfo}; /// HTTP connector with dynamic DNS resolver. -pub type HttpConnector = self::http::HttpConnector; +pub type HttpConnector = self::http::HttpConnector; /// Boxed connector service for establishing connections. pub type BoxedConnectorService = BoxCloneSyncService; @@ -69,6 +74,31 @@ impl AsyncConn for T where T: AsyncRead + AsyncWrite + Connection + Send + Sy impl AsyncConnWithInfo for T where T: AsyncConn + TlsInfoFactory {} +pin_project! { + /// Note: the `is_proxy` member means *is plain text HTTP proxy*. + /// This tells core whether the URI should be written in + /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or + /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise. + pub struct Conn { + tls_info: bool, + proxy: Option, + #[pin] + stream: Box, + } +} + +pin_project! { + /// A wrapper around `SslStream` that adapts it for use as a generic async connection. + /// + /// This type enables unified handling of plain TCP and TLS-encrypted streams by providing + /// implementations of `Connection`, `Read`, `Write`, and `TlsInfoFactory`. + /// It is mainly used internally to abstract over different connection types. + pub struct TlsConn { + #[pin] + stream: SslStream, + } +} + /// Describes a type returned by a connector. pub trait Connection { /// Return metadata describing the connection. @@ -129,6 +159,145 @@ pub struct Connected { poisoned: PoisonPill, } +// ==== impl Conn ==== + +impl Connection for Conn { + fn connected(&self) -> Connected { + let mut connected = self.stream.connected(); + + if let Some(proxy) = &self.proxy { + connected = connected.proxy(proxy.clone()); + } + + if self.tls_info { + if let Some(tls_info) = self.stream.tls_info() { + connected.extra(tls_info) + } else { + connected + } + } else { + connected + } + } +} + +impl AsyncRead for Conn { + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + AsyncRead::poll_read(self.project().stream, cx, buf) + } +} + +impl AsyncWrite for Conn { + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + AsyncWrite::poll_write(self.project().stream, cx, buf) + } + + #[inline] + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + AsyncWrite::poll_write_vectored(self.project().stream, cx, bufs) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + self.stream.is_write_vectored() + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + AsyncWrite::poll_flush(self.project().stream, cx) + } + + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + AsyncWrite::poll_shutdown(self.project().stream, cx) + } +} + +// ===== impl TlsConn ===== + +impl Connection for TlsConn +where + T: Connection, +{ + fn connected(&self) -> Connected { + let connected = self.stream.get_ref().connected(); + if self.stream.ssl().selected_alpn_protocol() == Some(b"h2") { + connected.negotiated_h2() + } else { + connected + } + } +} + +impl AsyncRead for TlsConn { + #[inline] + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + AsyncRead::poll_read(self.project().stream, cx, buf) + } +} + +impl AsyncWrite for TlsConn { + #[inline] + fn poll_write( + self: Pin<&mut Self>, + cx: &mut Context, + buf: &[u8], + ) -> Poll> { + AsyncWrite::poll_write(self.project().stream, cx, buf) + } + + #[inline] + fn poll_write_vectored( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + bufs: &[IoSlice<'_>], + ) -> Poll> { + AsyncWrite::poll_write_vectored(self.project().stream, cx, bufs) + } + + #[inline] + fn is_write_vectored(&self) -> bool { + self.stream.is_write_vectored() + } + + #[inline] + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + AsyncWrite::poll_flush(self.project().stream, cx) + } + + #[inline] + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + AsyncWrite::poll_shutdown(self.project().stream, cx) + } +} + +impl TlsInfoFactory for TlsConn +where + SslStream: TlsInfoFactory, +{ + fn tls_info(&self) -> Option { + self.stream.tls_info() + } +} + // ===== impl PoisonPill ===== impl fmt::Debug for PoisonPill { diff --git a/src/client/conn/conn.rs b/src/client/conn/conn.rs deleted file mode 100644 index 3eeda84b..00000000 --- a/src/client/conn/conn.rs +++ /dev/null @@ -1,231 +0,0 @@ -use std::{ - io::{self, IoSlice}, - pin::Pin, - task::{Context, Poll}, -}; - -use pin_project_lite::pin_project; -#[cfg(unix)] -use tokio::net::UnixStream; -use tokio::{ - io::{AsyncRead, AsyncWrite, ReadBuf}, - net::TcpStream, -}; -use tokio_btls::SslStream; - -use super::{AsyncConnWithInfo, Connected, Connection, TlsInfoFactory}; -use crate::{ - proxy::matcher::Intercept, - tls::{MaybeHttpsStream, TlsInfo}, -}; - -pin_project! { - /// Note: the `is_proxy` member means *is plain text HTTP proxy*. - /// This tells core whether the URI should be written in - /// * origin-form (`GET /just/a/path HTTP/1.1`), when `is_proxy == false`, or - /// * absolute-form (`GET http://foo.bar/and/a/path HTTP/1.1`), otherwise. - pub struct Conn { - #[pin] - pub(super) inner: Box, - pub(super) tls_info: bool, - pub(super) proxy: Option, - } -} - -pin_project! { - /// A wrapper around `SslStream` that adapts it for use as a generic async connection. - /// - /// This type enables unified handling of plain TCP and TLS-encrypted streams by providing - /// implementations of `Connection`, `Read`, `Write`, and `TlsInfoFactory`. - /// It is mainly used internally to abstract over different connection types. - pub struct TlsConn { - #[pin] - inner: SslStream, - } -} - -// ==== impl Conn ==== - -impl Connection for Conn { - fn connected(&self) -> Connected { - let mut connected = self.inner.connected(); - - if let Some(proxy) = &self.proxy { - connected = connected.proxy(proxy.clone()); - } - - if self.tls_info { - if let Some(tls_info) = self.inner.tls_info() { - connected.extra(tls_info) - } else { - connected - } - } else { - connected - } - } -} - -impl AsyncRead for Conn { - #[inline] - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - AsyncRead::poll_read(self.project().inner, cx, buf) - } -} - -impl AsyncWrite for Conn { - #[inline] - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &[u8], - ) -> Poll> { - AsyncWrite::poll_write(self.project().inner, cx, buf) - } - - #[inline] - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) - } - - #[inline] - fn is_write_vectored(&self) -> bool { - self.inner.is_write_vectored() - } - - #[inline] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - AsyncWrite::poll_flush(self.project().inner, cx) - } - - #[inline] - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - AsyncWrite::poll_shutdown(self.project().inner, cx) - } -} - -// ==== impl TlsConn ==== - -impl TlsConn -where - T: AsyncRead + AsyncWrite + Unpin, -{ - /// Creates a new `TlsConn` wrapping the provided `SslStream`. - #[inline(always)] - pub fn new(inner: SslStream) -> Self { - Self { inner } - } -} - -// ===== impl TcpStream ===== - -impl Connection for TlsConn { - fn connected(&self) -> Connected { - let connected = self.inner.get_ref().connected(); - if self.inner.ssl().selected_alpn_protocol() == Some(b"h2") { - connected.negotiated_h2() - } else { - connected - } - } -} - -impl Connection for TlsConn> { - fn connected(&self) -> Connected { - let connected = self.inner.get_ref().connected(); - if self.inner.ssl().selected_alpn_protocol() == Some(b"h2") { - connected.negotiated_h2() - } else { - connected - } - } -} - -// ===== impl UnixStream ===== - -#[cfg(unix)] -impl Connection for TlsConn { - fn connected(&self) -> Connected { - let connected = self.inner.get_ref().connected(); - if self.inner.ssl().selected_alpn_protocol() == Some(b"h2") { - connected.negotiated_h2() - } else { - connected - } - } -} - -#[cfg(unix)] -impl Connection for TlsConn> { - fn connected(&self) -> Connected { - let connected = self.inner.get_ref().connected(); - if self.inner.ssl().selected_alpn_protocol() == Some(b"h2") { - connected.negotiated_h2() - } else { - connected - } - } -} - -impl AsyncRead for TlsConn { - #[inline] - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut ReadBuf<'_>, - ) -> Poll> { - AsyncRead::poll_read(self.project().inner, cx, buf) - } -} - -impl AsyncWrite for TlsConn { - #[inline] - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &[u8], - ) -> Poll> { - AsyncWrite::poll_write(self.project().inner, cx, buf) - } - - #[inline] - fn poll_write_vectored( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - bufs: &[IoSlice<'_>], - ) -> Poll> { - AsyncWrite::poll_write_vectored(self.project().inner, cx, bufs) - } - - #[inline] - fn is_write_vectored(&self) -> bool { - self.inner.is_write_vectored() - } - - #[inline] - fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - AsyncWrite::poll_flush(self.project().inner, cx) - } - - #[inline] - fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - AsyncWrite::poll_shutdown(self.project().inner, cx) - } -} - -impl TlsInfoFactory for TlsConn -where - SslStream: TlsInfoFactory, -{ - fn tls_info(&self) -> Option { - self.inner.tls_info() - } -} diff --git a/src/client/conn/connector.rs b/src/client/conn/connector.rs index d73faeb0..76204394 100644 --- a/src/client/conn/connector.rs +++ b/src/client/conn/connector.rs @@ -1,4 +1,5 @@ use std::{ + borrow::Cow, future::Future, pin::Pin, sync::Arc, @@ -6,8 +7,6 @@ use std::{ time::Duration, }; -use http::Uri; -use tokio::io::{AsyncRead, AsyncWrite}; use tokio_btls::SslStream; use tower::{ Service, ServiceBuilder, ServiceExt, @@ -18,14 +17,14 @@ use tower::{ #[cfg(unix)] use super::uds::UnixConnector; use super::{ - AsyncConnWithInfo, BoxedConnectorLayer, BoxedConnectorService, Connection, HttpConnector, - TlsInfoFactory, Unnameable, - conn::{Conn, TlsConn}, - proxy, - verbose::Verbose, + AsyncConnWithInfo, BoxedConnectorLayer, BoxedConnectorService, Conn, Connection, HttpConnector, + TlsConn, TlsInfoFactory, Unnameable, http::HttpTransport, proxy, verbose::Verbose, }; use crate::{ - client::http::{ConnectExtra, ConnectRequest}, + client::{ + conn::TokioTcpConnector, + http::{ConnectExtra, ConnectRequest}, + }, dns::DynResolver, error::{BoxError, ProxyConnect, TimedOut, map_timeout_to_connector_error}, ext::UriExt, @@ -43,7 +42,7 @@ type Connecting = Pin> + Send>>; struct Config { proxies: Arc>, verbose: Verbose, - tcp_nodelay: bool, + nodelay: bool, tls_info: bool, /// When there is a single timeout layer and no other layers, /// we embed it directly inside our base Service::call(). @@ -58,8 +57,7 @@ pub struct ConnectorBuilder { #[cfg(feature = "socks")] resolver: DynResolver, http: HttpConnector, - tls_options: TlsOptions, - tls_builder: TlsConnectorBuilder, + builder: TlsConnectorBuilder, } /// Connector service that establishes connections. @@ -75,9 +73,9 @@ pub struct ConnectorService { config: Config, #[cfg(feature = "socks")] resolver: DynResolver, - http: HttpConnector, tls: TlsConnector, - tls_builder: Arc, + http: HttpConnector, + builder: Arc, } // ===== impl ConnectorBuilder ===== @@ -99,7 +97,7 @@ impl ConnectorBuilder { where F: FnOnce(TlsConnectorBuilder) -> TlsConnectorBuilder, { - self.tls_builder = call(self.tls_builder); + self.builder = call(self.builder); self } @@ -127,31 +125,28 @@ impl ConnectorBuilder { self } - /// Sets the TLS options to use. - #[inline] - pub fn tls_options(mut self, opts: Option) -> ConnectorBuilder { - if let Some(opts) = opts { - self.tls_options = opts; - } - self - } - /// Sets the TCP_NODELAY option for connections. #[inline] pub fn tcp_nodelay(mut self, enabled: bool) -> ConnectorBuilder { - self.config.tcp_nodelay = enabled; + self.config.nodelay = enabled; self } /// Build a [`Connector`] with the provided layers. - pub fn build(self, layers: Vec) -> crate::Result { + pub fn build( + self, + tls_options: Option, + layers: Vec, + ) -> crate::Result { let mut service = ConnectorService { config: self.config, #[cfg(feature = "socks")] resolver: self.resolver.clone(), http: self.http, - tls: self.tls_builder.build(&self.tls_options)?, - tls_builder: Arc::new(self.tls_builder), + tls: self + .builder + .build(tls_options.map(Cow::Owned).unwrap_or_default())?, + builder: Arc::new(self.builder), }; // we have no user-provided layers, only use concrete types @@ -210,15 +205,14 @@ impl Connector { config: Config { proxies: Arc::new(proxies), verbose: Verbose::OFF, - tcp_nodelay: true, + nodelay: true, tls_info: false, timeout: None, }, #[cfg(feature = "socks")] resolver: resolver.clone(), - http: HttpConnector::new_with_resolver(resolver), - tls_options: TlsOptions::default(), - tls_builder: TlsConnector::builder(), + http: HttpConnector::new(resolver, TokioTcpConnector::new()), + builder: TlsConnector::builder(), } } } @@ -258,37 +252,39 @@ impl ConnectorService { // Disable Nagle's algorithm for TLS handshake // // https://www.openssl.org/docs/man1.1.1/man3/SSL_connect.html#NOTES - if https && !self.config.tcp_nodelay { + if https && !self.config.nodelay { http.set_nodelay(true); } // Apply TCP options if provided in metadata if let Some(opts) = extra.tcp_options() { - http.set_connect_options(opts.clone()); - } + #[cfg(any( + target_os = "android", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + ))] + if let Some(interface) = &opts.interface { + http.set_interface(interface.clone()); + } - self.build_tls_connector_generic(http, extra) - } + http.set_local_addresses(opts.local_address_ipv4, opts.local_address_ipv6); + } - fn build_tls_connector_generic( - &self, - connector: S, - extra: &ConnectExtra, - ) -> Result, BoxError> - where - S: Service + Send, - S::Error: Into, - S::Future: Unpin + Send + 'static, - T: AsyncRead + AsyncWrite + Connection + Unpin + std::fmt::Debug + Sync + Send + 'static, - { // Prefer TLS options from metadata, fallback to default let tls = extra .tls_options() - .map(|opts| self.tls_builder.build(opts)) + .map(|opts| self.builder.build(Cow::Borrowed(opts))) .transpose()? .unwrap_or_else(|| self.tls.clone()); - Ok(HttpsConnector::with_connector(connector, tls)) + Ok(HttpsConnector::new(http, tls)) } fn tunnel_conn_from_stream(&self, io: MaybeHttpsStream) -> Result @@ -298,13 +294,13 @@ impl ConnectorService { SslStream: TlsInfoFactory, { let conn = match io { - MaybeHttpsStream::Http(inner) => Conn { - inner: self.config.verbose.wrap(inner), + MaybeHttpsStream::Http(stream) => Conn { + stream: self.config.verbose.wrap(stream), tls_info: false, proxy: None, }, - MaybeHttpsStream::Https(inner) => Conn { - inner: self.config.verbose.wrap(TlsConn::new(inner)), + MaybeHttpsStream::Https(stream) => Conn { + stream: self.config.verbose.wrap(TlsConn { stream }), tls_info: self.config.tls_info, proxy: None, }, @@ -321,19 +317,17 @@ impl ConnectorService { P: Into>, { let conn = match io { - MaybeHttpsStream::Http(inner) => self.config.verbose.wrap(inner), - MaybeHttpsStream::Https(inner) => self.config.verbose.wrap(TlsConn::new(inner)), + MaybeHttpsStream::Http(stream) => self.config.verbose.wrap(stream), + MaybeHttpsStream::Https(stream) => self.config.verbose.wrap(TlsConn { stream }), }; Ok(Conn { - inner: conn, + stream: conn, tls_info: self.config.tls_info, proxy: proxy.into(), }) } -} -impl ConnectorService { async fn connect_auto_proxy>>( self, req: ConnectRequest, @@ -354,7 +348,7 @@ impl ConnectorService { let io = connector.call(req).await?; // Re-enable Nagle's algorithm if it was disabled earlier - if is_https && !self.config.tcp_nodelay { + if is_https && !self.config.nodelay { io.as_ref().set_nodelay(false)?; } @@ -407,7 +401,7 @@ impl ConnectorService { let io = connector.call(EstablishedConn::new(conn, req)).await?; // Re-enable Nagle's algorithm if it was disabled earlier - if is_https && !self.config.tcp_nodelay { + if is_https && !self.config.nodelay { io.as_ref().set_nodelay(false)?; } @@ -444,7 +438,7 @@ impl ConnectorService { let io = connector.call(EstablishedConn::new(tunneled, req)).await?; // Re-enable Nagle's algorithm if it was disabled earlier - if !self.config.tcp_nodelay { + if !self.config.nodelay { io.as_ref().as_ref().set_nodelay(false)?; } @@ -463,13 +457,13 @@ impl ConnectorService { // Create a Unix connector with the specified socket path. let mut connector = - self.build_tls_connector_generic(UnixConnector(unix_socket), req.extra())?; + HttpsConnector::new(UnixConnector::new(unix_socket), self.tls.clone()); // If the target URI is HTTPS, establish a CONNECT tunnel over the Unix socket, // then upgrade the tunneled stream to TLS. if uri.is_https() { // Use a dummy HTTP URI so the HTTPS connector works over the Unix socket. - let proxy_uri = Uri::from_static("http://localhost"); + let proxy_uri = http::Uri::from_static("http://localhost"); // The tunnel connector will first establish a CONNECT tunnel, // then perform the TLS handshake over the tunneled stream. diff --git a/src/client/conn/http.rs b/src/client/conn/http.rs index f496f5fd..2cc66de3 100644 --- a/src/client/conn/http.rs +++ b/src/client/conn/http.rs @@ -1,30 +1,108 @@ use std::{ error::Error as StdError, - fmt, future::Future, - io, marker::PhantomData, - net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + net::{Ipv4Addr, Ipv6Addr, SocketAddr}, pin::Pin, sync::Arc, task::{self, Poll}, time::Duration, }; -use futures_util::future::Either; use http::uri::{Scheme, Uri}; use pin_project_lite::pin_project; -use socket2::TcpKeepalive; -use tokio::{ - net::{TcpSocket, TcpStream}, - time::Sleep, +use tokio::io::{AsyncRead, AsyncWrite}; +use tower::Service; + +use super::{ + Connection, + tcp::{ + ConnectError, ConnectingTcp, SocketBindOptions, TcpConnector, TcpKeepaliveOptions, + TcpOptions, + }, }; +use crate::dns::{self, InternalResolve}; -use super::{Connected, Connection}; -use crate::{ - dns::{self, GaiResolver, InternalResolve, resolve}, - error::BoxError, -}; +static INVALID_NOT_HTTP: &str = "invalid URI, scheme is not http"; +static INVALID_MISSING_SCHEME: &str = "invalid URI, scheme is missing"; +static INVALID_MISSING_HOST: &str = "invalid URI, host is missing"; + +type ConnectResult = Result<::Connection, ConnectError>; +type BoxConnecting = Pin> + Send>>; + +/// A trait for configuring HTTP transport options on a [`Service`] connector. +/// +/// Provides methods to adjust TCP/socket-level settings such as keepalive, +/// timeouts, buffer sizes, and local address binding. [`HttpConnector`] +/// is the default implementation. +pub trait HttpTransport: Service + Clone + Send + Sized + 'static +where + Self::Response: AsyncRead + AsyncWrite + Connection + Unpin + Send + 'static, + Self::Error: Into>, + Self::Future: Unpin + Send + 'static, +{ + /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration + /// to remain idle before sending TCP keepalive probes. + fn enforce_http(&mut self, enforced: bool); + + /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`. + fn set_nodelay(&mut self, nodelay: bool); + + /// Sets the value of the `SO_SNDBUF` option on the socket. + fn set_send_buffer_size(&mut self, size: Option); + + /// Sets the value of the `SO_RCVBUF` option on the socket. + fn set_recv_buffer_size(&mut self, size: Option); + + /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`. + fn set_reuse_address(&mut self, reuse: bool); + + /// Sets the value of the `TCP_USER_TIMEOUT` option on the socket. + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + fn set_tcp_user_timeout(&mut self, time: Option); + + /// Set the connect timeout. + fn set_connect_timeout(&mut self, dur: Option); + + /// Set timeout for [RFC 6555 (Happy Eyeballs)][RFC 6555] algorithm. + /// + /// [RFC 6555]: https://tools.ietf.org/html/rfc6555 + fn set_happy_eyeballs_timeout(&mut self, dur: Option); + + /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration + /// to remain idle before sending TCP keepalive probes. + fn set_keepalive(&mut self, time: Option); + + /// Set the duration between two successive TCP keepalive retransmissions, + /// if acknowledgement to the previous keepalive transmission is not received. + fn set_keepalive_interval(&mut self, interval: Option); + + /// Set the number of retransmissions to be carried out before declaring that remote end is not + /// available. + fn set_keepalive_retries(&mut self, retries: Option); + + /// Sets the name of the interface to bind sockets produced. + #[cfg(any( + target_os = "android", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + ))] + fn set_interface>>(&mut self, interface: I); + + /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's + /// preferences) before connection. + fn set_local_addresses(&mut self, local_ipv4: V4, local_ipv6: V6) + where + V4: Into>, + V6: Into>; +} /// A connector for the `http` scheme. /// @@ -35,9 +113,10 @@ use crate::{ /// Sets the [`HttpInfo`] value on responses, which includes /// transport information such as the remote socket address used. #[derive(Clone)] -pub struct HttpConnector { - config: Arc, +pub struct HttpConnector { + options: Arc, resolver: R, + connector: S, } /// Extra information about the transport when an HttpConnector is used. @@ -62,381 +141,89 @@ pub struct HttpConnector { /// connector to see what "extra" information it might provide to responses. #[derive(Clone, Debug)] pub struct HttpInfo { - remote_addr: SocketAddr, - local_addr: SocketAddr, -} - -/// Options for configuring a TCP network connection. -/// -/// `TcpConnectOptions` allows fine-grained control over how TCP sockets -/// are created and connected. It can be used to: -/// -/// - Bind a socket to a specific **network interface** -/// - Bind to a **local IPv4 or IPv6 address** -/// -/// This is especially useful for scenarios involving: -/// - Virtual routing tables (e.g. Linux VRFs) -/// - Multiple NICs (network interface cards) -/// - Explicit source IP routing or firewall rules -/// -/// Platform-specific behavior is handled internally, with the interface binding -/// mechanism differing across Unix-like systems. -/// -/// # Platform Notes -/// -/// ## Interface binding (`set_interface`) -/// -/// - **Linux / Android / Fuchsia**: uses the `SO_BINDTODEVICE` socket option See [`man 7 socket`](https://man7.org/linux/man-pages/man7/socket.7.html) -/// -/// - **macOS / iOS / tvOS / watchOS / visionOS / illumos / Solaris**: uses the `IP_BOUND_IF` socket -/// option See [`man 7p ip`](https://docs.oracle.com/cd/E86824_01/html/E54777/ip-7p.html) -/// -/// Binding to an interface ensures that: -/// - **Outgoing packets** are sent through the specified interface -/// - **Incoming packets** are only accepted if received via that interface -/// -/// ❗ This only applies to certain socket types (e.g. `AF_INET`), and may require -/// elevated permissions (e.g. `CAP_NET_RAW` on Linux). -#[derive(Debug, Clone, Hash, PartialEq, Eq, Default)] -#[non_exhaustive] -pub struct TcpConnectOptions { - #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] - interface: Option>, - #[cfg(any( - target_os = "illumos", - target_os = "ios", - target_os = "macos", - target_os = "solaris", - target_os = "tvos", - target_os = "visionos", - target_os = "watchos", - ))] - interface: Option, - local_ipv4: Option, - local_ipv6: Option, -} - -impl TcpConnectOptions { - /// Sets the name of the network interface to bind the socket to. - /// - /// ## Platform behavior - /// - On Linux/Fuchsia/Android: sets `SO_BINDTODEVICE` - /// - On macOS/illumos/Solaris/iOS/etc.: sets `IP_BOUND_IF` - /// - /// If `interface` is `None`, the socket will not be explicitly bound to any device. - /// - /// # Errors - /// - /// On platforms that require a `CString` (e.g. macOS), this will return an error if the - /// interface name contains an internal null byte (`\0`), which is invalid in C strings. - /// - /// # See Also - /// - [VRF documentation](https://www.kernel.org/doc/Documentation/networking/vrf.txt) - /// - [`man 7 socket`](https://man7.org/linux/man-pages/man7/socket.7.html) - /// - [`man 7p ip`](https://docs.oracle.com/cd/E86824_01/html/E54777/ip-7p.html) - #[cfg(any( - target_os = "android", - target_os = "fuchsia", - target_os = "illumos", - target_os = "ios", - target_os = "linux", - target_os = "macos", - target_os = "solaris", - target_os = "tvos", - target_os = "visionos", - target_os = "watchos", - ))] - #[inline] - pub fn set_interface(&mut self, interface: S) -> &mut Self - where - S: Into>, - { - #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] - { - self.interface = Some(interface.into()); - } - - #[cfg(not(any(target_os = "android", target_os = "fuchsia", target_os = "linux")))] - { - self.interface = std::ffi::CString::new(interface.into().into_owned()).ok() - } - - self - } - - /// Set that all sockets are bound to the configured address before connection. - /// - /// If `None`, the sockets will not be bound. - /// - /// Default is `None`. - #[inline] - pub fn set_local_address(&mut self, local_addr: Option) { - match local_addr { - Some(IpAddr::V4(a)) => { - self.local_ipv4 = Some(a); - } - Some(IpAddr::V6(a)) => { - self.local_ipv6 = Some(a); - } - _ => {} - }; - } - - /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's - /// preferences) before connection. - #[inline] - pub fn set_local_addresses(&mut self, local_ipv4: V4, local_ipv6: V6) - where - V4: Into>, - V6: Into>, - { - self.local_ipv4 = local_ipv4.into(); - self.local_ipv6 = local_ipv6.into(); - } -} - -#[derive(Clone)] -struct Config { - connect_timeout: Option, - enforce_http: bool, - happy_eyeballs_timeout: Option, - tcp_keepalive_config: TcpKeepaliveConfig, - tcp_connect_options: TcpConnectOptions, - nodelay: bool, - reuse_address: bool, - send_buffer_size: Option, - recv_buffer_size: Option, - #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] - tcp_user_timeout: Option, -} - -#[derive(Default, Debug, Clone, Copy)] -struct TcpKeepaliveConfig { - time: Option, - #[cfg(any( - target_os = "android", - target_os = "dragonfly", - target_os = "freebsd", - target_os = "fuchsia", - target_os = "illumos", - target_os = "ios", - target_os = "visionos", - target_os = "linux", - target_os = "macos", - target_os = "netbsd", - target_os = "tvos", - target_os = "watchos", - target_os = "windows", - target_os = "cygwin", - ))] - interval: Option, - #[cfg(any( - target_os = "android", - target_os = "dragonfly", - target_os = "freebsd", - target_os = "fuchsia", - target_os = "illumos", - target_os = "ios", - target_os = "visionos", - target_os = "linux", - target_os = "macos", - target_os = "netbsd", - target_os = "tvos", - target_os = "watchos", - target_os = "cygwin", - target_os = "windows", - ))] - retries: Option, -} - -impl TcpKeepaliveConfig { - /// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration. - fn into_tcpkeepalive(self) -> Option { - let mut dirty = false; - let mut ka = TcpKeepalive::new(); - if let Some(time) = self.time { - ka = ka.with_time(time); - dirty = true - } - - // Set the value of the `TCP_KEEPINTVL` option. On Windows, this sets the - // value of the `tcp_keepalive` struct's `keepaliveinterval` field. - // - // Sets the time interval between TCP keepalive probes. - // - // Some platforms specify this value in seconds, so sub-second - // specifications may be omitted. - #[cfg(any( - target_os = "android", - target_os = "dragonfly", - target_os = "freebsd", - target_os = "fuchsia", - target_os = "illumos", - target_os = "ios", - target_os = "visionos", - target_os = "linux", - target_os = "macos", - target_os = "netbsd", - target_os = "tvos", - target_os = "watchos", - target_os = "windows", - target_os = "cygwin", - ))] - { - if let Some(interval) = self.interval { - dirty = true; - ka = ka.with_interval(interval) - }; - } - - // Set the value of the `TCP_KEEPCNT` option. - // - // Set the maximum number of TCP keepalive probes that will be sent before - // dropping a connection, if TCP keepalive is enabled on this socket. - #[cfg(any( - target_os = "android", - target_os = "dragonfly", - target_os = "freebsd", - target_os = "fuchsia", - target_os = "illumos", - target_os = "ios", - target_os = "visionos", - target_os = "linux", - target_os = "macos", - target_os = "netbsd", - target_os = "tvos", - target_os = "watchos", - target_os = "cygwin", - target_os = "windows", - ))] - if let Some(retries) = self.retries { - dirty = true; - ka = ka.with_retries(retries) - }; - - if dirty { Some(ka) } else { None } - } + pub(crate) remote_addr: SocketAddr, + pub(crate) local_addr: SocketAddr, } // ===== impl HttpConnector ===== -impl Default for HttpConnector { - fn default() -> Self { - Self::new() - } -} - -impl HttpConnector { - /// Construct a new HttpConnector. - pub fn new() -> HttpConnector { - HttpConnector::new_with_resolver(GaiResolver::new()) - } -} - -impl HttpConnector { +impl HttpConnector { /// Construct a new [`HttpConnector`]. - pub fn new_with_resolver(resolver: R) -> HttpConnector { + pub fn new(resolver: R, connector: S) -> HttpConnector { HttpConnector { - config: Arc::new(Config { + options: Arc::new(TcpOptions { connect_timeout: None, enforce_http: true, happy_eyeballs_timeout: Some(Duration::from_millis(300)), - tcp_keepalive_config: TcpKeepaliveConfig::default(), - tcp_connect_options: TcpConnectOptions::default(), nodelay: false, reuse_address: false, send_buffer_size: None, recv_buffer_size: None, #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] tcp_user_timeout: None, + tcp_keepalive_config: TcpKeepaliveOptions::default(), + socket_bind_options: SocketBindOptions::default(), }), resolver, + connector, } } + fn config_mut(&mut self) -> &mut TcpOptions { + // If the are HttpConnector clones, this will clone the inner + // config. So mutating the config won't ever affect previous + // clones. + Arc::make_mut(&mut self.options) + } +} + +impl HttpTransport for HttpConnector +where + R: InternalResolve + Clone + Send + Sync + 'static, + R::Future: Send, + S: TcpConnector, +{ /// Option to enforce all `Uri`s have the `http` scheme. /// /// Enabled by default. #[inline] - pub fn enforce_http(&mut self, is_enforced: bool) { + fn enforce_http(&mut self, is_enforced: bool) { self.config_mut().enforce_http = is_enforced; } - /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration - /// to remain idle before sending TCP keepalive probes. - /// - /// If `None`, keepalive is disabled. - /// - /// Default is `None`. - #[inline] - pub fn set_keepalive(&mut self, time: Option) { - self.config_mut().tcp_keepalive_config.time = time; - } - - /// Set the duration between two successive TCP keepalive retransmissions, - /// if acknowledgement to the previous keepalive transmission is not received. - #[inline] - pub fn set_keepalive_interval(&mut self, interval: Option) { - self.config_mut().tcp_keepalive_config.interval = interval; - } - - /// Set the number of retransmissions to be carried out before declaring that remote end is not - /// available. - #[inline] - pub fn set_keepalive_retries(&mut self, retries: Option) { - self.config_mut().tcp_keepalive_config.retries = retries; - } - /// Set that all sockets have `SO_NODELAY` set to the supplied value `nodelay`. /// /// Default is `false`. #[inline] - pub fn set_nodelay(&mut self, nodelay: bool) { + fn set_nodelay(&mut self, nodelay: bool) { self.config_mut().nodelay = nodelay; } /// Sets the value of the SO_SNDBUF option on the socket. #[inline] - pub fn set_send_buffer_size(&mut self, size: Option) { + fn set_send_buffer_size(&mut self, size: Option) { self.config_mut().send_buffer_size = size; } /// Sets the value of the SO_RCVBUF option on the socket. #[inline] - pub fn set_recv_buffer_size(&mut self, size: Option) { + fn set_recv_buffer_size(&mut self, size: Option) { self.config_mut().recv_buffer_size = size; } - /// Set the connect options to be used when connecting. + /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`. + /// + /// Default is `false`. #[inline] - pub fn set_connect_options(&mut self, opts: TcpConnectOptions) { - let this = self.config_mut(); - - #[cfg(any( - target_os = "android", - target_os = "fuchsia", - target_os = "illumos", - target_os = "ios", - target_os = "linux", - target_os = "macos", - target_os = "solaris", - target_os = "tvos", - target_os = "visionos", - target_os = "watchos", - ))] - if let Some(interface) = opts.interface { - this.tcp_connect_options.interface = Some(interface); - } - - if let Some(local_ipv4) = opts.local_ipv4 { - this.tcp_connect_options - .set_local_address(Some(local_ipv4.into())); - } + fn set_reuse_address(&mut self, reuse_address: bool) { + self.config_mut().reuse_address = reuse_address; + } - if let Some(local_ipv6) = opts.local_ipv6 { - this.tcp_connect_options - .set_local_address(Some(local_ipv6.into())); - } + /// Sets the value of the TCP_USER_TIMEOUT option on the socket. + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + #[inline] + fn set_tcp_user_timeout(&mut self, time: Option) { + self.config_mut().tcp_user_timeout = time; } /// Set the connect timeout. @@ -446,7 +233,7 @@ impl HttpConnector { /// /// Default is `None`. #[inline] - pub fn set_connect_timeout(&mut self, dur: Option) { + fn set_connect_timeout(&mut self, dur: Option) { self.config_mut().connect_timeout = dur; } @@ -463,55 +250,100 @@ impl HttpConnector { /// /// [RFC 6555]: https://tools.ietf.org/html/rfc6555 #[inline] - pub fn set_happy_eyeballs_timeout(&mut self, dur: Option) { + fn set_happy_eyeballs_timeout(&mut self, dur: Option) { self.config_mut().happy_eyeballs_timeout = dur; } - /// Set that all socket have `SO_REUSEADDR` set to the supplied value `reuse_address`. + /// Set that all sockets have `SO_KEEPALIVE` set with the supplied duration + /// to remain idle before sending TCP keepalive probes. /// - /// Default is `false`. + /// If `None`, keepalive is disabled. + /// + /// Default is `None`. #[inline] - pub fn set_reuse_address(&mut self, reuse_address: bool) -> &mut Self { - self.config_mut().reuse_address = reuse_address; - self + fn set_keepalive(&mut self, time: Option) { + self.config_mut().tcp_keepalive_config.time = time; } - /// Sets the value of the TCP_USER_TIMEOUT option on the socket. - #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + /// Set the duration between two successive TCP keepalive retransmissions, + /// if acknowledgement to the previous keepalive transmission is not received. #[inline] - pub fn set_tcp_user_timeout(&mut self, time: Option) { - self.config_mut().tcp_user_timeout = time; + fn set_keepalive_interval(&mut self, interval: Option) { + self.config_mut().tcp_keepalive_config.interval = interval; } - // private - - fn config_mut(&mut self) -> &mut Config { - // If the are HttpConnector clones, this will clone the inner - // config. So mutating the config won't ever affect previous - // clones. - Arc::make_mut(&mut self.config) + /// Set the number of retransmissions to be carried out before declaring that remote end is not + /// available. + #[inline] + fn set_keepalive_retries(&mut self, retries: Option) { + self.config_mut().tcp_keepalive_config.retries = retries; } -} -static INVALID_NOT_HTTP: &str = "invalid URI, scheme is not http"; -static INVALID_MISSING_SCHEME: &str = "invalid URI, scheme is missing"; -static INVALID_MISSING_HOST: &str = "invalid URI, host is missing"; + /// Sets the name of the interface to bind sockets produced by this + /// connector. + /// + /// On Linux, this sets the `SO_BINDTODEVICE` option on this socket (see + /// [`man 7 socket`] for details). On macOS (and macOS-derived systems like + /// iOS), illumos, and Solaris, this will instead use the `IP_BOUND_IF` + /// socket option (see [`man 7p ip`]). + /// + /// If a socket is bound to an interface, only packets received from that particular + /// interface are processed by the socket. Note that this only works for some socket + /// types, particularly `AF_INET`` sockets. + /// + /// On Linux it can be used to specify a [VRF], but the binary needs + /// to either have `CAP_NET_RAW` or to be run as root. + /// + /// This function is only available on the following operating systems: + /// - Linux, including Android + /// - Fuchsia + /// - illumos and Solaris + /// - macOS, iOS, visionOS, watchOS, and tvOS + /// + /// [VRF]: https://www.kernel.org/doc/Documentation/networking/vrf.txt + /// [`man 7 socket`]: https://man7.org/linux/man-pages/man7/socket.7.html + /// [`man 7p ip`]: https://docs.oracle.com/cd/E86824_01/html/E54777/ip-7p.html + #[cfg(any( + target_os = "android", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + ))] + fn set_interface>>(&mut self, interface: I) { + self.config_mut() + .socket_bind_options + .set_interface(interface); + } -// R: Debug required for now to allow adding it to debug output later... -impl fmt::Debug for HttpConnector { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.debug_struct("HttpConnector").finish() + /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's + /// preferences) before connection. + fn set_local_addresses(&mut self, addr_ipv4: V4, addr_ipv6: V6) + where + V4: Into>, + V6: Into>, + { + self.config_mut() + .socket_bind_options + .set_local_addresses(addr_ipv4, addr_ipv6); } } -impl tower::Service for HttpConnector +impl Service for HttpConnector where R: InternalResolve + Clone + Send + Sync + 'static, R::Future: Send, + S: TcpConnector, + S::TcpStream: From, { - type Response = TcpStream; + type Response = S::Connection; type Error = ConnectError; - type Future = HttpConnecting; + type Future = HttpConnecting; #[inline] fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { @@ -519,15 +351,42 @@ where } fn call(&mut self, dst: Uri) -> Self::Future { - let mut self_ = self.clone(); + let mut this = self.clone(); + + let fut = async move { + let options = &this.options; + + let (host, port) = get_host_port(options, &dst)?; + let host = host.trim_start_matches('[').trim_end_matches(']'); + + let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) { + addrs + } else { + let addrs = dns::resolve(&mut this.resolver, dns::Name::new(host.into())) + .await + .map_err(ConnectError::dns)?; + let addrs = addrs + .map(|mut addr| { + set_port(&mut addr, port, dst.port().is_some()); + addr + }) + .collect(); + dns::SocketAddrs::new(addrs) + }; + + ConnectingTcp::new(addrs, options, this.connector) + .connect(options) + .await + }; + HttpConnecting { - fut: Box::pin(async move { self_.call_async(dst).await }), + fut: Box::pin(fut), _marker: PhantomData, } } } -fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> { +fn get_host_port<'u>(options: &TcpOptions, dst: &'u Uri) -> Result<(&'u str, u16), ConnectError> { trace!( "Http::connect; scheme={:?}, host={:?}, port={:?}", dst.scheme(), @@ -535,7 +394,7 @@ fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), Co dst.port(), ); - if config.enforce_http { + if options.enforce_http { if dst.scheme() != Some(&Scheme::HTTP) { return Err(ConnectError { msg: INVALID_NOT_HTTP, @@ -575,57 +434,13 @@ fn get_host_port<'u>(config: &Config, dst: &'u Uri) -> Result<(&'u str, u16), Co Ok((host, port)) } -impl HttpConnector -where - R: InternalResolve, -{ - async fn call_async(&mut self, dst: Uri) -> Result { - let config = &self.config; - - let (host, port) = get_host_port(config, &dst)?; - let host = host.trim_start_matches('[').trim_end_matches(']'); - - // If the host is already an IP addr (v4 or v6), - // skip resolving the dns and start connecting right away. - let addrs = if let Some(addrs) = dns::SocketAddrs::try_parse(host, port) { - addrs - } else { - let addrs = resolve(&mut self.resolver, dns::Name::new(host.into())) - .await - .map_err(ConnectError::dns)?; - let addrs = addrs - .map(|mut addr| { - set_port(&mut addr, port, dst.port().is_some()); - addr - }) - .collect(); - dns::SocketAddrs::new(addrs) - }; - - let c = ConnectingTcp::new(addrs, config); - - let sock = c.connect().await?; - - if let Err(_e) = sock.set_nodelay(config.nodelay) { - warn!("tcp set_nodelay error: {_e}"); - } - - Ok(sock) - } -} - -impl Connection for TcpStream { - fn connected(&self) -> Connected { - let connected = Connected::new(); - if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) { - connected.extra(HttpInfo { - remote_addr, - local_addr, - }) - } else { - connected - } - } +/// Respect explicit ports in the URI, if none, either +/// keep non `0` ports resolved from a custom dns resolver, +/// or use the default port for the scheme. +fn set_port(addr: &mut SocketAddr, host_port: u16, explicit: bool) { + if explicit || addr.port() == 0 { + addr.set_port(host_port) + }; } impl HttpInfo { @@ -647,377 +462,22 @@ pin_project! { // so that users don't rely on it fitting in a `Pin>` slot // (and thus we can change the type in the future). #[must_use = "futures do nothing unless polled"] - pub struct HttpConnecting { + pub struct HttpConnecting { #[pin] - fut: BoxConnecting, + fut: BoxConnecting, _marker: PhantomData, } } -type ConnectResult = Result; -type BoxConnecting = Pin + Send>>; - -impl Future for HttpConnecting { - type Output = ConnectResult; +impl Future for HttpConnecting +where + R: InternalResolve, + S: TcpConnector, +{ + type Output = ConnectResult; + #[inline] fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { self.project().fut.poll(cx) } } - -// Not publicly exported (so missing_docs doesn't trigger). -pub struct ConnectError { - msg: &'static str, - addr: Option, - cause: Option, -} - -impl ConnectError { - fn new(msg: &'static str, cause: E) -> ConnectError - where - E: Into, - { - ConnectError { - msg, - addr: None, - cause: Some(cause.into()), - } - } - - fn dns(cause: E) -> ConnectError - where - E: Into, - { - ConnectError::new("dns error", cause) - } - - fn m(msg: &'static str) -> impl FnOnce(E) -> ConnectError - where - E: Into, - { - move |cause| ConnectError::new(msg, cause) - } -} - -impl fmt::Debug for ConnectError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let mut b = f.debug_tuple("ConnectError"); - b.field(&self.msg); - if let Some(ref addr) = self.addr { - b.field(addr); - } - if let Some(ref cause) = self.cause { - b.field(cause); - } - b.finish() - } -} - -impl fmt::Display for ConnectError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(self.msg) - } -} - -impl StdError for ConnectError { - fn source(&self) -> Option<&(dyn StdError + 'static)> { - self.cause.as_ref().map(|e| &**e as _) - } -} - -struct ConnectingTcp<'a> { - preferred: ConnectingTcpRemote, - fallback: Option, - config: &'a Config, -} - -impl<'a> ConnectingTcp<'a> { - fn new(remote_addrs: dns::SocketAddrs, config: &'a Config) -> Self { - if let Some(fallback_timeout) = config.happy_eyeballs_timeout { - let (preferred_addrs, fallback_addrs) = remote_addrs.split_by_preference( - config.tcp_connect_options.local_ipv4, - config.tcp_connect_options.local_ipv6, - ); - if fallback_addrs.is_empty() { - return ConnectingTcp { - preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout), - fallback: None, - config, - }; - } - - ConnectingTcp { - preferred: ConnectingTcpRemote::new(preferred_addrs, config.connect_timeout), - fallback: Some(ConnectingTcpFallback { - delay: tokio::time::sleep(fallback_timeout), - remote: ConnectingTcpRemote::new(fallback_addrs, config.connect_timeout), - }), - config, - } - } else { - ConnectingTcp { - preferred: ConnectingTcpRemote::new(remote_addrs, config.connect_timeout), - fallback: None, - config, - } - } - } -} - -struct ConnectingTcpFallback { - delay: Sleep, - remote: ConnectingTcpRemote, -} - -struct ConnectingTcpRemote { - addrs: dns::SocketAddrs, - connect_timeout: Option, -} - -impl ConnectingTcpRemote { - fn new(addrs: dns::SocketAddrs, connect_timeout: Option) -> Self { - let connect_timeout = connect_timeout.and_then(|t| t.checked_div(addrs.len() as u32)); - - Self { - addrs, - connect_timeout, - } - } -} - -impl ConnectingTcpRemote { - async fn connect(&mut self, config: &Config) -> Result { - let mut err = None; - for addr in &mut self.addrs { - debug!("connecting to {}", addr); - match connect(&addr, config, self.connect_timeout)?.await { - Ok(tcp) => { - debug!("connected to {}", addr); - return Ok(tcp); - } - Err(mut e) => { - e.addr = Some(addr); - // Only return the first error; assume it’s the most relevant. - if err.is_none() { - err = Some(e); - } - } - } - } - - match err { - Some(e) => Err(e), - None => Err(ConnectError::new( - "tcp connect error", - io::Error::new(io::ErrorKind::NotConnected, "Network unreachable"), - )), - } - } -} - -fn bind_local_address( - socket: &socket2::Socket, - dst_addr: &SocketAddr, - local_addr_ipv4: &Option, - local_addr_ipv6: &Option, -) -> io::Result<()> { - match (*dst_addr, local_addr_ipv4, local_addr_ipv6) { - (SocketAddr::V4(_), Some(addr), _) => { - socket.bind(&SocketAddr::new((*addr).into(), 0).into())?; - } - (SocketAddr::V6(_), _, Some(addr)) => { - socket.bind(&SocketAddr::new((*addr).into(), 0).into())?; - } - _ => { - if cfg!(windows) { - // Windows requires a socket be bound before calling connect - let any: SocketAddr = match *dst_addr { - SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(), - SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(), - }; - socket.bind(&any.into())?; - } - } - } - - Ok(()) -} - -fn connect( - addr: &SocketAddr, - config: &Config, - connect_timeout: Option, -) -> Result>, ConnectError> { - // TODO(eliza): if Tokio's `TcpSocket` gains support for setting the - // keepalive timeout, it would be nice to use that instead of socket2, - // and avoid the unsafe `into_raw_fd`/`from_raw_fd` dance... - use socket2::{Domain, Protocol, Socket, Type}; - - let domain = Domain::for_address(*addr); - let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)) - .map_err(ConnectError::m("tcp open error"))?; - - // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is - // responsible for ensuring O_NONBLOCK is set. - socket - .set_nonblocking(true) - .map_err(ConnectError::m("tcp set_nonblocking error"))?; - - if let Some(tcp_keepalive) = &config.tcp_keepalive_config.into_tcpkeepalive() { - if let Err(_e) = socket.set_tcp_keepalive(tcp_keepalive) { - warn!("tcp set_keepalive error: {_e}"); - } - } - - // That this only works for some socket types, particularly AF_INET sockets. - #[cfg(any( - target_os = "android", - target_os = "fuchsia", - target_os = "illumos", - target_os = "ios", - target_os = "linux", - target_os = "macos", - target_os = "solaris", - target_os = "tvos", - target_os = "visionos", - target_os = "watchos", - ))] - if let Some(interface) = &config.tcp_connect_options.interface { - // On Linux-like systems, set the interface to bind using - // `SO_BINDTODEVICE`. - #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] - socket - .bind_device(Some(interface.as_bytes())) - .map_err(ConnectError::m("tcp bind interface error"))?; - - // On macOS-like and Solaris-like systems, we instead use `IP_BOUND_IF`. - // This socket option desires an integer index for the interface, so we - // must first determine the index of the requested interface name using - // `if_nametoindex`. - #[cfg(any( - target_os = "illumos", - target_os = "ios", - target_os = "macos", - target_os = "solaris", - target_os = "tvos", - target_os = "visionos", - target_os = "watchos", - ))] - { - #[allow(unsafe_code)] - let idx = unsafe { libc::if_nametoindex(interface.as_ptr()) }; - let idx = std::num::NonZeroU32::new(idx).ok_or_else(|| { - // If the index is 0, check errno and return an I/O error. - ConnectError::new( - "error converting interface name to index", - io::Error::last_os_error(), - ) - })?; - - // Different setsockopt calls are necessary depending on whether the - // address is IPv4 or IPv6. - match addr { - SocketAddr::V4(_) => socket.bind_device_by_index_v4(Some(idx)), - SocketAddr::V6(_) => socket.bind_device_by_index_v6(Some(idx)), - } - .map_err(ConnectError::m("tcp bind interface error"))?; - } - } - - #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] - if let Some(tcp_user_timeout) = &config.tcp_user_timeout { - if let Err(_e) = socket.set_tcp_user_timeout(Some(*tcp_user_timeout)) { - warn!("tcp set_tcp_user_timeout error: {_e}"); - } - } - - bind_local_address( - &socket, - addr, - &config.tcp_connect_options.local_ipv4, - &config.tcp_connect_options.local_ipv6, - ) - .map_err(ConnectError::m("tcp bind local error"))?; - - // Safely convert socket2::Socket to tokio TcpSocket. - let socket = TcpSocket::from_std_stream(socket.into()); - - if config.reuse_address { - if let Err(_e) = socket.set_reuseaddr(true) { - warn!("tcp set_reuse_address error: {_e}"); - } - } - - if let Some(size) = config.send_buffer_size { - if let Err(_e) = socket.set_send_buffer_size(size.try_into().unwrap_or(u32::MAX)) { - warn!("tcp set_buffer_size error: {_e}"); - } - } - - if let Some(size) = config.recv_buffer_size { - if let Err(_e) = socket.set_recv_buffer_size(size.try_into().unwrap_or(u32::MAX)) { - warn!("tcp set_recv_buffer_size error: {_e}"); - } - } - - let connect = socket.connect(*addr); - Ok(async move { - match connect_timeout { - Some(dur) => match tokio::time::timeout(dur, connect).await { - Ok(Ok(s)) => Ok(s), - Ok(Err(e)) => Err(e), - Err(e) => Err(io::Error::new(io::ErrorKind::TimedOut, e)), - }, - None => connect.await, - } - .map_err(ConnectError::m("tcp connect error")) - }) -} - -impl ConnectingTcp<'_> { - async fn connect(mut self) -> Result { - match self.fallback { - None => self.preferred.connect(self.config).await, - Some(mut fallback) => { - let preferred_fut = self.preferred.connect(self.config); - futures_util::pin_mut!(preferred_fut); - - let fallback_fut = fallback.remote.connect(self.config); - futures_util::pin_mut!(fallback_fut); - - let fallback_delay = fallback.delay; - futures_util::pin_mut!(fallback_delay); - - let (result, future) = - match futures_util::future::select(preferred_fut, fallback_delay).await { - Either::Left((result, _fallback_delay)) => { - (result, Either::Right(fallback_fut)) - } - Either::Right(((), preferred_fut)) => { - // Delay is done, start polling both the preferred and the fallback - futures_util::future::select(preferred_fut, fallback_fut) - .await - .factor_first() - } - }; - - if result.is_err() { - // Fallback to the remaining future (could be preferred or fallback) - // if we get an error - future.await - } else { - result - } - } - } - } -} - -/// Respect explicit ports in the URI, if none, either -/// keep non `0` ports resolved from a custom dns resolver, -/// or use the default port for the scheme. -fn set_port(addr: &mut SocketAddr, host_port: u16, explicit: bool) { - if explicit || addr.port() == 0 { - addr.set_port(host_port) - }; -} diff --git a/src/client/conn/tcp.rs b/src/client/conn/tcp.rs new file mode 100644 index 00000000..342b40e9 --- /dev/null +++ b/src/client/conn/tcp.rs @@ -0,0 +1,664 @@ +pub mod tokio; + +#[cfg(any( + target_os = "illumos", + target_os = "ios", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + target_os = "android", + target_os = "fuchsia", + target_os = "linux", +))] +use std::borrow::Cow; +use std::{ + error::Error as StdError, + fmt, + future::Future, + io, + net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr}, + pin::pin, + time::Duration, +}; + +use futures_util::future::Either; +use socket2::TcpKeepalive; + +use super::Connection; +use crate::{dns, error::BoxError}; + +/// A builder for tcp connections. +pub trait TcpConnector: Clone + Send + Sync + 'static { + /// The underlying stream type. + type TcpStream: From + Send + Sync + 'static; + + /// The type of connection returned by this builder. + type Connection: ::tokio::io::AsyncRead + + ::tokio::io::AsyncWrite + + Connection + + Send + + Unpin + + 'static; + + /// The type of error returned by this builder. + type Error: Into>; + + /// The future type returned by this builder. + type Future: Future> + Send + 'static; + + /// The future type returned by this builder's sleep. + type Sleep: Future + Send + 'static; + + /// Build a connection from the given socket and connect to the address. + fn connect(&self, socket: Self::TcpStream, addr: SocketAddr) -> Self::Future; + + /// Return a future that sleeps for the given duration. + fn sleep(&self, duration: Duration) -> Self::Sleep; +} + +pub(super) struct ConnectingTcp { + preferred: ConnectingTcpRemote, + fallback: Option>, +} + +struct ConnectingTcpFallback { + delay: S::Sleep, + remote: ConnectingTcpRemote, +} + +struct ConnectingTcpRemote { + addrs: dns::SocketAddrs, + connect_timeout: Option, + connector: S, +} + +impl ConnectingTcp +where + S::TcpStream: From, +{ + pub(super) fn new(remote_addrs: dns::SocketAddrs, config: &TcpOptions, connector: S) -> Self { + if let Some(fallback_timeout) = config.happy_eyeballs_timeout { + let (preferred_addrs, fallback_addrs) = remote_addrs.split_by_preference( + config.socket_bind_options.local_address_ipv4, + config.socket_bind_options.local_address_ipv6, + ); + if fallback_addrs.is_empty() { + return ConnectingTcp { + preferred: ConnectingTcpRemote::new( + preferred_addrs, + config.connect_timeout, + connector, + ), + fallback: None, + }; + } + + ConnectingTcp { + preferred: ConnectingTcpRemote::new( + preferred_addrs, + config.connect_timeout, + connector.clone(), + ), + fallback: Some(ConnectingTcpFallback { + delay: connector.sleep(fallback_timeout), + remote: ConnectingTcpRemote::new( + fallback_addrs, + config.connect_timeout, + connector, + ), + }), + } + } else { + ConnectingTcp { + preferred: ConnectingTcpRemote::new( + remote_addrs, + config.connect_timeout, + connector, + ), + fallback: None, + } + } + } +} + +impl ConnectingTcpRemote +where + S::TcpStream: From, +{ + fn new(addrs: dns::SocketAddrs, connect_timeout: Option, connector: S) -> Self { + let connect_timeout = connect_timeout.and_then(|t| t.checked_div(addrs.len() as u32)); + + Self { + addrs, + connect_timeout, + connector, + } + } + + async fn connect(&mut self, config: &TcpOptions) -> Result { + let mut err = None; + for addr in &mut self.addrs { + debug!("connecting to {}", addr); + match connect(&addr, config, self.connect_timeout, &self.connector) { + Ok(fut) => match fut.await { + Ok(tcp) => { + debug!("connected to {}", addr); + return Ok(tcp); + } + Err(mut e) => { + trace!("connect error for {}: {:?}", addr, e); + e.addr = Some(addr); + if err.is_none() { + err = Some(e); + } + } + }, + Err(mut e) => { + trace!("connect error for {}: {:?}", addr, e); + e.addr = Some(addr); + if err.is_none() { + err = Some(e); + } + } + } + } + + match err { + Some(e) => Err(e), + None => Err(ConnectError::new( + "tcp connect error", + std::io::Error::new(std::io::ErrorKind::NotConnected, "Network unreachable"), + )), + } + } +} + +fn bind_local_address( + socket: &socket2::Socket, + dst_addr: &SocketAddr, + local_addr_ipv4: &Option, + local_addr_ipv6: &Option, +) -> io::Result<()> { + match (*dst_addr, local_addr_ipv4, local_addr_ipv6) { + (SocketAddr::V4(_), Some(addr), _) => { + socket.bind(&SocketAddr::new((*addr).into(), 0).into())?; + } + (SocketAddr::V6(_), _, Some(addr)) => { + socket.bind(&SocketAddr::new((*addr).into(), 0).into())?; + } + _ => { + if cfg!(windows) { + // Windows requires a socket be bound before calling connect + let any: SocketAddr = match *dst_addr { + SocketAddr::V4(_) => ([0, 0, 0, 0], 0).into(), + SocketAddr::V6(_) => ([0, 0, 0, 0, 0, 0, 0, 0], 0).into(), + }; + socket.bind(&any.into())?; + } + } + } + + Ok(()) +} + +fn connect( + addr: &SocketAddr, + config: &TcpOptions, + connect_timeout: Option, + connector: &S, +) -> Result>, ConnectError> +where + S::TcpStream: From, +{ + use socket2::{Domain, Protocol, Socket, Type}; + + let domain = Domain::for_address(*addr); + let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP)) + .map_err(ConnectError::m("tcp open error"))?; + + // When constructing a Tokio `TcpSocket` from a raw fd/socket, the user is + // responsible for ensuring O_NONBLOCK is set. + socket + .set_nonblocking(true) + .map_err(ConnectError::m("tcp set_nonblocking error"))?; + + if let Some(tcp_keepalive) = &config.tcp_keepalive_config.into_tcpkeepalive() { + if let Err(_e) = socket.set_tcp_keepalive(tcp_keepalive) { + warn!("tcp set_keepalive error: {_e}"); + } + } + + // That this only works for some socket types, particularly AF_INET sockets. + #[cfg(any( + target_os = "android", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + ))] + if let Some(interface) = &config.socket_bind_options.interface { + // On Linux-like systems, set the interface to bind using + // `SO_BINDTODEVICE`. + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + socket + .bind_device(Some(interface.as_bytes())) + .map_err(ConnectError::m("tcp bind interface error"))?; + + // On macOS-like and Solaris-like systems, we instead use `IP_BOUND_IF`. + // This socket option desires an integer index for the interface, so we + // must first determine the index of the requested interface name using + // `if_nametoindex`. + #[cfg(any( + target_os = "illumos", + target_os = "ios", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + ))] + if let Ok(interface) = std::ffi::CString::new(interface.as_bytes()) { + #[allow(unsafe_code)] + let idx = unsafe { libc::if_nametoindex(interface.as_ptr()) }; + let idx = std::num::NonZeroU32::new(idx).ok_or_else(|| { + // If the index is 0, check errno and return an I/O error. + ConnectError::new( + "error converting interface name to index", + io::Error::last_os_error(), + ) + })?; + + // Different setsockopt calls are necessary depending on whether the + // address is IPv4 or IPv6. + match addr { + SocketAddr::V4(_) => socket.bind_device_by_index_v4(Some(idx)), + SocketAddr::V6(_) => socket.bind_device_by_index_v6(Some(idx)), + } + .map_err(ConnectError::m("tcp bind interface error"))?; + } + } + + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + if let Some(tcp_user_timeout) = &config.tcp_user_timeout { + if let Err(_e) = socket.set_tcp_user_timeout(Some(*tcp_user_timeout)) { + warn!("tcp set_tcp_user_timeout error: {_e}"); + } + } + + bind_local_address( + &socket, + addr, + &config.socket_bind_options.local_address_ipv4, + &config.socket_bind_options.local_address_ipv6, + ) + .map_err(ConnectError::m("tcp bind local error"))?; + + if config.reuse_address { + if let Err(_e) = socket.set_reuse_address(true) { + warn!("tcp set_reuse_address error: {_e}"); + } + } + + if let Some(size) = config.send_buffer_size { + if let Err(_e) = socket.set_send_buffer_size(size) { + warn!("tcp set_buffer_size error: {_e}"); + } + } + + if let Some(size) = config.recv_buffer_size { + if let Err(_e) = socket.set_recv_buffer_size(size) { + warn!("tcp set_recv_buffer_size error: {_e}"); + } + } + + let connect = connector.connect(socket.into(), *addr); + let sleep = connect_timeout.map(|dur| connector.sleep(dur)); + + Ok(async move { + match sleep { + Some(sleep) => match futures_util::future::select(pin!(sleep), pin!(connect)).await { + Either::Left(((), _)) => { + Err(io::Error::new(io::ErrorKind::TimedOut, "connect timeout").into()) + } + Either::Right((Ok(s), _)) => Ok(s), + Either::Right((Err(e), _)) => Err(e.into()), + }, + None => connect.await.map_err(Into::into), + } + .map_err(ConnectError::m("tcp connect error")) + }) +} + +impl ConnectingTcp +where + S::TcpStream: From, +{ + pub(super) async fn connect( + mut self, + config: &TcpOptions, + ) -> Result { + match self.fallback { + None => self.preferred.connect(config).await, + Some(mut fallback) => { + let preferred_fut = self.preferred.connect(config); + futures_util::pin_mut!(preferred_fut); + + let fallback_fut = fallback.remote.connect(config); + futures_util::pin_mut!(fallback_fut); + + let fallback_delay = fallback.delay; + futures_util::pin_mut!(fallback_delay); + + let (result, future) = + match futures_util::future::select(preferred_fut, fallback_delay).await { + Either::Left((result, _fallback_delay)) => { + (result, Either::Right(fallback_fut)) + } + Either::Right(((), preferred_fut)) => { + // Delay is done, start polling both the preferred and the fallback + futures_util::future::select(preferred_fut, fallback_fut) + .await + .factor_first() + } + }; + + if result.is_err() { + // Fallback to the remaining future (could be preferred or fallback) + // if we get an error + future.await + } else { + result + } + } + } + } +} + +// Not publicly exported (so missing_docs doesn't trigger). +pub struct ConnectError { + pub(super) msg: &'static str, + pub(super) addr: Option, + pub(super) cause: Option, +} + +impl ConnectError { + pub(super) fn new(msg: &'static str, cause: E) -> ConnectError + where + E: Into, + { + ConnectError { + msg, + addr: None, + cause: Some(cause.into()), + } + } + + pub(super) fn dns(cause: E) -> ConnectError + where + E: Into, + { + ConnectError::new("dns error", cause) + } + + pub(super) fn m(msg: &'static str) -> impl FnOnce(E) -> ConnectError + where + E: Into, + { + move |cause| ConnectError::new(msg, cause) + } +} + +impl fmt::Debug for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut b = f.debug_tuple("ConnectError"); + b.field(&self.msg); + if let Some(ref addr) = self.addr { + b.field(addr); + } + if let Some(ref cause) = self.cause { + b.field(cause); + } + b.finish() + } +} + +impl fmt::Display for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.msg) + } +} + +impl StdError for ConnectError { + fn source(&self) -> Option<&(dyn StdError + 'static)> { + self.cause.as_ref().map(|e| &**e as _) + } +} + +/// Options for configuring socket bind behavior for outbound connections. +#[derive(Debug, Clone, Hash, PartialEq, Eq, Default)] +pub(crate) struct SocketBindOptions { + #[cfg(any( + target_os = "illumos", + target_os = "ios", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + target_os = "android", + target_os = "fuchsia", + target_os = "linux", + ))] + pub interface: Option>, + pub local_address_ipv4: Option, + pub local_address_ipv6: Option, +} + +impl SocketBindOptions { + /// Sets the name of the network interface to bind the socket to. + /// + /// ## Platform behavior + /// - On Linux/Fuchsia/Android: sets `SO_BINDTODEVICE` + /// - On macOS/illumos/Solaris/iOS/etc.: sets `IP_BOUND_IF` + /// + /// If `interface` is `None`, the socket will not be explicitly bound to any device. + /// + /// # Errors + /// + /// On platforms that require a `CString` (e.g. macOS), this will return an error if the + /// interface name contains an internal null byte (`\0`), which is invalid in C strings. + /// + /// # See Also + /// - [VRF documentation](https://www.kernel.org/doc/Documentation/networking/vrf.txt) + /// - [`man 7 socket`](https://man7.org/linux/man-pages/man7/socket.7.html) + /// - [`man 7p ip`](https://docs.oracle.com/cd/E86824_01/html/E54777/ip-7p.html) + #[cfg(any( + target_os = "android", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + ))] + #[inline] + pub fn set_interface(&mut self, interface: I) -> &mut Self + where + I: Into>, + { + self.interface = Some(interface.into()); + self + } + + /// Set that all sockets are bound to the configured address before connection. + /// + /// If `None`, the sockets will not be bound. + /// + /// Default is `None`. + #[inline] + pub fn set_local_address(&mut self, local_addr: Option) { + match local_addr { + Some(IpAddr::V4(a)) => { + self.local_address_ipv4 = Some(a); + } + Some(IpAddr::V6(a)) => { + self.local_address_ipv6 = Some(a); + } + _ => {} + }; + + let (v4, v6) = match local_addr { + Some(IpAddr::V4(a)) => (Some(a), None), + Some(IpAddr::V6(a)) => (None, Some(a)), + _ => (None, None), + }; + + self.local_address_ipv4 = v4; + self.local_address_ipv6 = v6; + } + + /// Set that all sockets are bound to the configured IPv4 or IPv6 address (depending on host's + /// preferences) before connection. + #[inline] + pub fn set_local_addresses(&mut self, local_ipv4: V4, local_ipv6: V6) + where + V4: Into>, + V6: Into>, + { + self.local_address_ipv4 = local_ipv4.into(); + self.local_address_ipv6 = local_ipv6.into(); + } +} + +#[derive(Clone)] +pub(crate) struct TcpOptions { + pub connect_timeout: Option, + pub enforce_http: bool, + pub happy_eyeballs_timeout: Option, + pub tcp_keepalive_config: TcpKeepaliveOptions, + pub socket_bind_options: SocketBindOptions, + pub nodelay: bool, + pub reuse_address: bool, + pub send_buffer_size: Option, + pub recv_buffer_size: Option, + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + pub tcp_user_timeout: Option, +} + +#[derive(Default, Debug, Clone, Copy)] +pub(crate) struct TcpKeepaliveOptions { + pub time: Option, + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "visionos", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + target_os = "windows", + target_os = "cygwin", + ))] + pub interval: Option, + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "visionos", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + target_os = "cygwin", + target_os = "windows", + ))] + pub retries: Option, +} + +impl TcpKeepaliveOptions { + /// Converts into a `socket2::TcpKeealive` if there is any keep alive configuration. + pub(crate) fn into_tcpkeepalive(self) -> Option { + let mut dirty = false; + let mut ka = TcpKeepalive::new(); + if let Some(time) = self.time { + ka = ka.with_time(time); + dirty = true + } + + // Set the value of the `TCP_KEEPINTVL` option. On Windows, this sets the + // value of the `tcp_keepalive` struct's `keepaliveinterval` field. + // + // Sets the time interval between TCP keepalive probes. + // + // Some platforms specify this value in seconds, so sub-second + // specifications may be omitted. + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "visionos", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + target_os = "windows", + target_os = "cygwin", + ))] + { + if let Some(interval) = self.interval { + dirty = true; + ka = ka.with_interval(interval) + }; + } + + // Set the value of the `TCP_KEEPCNT` option. + // + // Set the maximum number of TCP keepalive probes that will be sent before + // dropping a connection, if TCP keepalive is enabled on this socket. + #[cfg(any( + target_os = "android", + target_os = "dragonfly", + target_os = "freebsd", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "visionos", + target_os = "linux", + target_os = "macos", + target_os = "netbsd", + target_os = "tvos", + target_os = "watchos", + target_os = "cygwin", + target_os = "windows", + ))] + if let Some(retries) = self.retries { + dirty = true; + ka = ka.with_retries(retries) + }; + + if dirty { Some(ka) } else { None } + } +} diff --git a/src/client/conn/tcp/tokio.rs b/src/client/conn/tcp/tokio.rs new file mode 100644 index 00000000..c5539f06 --- /dev/null +++ b/src/client/conn/tcp/tokio.rs @@ -0,0 +1,52 @@ +use std::{future::Future, io, net::SocketAddr, pin::Pin, time::Duration}; + +use tokio::net::{TcpSocket, TcpStream}; + +use super::TcpConnector; +use crate::client::{Connected, Connection, conn::HttpInfo}; + +/// A connector that uses `tokio` for TCP connections. +#[derive(Clone, Copy, Debug, Default)] +pub struct TokioTcpConnector { + _priv: (), +} + +impl TokioTcpConnector { + /// Create a new [`TokioTcpConnector`]. + pub fn new() -> Self { + Self { _priv: () } + } +} + +impl TcpConnector for TokioTcpConnector { + type TcpStream = std::net::TcpStream; + type Connection = TcpStream; + type Error = io::Error; + type Future = Pin> + Send>>; + type Sleep = tokio::time::Sleep; + + #[inline] + fn connect(&self, socket: Self::TcpStream, addr: SocketAddr) -> Self::Future { + let socket = TcpSocket::from_std_stream(socket); + Box::pin(socket.connect(addr)) + } + + #[inline] + fn sleep(&self, duration: Duration) -> Self::Sleep { + tokio::time::sleep(duration) + } +} + +impl Connection for TcpStream { + fn connected(&self) -> Connected { + let connected = Connected::new(); + if let (Ok(remote_addr), Ok(local_addr)) = (self.peer_addr(), self.local_addr()) { + connected.extra(HttpInfo { + remote_addr, + local_addr, + }) + } else { + connected + } + } +} diff --git a/src/client/conn/uds.rs b/src/client/conn/uds.rs index 6df1cdee..b3a217c4 100644 --- a/src/client/conn/uds.rs +++ b/src/client/conn/uds.rs @@ -15,7 +15,16 @@ type ConnectResult = io::Result; type BoxConnecting = Pin + Send>>; #[derive(Clone)] -pub struct UnixConnector(pub(crate) Arc); +pub struct UnixConnector { + path: Arc, +} + +impl UnixConnector { + /// Create a new [`UnixConnector`]. + pub fn new(path: impl Into>) -> Self { + Self { path: path.into() } + } +} impl tower::Service for UnixConnector { type Response = UnixStream; @@ -28,7 +37,7 @@ impl tower::Service for UnixConnector { } fn call(&mut self, _: Uri) -> Self::Future { - let fut = UnixStream::connect(self.0.clone()); + let fut = UnixStream::connect(self.path.clone()); Box::pin(async move { let io = fut.await?; Ok::<_, io::Error>(io) diff --git a/src/client/core/conn/http2.rs b/src/client/core/conn/http2.rs index 80a564ae..880206b5 100644 --- a/src/client/core/conn/http2.rs +++ b/src/client/core/conn/http2.rs @@ -4,6 +4,7 @@ use std::{ future::Future, marker::PhantomData, pin::Pin, + sync::Arc, task::{Context, Poll, ready}, }; @@ -20,7 +21,7 @@ use crate::client::core::{ self, http2::{Http2Options, ping}, }, - rt::{ArcTimer, Time, Timer, bounds::Http2ClientConnExec}, + rt::{Time, Timer, bounds::Http2ClientConnExec}, }; /// The sender side of an established connection. @@ -189,7 +190,7 @@ where where M: Timer + Send + Sync + 'static, { - self.timer = Time::Timer(ArcTimer::new(timer)); + self.timer = Time::Timer(Arc::new(timer)); } /// Provide a options configuration for the HTTP/2 connection. diff --git a/src/client/core/proto/http2/ping.rs b/src/client/core/proto/http2/ping.rs index ca326d68..27e85478 100644 --- a/src/client/core/proto/http2/ping.rs +++ b/src/client/core/proto/http2/ping.rs @@ -31,7 +31,7 @@ use crate::{ client::core::{ self, error::{Error, Kind}, - rt::{Sleep, Time}, + rt::{Sleep, Time, Timer}, }, sync::Mutex, }; diff --git a/src/client/core/rt.rs b/src/client/core/rt.rs index d78a2f3e..0ec828e8 100644 --- a/src/client/core/rt.rs +++ b/src/client/core/rt.rs @@ -12,7 +12,7 @@ mod timer; mod tokio; pub use self::{ - timer::{ArcTimer, Sleep, Time, Timer}, + timer::{Sleep, Time, Timer}, tokio::{TokioExecutor, TokioTimer}, }; diff --git a/src/client/core/rt/timer.rs b/src/client/core/rt/timer.rs index d6efcd44..5b2a5aa9 100644 --- a/src/client/core/rt/timer.rs +++ b/src/client/core/rt/timer.rs @@ -41,21 +41,10 @@ pub trait Sleep: Send + Sync + Future { } } -/// A handle to a shared timer instance. -/// -/// `TimerHandle` provides a reference-counted, thread-safe handle to any type implementing the -/// [`Timer`] trait. It allows cloning and sharing a timer implementation across multiple components -/// or tasks. -/// -/// This is typically used to abstract over different timer backends and to provide a unified -/// interface for spawning sleep futures or scheduling timeouts. -#[derive(Clone)] -pub struct ArcTimer(Arc); - /// A user-provided timer to time background tasks. #[derive(Clone)] pub enum Time { - Timer(ArcTimer), + Timer(Arc), Empty, } @@ -91,35 +80,10 @@ impl dyn Sleep { } } -// =====impl ArcTimer ===== - -impl ArcTimer { - pub(crate) fn new(inner: T) -> Self - where - T: Timer + Send + Sync + 'static, - { - Self(Arc::new(inner)) - } -} +// ===== impl Time ===== -impl Timer for ArcTimer { +impl Timer for Time { fn sleep(&self, duration: Duration) -> Pin> { - self.0.sleep(duration) - } - - fn now(&self) -> Instant { - tokio::time::Instant::now().into() - } - - fn sleep_until(&self, deadline: Instant) -> Pin> { - self.0.sleep_until(deadline) - } -} - -// =====impl Time ===== - -impl Time { - pub(crate) fn sleep(&self, duration: Duration) -> Pin> { match *self { Time::Empty => { panic!("You must supply a timer.") @@ -128,14 +92,23 @@ impl Time { } } - pub(crate) fn now(&self) -> Instant { + fn now(&self) -> Instant { match *self { Time::Empty => Instant::now(), Time::Timer(ref t) => t.now(), } } - pub(crate) fn reset(&self, sleep: &mut Pin>, new_deadline: Instant) { + fn sleep_until(&self, deadline: Instant) -> Pin> { + match *self { + Time::Empty => { + panic!("You must supply a timer.") + } + Time::Timer(ref t) => t.sleep_until(deadline), + } + } + + fn reset(&self, sleep: &mut Pin>, new_deadline: Instant) { match *self { Time::Empty => { panic!("You must supply a timer.") diff --git a/src/client/core/rt/tokio.rs b/src/client/core/rt/tokio.rs index 3128a8bb..1c8e5ec1 100644 --- a/src/client/core/rt/tokio.rs +++ b/src/client/core/rt/tokio.rs @@ -82,9 +82,12 @@ impl TokioTimer { } } +// ===== impl TokioSleep ===== + impl Future for TokioSleep { type Output = (); + #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { self.project().inner.poll(cx) } @@ -93,6 +96,7 @@ impl Future for TokioSleep { impl Sleep for TokioSleep {} impl TokioSleep { + #[inline] fn reset(self: Pin<&mut Self>, deadline: Instant) { self.project().inner.as_mut().reset(deadline.into()); } diff --git a/src/client/http.rs b/src/client/http.rs index 76e22fc3..dcf65fe8 100644 --- a/src/client/http.rs +++ b/src/client/http.rs @@ -38,7 +38,8 @@ use super::ws::WebSocketRequestBuilder; use super::{ Body, EmulationFactory, conn::{ - BoxedConnectorLayer, BoxedConnectorService, Conn, Connector, TcpConnectOptions, Unnameable, + BoxedConnectorLayer, BoxedConnectorService, Conn, Connector, HttpTransport, + SocketBindOptions, Unnameable, }, core::{ body::Incoming, @@ -206,7 +207,7 @@ struct Config { tcp_send_buffer_size: Option, tcp_recv_buffer_size: Option, tcp_happy_eyeballs_timeout: Option, - tcp_connect_options: TcpConnectOptions, + socket_bind_options: SocketBindOptions, proxies: Vec, auto_sys_proxy: bool, retry_policy: retry::Policy, @@ -282,12 +283,12 @@ impl Client { tcp_keepalive_retries: Some(3), #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] tcp_user_timeout: Some(Duration::from_secs(30)), - tcp_connect_options: TcpConnectOptions::default(), tcp_nodelay: true, tcp_reuse_address: false, tcp_send_buffer_size: None, tcp_recv_buffer_size: None, tcp_happy_eyeballs_timeout: Some(Duration::from_millis(300)), + socket_bind_options: SocketBindOptions::default(), proxies: Vec::new(), auto_sys_proxy: true, retry_policy: retry::Policy::default(), @@ -511,24 +512,22 @@ impl ClientBuilder { let connector = Connector::builder(config.proxies, resolver) .timeout(config.connect_timeout) .tls_info(config.tls_info) - .tls_options(tls_options) .tcp_nodelay(config.tcp_nodelay) .verbose(config.connection_verbose) .with_tls(|tls| { - let alpn_protocol = match config.http_version_pref { + tls.alpn_protocol(match config.http_version_pref { HttpVersionPref::Http1 => Some(AlpnProtocol::HTTP1), HttpVersionPref::Http2 => Some(AlpnProtocol::HTTP2), _ => None, - }; - tls.alpn_protocol(alpn_protocol) - .max_version(config.max_tls_version) - .min_version(config.min_tls_version) - .tls_sni(config.tls_sni) - .verify_hostname(config.verify_hostname) - .cert_verification(config.cert_verification) - .cert_store(config.cert_store) - .identity(config.identity) - .keylog(config.keylog) + }) + .keylog(config.keylog) + .cert_store(config.cert_store) + .identity(config.identity) + .max_version(config.max_tls_version) + .min_version(config.min_tls_version) + .tls_sni(config.tls_sni) + .verify_hostname(config.verify_hostname) + .cert_verification(config.cert_verification) }) .with_http(|http| { http.enforce_http(false); @@ -536,16 +535,37 @@ impl ClientBuilder { http.set_keepalive_interval(config.tcp_keepalive_interval); http.set_keepalive_retries(config.tcp_keepalive_retries); http.set_reuse_address(config.tcp_reuse_address); - http.set_connect_options(config.tcp_connect_options); http.set_connect_timeout(config.connect_timeout); http.set_nodelay(config.tcp_nodelay); http.set_send_buffer_size(config.tcp_send_buffer_size); http.set_recv_buffer_size(config.tcp_recv_buffer_size); http.set_happy_eyeballs_timeout(config.tcp_happy_eyeballs_timeout); + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] http.set_tcp_user_timeout(config.tcp_user_timeout); + + #[cfg(any( + target_os = "android", + target_os = "fuchsia", + target_os = "illumos", + target_os = "ios", + target_os = "linux", + target_os = "macos", + target_os = "solaris", + target_os = "tvos", + target_os = "visionos", + target_os = "watchos", + ))] + if let Some(interface) = config.socket_bind_options.interface { + http.set_interface(interface); + } + + http.set_local_addresses( + config.socket_bind_options.local_address_ipv4, + config.socket_bind_options.local_address_ipv6, + ); }) - .build(config.connector_layers)?; + .build(tls_options, config.connector_layers)?; // Build client HttpClient::builder(TokioExecutor::new()) @@ -589,7 +609,10 @@ impl ClientBuilder { .service(service); let service = ServiceBuilder::new() - .layer(ResponseBodyTimeoutLayer::new(config.timeout_options)) + .layer(ResponseBodyTimeoutLayer::new( + TokioTimer::new(), + config.timeout_options, + )) .layer(ConfigServiceLayer::new( config.https_only, config.headers, @@ -1244,7 +1267,7 @@ impl ClientBuilder { T: Into>, { self.config - .tcp_connect_options + .socket_bind_options .set_local_address(addr.into()); self } @@ -1270,7 +1293,7 @@ impl ClientBuilder { V6: Into>, { self.config - .tcp_connect_options + .socket_bind_options .set_local_addresses(ipv4, ipv6); self } @@ -1339,7 +1362,7 @@ impl ClientBuilder { where T: Into>, { - self.config.tcp_connect_options.set_interface(interface); + self.config.socket_bind_options.set_interface(interface); self } diff --git a/src/client/http/client.rs b/src/client/http/client.rs index a2d2b8ea..511b2ce7 100644 --- a/src/client/http/client.rs +++ b/src/client/http/client.rs @@ -41,7 +41,7 @@ use crate::{ dispatch::TrySendError as ConnTrySendError, http1::Http1Options, http2::Http2Options, - rt::{ArcTimer, Executor, Timer}, + rt::{Executor, Time, Timer}, }, layer::config::RequestOptions, }, @@ -183,7 +183,7 @@ where let options = RequestConfig::::remove(req.extensions_mut()); // Apply HTTP/1 and HTTP/2 options if provided - if let Some(opts) = options.as_ref().map(RequestOptions::transport_opts) { + if let Some(opts) = options.as_ref().map(RequestOptions::transport_options) { if let Some(opts) = opts.http1_options() { this.h1_builder.options(opts.clone()); } @@ -818,7 +818,7 @@ pub struct Builder { h1_builder: conn::http1::Builder, h2_builder: conn::http2::Builder, pool_config: pool::Config, - pool_timer: Option, + pool_timer: Time, } // ===== impl Builder ===== @@ -845,7 +845,7 @@ impl Builder { max_idle_per_host: usize::MAX, max_pool_size: None, }, - pool_timer: None, + pool_timer: Time::Empty, } } /// Set an optional timeout for idle sockets being kept-alive. @@ -943,7 +943,7 @@ impl Builder { where M: Timer + Clone + Send + Sync + 'static, { - self.pool_timer = Some(ArcTimer::new(timer)); + self.pool_timer = Time::Timer(Arc::new(timer)); self } diff --git a/src/client/http/client/extra.rs b/src/client/http/client/extra.rs index cf08b4ce..5ab47e6e 100644 --- a/src/client/http/client/extra.rs +++ b/src/client/http/client/extra.rs @@ -4,7 +4,7 @@ use http::{Uri, Version}; use crate::{ client::{ - conn::TcpConnectOptions, + conn::SocketBindOptions, layer::config::{RequestOptions, TransportOptions}, }, hash::HashMemo, @@ -65,13 +65,13 @@ impl ConnectExtra { pub fn tls_options(&self) -> Option<&TlsOptions> { self.extra .as_ref() - .map(RequestOptions::transport_opts) + .map(RequestOptions::transport_options) .and_then(TransportOptions::tls_options) } - /// Return a reference to the [`TcpConnectOptions`]. + /// Return a reference to the [`SocketBindOptions`]. #[inline] - pub fn tcp_options(&self) -> Option<&TcpConnectOptions> { - self.extra.as_ref().map(RequestOptions::tcp_connect_opts) + pub fn tcp_options(&self) -> Option<&SocketBindOptions> { + self.extra.as_ref().map(RequestOptions::socket_bind_options) } } diff --git a/src/client/http/client/pool.rs b/src/client/http/client/pool.rs index 16cef920..6fc30aed 100644 --- a/src/client/http/client/pool.rs +++ b/src/client/http/client/pool.rs @@ -18,7 +18,7 @@ use lru::LruCache; use super::exec::{self, Exec}; use crate::{ - client::core::rt::{ArcTimer, Executor, Timer}, + client::core::rt::{Executor, Time, Timer}, sync::Mutex, }; @@ -94,7 +94,7 @@ struct PoolInner { // the Pool completely drops. That way, the interval can cancel immediately. idle_interval_ref: Option>, exec: Exec, - timer: Option, + timer: Time, timeout: Option, } @@ -116,10 +116,9 @@ impl Config { } impl Pool { - pub fn new(config: Config, executor: E, timer: Option) -> Pool + pub fn new(config: Config, executor: E, timer: Time) -> Pool where E: Executor + Send + Sync + Clone + 'static, - M: Timer + Send + Sync + Clone + 'static, { let inner = if config.is_enabled() { Some(Arc::new(Mutex::new(PoolInner { @@ -131,7 +130,7 @@ impl Pool { max_idle_per_host: config.max_idle_per_host, waiters: HashMap::default(), exec: Exec::new(executor), - timer: timer.map(ArcTimer::new), + timer, timeout: config.idle_timeout, }))) } else { @@ -302,7 +301,7 @@ impl<'a, T: Poolable + 'a, K: Debug> IdlePopper<'a, T, K> { impl PoolInner { fn now(&self) -> Instant { - self.timer.as_ref().map_or_else(Instant::now, ArcTimer::now) + self.timer.now() } fn put(&mut self, key: &K, value: T, __pool_ref: &Arc>>) { @@ -399,11 +398,9 @@ impl PoolInner { return; } - let timer = if let Some(timer) = self.timer.clone() { - timer - } else { + if matches!(self.timer, Time::Empty) { return; - }; + } // While someone might want a shorter duration, and it will be respected // at checkout time, there's no need to wake up and proactively evict @@ -422,7 +419,7 @@ impl PoolInner { self.idle_interval_ref = Some(tx); let interval = IdleTask { - timer: timer.clone(), + timer: self.timer.clone(), duration: dur, pool: WeakOpt::downgrade(pool_ref), pool_drop_notifier: rx, @@ -750,7 +747,7 @@ impl Expiration { } struct IdleTask { - timer: ArcTimer, + timer: Time, duration: Duration, pool: WeakOpt>>, // This allows the IdleTask to be notified as soon as the entire @@ -811,13 +808,14 @@ mod tests { hash::Hash, num::NonZero, pin::Pin, + sync::Arc, task::{self, Poll}, time::Duration, }; use super::{Connecting, Key, Pool, Poolable, Reservation, WeakOpt}; use crate::{ - client::core::rt::{ArcTimer, TokioExecutor, TokioTimer}, + client::core::rt::{Time, TokioExecutor, TokioTimer}, sync::MutexGuard, }; @@ -865,7 +863,7 @@ mod tests { max_pool_size: None, }, TokioExecutor::new(), - Option::::None, + Time::Empty, ) } @@ -979,7 +977,7 @@ mod tests { max_pool_size: None, }, TokioExecutor::new(), - Some(TokioTimer::new()), + Time::Timer(Arc::new(TokioTimer::new())), ); let key = host_key("foo"); @@ -1099,7 +1097,7 @@ mod tests { max_pool_size: Some(NonZero::new(2).expect("max pool size")), }, TokioExecutor::new(), - Option::::None, + Time::Empty, ); let key1 = host_key("foo"); let key2 = host_key("bar"); diff --git a/src/client/layer/config/options.rs b/src/client/layer/config/options.rs index f7b91d45..b697bd28 100644 --- a/src/client/layer/config/options.rs +++ b/src/client/layer/config/options.rs @@ -2,7 +2,7 @@ use http::Version; use crate::{ client::{ - conn::TcpConnectOptions, + conn::SocketBindOptions, core::{http1::Http1Options, http2::Http2Options}, }, proxy::Matcher, @@ -16,7 +16,7 @@ use crate::{ pub struct RequestOptions { proxy_matcher: Option, enforced_version: Option, - tcp_connect_opts: TcpConnectOptions, + socket_bind_options: SocketBindOptions, transport_opts: TransportOptions, } @@ -130,27 +130,27 @@ impl RequestOptions { &mut self.enforced_version } - /// Get a reference to the TCP connection options. + /// Get a reference to the socket bind options. #[inline] - pub fn tcp_connect_opts(&self) -> &TcpConnectOptions { - &self.tcp_connect_opts + pub fn socket_bind_options(&self) -> &SocketBindOptions { + &self.socket_bind_options } - /// Get a mutable reference to the TCP connection options. + /// Get a mutable reference to the socket bind options. #[inline] - pub fn tcp_connect_opts_mut(&mut self) -> &mut TcpConnectOptions { - &mut self.tcp_connect_opts + pub fn socket_bind_options_mut(&mut self) -> &mut SocketBindOptions { + &mut self.socket_bind_options } /// Get a reference to the transport options. #[inline] - pub fn transport_opts(&self) -> &TransportOptions { + pub fn transport_options(&self) -> &TransportOptions { &self.transport_opts } /// Get a mutable reference to the transport options. #[inline] - pub fn transport_opts_mut(&mut self) -> &mut TransportOptions { + pub fn transport_options_mut(&mut self) -> &mut TransportOptions { &mut self.transport_opts } } diff --git a/src/client/layer/cookie.rs b/src/client/layer/cookie.rs index 2bd161d5..148186cf 100644 --- a/src/client/layer/cookie.rs +++ b/src/client/layer/cookie.rs @@ -79,7 +79,7 @@ where impl CookieServiceLayer { /// Create a new [`CookieServiceLayer`]. #[inline(always)] - pub const fn new(store: Option>) -> Self { + pub fn new(store: Option>) -> Self { Self { store: RequestConfig::new(store), } diff --git a/src/client/layer/decoder.rs b/src/client/layer/decoder.rs index 0b521c8d..9c4acd1b 100644 --- a/src/client/layer/decoder.rs +++ b/src/client/layer/decoder.rs @@ -60,7 +60,7 @@ impl_request_config_value!(AcceptEncoding); impl DecompressionLayer { /// Creates a new [`DecompressionLayer`] with the specified [`AcceptEncoding`]. #[inline(always)] - pub const fn new(accept: AcceptEncoding) -> Self { + pub fn new(accept: AcceptEncoding) -> Self { Self { accept } } } diff --git a/src/client/layer/redirect.rs b/src/client/layer/redirect.rs index 11a33ebb..c42b23f7 100644 --- a/src/client/layer/redirect.rs +++ b/src/client/layer/redirect.rs @@ -64,7 +64,7 @@ pub struct FollowRedirectLayer

{ impl

FollowRedirectLayer

{ /// Create a new [`FollowRedirectLayer`] with the given redirection [`Policy`]. #[inline(always)] - pub const fn with_policy(policy: P) -> Self { + pub fn with_policy(policy: P) -> Self { FollowRedirectLayer { policy } } } @@ -95,7 +95,7 @@ where { /// Create a new [`FollowRedirect`] with the given redirection [`Policy`]. #[inline(always)] - pub const fn with_policy(inner: S, policy: P) -> Self { + pub fn with_policy(inner: S, policy: P) -> Self { FollowRedirect { inner, policy } } } @@ -111,7 +111,7 @@ where type Error = S::Error; type Future = ResponseFuture; - #[inline] + #[inline(always)] fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } diff --git a/src/client/layer/timeout.rs b/src/client/layer/timeout.rs index 9a706358..3b94f7b4 100644 --- a/src/client/layer/timeout.rs +++ b/src/client/layer/timeout.rs @@ -4,6 +4,7 @@ mod body; mod future; use std::{ + sync::Arc, task::{Context, Poll}, time::Duration, }; @@ -13,7 +14,11 @@ use tower::{Layer, Service}; pub use self::body::TimeoutBody; use self::future::{ResponseBodyTimeoutFuture, ResponseFuture}; -use crate::{config::RequestConfig, error::BoxError}; +use crate::{ + client::core::rt::{Time, Timer}, + config::RequestConfig, + error::BoxError, +}; /// Options for configuring timeouts. #[derive(Clone, Copy, Default)] @@ -49,8 +54,7 @@ pub struct TimeoutLayer { impl TimeoutLayer { /// Create a new [`TimeoutLayer`]. - #[inline(always)] - pub const fn new(options: TimeoutOptions) -> Self { + pub fn new(options: TimeoutOptions) -> Self { TimeoutLayer { timeout: RequestConfig::new(Some(options)), } @@ -104,14 +108,18 @@ where // This layer allows you to set a total timeout and a read timeout for the response body. #[derive(Clone)] pub struct ResponseBodyTimeoutLayer { + timer: Time, timeout: RequestConfig, } impl ResponseBodyTimeoutLayer { /// Creates a new [`ResponseBodyTimeoutLayer`]. - #[inline(always)] - pub const fn new(options: TimeoutOptions) -> Self { + pub fn new(timer: M, options: TimeoutOptions) -> Self + where + M: Timer + Send + Sync + 'static, + { Self { + timer: Time::Timer(Arc::new(timer)), timeout: RequestConfig::new(Some(options)), } } @@ -125,6 +133,7 @@ impl Layer for ResponseBodyTimeoutLayer { ResponseBodyTimeout { inner, timeout: self.timeout, + timer: self.timer.clone(), } } } @@ -135,6 +144,7 @@ impl Layer for ResponseBodyTimeoutLayer { pub struct ResponseBodyTimeout { inner: S, timeout: RequestConfig, + timer: Time, } impl Service> for ResponseBodyTimeout @@ -157,6 +167,7 @@ where inner: self.inner.call(req), total_timeout, read_timeout, + timer: self.timer.clone(), } } } diff --git a/src/client/layer/timeout/body.rs b/src/client/layer/timeout/body.rs index c7952086..8188eb43 100644 --- a/src/client/layer/timeout/body.rs +++ b/src/client/layer/timeout/body.rs @@ -7,10 +7,10 @@ use std::{ use http_body::Body; use pin_project_lite::pin_project; -use tokio::time::{Sleep, sleep}; use crate::{ Error, + client::core::rt::{Sleep, Time, Timer}, error::{BoxError, TimedOut}, }; @@ -46,7 +46,7 @@ pin_project! { pub struct TotalTimeoutBody { #[pin] body: B, - timeout: Pin>, + timeout: Pin>, } } @@ -58,17 +58,23 @@ pin_project! { pub struct ReadTimeoutBody { timeout: Duration, #[pin] - sleep: Option, + sleep: Option>>, #[pin] body: B, + timer: Time, } } /// ==== impl TimeoutBody ==== impl TimeoutBody { /// Creates a new [`TimeoutBody`] with no timeout. - pub fn new(deadline: Option, read_timeout: Option, body: B) -> Self { - let deadline = deadline.map(sleep).map(Box::pin); + pub fn new( + timer: Time, + deadline: Option, + read_timeout: Option, + body: B, + ) -> Self { + let deadline = deadline.map(|deadline| timer.sleep(deadline)); match (deadline, read_timeout) { (Some(total_timeout), Some(read_timeout)) => TimeoutBody::CombinedTimeout { body: TotalTimeoutBody { @@ -77,6 +83,7 @@ impl TimeoutBody { timeout: read_timeout, sleep: None, body, + timer, }, }, }, @@ -88,6 +95,7 @@ impl TimeoutBody { timeout, sleep: None, body, + timer, }, }, (None, None) => TimeoutBody::Plain { body }, @@ -199,7 +207,7 @@ where // Error if the timeout has expired. if this.sleep.is_none() { - this.sleep.set(Some(sleep(*this.timeout))); + this.sleep.set(Some(this.timer.sleep(*this.timeout))); } // Error if the timeout has expired. diff --git a/src/client/layer/timeout/future.rs b/src/client/layer/timeout/future.rs index 5800a9b1..faf7743f 100644 --- a/src/client/layer/timeout/future.rs +++ b/src/client/layer/timeout/future.rs @@ -10,7 +10,10 @@ use pin_project_lite::pin_project; use tokio::time::Sleep; use super::body::TimeoutBody; -use crate::error::{BoxError, Error, TimedOut}; +use crate::{ + client::core::rt::Time, + error::{BoxError, Error, TimedOut}, +}; pin_project! { /// [`Timeout`] response future @@ -71,6 +74,7 @@ pin_project! { pub(super) inner: Fut, pub(super) total_timeout: Option, pub(super) read_timeout: Option, + pub(super) timer: Time, } } @@ -81,10 +85,11 @@ where type Output = Result>, E>; fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let timer = self.timer.clone(); let total_timeout = self.total_timeout; let read_timeout = self.read_timeout; let res = ready!(self.project().inner.poll(cx))? - .map(|body| TimeoutBody::new(total_timeout, read_timeout, body)); + .map(|body| TimeoutBody::new(timer, total_timeout, read_timeout, body)); Poll::Ready(Ok(res)) } } diff --git a/src/client/request.rs b/src/client/request.rs index 8d452d8d..04573794 100644 --- a/src/client/request.rs +++ b/src/client/request.rs @@ -632,7 +632,7 @@ impl RequestBuilder { if let Ok(ref mut req) = self.request { req.config_mut::() .get_or_insert_default() - .tcp_connect_opts_mut() + .socket_bind_options_mut() .set_local_address(local_address.into()); } self @@ -647,7 +647,7 @@ impl RequestBuilder { if let Ok(ref mut req) = self.request { req.config_mut::() .get_or_insert_default() - .tcp_connect_opts_mut() + .socket_bind_options_mut() .set_local_addresses(ipv4, ipv6); } self @@ -719,7 +719,7 @@ impl RequestBuilder { if let Ok(ref mut req) = self.request { req.config_mut::() .get_or_insert_default() - .tcp_connect_opts_mut() + .socket_bind_options_mut() .set_interface(interface); } self @@ -735,7 +735,7 @@ impl RequestBuilder { let (transport_opts, default_headers, orig_headers) = emulation.into_parts(); req.config_mut::() .get_or_insert_default() - .transport_opts_mut() + .transport_options_mut() .apply_transport_options(transport_opts); self = self.headers(default_headers).orig_headers(orig_headers); } diff --git a/src/tls/conn.rs b/src/tls/conn.rs index eafc6737..db1aa1d6 100644 --- a/src/tls/conn.rs +++ b/src/tls/conn.rs @@ -62,14 +62,7 @@ pub struct HandshakeSettings { #[derive(Clone)] pub struct HttpsConnector { http: T, - inner: Inner, -} - -#[derive(Clone)] -struct Inner { - ssl: SslConnector, - cache: Option>>>, - settings: HandshakeSettings, + tls: TlsConnector, } /// A builder for creating a `TlsConnector`. @@ -90,7 +83,9 @@ pub struct TlsConnectorBuilder { /// A layer which wraps services in an `SslConnector`. #[derive(Clone)] pub struct TlsConnector { - inner: Inner, + ssl: SslConnector, + cache: Option>>>, + settings: HandshakeSettings, } // ===== impl HttpsConnector ===== @@ -104,24 +99,40 @@ where { /// Creates a new [`HttpsConnector`] with a given [`TlsConnector`]. #[inline] - pub fn with_connector(http: S, connector: TlsConnector) -> HttpsConnector { - HttpsConnector { - http, - inner: connector.inner, - } + pub fn new(http: S, tls: TlsConnector) -> HttpsConnector { + HttpsConnector { http, tls } } /// Disables ALPN negotiation. #[inline] pub fn no_alpn(&mut self) -> &mut Self { - self.inner.settings.alpn_protocols = None; + self.tls.settings.alpn_protocols = None; self } } -// ===== impl Inner ===== +// ===== impl TlsConnector ===== + +impl TlsConnector { + /// Creates a new [`TlsConnectorBuilder`] with the given configuration. + pub fn builder() -> TlsConnectorBuilder { + const DEFAULT_SESSION_CACHE_CAPACITY: usize = 8; + TlsConnectorBuilder { + session_cache: Arc::new(Mutex::new(SessionCache::with_capacity( + DEFAULT_SESSION_CACHE_CAPACITY, + ))), + alpn_protocol: None, + min_version: None, + max_version: None, + identity: None, + tls_sni: true, + verify_hostname: true, + cert_store: None, + cert_verification: true, + keylog: None, + } + } -impl Inner { fn setup_ssl(&self, uri: Uri) -> Result { let cfg = self.ssl.configure()?; let host = uri.host().ok_or("URI missing host")?; @@ -299,7 +310,12 @@ impl TlsConnectorBuilder { } /// Build the `TlsConnector` with the provided configuration. - pub fn build(&self, opts: &TlsOptions) -> crate::Result { + pub fn build<'a, T>(&self, opts: T) -> crate::Result + where + T: Into>, + { + let opts = opts.into(); + // Replace the default configuration with the provided one let max_tls_version = opts.max_tls_version.or(self.max_version); let min_tls_version = opts.min_tls_version.or(self.min_version); @@ -449,38 +465,13 @@ impl TlsConnectorBuilder { }); Ok(TlsConnector { - inner: Inner { - ssl: connector.build(), - cache, - settings, - }, + ssl: connector.build(), + cache, + settings, }) } } -// ===== impl TlsConnector ===== - -impl TlsConnector { - /// Creates a new `TlsConnectorBuilder` with the given configuration. - pub fn builder() -> TlsConnectorBuilder { - const DEFAULT_SESSION_CACHE_CAPACITY: usize = 8; - TlsConnectorBuilder { - session_cache: Arc::new(Mutex::new(SessionCache::with_capacity( - DEFAULT_SESSION_CACHE_CAPACITY, - ))), - alpn_protocol: None, - min_version: None, - max_version: None, - identity: None, - cert_store: None, - cert_verification: true, - tls_sni: true, - verify_hostname: true, - keylog: None, - } - } -} - /// A stream which may be wrapped with TLS. pub enum MaybeHttpsStream { /// A raw HTTP stream. diff --git a/src/tls/conn/service.rs b/src/tls/conn/service.rs index 18173362..7d011ab5 100644 --- a/src/tls/conn/service.rs +++ b/src/tls/conn/service.rs @@ -46,7 +46,7 @@ where fn call(&mut self, uri: Uri) -> Self::Future { let connect = self.http.call(uri.clone()); - let inner = self.inner.clone(); + let tls = self.tls.clone(); let f = async move { let conn = connect.await.map_err(Into::into)?; @@ -56,7 +56,7 @@ where return Ok(MaybeHttpsStream::Http(conn)); } - let ssl = inner.setup_ssl(uri)?; + let ssl = tls.setup_ssl(uri)?; perform_handshake(ssl, conn).await }; @@ -83,7 +83,7 @@ where fn call(&mut self, req: ConnectRequest) -> Self::Future { let uri = req.uri().clone(); let connect = self.http.call(uri.clone()); - let inner = self.inner.clone(); + let tls = self.tls.clone(); let f = async move { let conn = connect.await.map_err(Into::into)?; @@ -93,7 +93,7 @@ where return Ok(MaybeHttpsStream::Http(conn)); } - let ssl = inner.setup_ssl2(req)?; + let ssl = tls.setup_ssl2(req)?; perform_handshake(ssl, conn).await }; @@ -119,14 +119,14 @@ where } fn call(&mut self, conn: EstablishedConn) -> Self::Future { - let inner = self.inner.clone(); + let tls = self.tls.clone(); let fut = async move { // Early return if it is not a tls scheme if conn.req.uri().is_http() { return Ok(MaybeHttpsStream::Http(conn.io)); } - let ssl = inner.setup_ssl2(conn.req)?; + let ssl = tls.setup_ssl2(conn.req)?; perform_handshake(ssl, conn.io).await };