Commit 7fb21b53 authored by Damien Leroux's avatar Damien Leroux
Browse files

Moved stuff to attic. Started new impl for BN.

parent 19e8ed44
This diff is collapsed.
#ifndef _SPEL_BAYES_FACTOR_VAR_H_
#define _SPEL_BAYES_FACTOR_VAR_H_
#include "pedigree.h"
struct bn_message_type {
typedef std::map<genotype_comb_type::key_list, double>::iterator iterator;
typedef std::map<genotype_comb_type::key_list, double>::const_iterator const_iterator;
bn_message_type() : m_map(), m_default_val(0) {}
bn_message_type(double default_val) : m_map(), m_default_val(default_val) {}
/*double&*/
/*operator [] (const genotype_comb_type::key_list& keys) { return m_map[keys]; }*/
void
set(const genotype_comb_type::key_list& keys, double d) { if (d != m_default_val) { m_map[keys] = d; } }
void
accumulate(const genotype_comb_type::key_list& keys, double d)
{
auto it = m_map.find(keys);
if (it == m_map.end()) {
if (d == 0) {
return;
} else {
m_map.emplace(keys, d);
}
}
it->second += d;
}
double
operator [] (const genotype_comb_type::key_list& keys) const
{
auto it = m_map.find(keys);
if (it == m_map.end()) {
return m_default_val;
}
return it->second;
}
double
default_val() const { return m_default_val; }
void
default_val(double p) { m_default_val = p; }
double
delta(const bn_message_type& other) const
{
double accum;
for (const auto& kv: m_map) {
accum += abs(kv.second - other[kv.first]);
}
return accum;
}
iterator begin() { return m_map.begin(); }
iterator end() { return m_map.end(); }
const_iterator begin() const { return m_map.begin(); }
const_iterator end() const { return m_map.end(); }
const_iterator cbegin() const { return m_map.cbegin(); }
const_iterator cend() const { return m_map.cend(); }
friend
std::ostream&
operator << (std::ostream& os, const bn_message_type& msg)
{
os << '{';
for (const auto& kv: msg) {
os << kv.first << '=' << kv.second << ' ';
}
return os << "default=" << msg.default_val() << '}';
}
private:
std::map<genotype_comb_type::key_list, double> m_map;
double m_default_val;
};
struct bn_factor_type;
struct bn_factor_interface_type;
struct bn_factor_interface_type {
bn_factor_interface_type(
const std::vector<size_t>& variables,
std::shared_ptr<bn_factor_type> f1,
std::shared_ptr<bn_factor_type> f2)
: m_variables(variables)
, m_f1(f1), m_f2(f2)
, m_msg_to_f1({1., 1.}), m_msg_to_f2({1., 1.})
{}
double
delta() const
{
return m_msg_to_f1[0].delta(m_msg_to_f1[1]) + m_msg_to_f2[0].delta(m_msg_to_f2[1]);
}
const bn_factor_type&
get_target_from(const bn_factor_type* source) const
{
return source == m_f1.get() ? *m_f2 : *m_f1;
}
const bn_message_type&
get_message_to(const bn_factor_type* dest, size_t buffer_index) const
{
return dest == m_f1.get() ? m_msg_to_f1[buffer_index] : m_msg_to_f2[buffer_index];
}
void
update_messages(const bn_message_type& observations, size_t buffer_index);
const std::vector<size_t>&
variables() const { return m_variables; }
bool
operator < (const bn_factor_interface_type& other) const
{
return m_variables < other.m_variables;
}
friend
std::ostream&
operator << (std::ostream& os, const bn_factor_interface_type& interf)
{
return os << "[@" << interf.m_f1
<< " (" << interf.m_msg_to_f1[0] << ", " << interf.m_msg_to_f1[1]
<< ") <--(" << interf.variables() << ")--> ("
<< interf.m_msg_to_f2[0] << ", " << interf.m_msg_to_f2[1]
<< ") @" << interf.m_f2 << ']';
}
private:
std::vector<size_t> m_variables;
std::shared_ptr<bn_factor_type> m_f1, m_f2;
bn_message_type m_msg_to_f1[2], m_msg_to_f2[2];
};
struct bn_factor_type {
bn_factor_type() : m_variables(), m_joint_prob_table(), m_interfaces() {}
bn_factor_type(const genotype_comb_type& joint)
: m_variables(), m_joint_prob_table(joint), m_interfaces()
{
m_variables = get_parents(m_joint_prob_table);
}
bn_factor_type(const bn_factor_type& other)
: m_variables(other.m_variables), m_joint_prob_table(other.m_joint_prob_table), m_interfaces(other.m_interfaces)
{}
bn_factor_type(bn_factor_type&& other)
: m_variables(std::move(other.m_variables)), m_joint_prob_table(std::move(other.m_joint_prob_table)), m_interfaces(std::move(other.m_interfaces))
{}
bn_factor_type(genotype_comb_type&& joint)
: m_variables(), m_joint_prob_table(std::move(joint)), m_interfaces()
{
m_variables = get_parents(m_joint_prob_table);
}
bn_message_type
compute_message_for(const bn_factor_interface_type* interface, const bn_message_type& observations, size_t buffer_index)
{
bn_message_type ret;
double accum = 0;
for (const auto& e: m_joint_prob_table) {
genotype_comb_type::key_list output_key = e.keys % interface->variables();
double prob = e.coef;
for (const auto& key: e.keys) { /* FIXME: all keys or all keys BUT the output ones? */
prob *= observations[key];
}
for (const auto& i: m_interfaces) {
if (i.get() == interface) {
continue;
}
genotype_comb_type::key_list interface_key = e.keys % i->variables();
prob *= i->get_message_to(this, buffer_index)[interface_key];
}
ret.accumulate(output_key, prob);
accum += prob;
}
if (accum != 0) {
accum = 1. / accum;
for (auto& kv: ret) {
kv.second *= accum;
}
}
return ret;
}
std::vector<size_t>
common_variables(const bn_factor_type& other) const
{
std::vector<size_t> ret(std::min(m_variables.size(), other.m_variables.size()));
auto end = std::set_intersection(m_variables.begin(), m_variables.end(), other.m_variables.begin(), other.m_variables.end(), ret.begin());
ret.resize(end - ret.begin());
return ret;
}
friend
std::ostream&
operator << (std::ostream& os, const bn_factor_type& factor)
{
os << "FACTOR @" << (&factor) << " on variables {" << factor.m_variables << '}' << std::endl;
os << "joint prob. table: " << factor.m_joint_prob_table << std::endl;
os << "interfaces:" << std::endl;
for (const auto& i: factor.interfaces()) {
os << (*i) << std::endl;
}
return os;
}
const std::vector<size_t>&
variables() const { return m_variables; }
const std::vector<std::shared_ptr<bn_factor_interface_type>>&
interfaces() const { return m_interfaces; }
void
add_interface(std::shared_ptr<bn_factor_interface_type> interf)
{
m_interfaces.emplace_back(interf);
}
private:
std::vector<size_t> m_variables;
genotype_comb_type m_joint_prob_table;
std::vector<std::shared_ptr<bn_factor_interface_type>> m_interfaces;
};
void
bn_factor_interface_type::update_messages(const bn_message_type& observations, size_t buffer_index)
{
m_msg_to_f2[!buffer_index] = m_f1->compute_message_for(this, observations, buffer_index);
m_msg_to_f1[!buffer_index] = m_f2->compute_message_for(this, observations, buffer_index);
}
struct compute_labels {
bn_label_type
find_label(size_t n, const genotype_comb_type::element_type& labels)
{
auto it
= std::find_if(labels.keys.begin(), labels.keys.end(),
[=] (const genotype_comb_type::key_type& k) { return k.parent == n; }
);
if (it == labels.keys.end()) {
MSG_ERROR("COULDN'T FIND LABEL FOR " << n << " IN " << labels, "");
MSG_QUEUE_FLUSH();
return {};
}
return it->state;
}
bn_label_type
operator () (const pedigree_tree_type& tree, size_t n, const genotype_comb_type::element_type& labels, const std::vector<bool>& recompute)
{
if (tree.get_p2(n) == NONE) {
/* gamete or ancestor */
if (tree.get_p1(n) == NONE) {
/* ancestor */
return find_label(n, labels);
} else {
auto gl = find_label(n, labels);
auto sub = operator () (tree, tree.get_p1(n), labels, recompute);
if (gl.first == GAMETE_L) {
return {sub.first, 0, sub.first_allele, 0};
} else {
return {sub.second, 0, sub.second_allele, 0};
}
}
} else if (recompute[n]) {
auto subl = operator () (tree, tree.get_p1(n), labels, recompute);
auto subr = operator () (tree, tree.get_p2(n), labels, recompute);
return {subl.first, subr.first, subl.first_allele, subr.first_allele};
} else {
return find_label(n, labels);
}
}
std::vector<bn_label_type>
operator () (const pedigree_tree_type& tree, size_t n, const genotype_comb_type& comb, const std::vector<bool>& recompute)
{
std::vector<bn_label_type> ret;
ret.reserve(comb.m_combination.size());
for (const auto& e: comb) {
ret.emplace_back(operator () (tree, n, e, recompute));
}
return ret;
}
/*static*/
/*genotype_comb_type*/
/*make_comb(const pedigree_tree_type& tree, size_t n, const genotype_comb_type& comb)*/
/*{*/
/*return state_to_combination(n, compute_labels()(tree, n, comb));*/
/*}*/
/*static*/
/*genotype_comb_type*/
/*add_labels(const pedigree_tree_type& tree, size_t n, const genotype_comb_type& comb)*/
/*{*/
/*auto labcomb = make_comb(tree, n, comb);*/
/*return hadamard(labcomb, comb);*/
/*}*/
};
struct factor_graph {
std::map<size_t, std::vector<bn_label_type>> variable_domains;
std::vector<std::shared_ptr<bn_factor_type>> factors;
std::vector<std::shared_ptr<bn_factor_interface_type>> interfaces;
bn_message_type observations;
size_t buffer_index;
factor_graph(const pedigree_type& ped)
: variable_domains(), factors(), interfaces(), buffer_index(0)
{
compute_factors_and_domains(ped);
compute_interfaces();
}
friend
std::ostream&
operator << (std::ostream& os, const factor_graph& fg)
{
os << "FACTOR GRAPH @" << (&fg) << std::endl;
os << "Variable domains:" << std::endl;
/*for (size_t i = 1; i < fg.variable_domains.size(); ++i) {*/
/*os << " - " << i << ": " << fg.variable_domains[i] << std::endl;*/
for (const auto& kv: fg.variable_domains) {
os << " - " << kv.first << ": " << kv.second << std::endl;
}
for (const auto& f: fg.factors) {
os << (*f) << std::endl;
}
os << "Ordered interfaces:" << std::endl;
for (const auto& i: fg.interfaces) {
os << (*i) << std::endl;
}
return os;
}
genotype_comb_type
compute_factor_rec(const pedigree_type& ped, int n0, int n, const std::vector<bool>& recompute, std::vector<size_t>& gametes, std::vector<bool>& visited)
{
/*if (visited[n]) {*/
/*return {1.};*/
/*}*/
/*visited[n] = true;*/
if (ped.tree[n].is_ancestor()) {
std::vector<bn_label_type> labels;
/* FIXME allow heterozygous ancestors allele-wise? */
if (n == n0) {
char letter = ped.ancestor_letters.find(n)->second;
for (size_t i = 0; i < ped.n_alleles; ++i) {
labels.emplace_back(letter, letter, i, i);
}
variable_domains[n] = labels;
}
return state_to_combination((size_t) n, variable_domains[n]);
} else if (ped.tree[n].is_gamete()) {
static std::vector<bn_label_type> label_g = {{GAMETE_L, 0, 0, 0}, {GAMETE_R, 0, 0, 0}};
genotype_comb_type G = state_to_combination((size_t) n, label_g) * .5;
gametes.push_back((size_t) n);
return kronecker(compute_factor_rec(ped, n0, ped.tree.get_p1(n), recompute, gametes, visited), G);
} else if (recompute[n]) {
auto tmp = kronecker(
compute_factor_rec(ped, n0, ped.tree.get_p1(n), recompute, gametes, visited),
compute_factor_rec(ped, n0, ped.tree.get_p2(n), recompute, gametes, visited));
auto label_per_state = compute_labels()(ped.tree, n, tmp, recompute);
if (n == n0) {
std::set<bn_label_type> uniq_sorted(label_per_state.begin(), label_per_state.end());
variable_domains[n].assign(uniq_sorted.begin(), uniq_sorted.end());
}
MSG_DEBUG("intermediary " << tmp);
auto labels = state_to_combination((size_t) n, label_per_state);
MSG_DEBUG("adding labels " << labels);
return hadamard(labels, tmp);
} else {
return state_to_combination((size_t) n, variable_domains[n]);
}
}
genotype_comb_type
compute_raw_factor(const pedigree_type& ped, int n)
{
std::vector<size_t> gametes;
std::vector<bool> visited(n + 1, false);
auto raw_fac = compute_factor_rec(ped, n, n, ped.tree.m_must_recompute[n], gametes, visited);
std::sort(gametes.begin(), gametes.end());
MSG_DEBUG("raw factor " << raw_fac);
auto clean = fold(sum_over(raw_fac, gametes));
MSG_DEBUG("cleaned factor " << clean);
return clean;
}
void
compute_factors_and_domains(const pedigree_type& ped)
{
for (int ind_node: ped.tree.m_ind_number_to_node_number) {
MSG_DEBUG("ind_node " << ind_node);
MSG_QUEUE_FLUSH();
if (ind_node == NONE) {
MSG_DEBUG("...not an individual...");
continue;
}
if (ped.tree[ind_node].is_ancestor()) {
MSG_DEBUG("...ancestor...");
(void) compute_raw_factor(ped, ind_node); /* to fill the variable domain */
continue;
}
auto raw_factor = compute_raw_factor(ped, ind_node);
if (!ped.tree[ind_node].is_ancestor()) {
const auto& recomp = ped.tree.m_must_recompute[ind_node];
int p1 = ped.tree.get_p1(ped.tree.get_p1(ind_node));
int p2 = ped.tree.get_p1(ped.tree.get_p2(ind_node));
if (recomp[p1] || recomp[p2]) {
/* create factor for joint probability of p1 and p2 */
factors.emplace_back(std::make_shared<bn_factor_type>(fold(sum_over(raw_factor, {(size_t) ind_node}))));
/* create factor for crossing */
factors.emplace_back(std::make_shared<bn_factor_type>(fold(sum_over_dual(raw_factor, {(size_t) ind_node, (size_t) p1, (size_t) p2}))));
} else {
/* otherwise, their product probabilities will be sufficient */
factors.emplace_back(std::make_shared<bn_factor_type>(raw_factor));
}
}
}
/*
* cleanup
* - a factor is dismissed if its variables are all included in at least another factor's variables
*/
MSG_DEBUG("BEFORE CLEANUP: " << factors.size() << " FACTORS.");
std::vector<bool> factor_included(factors.size(), false);
size_t total = factors.size();
for (size_t fsmall = 0; fsmall < factors.size(); ++fsmall) {
if (factor_included[fsmall]) { continue; }
for (size_t fbig = 0; fbig < factors.size(); ++fbig) {
if (factor_included[fbig]) { continue; }
if (fbig == fsmall) { continue; }
if (factors[fbig]->variables() == factors[fsmall]->variables()) {
factor_included[std::max(fsmall, fbig)] = true;
continue;
}
if (std::includes(factors[fbig]->variables().begin(), factors[fbig]->variables().end(),
factors[fsmall]->variables().begin(), factors[fsmall]->variables().end())) {
total -= !factor_included[fsmall]; /* decrease if not already detected as an included-inside-another-factor factor. */
factor_included[fsmall] = true;
MSG_DEBUG("FACTOR #" << fsmall << ' ' << factors[fsmall]->variables() << " INCLUDED IN #" << fbig << ' ' << factors[fbig]->variables());
}
}
}
std::vector<std::shared_ptr<bn_factor_type>> tmp_factors;
tmp_factors.reserve(total);
for (size_t i = 0; i < factor_included.size(); ++i) {
if (!factor_included[i]) {
tmp_factors.emplace_back(factors[i]);
}
}
factors.swap(tmp_factors);
MSG_DEBUG("AFTER CLEANUP: " << factors.size() << " FACTORS.");
}
void
compute_interfaces()
{
for (size_t f1 = 0; f1 < factors.size(); ++f1) {
for (size_t f2 = f1 + 1; f2 < factors.size(); ++f2) {
std::vector<size_t> common = factors[f1]->common_variables(*factors[f2]);
if (common.size()) {
auto interf = std::make_shared<bn_factor_interface_type>(common, factors[f1], factors[f2]);
interfaces.emplace_back(interf);
factors[f1]->add_interface(interf);
factors[f2]->add_interface(interf);
}
}
}
std::sort(
interfaces.begin(), interfaces.end(),
[](std::shared_ptr<bn_factor_interface_type> i1, std::shared_ptr<bn_factor_interface_type> i2)
{
return *i1 < *i2;
});
}
};
#endif
......@@ -10,7 +10,7 @@
#include "eigen.h"
#include "error.h"
#include "generation_rs_fwd.h"
/*#include "generation_rs_fwd.h"*/
#include "input/read_trait.h"
/** FOURCC **/
......@@ -95,6 +95,23 @@ int read_int(std::ifstream& ifs)
return ret;
}
/** CHAR **/
inline
void write_char(std::ofstream& ofs, char sz)
{
ofs.write((const char*) &sz, sizeof sz);
}
inline
char read_char(std::ifstream& ifs)
{
char ret;
ifs.read(&ret, sizeof ret);
return ret;
}
#if 0
/** FAST_POLYNOM **/
inline
......@@ -200,7 +217,7 @@ void read_genomatrix(std::ifstream& ifs, GenoMatrix& mat)
/** GENERATION_RS **/
inline void write_generation_rs(std::ofstream& ofs, const generation_rs* gen)
inline void write_geno_matrix(std::ofstream& ofs, const geno_matrix* gen)
{
write_fourcc(ofs, "SGRS");
write_str(ofs, gen->name);
......@@ -212,7 +229,7 @@ inline void write_generation_rs(std::ofstream& ofs, const generation_rs* gen)
}
inline
generation_rs* read_generation_rs(std::ifstream& ifs)
geno_matrix* read_geno_matrix(std::ifstream& ifs)
{
if (check_fourcc(ifs, "SGRS")) {
MSG_ERROR("File is not valid or has been corrupted", "");
......@@ -220,7 +237,7 @@ generation_rs* read_generation_rs(std::ifstream& ifs)
/*MSG_DEBUG("pouet 1"); MSG_QUEUE_FLUSH();*/
std::string name = read_str(ifs);
/*MSG_DEBUG("pouet 2"); MSG_QUEUE_FLUSH();*/
generation_rs* ret = generation_rs::blank(name);
geno_matrix* ret = geno_matrix::blank(name);
/*MSG_DEBUG("pouet 3"); MSG_QUEUE_FLUSH();*/
size_t n_p = read_size(ifs);
/*MSG_DEBUG("Have " << n_p << " processes"); MSG_QUEUE_FLUSH();*/
......@@ -236,6 +253,8 @@ generation_rs* read_generation_rs(std::ifstream& ifs)
ret->precompute();
return ret;
}
#endif