Commit 2aa73e16 authored by Gauthier Quesnel's avatar Gauthier Quesnel
Browse files

core: add observation system

parent 2a42ad4a
Pipeline #10684 passed with stage
in 1 minute and 5 seconds
This diff is collapsed.
......@@ -1652,7 +1652,7 @@ public:
if (new_capacity != m_capacity) {
if (m_items) {
if constexpr (!std::is_trivial_v<T>)
for (auto i = 0u; i != m_capacity; ++i)
for (auto i = 0u; i != m_size; ++i)
m_items[i].~T();
if (m_items)
......@@ -1675,7 +1675,7 @@ public:
void clear() noexcept
{
if constexpr (!std::is_trivial_v<T>)
for (auto i = 0u; i != m_capacity; ++i)
for (auto i = 0u; i != m_size; ++i)
m_items[i].~T();
m_size = 0;
......@@ -1800,6 +1800,30 @@ public:
{
return m_capacity;
}
reference front() noexcept
{
assert(m_size > 0);
return m_items[0];
}
const_reference front() const noexcept
{
assert(m_size > 0);
return m_items[0];
}
reference back() noexcept
{
assert(m_size > 0);
return m_items[m_size - 1];
}
const_reference back() const noexcept
{
assert(m_size > 0);
return m_items[m_size - 1];
}
};
template<typename T>
......@@ -1951,6 +1975,7 @@ enum class message_id : std::uint64_t;
enum class input_port_id : std::uint64_t;
enum class output_port_id : std::uint64_t;
enum class init_port_id : std::uint64_t;
enum class observer_id : std::uint64_t;
template<typename T>
constexpr u32
......@@ -2464,7 +2489,7 @@ public:
{
data.free(id);
}
/**
* @brief Accessor to the id part of the item
*
......@@ -2610,22 +2635,21 @@ public:
return;
}
if (m_size > 0) {
if (m_size > 0) {
m_size--;
detach_subheap(elem);
elem = merge_subheaps(elem);
root = merge(root, elem);
}
else {
root = nullptr;
}
} else {
root = nullptr;
}
//assert(m_size > 0);
// assert(m_size > 0);
//m_size--;
//detach_subheap(elem);
//elem = merge_subheaps(elem);
//root = merge(root, elem);
// m_size--;
// detach_subheap(elem);
// elem = merge_subheaps(elem);
// root = merge(root, elem);
}
void pop() noexcept
......@@ -2800,11 +2824,38 @@ struct model
heap::handle handle{ nullptr };
dynamics_id id{ 0 };
observer_id obs_id{ 0 };
dynamics_type type{ dynamics_type::none };
small_string<7> name;
};
struct observer
{
observer() noexcept = default;
observer(const time time_step_, const char* name_, void* user_data_)
: time_step(std::clamp(time_step_, 0.0, time_domain<time>::infinity))
, name(name_)
, user_data(user_data_)
{}
double tl = 0.0;
double time_step = 0.0;
small_string<8> name;
void* user_data = nullptr;
void (*initialize)(const observer& obs, const time t) noexcept = nullptr;
void (*observe)(const observer& obs,
const time t,
const message& msg) noexcept = nullptr;
void (*free)(const observer& obs, const time t) noexcept = nullptr;
};
struct init_message
{
message_id msg;
......@@ -2864,6 +2915,10 @@ using transition_function_t = decltype(
status (T::*)(data_array<input_port, input_port_id>&, time, time, time),
&T::transition>{});
template<class T>
using observation_function_t =
decltype(detail::helper<message (T::*)(time) const, &T::observation>{});
template<class T>
using initialize_function_t =
decltype(detail::helper<status (T::*)(data_array<message, message_id>&),
......@@ -2962,6 +3017,16 @@ struct adder
return status::success;
}
message observation(time /*t*/) const noexcept
{
double ret = 0.0;
for (size_t i = 0; i != PortNumber; ++i)
ret += input_coeffs[i] * values[i];
return message(ret);
}
};
template<size_t PortNumber>
......@@ -3039,6 +3104,16 @@ struct mult
return status::success;
}
message observation(time /*t*/) const noexcept
{
double ret = 1.0;
for (size_t i = 0; i != PortNumber; ++i)
ret *= std::pow(values[i], input_coeffs[i]);
return message(ret);
}
};
using adder_2 = adder<2>;
......@@ -3084,6 +3159,11 @@ struct counter
{
return status::success;
}
message observation(time /*t*/) const noexcept
{
return message(number);
}
};
struct generator
......@@ -3148,6 +3228,11 @@ struct constant
return status::success;
}
message observation(time /*t*/) const noexcept
{
return message(value);
}
};
struct time_func
......@@ -3189,6 +3274,11 @@ struct time_func
return status::success;
}
message observation(time /*t*/) const noexcept
{
return message(value);
}
};
struct cross
......@@ -3293,6 +3383,11 @@ struct cross
return status::success;
}
message observation(time /*t*/) const noexcept
{
return message(value, if_value, else_value);
}
};
struct none
......@@ -3468,6 +3563,11 @@ struct integrator
return status::success;
}
message observation(time /*t*/) const noexcept
{
return message(last_output_value);
}
status ta() noexcept
{
if (st == state::running) {
......@@ -3704,6 +3804,11 @@ struct quantifier
return status::success;
}
message observation(time /*t*/) const noexcept
{
return message(m_upthreshold, m_downthreshold);
}
private:
status ta() noexcept
{
......@@ -3985,6 +4090,8 @@ struct simulation
data_array<cross, dynamics_id> cross_models;
data_array<time_func, dynamics_id> time_func_models;
data_array<observer, observer_id> observers;
scheduller sched;
time begin = time_domain<time>::zero;
......@@ -3994,8 +4101,8 @@ struct simulation
{
constexpr size_t ten{ 10 };
irt_return_if_bad(model_list_allocator.init(model_capacity));
irt_return_if_bad(message_list_allocator.init(messages_capacity));
irt_return_if_bad(model_list_allocator.init(model_capacity * ten));
irt_return_if_bad(message_list_allocator.init(messages_capacity * ten));
irt_return_if_bad(input_port_list_allocator.init(model_capacity));
irt_return_if_bad(output_port_list_allocator.init(model_capacity));
irt_return_if_bad(emitting_output_port_allocator.init(model_capacity));
......@@ -4025,6 +4132,8 @@ struct simulation
irt_return_if_bad(cross_models.init(model_capacity));
irt_return_if_bad(time_func_models.init(model_capacity));
irt_return_if_bad(observers.init(model_capacity));
return status::success;
}
......@@ -4035,7 +4144,7 @@ struct simulation
/**
* @brief cleanup simulation object
*
*
* Clean scheduller and input/output port from message. This function
* must be call at the end of the simulation.
*/
......@@ -4088,6 +4197,8 @@ struct simulation
cross_models.clear();
time_func_models.clear();
observers.clear();
begin = time_domain<time>::zero;
end = time_domain<time>::infinity;
}
......@@ -4346,6 +4457,13 @@ struct simulation
while (models.next(mdl))
irt_return_if_bad(make_initialize(*mdl, t));
irt::observer* obs = nullptr;
while (observers.next(obs)) {
obs->tl = t;
if (obs->initialize)
obs->initialize(*obs, t);
}
return status::success;
}
......@@ -4474,6 +4592,21 @@ struct simulation
if (auto* port = input_ports.try_to_get(dyn.x[i]); port)
port->messages.clear();
if constexpr (is_detected_v<observation_function_t, Dynamics>) {
if (mdl.obs_id != static_cast<observer_id>(0)) {
if (auto* observer = observers.try_to_get(mdl.obs_id);
observer && observer->observe) {
if (observer->time_step == 0.0 ||
t - observer->tl >= observer->time_step) {
observer->observe(*observer, t, dyn.observation(t));
observer->tl = t;
}
} else {
mdl.obs_id = static_cast<observer_id>(0);
}
}
}
mdl.tl = t;
mdl.tn = t + dyn.sigma;
......
......@@ -94,6 +94,43 @@ f(double t) noexcept
return t * t;
}
struct file_output
{
file_output(const char* file_path) noexcept
: os(std::fopen(file_path, "w"))
{}
~file_output() noexcept
{
if (os)
std::fclose(os);
}
std::FILE* os = nullptr;
};
void
file_output_initialize(const irt::observer& obs, const irt::time /*t*/) noexcept
{
if (!obs.user_data)
return;
auto* output = reinterpret_cast<file_output*>(obs.user_data);
fmt::print(output->os, "t,{}\n", obs.name.c_str());
}
void
file_output_observe(const irt::observer& obs,
const irt::time t,
const irt::message& msg) noexcept
{
if (!obs.user_data)
return;
auto* output = reinterpret_cast<file_output*>(obs.user_data);
fmt::print(output->os, "{},{}\n", t, msg.to_real_64(0));
}
int
main()
{
......@@ -952,27 +989,29 @@ main()
dot_graph_save(sim, stdout);
file_output fo_a("lotka-volterra_a.csv");
file_output fo_b("lotka-volterra_b.csv");
auto& obs_a = sim.observers.alloc(0.01, "A", static_cast<void*>(&fo_a));
auto& obs_b = sim.observers.alloc(0.01, "B", static_cast<void*>(&fo_b));
obs_a.initialize = &file_output_initialize;
obs_a.observe = &file_output_observe;
obs_b.initialize = &file_output_initialize;
obs_b.observe = &file_output_observe;
sim.models.get(integrator_a.id).obs_id = sim.observers.get_id(obs_a);
sim.models.get(integrator_b.id).obs_id = sim.observers.get_id(obs_b);
expect(fo_a.os != nullptr);
expect(fo_b.os != nullptr);
irt::time t = 0.0;
expect(sim.initialize(t) == irt::status::success);
!expect(sim.sched.size() == 7_ul);
std::FILE* os = std::fopen("output.csv", "w");
!expect(os != nullptr);
fmt::print(os, "t,x,y\n");
do {
auto st = sim.run(t);
expect(st == irt::status::success);
fmt::print(os,
"{},{},{}\n",
t,
integrator_a.last_output_value,
integrator_b.last_output_value);
} while (t < 15.0);
std::fclose(os);
};
"izhikevitch_simulation"_test = [] {
......@@ -1131,10 +1170,21 @@ main()
dot_graph_save(sim, stdout);
file_output fo_a("izhikevitch_a.csv");
file_output fo_b("izhikevitch_b.csv");
auto& obs_a = sim.observers.alloc(0.01, "A", static_cast<void*>(&fo_a));
auto& obs_b = sim.observers.alloc(0.01, "B", static_cast<void*>(&fo_b));
obs_a.initialize = &file_output_initialize;
obs_a.observe = &file_output_observe;
obs_b.initialize = &file_output_initialize;
obs_b.observe = &file_output_observe;
sim.models.get(integrator_a.id).obs_id = sim.observers.get_id(obs_a);
sim.models.get(integrator_b.id).obs_id = sim.observers.get_id(obs_b);
expect(fo_a.os != nullptr);
expect(fo_b.os != nullptr);
irt::time t = 0.0;
std::FILE* os = std::fopen("output_izhikevitch.csv", "w");
!expect(os != nullptr);
fmt::print(os, "t,v,u\n");
expect(irt::status::success == sim.initialize(t));
!expect(sim.sched.size() == 14_ul);
......@@ -1142,14 +1192,6 @@ main()
do {
irt::status st = sim.run(t);
expect(st == irt::status::success);
fmt::print(os,
"{},{},{}\n",
t,
integrator_a.last_output_value,
integrator_b.last_output_value);
} while (t < 120);
std::fclose(os);
};
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment