Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions bindings/cs/rl.net.cli/PerfTestCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,12 @@ private PerfTestStepProvider DoWork(string tag)
};

Console.WriteLine(stepProvider.DataSize);
RLDriver rlDriver = new RLDriver(liveModel, loopKind: this.GetLoopKind())
using (RLDriver rlDriver = new RLDriver(liveModel, loopKind: this.GetLoopKind()))
{
StepInterval = TimeSpan.FromMilliseconds(this.SleepIntervalMs)
};
rlDriver.Run(stepProvider);
rlDriver.StepInterval = TimeSpan.FromMilliseconds(this.SleepIntervalMs);
rlDriver.Run(stepProvider);
}

stepProvider.Stats.Print();
return stepProvider;
}
Expand Down
27 changes: 26 additions & 1 deletion bindings/cs/rl.net.cli/RLDriver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ internal interface IOutcomeReporter<TOutcome>
bool TryQueueOutcomeEvent(RunContext runContext, string eventId, string slotId, TOutcome outcome);
}

public class RLDriver : IOutcomeReporter<float>, IOutcomeReporter<string>
public class RLDriver : IOutcomeReporter<float>, IOutcomeReporter<string>, IDisposable
{
private LiveModel liveModel;
private LoopKind loopKind;
Expand Down Expand Up @@ -250,5 +250,30 @@ private void SafeRaiseError(ApiStatus errorStatus)
localHandler(this, errorStatus);
}
}

#region IDisposable Support
private bool disposedValue = false; // To detect redundant calls

protected virtual void Dispose(bool disposing)
{
if (!disposedValue)
{
if (disposing)
{
this.liveModel?.Dispose();
this.liveModel = null;
}

disposedValue = true;
}
}


public void Dispose()
{
// Do not change this code. Put cleanup code in Dispose(bool disposing) above.
Dispose(true);
}
#endregion
}
}
26 changes: 25 additions & 1 deletion bindings/cs/rl.net.cli/RLSimulator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ public float GetContinuousActionOutcome(float action, float pdf_value)
}
}

internal class RLSimulator
internal class RLSimulator : IDisposable
{
private RLDriver driver;

Expand Down Expand Up @@ -322,5 +322,29 @@ public event EventHandler<ApiStatus> OnError
this.driver.OnError -= value;
}
}

#region IDisposable Support
private bool disposedValue = false; // To detect redundant calls
Comment thread
kumpera marked this conversation as resolved.

protected virtual void Dispose(bool disposing)
{
Console.WriteLine("Disponsing rlsim");
if (!disposedValue)
{
if (disposing)
{
this.driver?.Dispose();
this.driver = null;
}

disposedValue = true;
}
}

public void Dispose()
{
Dispose(true);
}
#endregion
}
}
16 changes: 9 additions & 7 deletions bindings/cs/rl.net.cli/ReplayCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,17 @@ class ReplayCommand : CommandBase
public override void Run()
{
LiveModel liveModel = Helpers.CreateLiveModelOrExit(this.ConfigPath);
RLDriver rlDriver = new RLDriver(liveModel, loopKind: this.GetLoopKind());
rlDriver.StepInterval = TimeSpan.FromMilliseconds(this.SleepIntervalMs);

using (TextReader textReader = File.OpenText(this.LogPath))
using (RLDriver rlDriver = new RLDriver(liveModel, loopKind: this.GetLoopKind()))
{
IEnumerable<string> dsJsonLines = textReader.LazyReadLines();
ReplayStepProvider stepProvider = new ReplayStepProvider(dsJsonLines);
rlDriver.StepInterval = TimeSpan.FromMilliseconds(this.SleepIntervalMs);

using (TextReader textReader = File.OpenText(this.LogPath))
{
IEnumerable<string> dsJsonLines = textReader.LazyReadLines();
ReplayStepProvider stepProvider = new ReplayStepProvider(dsJsonLines);

rlDriver.Run(stepProvider);
rlDriver.Run(stepProvider);
}
}
}
}
Expand Down
10 changes: 6 additions & 4 deletions bindings/cs/rl.net.cli/RunSimulatorCommand.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@ public override void Run()
{
LiveModel liveModel = Helpers.CreateLiveModelOrExit(this.ConfigPath);

RLSimulator rlSim = new RLSimulator(liveModel, loopKind: this.GetLoopKind());
rlSim.StepInterval = TimeSpan.FromMilliseconds(this.SleepIntervalMs);
rlSim.OnError += (sender, apiStatus) => Helpers.WriteStatusAndExit(apiStatus);
rlSim.Run(this.Steps);
using (RLSimulator rlSim = new RLSimulator(liveModel, loopKind: this.GetLoopKind()))
{
rlSim.StepInterval = TimeSpan.FromMilliseconds(this.SleepIntervalMs);
rlSim.OnError += (sender, apiStatus) => Helpers.WriteStatusAndExit(apiStatus);
rlSim.Run(this.Steps);
}
}
}
}
2 changes: 2 additions & 0 deletions bindings/cs/rl.net.native/binding_tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ namespace rl_net_native {
: context(_context) {
}

binding_tracer::~binding_tracer() { }

void binding_tracer::log(int log_level, const std::string& msg) {
if (context.trace_logger_callback != nullptr) {
context.trace_logger_callback(log_level, msg.c_str());
Expand Down
1 change: 1 addition & 0 deletions bindings/cs/rl.net.native/binding_tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ namespace rl_net_native {
// Inherited via i_trace
binding_tracer(livemodel_context& _context);
void log(int log_level, const std::string &msg) override;
virtual ~binding_tracer();
private:
livemodel_context& context;
};
Expand Down
4 changes: 3 additions & 1 deletion rlclientlib/dedup.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,7 @@ class dedup_extensions : public logger::i_logger_extensions
logger::i_logger_extensions(c), _dedup_state(c, use_compression, use_dedup, time_provider), _use_dedup(use_dedup), _use_compression(use_compression) {}

logger::i_async_batcher<generic_event>* create_batcher(logger::i_message_sender* sender, utility::watchdog& watchdog,
error_callback_fn* perror_cb, const char* section) override {
error_callback_fn* perror_cb, i_trace* trace, const char* section) override {
auto config = utility::get_batcher_config(_config, section);

if(_use_dedup) {
Expand All @@ -290,6 +290,7 @@ class dedup_extensions : public logger::i_logger_extensions
watchdog,
_dedup_state,
perror_cb,
trace,
config);
} else {
int _dummy = 0;
Expand All @@ -298,6 +299,7 @@ class dedup_extensions : public logger::i_logger_extensions
watchdog,
_dummy,
perror_cb,
trace,
config);
}

Expand Down
11 changes: 9 additions & 2 deletions rlclientlib/live_model_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,13 @@ namespace reinforcement_learning {
_learning_mode = learning::to_learning_mode(_configuration.get(name::LEARNING_MODE, value::LEARNING_MODE_ONLINE));
}

live_model_impl::~live_model_impl() {
if (_interaction_logger)
Comment thread
ataymano marked this conversation as resolved.
_interaction_logger->flush();
if (_outcome_logger)
_outcome_logger->flush();
}

int live_model_impl::init_trace(api_status* status) {
const auto trace_impl = _configuration.get(name::TRACE_LOG_IMPLEMENTATION, value::NULL_TRACE_LOGGER);
i_trace* plogger;
Expand Down Expand Up @@ -442,7 +449,7 @@ namespace reinforcement_learning {
RETURN_IF_FAIL(_time_provider_factory->create(&ranking_time_provider, time_provider_impl, _configuration, _trace_logger.get(), status));

// Create a logger for interactions that will use msg sender to send interaction messages
_interaction_logger.reset(new logger::interaction_logger_facade(_model->model_type(), _configuration, ranking_msg_sender, _watchdog, ranking_time_provider, *_logger_extensions.get(), &_error_cb));
_interaction_logger.reset(new logger::interaction_logger_facade(_model->model_type(), _configuration, ranking_msg_sender, _watchdog, ranking_time_provider, *_logger_extensions.get(), _trace_logger.get(), &_error_cb));
RETURN_IF_FAIL(_interaction_logger->init(status));

// Get the name of raw data (as opposed to message) sender for observations.
Expand All @@ -463,7 +470,7 @@ namespace reinforcement_learning {
RETURN_IF_FAIL(_time_provider_factory->create(&observation_time_provider, time_provider_impl, _configuration, _trace_logger.get(), status));

// Create a logger for observations that will use msg sender to send observation messages
_outcome_logger.reset(new logger::observation_logger_facade(_configuration, outcome_msg_sender, _watchdog, observation_time_provider, &_error_cb));
_outcome_logger.reset(new logger::observation_logger_facade(_configuration, outcome_msg_sender, _watchdog, observation_time_provider, _trace_logger.get(), &_error_cb));
RETURN_IF_FAIL(_outcome_logger->init(status));

return error_code::success;
Expand Down
1 change: 1 addition & 0 deletions rlclientlib/live_model_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ namespace reinforcement_learning
model_factory_t* m_factory,
sender_factory_t* sender_factory,
time_provider_factory_t* time_provider_factory);
~live_model_impl();

live_model_impl(const live_model_impl&) = delete;
live_model_impl(live_model_impl&&) = delete;
Expand Down
26 changes: 18 additions & 8 deletions rlclientlib/logger/async_batcher.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#include "message_sender.h"
#include "utility/config_helper.h"
#include "utility/object_pool.h"
#include "trace_logger.h"
#include "str_util.h"

namespace reinforcement_learning {
class error_callback_fn;
Expand All @@ -30,6 +32,7 @@ namespace reinforcement_learning { namespace logger {
virtual int append(TEvent& evt, api_status* status = nullptr) = 0;

virtual int run_iteration(api_status* status) = 0;
virtual void flush() = 0; //TODO surface errors
};

// This class takes uses a queue and a background thread to accumulate events, and send them by batch asynchronously.
Expand All @@ -46,27 +49,30 @@ namespace reinforcement_learning { namespace logger {

int run_iteration(api_status* status) override;

void flush() override; //flush all batches

private:
int fill_buffer(std::shared_ptr<utility::data_buffer>& retbuffer,
size_t& remaining,
api_status* status);

void flush(); //flush all batches

public:
async_batcher(i_message_sender* sender,
utility::watchdog& watchdog,
shared_state_t& shared_state,
error_callback_fn* perror_cb,
i_trace *tracer,
Comment thread
kumpera marked this conversation as resolved.
const utility::async_batcher_config& config);
~async_batcher();
virtual ~async_batcher();

private:
std::unique_ptr<i_message_sender> _sender;

event_queue<TEvent> _queue; // A queue to accumulate batch of events.
size_t _send_high_water_mark;
error_callback_fn* _perror_cb;
i_trace *_trace;
shared_state_t& _shared_state;

utility::periodic_background_proc<async_batcher> _periodic_background_proc;
Expand Down Expand Up @@ -122,27 +128,32 @@ namespace reinforcement_learning { namespace logger {
TEvent evt;
TSerializer<TEvent> collection_serializer(*buffer.get(), _batch_content_encoding, _shared_state);

int event_count = 0;
while (remaining > 0 && collection_serializer.size() < _send_high_water_mark) {
if (_queue.pop(&evt)) {
if (queue_mode_enum::BLOCK == _queue_mode) {
_cv.notify_one();
}
RETURN_IF_FAIL(collection_serializer.add(evt, status));
--remaining;
++event_count;
}
}

RETURN_IF_FAIL(collection_serializer.finalize(status));

TRACE_INFO(_trace, utility::concat("async_batcher.fill_buffer: created batch with ",
event_count, " events and ",
collection_serializer.size(), " bytes"));

return error_code::success;
}

template<typename TEvent, template<typename> class TSerializer>
void async_batcher<TEvent, TSerializer>::flush() {
const auto queue_size = _queue.size();

// Early exit if queue is empty.
if (queue_size == 0) {
TRACE_INFO(_trace, "async_batcher.flush: empty queue");
return;
}

Expand All @@ -157,7 +168,7 @@ namespace reinforcement_learning { namespace logger {
ERROR_CALLBACK(_perror_cb, status);
}

if (_sender->send(TSerializer<TEvent>::message_id(), buffer, &status) != error_code::success) {
if (_sender->send(TSerializer<TEvent>::message_id(), buffer, &status) != error_code::success) {
ERROR_CALLBACK(_perror_cb, status);
}
}
Expand All @@ -169,11 +180,13 @@ namespace reinforcement_learning { namespace logger {
utility::watchdog& watchdog,
typename TSerializer<TEvent>::shared_state_t& shared_state,
error_callback_fn* perror_cb,
i_trace *trace,
const utility::async_batcher_config& config)
: _sender(sender)
, _queue(config.send_queue_max_capacity)
, _send_high_water_mark(config.send_high_water_mark)
, _perror_cb(perror_cb)
, _trace(trace)
, _shared_state(shared_state)
, _periodic_background_proc(static_cast<int>(config.send_batch_interval_ms), watchdog, "Async batcher thread", perror_cb)
, _pass_prob(0.5)
Expand All @@ -185,8 +198,5 @@ namespace reinforcement_learning { namespace logger {
async_batcher<TEvent, TSerializer>::~async_batcher() {
// Stop the background procedure the queue before exiting
_periodic_background_proc.stop();
if (_queue.size() > 0) {
flush();
}
}
}}
9 changes: 9 additions & 0 deletions rlclientlib/logger/event_logger.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ namespace reinforcement_learning { namespace logger {

int init(api_status* status);

void flush();

protected:
int append(TEvent&& item, api_status* status);
int append(TEvent& item, api_status* status);
Expand All @@ -55,6 +57,13 @@ namespace reinforcement_learning { namespace logger {
return error_code::success;
}


template<typename TEvent>
void event_logger<TEvent>::flush()
{
_batcher->flush();
}

template<typename TEvent>
int event_logger<TEvent>::append(TEvent&& item, api_status* status) {
if (!_initialized) {
Expand Down
3 changes: 2 additions & 1 deletion rlclientlib/logger/logger_extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@ class default_extensions : public i_logger_extensions
delete provider; //We don't use it
}

i_async_batcher<generic_event>* create_batcher(i_message_sender* sender, utility::watchdog& watchdog, error_callback_fn* perror_cb, const char* section) override {
i_async_batcher<generic_event>* create_batcher(i_message_sender* sender, utility::watchdog& watchdog, error_callback_fn* perror_cb, i_trace* trace, const char* section) override {
auto config = utility::get_batcher_config(_config, section);
int _dummy = 0;
return new async_batcher<generic_event, fb_collection_serializer>(
sender,
watchdog,
_dummy,
perror_cb,
trace,
config);
}

Expand Down
Loading