diff --git a/crates/anemo-tls/Cargo.toml b/crates/anemo-tls/Cargo.toml new file mode 100644 index 00000000..04bd1fb8 --- /dev/null +++ b/crates/anemo-tls/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "anemo-tls" +version = "0.0.0" +license = "Apache-2.0 OR MIT" +authors = [ "Andrew Schran " ] +edition = "2021" + +[dependencies] +anyhow = "1.0.56" +futures = "0.3.21" +futures-util = "0.3.21" +rustls = { version = "0.21.2", features = ["dangerous_configuration"] } +tokio = { version = "1.17.0", features = ["sync", "rt", "macros", "io-util"] } +tokio-rustls = { version = "0.24", features = ["dangerous_configuration"] } +tokio-util = { version = "0.7.4", features = ["compat"] } +tracing = "0.1.32" +yamux = { git = "https://github.com/aschran/rust-yamux.git", rev = "ed1bd31a75305ca7f13bdc0ccd309905fa2a42c0" } diff --git a/crates/anemo-tls/src/lib.rs b/crates/anemo-tls/src/lib.rs new file mode 100644 index 00000000..d3860761 --- /dev/null +++ b/crates/anemo-tls/src/lib.rs @@ -0,0 +1,351 @@ +use anyhow::Result; +use futures_util::{ + io::{ReadHalf, WriteHalf}, + StreamExt, +}; +use std::{ + net::SocketAddr, + pin::Pin, + sync::Arc, + task::{Context, Poll}, +}; +use tokio::{ + net::{TcpListener, TcpSocket, TcpStream}, + sync::{mpsc, Mutex}, +}; +use tokio_rustls::{TlsAcceptor, TlsStream}; +use tokio_util::compat::{ + Compat, FuturesAsyncReadCompatExt, FuturesAsyncWriteCompatExt, TokioAsyncReadCompatExt, +}; + +/// Configuration for outbound connections. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct ClientConfig { + pub(crate) tls: Arc, + pub(crate) socket_send_buffer_size: Option, + pub(crate) socket_receive_buffer_size: Option, + pub(crate) allow_failed_socket_buffer_size_setting: bool, +} + +impl ClientConfig { + pub fn new( + tls: tokio_rustls::rustls::ClientConfig, + socket_send_buffer_size: Option, + socket_receive_buffer_size: Option, + allow_failed_socket_buffer_size_setting: bool, + ) -> Self { + Self { + tls: Arc::new(tls), + socket_send_buffer_size, + socket_receive_buffer_size, + allow_failed_socket_buffer_size_setting, + } + } +} + +/// Configuration for inbound connections. +#[derive(Debug, Clone)] +#[non_exhaustive] +pub struct ServerConfig { + /// Transport configuration to use + pub(crate) tls: Arc, +} + +impl ServerConfig { + pub fn new(tls: tokio_rustls::rustls::ServerConfig) -> Self { + Self { tls: Arc::new(tls) } + } +} + +/// A TLS connection. +/// +/// May be cloned to obtain another handle to the same connection. +#[derive(Clone)] +pub struct Connection(ConnectionRef); + +impl Connection { + fn new(stream: TlsStream, peer_address: SocketAddr, mode: yamux::Mode) -> Self { + let (_, state) = stream.get_ref(); + let peer_identity = state.peer_certificates().map(|certs| certs[0].to_owned()); + + let (control, connection) = yamux::Control::new(yamux::Connection::new( + stream.compat(), + yamux::Config::default(), + mode, + )); + + // Weird quirk alert: + // yamux requires us to constantly drive the ControlledConnection or else new *outbound* + // streams will not be started, even if we never intend to accept inbound streams. + let (tx, rx) = mpsc::channel(1); + tokio::spawn(Self::yield_streams(connection, tx)); + + Self(ConnectionRef(Arc::new(ConnectionInner { + state: Mutex::new(ConnectionInnerState { rx_streams: rx }), + control, + peer_address, + peer_identity, + }))) + } + + async fn yield_streams( + mut connection: yamux::ControlledConnection>>, + tx: mpsc::Sender, + ) { + while let Some(stream) = connection.next().await { + match stream { + Ok(stream) => { + if tx.send(stream).await.is_err() { + // The receiver is gone, so we can stop. + break; + } + } + Err(e) => { + tracing::trace!("yamux stream error: {e}"); + break; + } + } + } + } + + pub fn peer_identity(&self) -> Option<&rustls::Certificate> { + self.0 .0.peer_identity.as_ref() + } + + pub fn stable_id(&self) -> usize { + &*self.0 .0 as *const _ as usize + } + + pub fn peer_address(&self) -> SocketAddr { + self.0 .0.peer_address + } + + pub async fn open_stream(&self) -> Result<(SendStream, RecvStream)> { + let stream = self.0 .0.control.clone().open_stream().await?; + let display_str = stream.to_string(); + let (read, write) = futures_util::AsyncReadExt::split(stream); + Ok(( + SendStream { + stream: write.compat_write(), + display_str: display_str.to_owned(), + }, + RecvStream { + stream: read.compat(), + display_str, + }, + )) + } + + pub async fn accept_stream(&self) -> Result<(SendStream, RecvStream)> { + let stream = self + .0 + .0 + .state + .lock() + .await + .rx_streams + .recv() + .await + .ok_or(anyhow::anyhow!("connection closed"))?; + let display_str = stream.to_string(); + let (read, write) = futures_util::AsyncReadExt::split(stream); + Ok(( + SendStream { + stream: write.compat_write(), + display_str: display_str.to_owned(), + }, + RecvStream { + stream: read.compat(), + display_str, + }, + )) + } +} + +#[derive(Clone)] +pub(crate) struct ConnectionRef(Arc); + +pub(crate) struct ConnectionInner { + state: Mutex, + control: yamux::Control, + peer_address: SocketAddr, + peer_identity: Option, +} + +pub(crate) struct ConnectionInnerState { + rx_streams: mpsc::Receiver, +} + +pub struct SendStream { + stream: Compat>, + display_str: String, +} + +impl std::fmt::Display for SendStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.display_str) + } +} + +impl std::ops::Deref for SendStream { + type Target = Compat>; + + fn deref(&self) -> &Self::Target { + &self.stream + } +} + +impl std::ops::DerefMut for SendStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.stream + } +} + +impl tokio::io::AsyncWrite for SendStream { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &[u8], + ) -> Poll> { + Pin::new(&mut self.stream).poll_write(cx, buf) + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.stream).poll_flush(cx) + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + Pin::new(&mut self.stream).poll_shutdown(cx) + } +} + +pub struct RecvStream { + stream: Compat>, + display_str: String, +} + +impl std::fmt::Display for RecvStream { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.display_str) + } +} + +impl std::ops::Deref for RecvStream { + type Target = Compat>; + + fn deref(&self) -> &Self::Target { + &self.stream + } +} + +impl std::ops::DerefMut for RecvStream { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.stream + } +} + +impl tokio::io::AsyncRead for RecvStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + Pin::new(&mut self.stream).poll_read(cx, buf) + } +} + +impl futures::AsyncRead for RecvStream { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + Pin::new(self.stream.get_mut()).poll_read(cx, buf) + } +} + +/// A TLS endpoint. +/// +/// An endpoint may host many connections, and may act as both client and server for different +/// connections. +/// +/// May be cloned to obtain another handle to the same endpoint. +#[derive(Clone)] +pub struct Endpoint { + pub(crate) inner: EndpointRef, +} + +impl Endpoint { + pub fn new(config: ServerConfig, listener: TcpListener) -> Self { + let acceptor = TlsAcceptor::from(config.tls); + Self { + inner: EndpointRef(Arc::new(EndpointInner { listener, acceptor })), + } + } + + /// Connect to a remote endpoint. + pub async fn connect_with( + &self, + config: ClientConfig, + addr: std::net::SocketAddr, + server_name: &str, + ) -> std::io::Result { + let parsed_server_name = rustls::ServerName::try_from(server_name).map_err(|_| { + std::io::Error::new( + std::io::ErrorKind::InvalidInput, + format!("invalid server_name: {server_name}"), + ) + })?; + + let connector = tokio_rustls::TlsConnector::from(config.tls.clone()); + + let socket = if addr.is_ipv4() { + TcpSocket::new_v4()? + } else { + TcpSocket::new_v6()? + }; + if let Some(size) = config.socket_send_buffer_size { + let result = socket.set_send_buffer_size(size as u32); + if !config.allow_failed_socket_buffer_size_setting { + result? + } + } + if let Some(size) = config.socket_receive_buffer_size { + let result = socket.set_recv_buffer_size(size as u32); + if !config.allow_failed_socket_buffer_size_setting { + result? + } + } + let stream = socket.connect(addr).await?; + + let stream = connector.connect(parsed_server_name, stream).await?; + Ok(Connection::new( + TlsStream::Client(stream), + addr, + yamux::Mode::Client, + )) + } + + pub async fn accept(&self) -> Result { + let (stream, peer_address) = self.inner.0.listener.accept().await?; + // TODO: This will drop/lose the TCP connection if the future is dropped before + // completion because ConnectionManager select loop finishes something else first. + // Ideally the connection should somehow be saved here, but it's not a huge deal + // since clients can just retry connecting. + let stream = self.inner.0.acceptor.accept(stream).await?; + Ok(Connection::new( + TlsStream::Server(stream), + peer_address, + yamux::Mode::Server, + )) + } +} + +#[derive(Clone)] +pub(crate) struct EndpointRef(Arc); + +pub(crate) struct EndpointInner { + pub(crate) listener: TcpListener, + pub(crate) acceptor: TlsAcceptor, +} diff --git a/crates/anemo/Cargo.toml b/crates/anemo/Cargo.toml index a36da897..8cf47db5 100644 --- a/crates/anemo/Cargo.toml +++ b/crates/anemo/Cargo.toml @@ -22,11 +22,11 @@ quinn-proto = "^0.10.0" rand = "0.8.5" ring = "0.16.7" rcgen = "0.9.2" -rustls = { version = "0.21.0", features = ["dangerous_configuration"] } +rustls = { version = "0.21.2", features = ["dangerous_configuration"] } serde = { version = "1.0.136", features = ["derive"] } serde_json = "1.0.83" tokio = { version = "1.17.0", features = ["sync", "rt", "macros", "io-util"] } -tokio-util = { version = "0.7.1", features = ["codec"] } +tokio-util = { version = "0.7.4", features = ["codec"] } tower = { version = "0.4.12", default-features = false, features = ["full"] } tracing = "0.1.32" webpki = { version = "0.22.0", features = ["alloc", "std"] } @@ -36,6 +36,9 @@ tap = "1.0.1" thiserror = "1.0.24" socket2 = "0.5.2" +anemo-tls = { path = "../anemo-tls" } + [dev-dependencies] +rstest = "0.17.0" tracing-subscriber = { version = "0.3.11", features = ["env-filter"] } tokio = { version = "1.17.0", features = ["full", "test-util"] } diff --git a/crates/anemo/src/config.rs b/crates/anemo/src/config.rs index a2f01c4d..6733199b 100644 --- a/crates/anemo/src/config.rs +++ b/crates/anemo/src/config.rs @@ -2,6 +2,7 @@ use crate::{ crypto::{CertVerifier, ExpectedCertVerifier}, PeerId, Result, }; +use anyhow::anyhow; use pkcs8::EncodePrivateKey; use quinn::VarInt; use rcgen::{CertificateParams, KeyPair, SignatureAlgorithm}; @@ -13,9 +14,11 @@ use std::{sync::Arc, time::Duration}; #[serde(rename_all = "kebab-case")] #[non_exhaustive] pub struct Config { - /// Configuration for the underlying QUIC transport. + /// Configuration for transport layer. + /// + /// If unspecified, uses default QUIC configuration. #[serde(skip_serializing_if = "Option::is_none")] - pub quic: Option, + pub transport: Option, /// Size of the internal `ConnectionManager`s mailbox. /// @@ -121,7 +124,42 @@ pub struct Config { pub shutdown_idle_timeout_ms: Option, } -/// Configuration for the underlying QUIC transport. +/// Configuration for the transport layer. +#[derive(Clone, Debug, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +pub enum TransportConfig { + Quic(QuicConfig), + Tls(TlsConfig), +} + +impl Default for TransportConfig { + fn default() -> Self { + Self::Quic(QuicConfig::default()) + } +} + +impl TransportConfig { + pub fn socket_send_buffer_size(&self) -> Option { + match self { + TransportConfig::Quic(config) => config.socket_send_buffer_size, + TransportConfig::Tls(config) => config.socket_send_buffer_size, + } + } + pub fn socket_receive_buffer_size(&self) -> Option { + match self { + TransportConfig::Quic(config) => config.socket_receive_buffer_size, + TransportConfig::Tls(config) => config.socket_receive_buffer_size, + } + } + pub fn allow_failed_socket_buffer_size_setting(&self) -> bool { + match self { + TransportConfig::Quic(config) => config.allow_failed_socket_buffer_size_setting, + TransportConfig::Tls(config) => config.allow_failed_socket_buffer_size_setting, + } + } +} + +/// Configuration for QUIC transport. #[derive(Clone, Debug, Default, Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] #[non_exhaustive] @@ -205,14 +243,30 @@ pub struct QuicConfig { pub allow_failed_socket_buffer_size_setting: bool, } -impl Config { - pub(crate) fn transport_config(&self) -> quinn::TransportConfig { - self.quic - .as_ref() - .map(QuicConfig::transport_config) - .unwrap_or_default() - } +/// Configuration for TLS transport. +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +#[serde(rename_all = "kebab-case")] +#[non_exhaustive] +pub struct TlsConfig { + /// Size of the send buffer on the TCP socket (`SO_SNDBUF`). + /// + /// If unspecified, this will use the operating system default. + #[serde(skip_serializing_if = "Option::is_none")] + pub socket_send_buffer_size: Option, + + /// Size of the receive buffer on the TCP socket (`SO_RCVBUF`). + /// + /// If unspecified, this will use the operating system default. + #[serde(skip_serializing_if = "Option::is_none")] + pub socket_receive_buffer_size: Option, + + /// If true, failure to set UDP socket buffer sizes as requested above will not + /// prevent a Network from starting. + #[serde(default)] + pub allow_failed_socket_buffer_size_setting: bool, +} +impl Config { pub(crate) fn connection_manager_channel_capacity(&self) -> usize { const CONNECTION_MANAGER_CHANNEL_CAPACITY: usize = 128; @@ -360,7 +414,7 @@ pub(crate) struct EndpointConfigBuilder { /// initiating outbound connections. pub alternate_server_name: Option, - pub transport_config: Option, + pub transport_config: Option, } impl EndpointConfigBuilder { @@ -378,8 +432,13 @@ impl EndpointConfigBuilder { self } - pub fn transport_config(mut self, transport_config: quinn::TransportConfig) -> Self { - self.transport_config = Some(transport_config); + pub fn transport_config(mut self, transport_config: &TransportConfig) -> Self { + self.transport_config = Some(match transport_config { + TransportConfig::Quic(quic) => { + InnerTransportConfig::Quic(Arc::new(quic.transport_config())) + } + TransportConfig::Tls(tls) => InnerTransportConfig::Tls(tls.to_owned()), + }); self } @@ -407,10 +466,13 @@ impl EndpointConfigBuilder { // Derive our quic reset key from our private key using an HKDF let reset_key = crate::crypto::construct_reset_key(&keypair.secret_key); - let quinn_endpoint_config = quinn::EndpointConfig::new(Arc::new(reset_key)); + let transport_endpoint_config = + TransportEndpointConfig::Quic(quinn::EndpointConfig::new(Arc::new(reset_key))); let primary_server_name = self.server_name.unwrap(); - let transport_config = Arc::new(self.transport_config.unwrap_or_default()); + let transport_config = self.transport_config.unwrap_or_else(|| { + InnerTransportConfig::Quic(Arc::new(quinn::TransportConfig::default())) + }); let cert_verifier = Arc::new(CertVerifier { server_names: vec![primary_server_name.clone()], @@ -423,7 +485,7 @@ impl EndpointConfigBuilder { primary_certificate.clone(), pkcs8_der.clone(), cert_verifier.clone(), - transport_config.clone(), + &transport_config, )?; let alternate_server_name = self.alternate_server_name; @@ -441,14 +503,14 @@ impl EndpointConfigBuilder { ], pkcs8_der.clone(), cert_verifier, - transport_config.clone(), + &transport_config, ) } _ => Self::server_config( vec![(primary_server_name.clone(), primary_certificate.clone())], pkcs8_der.clone(), cert_verifier, - transport_config.clone(), + &transport_config, ), }?; @@ -458,11 +520,11 @@ impl EndpointConfigBuilder { peer_id, client_certificate: primary_certificate, pkcs8_der, - quinn_server_config: server_config, - quinn_client_config: client_config, + server_config, + client_config, server_name: primary_server_name, transport_config, - quinn_endpoint_config, + transport_endpoint_config, }) } @@ -481,8 +543,8 @@ impl EndpointConfigBuilder { certs: Vec<(String, rustls::Certificate)>, pkcs8_der: rustls::PrivateKey, cert_verifier: Arc, - transport_config: Arc, - ) -> Result { + transport_config: &InnerTransportConfig, + ) -> Result { let mut server_cert_resolver = rustls::server::ResolvesServerCertUsingSni::new(); let key = rustls::sign::any_supported_type(&pkcs8_der) .map_err(|_| anyhow::anyhow!("invalid private key"))?; @@ -496,25 +558,120 @@ impl EndpointConfigBuilder { .with_client_cert_verifier(cert_verifier) .with_cert_resolver(Arc::new(server_cert_resolver)); - let mut server = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); - server.transport = transport_config; - Ok(server) + match transport_config { + InnerTransportConfig::Quic(transport_config) => { + let mut server = quinn::ServerConfig::with_crypto(Arc::new(server_crypto)); + server.transport = transport_config.clone(); + Ok(ServerConfig::Quic(server)) + } + InnerTransportConfig::Tls(_) => Ok(ServerConfig::Tls(anemo_tls::ServerConfig::new( + server_crypto, + ))), + } } fn client_config( cert: rustls::Certificate, pkcs8_der: rustls::PrivateKey, cert_verifier: Arc, - transport_config: Arc, - ) -> Result { + transport_config: &InnerTransportConfig, + ) -> Result { let client_crypto = rustls::ClientConfig::builder() .with_safe_defaults() .with_custom_certificate_verifier(cert_verifier) .with_single_cert(vec![cert], pkcs8_der)?; - let mut client = quinn::ClientConfig::new(Arc::new(client_crypto)); - client.transport_config(transport_config); - Ok(client) + match transport_config { + InnerTransportConfig::Quic(quic_config) => { + let mut client = quinn::ClientConfig::new(Arc::new(client_crypto)); + client.transport_config(quic_config.clone()); + Ok(ClientConfig::Quic(client)) + } + InnerTransportConfig::Tls(tls_config) => { + Ok(ClientConfig::Tls(anemo_tls::ClientConfig::new( + client_crypto, + tls_config.socket_send_buffer_size, + tls_config.socket_receive_buffer_size, + tls_config.allow_failed_socket_buffer_size_setting, + ))) + } + } + } +} + +#[derive(Clone, Debug)] +pub(crate) enum ServerConfig { + Quic(quinn::ServerConfig), + Tls(anemo_tls::ServerConfig), +} + +impl ServerConfig { + pub fn try_quic(&self) -> Result<&quinn::ServerConfig> { + match self { + ServerConfig::Quic(config) => Ok(config), + _ => Err(anyhow!("called try_quic on a non-Quic ServerConfig")), + } + } + pub fn try_tls(&self) -> Result<&anemo_tls::ServerConfig> { + match self { + ServerConfig::Tls(config) => Ok(config), + _ => Err(anyhow!("called try_tls on a non-Tls ServerConfig")), + } + } +} + +#[derive(Clone, Debug)] +pub(crate) enum ClientConfig { + Quic(quinn::ClientConfig), + Tls(anemo_tls::ClientConfig), +} + +impl ClientConfig { + pub fn try_quic(&self) -> Result<&quinn::ClientConfig> { + match self { + ClientConfig::Quic(inner) => Ok(inner), + _ => Err(anyhow!("called try_quic on a non-Quic ClientConfig")), + } + } + pub fn try_tls(&self) -> Result<&anemo_tls::ClientConfig> { + match self { + ClientConfig::Tls(inner) => Ok(inner), + _ => Err(anyhow!("called try_quic on a non-Quic ClientConfig")), + } + } +} + +#[derive(Clone, Debug)] +pub(crate) enum InnerTransportConfig { + Quic(Arc), + Tls(TlsConfig), +} + +impl InnerTransportConfig { + pub fn try_quic(&self) -> Result> { + match self { + InnerTransportConfig::Quic(inner) => Ok(inner.clone()), + _ => Err(anyhow!("called try_quic on a non-Quic TransportConfig")), + } + } + pub fn try_tls(&self) -> Result { + match self { + InnerTransportConfig::Tls(inner) => Ok(inner.clone()), + _ => Err(anyhow!("called try_tls on a non-TLS TransportConfig")), + } + } +} + +#[derive(Clone, Debug)] +pub(crate) enum TransportEndpointConfig { + Quic(quinn::EndpointConfig), +} + +impl TransportEndpointConfig { + pub fn as_quic(&self) -> &quinn::EndpointConfig { + match self { + TransportEndpointConfig::Quic(inner) => inner, + } } } @@ -524,16 +681,16 @@ pub(crate) struct EndpointConfig { // Store client certificate for outbound connections initiation client_certificate: rustls::Certificate, pkcs8_der: rustls::PrivateKey, - quinn_server_config: quinn::ServerConfig, - quinn_client_config: quinn::ClientConfig, + server_config: ServerConfig, + client_config: ClientConfig, /// Note that the end-entity certificate must have the /// [Subject Alternative Name](https://tools.ietf.org/html/rfc6125#section-4.1) /// extension to describe, e.g., the valid DNS name. server_name: String, - transport_config: Arc, - quinn_endpoint_config: quinn::EndpointConfig, + transport_config: InnerTransportConfig, + transport_endpoint_config: TransportEndpointConfig, } impl EndpointConfig { @@ -549,22 +706,19 @@ impl EndpointConfig { &self.server_name } - pub fn quinn_endpoint_config(&self) -> quinn::EndpointConfig { - self.quinn_endpoint_config.clone() + pub fn transport_endpoint_config(&self) -> &TransportEndpointConfig { + &self.transport_endpoint_config } - pub fn server_config(&self) -> &quinn::ServerConfig { - &self.quinn_server_config + pub fn server_config(&self) -> &ServerConfig { + &self.server_config } - pub fn client_config(&self) -> &quinn::ClientConfig { - &self.quinn_client_config + pub fn client_config(&self) -> &ClientConfig { + &self.client_config } - pub fn client_config_with_expected_server_identity( - &self, - peer_id: PeerId, - ) -> quinn::ClientConfig { + pub fn client_config_with_expected_server_identity(&self, peer_id: PeerId) -> ClientConfig { let server_cert_verifier = ExpectedCertVerifier( CertVerifier { server_names: vec![self.server_name().into()], @@ -580,16 +734,52 @@ impl EndpointConfig { ) .unwrap(); - let mut client = quinn::ClientConfig::new(Arc::new(client_crypto)); - client.transport_config(self.transport_config.clone()); - client + match self.client_config { + ClientConfig::Quic(_) => { + let mut client = quinn::ClientConfig::new(Arc::new(client_crypto)); + client.transport_config( + self.transport_config + .try_quic() + .expect("config variants must match"), + ); + ClientConfig::Quic(client) + } + ClientConfig::Tls(_) => { + let transport = self + .transport_config + .try_tls() + .expect("config variants must match"); + let client = anemo_tls::ClientConfig::new( + client_crypto, + transport.socket_send_buffer_size, + transport.socket_receive_buffer_size, + transport.allow_failed_socket_buffer_size_setting, + ); + ClientConfig::Tls(client) + } + } + } + + #[cfg(test)] + pub(crate) fn random_quic(server_name: &str) -> Self { + Self::builder() + .random_private_key() + .server_name(server_name) + .build() + .unwrap() } #[cfg(test)] - pub(crate) fn random(server_name: &str) -> Self { + pub(crate) fn random_tls(server_name: &str) -> Self { + let transport_config = TransportConfig::Tls(TlsConfig { + socket_receive_buffer_size: None, + socket_send_buffer_size: None, + allow_failed_socket_buffer_size_setting: true, + }); Self::builder() .random_private_key() .server_name(server_name) + .transport_config(&transport_config) .build() .unwrap() } diff --git a/crates/anemo/src/connection.rs b/crates/anemo/src/connection.rs index 8760b82f..cfc9632e 100644 --- a/crates/anemo/src/connection.rs +++ b/crates/anemo/src/connection.rs @@ -1,5 +1,4 @@ use crate::{ConnectionOrigin, PeerId, Result}; -use quinn::{ConnectionError, RecvStream}; use quinn_proto::ConnectionStats; use std::{ fmt, io, @@ -12,7 +11,7 @@ use tracing::trace; #[derive(Clone)] pub(crate) struct Connection { - inner: quinn::Connection, + inner: ConnectionInner, peer_id: PeerId, origin: ConnectionOrigin, @@ -21,7 +20,7 @@ pub(crate) struct Connection { } impl Connection { - pub fn new(inner: quinn::Connection, origin: ConnectionOrigin) -> Result { + pub fn new(inner: ConnectionInner, origin: ConnectionOrigin) -> Result { let peer_id = Self::try_peer_id(&inner)?; Ok(Self { inner, @@ -32,19 +31,27 @@ impl Connection { } /// Try to query Cryptographic identity of the peer - fn try_peer_id(connection: &quinn::Connection) -> Result { - // Query the certificate chain provided by a [TLS - // Connection](https://docs.rs/rustls/0.20.4/rustls/enum.Connection.html#method.peer_certificates). - // The first cert in the chain is guaranteed to be the peer - let peer_cert = &connection - .peer_identity() - .unwrap() - .downcast::>() - .unwrap()[0]; - - let peer_id = crate::crypto::peer_id_from_certificate(peer_cert)?; - - Ok(peer_id) + fn try_peer_id(connection: &ConnectionInner) -> Result { + match connection { + ConnectionInner::Quic(connection) => { + // Query the certificate chain provided by a [TLS + // Connection](https://docs.rs/rustls/0.20.4/rustls/enum.Connection.html#method.peer_certificates). + // The first cert in the chain is guaranteed to be the peer + let cert = &connection + .peer_identity() + .unwrap() + .downcast::>() + .unwrap()[0]; + crate::crypto::peer_id_from_certificate(cert) + } + ConnectionInner::Tls(connection) => { + let cert = connection.peer_identity().ok_or(anyhow::anyhow!( + "TLS connection does not have a peer identity" + ))?; + crate::crypto::peer_id_from_certificate(cert) + } + } + .map_err(Into::into) } /// PeerId of the Remote Peer @@ -68,18 +75,33 @@ impl Connection { /// Peer addresses and connection IDs can change, but this value will remain /// fixed for the lifetime of the connection. pub fn stable_id(&self) -> usize { - self.inner.stable_id() + match &self.inner { + ConnectionInner::Quic(connection) => connection.stable_id(), + ConnectionInner::Tls(connection) => connection.stable_id(), + } } /// Current best estimate of this connection's latency (round-trip-time) #[allow(unused)] pub fn rtt(&self) -> Duration { - self.inner.rtt() + match &self.inner { + ConnectionInner::Quic(connection) => connection.rtt(), + ConnectionInner::Tls(connection) => { + // TODO: Implement this for TLS connections, or change the interface. + Duration::ZERO + } + } } /// Returns connection statistics pub fn stats(&self) -> ConnectionStats { - self.inner.stats() + match &self.inner { + ConnectionInner::Quic(connection) => connection.stats(), + ConnectionInner::Tls(_connection) => { + // TODO: Implement this for TLS connections, or change the interface. + ConnectionStats::default() + } + } } /// The peer's UDP address @@ -87,15 +109,27 @@ impl Connection { /// If `ServerConfig::migration` is `true`, clients may change addresses at will, e.g. when /// switching to a cellular internet connection. pub fn remote_address(&self) -> SocketAddr { - self.inner.remote_address() + match &self.inner { + ConnectionInner::Quic(connection) => connection.remote_address(), + ConnectionInner::Tls(connection) => connection.peer_address(), + } } /// Open a unidirection stream to the peer. /// /// Messages sent over the stream will arrive at the peer in the order they were sent. - #[allow(dead_code)] pub async fn open_uni(&self) -> Result { - self.inner.open_uni().await.map(SendStream) + match &self.inner { + ConnectionInner::Quic(connection) => connection + .open_uni() + .await + .map(SendStream::Quic) + .map_err(Into::into), + ConnectionInner::Tls(connection) => { + let (send, _recv) = connection.open_stream().await?; + Ok(SendStream::Tls(send)) + } + } } /// Open a bidirectional stream to the peer. @@ -105,10 +139,18 @@ impl Connection { /// /// Messages sent over the stream will arrive at the peer in the order they were sent. pub async fn open_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> { - self.inner - .open_bi() - .await - .map(|(send, recv)| (SendStream(send), recv)) + match &self.inner { + ConnectionInner::Quic(connection) => connection + .open_bi() + .await + .map(|(send, recv)| (SendStream::Quic(send), RecvStream::Quic(recv))) + .map_err(Into::into), + ConnectionInner::Tls(connection) => connection + .open_stream() + .await + .map(|(send, recv)| (SendStream::Tls(send), RecvStream::Tls(recv))) + .map_err(Into::into), + } } /// Close the connection immediately. @@ -117,25 +159,44 @@ impl Connection { /// unfinished streams is not guaranteed to be delivered. pub fn close(&self) { trace!("Closing Connection"); - self.inner.close(0_u32.into(), b"connection closed") + match &self.inner { + ConnectionInner::Quic(connection) => { + connection.close(0_u32.into(), b"connection closed") + } + // TODO: Implement close for TLS. + ConnectionInner::Tls(_connection) => (), + } } /// Accept the next incoming uni-directional stream pub async fn accept_uni(&self) -> Result { - self.inner.accept_uni().await + match &self.inner { + ConnectionInner::Quic(connection) => connection + .accept_uni() + .await + .map(RecvStream::Quic) + .map_err(Into::into), + ConnectionInner::Tls(connection) => { + let (_send, recv) = connection.accept_stream().await?; + Ok(RecvStream::Tls(recv)) + } + } } /// Accept the next incoming bidirectional stream pub async fn accept_bi(&self) -> Result<(SendStream, RecvStream), ConnectionError> { - self.inner - .accept_bi() - .await - .map(|(send, recv)| (SendStream(send), recv)) - } - - /// Receive an application datagram - pub async fn read_datagram(&self) -> Result { - self.inner.read_datagram().await + match &self.inner { + ConnectionInner::Quic(connection) => connection + .accept_bi() + .await + .map(|(send, recv)| (SendStream::Quic(send), RecvStream::Quic(recv))) + .map_err(Into::into), + ConnectionInner::Tls(connection) => connection + .accept_stream() + .await + .map(|(send, recv)| (SendStream::Tls(send), RecvStream::Tls(recv))) + .map_err(Into::into), + } } } @@ -150,46 +211,137 @@ impl fmt::Debug for Connection { } } -/// A wrapper around a [quinn::SendStream] that enforces that the stream is shut down immediately -/// when dropped. The proper way to ensure that all data has been successfully transmitted and -/// Ack'd by the remote side is to call [quinn::SendStream::finish] prior to dropping the stream. -pub(crate) struct SendStream(quinn::SendStream); +#[derive(Clone)] +pub(crate) enum ConnectionInner { + Quic(quinn::Connection), + Tls(anemo_tls::Connection), +} -impl Drop for SendStream { - fn drop(&mut self) { - // We don't care if the stream has already been closed - let _ = self.0.reset(0u8.into()); +#[derive(Debug, thiserror::Error)] +pub enum ConnectionError { + #[error(transparent)] + Quic(#[from] quinn::ConnectionError), + #[error(transparent)] + Tls(#[from] anyhow::Error), +} + +/// A wrapper around a transport layer SendStream that enforces that the stream is shut down +/// immediately when dropped. The proper way to ensure that all data has been successfully +/// transmitted and Ack'd by the remote side is to call [SendStream::finish] prior to dropping +/// the stream. +pub(crate) enum SendStream { + Quic(quinn::SendStream), + Tls(anemo_tls::SendStream), +} + +impl fmt::Display for SendStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SendStream::Quic(stream) => write!(f, "Quic({})", stream.id()), + SendStream::Tls(stream) => write!(f, "Tls({})", stream), + } } } -impl std::ops::Deref for SendStream { - type Target = quinn::SendStream; +impl SendStream { + /// Shut down the send stream gracefully. No new data may be written after calling this method. + pub async fn finish(&mut self) -> Result<()> { + match self { + SendStream::Quic(stream) => stream.finish().await.map_err(Into::into), + SendStream::Tls(stream) => tokio::io::AsyncWriteExt::shutdown(stream) + .await + .map_err(Into::into), + } + } - fn deref(&self) -> &Self::Target { - &self.0 + /// Completes if/when the peer stops the stream. + pub async fn stopped(&mut self) { + match self { + SendStream::Quic(stream) => { + let _ = stream.stopped().await; + } + // Stream cannot be stopped/reset for yamux. + SendStream::Tls(_stream) => futures::future::pending().await, + } } } -impl std::ops::DerefMut for SendStream { - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.0 +impl Drop for SendStream { + fn drop(&mut self) { + match self { + SendStream::Quic(stream) => { + // We don't care if the stream has already been closed + let _ = stream.reset(0u8.into()); + } + // Nothing to do on drop for TLS. + SendStream::Tls(_stream) => (), + } } } impl tokio::io::AsyncWrite for SendStream { fn poll_write( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - Pin::new(&mut self.0).poll_write(cx, buf) + match self.get_mut() { + SendStream::Quic(stream) => Pin::new(stream).poll_write(cx, buf), + SendStream::Tls(stream) => Pin::new(stream).poll_write(cx, buf), + } } - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - Pin::new(&mut self.0).poll_flush(cx) + fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match self.get_mut() { + SendStream::Quic(stream) => Pin::new(stream).poll_flush(cx), + SendStream::Tls(stream) => Pin::new(stream).poll_flush(cx), + } } - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { - Pin::new(&mut self.0).poll_shutdown(cx) + fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll> { + match self.get_mut() { + SendStream::Quic(stream) => Pin::new(stream).poll_shutdown(cx), + SendStream::Tls(stream) => Pin::new(stream).poll_shutdown(cx), + } + } +} + +pub(crate) enum RecvStream { + Quic(quinn::RecvStream), + Tls(anemo_tls::RecvStream), +} + +impl fmt::Display for RecvStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + RecvStream::Quic(stream) => write!(f, "Quic({})", stream.id()), + RecvStream::Tls(stream) => write!(f, "Tls({})", stream), + } + } +} + +impl tokio::io::AsyncRead for RecvStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> Poll> { + match self.get_mut() { + RecvStream::Quic(stream) => Pin::new(stream).poll_read(cx, buf), + RecvStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf), + } + } +} + +impl futures::AsyncRead for RecvStream { + fn poll_read( + self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + match self.get_mut() { + RecvStream::Quic(stream) => Pin::new(stream).poll_read(cx, buf), + RecvStream::Tls(stream) => Pin::new(stream).poll_read(cx, buf), + } } } diff --git a/crates/anemo/src/endpoint.rs b/crates/anemo/src/endpoint.rs index f94bc06e..1961958e 100644 --- a/crates/anemo/src/endpoint.rs +++ b/crates/anemo/src/endpoint.rs @@ -1,88 +1,146 @@ +use crate::config::ClientConfig; +use crate::connection::ConnectionInner; use crate::{ config::EndpointConfig, connection::Connection, types::Address, ConnectionOrigin, PeerId, Result, }; +use std::net::SocketAddr; use std::sync::Arc; use std::time::Duration; -use std::{ - future::Future, - net::SocketAddr, - pin::Pin, - sync::RwLock, - task::{Context, Poll}, -}; -use tap::Pipe; use tokio::time::timeout; use tracing::{trace, warn}; -/// A QUIC endpoint. +enum Transport { + Quic(quinn::Endpoint), + Tls(anemo_tls::Endpoint), +} + +impl Transport { + async fn wait_idle(&self) { + match self { + Transport::Quic(inner) => inner.wait_idle().await, + Transport::Tls(_inner) => (), // TODO: add wait_idle + } + } + + fn drop_socket(&self) -> std::io::Result<()> { + match self { + Transport::Quic(inner) => { + let socket = std::net::UdpSocket::bind((std::net::Ipv4Addr::LOCALHOST, 0)).unwrap(); + inner.rebind(socket) + } + Transport::Tls(_inner) => unimplemented!(), + } + } +} + +/// A transport endpoint. /// -/// An endpoint corresponds to a single UDP socket, may host many connections, and may act as both -/// client and server for different connections. -#[derive(Debug)] +/// An endpoint may host many connections, and may act as both client and server for different +/// connections. pub(crate) struct Endpoint { - inner: quinn::Endpoint, - local_addr: RwLock, config: EndpointConfig, + local_addr: SocketAddr, + transport: Transport, } impl Endpoint { - pub fn new(config: EndpointConfig, socket: std::net::UdpSocket) -> Result { - let local_addr = socket.local_addr()?.pipe(RwLock::new); + pub fn new_quic(config: EndpointConfig, socket: std::net::UdpSocket) -> Result { + let local_addr = socket.local_addr()?; let server_config = config.server_config().clone(); let endpoint = quinn::Endpoint::new( - config.quinn_endpoint_config(), - Some(server_config), + config.transport_endpoint_config().as_quic().clone(), + Some(server_config.try_quic()?.clone()), socket, Arc::new(quinn::TokioRuntime), )?; let endpoint = Self { - inner: endpoint, - local_addr, config, + local_addr, + transport: Transport::Quic(endpoint), }; Ok(endpoint) } - #[cfg(test)] - fn new_with_address>(config: EndpointConfig, addr: A) -> Result { - let socket = std::net::UdpSocket::bind(addr.into())?; - Self::new(config, socket) + /// WARNING: TLS support is unstable, experimental, and incomplete. + pub fn new_tls(config: EndpointConfig, listener: tokio::net::TcpListener) -> Result { + let local_addr = listener.local_addr()?; + let endpoint = + anemo_tls::Endpoint::new(config.server_config().try_tls()?.clone(), listener); + Ok(Self { + config, + local_addr, + transport: Transport::Tls(endpoint), + }) } - pub fn connect(&self, address: Address) -> Result { + pub async fn connect(&self, address: Address) -> Result { self.connect_with_client_config(self.config.client_config().clone(), address) + .await } - pub fn connect_with_expected_peer_id( + pub async fn connect_with_expected_peer_id( &self, address: Address, peer_id: PeerId, - ) -> Result { + ) -> Result { let config = self .config .client_config_with_expected_server_identity(peer_id); - self.connect_with_client_config(config, address) + self.connect_with_client_config(config, address).await } - fn connect_with_client_config( + async fn connect_with_client_config( &self, - config: quinn::ClientConfig, + config: ClientConfig, address: Address, - ) -> Result { + ) -> Result { let addr = address.resolve()?; - - self.inner - .connect_with(config, addr, self.config.server_name()) - .map_err(Into::into) - .map(Connecting::new_outbound) + match self.transport { + Transport::Quic(ref inner) => inner + .connect_with( + config.try_quic()?.to_owned(), + addr, + self.config.server_name(), + )? + .await + .map_err(anyhow::Error::from) + .and_then(|connection| { + Connection::new( + ConnectionInner::Quic(connection), + ConnectionOrigin::Outbound, + ) + }), + Transport::Tls(ref inner) => { + inner + .connect_with( + config.try_tls()?.to_owned(), + addr, + self.config.server_name(), + ) + .await + .map_err(anyhow::Error::from) + .and_then(|connection| { + Connection::new( + ConnectionInner::Tls(connection), + ConnectionOrigin::Outbound, + ) + }) + } + .map_err(|e| { + anyhow::anyhow!( + "failed establishing {} connection: {e}", + ConnectionOrigin::Outbound + ) + }), + } } /// Returns the socket address that this Endpoint is bound to. pub fn local_addr(&self) -> SocketAddr { - *self.local_addr.read().unwrap() + self.local_addr } pub fn peer_id(&self) -> PeerId { @@ -96,7 +154,11 @@ impl Endpoint { /// Close all of this endpoint's connections immediately and cease accepting new connections. pub fn close(&self) { trace!("Closing endpoint"); - self.inner.close(0_u32.into(), b"endpoint closed") + match self.transport { + Transport::Quic(ref inner) => inner.close(0_u32.into(), b"endpoint closed"), + // TODO: add close for TLS. + Transport::Tls(ref _inner) => (), + } } /// Wait for all connections on the endpoint to be cleanly shut down @@ -113,7 +175,10 @@ impl Endpoint { /// /// [`close()`]: Endpoint::close pub async fn wait_idle(&self, max_timeout: Duration) { - if timeout(max_timeout, self.inner.wait_idle()).await.is_err() { + if timeout(max_timeout, self.transport.wait_idle()) + .await + .is_err() + { warn!( "Max timeout reached {}s while waiting for connections clean shutdown", max_timeout.as_secs_f64() @@ -121,110 +186,97 @@ impl Endpoint { } } - /// Switch to a new UDP socket + /// Ensures that the underlying socket we're bound to is dropped and immediately able to be + /// rebound to once this function exits. /// - /// Allows the endpoint's address to be updated live, affecting all active connections. Incoming - /// connections and connections to servers unreachable from the new address will be lost. - /// - /// On error, the old UDP socket is retained. - pub fn rebind(&self, socket: std::net::UdpSocket) -> std::io::Result<()> { - let local_addr = socket.local_addr()?; - self.inner.rebind(socket)?; - *self.local_addr.write().unwrap() = local_addr; - - Ok(()) + /// Behavior of the endpoint is undefined after this function is called. + pub fn drop_socket(&self) -> std::io::Result<()> { + self.transport.drop_socket() } - /// Get the next incoming connection attempt from a client - /// - /// Yields [`Connecting`] futures that must be `await`ed to obtain the final `Connection`, or + /// Get the next incoming connection attempt from a client, or /// `None` if the endpoint is [`close`](Self::close)d. - pub(crate) fn accept(&self) -> Accept<'_> { - Accept { - inner: self.inner.accept(), + pub(crate) async fn accept(&self) -> Option> { + match self.transport { + Transport::Quic(ref inner) => { + let connecting = inner.accept().await; + if let Some(connecting) = connecting { + Some( + connecting + .await + .map_err(anyhow::Error::from) + .and_then(|connection| { + Connection::new( + ConnectionInner::Quic(connection), + ConnectionOrigin::Inbound, + ) + }), + ) + } else { + None + } + } + Transport::Tls(ref inner) => Some( + inner + .accept() + .await + .map_err(anyhow::Error::from) + .and_then(|connection| { + Connection::new(ConnectionInner::Tls(connection), ConnectionOrigin::Inbound) + }), + ), } } } -pin_project_lite::pin_project! { - /// Future produced by [`Endpoint::accept`] - #[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"] - pub(crate) struct Accept<'a> { - #[pin] - inner: quinn::Accept<'a>, - } -} - -impl<'a> Future for Accept<'a> { - type Output = Option; - - fn poll(self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll { - self.project() - .inner - .poll(ctx) - .map(|maybe_connecting| maybe_connecting.map(Connecting::new_inbound)) - } -} - -#[derive(Debug)] -#[must_use = "futures/streams/sinks do nothing unless you `.await` or poll them"] -pub(crate) struct Connecting { - inner: quinn::Connecting, - origin: ConnectionOrigin, -} - -impl Connecting { - pub(crate) fn new(inner: quinn::Connecting, origin: ConnectionOrigin) -> Self { - Self { inner, origin } - } - - pub(crate) fn new_inbound(inner: quinn::Connecting) -> Self { - Self::new(inner, ConnectionOrigin::Inbound) - } - - pub(crate) fn new_outbound(inner: quinn::Connecting) -> Self { - Self::new(inner, ConnectionOrigin::Outbound) - } -} - -impl Future for Connecting { - type Output = Result; - - fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { - Pin::new(&mut self.inner).poll(cx).map(|result| { - result - .map_err(anyhow::Error::from) - .and_then(|connection| Connection::new(connection, self.origin)) - .map_err(|e| anyhow::anyhow!("failed establishing {} connection: {e}", self.origin)) - }) - } -} - #[cfg(test)] mod test { use super::*; use futures::{future::join, io::AsyncReadExt}; + use rstest::*; use std::time::Duration; + use tokio::io::AsyncWriteExt; + + #[fixture] + async fn quic_endpoint() -> Endpoint { + let config = EndpointConfig::random_quic("test"); + let socket = std::net::UdpSocket::bind("localhost:0").unwrap(); + Endpoint::new_quic(config, socket).unwrap() + } + #[fixture] + async fn tls_endpoint() -> Endpoint { + let config = EndpointConfig::random_tls("test"); + let listener = tokio::net::TcpListener::bind("localhost:0").await.unwrap(); + Endpoint::new_tls(config, listener).unwrap() + } + + #[rstest] + #[case::quic(quic_endpoint(), quic_endpoint())] + #[case::tls(tls_endpoint(), tls_endpoint())] + #[awt] #[tokio::test] - async fn basic_endpoint() -> Result<()> { + async fn basic_endpoint( + #[future] + #[case] + endpoint_1: Endpoint, + #[future] + #[case] + endpoint_2: Endpoint, + ) -> Result<()> { let _guard = crate::init_tracing_for_testing(); let msg = b"hello"; - let config_1 = EndpointConfig::random("test"); - let endpoint_1 = Endpoint::new_with_address(config_1, "localhost:0")?; let peer_id_1 = endpoint_1.config.peer_id(); println!("1: {}", endpoint_1.local_addr()); - let config_2 = EndpointConfig::random("test"); - let endpoint_2 = Endpoint::new_with_address(config_2, "localhost:0")?; let peer_id_2 = endpoint_2.config.peer_id(); let addr_2 = endpoint_2.local_addr(); println!("2: {}", endpoint_2.local_addr()); let peer_1 = async move { - let connection = endpoint_1.connect(addr_2.into()).unwrap().await.unwrap(); + let connection = endpoint_1.connect(addr_2.into()).await.unwrap(); assert_eq!(connection.peer_id(), peer_id_2); { let mut send_stream = connection.open_uni().await.unwrap(); @@ -232,12 +284,12 @@ mod test { send_stream.finish().await.unwrap(); } endpoint_1.close(); - endpoint_1.inner.wait_idle().await; + endpoint_1.transport.wait_idle().await; // Result::<()>::Ok(()) }; let peer_2 = async move { - let connection = endpoint_2.accept().await.unwrap().await.unwrap(); + let connection = endpoint_2.accept().await.unwrap().unwrap(); assert_eq!(connection.peer_id(), peer_id_1); let mut recv = connection.accept_uni().await.unwrap(); @@ -248,7 +300,7 @@ mod test { println!("from remote: {}", buf.escape_ascii()); assert_eq!(buf, msg); endpoint_2.close(); - endpoint_2.inner.wait_idle().await; + endpoint_2.transport.wait_idle().await; // Result::<()>::Ok(()) }; @@ -258,27 +310,34 @@ mod test { // Test to verify that multiple connections to the same endpoint can be open simultaneously. // While we don't currently allow for this, we may want to eventually enable/allow for it. + #[rstest] + #[case::quic(quic_endpoint(), quic_endpoint())] + #[case::tls(tls_endpoint(), tls_endpoint())] + #[awt] #[tokio::test] - async fn multiple_connections() -> Result<()> { + async fn multiple_connections( + #[future] + #[case] + endpoint_1: Endpoint, + #[future] + #[case] + endpoint_2: Endpoint, + ) -> Result<()> { let _guard = crate::init_tracing_for_testing(); let msg = b"hello"; - let config_1 = EndpointConfig::random("test"); - let endpoint_1 = Endpoint::new_with_address(config_1, "localhost:0")?; let peer_id_1 = endpoint_1.config.peer_id(); println!("1: {}", endpoint_1.local_addr()); - let config_2 = EndpointConfig::random("test"); - let endpoint_2 = Endpoint::new_with_address(config_2, "localhost:0")?; let peer_id_2 = endpoint_2.config.peer_id(); let addr_2 = endpoint_2.local_addr(); println!("2: {}", endpoint_2.local_addr()); let peer_1 = async move { - let connection_1 = endpoint_1.connect(addr_2.into()).unwrap().await.unwrap(); + let connection_1 = endpoint_1.connect(addr_2.into()).await.unwrap(); assert_eq!(connection_1.peer_id(), peer_id_2); - let connection_2 = endpoint_1.connect(addr_2.into()).unwrap().await.unwrap(); + let connection_2 = endpoint_1.connect(addr_2.into()).await.unwrap(); assert_eq!(connection_2.peer_id(), peer_id_2); let req_1 = async { let mut send_stream = connection_2.open_uni().await.unwrap(); @@ -292,15 +351,15 @@ mod test { }; join(req_1, req_2).await; endpoint_1.close(); - endpoint_1.inner.wait_idle().await; + endpoint_1.transport.wait_idle().await; // Result::<()>::Ok(()) }; let peer_2 = async move { - let connection_1 = endpoint_2.accept().await.unwrap().await.unwrap(); + let connection_1 = endpoint_2.accept().await.unwrap().unwrap(); assert_eq!(connection_1.peer_id(), peer_id_1); - let connection_2 = endpoint_2.accept().await.unwrap().await.unwrap(); + let connection_2 = endpoint_2.accept().await.unwrap().unwrap(); assert_eq!(connection_2.peer_id(), peer_id_1); assert_ne!(connection_1.stable_id(), connection_2.stable_id()); @@ -328,7 +387,7 @@ mod test { join(req_1, req_2).await; endpoint_2.close(); - endpoint_2.inner.wait_idle().await; + endpoint_2.transport.wait_idle().await; // Result::<()>::Ok(()) }; @@ -336,21 +395,27 @@ mod test { Ok(()) } + #[rstest] + #[case::quic(quic_endpoint(), quic_endpoint())] + #[case::tls(tls_endpoint(), tls_endpoint())] + #[awt] #[tokio::test] - async fn peers_concurrently_finishing_uni_stream_before_accepting() -> Result<()> { + async fn peers_concurrently_finishing_uni_stream_before_accepting( + #[future] + #[case] + endpoint_1: Endpoint, + #[future] + #[case] + endpoint_2: Endpoint, + ) -> Result<()> { let _guard = crate::init_tracing_for_testing(); let msg = b"hello"; - let config_1 = EndpointConfig::random("test"); - let endpoint_1 = Endpoint::new_with_address(config_1, "localhost:0")?; - - let config_2 = EndpointConfig::random("test"); - let endpoint_2 = Endpoint::new_with_address(config_2, "localhost:0")?; let addr_2 = endpoint_2.local_addr(); let (connection_1_to_2, connection_2_to_1) = timeout(join( - async { endpoint_1.connect(addr_2.into()).unwrap().await.unwrap() }, - async { endpoint_2.accept().await.unwrap().await.unwrap() }, + async { endpoint_1.connect(addr_2.into()).await.unwrap() }, + async { endpoint_2.accept().await.unwrap().unwrap() }, )) .await .unwrap(); diff --git a/crates/anemo/src/lib.rs b/crates/anemo/src/lib.rs index a36093c4..ce07aaf3 100644 --- a/crates/anemo/src/lib.rs +++ b/crates/anemo/src/lib.rs @@ -9,7 +9,7 @@ mod routing; pub mod rpc; pub mod types; -pub use config::{Config, QuicConfig}; +pub use config::{Config, QuicConfig, TlsConfig, TransportConfig}; pub use error::{Error, Result}; pub use network::{Builder, KnownPeers, Network, NetworkRef, Peer}; pub use routing::Router; diff --git a/crates/anemo/src/network/connection_manager.rs b/crates/anemo/src/network/connection_manager.rs index 322b8fcf..ec80855b 100644 --- a/crates/anemo/src/network/connection_manager.rs +++ b/crates/anemo/src/network/connection_manager.rs @@ -2,7 +2,7 @@ use super::request_handler::InboundRequestHandler; use crate::{ config::Config, connection::Connection, - endpoint::{Connecting, Endpoint}, + endpoint::Endpoint, types::{Address, DisconnectReason, PeerAffinity, PeerEvent, PeerInfo}, ConnectionOrigin, PeerId, Request, Response, Result, }; @@ -17,7 +17,7 @@ use tokio::{ task::JoinSet, }; use tower::util::BoxCloneService; -use tracing::{debug, info, instrument, trace}; +use tracing::{debug, info, instrument, trace, warn}; #[derive(Debug)] pub enum ConnectionManagerRequest { @@ -188,14 +188,9 @@ impl ConnectionManager { self.endpoint .wait_idle(self.config.shutdown_idle_timeout()) .await; - - // This is a small hack in order to ensure that the underlying socket we're bound to is - // dropped and immediately available to be rebound to once this function exits. - // In essence we construct a new UdpSocket, bound on some ephemeral localhost port, and - // swap it in for the socket the endpoint is currently bound to, causing it to be dropped - // and freed. - let socket = std::net::UdpSocket::bind((std::net::Ipv4Addr::LOCALHOST, 0)).unwrap(); - self.endpoint.rebind(socket).unwrap(); + if let Err(e) = self.endpoint.drop_socket() { + warn!("unable to drop socket: {e}"); + } } /// This method adds an established connection with a peer to the map of active peers. @@ -227,11 +222,11 @@ impl ConnectionManager { self.dial_peer(address, peer_id, oneshot); } - fn handle_incoming(&mut self, connecting: Connecting) { + fn handle_incoming(&mut self, connection: Result) { trace!("received new incoming connection"); self.pending_connections.spawn(Self::handle_incoming_task( - connecting, + connection, self.config.clone(), self.active_peers.clone(), self.known_peers.clone(), @@ -239,13 +234,13 @@ impl ConnectionManager { } async fn handle_incoming_task( - connecting: Connecting, + connection: Result, config: Arc, active_peers: ActivePeers, known_peers: KnownPeers, ) -> ConnectingOutput { let fut = async { - let connection = connecting.await?; + let connection = connection?; // TODO close the connection explicitly with a reason once we have machine // readable errors. See https://github.com/MystenLabs/anemo/issues/13 for more info. @@ -421,33 +416,31 @@ impl ConnectionManager { peer_id: Option, oneshot: oneshot::Sender>, ) { - let target_address = address.clone(); - let maybe_connecting = if let Some(peer_id) = peer_id { - self.endpoint - .connect_with_expected_peer_id(address, peer_id) - } else { - self.endpoint.connect(address) - }; self.pending_connections.spawn(Self::dial_peer_task( - maybe_connecting, - target_address, + self.endpoint.clone(), + address, peer_id, oneshot, self.config.clone(), )); } - // TODO maybe look at cloning the endpoint so that we can try multiple addresses in the event - // Address resolves to multiple ips. + // TODO maybe try multiple addresses in the event Address resolves to multiple ips. async fn dial_peer_task( - maybe_connecting: Result, + endpoint: Arc, target_address: Address, peer_id: Option, oneshot: oneshot::Sender>, config: Arc, ) -> ConnectingOutput { let fut = async { - let connection = maybe_connecting?.await?; + let connection = if let Some(peer_id) = peer_id { + endpoint + .connect_with_expected_peer_id(target_address.clone(), peer_id) + .await + } else { + endpoint.connect(target_address.clone()).await + }?; super::wire::handshake(connection).await }; @@ -788,11 +781,11 @@ mod tests { } #[tokio::test] - async fn shutdown() { + async fn shutdown_quic() { let socket = std::net::UdpSocket::bind((std::net::Ipv4Addr::LOCALHOST, 0)).unwrap(); let address = socket.local_addr().unwrap(); - let config = crate::config::EndpointConfig::random("test"); - let endpoint = Arc::new(Endpoint::new(config, socket).unwrap()); + let config = crate::config::EndpointConfig::random_quic("test"); + let endpoint = Arc::new(Endpoint::new_quic(config, socket).unwrap()); let (connection_manager, sender) = ConnectionManager::new( Default::default(), endpoint, diff --git a/crates/anemo/src/network/mod.rs b/crates/anemo/src/network/mod.rs index abe540f9..d814674b 100644 --- a/crates/anemo/src/network/mod.rs +++ b/crates/anemo/src/network/mod.rs @@ -1,5 +1,5 @@ use crate::{ - config::EndpointConfig, + config::{EndpointConfig, TransportConfig}, endpoint::Endpoint, middleware::{add_extension::AddExtensionLayer, timeout}, types::{Address, DisconnectReason, PeerEvent}, @@ -127,13 +127,13 @@ impl Builder { >>::Future: Send + 'static, { let config = self.config.unwrap_or_default(); - let quic_config = config.quic.clone().unwrap_or_default(); + let transport_config = config.transport.clone().unwrap_or_default(); let primary_server_name = self.server_name.unwrap(); let alternate_server_name = self.alternate_server_name; let private_key = self.private_key.unwrap(); let endpoint_config = EndpointConfig::builder() - .transport_config(config.transport_config()) + .transport_config(&transport_config) .server_name(primary_server_name) .alternate_server_name(alternate_server_name) .private_key(private_key) @@ -143,8 +143,14 @@ impl Builder { let socket = (|| { let mut result = Err(anyhow!("no addresses to bind to")); for addr in addrs.iter() { - let socket = - Socket::new(Domain::for_address(*addr), Type::DGRAM, Some(Protocol::UDP))?; + let socket = match transport_config { + TransportConfig::Quic(_) => { + Socket::new(Domain::for_address(*addr), Type::DGRAM, Some(Protocol::UDP))? + } + TransportConfig::Tls(_) => { + Socket::new(Domain::for_address(*addr), Type::STREAM, None)? + } + }; result = socket .bind(&socket2::SockAddr::from(*addr)) .map_err(|e| e.into()); @@ -155,11 +161,11 @@ impl Builder { Err(result.unwrap_err()) })()?; let socket_send_buf_size = if let Some(send_buffer_size) = - quic_config.socket_send_buffer_size + transport_config.socket_send_buffer_size() { let result = socket.set_send_buffer_size(send_buffer_size); if let Err(e) = result { - if quic_config.allow_failed_socket_buffer_size_setting { + if transport_config.allow_failed_socket_buffer_size_setting() { warn!("failed to set socket send buffer size to {send_buffer_size}: {e}",); } else { return Err(e.into()); @@ -171,7 +177,7 @@ impl Builder { let msg = format!( "expected socket send buffer size to be at least {send_buffer_size}, got {buf_size}" ); - if quic_config.allow_failed_socket_buffer_size_setting { + if transport_config.allow_failed_socket_buffer_size_setting() { warn!(msg); } else { return Err(anyhow!(msg)); @@ -182,11 +188,11 @@ impl Builder { socket.send_buffer_size()? }; let socket_receive_buf_size = if let Some(receive_buffer_size) = - quic_config.socket_receive_buffer_size + transport_config.socket_receive_buffer_size() { let result = socket.set_recv_buffer_size(receive_buffer_size); if let Err(e) = result { - if quic_config.allow_failed_socket_buffer_size_setting { + if transport_config.allow_failed_socket_buffer_size_setting() { warn!("failed to set socket receive buffer size to {receive_buffer_size}: {e}",); } else { return Err(e.into()); @@ -198,7 +204,7 @@ impl Builder { let msg = format!( "expected socket receive buffer size to be at least {receive_buffer_size}, got {buf_size}", ); - if quic_config.allow_failed_socket_buffer_size_setting { + if transport_config.allow_failed_socket_buffer_size_setting() { warn!(msg); } else { return Err(anyhow!(msg)); @@ -209,8 +215,18 @@ impl Builder { socket.recv_buffer_size()? }; - let endpoint = Endpoint::new(endpoint_config, socket.into())?; - + let endpoint = match transport_config { + TransportConfig::Quic(_) => Endpoint::new_quic(endpoint_config, socket.into())?, + TransportConfig::Tls(_tls) => { + socket.listen(128)?; + let tcp_listener: std::net::TcpListener = socket.into(); + tcp_listener.set_nonblocking(true)?; + Endpoint::new_tls( + endpoint_config, + tokio::net::TcpListener::from_std(tcp_listener)?, + )? + } + }; let config = Arc::new(config); let endpoint = Arc::new(endpoint); let active_peers = ActivePeers::new(config.peer_event_broadcast_channel_capacity()); diff --git a/crates/anemo/src/network/request_handler.rs b/crates/anemo/src/network/request_handler.rs index 11f8b5fb..45c07558 100644 --- a/crates/anemo/src/network/request_handler.rs +++ b/crates/anemo/src/network/request_handler.rs @@ -3,11 +3,10 @@ use super::{ ActivePeers, }; use crate::{ - connection::{Connection, SendStream}, + connection::{Connection, RecvStream, SendStream}, Config, Request, Response, Result, }; use bytes::Bytes; -use quinn::RecvStream; use std::convert::Infallible; use std::sync::Arc; use tokio_util::codec::{FramedRead, FramedWrite, LengthDelimitedCodec}; @@ -48,21 +47,10 @@ impl InboundRequestHandler { let close_reason = loop { tokio::select! { - // anemo does not currently use uni streams so we can - // just ignore and drop the stream - uni = self.connection.accept_uni() => { - match uni { - Ok(recv_stream) => trace!("incoming uni stream! {}", recv_stream.id()), - Err(e) => { - trace!("error listening for incoming uni streams: {e}"); - break e; - } - } - }, bi = self.connection.accept_bi() => { match bi { Ok((bi_tx, bi_rx)) => { - trace!("incoming bi stream! {}", bi_tx.id()); + trace!("incoming bi stream! {}", bi_tx); let request_handler = BiStreamRequestHandler::new(&self.config, self.connection.clone(), self.service.clone(), bi_tx, bi_rx); inflight_requests.spawn(request_handler.handle()); @@ -73,17 +61,6 @@ impl InboundRequestHandler { } } }, - // anemo does not currently use datagrams so we can - // just ignore them - datagram = self.connection.read_datagram() => { - match datagram { - Ok(datagram) => trace!("incoming datagram of length: {}", datagram.len()), - Err(e) => { - trace!("error listening for datagrams: {e}"); - break e; - } - } - }, Some(completed_request) = inflight_requests.join_next() => { match completed_request { Ok(()) => { @@ -107,7 +84,7 @@ impl InboundRequestHandler { self.active_peers.remove_with_stable_id( self.connection.peer_id(), self.connection.stable_id(), - crate::types::DisconnectReason::from_quinn_error(&close_reason), + crate::types::DisconnectReason::from_connection_error(&close_reason), ); inflight_requests.shutdown().await; diff --git a/crates/anemo/src/network/tests.rs b/crates/anemo/src/network/tests.rs index da67a148..37c15fdf 100644 --- a/crates/anemo/src/network/tests.rs +++ b/crates/anemo/src/network/tests.rs @@ -10,8 +10,8 @@ async fn basic_network() -> Result<()> { let msg = b"The Way of Kings"; - let network_1 = build_network()?; - let network_2 = build_network()?; + let network_1 = build_network().await?; + let network_2 = build_network().await?; let peer = network_1.connect(network_2.local_addr()).await?; let response = network_1 @@ -32,8 +32,8 @@ async fn basic_network() -> Result<()> { async fn connect() -> Result<()> { let _guard = crate::init_tracing_for_testing(); - let network_1 = build_network()?; - let network_2 = build_network()?; + let network_1 = build_network().await?; + let network_2 = build_network().await?; let peer = network_1.connect(network_2.local_addr()).await?; assert_eq!(peer, network_2.peer_id()); @@ -45,8 +45,8 @@ async fn connect() -> Result<()> { async fn connect_with_peer_id() -> Result<()> { let _guard = crate::init_tracing_for_testing(); - let network_1 = build_network()?; - let network_2 = build_network()?; + let network_1 = build_network().await?; + let network_2 = build_network().await?; let peer = network_1 .connect_with_peer_id(network_2.local_addr(), network_2.peer_id()) @@ -60,9 +60,9 @@ async fn connect_with_peer_id() -> Result<()> { async fn connect_with_invalid_peer_id() -> Result<()> { let _guard = crate::init_tracing_for_testing(); - let network_1 = build_network()?; - let network_2 = build_network()?; - let network_3 = build_network()?; + let network_1 = build_network().await?; + let network_2 = build_network().await?; + let network_3 = build_network().await?; // Try to dial network 2, but with network 3's peer id network_1 @@ -77,9 +77,9 @@ async fn connect_with_invalid_peer_id() -> Result<()> { async fn connect_with_invalid_peer_id_ensure_server_doesnt_succeed() -> Result<()> { let _guard = crate::init_tracing_for_testing(); - let network_1 = build_network()?; - let network_2 = build_network()?; - let network_3 = build_network()?; + let network_1 = build_network().await?; + let network_2 = build_network().await?; + let network_3 = build_network().await?; let (mut subscriber_2, _) = network_2.subscribe().unwrap(); @@ -123,9 +123,9 @@ async fn connect_with_invalid_peer_id_ensure_server_doesnt_succeed() -> Result<( async fn connect_with_hostname() -> Result<()> { let _guard = crate::init_tracing_for_testing(); - let network_1 = build_network()?; - let network_2 = build_network()?; - let network_3 = build_network()?; + let network_1 = build_network().await?; + let network_2 = build_network().await?; + let network_3 = build_network().await?; let peer = network_1 .connect_with_peer_id( @@ -161,7 +161,7 @@ async fn max_concurrent_connections_0() -> Result<()> { .config(config) .start(echo_service())?; - let network_2 = build_network()?; + let network_2 = build_network().await?; network_2 .connect_with_peer_id(network_1.local_addr(), network_1.peer_id()) @@ -186,8 +186,8 @@ async fn max_concurrent_connections_1() -> Result<()> { .config(config) .start(echo_service())?; - let network_2 = build_network()?; - let network_3 = build_network()?; + let network_2 = build_network().await?; + let network_3 = build_network().await?; // first connection succeeds network_2 @@ -214,8 +214,8 @@ async fn max_concurrent_connections_1() -> Result<()> { async fn reject_peer_with_affinity_never() -> Result<()> { let _guard = crate::init_tracing_for_testing(); - let network_1 = build_network()?; - let network_2 = build_network()?; + let network_1 = build_network().await?; + let network_2 = build_network().await?; // Configure peer 2 with affinity never let peer_info_2 = crate::types::PeerInfo { @@ -238,9 +238,9 @@ async fn reject_peer_with_affinity_never() -> Result<()> { async fn peers_with_affinity_never_are_not_dialed_in_the_background() -> Result<()> { let _guard = crate::init_tracing_for_testing(); - let network_1 = build_network()?; - let network_2 = build_network()?; - let network_3 = build_network()?; + let network_1 = build_network().await?; + let network_2 = build_network().await?; + let network_3 = build_network().await?; let mut subscriber_1 = network_1.subscribe()?.0; @@ -279,11 +279,11 @@ async fn peers_with_affinity_never_are_not_dialed_in_the_background() -> Result< Ok(()) } -fn build_network() -> Result { - build_network_with_addr("localhost:0") +async fn build_network() -> Result { + build_network_with_addr("localhost:0").await } -fn build_network_with_addr(addr: &str) -> Result { +async fn build_network_with_addr(addr: &str) -> Result { let network = Network::bind(addr) .random_private_key() .server_name("test") @@ -294,6 +294,8 @@ fn build_network_with_addr(addr: &str) -> Result { peer_id =% network.peer_id(), "starting network" ); + // Give time for spawned I/O tasks to spin up. + tokio::time::sleep(std::time::Duration::from_millis(100)).await; Ok(network) } @@ -312,8 +314,8 @@ fn echo_service() -> BoxCloneService, Response, Infallible async fn ip6_calling_ip4() -> Result<()> { let _guard = crate::init_tracing_for_testing(); - let network_1 = build_network_with_addr("[::]:0")?; - let network_2 = build_network_with_addr("127.0.0.1:0")?; + let network_1 = build_network_with_addr("[::]:0").await?; + let network_2 = build_network_with_addr("127.0.0.1:0").await?; let msg = b"The Way of Kings"; let peer = network_1.connect(network_2.local_addr()).await?; @@ -330,8 +332,8 @@ async fn ip6_calling_ip4() -> Result<()> { async fn localhost_calling_anyaddr() -> Result<()> { let _guard = crate::init_tracing_for_testing(); - let network_1 = build_network_with_addr("0.0.0.0:0")?; - let network_2 = build_network_with_addr("127.0.0.1:0")?; + let network_1 = build_network_with_addr("0.0.0.0:0").await?; + let network_2 = build_network_with_addr("127.0.0.1:0").await?; let msg = b"The Way of Kings"; let peer = network_2 @@ -357,8 +359,8 @@ async fn localhost_calling_anyaddr() -> Result<()> { async fn dropped_connection() -> Result<()> { let _guard = crate::init_tracing_for_testing(); - let network_1 = build_network()?; - let network_2 = build_network()?; + let network_1 = build_network().await?; + let network_2 = build_network().await?; let msg = b"The Way of Kings"; let peer = network_1.connect(network_2.local_addr()).await?; @@ -385,8 +387,8 @@ async fn basic_connectivity_check() -> Result<()> { let _guard = crate::init_tracing_for_testing(); - let network_1 = build_network()?; - let network_2 = build_network()?; + let network_1 = build_network().await?; + let network_2 = build_network().await?; let peer_id_1 = network_1.peer_id(); let peer_id_2 = network_2.peer_id(); @@ -497,7 +499,7 @@ async fn drop_shutdown() -> Result<()> { .server_name("test") .start(service)?; - let network_2 = build_network()?; + let network_2 = build_network().await?; let peer = network_2.connect(network.local_addr()).await?; let _response = network_2.rpc(peer, Request::new(Bytes::new())).await?; @@ -554,7 +556,7 @@ async fn explicit_shutdown() -> Result<()> { .server_name("test") .start(service)?; - let network_2 = build_network()?; + let network_2 = build_network().await?; let peer = network_2.connect(network.local_addr()).await?; let _response = network_2.rpc(peer, Request::new(Bytes::new())).await?; @@ -588,7 +590,7 @@ async fn explicit_shutdown() -> Result<()> { #[tokio::test] async fn subscribe_channel_closes_on_shutdown() -> Result<()> { let _guard = crate::init_tracing_for_testing(); - let network = build_network()?; + let network = build_network().await?; let mut subscriber = network.subscribe()?.0; drop(network); @@ -604,7 +606,7 @@ async fn subscribe_channel_closes_on_shutdown() -> Result<()> { #[tokio::test] async fn subscribe_channel_closes_on_explicit_shutdown() -> Result<()> { let _guard = crate::init_tracing_for_testing(); - let network = build_network()?; + let network = build_network().await?; let mut subscriber = network.subscribe()?.0; network.shutdown().await?; @@ -656,7 +658,7 @@ async fn early_termination_of_request_handlers() { .start(service) .unwrap(); - let network_2 = build_network().unwrap(); + let network_2 = build_network().await.unwrap(); let peer = network_2.connect(network.local_addr()).await.unwrap(); @@ -753,6 +755,8 @@ async fn user_provided_client_service_layer() { let (network_1, server_counter_1, client_counter_1) = create_network(); let (network_2, server_counter_2, client_counter_2) = create_network(); + // Give time for spawned I/O tasks to spin up. + tokio::time::sleep(std::time::Duration::from_millis(100)).await; let peer_id = network_1.connect(network_2.local_addr()).await.unwrap(); @@ -798,7 +802,7 @@ async fn network_ref_via_extension() -> Result<()> { .server_name("test") .random_private_key() .start(svc)?; - let network_2 = build_network()?; + let network_2 = build_network().await?; let peer = network_2.connect(network_1.local_addr()).await?; let mut response = network_2 diff --git a/crates/anemo/src/types/mod.rs b/crates/anemo/src/types/mod.rs index 2d5a8dec..675603e4 100644 --- a/crates/anemo/src/types/mod.rs +++ b/crates/anemo/src/types/mod.rs @@ -7,7 +7,8 @@ pub use address::Address; pub use peer_id::{ConnectionOrigin, Direction, PeerId}; pub use http::Extensions; -use quinn::ConnectionError; + +use crate::connection::ConnectionError; #[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord)] #[repr(u16)] @@ -79,15 +80,18 @@ pub enum DisconnectReason { } impl DisconnectReason { - pub fn from_quinn_error(error: &ConnectionError) -> Self { + pub fn from_connection_error(error: &ConnectionError) -> Self { match error { - ConnectionError::VersionMismatch => DisconnectReason::VersionMismatch, - ConnectionError::TransportError(_) => DisconnectReason::TransportError, - ConnectionError::ConnectionClosed(_) => DisconnectReason::ConnectionClosed, - ConnectionError::ApplicationClosed(_) => DisconnectReason::ApplicationClosed, - ConnectionError::Reset => DisconnectReason::Reset, - ConnectionError::TimedOut => DisconnectReason::TimedOut, - ConnectionError::LocallyClosed => DisconnectReason::LocallyClosed, + ConnectionError::Quic(error) => match error { + quinn::ConnectionError::VersionMismatch => DisconnectReason::VersionMismatch, + quinn::ConnectionError::TransportError(_) => DisconnectReason::TransportError, + quinn::ConnectionError::ConnectionClosed(_) => DisconnectReason::ConnectionClosed, + quinn::ConnectionError::ApplicationClosed(_) => DisconnectReason::ApplicationClosed, + quinn::ConnectionError::Reset => DisconnectReason::Reset, + quinn::ConnectionError::TimedOut => DisconnectReason::TimedOut, + quinn::ConnectionError::LocallyClosed => DisconnectReason::LocallyClosed, + }, + ConnectionError::Tls(_) => DisconnectReason::ConnectionClosed, } } }