Commit 1afc0863 authored by Damien Leroux's avatar Damien Leroux
Browse files

New factor creation method.

parent 7fb21b53
This diff is collapsed.
......@@ -84,7 +84,7 @@ struct palette {
int idx;
bool bg;
color256(int r, int g, int b, bool back=false) : idx(closest_index_256(r, g, b)), bg(back) {}
color256(int i) : idx(i) {}
color256(int i) : idx(i), bg(false) {}
friend std::ostream& operator << (std::ostream& os, const color256& col)
{
return os << "\x1b[" << (3 + col.bg) << "8;5;" << std::dec << col.idx << 'm';
......
......@@ -25,5 +25,30 @@ template <typename X> using SelfAdjointEigenSolver = Eigen::SelfAdjointEigenSolv
#pragma clang diagnostic pop
#endif
namespace Eigen {
template<>
struct NumTraits<std::string> {
typedef std::string Real;
typedef std::string NonInteger;
typedef std::string Nested;
enum {
IsComplex = 0,
IsInteger = 1,
ReadCost = 100,
WriteCost = 100,
MulCost = 1000,
IsSigned = 0,
RequireInitialization = 1
};
static Real epsilon() { return { std::string() }; }
static Real dummy_precision() { return { std::string() }; }
static std::string highest() { return std::string(); }
static std::string lowest() { return std::string(); }
};
}
#endif
......@@ -36,30 +36,6 @@ std::ostream& operator << (std::ostream& os, const std::vector<VALUE_TYPE>& v)
return os;
}
namespace Eigen {
template<>
struct NumTraits<std::string> {
typedef std::string Real;
typedef std::string NonInteger;
typedef std::string Nested;
enum {
IsComplex = 0,
IsInteger = 1,
ReadCost = 100,
WriteCost = 100,
MulCost = 1000,
IsSigned = 0,
RequireInitialization = 1
};
static Real epsilon() { return { std::string() }; }
static Real dummy_precision() { return { std::string() }; }
static std::string highest() { return std::string(); }
static std::string lowest() { return std::string(); }
};
}
template <typename MATRIX, typename ROW_LABEL, typename COL_LABEL=ROW_LABEL>
struct labelled_matrix {
typedef MATRIX matrix_type;
......
......@@ -4,6 +4,8 @@
#include <list>
#include "eigen.h"
#define _EPSILON 1.e-10
template <typename PARENT_TYPE, typename STATE_TYPE=size_t>
struct combination_type {
static const size_t nopar = (size_t) -1;
......@@ -186,8 +188,8 @@ struct combination_type {
element_type() : keys(), coef(1) {}
element_type& operator = (const element_type& e) { keys = e.keys; coef = e.coef; return *this; }
element_type& operator = (element_type&& e) { keys = std::move(e.keys); coef = e.coef; return *this; }
bool operator == (const element_type& other) const { return keys == other.keys; }
bool operator != (const element_type& other) const { return keys != other.keys; }
bool operator == (const element_type& other) const { return keys == other.keys && fabs(coef - other.coef) < _EPSILON; }
bool operator != (const element_type& other) const { return keys != other.keys || coef != other.coef; }
bool operator < (const element_type& other) const { return keys < other.keys; }
friend
......@@ -334,6 +336,12 @@ struct combination_type {
m_combination.emplace_back(coef);
}
element_type&
operator [] (size_t i) { return m_combination[i]; }
const element_type&
operator [] (size_t i) const { return m_combination[i]; }
sum_iterator_type begin() { return m_combination.begin(); }
sum_iterator_type end() { return m_combination.end(); }
sum_const_iterator_type begin() const { return m_combination.begin(); }
......@@ -367,11 +375,14 @@ struct combination_type {
ret.m_combination = {*ai, *bi};
break;
} else*/
if (*ai < *bi) { /*MSG_DEBUG("a<b");*/ if (ai->coef) { ret.m_combination.push_back(*ai); } ++ai; }
else if (*bi < *ai) { /*MSG_DEBUG("b<a");*/ if (bi->coef) { ret.m_combination.push_back(*bi); } ++bi; }
if (*ai < *bi) { /*MSG_DEBUG("a<b");*/ if (ai->coef > _EPSILON || ai->coef < -_EPSILON) { ret.m_combination.push_back(*ai); } ++ai; }
else if (*bi < *ai) { /*MSG_DEBUG("b<a");*/ if (bi->coef > _EPSILON || ai->coef < -_EPSILON) { ret.m_combination.push_back(*bi); } ++bi; }
else {
/*MSG_DEBUG("a=b");*/
ret.m_combination.emplace_back((*ai) + bi->coef);
auto tmp = (*ai) + bi->coef;
if (tmp.coef > _EPSILON || tmp.coef < -_EPSILON) {
ret.m_combination.emplace_back(tmp);
}
++ai;
++bi;
}
......@@ -618,6 +629,28 @@ sum_over(const combination_type<PARENT_TYPE, STATE_TYPE>& comb, const std::vecto
}
template <typename PARENT_TYPE, typename STATE_TYPE>
combination_type<PARENT_TYPE, STATE_TYPE>
normalize_over(const combination_type<PARENT_TYPE, STATE_TYPE>& comb, const std::vector<PARENT_TYPE>& variables)
{
typedef typename combination_type<PARENT_TYPE, STATE_TYPE>::key_list key_list;
combination_type<PARENT_TYPE, STATE_TYPE> ret;
std::map<key_list, double> norm_coef;
for (const auto& e: comb) {
auto k = e.keys % variables;
norm_coef[k] += e.coef;
}
MSG_DEBUG("NORM FACTORS: " << norm_coef);
for (auto& kv: norm_coef) { kv.second = 1. / kv.second; }
for (const auto& e: comb) {
auto k = e.keys % variables;
ret.m_combination.emplace_back(e.keys, e.coef * norm_coef[k]);
}
return ret;
}
template <typename PARENT_TYPE, typename STATE_TYPE>
std::map<typename combination_type<PARENT_TYPE, STATE_TYPE>::key_list, combination_type<PARENT_TYPE, STATE_TYPE>>
sum_over_dual(const combination_type<PARENT_TYPE, STATE_TYPE>& comb, const std::vector<PARENT_TYPE>& parents)
......@@ -626,13 +659,17 @@ sum_over_dual(const combination_type<PARENT_TYPE, STATE_TYPE>& comb, const std::
for (const auto& e: comb) {
auto ke = e.extract(parents);
ret[ke.first] += ke.second;
/*ret[ke.first] += ke.second;*/
ret[ke.second.keys] += {ke.first, ke.second.coef};
}
MSG_DEBUG("sum_over_dual(" << comb << ", " << parents << ") => " << ret);
return ret;
}
#if 0
template <typename PARENT_TYPE, typename STATE_TYPE>
std::map<typename combination_type<PARENT_TYPE, STATE_TYPE>::key_list, combination_type<PARENT_TYPE, STATE_TYPE>>
sum_over_and_normalize(const combination_type<PARENT_TYPE, STATE_TYPE>& comb, const std::vector<PARENT_TYPE>& parents)
......@@ -655,7 +692,7 @@ sum_over_and_normalize(const combination_type<PARENT_TYPE, STATE_TYPE>& comb, co
return ret;
}
#endif
template <typename PARENT_TYPE, typename STATE_TYPE>
combination_type<PARENT_TYPE, STATE_TYPE>
......@@ -668,6 +705,27 @@ fold(const std::map<typename combination_type<PARENT_TYPE, STATE_TYPE>::key_list
return ret;
}
template <typename PARENT_TYPE, typename STATE_TYPE>
combination_type<PARENT_TYPE, STATE_TYPE>
project(const combination_type<PARENT_TYPE, STATE_TYPE>& comb, const std::vector<PARENT_TYPE>& project_variables, const std::vector<PARENT_TYPE>& norm_variables)
{
if (norm_variables.size()) {
return fold(sum_over_dual(normalize_over(comb, norm_variables), project_variables));
} else {
return fold(sum_over_dual(comb, project_variables));
}
}
template <typename PARENT_TYPE, typename STATE_TYPE>
combination_type<PARENT_TYPE, STATE_TYPE>
project_dual(const combination_type<PARENT_TYPE, STATE_TYPE>& comb, const std::vector<PARENT_TYPE>& project_variables, const std::vector<PARENT_TYPE>& norm_variables)
{
return fold(sum_over_dual(normalize_over(comb, norm_variables), project_variables));
}
namespace Eigen {
template<typename PARENT_TYPE>
struct NumTraits<combination_type<PARENT_TYPE>> {
......
......@@ -650,7 +650,7 @@ struct pedigree_type {
int n = tree.add_node();
MSG_DEBUG("node=" << n << " ind=" << ind);
compute_generation(generation_name, n);
compute_LC(n);
/*compute_LC(n);*/
MSG_DEBUG_DEDENT;
}
break;
......@@ -662,8 +662,8 @@ struct pedigree_type {
int g1 = spawn_gamete("M", tree.ind2node(p1));
int n = tree.add_node(g1, g1);
MSG_DEBUG("node=" << n << " ind=" << ind);
compute_generation(generation_name, n);
compute_data_for_bn(n);
/*compute_generation(generation_name, n);*/
/*compute_data_for_bn(n);*/
MSG_DEBUG_DEDENT;
}
break;
......@@ -681,8 +681,8 @@ struct pedigree_type {
int g2 = spawn_gamete("F", n2);
int n = tree.add_node(g1, g2);
MSG_DEBUG("node=" << n << " ind=" << ind);
compute_generation(generation_name, n);
compute_data_for_bn(n);
/*compute_generation(generation_name, n);*/
/*compute_data_for_bn(n);*/
MSG_DEBUG_DEDENT;
}
break;
......
......@@ -277,7 +277,7 @@ struct pedigree_tree_type {
}
ancestor_node_list_type
cleanup_reentrants(int node)
cleanup_reentrants(int node) const
{
auto A = count_ancestors(node);
auto Ap1 = count_ancestors(m_nodes[node].p1);
......
......@@ -30,7 +30,7 @@ COV_OBJ=$(subst .cc,.cov.o,$(SRC))
ALL_BUT_MAIN_OBJ=$(subst main.o ,,$(OBJ))
DEBUG_OPTS=-ggdb
DEBUG_OPTS=-ggdb -O
#OPT_OPTS=-O3 -DEIGEN_NO_DEBUG -DNDEBUG
#DEBUG_OPTS=-ggdb -DNDEBUG
#OPT_OPTS=-O3 -DNDEBUG
......
......@@ -722,6 +722,13 @@ int main(int argc, char** argv)
MSG_DEBUG((*ped.get_gen(F1)));
size_t BCR = ped.crossing("BCR", F1, A); (void) BCR;
size_t BCL = ped.crossing("BCL", A, F1); (void) BCL;
factor_graph fg(ped);
MSG_DEBUG("" << fg);
bn_message_type marginals;
double delta = fg.run(marginals);
MSG_DEBUG("Marginal probabilities:");
MSG_DEBUG("" << marginals);
}
/* 2 */
......@@ -735,7 +742,7 @@ int main(int argc, char** argv)
* au reste de son/leur (futur ex) groupe
*/
pedigree_type ped;
ped.n_alleles = 2;
/*ped.n_alleles = 2;*/
/*ped.max_states = 4;*/
MSG_DEBUG("#####################################################################");
MSG_DEBUG("A");
......@@ -760,6 +767,39 @@ int main(int argc, char** argv)
size_t F3 = ped.selfing("F3", F2); (void) F3;
/*MSG_DEBUG((*ped.get_gen(F3)));*/
MSG_DEBUG("#####################################################################");
MSG_DEBUG("F4");
size_t F4 = ped.selfing("F4", F3);
/*MSG_DEBUG((*ped.get_gen(F4)));*/
MSG_DEBUG("#####################################################################");
MSG_DEBUG("F5");
size_t F5 = ped.selfing("F5", F4);
MSG_DEBUG("#####################################################################");
MSG_DEBUG("F6");
size_t F6 = ped.selfing("F6", F5); (void) F6;
size_t F6_2 = ped.selfing("F6", F5); (void) F6_2;
factor_graph fg(ped);
MSG_DEBUG("" << fg);
bn_message_type marginals;
/*fg.clear_evidence()*/
/*.evidence(ped.tree.ind2node(F6), {'a', 'a', 0, 0}, 0)*/
/*.evidence(ped.tree.ind2node(F6), {'a', 'b', 0, 0}, .5)*/
/*.evidence(ped.tree.ind2node(F6), {'b', 'a', 0, 0}, .5)*/
/*.evidence(ped.tree.ind2node(F6), {'b', 'b', 0, 0}, 0);*/
/*double delta = fg.run(marginals); (void) delta;*/
/*MSG_DEBUG("Marginal probabilities:");*/
/*MSG_DEBUG("" << marginals);*/
fg.clear_evidence()
.evidence(ped.tree.ind2node(F6), {'a', 'a', 0, 0}, 1)
.evidence(ped.tree.ind2node(F6), {'a', 'b', 0, 0}, 0)
.evidence(ped.tree.ind2node(F6), {'b', 'a', 0, 0}, 0)
.evidence(ped.tree.ind2node(F6), {'b', 'b', 0, 0}, 0)
.run(marginals);
MSG_DEBUG("Marginal probabilities:");
MSG_DEBUG("" << marginals);
#if 0
MSG_DEBUG("#####################################################################");
MSG_DEBUG("F4");
......@@ -860,6 +900,7 @@ int main(int argc, char** argv)
size_t F2_CP = ped.selfing("F2_CP", CP);
(void) F2_CP;
#if 0
ped.save("test_pedigree_save.dat");
pedigree_type ped2;
ped2.load("test_pedigree_save.dat");
......@@ -882,9 +923,18 @@ int main(int argc, char** argv)
MSG_DEBUG("LC eq: " << (ped2.LC == ped.LC));
MSG_DEBUG("factor_messages eq: " << (ped2.factor_messages == ped.factor_messages));
MSG_QUEUE_FLUSH();
#endif
factor_graph fg(ped);
MSG_DEBUG("" << fg);
bn_message_type marginals;
/*fg.evidence(ped.tree.ind2node(CP), {'a', 'c', 0, 0}, .5);*/
/*fg.evidence(ped.tree.ind2node(CP), {'b', 'c', 0, 0}, .5);*/
/*fg.evidence(ped.tree.ind2node(CP), {'a', 'd', 0, 0}, 0);*/
/*fg.evidence(ped.tree.ind2node(CP), {'b', 'd', 0, 0}, 0);*/
double delta = fg.run(marginals); (void) delta;
MSG_DEBUG("Marginal probabilities:");
MSG_DEBUG("" << marginals);
}
/* 8 */
......@@ -954,12 +1004,19 @@ int main(int argc, char** argv)
size_t F3_2 = ped.crossing("F3", F2_1, F2_2);
MSG_DEBUG("F4s");
size_t F4_1 = ped.crossing("F4", F3_1, F3_2); (void) F4_1;
/*size_t F4_2 = ped.crossing(F3_1, F3_2); (void) F4_2;*/
/*MSG_DEBUG("F5s");*/
/*size_t F5_1 = ped.crossing(F4_1, F4_2); (void) F5_1;*/
size_t F4_2 = ped.crossing("F4", F3_1, F3_2); (void) F4_2;
MSG_DEBUG("F5s");
size_t F5_1 = ped.crossing("F5", F4_1, F4_2); (void) F5_1;
size_t F5_2 = ped.crossing("F5", F4_1, F4_2); (void) F5_2;
/*size_t F5_3 = ped.crossing("F5", F4_1, F4_2); (void) F5_3;*/
factor_graph fg(ped);
MSG_DEBUG("" << fg);
bn_message_type marginals;
double delta = fg.run(marginals);
MSG_DEBUG("Marginal probabilities:");
MSG_DEBUG("" << marginals);
}
/* 64 */
......@@ -1165,9 +1222,29 @@ int main(int argc, char** argv)
}
/* 65536 */
#if 0
if (test_case())
{
pedigree_type ped;
size_t A = ped.ancestor("A");
size_t B = ped.ancestor("B");
size_t F1 = ped.crossing("F1", A, B);
size_t BC1 = ped.crossing("BC", A, F1);
size_t BC2 = ped.crossing("BC2", A, BC1);
size_t BC3_1 = ped.crossing("BC3", A, BC2);
size_t BC3_2 = ped.crossing("BC3", A, BC2);
size_t BC3_3 = ped.crossing("BC3", A, BC2);
factor_graph fg(ped);
MSG_DEBUG("" << fg);
bn_message_type marginals;
fg.clear_evidence()
.evidence(ped.tree.ind2node(BC3_1), {'a', 'a', 0, 0}, 0)
.evidence(ped.tree.ind2node(BC3_1), {'a', 'b', 0, 0}, 1)
;
double delta = fg.run(marginals); (void) delta;
MSG_DEBUG("Marginal probabilities:");
MSG_DEBUG("" << marginals);
#if 0
geno_matrix A = ancestor_matrix("A", 'a');
geno_matrix B = ancestor_matrix("B", 'b');
geno_matrix F1 = lump(A|B);
......@@ -1183,9 +1260,11 @@ int main(int argc, char** argv)
MSG_DEBUG("" << ones_4);
geno_matrix F3_1 = lump(F2_1 | ones_4);
geno_matrix F3_2 = lump(F2_2 | ones_4);
#endif
}
/* 131072 */
#if 0
if (test_case())
{
pedigree_type ped;
......@@ -1246,6 +1325,13 @@ int main(int argc, char** argv)
size_t F1_1 = ped.crossing("F1_1", BCa, BCc);
size_t F1_2 = ped.crossing("F1_2", BCc, BCf);
size_t F2 = ped.crossing("F2", F1_1, F1_2); (void) F2;
factor_graph fg(ped);
MSG_DEBUG("" << fg);
bn_message_type marginals;
double delta = fg.run(marginals);
MSG_DEBUG("Marginal probabilities:");
MSG_DEBUG("" << marginals);
}
if (0)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment