Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
c1054ce
Implement prototype local join and train loop
jackgerrits May 9, 2022
63fd388
Merge branch 'master' of https://github.com/VowpalWabbit/reinforcemen…
byronxu99 Oct 13, 2022
f78183a
Merge branch 'master' of https://github.com/VowpalWabbit/reinforcemen…
byronxu99 Oct 14, 2022
349e555
Fix build errors
byronxu99 Oct 14, 2022
15f4143
clang format
byronxu99 Oct 14, 2022
99cac6b
Merge branch 'master' of https://github.com/VowpalWabbit/reinforcemen…
byronxu99 Oct 19, 2022
daee5a6
Bump vw submodule to include delta serialization
jackgerrits Oct 20, 2022
8ed5ab7
add local client implementation (#521)
jackgerrits Oct 21, 2022
e473556
Merge branch 'master' into local_loop_prototype
jackgerrits Nov 9, 2022
7355faa
Merge remote-tracking branch 'origin/master' into local_loop_prototype
jackgerrits Nov 9, 2022
0f0e1f8
Add local joining and training for federated learning (#523)
byronxu99 Jan 10, 2023
13e744e
Merge branch 'master' of github.com:VowpalWabbit/reinforcement_learni…
byronxu99 Jan 10, 2023
3477f23
remove old implementation of local model (#539)
jackgerrits Jan 17, 2023
f553882
fix model refresh from background thread (#540)
jackgerrits Jan 17, 2023
5506c1f
improve logging and dont learn on newline example (#541)
jackgerrits Jan 18, 2023
d592156
fix segfault from using wrong finish (#542)
jackgerrits Jan 18, 2023
7703abb
Add ctr to cb and add an option to refresh after every n examples (#543)
jackgerrits Jan 20, 2023
ab01bee
add schema for model update (#545)
jackgerrits Jan 24, 2023
05f6d82
Change name ModelUpdate -> ModelUpdateEvent in schema (#564)
jackgerrits Feb 10, 2023
9f4ff7d
Update order of steps in local loop controller (#563)
byronxu99 Feb 13, 2023
5b7aa3f
merge
bassmang Apr 5, 2023
60d3f4e
clang
bassmang Apr 5, 2023
a8e4782
Merge branch 'master' into local_loop_prototype
bassmang Apr 7, 2023
f699c63
feat: add id and counter (#580)
bassmang Apr 10, 2023
4d2cbf0
Merge branch 'master' into local_loop_prototype
bassmang Apr 10, 2023
60e3e3c
Merge branch 'master' into local_loop_prototype
bassmang Apr 26, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions .scripts/linux/run-clang-tidy.sh
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ cd "$REPO_DIR"
cmake -S . -B build -DCMAKE_EXPORT_COMPILE_COMMANDS=On

# generate flatbuffers files
cmake --build build --target fbgenerator_v1
cmake --build build --target fbgenerator_v2
cmake --build build --target fbgen

# check that compile_commands.json was generated
cd build
Expand Down
28 changes: 22 additions & 6 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,17 @@ option(RL_BUILD_NUGET "Build the Nuget package" OFF)
option(RL_BUILD_EXTERNAL_PARSER "Build external parser version of VW" OFF)
option(RL_USE_ASAN "Compile with AddressSanitizer" OFF)
option(RL_USE_UBSAN "Compile with UndefinedBehaviorSanitizer" OFF)
option(RL_BUILD_FEDERATION "Build code for Federated Learning" ON)
option(rlclientlib_BUILD_ONNXRUNTIME_EXTENSION "Build OnnxRuntime Inference Extension" OFF)
option(rlclientlib_BUILD_DOTNET "Build .NET bindings" OFF)
option(rlclientlib_DOTNET_USE_MSPROJECT "[Experimental] Use import_external_msproject to build .NET csproj files." OFF)

if(RL_BUILD_FEDERATION)
add_compile_definitions(RL_BUILD_FEDERATION)
# Needed for joiner code
set(RL_BUILD_EXTERNAL_PARSER ON CACHE BOOL "" FORCE)
endif()

if(RL_USE_ASAN)
add_compile_definitions(RL_USE_ASAN VW_USE_ASAN)
if(MSVC)
Expand All @@ -81,8 +88,9 @@ if(RL_USE_UBSAN)
if(MSVC)
message(FATAL_ERROR "UBSan not supported on MSVC")
else()
add_compile_options(-fsanitize=undefined -fno-sanitize-recover -fno-omit-frame-pointer -g3)
add_link_options(-fsanitize=undefined -fno-sanitize-recover -fno-omit-frame-pointer -g3)
# Flatbuffers gives errors with misaligned pointer sanitization, so disable it
add_compile_options(-fsanitize=undefined -fno-sanitize=alignment -fno-sanitize-recover -fno-omit-frame-pointer -g3)
add_link_options(-fsanitize=undefined -fno-sanitize=alignment -fno-sanitize-recover -fno-omit-frame-pointer -g3)
endif()
endif()

Expand Down Expand Up @@ -140,17 +148,18 @@ include(GNUInstallDirs)

include(ext_libs/ext_libs.cmake)

if(RL_BUILD_EXTERNAL_PARSER)
add_subdirectory(external_parser)
endif()

add_subdirectory(rlclientlib)
add_subdirectory(rlclientlib/extensions)

add_subdirectory(examples)
add_subdirectory(test_tools/joiner)
add_subdirectory(test_tools/sender_test)
add_subdirectory(test_tools/example_gen)

if(RL_BUILD_EXTERNAL_PARSER)
add_subdirectory(external_parser)
endif()

# enable_testing should be run after ext_libs so that the vw unit tests arent turned on.
enable_testing()

Expand All @@ -168,6 +177,13 @@ if(RL_BUILD_BENCHMARKS)
add_subdirectory(benchmarks)
endif()

# Add a target to generate all flatbuffer header files, used for clang-tidy
add_custom_target(fbgen)
add_dependencies(fbgen fbgenerator_v1 fbgenerator_v2)
if(TARGET fbgen_external_parser)
add_dependencies(fbgen fbgen_external_parser)
endif()

# Add the nuget subdirectory last
if(RL_BUILD_NUGET)
if(WIN32)
Expand Down
16 changes: 16 additions & 0 deletions examples/rl_sim_cpp/local_loop_client.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"model.vw.initial_command_line": "--cb_explore_adf --json --epsilon 0.0 --preserve_performance_counters -q :: --driver_output_off",
"ApplicationID": "local_loop",
"IsExplorationEnabled": true,
"InitialExplorationEpsilon": 0.2,
"protocol.version": "2",
"model.source": "LOCAL_LOOP_MODEL_DATA",
"interaction.sender.implementation": "LOCAL_LOOP_SENDER",
"observation.sender.implementation": "LOCAL_LOOP_SENDER",
"eud.duration": "0:0:1",
"joiner.problem.type": "PROBLEM_TYPE_CB",
"joiner.reward.function": "REWARD_FUNCTION_EARLIEST",
"joiner.learning.mode": "ONLINE",
"model.refreshintervalms": "5000",
"time_provider.implementation": "CLOCK_TIME_PROVIDER"
}
5 changes: 4 additions & 1 deletion examples/rl_sim_cpp/main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,10 @@ po::variables_map process_cmd_line(const int argc, char** argv)
"random_seed", po::value<uint64_t>()->default_value(rand()), "Random seed. Default is random")("delay",
po::value<int64_t>()->default_value(2000),
"Delay between events in ms")("quiet", po::bool_switch(), "Suppress logs")(
"random_ids", po::value<bool>()->default_value(true), "Use randomly generated Event IDs. Default is true");
"random_ids", po::value<bool>()->default_value(true), "Use randomly generated Event IDs. Default is true")(
"refresh_model_period", po::value<uint64_t>()->default_value(0),
"Call refresh model after every N examples. 0 turns off explicit model refresh and relies on background refresh. "
"Must disable background refresh in client.json with key 'model.backgroundrefresh'");

po::variables_map vm;
store(parse_command_line(argc, argv, desc), vm);
Expand Down
72 changes: 64 additions & 8 deletions examples/rl_sim_cpp/rl_sim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,21 @@ int rl_sim::cb_loop()
{
std::cout << " " << stats.count() << ", ctxt, " << p.id() << ", action, " << chosen_action << ", outcome, "
<< outcome << ", dist, " << get_dist_str(response) << ", " << stats.get_stats(p.id(), chosen_action)
<< std::endl;
<< ", ctr: " << stats.get_ctr() << std::endl;
}

// refresh model every _model_refresh_period events
std::cerr << "Current events: " << _current_events << std::endl;
if (_model_refresh_period != 0 && (_current_events % _model_refresh_period) == 0)
{
r::api_status status;
if (_rl->refresh_model(&status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
}

std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
}

Expand Down Expand Up @@ -163,6 +176,18 @@ int rl_sim::multistep_loop()
continue;
}

// refresh model every _model_refresh_period events
// Treat each episode as a single event
if (_model_refresh_period != 0 && (_current_events / episode_length) % _model_refresh_period == 0)
{
r::api_status status;
if (_rl->refresh_model(&status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
}

std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
}
return 0;
Expand Down Expand Up @@ -203,6 +228,17 @@ int rl_sim::ca_loop()
<< stats.get_stats(joint.id(), chosen_action) << std::endl;
}

// refresh model every _model_refresh_period events
if (_model_refresh_period != 0 && _current_events % _model_refresh_period == 0)
{
r::api_status status;
if (_rl->refresh_model(&status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
}

std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
}
return 0;
Expand Down Expand Up @@ -257,6 +293,17 @@ int rl_sim::ccb_loop()
index++;
}

// refresh model every _model_refresh_period events
if (_model_refresh_period != 0 && _current_events % _model_refresh_period == 0)
{
r::api_status status;
if (_rl->refresh_model(&status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
}

std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
}

Expand Down Expand Up @@ -327,6 +374,17 @@ int rl_sim::slates_loop()
continue;
}

// refresh model every _model_refresh_period events
if (_model_refresh_period != 0 && _current_events % _model_refresh_period == 0)
{
r::api_status status;
if (_rl->refresh_model(&status) != err::success)
{
std::cout << status.get_error_msg() << std::endl;
continue;
}
}

std::this_thread::sleep_for(std::chrono::milliseconds(_delay));
}

Expand Down Expand Up @@ -373,6 +431,7 @@ int rl_sim::init_rl()
{
r::api_status status;
u::configuration config;

// Load configuration from json config file
const auto config_file = _options["json_config"].as<std::string>();
if (load_config_from_json(config_file, config, &status) != err::success)
Expand All @@ -387,11 +446,7 @@ int rl_sim::init_rl()
config.set(r::name::OBSERVATION_SENDER_IMPLEMENTATION, r::value::OBSERVATION_FILE_SENDER);
}

if (!_options["get_model"].as<bool>())
{
// Set the time provider to the clock time provider
config.set(r::name::MODEL_SRC, r::value::NO_MODEL_DATA);
}
if (!_options["get_model"].as<bool>()) { config.set(r::name::MODEL_SRC, r::value::NO_MODEL_DATA); }

if (_options["log_timestamp"].as<bool>())
{
Expand Down Expand Up @@ -539,7 +594,7 @@ std::string rl_sim::create_context_json(const std::string& cntxt, const std::str

std::string rl_sim::create_event_id()
{
if (_num_events > 0 && ++_current_events >= _num_events) { _run_loop = false; }
if (++_current_events >= _num_events && _num_events > 0) { _run_loop = false; }

if (_random_ids) { return boost::uuids::to_string(boost::uuids::random_generator()()); }

Expand All @@ -548,7 +603,7 @@ std::string rl_sim::create_event_id()
return oss.str();
}

rl_sim::rl_sim(boost::program_options::variables_map vm) : _options(std::move(vm)), _loop_kind(CB)
rl_sim::rl_sim(const boost::program_options::variables_map& vm) : _options(vm), _loop_kind(CB)
{
if (_options["ccb"].as<bool>()) { _loop_kind = CCB; }
else if (_options["slates"].as<bool>()) { _loop_kind = Slates; }
Expand All @@ -560,6 +615,7 @@ rl_sim::rl_sim(boost::program_options::variables_map vm) : _options(std::move(vm
_delay = _options["delay"].as<int64_t>();
_quiet = _options["quiet"].as<bool>();
_random_ids = _options["random_ids"].as<bool>();
_model_refresh_period = _options["refresh_model_period"].as<uint64_t>();
}

std::string get_dist_str(const reinforcement_learning::ranking_response& response)
Expand Down
3 changes: 2 additions & 1 deletion examples/rl_sim_cpp/rl_sim.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class rl_sim
*
* @param vm User defined options
*/
explicit rl_sim(boost::program_options::variables_map vm);
explicit rl_sim(const boost::program_options::variables_map& vm);

/**
* @brief Simulation loop
Expand Down Expand Up @@ -177,4 +177,5 @@ class rl_sim
int64_t _delay = 2000;
bool _quiet = false;
bool _random_ids = true;
uint64_t _model_refresh_period = 0;
};
4 changes: 4 additions & 0 deletions examples/rl_sim_cpp/simulation_stats.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class simulation_stats
auto& item_count = _item_stats[id];
++item_count;
++_total_events;
_total_reward += outcome;
}

std::string get_stats(const std::string& id, T chosen_action)
Expand All @@ -29,8 +30,11 @@ class simulation_stats

int count() const { return _total_events; }

float get_ctr() const { return _total_reward / _total_events; }

private:
std::map<std::pair<std::string, T>, std::pair<int, int>> _action_stats;
std::map<std::string, int> _item_stats;
int _total_events = 0;
float _total_reward = 0.f;
};
17 changes: 13 additions & 4 deletions external_parser/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ set(RL_FLAT_BUFFER_FILES
)

add_flatbuffer_schema(
TARGET fbgen
TARGET fbgen_external_parser
SCHEMAS ${RL_FLAT_BUFFER_FILES}
OUTPUT_DIR ${CMAKE_CURRENT_LIST_DIR}/generated/v2/
FLATC_EXE ${flatc_location}
Expand Down Expand Up @@ -154,10 +154,16 @@ set(binary_parser_sources
)

add_library(rl_binary_parser STATIC ${binary_parser_headers} ${binary_parser_sources})
set_target_properties(rl_binary_parser PROPERTIES POSITION_INDEPENDENT_CODE ON)
if(WIN32)
set_target_properties(rl_binary_parser PROPERTIES DEBUG_POSTFIX d)
endif()

target_link_libraries(rl_binary_parser PUBLIC vw_core RapidJSON PRIVATE libzstd_static)
target_include_directories(rl_binary_parser
PUBLIC
${CMAKE_CURRENT_LIST_DIR}/
${CMAKE_CURRENT_LIST_DIR}/generated/v2/
${CMAKE_CURRENT_LIST_DIR}/../ext_libs/zstd/lib/
${CMAKE_CURRENT_LIST_DIR}/../ext_libs/date/
)
Expand All @@ -169,17 +175,20 @@ if(TARGET flatbuffers::flatbuffers)
else()
target_include_directories(rl_binary_parser PRIVATE ${FLATBUFFERS_INCLUDE_DIR})
endif()
add_dependencies(rl_binary_parser fbgen)
add_dependencies(rl_binary_parser fbgen_external_parser)

add_executable(rl_binary_parser_bin main.cc)
target_link_libraries(rl_binary_parser_bin PUBLIC rl_binary_parser)
set_target_properties(rl_binary_parser_bin PROPERTIES OUTPUT_NAME "vw")
if (NOT rlclientlib_BUILD_DOTNET)
# The build for .NET bindings configures all binary output to a single directory
# In this case we can't name the binary parser "vw", since this will overwrite the real vw executable
set_target_properties(rl_binary_parser_bin PROPERTIES OUTPUT_NAME "vw")
endif()

if(STATIC_LINK_BINARY_PARSER AND NOT APPLE)
target_link_libraries(rl_binary_parser_bin PRIVATE -static)
endif()


# Tests
# -----

Expand Down
Loading