Commit b6d0579e authored by Gauthier Quesnel's avatar Gauthier Quesnel
Browse files

core: remove value-type in message

parent 5e545d72
...@@ -162,7 +162,7 @@ observation_output_observe(const irt::observer& obs, ...@@ -162,7 +162,7 @@ observation_output_observe(const irt::observer& obs,
return; return;
auto* output = reinterpret_cast<observation_output*>(obs.user_data); auto* output = reinterpret_cast<observation_output*>(obs.user_data);
const auto value = static_cast<float>(msg.cast_to_real_64(0)); const auto value = static_cast<float>(msg[0]);
if (match(output->observation_type, if (match(output->observation_type,
observation_output::type::multiplot, observation_output::type::multiplot,
......
...@@ -51,7 +51,7 @@ file_output_observe(const irt::observer& obs, ...@@ -51,7 +51,7 @@ file_output_observe(const irt::observer& obs,
return; return;
auto* output = reinterpret_cast<file_output*>(obs.user_data); auto* output = reinterpret_cast<file_output*>(obs.user_data);
fmt::print(output->os, "{},{}\n", t, msg.to_real_64(0)); fmt::print(output->os, "{},{}\n", t, msg.real[0]);
} }
......
...@@ -9,13 +9,18 @@ ...@@ -9,13 +9,18 @@
#include <limits> #include <limits>
#include <string_view> #include <string_view>
#include <cassert>
#include <cmath> #include <cmath>
#include <cstdint> #include <cstdint>
#include <cstring> #include <cstring>
#include <vector> #include <vector>
// You can override the default assert handler by editing imconfig.h
#ifndef irt_assert
#include <cassert>
#define irt_assert(_expr) assert(_expr)
#endif
namespace irt { namespace irt {
using i8 = int8_t; using i8 = int8_t;
...@@ -76,11 +81,9 @@ enum class status ...@@ -76,11 +81,9 @@ enum class status
model_adder_empty_init_message, model_adder_empty_init_message,
model_adder_bad_init_message, model_adder_bad_init_message,
model_adder_bad_external_message,
model_mult_empty_init_message, model_mult_empty_init_message,
model_mult_bad_init_message, model_mult_bad_init_message,
model_mult_bad_external_message,
model_integrator_dq_error, model_integrator_dq_error,
model_integrator_X_error, model_integrator_X_error,
...@@ -89,20 +92,14 @@ enum class status ...@@ -89,20 +92,14 @@ enum class status
model_integrator_output_error, model_integrator_output_error,
model_integrator_running_without_x_dot, model_integrator_running_without_x_dot,
model_integrator_ta_with_bad_x_dot, model_integrator_ta_with_bad_x_dot,
model_integrator_bad_external_message,
model_quantifier_bad_quantum_parameter, model_quantifier_bad_quantum_parameter,
model_quantifier_bad_archive_length_parameter, model_quantifier_bad_archive_length_parameter,
model_quantifier_shifting_value_neg, model_quantifier_shifting_value_neg,
model_quantifier_shifting_value_less_1, model_quantifier_shifting_value_less_1,
model_quantifier_bad_external_message,
model_cross_bad_external_message,
model_time_func_bad_init_message, model_time_func_bad_init_message,
model_accumulator_bad_external_message,
gui_not_enough_memory, gui_not_enough_memory,
io_file_format_error, io_file_format_error,
...@@ -454,78 +451,6 @@ public: ...@@ -454,78 +451,6 @@ public:
* *
****************************************************************************/ ****************************************************************************/
enum class value_type : i8
{
none,
integer_8,
integer_32,
integer_64,
real_32,
real_64
};
template<typename T, size_t length>
struct span
{
public:
using value_type = T;
using difference_type = std::ptrdiff_t;
using pointer = T*;
using reference = T&;
using const_reference = T&;
using iterator = T*;
using const_iterator = const T*;
private:
const T* data_;
public:
constexpr span(const T* ptr)
: data_(ptr)
{}
constexpr T* data() noexcept
{
return data_;
}
constexpr const T* data() const noexcept
{
return data_;
}
constexpr size_t size() const noexcept
{
return length;
}
constexpr T operator[](size_t i) const noexcept
{
assert(i < length);
return data_[i];
}
constexpr iterator begin() noexcept
{
return iterator{ data_ };
}
constexpr iterator end() noexcept
{
return iterator{ data_ + length };
}
constexpr const_iterator begin() const noexcept
{
return iterator{ data_ };
}
constexpr const_iterator end() const noexcept
{
return iterator{ data_ + length };
}
};
template<class T, class... Rest> template<class T, class... Rest>
constexpr bool constexpr bool
are_all_same() noexcept are_all_same() noexcept
...@@ -538,188 +463,41 @@ struct message ...@@ -538,188 +463,41 @@ struct message
using size_type = std::size_t; using size_type = std::size_t;
using difference_type = std::ptrdiff_t; using difference_type = std::ptrdiff_t;
union double real[3];
{ i8 length;
i8 integer_8[32]; // 32 bytes
i32 integer_32[8]; // 8 * 4 bytes
i64 integer_64[4]; // 4 * 8 bytes
float real_32[8]; // 8 * 4 bytes
double real_64[4]; // 4 * 8 bytes
};
u8 length;
value_type type;
constexpr std::size_t size() const noexcept constexpr std::size_t size() const noexcept
{ {
return length; assert(length >= 0);
return static_cast<std::size_t>(length);
} }
constexpr message() noexcept constexpr message() noexcept
: integer_8{ 0 } : real{ 0.0, 0.0, 0.0 }
, length{ 0 } , length{ 0 }
, type{ value_type::none }
{} {}
template<typename... T> template<typename... T>
constexpr message(T... args) noexcept constexpr message(T... args) noexcept
{ {
if constexpr (are_all_same<i8, T...>()) { if constexpr (are_all_same<double, T...>()) {
static_assert(sizeof...(args) <= 32, "i8 message limited to 32"); static_assert(sizeof...(args) <= 3, "double message limited to 3");
using unused = i8[];
length = 0;
(void)unused{ i8(0), (integer_8[length++] = args, i8(0))... };
type = value_type::integer_8;
} else if constexpr (are_all_same<i32, T...>()) {
static_assert(sizeof...(args) <= 8, "i32 message limited to 32");
using unused = i32[];
length = 0;
(void)unused{ i32(0), (integer_32[length++] = args, i32(0))... };
type = value_type::integer_32;
} else if constexpr (are_all_same<i64, T...>()) {
static_assert(sizeof...(args) <= 4, "i64 message limited to 32");
using unused = i64[];
length = 0;
(void)unused{ i64(0), (integer_64[length++] = args, i64(0))... };
type = value_type::integer_64;
} else if constexpr (are_all_same<float, T...>()) {
static_assert(sizeof...(args) <= 8, "float message limited to 8");
using unused = float[];
length = 0;
(void)unused{ 0.0f, (real_32[length++] = args, 0.0f)... };
type = value_type::real_32;
return;
} else if constexpr (are_all_same<double, T...>()) {
static_assert(sizeof...(args) <= 4, "double message limited to 4");
using unused = double[]; using unused = double[];
length = 0; length = 0;
(void)unused{ 0.0, (real_64[length++] = args, 0.0)... }; (void)unused{ 0.0, (real[length++] = args, 0.0)... };
type = value_type::real_64;
}
} }
constexpr span<i8, 32> to_integer_8() const
{
assert(type == value_type::integer_8);
return span<i8, 32>(integer_8);
}
constexpr span<i32, 8> to_integer_32() const
{
assert(type == value_type::integer_32);
return span<i32, 8>(integer_32);
}
constexpr span<i64, 4> to_integer_64() const
{
assert(type == value_type::integer_64);
return span<i64, 4>(integer_64);
}
constexpr span<float, 8> to_real_32() const
{
assert(type == value_type::real_32);
return span<float, 8>(real_32);
}
constexpr span<double, 4> to_real_64() const
{
assert(type == value_type::real_64);
return span<double, 4>(real_64);
}
template<typename T>
constexpr i8 to_integer_8(T i) const
{
static_assert(std::is_integral_v<T>, "need [unsigned] integer");
if constexpr (std::is_signed_v<T>)
assert(i >= 0);
assert(type == value_type::integer_8);
assert(i < static_cast<T>(length));
return integer_8[i];
}
template<typename T>
constexpr i32 to_integer_32(T i) const
{
static_assert(std::is_integral_v<T>, "need [unsigned] integer");
if constexpr (std::is_signed_v<T>)
assert(i >= 0);
assert(type == value_type::integer_32);
assert(i < static_cast<T>(length));
return integer_32[i];
}
template<typename T>
constexpr i64 to_integer_64(T i) const
{
static_assert(std::is_integral_v<T>, "need [unsigned] integer");
if constexpr (std::is_signed_v<T>)
assert(i >= 0);
assert(type == value_type::integer_64);
assert(i < static_cast<T>(length));
return integer_64[i];
} }
template<typename T> double operator[](const difference_type i) const noexcept
constexpr float to_real_32(T i) const
{ {
static_assert(std::is_integral_v<T>, "need [unsigned] integer"); irt_assert(i < static_cast<std::ptrdiff_t>(length));
return real[i];
if constexpr (std::is_signed_v<T>)
assert(i >= 0);
assert(type == value_type::real_32);
assert(i < static_cast<T>(length));
return real_32[i];
}
template<typename T>
constexpr double to_real_64(T i) const
{
static_assert(std::is_integral_v<T>, "need [unsigned] integer");
if constexpr (std::is_signed_v<T>)
assert(i >= 0);
assert(type == value_type::real_64);
assert(i < static_cast<T>(length));
return real_64[i];
} }
template<typename T> double& operator[](const difference_type i) noexcept
constexpr double cast_to_real_64(T i) const
{ {
if constexpr (std::is_signed_v<T>) irt_assert(i < static_cast<std::ptrdiff_t>(length));
assert(i >= 0); return real[i];
switch (type) {
case value_type::integer_8:
return static_cast<double>(to_integer_8(i));
case value_type::integer_32:
return static_cast<double>(to_integer_32(i));
case value_type::integer_64:
return static_cast<double>(to_integer_64(i));
case value_type::real_32:
return static_cast<double>(to_real_32(i));
case value_type::real_64:
return static_cast<double>(to_real_64(i));
default:
return 0.0;
}
return 0.0;
} }
}; };
...@@ -3137,12 +2915,10 @@ struct integrator ...@@ -3137,12 +2915,10 @@ struct integrator
time t) noexcept time t) noexcept
{ {
for (const auto& msg : port_quanta.messages) { for (const auto& msg : port_quanta.messages) {
irt_return_if_fail(msg.type == value_type::real_64 && irt_assert(msg.size() == 2);
msg.size() == 2,
status::model_integrator_bad_external_message);
up_threshold = msg.to_real_64(0); up_threshold = msg.real[0];
down_threshold = msg.to_real_64(1); down_threshold = msg.real[1];
if (st == state::wait_for_quanta) if (st == state::wait_for_quanta)
st = state::running; st = state::running;
...@@ -3152,11 +2928,9 @@ struct integrator ...@@ -3152,11 +2928,9 @@ struct integrator
} }
for (const auto& msg : port_x_dot.messages) { for (const auto& msg : port_x_dot.messages) {
irt_return_if_fail(msg.type == value_type::real_64 && irt_assert(msg.size() == 1);
msg.size() == 1,
status::model_integrator_bad_external_message);
archive.emplace_back(msg.to_real_64(0), t); archive.emplace_back(msg.real[0], t);
if (st == state::wait_for_x_dot) if (st == state::wait_for_x_dot)
st = state::running; st = state::running;
...@@ -3166,11 +2940,9 @@ struct integrator ...@@ -3166,11 +2940,9 @@ struct integrator
} }
for (const auto& msg : port_reset.messages) { for (const auto& msg : port_reset.messages) {
irt_return_if_fail(msg.type == value_type::real_64 && irt_assert(msg.size() == 1);
msg.size() == 1,
status::model_integrator_bad_external_message);
reset_value = msg.to_real_64(0); reset_value = msg.real[0];
reset = true; reset = true;
} }
...@@ -3299,16 +3071,11 @@ struct integrator ...@@ -3299,16 +3071,11 @@ struct integrator
val += (t - archive.back().date) * archive.back().x_dot; val += (t - archive.back().date) * archive.back().x_dot;
if(up_threshold < val) if (up_threshold < val) {
{
return up_threshold; return up_threshold;
} } else if (down_threshold > val) {
else if (down_threshold > val)
{
return down_threshold; return down_threshold;
} } else {
else
{
return val; return val;
} }
} }
...@@ -3424,19 +3191,15 @@ struct qss1_integrator ...@@ -3424,19 +3191,15 @@ struct qss1_integrator
bool reset = false; bool reset = false;
for (const auto& msg : port_x.messages) { for (const auto& msg : port_x.messages) {
irt_return_if_fail(msg.type == value_type::real_64 && irt_assert(msg.size() == 1);
msg.size() == 1,
status::model_integrator_bad_external_message);
value_x = msg.to_real_64(0); value_x = msg.real[0];
} }
for (const auto& msg : port_r.messages) { for (const auto& msg : port_r.messages) {
irt_return_if_fail(msg.type == value_type::real_64 && irt_assert(msg.size() == 1);
msg.size() == 1,
status::model_integrator_bad_external_message);
X = msg.to_real_64(0); X = msg.real[0];
reset = true; reset = true;
} }
...@@ -3609,20 +3372,16 @@ struct qss2_integrator ...@@ -3609,20 +3372,16 @@ struct qss2_integrator
bool reset = false; bool reset = false;
for (const auto& msg : port_x.messages) { for (const auto& msg : port_x.messages) {
irt_return_if_fail(msg.type == value_type::real_64 && irt_assert(msg.size() == 2);
msg.size() == 2,
status::model_integrator_bad_external_message);
value_x = msg.to_real_64(0); value_x = msg.real[0];
value_slope = msg.to_real_64(1); value_slope = msg.real[1];
} }
for (const auto& msg : port_r.messages) { for (const auto& msg : port_r.messages) {
irt_return_if_fail(msg.type == value_type::real_64 && irt_assert(msg.size() == 1);
msg.size() == 1,
status::model_integrator_bad_external_message);
X = msg.to_real_64(0); X = msg.real[0];
reset = true; reset = true;
} }
...@@ -3712,8 +3471,8 @@ struct qss2_sum ...@@ -3712,8 +3471,8 @@ struct qss2_sum
values[i] += slopes[i] * e; values[i] += slopes[i] * e;
} else { } else {
for (const auto& msg : input_ports.get(x[i]).messages) { for (const auto& msg : input_ports.get(x[i]).messages) {
values[i] = msg.to_real_64(0); values[i] = msg[0];
slopes[i] = msg.size() > 1 ? msg.to_real_64(1) : 0.0; slopes[i] = msg.size() > 1 ? msg[1] : 0.0;
message = true; message = true;
} }
} }
...@@ -3778,15 +3537,15 @@ struct qss2_multiplier ...@@ -3778,15 +3537,15 @@ struct qss2_multiplier
sigma = time_domain<time>::infinity; sigma = time_domain<time>::infinity;
for (const auto& msg : input_ports.get(x[0]).messages) { for (const auto& msg : input_ports.get(x[0]).messages) {
values[0] = msg.to_real_64(0); values[0] = msg.real[0];
slopes[0] = msg.size() > 1 ? msg.to_real_64(1) : 0.0; slopes[0] = msg.size() > 1 ? msg.real[1] : 0.0;
message_port_0 = true; message_port_0 = true;
sigma = time_domain<time>::zero; sigma = time_domain<time>::zero;
} }
for (const auto& msg : input_ports.get(x[1]).messages) { for (const auto& msg : input_ports.get(x[1]).messages) {
values[1] = msg.to_real_64(0); values[1] = msg.real[0];
slopes[1] = msg.size() > 1 ? msg.to_real_64(1) : 0.0; slopes[1] = msg.size() > 1 ? msg.real[1] : 0.0;
message_port_1 = true; message_port_1 = true;
sigma = time_domain<time>::zero; sigma = time_domain<time>::zero;
} }
...@@ -3864,8 +3623,8 @@ struct qss2_wsum ...@@ -3864,8 +3623,8 @@ struct qss2_wsum
values[i] += slopes[i] * e; values[i] += slopes[i] * e;
} else { } else {