Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .scripts/linux/run-clang-tidy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ cd "$REPO_DIR"
cmake -S . -B build -DCMAKE_EXPORT_COMPILE_COMMANDS=On

# generate flatbuffers files
cmake --build build --target fbgenerator_v1
cmake --build build --target fbgenerator_v2
cmake --build build --target fbgen

# check that compile_commands.json was generated
cd build
Expand Down
27 changes: 21 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ option(RL_USE_UBSAN "Compile with UndefinedBehaviorSanitizer" OFF)
option(rlclientlib_BUILD_ONNXRUNTIME_EXTENSION "Build OnnxRuntime Inference Extension" OFF)
option(rlclientlib_BUILD_DOTNET "Build .NET bindings" OFF)
option(rlclientlib_DOTNET_USE_MSPROJECT "[Experimental] Use import_external_msproject to build .NET csproj files." OFF)
option(RL_BUILD_FEDERATION "Build code for Federated Learning" OFF)

if(RL_BUILD_FEDERATION)
add_compile_definitions(RL_BUILD_FEDERATION)
# Needed for joiner code
set(RL_BUILD_EXTERNAL_PARSER ON CACHE BOOL "" FORCE)
endif()

if(RL_USE_ASAN)
add_compile_definitions(RL_USE_ASAN VW_USE_ASAN)
Expand All @@ -81,8 +88,9 @@ if(RL_USE_UBSAN)
if(MSVC)
message(FATAL_ERROR "UBSan not supported on MSVC")
else()
add_compile_options(-fsanitize=undefined -fno-sanitize-recover -fno-omit-frame-pointer -g3)
add_link_options(-fsanitize=undefined -fno-sanitize-recover -fno-omit-frame-pointer -g3)
# Flatbuffers gives errors with misaligned pointer sanitization, so disable it
add_compile_options(-fsanitize=undefined -fno-sanitize=alignment -fno-sanitize-recover -fno-omit-frame-pointer -g3)
add_link_options(-fsanitize=undefined -fno-sanitize=alignment -fno-sanitize-recover -fno-omit-frame-pointer -g3)
endif()
endif()

Expand Down Expand Up @@ -140,17 +148,17 @@ include(GNUInstallDirs)

include(ext_libs/ext_libs.cmake)

if(RL_BUILD_EXTERNAL_PARSER)
add_subdirectory(external_parser)
endif()

add_subdirectory(rlclientlib)
add_subdirectory(rlclientlib/extensions)
add_subdirectory(examples)
add_subdirectory(test_tools/joiner)
add_subdirectory(test_tools/sender_test)
add_subdirectory(test_tools/example_gen)

if(RL_BUILD_EXTERNAL_PARSER)
add_subdirectory(external_parser)
endif()

# enable_testing should be run after ext_libs so that the vw unit tests arent turned on.
enable_testing()

Expand All @@ -168,6 +176,13 @@ if(RL_BUILD_BENCHMARKS)
add_subdirectory(benchmarks)
endif()

# Add a target to generate all flatbuffer header files, used for clang-tidy
add_custom_target(fbgen)
add_dependencies(fbgen fbgenerator_v1 fbgenerator_v2)
if(TARGET fbgen_external_parser)
add_dependencies(fbgen fbgen_external_parser)
endif()

# Add the nuget subdirectory last
if(RL_BUILD_NUGET)
if(WIN32)
Expand Down
16 changes: 16 additions & 0 deletions examples/rl_sim_cpp/local_loop_client.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"model.vw.initial_command_line": "--cb_explore_adf --json --epsilon 0.0 --preserve_performance_counters -q :: --driver_output_off",
"ApplicationID": "local_loop",
"IsExplorationEnabled": true,
"InitialExplorationEpsilon": 0.2,
"protocol.version": "2",
"model.source": "LOCAL_LOOP_MODEL_DATA",
"interaction.sender.implementation": "LOCAL_LOOP_SENDER",
"observation.sender.implementation": "LOCAL_LOOP_SENDER",
"eud.duration": "0:0:1",
"joiner.problem.type": "PROBLEM_TYPE_CB",
"joiner.reward.function": "REWARD_FUNCTION_EARLIEST",
"joiner.learning.mode": "ONLINE",
"model.refreshintervalms": "5000",
"time_provider.implementation": "CLOCK_TIME_PROVIDER"
}
33 changes: 20 additions & 13 deletions examples/rl_sim_cpp/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,19 +26,26 @@ int main(int argc, char** argv)
po::variables_map process_cmd_line(const int argc, char** argv)
{
po::options_description desc("Options");
desc.add_options()("help", "produce help message")("json_config,j",
po::value<std::string>()->default_value("client.json"), "JSON file with config information for hosted RL loop")(
"log_to_file,l", po::value<bool>()->default_value(false), "Log interactions and observations to local files")(
"get_model,m", po::value<bool>()->default_value(true), "Download model from model source")(
"log_timestamp,t", po::value<bool>()->default_value(true), "Apply timestamp to all logged message")("ccb",
po::value<bool>()->default_value(false), "Run in ccb mode")("slates", po::value<bool>()->default_value(false),
"Run in slates mode")("ca", po::value<bool>()->default_value(false), "Run in continuous actions mode")(
"multistep", po::value<bool>()->default_value(false), "Run in multistep mode")(
"num_events", po::value<int>()->default_value(0), "Number of event series' to be sent. 0 is infinite.")(
"random_seed", po::value<uint64_t>()->default_value(rand()), "Random seed. Default is random")(
"delay", po::value<int64_t>()->default_value(2000), "Delay between events in ms")(
"quiet", po::bool_switch(), "Suppress logs")("random_ids", po::value<bool>()->default_value(true),
"Use randomly generated Event IDs. Default is true")("throughput", "print throughput stats");
desc.add_options()(
"help", "produce help message")(
"json_config,j", po::value<std::string>()->default_value("client.json"),
"JSON file with config information for hosted RL loop")(
"log_to_file,l", po::value<bool>()->default_value(false), "Log interactions and observations to local files")(
"get_model,m", po::value<bool>()->default_value(true), "Download model from model source")(
"log_timestamp,t", po::value<bool>()->default_value(true), "Apply timestamp to all logged message")(
"ccb", po::value<bool>()->default_value(false), "Run in ccb mode")(
"slates", po::value<bool>()->default_value(false), "Run in slates mode")(
"ca", po::value<bool>()->default_value(false), "Run in continuous actions mode")(
"multistep", po::value<bool>()->default_value(false), "Run in multistep mode")(
"num_events", po::value<int>()->default_value(0), "Number of event series' to be sent. 0 is infinite.")(
"random_seed", po::value<uint64_t>()->default_value(rand()), "Random seed. Default is random")(
"delay", po::value<int64_t>()->default_value(2000), "Delay between events in ms")(
"quiet", po::bool_switch(), "Suppress logs")(
"throughput", "print throughput stats")(
"random_ids", po::value<bool>()->default_value(true), "Use randomly generated Event IDs. Default is true")(
"refresh_model_period", po::value<uint64_t>()->default_value(0),
"Call refresh model after every N examples. 0 turns off explicit model refresh and relies on background refresh. "
"Must disable background refresh in client.json with key 'model.backgroundrefresh'");

po::variables_map vm;
store(parse_command_line(argc, argv, desc), vm);
Expand Down
69 changes: 61 additions & 8 deletions examples/rl_sim_cpp/rl_sim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,19 @@ int rl_sim::cb_loop()
{
std::cout << " " << stats.count() << ", ctxt, " << p.id() << ", action, " << chosen_action << ", outcome, "
<< outcome << ", dist, " << get_dist_str(response) << ", " << stats.get_stats(p.id(), chosen_action)
<< std::endl;
<< ", ctr: " << stats.get_ctr() << std::endl;
}
// refresh model every _model_refresh_period events
std::cerr << "Current events: " << _current_events << std::endl;
if (_model_refresh_period != 0 && (_current_events % _model_refresh_period) == 0)
{
r::api_status status;
if (_rl->refresh_model(&status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
}
std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
}

Expand Down Expand Up @@ -165,6 +176,18 @@ int rl_sim::multistep_loop()
{
std::cout << status.get_error_msg() << std::endl;
continue;

}
// refresh model every _model_refresh_period events
// Treat each episode as a single event
if (_model_refresh_period != 0 && (_current_events / episode_length) % _model_refresh_period == 0)
{
r::api_status status;
if (_rl->refresh_model(&status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
}

std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
Expand Down Expand Up @@ -207,6 +230,17 @@ int rl_sim::ca_loop()
<< stats.get_stats(joint.id(), chosen_action) << std::endl;
}

// refresh model every _model_refresh_period events
if (_model_refresh_period != 0 && _current_events % _model_refresh_period == 0)
{
r::api_status status;
if (_rl->refresh_model(&status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
}

std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
}
return 0;
Expand Down Expand Up @@ -260,6 +294,17 @@ int rl_sim::ccb_loop()
}
index++;
}

// refresh model every _model_refresh_period events
if (_model_refresh_period != 0 && _current_events % _model_refresh_period == 0)
{
r::api_status status;
if (_rl->refresh_model(&status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
}

std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
}
Expand Down Expand Up @@ -331,6 +376,17 @@ int rl_sim::slates_loop()
continue;
}

// refresh model every _model_refresh_period events
if (_model_refresh_period != 0 && _current_events % _model_refresh_period == 0)
{
r::api_status status;
if (_rl->refresh_model(&status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
}

std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
}

Expand Down Expand Up @@ -454,11 +510,7 @@ int rl_sim::init_rl()
config.set(r::name::OBSERVATION_SENDER_IMPLEMENTATION, r::value::OBSERVATION_FILE_SENDER);
}

if (!_options["get_model"].as<bool>())
{
// Set the time provider to the clock time provider
config.set(r::name::MODEL_SRC, r::value::NO_MODEL_DATA);
}
if (!_options["get_model"].as<bool>()) { config.set(r::name::MODEL_SRC, r::value::NO_MODEL_DATA); }

if (_options["log_timestamp"].as<bool>())
{
Expand Down Expand Up @@ -628,7 +680,7 @@ std::string rl_sim::create_context_json(const std::string& cntxt, const std::str

std::string rl_sim::create_event_id()
{
if (_num_events > 0 && ++_current_events >= _num_events) { _run_loop = false; }
if (++_current_events >= _num_events && _num_events > 0) { _run_loop = false; }

if (_random_ids) { return boost::uuids::to_string(boost::uuids::random_generator()()); }

Expand All @@ -637,7 +689,7 @@ std::string rl_sim::create_event_id()
return oss.str();
}

rl_sim::rl_sim(boost::program_options::variables_map vm) : _options(std::move(vm)), _loop_kind(CB)
rl_sim::rl_sim(const boost::program_options::variables_map& vm) : _options(vm), _loop_kind(CB)
{
if (_options["ccb"].as<bool>()) { _loop_kind = CCB; }
else if (_options["slates"].as<bool>()) { _loop_kind = Slates; }
Expand All @@ -649,6 +701,7 @@ rl_sim::rl_sim(boost::program_options::variables_map vm) : _options(std::move(vm
_delay = _options["delay"].as<int64_t>();
_quiet = _options["quiet"].as<bool>();
_random_ids = _options["random_ids"].as<bool>();
_model_refresh_period = _options["refresh_model_period"].as<uint64_t>();
}

std::string get_dist_str(const reinforcement_learning::ranking_response& response)
Expand Down
3 changes: 2 additions & 1 deletion examples/rl_sim_cpp/rl_sim.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class rl_sim
*
* @param vm User defined options
*/
explicit rl_sim(boost::program_options::variables_map vm);
explicit rl_sim(const boost::program_options::variables_map& vm);

/**
* @brief Simulation loop
Expand Down Expand Up @@ -177,4 +177,5 @@ class rl_sim
int64_t _delay = 2000;
bool _quiet = false;
bool _random_ids = true;
uint64_t _model_refresh_period = 0;
};
4 changes: 4 additions & 0 deletions examples/rl_sim_cpp/simulation_stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class simulation_stats
auto& item_count = _item_stats[id];
++item_count;
++_total_events;
_total_reward += outcome;
}

std::string get_stats(const std::string& id, T chosen_action)
Expand All @@ -29,8 +30,11 @@ class simulation_stats

int count() const { return _total_events; }

float get_ctr() const { return _total_reward / _total_events; }

private:
std::map<std::pair<std::string, T>, std::pair<int, int>> _action_stats;
std::map<std::string, int> _item_stats;
int _total_events = 0;
float _total_reward = 0.f;
};
16 changes: 13 additions & 3 deletions external_parser/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ set(RL_FLAT_BUFFER_FILES
)

add_flatbuffer_schema(
TARGET fbgen
TARGET fbgen_external_parser
SCHEMAS ${RL_FLAT_BUFFER_FILES}
OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}/generated/v2/
FLATC_EXE ${flatc_location}
Expand Down Expand Up @@ -154,10 +154,16 @@ set(binary_parser_sources
)

add_library(rl_binary_parser STATIC ${binary_parser_headers} ${binary_parser_sources})
set_target_properties(rl_binary_parser PROPERTIES POSITION_INDEPENDENT_CODE ON)
if(WIN32)
set_target_properties(rl_binary_parser PROPERTIES DEBUG_POSTFIX d)
endif()

target_link_libraries(rl_binary_parser PUBLIC vw_core RapidJSON PRIVATE libzstd_static)
target_include_directories(rl_binary_parser
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/
${CMAKE_CURRENT_LIST_DIR}/generated/v2/
${CMAKE_CURRENT_LIST_DIR}/../ext_libs/zstd/lib/
${CMAKE_CURRENT_LIST_DIR}/../ext_libs/date/
)
Expand All @@ -169,11 +175,15 @@ if(TARGET flatbuffers::flatbuffers)
else()
target_include_directories(rl_binary_parser PRIVATE ${FLATBUFFERS_INCLUDE_DIR})
endif()
add_dependencies(rl_binary_parser fbgen)
add_dependencies(rl_binary_parser fbgen_external_parser)

add_executable(rl_binary_parser_bin main.cc)
target_link_libraries(rl_binary_parser_bin PUBLIC rl_binary_parser)
set_target_properties(rl_binary_parser_bin PROPERTIES OUTPUT_NAME "vw")
if (NOT rlclientlib_BUILD_DOTNET)
# The build for .NET bindings configures all binary output to a single directory
# In this case we can't name the binary parser "vw", since this will overwrite the real vw executable
set_target_properties(rl_binary_parser_bin PROPERTIES OUTPUT_NAME "vw")
endif()

if(STATIC_LINK_BINARY_PARSER AND NOT APPLE)
target_link_libraries(rl_binary_parser_bin PRIVATE -static)
Expand Down
Loading