Skip to content
Snippets Groups Projects
factor_var3.h 51.5 KiB
Newer Older
#ifndef _SPEL_BAYES_FACTOR_VAR_H_
#define _SPEL_BAYES_FACTOR_VAR_H_


#include "pedigree.h"



struct bn_message_type {
    typedef std::map<genotype_comb_type::key_list, double>::iterator iterator;
    typedef std::map<genotype_comb_type::key_list, double>::const_iterator const_iterator;
    bn_message_type() : m_map(), m_default_val(0) {}
    bn_message_type(double default_val) : m_map(), m_default_val(default_val) {}

    /*double&*/
        /*operator [] (const genotype_comb_type::key_list& keys) { return m_map[keys]; }*/

    void
        set(const genotype_comb_type::key_list& keys, double d) { if (d != m_default_val) { m_map[keys] = d; } }

    void
        accumulate(const genotype_comb_type::key_list& keys, double d)
        {
            auto it = m_map.find(keys);
            if (it == m_map.end()) {
                if (d == 0) {
                    return;
                } else {
                    m_map.emplace(keys, d);
                }
            }
            it->second += d;
        }


    double
        operator [] (const genotype_comb_type::key_list& keys) const
        {
            auto it = m_map.find(keys);
            if (it == m_map.end()) {
                return m_default_val;
            }
            return it->second;
        }

    double
        default_val() const { return m_default_val; }
    void
        default_val(double p) { m_default_val = p; }

    double
        delta(const bn_message_type& other) const
        {
            /*MSG_DEBUG_INDENT_EXPR("[delta] ");*/
            double accum = 0;
            for (const auto& kv: m_map) {
                accum += fabs(kv.second - other[kv.first]);
                /*MSG_DEBUG("on " << kv.first << ", " << kv.second << "; other[" << kv.first << "] = " << other[kv.first] << "; accum = " << accum);*/
            }
            for (const auto& kv: other.m_map) {
                if (m_map.find(kv.first) == m_map.end()) {
                    accum += fabs(default_val() - kv.second);
                    /*MSG_DEBUG("on " << kv.first << ", " << kv.second << "; accum = " << accum);*/
                }
            /*MSG_DEBUG("delta(" << (*this) << ", " << other << ") = " << accum);*/
            /*MSG_DEBUG_DEDENT;*/
            return accum;
        }

    iterator begin() { return m_map.begin(); }
    iterator end() { return m_map.end(); }
    const_iterator begin() const { return m_map.begin(); }
    const_iterator end() const { return m_map.end(); }
    const_iterator cbegin() const { return m_map.cbegin(); }
    const_iterator cend() const { return m_map.cend(); }

    void
        extract_variable(size_t var, const std::vector<bn_label_type>& domain, bn_message_type& output) const
        {
            std::map<bn_label_type, bool> visited;
            for (const auto& kv: m_map) {
                for (const auto& key: kv.first) {
                    if (key.parent == var) {
                        genotype_comb_type::key_list kl(key);
                        output.m_map[kl] += kv.second;
                        visited[key.state] = true;
                        break;
                    } else if (key.parent > var) {
                        break;
                    }
                }
            }
            for (const bn_label_type& label: domain) {
                if (!visited[label]) {
                    genotype_comb_type::key_list kl({var, label});
                    output.m_map[kl] = default_val();
                }
            }
        }

    void
        clear() { m_map.clear(); }

    friend
        std::ostream&
        operator << (std::ostream& os, const bn_message_type& msg)
        {
            os << '{';
            for (const auto& kv: msg) {
                os << kv.first << '=' << kv.second << ' ';
            }
            return os << "default=" << msg.default_val() << '}';
        }

private:
    std::map<genotype_comb_type::key_list, double> m_map;
    double m_default_val;
};


struct bn_factor_type;
struct bn_factor_interface_type;


struct bn_factor_interface_type {
    bn_factor_interface_type(
            const std::vector<size_t>& variables,
            std::shared_ptr<bn_factor_type> f1,
            std::shared_ptr<bn_factor_type> f2)
        : m_variables(variables)
        , m_f1(f1), m_f2(f2)
        , m_msg_to_f1({1., 1.}), m_msg_to_f2({1., 1.})
    {}

    double
        delta() const
        {
            return m_msg_to_f1[0].delta(m_msg_to_f1[1]) + m_msg_to_f2[0].delta(m_msg_to_f2[1]);
        }

    const bn_factor_type&
        get_target_from(const bn_factor_type* source) const
        {
            return source == m_f1.get() ? *m_f2 : *m_f1;
        }

    const bn_message_type&
        get_message_to(const bn_factor_type* dest, size_t buffer_index) const
        {
            return dest == m_f1.get() ? m_msg_to_f1[buffer_index] : m_msg_to_f2[buffer_index];
        }

    void
        update_messages(const bn_message_type& observations, size_t buffer_index);

    const std::vector<size_t>&
        variables() const { return m_variables; }

    bool
        operator < (const bn_factor_interface_type& other) const
        {
            return m_variables < other.m_variables;
        }

    void clear()
    {
        for (size_t i = 0; i < 2; ++i) {
            m_msg_to_f1[i].clear();
            m_msg_to_f2[i].clear();
        }
    }

    const bn_factor_type* f1() const { return m_f1.get(); }
    const bn_factor_type* f2() const { return m_f2.get(); }

        operator << (std::ostream& os, const bn_factor_interface_type& interf);

private:
    std::vector<size_t> m_variables;
    std::shared_ptr<bn_factor_type> m_f1, m_f2;

    bn_message_type m_msg_to_f1[2], m_msg_to_f2[2];
};


inline
std::ostream&
operator << (std::ostream& os, const std::pair<genotype_comb_type::key_list, double>& kd)
{
    return os << kd.first << ':' << kd.second;
}

struct bn_factor_type {
    bn_factor_type() : m_variables(), m_joint_prob_table(), m_interfaces(), m_leaves() {}
    bn_factor_type(const genotype_comb_type& joint)
        : m_variables(), m_joint_prob_table(joint), m_interfaces(), m_leaves()
    {
        m_variables = get_parents(m_joint_prob_table);
    }

    bn_factor_type(const bn_factor_type& other)
        : m_variables(other.m_variables), m_joint_prob_table(other.m_joint_prob_table), m_interfaces(other.m_interfaces), m_leaves(other.m_leaves)
    {}

    bn_factor_type(bn_factor_type&& other)
        : m_variables(std::move(other.m_variables)), m_joint_prob_table(std::move(other.m_joint_prob_table)),
          m_interfaces(std::move(other.m_interfaces)), m_leaves(std::move(other.m_leaves))
    {}

    bn_factor_type(genotype_comb_type&& joint)
        : m_variables(), m_joint_prob_table(std::move(joint)), m_interfaces(), m_leaves()
    {
        m_variables = get_parents(m_joint_prob_table);
    }

    void
        compute_leaves(const pedigree_tree_type& T)
        {
            m_leaves.reserve(m_variables.size() - 1);
            for (size_t v: m_variables) {
                auto anc = T.count_ancestors(v);
                bool add = true;
                for (const auto& kv: anc) {
                    if (std::find(m_variables.begin(), m_variables.end(), kv.first) != m_variables.end()) {
                        add = false;
                        break;
                    }
                }
                if (add) {
                    m_leaves.push_back(v);
                }
            }
        }

    genotype_comb_type
        project(const std::vector<size_t>& project_variables)
        {
            if (project_variables == m_variables) {
                return m_joint_prob_table;
            }
            std::vector<size_t> norm_variables(m_leaves.size());
            auto it = std::set_difference(m_leaves.begin(), m_leaves.end(),
                                          project_variables.begin(), project_variables.end(),
                                          norm_variables.begin());
            norm_variables.resize(it - norm_variables.begin());
            return ::project(m_joint_prob_table, project_variables, norm_variables);
        }

    bn_message_type
        compute_norm_factors(const std::vector<size_t>& targets)
        {
            bn_message_type norm(1.);
#if 1
            /*size_t debug_i = m_joint_prob_table.size();*/
            auto i = m_joint_prob_table.begin(), j = m_joint_prob_table.end();
            /*for (; i != j && debug_i != 0; ++i, --debug_i) {*/
            for (; i != j; ++i) {
                /*MSG_DEBUG("normalizing " << (++debug_i) << "...");*/
                /*MSG_QUEUE_FLUSH();*/
                /*MSG_DEBUG("normalizing on element " << (*i));*/
                /*MSG_QUEUE_FLUSH();*/
                auto ke = i->extract(targets);
                norm.accumulate(ke.second.keys, ke.second.coef);
                /*norm.accumulate(ke.first, ke.second.coef);*/
            }
            MSG_DEBUG("norm factors " << norm);
            for (auto& kv: norm) {
                kv.second = 1. / kv.second;
            }
#endif
            return norm;
        }

    bn_message_type
        compute_message_for(const bn_factor_interface_type* interface, const bn_message_type& observations, size_t buffer_index)
        {
            double accum = 0;
            bn_message_type ret;
            bn_message_type norm = compute_norm_factors(interface->variables());
            std::vector<std::pair<genotype_comb_type::key_list, double>> debug;

            MSG_DEBUG("joint_prob_table.size=" << m_joint_prob_table.size());
            MSG_QUEUE_FLUSH();

            for (const auto& e: m_joint_prob_table) {
                debug.clear();
                /*genotype_comb_type::key_list output_key = e.keys % interface->variables();*/
                genotype_comb_type::key_list output_key;
                genotype_comb_type::element_type sub_element;
                std::tie(output_key, sub_element) = e.extract(interface->variables());
                double prob = e.coef;
                /*double prob = e.coef * norm[sub_element.keys];*/
                /*double prob = e.coef * norm[output_key];*/
                /*norm.accumulate(output_key, prob);*/
                /*MSG_DEBUG("initial coef: " << prob);*/
                for (const auto& key: e.keys) {  /* FIXME: all keys or all keys BUT the output ones? */
                    prob *= observations[key];
                    debug.emplace_back(key, observations[key]);
                    /*MSG_DEBUG("(obs) prob: " << prob << " obs[" << key << "]=" << observations[key]);*/
                }
                for (const auto& i: m_interfaces) {
                    /*MSG_DEBUG("using interface " << (*i));*/
                    if (i.get() == interface || i->variables() == interface->variables()) {
                        continue;
                    }
                    genotype_comb_type::key_list interface_key = e.keys % i->variables();
                    prob *= i->get_message_to(this, buffer_index)[interface_key];
                    debug.emplace_back(interface_key, i->get_message_to(this, buffer_index)[interface_key]);
                    /*MSG_DEBUG("(itf) prob: " << prob << " itf[" << interface_key << "]=" << i->get_message_to(this, buffer_index)[interface_key]);*/
                }
                ret.accumulate(output_key, prob);
                accum += prob;
                MSG_DEBUG("output_key=" << output_key << " coef=" << e.coef << " probs { " << debug << " } result=" << prob);
            }
            MSG_DEBUG("RAW MESSAGE: " << ret);
            /*for (auto& kv: ret) {*/
                /*kv.second /= norm[kv.first];*/
            /*}*/
            if (accum != 0) {
                accum = 1. / accum;
                for (auto& kv: ret) {
                    kv.second *= accum;
                }
            }
            /*MSG_DEBUG("NORMALIZED MESSAGE: " << ret);*/
            return ret;
        }

    bn_message_type
        compute_state(const bn_message_type& observations, size_t buffer_index)
        {
            bn_message_type ret;
            double accum = 0;
            for (const auto& e: m_joint_prob_table) {
                double prob = e.coef;
                for (const auto& key: e.keys) {  /* FIXME: all keys or all keys BUT the output ones? */
                    prob *= observations[key];
                }
                for (const auto& i: m_interfaces) {
                    genotype_comb_type::key_list interface_key = e.keys % i->variables();
                    prob *= i->get_message_to(this, buffer_index)[interface_key];
                }
                ret.set(e.keys, prob);
                accum += prob;
            }
            if (accum != 0) {
                accum = 1. / accum;
                for (auto& kv: ret) {
                    kv.second *= accum;
                }
            }
            return ret;
        }

    std::vector<size_t>
        common_variables(const bn_factor_type& other) const
        {
            std::vector<size_t> ret(std::min(m_variables.size(), other.m_variables.size()));
            auto end = std::set_intersection(m_variables.begin(), m_variables.end(), other.m_variables.begin(), other.m_variables.end(), ret.begin());
            ret.resize(end - ret.begin());
            return ret;
        }

    friend
        std::ostream&
        operator << (std::ostream& os, const bn_factor_type& factor)
        {
            os << "FACTOR @" << (&factor) << " on variables {" << factor.m_variables << '}' << std::endl;
            os << "joint prob. table: " << factor.m_joint_prob_table << std::endl;
            os << "interfaces:" << std::endl;
            for (const auto& i: factor.interfaces()) {
                os << (*i) << std::endl;
            }
            return os;
        }

    friend
        std::ostream&
        operator << (std::ostream& os, std::shared_ptr<bn_factor_type> factor)
        {
            return os << '{' << factor->variables() << '}';
        }

    const std::vector<size_t>&
        variables() const { return m_variables; }

    const std::vector<std::shared_ptr<bn_factor_interface_type>>&
        interfaces() const { return m_interfaces; }

    void
        add_interface(std::shared_ptr<bn_factor_interface_type> interf)
        {
            m_interfaces.emplace_back(interf);
        }

    const genotype_comb_type&
        table() const { return m_joint_prob_table; }

private:
    std::vector<size_t> m_variables;
    genotype_comb_type m_joint_prob_table;
    std::vector<std::shared_ptr<bn_factor_interface_type>> m_interfaces;
    std::vector<size_t> m_leaves;
};

void
    bn_factor_interface_type::update_messages(const bn_message_type& observations, size_t buffer_index)
    {
        MSG_DEBUG("buffer_index=" << buffer_index);
        MSG_QUEUE_FLUSH();
        MSG_DEBUG_INDENT_EXPR("[to_f2] ");
        m_msg_to_f2[!buffer_index] = m_f1->compute_message_for(this, observations, buffer_index);
        MSG_DEBUG_DEDENT;
        MSG_DEBUG_INDENT_EXPR("[to_f1] ");
        m_msg_to_f1[!buffer_index] = m_f2->compute_message_for(this, observations, buffer_index);
        MSG_DEBUG_DEDENT;
inline
std::ostream&
operator << (std::ostream& os, const bn_factor_interface_type& interf)
{
    return os << "[@" << interf.m_f1
        << " (" << interf.m_msg_to_f1[0] << ", " << interf.m_msg_to_f1[1]
        << ")  <--(" << interf.variables() << ")-->  ("
        << interf.m_msg_to_f2[0] << ", " << interf.m_msg_to_f2[1]
        << ") @" << interf.m_f2 << ']';
}



struct compute_labels {
    bn_label_type
        find_label(size_t n, const genotype_comb_type::element_type& labels)
        {
            auto it
                = std::find_if(labels.keys.begin(), labels.keys.end(),
                               [=] (const genotype_comb_type::key_type& k) { return k.parent == n; }
                  );
            if (it == labels.keys.end()) {
                MSG_ERROR("COULDN'T FIND LABEL FOR " << n << " IN " << labels, "");
                MSG_QUEUE_FLUSH();
                return {};
            }
            return it->state;
        }

    bn_label_type
        operator () (const pedigree_tree_type& tree, size_t n, const genotype_comb_type::element_type& labels, const std::vector<bool>& recompute)
        {
            if (tree.get_p2(n) == NONE) {
                /* gamete or ancestor */
                if (tree.get_p1(n) == NONE) {
                    /* ancestor */
                    return find_label(n, labels);
                } else {
                    auto gl = find_label(n, labels);
                    auto sub = operator () (tree, tree.get_p1(n), labels, recompute);
                    if (gl.first == GAMETE_L) {
                        return {sub.first, 0, sub.first_allele, 0};
                    } else {
                        return {sub.second, 0, sub.second_allele, 0};
                    }
                }
            } else  if (recompute[n]) {
                auto subl = operator () (tree, tree.get_p1(n), labels, recompute);
                auto subr = operator () (tree, tree.get_p2(n), labels, recompute);
                return {subl.first, subr.first, subl.first_allele, subr.first_allele};
            } else {
                return find_label(n, labels);
            }
        }

    std::vector<bn_label_type>
        operator () (const pedigree_tree_type& tree, size_t n, const genotype_comb_type& comb, const std::vector<bool>& recompute)
        {
            std::vector<bn_label_type> ret;
            ret.reserve(comb.m_combination.size());
            for (const auto& e: comb) {
                ret.emplace_back(operator () (tree, n, e, recompute));
            }
            return ret;
        }

    /*static*/
        /*genotype_comb_type*/
        /*make_comb(const pedigree_tree_type& tree, size_t n, const genotype_comb_type& comb)*/
        /*{*/
            /*return state_to_combination(n, compute_labels()(tree, n, comb));*/
        /*}*/

    /*static*/
        /*genotype_comb_type*/
        /*add_labels(const pedigree_tree_type& tree, size_t n, const genotype_comb_type& comb)*/
        /*{*/
            /*auto labcomb = make_comb(tree, n, comb);*/
            /*return hadamard(labcomb, comb);*/
        /*}*/
};




struct factor_graph {
    std::map<size_t, std::vector<bn_label_type>> m_variable_domains;
    std::vector<std::shared_ptr<bn_factor_type>> m_factors;
    std::vector<std::shared_ptr<bn_factor_interface_type>> m_interfaces;
    bn_message_type m_observations;
    size_t m_buffer_index;

    factor_graph(const pedigree_type& ped)
        : m_variable_domains(), m_factors(), m_interfaces(), m_observations(1.), m_buffer_index(0)
        /*compute_factors_and_domains(ped);*/
        /*compute_interfaces();*/
        build_factors(ped);
#if 0
    std::vector<size_t>
        get_joint_ancestry(const pedigree_tree_type& T, size_t p_node, const std::vector<size_t>& ancestors_to_join)
        {
            auto anc = T.count_ancestors(p_node);
            std::vector<size_t> ret;
            ret.reserve(ancestors_to_join.size() + 1);
            for (size_t ja: ancestors_to_join) {
                if (anc.find(ja) != anc.end()) {
                    ret.push_back(ja);
                }
            }
            ret.push_back(p_node);
            return ret;
        }
#endif

    struct factor_creation_list_type {
        struct factor_creation_op {
            std::vector<size_t> variables;
            std::vector<size_t> f1_vars, f2_vars;
            size_t progeny;

            void
                cross(const pedigree_tree_type& T, factor_creation_list_type& fcl, std::vector<std::shared_ptr<bn_factor_type>>& factors) const
                {
                    static std::vector<bn_label_type> label_g = {{GAMETE_L, 0, 0, 0}, {GAMETE_R, 0, 0, 0}};
                    genotype_comb_type
                        G1 = state_to_combination((size_t) T.get_p1(progeny), label_g) * .5;
                    genotype_comb_type
                        G2 = state_to_combination((size_t) T.get_p2(progeny), label_g)
                        * (T.get_p1(progeny) != T.get_p2(progeny) ? .5 : 1);
                    genotype_comb_type p1, parents;
                    if (f1_vars.size() == 1 && T[f1_vars.front()].is_ancestor()) {
                        size_t n1 = f1_vars.front();
                        p1 = state_to_combination(n1, fcl.get_domain(n1));
                    } else {
                        size_t comp_fac = fcl.find_compatible_factor(f1_vars);
                        MSG_DEBUG("finding factor that provides {" << f1_vars << "} => " << comp_fac << " (array size is " << factors.size() << ')');
                        MSG_QUEUE_FLUSH();
                        p1 = factors[comp_fac]->project(f1_vars);
                        MSG_DEBUG("resulting table: " << p1);
                        MSG_QUEUE_FLUSH();
                    }
                    if (f2_vars.size() == 1 && T[f2_vars.front()].is_ancestor()) {
                        size_t n2 = f2_vars.front();
                        parents = kronecker(p1, state_to_combination(n2, fcl.get_domain(n2)));
                    } else if (f2_vars.size() > 0) {
                        size_t comp_fac = fcl.find_compatible_factor(f2_vars);
                        MSG_DEBUG("finding factor that provides {" << f2_vars << "} => " << comp_fac << " (array size is " << factors.size() << ')');
                        MSG_QUEUE_FLUSH();
                        parents = kronecker(p1, factors[comp_fac]->project(f2_vars));
                        MSG_DEBUG("resulting table: " << parents);
                        MSG_QUEUE_FLUSH();
                    } else {
                        parents = p1;
                    }
                    if (progeny != (size_t) -1) {
                        genotype_comb_type
                            unmarked_cross = kronecker(parents, kronecker(G1, G2));
                        MSG_DEBUG("unmarked_cross " << unmarked_cross);
                        MSG_QUEUE_FLUSH();
                        std::vector<bool> recompute(progeny + 1, false);
                        recompute.back() = true;
                        auto label_per_state = compute_labels()(T, progeny, unmarked_cross, recompute);
                        auto new_jp_table
                            = fold(sum_over(hadamard(unmarked_cross, state_to_combination(progeny, label_per_state)),
                                        {(size_t) T.get_p1(progeny), (size_t) T.get_p2(progeny)}));
                        factors.emplace_back(std::make_shared<bn_factor_type>(new_jp_table));
                        fcl.add_ind_domain(progeny, label_per_state);
                    } else {
                        factors.emplace_back(std::make_shared<bn_factor_type>(parents));
                    }
                    factors.back()->compute_leaves(T);
                }

            friend
                std::ostream&
                operator << (std::ostream& os, const factor_creation_op& op)
                {
                    if (op.progeny != (size_t) -1) {
                        return os << '{' << op.variables << "}: " << op.progeny << " = {" << op.f1_vars << "} ⨝ {" << op.f2_vars << '}';
                    } else {
                        return os << '{' << op.variables << "}: {" << op.f1_vars << "} ⨝ {" << op.f2_vars << '}';
                    }
                }
        };

        const std::vector<bn_label_type>&
            get_domain(size_t n) const
            {
                static std::vector<bn_label_type> empty;
                auto it = variable_domains.find(n);
                return it == variable_domains.end() ? empty : it->second;
            }

        size_t
            find_compatible_factor(const std::vector<size_t>& interface) const
            {
                auto
                    ret = std::find_if(operations.begin(), operations.end(),
                            [&] (const factor_creation_op& fco)
                            {
                                return std::includes(fco.variables.begin(), fco.variables.end(),
                                                     interface.begin(), interface.end());
                            });
                if (ret == operations.end()) {
                    return (size_t) -1;
                }
                return ret - operations.begin();
            }

        std::vector<size_t>
            joint_ancestors(const pedigree_tree_type& T, size_t node, const std::vector<size_t>& reent) const
            {
                auto p_anc = T.count_ancestors(node);
                std::vector<size_t> joint_reent;
                joint_reent.reserve(reent.size());
                for (size_t r: reent) {
                    if (p_anc.find(r) != p_anc.end()) {
                        joint_reent.push_back(r);
                    }
                }

                return joint_reent;
            }

        std::vector<size_t>
            unite(size_t n, const std::vector<size_t>& v1, const std::vector<size_t>& v2) const
            {
                std::set<size_t> tmp;
                if (n != (size_t) -1) {
                    tmp.insert(n);
                }
                tmp.insert(v1.begin(), v1.end());
                tmp.insert(v2.begin(), v2.end());
                return {tmp.begin(), tmp.end()};
            }

        /* returns interface */
        std::vector<size_t>
            ensure_factor(const pedigree_tree_type& T, size_t p_node, const std::vector<size_t>& reent)
            {
                MSG_DEBUG("... ensure_factor(" << p_node << ", " << reent << ')');
                auto joint_anc = joint_ancestors(T, p_node, reent);
                std::vector<size_t> interface = joint_anc;
                interface.push_back(p_node);
                size_t f = find_compatible_factor(interface);
                if (f != (size_t) -1) {
                    /* factor exists, OK. */
                    MSG_DEBUG("... ... factor exists " << operations[f]);
                    return interface;
                }

                operations.emplace_back();
                auto& new_op = operations.back();

                new_op.f1_vars = ensure_factor(T, T.get_p1(T.get_p1(p_node)), joint_anc);
                new_op.f2_vars = ensure_factor(T, T.get_p1(T.get_p2(p_node)), joint_anc);

                /* create cross {p_node} U itf1 U itf2 */
                new_op.progeny = (size_t) -1;
                new_op.variables = unite(new_op.progeny, new_op.f1_vars, new_op.f2_vars);
                MSG_DEBUG("... ... created new factor " << new_op);
                return interface;
            }

        void
            add_ind_domain(size_t ind_node, const std::vector<bn_label_type>& table)
            {
                std::set<bn_label_type> uniq(table.begin(), table.end());
                variable_domains[ind_node].assign(uniq.begin(), uniq.end());
            }

        void
            add_ind(const pedigree_type& ped, size_t ind_node)
            {
                if (ped.tree[ind_node].is_ancestor()) {
                    MSG_DEBUG("add_ind(" << ind_node << ')');
                    MSG_DEBUG("... is ancestor");
                    std::vector<bn_label_type> labels;
                    char letter = ped.ancestor_letters.find(ind_node)->second;
                    for (size_t i = 0; i < ped.n_alleles; ++i) {
                        labels.emplace_back(letter, letter, i, i);
                    }
                    variable_domains[ind_node] = labels;
                } else {
                    auto reent = ped.tree.cleanup_reentrants(ind_node);
                    genotype_comb_type result;
                    std::vector<size_t> itf1, itf2;
                    size_t p1 = (size_t) ped.tree.get_p1(ped.tree.get_p1(ind_node));
                    size_t p2 = (size_t) ped.tree.get_p1(ped.tree.get_p2(ind_node));
                    MSG_DEBUG("add_ind(" << ind_node << ", " << p1 << ", " << p2 << ')');
                    if (reent.size()) {
                        MSG_DEBUG("... has reentrants");
                        std::vector<size_t> R;
                        R.reserve(reent.size());
                        for (const auto& kv: reent) { R.push_back(kv.first); }
                        itf1 = ensure_factor(ped.tree, p1, R);
                        itf2 = ensure_factor(ped.tree, p2, R);
                        if (find_compatible_factor({p1, p2}) == (size_t) -1) {
                            operations.emplace_back();
                            auto& op = operations.back();
                            op.progeny = (size_t) -1;
                            op.variables = unite(op.progeny, itf1, itf2);
                            op.f1_vars = itf1;
                            op.f2_vars = itf2;
                            MSG_DEBUG("... joint parents for #" << ind_node << ": " << op);
                        }
                        {
                            operations.emplace_back();
                            auto& op = operations.back();
                            op.variables = {p1 > p2 ? p2 : p1, p1 > p2 ? p1 : p2, ind_node};
                            op.progeny = ind_node;
                            op.f1_vars = {p1 > p2 ? p2 : p1, p1 > p2 ? p1 : p2};
                            op.f2_vars = {};
                            MSG_DEBUG("... result for #" << ind_node << ": " << op);
                        }
                    } else {
                        MSG_DEBUG("... simple cross");
                        operations.emplace_back();
                        auto& op = operations.back();
                        op.variables = {p1 < p2 ? p1 : p2, p1 < p2 ? p2 : p1, ind_node};
                        op.f1_vars = {p1};
                        op.f2_vars = {p2};
                        op.progeny = ind_node;
                        MSG_DEBUG("... result for #" << ind_node << ": " << op);
                    }
                }
            }

        void
            add_all(const pedigree_type& ped)
            {
                for (size_t ind: ped.tree.m_ind_number_to_node_number) {
                    if (ind == (size_t) NONE) {
                        continue;
                    }
                    add_ind(ped, ind);
                }

                for (const auto& op: operations) {
                    MSG_DEBUG("[OP] " << op);
                }
            }

        void
            cleanup()
            {
                std::vector<bool> included(operations.size(), false);
                size_t total = operations.size();
                for (size_t i1 = 0; i1 < operations.size(); ++i1) {
                    if (included[i1]) { continue; }
                    const auto& o1 = operations[i1];
                    for (size_t i2 = 0; i2 < operations.size(); ++i2) {
                        if (included[i2] || i1 == i2) { continue; }
                        const auto& o2 = operations[i2];
                        if (o1.variables == o2.variables) {
                            included[std::max(i1, i2)] = true;
                            continue;
                        }
                        if (std::includes(o2.variables.begin(), o2.variables.end(), o1.variables.begin(), o1.variables.end())) {
                            total -= !included[i1];
                            included[i1] = true;
                        }
                    }
                }
                std::vector<factor_creation_op> tmp;
                tmp.reserve(total);
                for (size_t i = 0; i < included.size(); ++i) {
                    if (!included[i]) {
                        tmp.emplace_back(operations[i]);
                    }
                }
                operations.swap(tmp);

                for (const auto& op: operations) {
                    MSG_DEBUG("[POST CLEANUP OP] " << op);
                }
            }

        void
            compute_factors(const pedigree_type& ped, std::vector<std::shared_ptr<bn_factor_type>>& factors)
            {
                for (const auto& op: operations) {
                    op.cross(ped.tree, *this, factors);
                }
            }

    private:
        std::vector<factor_creation_op> operations;
        std::map<size_t, std::vector<bn_label_type>> variable_domains;
    };

    factor_creation_list_type factor_creation_operations;

    void
        build_factors(const pedigree_type& ped)
        {
            factor_creation_operations.add_all(ped);
            /*factor_creation_operations.cleanup();*/
            factor_creation_operations.compute_factors(ped, m_factors);
            cleanup_factor_list();
        }

    friend
        std::ostream&
        operator << (std::ostream& os, const factor_graph& fg)
        {
            os << "FACTOR GRAPH @" << (&fg) << std::endl;
            os << "Variable domains:" << std::endl;
            /*for (size_t i = 1; i < fg.m_variable_domains.size(); ++i) {*/
                /*os << "  - " << i << ": " << fg.m_variable_domains[i] << std::endl;*/
            for (const auto& kv: fg.m_variable_domains) {
                os << "  - " << kv.first << ": " << kv.second << std::endl;
            }
            for (const auto& f: fg.m_factors) {
                os << (*f) << std::endl;
            }
            os << "Ordered m_interfaces:" << std::endl;
            for (const auto& i: fg.m_interfaces) {
                os << (*i) << std::endl;
            }
            return os;
        }

#if 0
    bn_factor_type
        compute_factor2(const pedigree_type& ped, size_t n,
                        const std::vector<size_t>& reentrant_variables)
        {
            if (ped.tree[n].is_ancestor()) {

            } else if (ped.tree[n].is_genotype()) {
                std::vector<size_t> project_variables(3 + reentrant_variables.size());
                size_t p1 = (size_t) ped.tree.get_p1(ped.tree.get_p1(n));
                size_t p2 = (size_t) ped.tree.get_p1(ped.tree.get_p2(n));
                std::vector<size_t> base_variables = {n, p1, p2};
                auto it = std::set_union(base_variables.begin(), base_variables.end(),
                                         reentrant_variables.begin(), reentrant_variables.end(),
                                         project_variables.begin());

            }
        }

    std::shared_ptr<bn_factor_type>
        find_factor(const pedigree_type& ped, size_t n, const std::vector<size_t>& reentrant_variables)
        {
            std::vector<size_t> variables = reentrant_variables;
            variables.push_back(n);
            auto it = std::find_if(m_factors.begin(), m_factors.end(),
                                   [&] (std::shared_ptr<bn_factor_type> fac)
                                   {
                                       const auto& fv = fac->variables();
                                       return std::includes(fv.begin(), fv.end(), variables.begin(), variables.end());
                                   });
            if (it == m_factors.end()) {
                m_factors.emplace_back(compute_factor2(ped, n, reentrant_variables));
                return m_factors.back();
            }
            return *it;
        }
#endif


    genotype_comb_type
        compute_factor_rec(const pedigree_type& ped, int n0, int n, const std::vector<bool>& recompute, std::vector<size_t>& gametes, std::vector<bool>& visited)
        {
            /*if (visited[n]) {*/
                /*return {1.};*/
            /*}*/
            /*visited[n] = true;*/
            if (ped.tree[n].is_ancestor()) {
                std::vector<bn_label_type> labels;
                /* FIXME allow heterozygous ancestors allele-wise? */
                if (n == n0) {
                    char letter = ped.ancestor_letters.find(n)->second;
                    for (size_t i = 0; i < ped.n_alleles; ++i) {
                        labels.emplace_back(letter, letter, i, i);
                    }
                    m_variable_domains[n] = labels;
                return state_to_combination((size_t) n, m_variable_domains[n]);
            } else if (ped.tree[n].is_gamete()) {
                static std::vector<bn_label_type> label_g = {{GAMETE_L, 0, 0, 0}, {GAMETE_R, 0, 0, 0}};
                genotype_comb_type G = state_to_combination((size_t) n, label_g) * .5;
                gametes.push_back((size_t) n);
                return kronecker(compute_factor_rec(ped, n0, ped.tree.get_p1(n), recompute, gametes, visited), G);
            } else if (recompute[n]) {
                auto tmp = kronecker(
                        compute_factor_rec(ped, n0, ped.tree.get_p1(n), recompute, gametes, visited),
                        compute_factor_rec(ped, n0, ped.tree.get_p2(n), recompute, gametes, visited));
                auto label_per_state = compute_labels()(ped.tree, n, tmp, recompute);
                if (n == n0) {
                    std::set<bn_label_type> uniq_sorted(label_per_state.begin(), label_per_state.end());
                    m_variable_domains[n].assign(uniq_sorted.begin(), uniq_sorted.end());
                }
                MSG_DEBUG("intermediary  " << tmp);
                auto labels = state_to_combination((size_t) n, label_per_state);
                MSG_DEBUG("adding labels " << labels);
                return hadamard(labels, tmp);
            } else {
                return state_to_combination((size_t) n, m_variable_domains[n]);
#if 1
    genotype_comb_type
        compute_joint_crossing(const pedigree_type& ped, const genotype_comb_type& p1p2, int n)
        {
            static std::vector<bn_label_type> label_g = {{GAMETE_L, 0, 0, 0}, {GAMETE_R, 0, 0, 0}};
            size_t g1 = ped.tree.get_p1(n);
            size_t g2 = ped.tree.get_p2(n);
            genotype_comb_type cross = kronecker(kronecker(
                        p1p2, state_to_combination(g1, label_g)) * .5,
                        state_to_combination(g2, label_g) * .5);
            std::vector<bool> recompute(n, false);
            recompute[n] = recompute[g1] = recompute[g2] = true;
            auto label_per_state = compute_labels()(ped.tree, n, cross, recompute);
            return fold(sum_over(hadamard(cross, state_to_combination((size_t) n, label_per_state)), {g1, g2}));
        }
#endif

    genotype_comb_type
        compute_raw_factor(const pedigree_type& ped, int n)
        {
            std::vector<size_t> gametes;
            std::vector<bool> visited(n + 1, false);
            auto raw_fac = compute_factor_rec(ped, n, n, ped.tree.m_must_recompute[n], gametes, visited);
            std::sort(gametes.begin(), gametes.end());
            MSG_DEBUG("raw     factor " << raw_fac);
            auto clean = fold(sum_over(raw_fac, gametes));
            MSG_DEBUG("cleaned factor " << clean);
            return clean;
        }

    void
        compute_factors_and_domains(const pedigree_type& ped)
        {
            for (int ind_node: ped.tree.m_ind_number_to_node_number) {
                MSG_DEBUG_INDENT_EXPR("[ind_node " << ind_node << "] ");
                MSG_QUEUE_FLUSH();
                if (ind_node == NONE) {
                    MSG_DEBUG("...not an individual...");
                    MSG_DEBUG_DEDENT;
                    continue;
                }
                if (ped.tree[ind_node].is_ancestor()) {
                    MSG_DEBUG("...ancestor...");
                    (void) compute_raw_factor(ped, ind_node);  /* to fill the variable domain */
                    MSG_DEBUG_DEDENT;
                    continue;
                }
                auto raw_factor = compute_raw_factor(ped, ind_node);
                if (!ped.tree[ind_node].is_ancestor()) {
                    const auto& recomp = ped.tree.m_must_recompute[ind_node];
                    size_t p1 = (size_t) ped.tree.get_p1(ped.tree.get_p1(ind_node));
                    size_t p2 = (size_t) ped.tree.get_p1(ped.tree.get_p2(ind_node));
                    if (recomp[p1] || recomp[p2]) {
                        MSG_DEBUG("Decomposing factor.");
                        /* create factor for joint probability of p1 and p2 */
                        m_factors.emplace_back(std::make_shared<bn_factor_type>(fold(sum_over(raw_factor, {(size_t) ind_node}))));
                        MSG_DEBUG("Joint parent probability:" << std::endl << (*m_factors.back()));
                        /* create factor for crossing */
                        /*std::vector<size_t> fold_vars;*/
                        /*for (size_t v: m_factors.back()->variables()) {*/
                            /*if (v != p1 && v != p2) {*/
                                /*fold_vars.push_back(v);*/
                            /*}*/
                        /*}*/
                        /*m_factors.emplace_back(std::make_shared<bn_factor_type>(fold(sum_over_dual(raw_factor, {(size_t) ind_node, (size_t) p1, (size_t) p2}))));*/
                        /*m_factors.emplace_back(std::make_shared<bn_factor_type>(fold(sum_over(raw_factor, fold_vars))));*/
                        std::set<genotype_comb_type::key_list> uniq_p1p2;
                        for (const auto& e: raw_factor) {
                            auto k = e.keys % std::vector<size_t>{p1, p2};
                            uniq_p1p2.insert(k);
                        }
                        genotype_comb_type p1p2;