Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion core/include/userver/engine/io/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <cstddef>
#include <memory>
#include <optional>
#include <span>

#include <userver/engine/deadline.hpp>
#include <userver/utils/assert.hpp>
Expand Down Expand Up @@ -106,13 +107,21 @@ class WritableBase : public WriteAwaiter {
/// @note Can return less than len if stream is closed by peer.
[[nodiscard]] virtual size_t WriteAll(const void* buf, size_t len, Deadline deadline) = 0;

[[nodiscard]] virtual size_t WriteAll(std::initializer_list<IoData> list, Deadline deadline) {
/// @brief Sends IoData array using vector I/O (e.g. writev).
/// @note Can return less than total bytes if stream is closed by peer.
[[nodiscard]] virtual size_t WriteAll(std::span<const IoData> list, Deadline deadline) {
size_t result{0};
for (const auto& io_data : list) {
result += WriteAll(io_data.data, io_data.len, deadline);
}
return result;
}

/// @brief Sends IoData initializer list using vector I/O (e.g. writev).
/// @note Can return less than total bytes if stream is closed by peer.
[[nodiscard]] virtual size_t WriteAll(std::initializer_list<IoData> list, Deadline deadline) {
return WriteAll(std::span<const IoData>{list.begin(), list.size()}, deadline);
}
};

/// @ingroup userver_base_classes
Expand Down
6 changes: 4 additions & 2 deletions core/include/userver/engine/io/socket.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

#include <sys/socket.h>

#include <initializer_list>
#include <span>

#include <userver/engine/deadline.hpp>
#include <userver/engine/io/common.hpp>
Expand Down Expand Up @@ -102,7 +102,9 @@ class [[nodiscard]] Socket final : public RwBase {
/// @snippet src/engine/io/socket_test.cpp send vector data in socket
[[nodiscard]] size_t SendAll(std::initializer_list<IoData> list, Deadline deadline);

[[nodiscard]] size_t WriteAll(std::initializer_list<IoData> list, Deadline deadline) override {
[[nodiscard]] size_t SendAll(std::span<const IoData> list, Deadline deadline);

[[nodiscard]] size_t WriteAll(std::span<const IoData> list, Deadline deadline) override {
return SendAll(list, deadline);
}

Expand Down
7 changes: 6 additions & 1 deletion core/include/userver/engine/io/tls_wrapper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
/// @file userver/engine/io/tls_wrapper.hpp
/// @brief TLS socket wrappers

#include <span>
#include <string>
#include <vector>

Expand Down Expand Up @@ -115,7 +116,11 @@ class [[nodiscard]] TlsWrapper final : public RwBase {
return SendAll(buf, len, deadline);
}

[[nodiscard]] size_t WriteAll(std::initializer_list<IoData> list, Deadline deadline) override;
[[nodiscard]] size_t WriteAll(std::span<const IoData> list, Deadline deadline) override;

[[nodiscard]] size_t WriteAll(std::initializer_list<IoData> list, Deadline deadline) override {
return WriteAll(std::span<const IoData>{list.begin(), list.size()}, deadline);
}

int GetRawFd();

Expand Down
6 changes: 5 additions & 1 deletion core/src/engine/io/socket.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -287,8 +287,12 @@ std::optional<size_t> Socket::RecvNoblock(void* buf, size_t len) {
throw IoException("Attempt to RecvNoblock from closed socket");
}

size_t Socket::SendAll(std::span<const IoData> list, Deadline deadline) {
return SendAll(list.data(), list.size(), deadline);
}

size_t Socket::SendAll(std::initializer_list<IoData> list, Deadline deadline) {
return SendAll(list.begin(), list.size(), deadline);
return SendAll(std::span<const IoData>{list.begin(), list.size()}, deadline);
}

size_t Socket::SendAll(const IoData* list, std::size_t list_size, Deadline deadline) {
Expand Down
2 changes: 1 addition & 1 deletion core/src/engine/io/tls_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -582,7 +582,7 @@ size_t TlsWrapper::SendAll(const void* buf, size_t len, Deadline deadline) {
);
}

[[nodiscard]] size_t TlsWrapper::WriteAll(std::initializer_list<IoData> list, Deadline deadline) {
[[nodiscard]] size_t TlsWrapper::WriteAll(std::span<const IoData> list, Deadline deadline) {
static constexpr std::size_t kBufSize = 4'096;
std::byte buf[kBufSize];

Expand Down
5 changes: 1 addition & 4 deletions core/src/server/http/http2_session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,7 @@ int Http2Session::OnDataFrameSend(
auto& stream = *static_cast<Stream*>(source->ptr);

const auto frame_header{ToStringView(framehd, kFrameHeaderSize)};
// TODO: doesn't work with TLS?!
UASSERT(dynamic_cast<engine::io::Socket*>(parser.socket_));
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-static-cast-downcast)
stream.Send(*static_cast<engine::io::Socket*>(parser.socket_), frame_header, max_len);
stream.Send(*parser.socket_, frame_header, max_len);
return 0;
}

Expand Down
4 changes: 2 additions & 2 deletions core/src/server/http/http2_stream.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ ssize_t Stream::GetMaxSize(std::size_t max_len, std::uint32_t* flags) {
return res;
}

void Stream::Send(engine::io::Socket& socket, std::string_view data_frame_header, std::size_t max_len) {
void Stream::Send(engine::io::RwBase& socket, std::string_view data_frame_header, std::size_t max_len) {
boost::container::small_vector<engine::io::IoData, 16> parts{};
parts.push_back({data_frame_header.data(), data_frame_header.size()});
auto budget = max_len;
Expand All @@ -126,7 +126,7 @@ void Stream::Send(engine::io::Socket& socket, std::string_view data_frame_header
budget -= part.size();
}
UASSERT(!parts.empty());
[[maybe_unused]] const auto res = socket.SendAll(parts.data(), parts.size(), {});
[[maybe_unused]] const auto res = socket.WriteAll(std::span<const engine::io::IoData>{parts.data(), parts.size()}, {});
const auto send_parts_count = pos_in_first_chunk_ == 0 ? parts.size() - 1 : parts.size() - 2; // data_frame_header
// doesn't matter
chunks_.erase(chunks_.begin(), chunks_.begin() + send_parts_count);
Expand Down
2 changes: 1 addition & 1 deletion core/src/server/http/http2_stream.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class Stream final {
bool CheckUrlComplete();
void PushChunk(std::string&& chunk);
ssize_t GetMaxSize(std::size_t max_len, std::uint32_t* flags);
void Send(engine::io::Socket& socket, std::string_view data_frame_header, std::size_t max_len);
void Send(engine::io::RwBase& socket, std::string_view data_frame_header, std::size_t max_len);
nghttp2_data_provider* GetNativeProvider() { return &nghttp2_provider_; }

private:
Expand Down
6 changes: 5 additions & 1 deletion core/src/server/server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,11 @@ ServerImpl::ServerImpl(
for (auto& port : config_.listener.ports) {
port.ReadTlsSettings(secdist);
port.InitSslCtx();
}

if (port.ssl_ctx) {
port.ssl_ctx->SetHttpVersion(config.listener.connection_config.http_version);
}
}
main_port_info_.Init(config_, config_.listener, component_context, false);
if (config_.max_response_size_in_flight) {
main_port_info_.data_accounter.SetMaxPendingResponsesSizeInBytes(*config_.max_response_size_in_flight);
Expand Down Expand Up @@ -424,3 +427,4 @@ void Server::WriteMetrics(utils::statistics::Writer& writer) const {
} // namespace server

USERVER_NAMESPACE_END

7 changes: 7 additions & 0 deletions universal/include/userver/crypto/ssl_ctx.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
/// @brief @copybrief crypto::SslCtx

#include <memory>
#include <span>
#include <string_view>
#include <vector>

#include <userver/crypto/certificate.hpp>
#include <userver/crypto/private_key.hpp>
#include <userver/http/http_version.hpp>

USERVER_NAMESPACE_BEGIN

Expand Down Expand Up @@ -43,6 +45,9 @@ class SslCtx {

void* GetRawSslCtx() const noexcept;

void SetHttpVersion(http::HttpVersion);
[[nodiscard]] std::span<const unsigned char> GetAlpn() const noexcept;

private:
void AddCertAuthorities(const std::vector<Certificate>& cert_authorities);
void EnableVerifyClientCertificate();
Expand All @@ -55,8 +60,10 @@ class SslCtx {
std::unique_ptr<Impl> impl_{};

explicit SslCtx(std::unique_ptr<Impl>&& impl);
std::span<const unsigned char> alpn_;
};

} // namespace crypto

USERVER_NAMESPACE_END

59 changes: 59 additions & 0 deletions universal/src/crypto/ssl_ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,64 @@ std::unique_ptr<SslCtx::Impl> SslCtx::Impl::MakeSslCtx() {

void* SslCtx::GetRawSslCtx() const noexcept { return static_cast<void*>(impl_->Get()); }


// https://www.rfc-editor.org/rfc/rfc7540#section-3.1
static constexpr unsigned char kAlpnHttp1Only[] = "\x08http/1.1";
// https://www.rfc-editor.org/rfc/rfc9112.html#section-12.4
static constexpr unsigned char kAlpnHttp2Only[] = "\x02h2";
static constexpr unsigned char kAlpnHttp2FallbackHttp1[] = "\x02h2\x08http/1.1";

static int AlpnSelectCallback(SSL *,
const unsigned char **out,
unsigned char *outlen,
const unsigned char *in,
unsigned int inlen,
void *arg) {

auto * context = reinterpret_cast<SslCtx*>(arg);

if (SSL_select_next_proto(const_cast<unsigned char **>(out), outlen, context->GetAlpn().data(), context->GetAlpn().size(),
in, inlen)
!= OPENSSL_NPN_NEGOTIATED)
{
LOG_ERROR() << crypto::FormatSslError("SSL_select_next_proto failed");
return SSL_TLSEXT_ERR_ALERT_FATAL;
}

LOG_INFO() << "successfully negotiated ALPN";

return SSL_TLSEXT_ERR_OK;
}

void SslCtx::SetHttpVersion(http::HttpVersion http_version) {

switch (http_version) {
case http::HttpVersion::k10:
case http::HttpVersion::k11:
alpn_ = kAlpnHttp1Only;
LOG_INFO() << "set ALPN for HTTP/1.1 only";
break;
case http::HttpVersion::k2:
alpn_ = kAlpnHttp2FallbackHttp1;
LOG_INFO() << "set ALPN for HTTP/2 with fallback to HTTP/1.1";
break;
case http::HttpVersion::k2Tls:
case http::HttpVersion::k2PriorKnowledge:
LOG_INFO() << "set ALPN for HTTP/2 only";
alpn_ = kAlpnHttp2Only;
break;
default:
LOG_INFO() << "skip setting ALPN";
return;
}

SSL_CTX_set_alpn_select_cb(impl_->Get(), AlpnSelectCallback, this);
}

std::span<const unsigned char> SslCtx::GetAlpn() const noexcept {
return alpn_;
}

SslCtx::SslCtx(std::unique_ptr<Impl>&& impl)
: impl_(std::move(impl))
{}
Expand Down Expand Up @@ -197,3 +255,4 @@ SslCtx SslCtx::CreateServerTlsContext(
} // namespace crypto

USERVER_NAMESPACE_END

Loading