diff --git a/grpc-google/src/lib.rs b/grpc-google/src/lib.rs index aa3641372..1dd9f599f 100644 --- a/grpc-google/src/lib.rs +++ b/grpc-google/src/lib.rs @@ -37,9 +37,9 @@ use grpc::credentials::SecurityLevel; use grpc::credentials::call::CallCredentials; use grpc::credentials::call::CallDetails; use grpc::credentials::call::ClientConnectionSecurityInfo; +use grpc::metadata::AsciiMetadataValue; +use grpc::metadata::MetadataMap; use tonic::async_trait; -use tonic::metadata::AsciiMetadataValue; -use tonic::metadata::MetadataMap; const DEFAULT_CLOUD_PLATFORM_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform"; @@ -162,7 +162,7 @@ mod tests { assert!(res.is_ok()); let auth_header = metadata.get("authorization").unwrap(); - assert_eq!(auth_header.to_str().unwrap(), "Bearer valid_token"); + assert_eq!(auth_header.to_str(), "Bearer valid_token"); } #[tokio::test] @@ -184,7 +184,7 @@ mod tests { async fn non_ascii_token_internal_error() { let creds = GcpCallCredentials { provider: MockTokenProvider { - result: Ok("invalid character\n".into()), + result: Ok("invalid\ncharacter".into()), }, }; let (cd, auth_info) = fake_args(); diff --git a/grpc/Cargo.toml b/grpc/Cargo.toml index 30b55cd3e..da82e25dc 100644 --- a/grpc/Cargo.toml +++ b/grpc/Cargo.toml @@ -28,6 +28,7 @@ _runtime-tokio = [ "tokio/time", "dep:socket2", "dep:tower", + "dep:futures", ] # Used for testing with udeps as it wants this feature to exist # to be able to do its checks. @@ -44,6 +45,7 @@ tls-rustls = [ [dependencies] base64 = "0.22" bytes = "1.10.1" +futures = { version = "0.3", default-features = false, optional = true } hickory-resolver = { version = "0.25.1", optional = true } http = "1.1.0" http-body = "1.0.1" diff --git a/grpc/src/client/load_balancing/graceful_switch.rs b/grpc/src/client/load_balancing/graceful_switch.rs index f06701e49..70bc784f3 100644 --- a/grpc/src/client/load_balancing/graceful_switch.rs +++ b/grpc/src/client/load_balancing/graceful_switch.rs @@ -252,8 +252,6 @@ mod test { use std::sync::mpsc; use std::time::Duration; - use tonic::metadata::MetadataMap; - use crate::client::load_balancing::ChannelController; use crate::client::load_balancing::LbPolicy; use crate::client::load_balancing::LbState; @@ -276,6 +274,7 @@ mod test { use crate::client::name_resolution::Endpoint; use crate::client::name_resolution::ResolverUpdate; use crate::core::RequestHeaders; + use crate::metadata::MetadataMap; use crate::rt::default_runtime; const DEFAULT_TEST_SHORT_TIMEOUT: Duration = Duration::from_millis(10); diff --git a/grpc/src/client/load_balancing/mod.rs b/grpc/src/client/load_balancing/mod.rs index 2cbff629c..487d495a4 100644 --- a/grpc/src/client/load_balancing/mod.rs +++ b/grpc/src/client/load_balancing/mod.rs @@ -29,8 +29,6 @@ use std::fmt::Debug; use std::fmt::Display; use std::sync::Arc; -use tonic::metadata::MetadataMap; - use crate::StatusCodeError; use crate::StatusError; use crate::client::ConnectivityState; @@ -39,6 +37,7 @@ use crate::client::load_balancing::subchannel::SubchannelState; use crate::client::name_resolution::Address; use crate::client::name_resolution::ResolverUpdate; use crate::core::RequestHeaders; +use crate::metadata::MetadataMap; use crate::rt::GrpcRuntime; pub(crate) mod child_manager; diff --git a/grpc/src/client/load_balancing/round_robin.rs b/grpc/src/client/load_balancing/round_robin.rs index 987a6ea2d..efa2ce61a 100644 --- a/grpc/src/client/load_balancing/round_robin.rs +++ b/grpc/src/client/load_balancing/round_robin.rs @@ -257,8 +257,6 @@ mod test { use std::sync::Arc; use std::sync::mpsc; - use tonic::metadata::MetadataMap; - use crate::StatusCodeError; use crate::client::ConnectivityState; use crate::client::load_balancing::ChannelController; @@ -286,6 +284,7 @@ mod test { use crate::client::name_resolution::Endpoint; use crate::client::name_resolution::ResolverUpdate; use crate::core::RequestHeaders; + use crate::metadata::MetadataMap; use crate::rt::default_runtime; const DEFAULT_TEST_SHORT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(100); diff --git a/grpc/src/client/load_balancing/subchannel_sharing.rs b/grpc/src/client/load_balancing/subchannel_sharing.rs index 7217016bc..c600ebd2c 100644 --- a/grpc/src/client/load_balancing/subchannel_sharing.rs +++ b/grpc/src/client/load_balancing/subchannel_sharing.rs @@ -291,8 +291,6 @@ mod tests { use std::sync::Mutex; use std::sync::mpsc; - use tonic::metadata::MetadataMap; - use super::*; use crate::client::ConnectivityState; use crate::client::load_balancing::LbPolicy; @@ -309,6 +307,7 @@ mod tests { use crate::client::load_balancing::test_utils::new_request_headers; use crate::client::name_resolution::Address; use crate::client::name_resolution::ResolverUpdate; + use crate::metadata::MetadataMap; use crate::rt::default_runtime; fn test_lb_policy_options(tx_events: mpsc::Sender) -> LbPolicyOptions { diff --git a/grpc/src/client/metadata_utils.rs b/grpc/src/client/metadata_utils.rs index da20c1e6c..42c476dcf 100644 --- a/grpc/src/client/metadata_utils.rs +++ b/grpc/src/client/metadata_utils.rs @@ -23,7 +23,6 @@ */ use tokio::sync::oneshot; -use tonic::metadata::MetadataMap; use crate::client::CallOptions; use crate::client::InvokeOnce; @@ -31,6 +30,7 @@ use crate::client::RecvStream; use crate::client::interceptor::Intercept; use crate::client::interceptor::InterceptOnce; use crate::core::RequestHeaders; +use crate::metadata::MetadataMap; /// An interceptor that attaches metadata to outgoing RPC headers. pub struct AttachHeadersInterceptor { @@ -53,16 +53,16 @@ impl Intercept for AttachHeadersInterceptor { options: CallOptions, next: I, ) -> (Self::SendStream, Self::RecvStream) { - headers - .metadata_mut() - .as_mut() - .extend(self.md.as_ref().clone()); - - let md = headers.metadata_mut(); - for entry in self.md.iter() { - match entry { - tonic::metadata::KeyAndValueRef::Ascii(k, v) => _ = md.insert(k, v.clone()), - tonic::metadata::KeyAndValueRef::Binary(k, v) => _ = md.insert_bin(k, v.clone()), + let incoming_meta = headers.metadata_mut(); + incoming_meta.reserve(self.md.len()); + for kv in self.md.iter() { + match kv { + crate::metadata::KeyAndValueRef::Ascii(key, value) => { + incoming_meta.append(key, value.clone()) + } + crate::metadata::KeyAndValueRef::Binary(key, value) => { + incoming_meta.append_bin(key, value.clone()) + } } } next.invoke_once(headers, options).await @@ -179,6 +179,7 @@ mod tests { use crate::core::ClientResponseStreamItem; use crate::core::ResponseHeaders; use crate::core::Trailers; + use crate::metadata::BinaryMetadataValue; #[tokio::test] async fn test_attach_headers_interceptor() { @@ -187,7 +188,7 @@ mod tests { md.insert("x-test-header", "test-value".parse().unwrap()); md.insert_bin( "x-test-header-bin", - tonic::metadata::MetadataValue::from_bytes(b"test-bin"), + BinaryMetadataValue::from_bytes(b"test-bin"), ); let interceptor = AttachHeadersInterceptor::new(md); diff --git a/grpc/src/client/transport/tonic/mod.rs b/grpc/src/client/transport/tonic/mod.rs index b7b387e0f..be62a1fd4 100644 --- a/grpc/src/client/transport/tonic/mod.rs +++ b/grpc/src/client/transport/tonic/mod.rs @@ -28,6 +28,7 @@ use std::net::SocketAddr; use std::path::PathBuf; use std::pin::Pin; use std::str::FromStr; +use std::sync::Arc; use std::task::Context; use std::task::Poll; use std::time::Instant; @@ -35,12 +36,14 @@ use std::time::Instant; use bytes::Buf; use bytes::BufMut as _; use bytes::Bytes; +use futures::stream::StreamExt; use http::Request as HttpRequest; use http::Response as HttpResponse; use http::Uri; use http::uri::PathAndQuery; use hyper::client::conn::http2::Builder; use hyper::client::conn::http2::SendRequest; +use tokio::sync::Notify; use tokio::sync::mpsc; use tokio::sync::oneshot; use tokio_stream::Stream; @@ -56,7 +59,7 @@ use tonic::codec::Codec; use tonic::codec::Decoder; use tonic::codec::EncodeBuf; use tonic::codec::Encoder; -use tonic::metadata::MetadataMap; +use tonic::metadata::MetadataMap as TonicMeta; use tower::ServiceBuilder; use tower::buffer::Buffer; use tower::buffer::future::ResponseFuture as BufferResponseFuture; @@ -151,10 +154,18 @@ impl Invoke for TonicTransport { options: CallOptions, ) -> (Self::SendStream, Self::RecvStream) { let (req_tx, req_rx) = mpsc::channel(1); - let request_stream = ReceiverStream::new(req_rx); + let stop_notify = Arc::new(Notify::new()); + + // Tonic runs the outbound request stream in a background task. It will + // NOT automatically stop sending if the inbound response stream is + // dropped. We use `take_until` with this Notify to explicitly force the + // stream to yield `None`, which tells Tonic to cancel the stream. + let stop_notify_clone = stop_notify.clone(); + let request_stream = + ReceiverStream::new(req_rx).take_until(stop_notify_clone.notified_owned()); let mut request = TonicRequest::new(Box::pin(request_stream)); let (method, metadata) = headers.into_parts(); - *request.metadata_mut() = metadata; + *request.metadata_mut() = metadata.into(); let Ok(path) = PathAndQuery::from_maybe_shared(method) else { return err_streams(StatusError::new(StatusCodeError::Internal, "invalid path")); @@ -184,6 +195,7 @@ impl Invoke for TonicTransport { TonicSendStream { sender: Ok(req_tx) }, TonicRecvStream { state: StreamState::AwaitingHeaders(resp_rx), + stop_notify: Some(stop_notify), }, ) } @@ -192,28 +204,28 @@ impl Invoke for TonicTransport { // Converts from a tonic status to a trailers stream item. fn trailers_from_tonic_status( status: TonicStatus, - md: Option, + md: Option, ) -> ClientResponseStreamItem { - let mut trailers = Trailers::new(if status.code() == Code::Ok { - Ok(()) - } else { - Err(StatusError::new( - StatusCodeError::from(status.code() as i32), + let status_res = match status.code() { + Code::Ok => Ok(()), + code => Err(StatusError::new( + StatusCodeError::from(code as i32), status.message(), - )) - }); - if let Some(md) = md { - trailers = trailers.with_metadata(md); - } - ClientResponseStreamItem::Trailers(trailers) + )), + }; + trailers_from_status(status_res, md) } // Builds a trailers with a status -fn trailers_from_status(status: Status, md: Option) -> ClientResponseStreamItem { - let mut trailers = Trailers::new(status); - if let Some(md) = md { - trailers = trailers.with_metadata(md); - } +fn trailers_from_status(status: Status, md: Option) -> ClientResponseStreamItem { + let trailers = match md.map(TryInto::try_into) { + Some(Err(e)) => Trailers::new(Err(StatusError::new( + StatusCodeError::Internal, + format!("failed to parse metadata: {e}"), + ))), + Some(Ok(metadata)) => Trailers::new(status).with_metadata(metadata), + None => Trailers::new(status), + }; ClientResponseStreamItem::Trailers(trailers) } @@ -238,6 +250,7 @@ impl SendStream for TonicSendStream { struct TonicRecvStream { state: StreamState, + stop_notify: Option>, } enum StreamState { @@ -262,11 +275,34 @@ impl RecvStream for TonicRecvStream { StreamState::AwaitingHeaders(rx) => match rx.await { Ok(Ok(response)) => { let (metadata, stream, _extensions) = response.into_parts(); - // Start streaming and return the headers. - self.state = StreamState::Streaming(stream); - ClientResponseStreamItem::Headers( - ResponseHeaders::new().with_metadata(metadata), - ) + // Tonic decodes base64-encoded binary headers lazily. It + // does not fail the RPC upon receiving invalid base64 data; + // the error only surfaces when the application attempts to + // read the metadata. + // In contrast, standard gRPC implementations eagerly decode + // these headers and immediately fail the RPC with an + // Internal status. + match metadata.try_into() { + Ok(md) => { + // Start streaming and return the headers. + self.state = StreamState::Streaming(stream); + ClientResponseStreamItem::Headers( + ResponseHeaders::new().with_metadata(md), + ) + } + Err(e) => { + if let Some(notify) = self.stop_notify.take() { + notify.notify_one(); + } + trailers_from_status( + Err(StatusError::new( + StatusCodeError::Internal, + format!("error decoding response: {e}"), + )), + None, + ) + } + } } // Stay closed after sending trailers. Err(_) => trailers_from_status( @@ -282,16 +318,18 @@ impl RecvStream for TonicRecvStream { self.state = StreamState::Streaming(stream); ClientResponseStreamItem::Message } - // TODO: in this case, tonic believes the stream is still - // running, but our decoding failed -- do we need to terminate - // the request stream now even though the Streaming is dropped? - Err(e) => trailers_from_status( - Err(StatusError::new( - StatusCodeError::Internal, - format!("error decoding response: {e}"), - )), - None, - ), + Err(e) => { + if let Some(notify) = self.stop_notify.take() { + notify.notify_one(); + } + trailers_from_status( + Err(StatusError::new( + StatusCodeError::Internal, + format!("error decoding response: {e}"), + )), + None, + ) + } }, // Stay closed after sending trailers. Err(status) => { @@ -309,11 +347,20 @@ impl RecvStream for TonicRecvStream { } } +impl Drop for TonicRecvStream { + fn drop(&mut self) { + if let Some(notify) = &self.stop_notify { + notify.notify_one(); + } + } +} + fn err_streams(status: StatusError) -> (TonicSendStream, TonicRecvStream) { ( TonicSendStream { sender: Err(()) }, TonicRecvStream { state: StreamState::Error(status), + stop_notify: None, }, ) } diff --git a/grpc/src/client/transport/tonic/test.rs b/grpc/src/client/transport/tonic/test.rs index 62ef6bcc1..bfd6a216e 100644 --- a/grpc/src/client/transport/tonic/test.rs +++ b/grpc/src/client/transport/tonic/test.rs @@ -25,12 +25,16 @@ use std::fs; use std::path::PathBuf; use std::pin::Pin; +use std::result::Result; use std::sync::Arc; use std::sync::Once; use std::time::Duration; use bytes::Buf; use bytes::Bytes; +use http::HeaderMap; +use http::HeaderName; +use http::HeaderValue; use tokio::net::TcpListener; use tokio::sync::Notify; use tokio::sync::oneshot; @@ -41,7 +45,7 @@ use tokio_stream::wrappers::ReceiverStream; use tokio_stream::wrappers::TcpListenerStream; use tonic::Response; use tonic::async_trait; -use tonic::metadata::MetadataMap; +use tonic::metadata::MetadataMap as TonicMetadata; use tonic::transport::Server; use tonic_prost::prost::Message as ProstMessage; @@ -80,6 +84,8 @@ use crate::echo_pb::EchoRequest; use crate::echo_pb::EchoResponse; use crate::echo_pb::echo_server::Echo; use crate::echo_pb::echo_server::EchoServer; +use crate::metadata::AsciiMetadataKey; +use crate::metadata::MetadataMap; use crate::rt::GrpcRuntime; use crate::rt::tokio::TokioRuntime; @@ -103,8 +109,7 @@ impl CallCredentials for MockCallCredentials { } for (key, val) in &self.metadata { metadata.insert( - key.parse::>() - .unwrap(), + key.parse::().unwrap(), val.parse().unwrap(), ); } @@ -129,7 +134,9 @@ pub(crate) async fn tonic_transport_rpc() { let shutdown_notify_copy = shutdown_notify.clone(); println!("EchoServer listening on: {addr}"); let server_handle = tokio::spawn(async move { - let echo_server = EchoService {}; + let echo_server = EchoService { + response_headers: None, + }; let svc = EchoServer::new(echo_server); let _ = Server::builder() .add_service(svc) @@ -228,7 +235,9 @@ async fn grpc_invoke_tonic_unary() { // Spawn a task for the server. let server_handle = tokio::spawn(async move { - let echo_server = EchoService {}; + let echo_server = EchoService { + response_headers: None, + }; let svc = EchoServer::new(echo_server); let _ = Server::builder() .add_service(svc) @@ -284,7 +293,9 @@ mod unix_tests { let shutdown_notify_copy = shutdown_notify.clone(); let server_handle = tokio::spawn(async move { - let echo_server = EchoService {}; + let echo_server = EchoService { + response_headers: None, + }; let svc = EchoServer::new(echo_server); let _ = Server::builder() .add_service(svc) @@ -419,7 +430,9 @@ async fn grpc_invoke_tonic_unary_tls() { // Spawn a task for the server. let server_handle = tokio::spawn(async move { - let echo_server = EchoService {}; + let echo_server = EchoService { + response_headers: None, + }; let svc = EchoServer::new(echo_server); let _ = Server::builder() .tls_config(tls_config) @@ -472,7 +485,9 @@ async fn grpc_invoke_failure_cases() { let shutdown_notify_copy = shutdown_notify.clone(); tokio::spawn(async move { - let echo_server = EchoService {}; + let echo_server = EchoService { + response_headers: None, + }; let svc = EchoServer::new(echo_server); let _ = Server::builder() .add_service(svc) @@ -622,6 +637,152 @@ async fn perform_unary_echo_failure(channel: &Channel) -> Trailers { t } +#[tokio::test] +async fn tonic_transport_invalid_base64_headers() { + super::reg(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let shutdown_notify = Arc::new(Notify::new()); + let shutdown_notify_copy = shutdown_notify.clone(); + + let mut headers = HeaderMap::new(); + headers.insert( + HeaderName::from_static("test-bin"), + HeaderValue::from_static("invalid base64 data"), + ); + let response_headers = Some(TonicMetadata::from_headers(headers)); + + let server_handle = tokio::spawn(async move { + let echo_server = EchoService { response_headers }; + let svc = EchoServer::new(echo_server); + let _ = Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown( + TcpListenerStream::new(listener), + shutdown_notify_copy.notified(), + ) + .await; + }); + + let builder = GLOBAL_TRANSPORT_REGISTRY + .get_transport(TCP_IP_NETWORK_TYPE) + .unwrap(); + let config = Arc::new(TransportOptions::default()); + let securty_opts = SecurityOpts { + credentials: LocalChannelCredentials::new_arc(), + authority: Authority::new("localhost".to_string(), None), + handshake_info: ClientHandshakeInfo::default(), + }; + let (conn, _sec_info, _disconnection_listener) = builder + .dyn_connect( + addr.to_string(), + GrpcRuntime::new(TokioRuntime::default()), + &securty_opts, + &config, + ) + .await + .unwrap(); + + let (mut tx, mut rx) = conn + .dyn_invoke( + RequestHeaders::new() + .with_method_name("/grpc.examples.echo.Echo/BidirectionalStreamingEcho"), + CallOptions::default(), + ) + .await; + + let mut dummy_msg = WrappedEchoResponse(EchoResponse { message: "".into() }); + + match rx.next(&mut dummy_msg).await { + ClientResponseStreamItem::Trailers(trailers) => { + println!("Got trailers as expected due to invalid headers"); + let status = trailers.status().as_ref().unwrap_err(); + assert_eq!(status.code(), StatusCodeError::Internal); + } + item => panic!("Expected Trailers with error, got {:?}", item), + } + + let request = EchoRequest { + message: "hello".into(), + }; + let req = WrappedEchoRequest(request); + + tokio::time::timeout(DEFAULT_TEST_DURATION, async { + while tx.send(&req, SendOptions::default()).await.is_ok() {} + }) + .await + .expect("timed out waiting for stream to close"); + + shutdown_notify.notify_one(); + server_handle.await.unwrap(); +} + +#[tokio::test] +async fn tonic_transport_recv_drop_cancels_send() { + super::reg(); + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let shutdown_notify = Arc::new(Notify::new()); + let shutdown_notify_copy = shutdown_notify.clone(); + + let server_handle = tokio::spawn(async move { + let echo_server = EchoService { + response_headers: None, + }; + let svc = EchoServer::new(echo_server); + let _ = Server::builder() + .add_service(svc) + .serve_with_incoming_shutdown( + TcpListenerStream::new(listener), + shutdown_notify_copy.notified(), + ) + .await; + }); + + let builder = GLOBAL_TRANSPORT_REGISTRY + .get_transport(TCP_IP_NETWORK_TYPE) + .unwrap(); + let config = Arc::new(TransportOptions::default()); + let securty_opts = SecurityOpts { + credentials: InsecureChannelCredentials::new_arc(), + authority: Authority::new("localhost".to_string(), None), + handshake_info: ClientHandshakeInfo::default(), + }; + let (conn, _sec_info, _disconnection_listener) = builder + .dyn_connect( + addr.to_string(), + GrpcRuntime::new(TokioRuntime::default()), + &securty_opts, + &config, + ) + .await + .unwrap(); + + let (mut tx, rx) = conn + .dyn_invoke( + RequestHeaders::new() + .with_method_name("/grpc.examples.echo.Echo/BidirectionalStreamingEcho"), + CallOptions::default(), + ) + .await; + + drop(rx); + + let request = EchoRequest { + message: "hello".into(), + }; + let req = WrappedEchoRequest(request); + + tokio::time::timeout(DEFAULT_TEST_DURATION, async { + while tx.send(&req, SendOptions::default()).await.is_ok() {} + }) + .await + .expect("timed out waiting for stream to close"); + + shutdown_notify.notify_one(); + server_handle.await.unwrap(); +} + struct WrappedEchoRequest(EchoRequest); struct WrappedEchoResponse(EchoResponse); @@ -640,14 +801,16 @@ impl RecvMessage for WrappedEchoResponse { } #[derive(Debug)] -pub(crate) struct EchoService {} +struct EchoService { + response_headers: Option, +} #[async_trait] impl Echo for EchoService { async fn unary_echo( &self, request: tonic::Request, - ) -> std::result::Result, tonic::Status> { + ) -> Result, tonic::Status> { let metadata = request.metadata().clone(); let message = request.into_inner().message; let mut response = tonic::Response::new(EchoResponse { message }); @@ -664,14 +827,14 @@ impl Echo for EchoService { async fn server_streaming_echo( &self, _: tonic::Request, - ) -> std::result::Result, tonic::Status> { + ) -> Result, tonic::Status> { unimplemented!() } async fn client_streaming_echo( &self, _: tonic::Request>, - ) -> std::result::Result, tonic::Status> { + ) -> Result, tonic::Status> { unimplemented!() } type BidirectionalStreamingEchoStream = @@ -680,8 +843,7 @@ impl Echo for EchoService { async fn bidirectional_streaming_echo( &self, request: tonic::Request>, - ) -> std::result::Result, tonic::Status> - { + ) -> Result, tonic::Status> { let metadata = request.metadata().clone(); if let Some(val) = metadata.get("x-test-metadata") && val == "test-value" @@ -702,8 +864,11 @@ impl Echo for EchoService { println!("Server closing stream"); }; - Ok(Response::new( - Box::pin(outbound) as Self::BidirectionalStreamingEchoStream - )) + let mut response = + Response::new(Box::pin(outbound) as Self::BidirectionalStreamingEchoStream); + if let Some(headers) = &self.response_headers { + *response.metadata_mut() = headers.clone(); + } + Ok(response) } } diff --git a/grpc/src/core/mod.rs b/grpc/src/core/mod.rs index 44c886072..90440711f 100644 --- a/grpc/src/core/mod.rs +++ b/grpc/src/core/mod.rs @@ -29,8 +29,8 @@ use std::any::TypeId; use bytes::Buf; -use tonic::metadata::MetadataMap; +use crate::metadata::MetadataMap; use crate::status::Status; #[allow(unused)] diff --git a/grpc/src/credentials/call.rs b/grpc/src/credentials/call.rs index 9d852d183..37f5d9415 100644 --- a/grpc/src/credentials/call.rs +++ b/grpc/src/credentials/call.rs @@ -26,11 +26,11 @@ use std::fmt::Debug; use std::sync::Arc; use tonic::async_trait; -use tonic::metadata::MetadataMap; use crate::StatusError; use crate::attributes::Attributes; use crate::credentials::SecurityLevel; +use crate::metadata::MetadataMap; /// Details regarding the call. /// @@ -175,9 +175,9 @@ impl CallCredentials for CompositeCallCredentials { #[cfg(test)] mod tests { - use tonic::metadata::MetadataValue; - use super::*; + use crate::metadata::AsciiMetadataKey; + use crate::metadata::AsciiMetadataValue; #[derive(Debug)] struct MockCallCredentials { @@ -195,10 +195,8 @@ mod tests { metadata: &mut MetadataMap, ) -> Result<(), StatusError> { metadata.insert( - self.key - .parse::>() - .unwrap(), - MetadataValue::try_from(&self.value).unwrap(), + self.key.parse::().unwrap(), + AsciiMetadataValue::try_from(&self.value).unwrap(), ); Ok(()) } diff --git a/grpc/src/credentials/client.rs b/grpc/src/credentials/client.rs index 45a642ea8..63ffb596d 100644 --- a/grpc/src/credentials/client.rs +++ b/grpc/src/credentials/client.rs @@ -203,8 +203,6 @@ impl ChannelCredentials for CompositeChannelCredentials Result<(), StatusError> { metadata.insert( - self.key - .parse::>() - .unwrap(), - MetadataValue::try_from(self.value).unwrap(), + self.key.parse::().unwrap(), + AsciiMetadataValue::try_from(self.value).unwrap(), ); Ok(()) } diff --git a/grpc/src/metadata/map.rs b/grpc/src/metadata/map.rs index 79fd06038..1ec3b614d 100644 --- a/grpc/src/metadata/map.rs +++ b/grpc/src/metadata/map.rs @@ -158,7 +158,11 @@ impl MetadataMap { } /// Convert an HTTP HeaderMap to a MetadataMap - pub(crate) fn from_headers(headers: HeaderMap) -> Self { + /// + /// # Errors + /// + /// Returns an error if base64 decoding of a binary metadata value fails. + pub(crate) fn from_headers(headers: HeaderMap) -> Result { let mut ret = Vec::with_capacity(headers.len()); let mut current_key: Option = None; @@ -180,16 +184,17 @@ impl MetadataMap { mv.set_sensitive(value.is_sensitive()); ret.push((k.clone(), mv.into_inner())); } - } else if Binary::is_valid_key(key_str) - && let Ok(b) = Binary::decode(value.as_bytes(), private::Internal) - { + } else if Binary::is_valid_key(key_str) { + let b = Binary::decode(value.as_bytes(), private::Internal).map_err(|e| { + format!("failed to decode base64 value for key '{key_str}': {e}") + })?; let mut mv = unsafe { MetadataValue::::from_shared_unchecked(b) }; mv.set_sensitive(value.is_sensitive()); ret.push((k.clone(), mv.into_inner())); } } - Self { headers: ret } + Ok(Self { headers: ret }) } /// Convert a MetadataMap into a HTTP HeaderMap. @@ -911,8 +916,10 @@ impl MetadataMap { } } -impl From for MetadataMap { - fn from(tonic_map: tonic::metadata::MetadataMap) -> Self { +impl TryFrom for MetadataMap { + type Error = String; + + fn try_from(tonic_map: tonic::metadata::MetadataMap) -> Result { Self::from_headers(tonic_map.into_headers()) } } @@ -1418,10 +1425,7 @@ mod tests { // in gRPC MetadataValue. http_map.insert("x-invalid-ascii", HeaderValue::from_bytes(&[0xFA]).unwrap()); - // Invalid Binary value (not valid base64) - http_map.insert("invalid-bin", "not-base64-!!!".parse().unwrap()); - - let map = MetadataMap::from_headers(http_map); + let map = MetadataMap::from_headers(http_map).unwrap(); assert_eq!(map.len(), 2); assert_eq!(map.get("x-host").unwrap(), "example.com"); @@ -1429,7 +1433,17 @@ mod tests { assert!(!map.contains_key("x-host!")); assert!(!map.contains_key("x-invalid-ascii")); - assert!(!map.contains_key("invalid-bin")); + } + + #[test] + fn test_from_headers_fails_on_invalid_bin_header() { + let mut http_map = http::HeaderMap::new(); + + // Invalid Binary value (not valid base64) + http_map.insert("invalid-bin", "not-base64-!!!".parse().unwrap()); + + let result = MetadataMap::from_headers(http_map); + assert!(result.is_err()); } #[test] @@ -1560,7 +1574,7 @@ mod tests { assert_eq!(tonic_map.get("x-host").unwrap(), "example.com"); assert_eq!(tonic_map.get_bin("trace-proto-bin").unwrap(), "Hello!!"); - let back_map: MetadataMap = tonic_map.into(); + let back_map: MetadataMap = tonic_map.try_into().unwrap(); assert_eq!(back_map.len(), 2); assert_eq!(back_map.get("x-host").unwrap(), "example.com"); assert_eq!(back_map.get_bin("trace-proto-bin").unwrap(), "Hello!!"); diff --git a/grpc/src/metadata/value.rs b/grpc/src/metadata/value.rs index 368d79817..519696622 100644 --- a/grpc/src/metadata/value.rs +++ b/grpc/src/metadata/value.rs @@ -440,7 +440,7 @@ impl MetadataValue { /// # Examples /// /// ``` - /// # use tonic::metadata::*; + /// # use grpc::metadata::*; /// let val = BinaryMetadataValue::from_bytes(b"hello\xfa"); /// assert_eq!(val, &b"hello\xfa"[..]); /// ``` diff --git a/interop/src/client_protobuf.rs b/interop/src/client_protobuf.rs index 7736c394f..e063be946 100644 --- a/interop/src/client_protobuf.rs +++ b/interop/src/client_protobuf.rs @@ -28,12 +28,12 @@ use grpc::client::Channel; use grpc::client::metadata_utils::AttachHeadersInterceptor; use grpc::client::metadata_utils::CaptureHeadersInterceptor; use grpc::client::metadata_utils::CaptureTrailersInterceptor; +use grpc::metadata::MetadataMap; +use grpc::metadata::MetadataValue; use grpc_protobuf::CallBuilder; use protobuf::message_eq; use protobuf::proto; use tonic::async_trait; -use tonic::metadata::MetadataMap; -use tonic::metadata::MetadataValue; use crate::TestAssertion; use crate::client::InteropTest;