Commit eae0fd1c authored by Damien Leroux's avatar Damien Leroux
Browse files

New implementation of graph_type&co.

This one is good, the old code was so bad and ugly.
parent 91141d66
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -39,6 +39,7 @@ enum msg_channel { Out, Err, Log };
struct message_struc {
msg_channel channel;
std::string message;
message_struc(msg_channel c, std::string&& msg) : channel(c), message(std::move(msg)) {}
};
typedef std::shared_ptr<message_struc> message_handle;
......
......@@ -75,8 +75,8 @@ $(OBJ):%.o: %.cc
test_graph: test_graph.cc ../static_data.o ../../include/bayes/graphnode.h
$C -O0 -ggdb $< ../static_data.o -o $@
generalized_product.debug: generalized_product.cc ../static_data.o ../../include/bayes/generalized_product.h ../../include/bayes/graphnode.h
generalized_product.debug: generalized_product.cc ../static_data.o ../../include/bayes/generalized_product.h ../../include/bayes/graphnode2.h ../../include/bayes/graphnode_base.h
$C -pg -O0 -ggdb $< ../static_data.o -o $@
generalized_product: generalized_product.cc ../static_data.o ../../include/bayes/generalized_product.h ../../include/bayes/graphnode.h
generalized_product: generalized_product.cc ../static_data.o ../../include/bayes/generalized_product.h ../../include/bayes/graphnode2.h ../../include/bayes/graphnode_base.h
$C $< ../static_data.o -o $@
#define SPELL_UNSAFE_OUTPUT
#include "bayes/generalized_product.h"
#include "graphnode.h"
/*#include "graphnode.h"*/
#include "graphnode2.h"
#include <algorithm>
#include <random>
......@@ -450,6 +451,8 @@ inline bool ends_with(std::string const & value, std::string const & ending)
int
main(int argc, char** argv)
{
msg_handler_t::set_color(true);
#if 0
std::map<std::vector<int>, genotype_comb_type> domains;
std::map<int, genotype_comb_type> factors;
......@@ -550,10 +553,10 @@ main(int argc, char** argv)
MSG_DEBUG("F4 * F5 (squeeze F4) size " << factors[100].size());
MSG_DEBUG("F4 * F5 * F6 (squeeze F4&F5) size " << factors[101].size());
#endif
std::unique_ptr<graph_type> g;
std::unique_ptr<factor_graph_type> g;
if (argc == 1) {
g = graph_type::from_pedigree(read_csv("random.ped", ';'), 2, {}, {}, "boiboite");
g = factor_graph_type::from_pedigree(read_csv("random.ped", ';'), 2, {}, {}, "boiboite");
g->to_image("boiboite", "png");
/*g.dump_active();*/
......@@ -561,7 +564,7 @@ main(int argc, char** argv)
/*g.to_image("boiboite12", "png");*/
/*g.dump_active();*/
g = graph_type::from_pedigree(read_csv("/home/daleroux/devel/spel/sample-data/data_magic/pedigree_clean_DH.csv", ';'), 1, {"Cervil", "Levovil", "Criollo", "Stupicke", "Plovdiv", "LA1420", "Ferum", "LA0147", "ICS3"}, {"ICS3"});
g = factor_graph_type::from_pedigree(read_csv("/home/daleroux/devel/spel/sample-data/data_magic/pedigree_clean_DH.csv", ';'), 1, {"Cervil", "Levovil", "Criollo", "Stupicke", "Plovdiv", "LA1420", "Ferum", "LA0147", "ICS3"}, {"ICS3"});
g->to_image("magic-dh", "png");
/*g = graph_type::from_pedigree(read_csv("/home/daleroux/devel/spel/sample-data/data_magic/pedigree_clean-norepeat.csv", ';'), 1, {}, {});*/
......@@ -597,17 +600,20 @@ main(int argc, char** argv)
/*g.to_image(MESSAGE(prefix << "-pre-optim"), "png");*/
/*g.optimize();*/
/*g.to_image(prefix, "png");*/
g = graph_type::from_pedigree(filter_pedigree(read_csv(pedfile, ';'), in, out), 2, in, out, MESSAGE("debug-" << prefix));
g->to_image(MESSAGE(prefix << "-filtered-pre-optim"), "png");
g = factor_graph_type::from_pedigree(filter_pedigree(read_csv(pedfile, ';'), in, out), 2, in, out, MESSAGE("debug-" << prefix));
/*g->build_subgraphs();*/
g->dump();
/*g->to_image(MESSAGE(prefix << "-filtered-pre-optim"), "png");*/
/*g->dump_active();*/
g->optimize();
MSG_DEBUG("Creating png...");
/*g->optimize();*/
/*MSG_DEBUG("Creating png...");*/
g->to_image(MESSAGE(prefix << "-filtered"), "png");
MSG_DEBUG("Creating instance...");
auto I = g->instance();
ofile of(MESSAGE(prefix << "-instance.data"));
I->file_io(of);
rw_comb<int, bn_label_type>()(of, g->domains);
/*MSG_DEBUG("Creating instance...");*/
/*auto I = g->instance();*/
/*ofile of(MESSAGE(prefix << "-instance.data"));*/
/*I->file_io(of);*/
/*rw_comb<int, bn_label_type>()(of, g->domains);*/
#if 0
} else {
std::unique_ptr<graph_type::instance_type> I(new graph_type::instance_type());
std::map<var_vec, genotype_comb_type> domains;
......@@ -624,6 +630,7 @@ main(int argc, char** argv)
for (const auto& t: ret) {
MSG_DEBUG("" << t);
}
#endif
}
}
......
......@@ -91,6 +91,15 @@ get_domain(std::map<var_vec, genotype_comb_type>& domains, variable_index_type v
genotype_comb_type
build_evidence(variable_index_type id, const std::map<bn_label_type, double>& obs)
{
genotype_comb_type tmp;
for (const auto& kv: obs) {
tmp.m_combination.emplace_back(genotype_comb_type::element_type{{{id, kv.first}}, kv.second});
}
return tmp;
}
std::map<std::string, std::pair<count_jobs_type, job_type>>
job_registry = {
......@@ -160,6 +169,7 @@ job_registry = {
/*fg.finalize();*/
/*fg.save(settings->job_filename("factor-graph", unique_n_alleles[n]));*/
fg->optimize();
fg->dump_active();
auto I = fg->instance();
ofile of(settings->job_filename("factor-graph", unique_n_alleles[n]));
I->file_io(of);
......@@ -204,7 +214,7 @@ job_registry = {
obs_vec = obs_spec.score(obsdat);
genotype_comb_type::key_type key;
key.parent = n;
for (const auto& label: fg.get_domain(n)) {
for (const auto& label: get_domain(domains, n)) {
ind_obs[label] = 0;
}
MSG_DEBUG("i " << i);
......@@ -221,7 +231,7 @@ job_registry = {
MSG_DEBUG("obs for #" << n << " '" << pop_obs.observations.data.find(settings->marker_names[mark])->second[i] << "': " << obs_vec);
if (obs_spec.domain == ODAllele) {
for (const auto& obs: obs_vec) {
for (const auto& label: fg.get_domain(n)) {
for (const auto& label: get_domain(domains, n)) {
if (label.first_allele == allele_obs_to_idx[obs.first] && label.second_allele == allele_obs_to_idx[obs.second]) {
ind_obs[label] = 1;
}
......@@ -229,7 +239,7 @@ job_registry = {
}
} else {
for (const auto& obs: obs_vec) {
for (const auto& label: fg.get_domain(n)) {
for (const auto& label: get_domain(domains, n)) {
if (label.first == obs.first && label.second == obs.second) {
ind_obs[label] = 1;
}
......@@ -242,7 +252,7 @@ job_registry = {
/*auto instance = fg.instance();*/
/*instance.clear_evidence();*/
std::vector<std::shared_ptr<message_type>> evidence;
/*std::vector<std::shared_ptr<message_type>> evidence;*/
for (const auto& pop_obs: settings->observed_mark) {
/*MSG_DEBUG("Obs in generation " << pop_obs.first);*/
......@@ -250,9 +260,10 @@ job_registry = {
for (size_t i: settings->pedigree.individuals_by_generation_name.find(pop_obs.first)->second) {
int id = settings->pedigree.ind2id(i);
auto obs = make_obs(pop_obs.second, n, id);
evidence.emplace_back(std::make_shared<message_type>(message_type{fg.build_evidence(id, obs)}));
/*evidence.emplace_back(std::make_shared<message_type>(message_type{build_evidence(id, obs)}));*/
I->add_evidence(id, build_evidence(id, obs));
MSG_DEBUG("observations for individual #" << n << " (id #" << id << ") = " << obs);
MSG_DEBUG("id=" << id << " id=" << id << " evidence " << evidence);
/*MSG_DEBUG("id=" << id << " id=" << id << " evidence " << evidence);*/
double accum = 0;
/*for (const auto& kv: obs) {*/
/*evidence.force_set({{node, kv.first}}, kv.second);*/
......@@ -279,9 +290,16 @@ job_registry = {
/*marginals.insert(marginals.end(), q.begin(), q.end());*/
/*}*/
std::map<int, genotype_comb_type> marginals;
fg.compute_messages(evidence.begin(), evidence.end(), messages);
fg.compute_full_factor_state(messages, marginals);
/*fg.compute_messages(evidence.begin(), evidence.end(), messages);*/
/*fg.compute_full_factor_state(messages, marginals);*/
message_type output = I->compute(0, domains);
MSG_DEBUG("MARGINALS:");
for (const auto& t: output) {
var_vec varz = get_parents(t);
for (variable_index_type v: varz) {
marginals[v] = project(t, {v}, {});
}
}
#if 0
std::vector<std::vector<double>> gen_dispatch;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment