Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions grpc-google/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ use grpc::credentials::SecurityLevel;
use grpc::credentials::call::CallCredentials;
use grpc::credentials::call::CallDetails;
use grpc::credentials::call::ClientConnectionSecurityInfo;
use grpc::metadata::AsciiMetadataValue;
use grpc::metadata::MetadataMap;
use tonic::async_trait;
use tonic::metadata::AsciiMetadataValue;
use tonic::metadata::MetadataMap;

const DEFAULT_CLOUD_PLATFORM_SCOPE: &str = "https://www.googleapis.com/auth/cloud-platform";

Expand Down Expand Up @@ -162,7 +162,7 @@ mod tests {
assert!(res.is_ok());

let auth_header = metadata.get("authorization").unwrap();
assert_eq!(auth_header.to_str().unwrap(), "Bearer valid_token");
assert_eq!(auth_header.to_str(), "Bearer valid_token");
}

#[tokio::test]
Expand All @@ -184,7 +184,7 @@ mod tests {
async fn non_ascii_token_internal_error() {
let creds = GcpCallCredentials {
provider: MockTokenProvider {
result: Ok("invalid character\n".into()),
result: Ok("invalid\ncharacter".into()),
},
};
let (cd, auth_info) = fake_args();
Expand Down
2 changes: 2 additions & 0 deletions grpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ _runtime-tokio = [
"tokio/time",
"dep:socket2",
"dep:tower",
"dep:futures",
]
# Used for testing with udeps as it wants this feature to exist
# to be able to do its checks.
Expand All @@ -44,6 +45,7 @@ tls-rustls = [
[dependencies]
base64 = "0.22"
bytes = "1.10.1"
futures = { version = "0.3", default-features = false, optional = true }
hickory-resolver = { version = "0.25.1", optional = true }
http = "1.1.0"
http-body = "1.0.1"
Expand Down
3 changes: 1 addition & 2 deletions grpc/src/client/load_balancing/graceful_switch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -252,8 +252,6 @@ mod test {
use std::sync::mpsc;
use std::time::Duration;

use tonic::metadata::MetadataMap;

use crate::client::load_balancing::ChannelController;
use crate::client::load_balancing::LbPolicy;
use crate::client::load_balancing::LbState;
Expand All @@ -276,6 +274,7 @@ mod test {
use crate::client::name_resolution::Endpoint;
use crate::client::name_resolution::ResolverUpdate;
use crate::core::RequestHeaders;
use crate::metadata::MetadataMap;
use crate::rt::default_runtime;

const DEFAULT_TEST_SHORT_TIMEOUT: Duration = Duration::from_millis(10);
Expand Down
3 changes: 1 addition & 2 deletions grpc/src/client/load_balancing/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@ use std::fmt::Debug;
use std::fmt::Display;
use std::sync::Arc;

use tonic::metadata::MetadataMap;

use crate::StatusCodeError;
use crate::StatusError;
use crate::client::ConnectivityState;
Expand All @@ -39,6 +37,7 @@ use crate::client::load_balancing::subchannel::SubchannelState;
use crate::client::name_resolution::Address;
use crate::client::name_resolution::ResolverUpdate;
use crate::core::RequestHeaders;
use crate::metadata::MetadataMap;
use crate::rt::GrpcRuntime;

pub(crate) mod child_manager;
Expand Down
3 changes: 1 addition & 2 deletions grpc/src/client/load_balancing/round_robin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,6 @@ mod test {
use std::sync::Arc;
use std::sync::mpsc;

use tonic::metadata::MetadataMap;

use crate::StatusCodeError;
use crate::client::ConnectivityState;
use crate::client::load_balancing::ChannelController;
Expand Down Expand Up @@ -286,6 +284,7 @@ mod test {
use crate::client::name_resolution::Endpoint;
use crate::client::name_resolution::ResolverUpdate;
use crate::core::RequestHeaders;
use crate::metadata::MetadataMap;
use crate::rt::default_runtime;

const DEFAULT_TEST_SHORT_TIMEOUT: std::time::Duration = std::time::Duration::from_millis(100);
Expand Down
3 changes: 1 addition & 2 deletions grpc/src/client/load_balancing/subchannel_sharing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -291,8 +291,6 @@ mod tests {
use std::sync::Mutex;
use std::sync::mpsc;

use tonic::metadata::MetadataMap;

use super::*;
use crate::client::ConnectivityState;
use crate::client::load_balancing::LbPolicy;
Expand All @@ -309,6 +307,7 @@ mod tests {
use crate::client::load_balancing::test_utils::new_request_headers;
use crate::client::name_resolution::Address;
use crate::client::name_resolution::ResolverUpdate;
use crate::metadata::MetadataMap;
use crate::rt::default_runtime;

fn test_lb_policy_options(tx_events: mpsc::Sender<TestEvent>) -> LbPolicyOptions {
Expand Down
25 changes: 13 additions & 12 deletions grpc/src/client/metadata_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
*/

use tokio::sync::oneshot;
use tonic::metadata::MetadataMap;

use crate::client::CallOptions;
use crate::client::InvokeOnce;
use crate::client::RecvStream;
use crate::client::interceptor::Intercept;
use crate::client::interceptor::InterceptOnce;
use crate::core::RequestHeaders;
use crate::metadata::MetadataMap;

/// An interceptor that attaches metadata to outgoing RPC headers.
pub struct AttachHeadersInterceptor {
Expand All @@ -53,16 +53,16 @@ impl<I: InvokeOnce> Intercept<I> for AttachHeadersInterceptor {
options: CallOptions,
next: I,
) -> (Self::SendStream, Self::RecvStream) {
headers
.metadata_mut()
.as_mut()
.extend(self.md.as_ref().clone());

let md = headers.metadata_mut();
for entry in self.md.iter() {
match entry {
tonic::metadata::KeyAndValueRef::Ascii(k, v) => _ = md.insert(k, v.clone()),
tonic::metadata::KeyAndValueRef::Binary(k, v) => _ = md.insert_bin(k, v.clone()),
let incoming_meta = headers.metadata_mut();
incoming_meta.reserve(self.md.len());
for kv in self.md.iter() {
match kv {
crate::metadata::KeyAndValueRef::Ascii(key, value) => {
incoming_meta.append(key, value.clone())
}
crate::metadata::KeyAndValueRef::Binary(key, value) => {
incoming_meta.append_bin(key, value.clone())
}
}
}
next.invoke_once(headers, options).await
Expand Down Expand Up @@ -179,6 +179,7 @@ mod tests {
use crate::core::ClientResponseStreamItem;
use crate::core::ResponseHeaders;
use crate::core::Trailers;
use crate::metadata::BinaryMetadataValue;

#[tokio::test]
async fn test_attach_headers_interceptor() {
Expand All @@ -187,7 +188,7 @@ mod tests {
md.insert("x-test-header", "test-value".parse().unwrap());
md.insert_bin(
"x-test-header-bin",
tonic::metadata::MetadataValue::from_bytes(b"test-bin"),
BinaryMetadataValue::from_bytes(b"test-bin"),
);
let interceptor = AttachHeadersInterceptor::new(md);

Expand Down
117 changes: 82 additions & 35 deletions grpc/src/client/transport/tonic/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,22 @@ use std::net::SocketAddr;
use std::path::PathBuf;
use std::pin::Pin;
use std::str::FromStr;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use std::time::Instant;

use bytes::Buf;
use bytes::BufMut as _;
use bytes::Bytes;
use futures::stream::StreamExt;
use http::Request as HttpRequest;
use http::Response as HttpResponse;
use http::Uri;
use http::uri::PathAndQuery;
use hyper::client::conn::http2::Builder;
use hyper::client::conn::http2::SendRequest;
use tokio::sync::Notify;
use tokio::sync::mpsc;
use tokio::sync::oneshot;
use tokio_stream::Stream;
Expand All @@ -56,7 +59,7 @@ use tonic::codec::Codec;
use tonic::codec::Decoder;
use tonic::codec::EncodeBuf;
use tonic::codec::Encoder;
use tonic::metadata::MetadataMap;
use tonic::metadata::MetadataMap as TonicMeta;
use tower::ServiceBuilder;
use tower::buffer::Buffer;
use tower::buffer::future::ResponseFuture as BufferResponseFuture;
Expand Down Expand Up @@ -151,10 +154,18 @@ impl Invoke for TonicTransport {
options: CallOptions,
) -> (Self::SendStream, Self::RecvStream) {
let (req_tx, req_rx) = mpsc::channel(1);
let request_stream = ReceiverStream::new(req_rx);
let stop_notify = Arc::new(Notify::new());

// Tonic runs the outbound request stream in a background task. It will
// NOT automatically stop sending if the inbound response stream is
// dropped. We use `take_until` with this Notify to explicitly force the
// stream to yield `None`, which tells Tonic to cancel the stream.
let stop_notify_clone = stop_notify.clone();
let request_stream =
ReceiverStream::new(req_rx).take_until(stop_notify_clone.notified_owned());
let mut request = TonicRequest::new(Box::pin(request_stream));
let (method, metadata) = headers.into_parts();
*request.metadata_mut() = metadata;
*request.metadata_mut() = metadata.into();

let Ok(path) = PathAndQuery::from_maybe_shared(method) else {
return err_streams(StatusError::new(StatusCodeError::Internal, "invalid path"));
Expand Down Expand Up @@ -184,6 +195,7 @@ impl Invoke for TonicTransport {
TonicSendStream { sender: Ok(req_tx) },
TonicRecvStream {
state: StreamState::AwaitingHeaders(resp_rx),
stop_notify: Some(stop_notify),
},
)
}
Expand All @@ -192,28 +204,28 @@ impl Invoke for TonicTransport {
// Converts from a tonic status to a trailers stream item.
fn trailers_from_tonic_status(
status: TonicStatus,
md: Option<MetadataMap>,
md: Option<TonicMeta>,
) -> ClientResponseStreamItem {
let mut trailers = Trailers::new(if status.code() == Code::Ok {
Ok(())
} else {
Err(StatusError::new(
StatusCodeError::from(status.code() as i32),
let status_res = match status.code() {
Code::Ok => Ok(()),
code => Err(StatusError::new(
StatusCodeError::from(code as i32),
status.message(),
))
});
if let Some(md) = md {
trailers = trailers.with_metadata(md);
}
ClientResponseStreamItem::Trailers(trailers)
)),
};
trailers_from_status(status_res, md)
}

// Builds a trailers with a status
fn trailers_from_status(status: Status, md: Option<MetadataMap>) -> ClientResponseStreamItem {
let mut trailers = Trailers::new(status);
if let Some(md) = md {
trailers = trailers.with_metadata(md);
}
fn trailers_from_status(status: Status, md: Option<TonicMeta>) -> ClientResponseStreamItem {
let trailers = match md.map(TryInto::try_into) {
Some(Err(e)) => Trailers::new(Err(StatusError::new(
StatusCodeError::Internal,
format!("failed to parse metadata: {e}"),
))),
Some(Ok(metadata)) => Trailers::new(status).with_metadata(metadata),
None => Trailers::new(status),
};
ClientResponseStreamItem::Trailers(trailers)
}

Expand All @@ -238,6 +250,7 @@ impl SendStream for TonicSendStream {

struct TonicRecvStream {
state: StreamState,
stop_notify: Option<Arc<Notify>>,
}

enum StreamState {
Expand All @@ -262,11 +275,34 @@ impl RecvStream for TonicRecvStream {
StreamState::AwaitingHeaders(rx) => match rx.await {
Ok(Ok(response)) => {
let (metadata, stream, _extensions) = response.into_parts();
// Start streaming and return the headers.
self.state = StreamState::Streaming(stream);
ClientResponseStreamItem::Headers(
ResponseHeaders::new().with_metadata(metadata),
)
// Tonic decodes base64-encoded binary headers lazily. It
// does not fail the RPC upon receiving invalid base64 data;
// the error only surfaces when the application attempts to
// read the metadata.
// In contrast, standard gRPC implementations eagerly decode
// these headers and immediately fail the RPC with an
// Internal status.
match metadata.try_into() {
Ok(md) => {
// Start streaming and return the headers.
self.state = StreamState::Streaming(stream);
ClientResponseStreamItem::Headers(
ResponseHeaders::new().with_metadata(md),
)
}
Err(e) => {
if let Some(notify) = self.stop_notify.take() {
notify.notify_one();
}
trailers_from_status(
Err(StatusError::new(
StatusCodeError::Internal,
format!("error decoding response: {e}"),
)),
None,
)
}
}
}
// Stay closed after sending trailers.
Err(_) => trailers_from_status(
Expand All @@ -282,16 +318,18 @@ impl RecvStream for TonicRecvStream {
self.state = StreamState::Streaming(stream);
ClientResponseStreamItem::Message
}
// TODO: in this case, tonic believes the stream is still
// running, but our decoding failed -- do we need to terminate
// the request stream now even though the Streaming is dropped?
Err(e) => trailers_from_status(
Err(StatusError::new(
StatusCodeError::Internal,
format!("error decoding response: {e}"),
)),
None,
),
Err(e) => {
if let Some(notify) = self.stop_notify.take() {
notify.notify_one();
}
trailers_from_status(
Err(StatusError::new(
StatusCodeError::Internal,
format!("error decoding response: {e}"),
)),
None,
)
}
},
// Stay closed after sending trailers.
Err(status) => {
Expand All @@ -309,11 +347,20 @@ impl RecvStream for TonicRecvStream {
}
}

impl Drop for TonicRecvStream {
fn drop(&mut self) {
if let Some(notify) = &self.stop_notify {
notify.notify_one();
}
}
}

fn err_streams(status: StatusError) -> (TonicSendStream, TonicRecvStream) {
(
TonicSendStream { sender: Err(()) },
TonicRecvStream {
state: StreamState::Error(status),
stop_notify: None,
},
)
}
Expand Down
Loading
Loading