Commit 04ad4293 authored by Damien Leroux's avatar Damien Leroux
Browse files

Pipeline works, but there is an issue with big products.

parent 4f0af463
......@@ -4,6 +4,7 @@
#include "pedigree.h"
#include <cstring>
#include <unordered_map>
#include <boost/dynamic_bitset.hpp>
......@@ -16,13 +17,17 @@
struct state_index_type {
/* TODO make data an array in case we need more than 64 bits of indexing space. */
typedef uint64_t block_type;
block_type data;
static constexpr size_t n_blocks = 4;
std::array<block_type, n_blocks> data;
inline
state_index_type&
operator = (const state_index_type& other)
{
data = other.data;
for (size_t i = 0; i < n_blocks; ++i) {
data[i] = other.data[i];
}
return *this;
}
......@@ -30,7 +35,10 @@ struct state_index_type {
state_index_type&
operator = (uint64_t value)
{
data = value;
data[0] = value;
for (size_t i = 1; i < n_blocks; ++i) {
data[i] = 0;
}
return *this;
}
......@@ -38,11 +46,15 @@ struct state_index_type {
size_t
hash() const
{
return std::hash<uint64_t>()(data);
size_t ret = std::hash<block_type>()(data[0]);
for (size_t i = 0; i < n_blocks; ++i) {
ret = (ret * 0xd3adb3ef) ^ std::hash<block_type>()(data[i]);
}
return ret;
}
inline
void reset() { data = 0; }
void reset() { *this = 0; }
/* Aide à la compilation des données nécessaires à calculer un produit.
* Donne le bitshift suivant en s'assurant qu'on ne franchit pas une barrière de 64 bits.
......@@ -91,13 +103,13 @@ struct state_index_type {
bool operator <= (const state_index_type& other) const { return data <= other.data; }
inline
bool is_bad() const { return data == (uint64_t) -1; }
bool is_bad() const { return data[n_blocks - 1] == (uint64_t) -1; }
static constexpr
state_index_type
bad()
{
return {(block_type) -1};
return {{~0ULL, ~0ULL, ~0ULL, ~0ULL}};
}
friend
......@@ -105,13 +117,30 @@ struct state_index_type {
std::ostream&
operator << (std::ostream& os, const state_index_type& s)
{
uint64_t mask = 1ULL << 63;
while (mask && !(s.data & mask)) { mask >>= 1; }
while (mask) { os << (!!(s.data & mask)); mask >>= 1; }
for (int i = n_blocks - 1; i >= 0; ++i) {
uint64_t mask = 1ULL << 63;
while (mask && !(s.data[i] & mask)) { mask >>= 1; }
while (mask) { os << (!!(s.data[i] & mask)); mask >>= 1; }
}
return os;
}
protected:
/* accède au bon indice pour un shift donné */
inline
block_type&
select_data(size_t shift)
{
return data[shift >> 6];
}
inline
const block_type&
select_data(size_t shift) const
{
return data[shift >> 6];
}
/* Donne la valeur max shiftée d'un indice, pour comparaison avec extract_unshifted() */
static
inline
......@@ -127,15 +156,15 @@ protected:
extract(size_t shift, block_type mask) const
{
/* data offset is shift >> 6 */
return (data & mask) >> (shift & 0x3f);
return (select_data(shift) & mask) >> (shift & 0x3f);
}
inline
void
assign(size_t shift, block_type mask, uint64_t value)
{
data &= ~mask;
data |= value << (shift & 0x3f);
select_data(shift) &= ~mask;
select_data(shift) |= value << (shift & 0x3f);
}
/* Incrémente un indice et retourne carry */
......@@ -143,12 +172,13 @@ protected:
bool
incr(size_t shift, block_type mask, block_type max)
{
block_type sub = (data & mask);
auto& D = select_data(shift);
block_type sub = (D & mask);
if (sub == max) {
data &= ~mask;
D &= ~mask;
return true;
} else {
data += 1 << (shift & 0x3f);
D += 1 << (shift & 0x3f);
return false;
}
}
......@@ -158,8 +188,7 @@ protected:
void
zero(size_t shift, block_type mask)
{
(void) shift;
data &= ~mask;
select_data(shift) &= ~mask;
}
/* fusionne deux curseurs étant donné un masque (utilise les dimensions de *this pour mask et d'other pour !mask) */
......@@ -169,7 +198,9 @@ protected:
{
/* iterate over block index */
state_index_type ret;
ret.data = data | (other.data & ~mask.data);
for (size_t i = 0; i < n_blocks; ++i) {
ret.data[i] = data[i] | (other.data[i] & ~mask.data[i]);
}
return ret;
}
......@@ -178,8 +209,7 @@ protected:
block_type
extract_unshifted(size_t shift, block_type mask) const
{
(void) shift;
return data & mask;
return select_data(shift) & mask;
}
/* calcule le résultat de *this & m. Permet de passer d'un curseur global à un curseur local à une table. */
......@@ -188,7 +218,9 @@ protected:
mask(const state_index_type& m) const
{
state_index_type ret;
ret.data = data & m.data;
for (size_t i = 0; i < n_blocks; ++i) {
ret.data[i] = data[i] & m.data[i];
}
return ret;
}
};
......
......@@ -467,7 +467,7 @@ struct factor_graph_type : public recursive_graph_type<factor_graph_type> {
if (std::find(input_gen.begin(), input_gen.end(), p.gen_name) != input_gen.end()) {
io |= Input;
}
if (std::find(output_gen.begin(), output_gen.end(), p.gen_name) != output_gen.end()) {
if (output_gen.size() == 0 || std::find(output_gen.begin(), output_gen.end(), p.gen_name) != output_gen.end()) {
io |= Output;
}
g->io[p.id] = io;
......@@ -884,7 +884,7 @@ struct factor_graph_type : public recursive_graph_type<factor_graph_type> {
get_joint_domain(const var_vec& varset)
{
/*MSG_QUEUE_FLUSH();*/
scoped_indent _(MESSAGE("[get_joint_domain " << varset << "] "));
/*scoped_indent _(MESSAGE("[get_joint_domain " << varset << "] "));*/
/*MSG_DEBUG("[get_joint_domain " << varset << "] ");*/
auto path = find_vpath(varset);
message_type ret;
......@@ -896,7 +896,7 @@ struct factor_graph_type : public recursive_graph_type<factor_graph_type> {
for (node_index_type n: path) {
accumulate(tmp, get_node_domain(n), domains);
tmp %= varset + variables_of(n);
MSG_DEBUG("tmp " << tmp);
/*MSG_DEBUG("tmp " << tmp);*/
}
accumulate(accum, tmp, domains);
accum %= varset;
......@@ -985,7 +985,7 @@ struct factor_graph_type : public recursive_graph_type<factor_graph_type> {
void
compute_node_domain(node_index_type n)
{
scoped_indent _(MESSAGE("[compute_node_domain #" << n << "] "));
/*scoped_indent _(MESSAGE("[compute_node_domain #" << n << "] "));*/
auto inputs = nei_in(n);
auto vv = variables_of(n);
if (node_is_interface(n)) {
......@@ -1006,9 +1006,9 @@ struct factor_graph_type : public recursive_graph_type<factor_graph_type> {
} else if (node_is_subgraph(n)) {
subgraph(n)->compute_domains_and_factors();
} else /* factor */ {
scoped_indent _(MESSAGE("[compute_factor {" << rule_of(n) << "} => " << own_variable_of(n) << "] "));
MSG_DEBUG(" ### ### COMPUTING FACTOR {" << rule_of(n) << "} => " << own_variable_of(n) << " ### ###");
MSG_QUEUE_FLUSH();
/*scoped_indent _(MESSAGE("[compute_factor {" << rule_of(n) << "} => " << own_variable_of(n) << "] "));*/
/*MSG_DEBUG(" ### ### COMPUTING FACTOR {" << rule_of(n) << "} => " << own_variable_of(n) << " ### ###");*/
/*MSG_QUEUE_FLUSH();*/
joint_variable_product_type jvp;
var_vec varset = variables_of(n);
......@@ -1036,9 +1036,9 @@ struct factor_graph_type : public recursive_graph_type<factor_graph_type> {
/*jvp.compile(domains);*/
/*auto jpar_dom = jvp.compute();*/
genotype_comb_type jpar_dom;
MSG_QUEUE_FLUSH();
/*MSG_QUEUE_FLUSH();*/
if (stack.size() > 1) {
MSG_DEBUG("Domains " << (parent() ? MESSAGE(parent() << "->" << index_in_parent()) : std::string("top-level")) << ' ' << domains);
/*MSG_DEBUG("Domains " << (parent() ? MESSAGE(parent() << "->" << index_in_parent()) : std::string("top-level")) << ' ' << domains);*/
jpar_dom = compute_product(stack.begin(), stack.end(), varset, domains);
} else {
jpar_dom = stack.front();
......@@ -1054,7 +1054,7 @@ struct factor_graph_type : public recursive_graph_type<factor_graph_type> {
e.coef = 1;
}
propagate_spawnling_domain({spawnling}, dom);
MSG_DEBUG("Domain for spawnling #" << spawnling << ": " << domains[{spawnling}]);
/*MSG_DEBUG("Domain for spawnling #" << spawnling << ": " << domains[{spawnling}]);*/
}
/*MSG_DEBUG("COMPUTED DOMAIN " << node_domains[n]);*/
dump_node(n);
......@@ -1152,7 +1152,7 @@ struct factor_graph_type : public recursive_graph_type<factor_graph_type> {
void
squeeze_factor(const node_vec& f1in, node_index_type f1, node_index_type i1, node_index_type f2)
{
MSG_DEBUG("squeeze_factor f1in " << f1in << " f1 " << f1 << " i1 " << i1 << " f2 " << f2);
/*MSG_DEBUG("squeeze_factor f1in " << f1in << " f1 " << f1 << " i1 " << i1 << " f2 " << f2);*/
var_vec var_f1in;
for (node_index_type i: f1in) {
var_f1in = var_f1in + variables_of(i);
......@@ -1471,8 +1471,8 @@ struct instance_type {
message_type
compute_message(const message_compute_operation_type& op, const std::map<var_vec, genotype_comb_type>& domains)
{
/*scoped_indent _(MESSAGE("[compute message " << op.emitter << " % " << op.output << " -> " << op.receiver << "] "));*/
/*MSG_DEBUG("" << op);*/
scoped_indent _(MESSAGE("[compute message " << op.emitter << " % " << op.output << " -> " << op.receiver << "] "));
MSG_DEBUG("" << op);
if (sub_instances[op.emitter]) {
return sub_instances[op.emitter]->compute(op.receiver, domains);
}
......@@ -1497,8 +1497,8 @@ struct instance_type {
mp.add(messages[*inci]);
}
auto ret = mp.compute(op.output, domains);
/*MSG_DEBUG("" << ret);*/
/*MSG_DEBUG("");*/
MSG_DEBUG("" << ret);
MSG_DEBUG("");
return ret;
}
......@@ -1506,7 +1506,7 @@ struct instance_type {
compute(node_index_type n, const std::map<var_vec, genotype_comb_type>& domains) /* compute an external message (this -> external node #n) */
{
const auto& variant = variants[n];
scoped_indent _(MESSAGE("[compute " << n << "] "));
/*scoped_indent _(MESSAGE("[compute " << n << "] "));*/
/*MSG_DEBUG("" << variant);*/
clear_internal_evidence();
for (size_t i = 0; i < variant.outer_inputs.size(); ++i) {
......@@ -1527,14 +1527,14 @@ struct instance_type {
auto tmp = compute_message(op, domains);
ret.insert(ret.end(), tmp.begin(), tmp.end());
}
MSG_DEBUG("MESSAGES");
for (const auto& kv: message_index) {
MSG_DEBUG(std::setw(4) << kv.first.first << " -> " << std::setw(4) << kv.first.second << " ==" << kv.second << "== " << messages[kv.second] << std::endl);
}
MSG_DEBUG("OUTPUT");
for (const auto& t: ret) {
MSG_DEBUG("" << t << std::endl);
}
/*MSG_DEBUG("MESSAGES");*/
/*for (const auto& kv: message_index) {*/
/*MSG_DEBUG(std::setw(4) << kv.first.first << " -> " << std::setw(4) << kv.first.second << " ==" << kv.second << "== " << messages[kv.second] << std::endl);*/
/*}*/
/*MSG_DEBUG("OUTPUT");*/
/*for (const auto& t: ret) {*/
/*MSG_DEBUG("" << t << std::endl);*/
/*}*/
return ret;
}
......@@ -1577,11 +1577,12 @@ struct instance_type {
V.output_operations.emplace_back();
auto& output_operation = V.output_operations.back();
/*output_operation.output = g->variables_of(output);*/
for (variable_index_type v: g->variables_of(output)) {
if (g->var_is_output(v)) {
output_operation.output.push_back(v);
}
}
/*for (variable_index_type v: g->variables_of(output)) {*/
/*if (g->var_is_output(v)) {*/
/*output_operation.output.push_back(v);*/
/*}*/
/*}*/
output_operation.output = g->parent()->variables_of(towards);
output_operation.receiver = 0;
output_operation.emitter = output;
for (node_index_type i: g->all_nei(output)) {
......@@ -1621,12 +1622,12 @@ struct instance_type {
{
/* FIXME: anchoring should account for forests. Maybe return a map [colour] => { [size_t] => node_index_type } */
std::map<size_t, node_index_type> all_incoming;
MSG_DEBUG("Anchor points " << g->anchor_points);
/*MSG_DEBUG("Anchor points " << g->anchor_points);*/
for (const auto& ext_anchor: g->anchor_points) {
node_index_type external = ext_anchor.first;
if (!g->parent()->node_is_deleted(external)) {
MSG_DEBUG("ext " << ext_anchor.first << " anchor " << ext_anchor.second);
MSG_QUEUE_FLUSH();
/*MSG_DEBUG("ext " << ext_anchor.first << " anchor " << ext_anchor.second);*/
/*MSG_QUEUE_FLUSH();*/
all_incoming[parent->get_message_index(external, g->index_in_parent())] = ext_anchor.second;
}
}
......
......@@ -412,6 +412,7 @@ std::vector<test_descr> all_tests = {
#if 0
std::vector<pedigree_item>
filter_pedigree(const std::vector<pedigree_item>& full_pedigree, const std::vector<std::string>& inputs, const std::vector<std::string>& outputs)
{
......@@ -437,7 +438,7 @@ filter_pedigree(const std::vector<pedigree_item>& full_pedigree, const std::vect
}
return ret;
}
#endif
inline bool ends_with(std::string const & value, std::string const & ending)
......@@ -553,6 +554,7 @@ main(int argc, char** argv)
MSG_DEBUG("F4 * F5 (squeeze F4) size " << factors[100].size());
MSG_DEBUG("F4 * F5 * F6 (squeeze F4&F5) size " << factors[101].size());
#endif
#if 0
std::unique_ptr<factor_graph_type> g;
if (argc == 1) {
......@@ -639,6 +641,7 @@ main(int argc, char** argv)
}
return 0;
#endif
#if 0
msg_handler_t::set_color(true);
......@@ -670,7 +673,7 @@ main(int argc, char** argv)
MSG_DEBUG("performed " << (good + failed) << " tests, " << good << " good, " << failed << " failed.");
return failed == 0;
#endif
#if 0
#if 1
joint_variable_product_type jvp;
bn_label_type
s0 = {'a', 'a', 0, 0},
......@@ -789,6 +792,7 @@ main(int argc, char** argv)
/*std::cout << std::endl;*/
#endif
#if 0
random_test T;
if (argc > 1) {
......@@ -803,6 +807,7 @@ main(int argc, char** argv)
auto output = T.output_vars(3);
T.test_equality(output);
/*T.measure_speed(output, 1000);*/
#endif
#endif
return 0;
(void) argc; (void) argv;
......
......@@ -38,13 +38,14 @@ operator << (std::ostream& os, const std::map<bn_label_type, double>& obs)
void
dispatch_geno_probs(
const std::vector<gencomb_type>& lc,
const std::vector<label_type>& labels,
const std::map<label_type, double>& geno_probs,
size_t ind,
std::map<size_t, std::vector<double>>& state_prob)
const std::vector<gencomb_type>& lc, /* table parent states -> spawnling states */
const std::vector<label_type>& labels, /* spawnling labels */
const std::map<label_type, double>& geno_probs, /* spawnling genotype probabilities */
size_t ind, /* spawnling number */
std::map<size_t, std::vector<double>>& state_prob) /* all computed locus vectors to fetch the parents' states */
{
scoped_indent _(MESSAGE("[dispatchGP] "));
MSG_DEBUG("ind " << ind);
MSG_DEBUG("lc " << lc);
MSG_DEBUG("labels " << labels);
MSG_DEBUG("geno_probs " << geno_probs);
......@@ -164,7 +165,7 @@ job_registry = {
/*MSG_DEBUG("Pedigree" << std::endl << settings->pedigree);*/
/*MSG_DEBUG("COMPUTING FACTOR GRAPH FOR " << unique_n_alleles[n] << " ALLELES.");*/
/*factor_graph fg(settings->pedigree, unique_n_alleles[n], settings->noise);*/
auto fg = factor_graph_type::from_pedigree(settings->pedigree.items, unique_n_alleles[n], settings->input_generations, settings->output_generations);
auto fg = factor_graph_type::from_pedigree(settings->pedigree.items, unique_n_alleles[n], settings->input_generations, /*settings->output_generations*/ {});
/*MSG_DEBUG("COMPUTED FACTOR GRAPH" << std::endl << fg);*/
/*fg.finalize();*/
/*fg.save(settings->job_filename("factor-graph", unique_n_alleles[n]));*/
......@@ -290,15 +291,74 @@ job_registry = {
/*auto q = fg.build_query_operation({i.id});*/
/*marginals.insert(marginals.end(), q.begin(), q.end());*/
/*}*/
std::map<int, genotype_comb_type> marginals;
/*std::map<int, genotype_comb_type> marginals;*/
/*fg.compute_messages(evidence.begin(), evidence.end(), messages);*/
/*fg.compute_full_factor_state(messages, marginals);*/
message_type output = I->compute(0, domains);
MSG_DEBUG("OUTPUT:");
std::map<variable_index_type, std::map<label_type, double>> marginals;
for (const auto& t: output) {
MSG_DEBUG(" * " << t << std::endl);
for (const auto& e: t) {
for (const auto& k: e.keys) {
label_type l = {k.state.first, k.state.second};
marginals[k.parent][l] += e.coef;
}
}
}
MSG_DEBUG("PROJECTED MARGINALS");
for (const auto& km: marginals) {
std::stringstream ss;
ss << " * " << km.first;
for (const auto& kv: km.second) {
ss << ' ' << kv.first << '=' << kv.second;
}
MSG_DEBUG(ss.str());
}
MSG_QUEUE_FLUSH();
std::map<size_t, std::vector<double>> state_prob, output_prob;
for (size_t ind = 1; ind < settings->pedigree.tree.m_ind_number_to_node_number.size(); ++ind) {
size_t node = settings->pedigree.tree.m_ind_number_to_node_number[ind];
auto gen = settings->pedigree.get_gen(ind);
const auto& labels = gen->labels;
const auto& LC = settings->pedigree.LC[node];
int variable = settings->pedigree.ind2id(ind);
dispatch_geno_probs(LC, labels, marginals[variable], node, state_prob);
if (std::find(settings->output_generations.begin(), settings->output_generations.end(), gen->name) != settings->output_generations.end()) {
output_prob[node] = state_prob[node];
}
}
MSG_DEBUG("LOCUS VECTORS");
for (const auto& kv: state_prob) {
MSG_DEBUG(" * " << kv.first << ' ' << kv.second);
}
ofile ofs(settings->job_filename("compute-LV", mark));
rw_base() (ofs, output_prob);
return true;
/*for (const auto& pop_obs: settings->observed_mark) {*/
/*MSG_DEBUG("Obs in generation " << pop_obs.first);*/
/*size_t n = 0;*/
/*for (size_t i: settings->pedigree.individuals_by_generation_name.find(pop_obs.first)->second) {*/
/*size_t node = settings->pedigree.ind(i);*/
/*size_t igen = settings->pedigree.node_generations[node];*/
/*const auto& gen = settings->pedigree.generations[igen];*/
/*gen->labels;*/
/*}*/
/*}*/
/**/
/*std::map<size_t, std::map<label_type, double>> prob_by_ind;*/
/*for (const auto& kv: marginals) {*/
/*genotype_comb_type::key_type key = kv.first.keys[0];*/
/*prob_by_ind[key.parent][{key.state.first, key.state.second}] += kv.second;*/
/*}*/
/*std::map<size_t, std::vector<double>> state_prob;*/
#if 0
MSG_DEBUG("MARGINALS:");
for (const auto& t: output) {
......@@ -442,7 +502,7 @@ job_registry = {
rw_base() (ofs, state_prob);
return true;
#else
/*#else*/
return false;
#endif
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment