bayes.cc 12.3 KB
Newer Older
1
2
3
4
5
6
7
8
#include <string>
#include <iostream>
#include <iomanip>
#include <fstream>
#include "script_parser.h"
#include "malloc.h"
#include "computations.h"
#include "chrono.h"
9
#include "input/read_mark.h"
Damien Leroux's avatar
Damien Leroux committed
10
#include "pedigree.h"
11
12
13

#include "bayes.h"

14
15
#include "bayes/factor_var2.h"

16
#include "generation_rs.h"
17

18
19
20
21
22
23
24
25
26
27
28
29
#include <boost/dynamic_bitset.hpp>


/* COMMAND-LINE HANDLING */

#include "commandline.h"

/* testing the implementation of bra|ket */
#include "bracket.h"



30
31
size_t novar = (size_t) -1;

32
/*
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
namespace direction_data {
#define V1 0, 1, 2, 3
#define V2 4, 5, 6, 7
#define V3 8, 9, 10, 11
    std::vector<std::vector<size_t>>
        normalize = {{V1}, {V2}, {V3}};

#define V11 {0, 0}, {1, 1}, {2, 2}, {3, 3}
#define V21 {4, 0}, {5, 1}, {6, 2}, {7, 3}
#define V22 {4, 4}, {5, 5}, {6, 6}, {7, 7}
#define V32 {8, 4}, {9, 5}, {10, 6}, {11, 7}
    std::vector<std::vector<std::pair<size_t, size_t>>>
        product_mapping_binary = {{V21}, {V11}},
        product_mapping_ternary = {{V21, V32}, {V11, V32}, {V11, V22}};

    std::vector<std::vector<size_t>>
49
50
        sum_mapping_binary = {{V2}, {V1}},
        sum_mapping_ternary = {{V2, V3}, {V1, V3}, {V1, V2}};
51
52

}
53
*/
54
55


56

57
58
59
60
61
62
63
64
65
66
67
68
inline std::ostream& operator << (std::ostream& os, const std::map<std::string, generation_rs*>& d)
{
    os << "Design:" << std::endl;
    auto g = d.begin();
    auto gend = d.end();
    for (; g != gend; ++g) {
        os << (*g->second) << std::endl;
    }
    return os;
}


69
std::map<std::string, std::vector<int>>
70
pedigree_families(const std::vector<pedigree_item>& pedigree, std::map<std::string, generation_rs*>& design, std::map<size_t, const generation_rs*>& ped_gen)
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
{
    std::map<int, std::string> family_by_id;
    std::string fname;
    auto new_gen = [&]() { return design.find(fname) == design.end(); };
    char haplo = 'a';
    for (const auto& pi: pedigree) {
        if (pi.is_ancestor()) {
            fname = family_by_id[pi.id] = pi.gen_name;
            if (new_gen()) {
                std::vector<std::pair<char, char>> alleles = {{0, 0}};
                design[fname] = generation_rs::ancestor(fname, haplo, alleles);
                ++haplo;
            }
        } else if (pi.is_self()) {
            fname = family_by_id[pi.id] = MESSAGE('S' << family_by_id[pi.p1]);
            if (new_gen()) {
                design[fname] = design[family_by_id[pi.p1]]->selfing(fname);
            }
        } else if (pi.is_dh()) {
            fname = family_by_id[pi.id] = MESSAGE("D" << family_by_id[pi.p1]);
            if (new_gen()) {
                design[fname] = design[family_by_id[pi.p1]]->to_doubled_haploid(fname);
            }
        } else if (pi.is_cross()) {
            fname = family_by_id[pi.id] = MESSAGE('(' << family_by_id[pi.p1] << '*' << family_by_id[pi.p2] << ')');
            if (new_gen()) {
                design[fname] = design[family_by_id[pi.p1]]->crossing(fname, design[family_by_id[pi.p2]]);
            }
        }
100
        ped_gen[pi.id] = design[fname];
101
102
103
104
105
    }
    std::map<std::string, std::vector<int>> families;
    for (const auto& kv: family_by_id) {
        families[kv.second].push_back(kv.first);
    }
106
    MSG_DEBUG(design);
107
108
109
110
111
112
113
114
    return families;
}


struct pedigree_bayesian_network {
    bayesian_network bn;
    std::map<std::string, std::vector<size_t>> evidence_by_gen_name;
    std::map<std::string, generation_rs*> generations;
115
    std::map<const generation_rs*, SparseMatrix<double>> geno_to_state;
116
    std::map<std::string, std::vector<int>> families;
117
118
119
120
121
122
123
124
125
126
127
    std::map<size_t, const generation_rs*> gen_by_id;
    std::vector<pedigree_item> pedigree;
    std::map<int, size_t> vars;
    struct vtsh_item {
        size_t var;
        size_t p1;
        size_t p2;
        const generation_rs* gen;
        vtsh_item(size_t v, size_t a, size_t b, const generation_rs* g) : var(v), p1(a), p2(b), gen(g) {}
    };
    std::vector<vtsh_item> var_to_state_helper;
128
129
130
131
    double tol;
    pedigree_bayesian_network(size_t n_parents, size_t n_alleles, double noise, double tolerance)
        : bn(n_parents, n_alleles, noise)
        , evidence_by_gen_name()
132
133
134
135
136
        , generations()
        , geno_to_state()
        , families()
        , pedigree()
        , var_to_state_helper()
137
138
139
        , tol(tolerance)
    {}

140
    std::map<std::string, std::vector<VectorXd>>
141
142
143
144
145
146
147
148
        run(std::function<VectorXd(const std::string&, size_t)> get_obs, size_t verbosity=0) const
        {
            auto comp = bn.instance();
            for (const auto& gv: evidence_by_gen_name) {
                for (size_t i = 0; i < gv.second.size(); ++i) {
                    comp.evidence(gv.second[i]) = get_obs(gv.first, i);
                }
            }
149
            chrono::start("LoopyBP");
150
            comp.run(tol, verbosity);
151
152
153
154
155
156
157
            chrono::stop("LoopyBP");
            std::map<std::string, std::vector<VectorXd>> ret;
#if 1
            chrono::start("Geno->State");
            std::map<size_t, VectorXd> state_vectors;
            /*std::map<const generation_rs*, SparseMatrix<double>> lincombs;*/
            /*for (const auto& kv: generations) {*/
158
                /*VectorLC lc = kv.second->design->lincomb();*/
159
160
161
                /*SparseMatrix<double> mat(*/
                /*lincombs*/
            /*}*/
162
            std::map<const generation_rs*, VectorLC> lincombs;
163
            for (const auto& kv: generations) {
164
                lincombs[kv.second] = kv.second->this_lincomb;
165
166
167
            }
            state_vectors[0] = VectorXd::Ones(1);
            for (const auto& vts: var_to_state_helper) {
168
169
                /*VectorLC lc = vts.gen->design->lincomb();*/
                auto par = vts.gen->design->get_parents();
170
171
172
173
174
                const auto& lc = lincombs[vts.gen];
                VectorXd v1(lc.size());
                VectorXd v2(lc.size());
                VectorXd p1 = state_vectors[vts.p1], p2 = state_vectors[vts.p2];
                for (int i = 0; i < lc.size(); ++i) {
175
176
                    /* FIXME the parent key should be {gen, #id} or #id. NOT gen only. */
                    v1(i) = lc(i, 0).apply({{par.first, p1}, {par.second, p2}});
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
                }
                /*v1 = (v1.array() == 0).select(VectorXd::Zero(v1.size()), VectorXd::Ones(v1.size()));*/
                const SparseMatrix<double>& g2s = geno_to_state.find(vts.gen)->second;
                VectorXd b = comp.parental_origin_belief(vts.var);
                VectorXd bp = g2s.transpose() * v1;
                VectorXd norm2 = (b.array() / (bp.array() == 0).select(VectorXd::Ones(bp.size()), bp).array()).matrix();
                v2 = g2s * norm2;
                /*v2 = g2s * b;*/
                /*VectorXd b01 = (b.array() == 0).select(VectorXd::Zero(b.size()), VectorXd::Ones(b.size()));*/
                /*VectorXd bs = (g2s.array().rowwise() * b01.array().transpose()).array().colwise().sum();*/
                /*VectorXd norm2 = g2s * bs;*/
                /*v2.array() /= (norm2.array() == 0).select(VectorXd::Ones(norm2.size()), norm2).array();*/
                VectorXd tmp = (v1.array() * v2.array()).matrix();
                double s = tmp.sum();
                state_vectors[vts.var] = tmp / (s ? s : 1.);
                /*MSG_DEBUG(vts.gen->name << ':' << vts.var);*/
                /*MSG_DEBUG("p1 " << p1.transpose());*/
                /*MSG_DEBUG("p2 " << p2.transpose());*/
                /*MSG_DEBUG("v1 " << v1.transpose());*/
                /*MSG_DEBUG("b " << b.transpose());*/
                /*MSG_DEBUG("bp " << bp.transpose());*/
                /*MSG_DEBUG("norm2 " << norm2.transpose());*/
                /*MSG_DEBUG("v2 " << v2.transpose());*/
                /*MSG_DEBUG(vts.var << "   " << state_vectors[vts.var].transpose());*/
            }
            std::map<std::string, size_t> sizes;
            for (const auto& pi: pedigree) {
                ++sizes[pi.gen_name];
205
            }
206
207
208
209
210
211
212
213
            for (const auto& kv: sizes) {
                ret[kv.first].reserve(kv.second);
            }
            for (const auto& pi: pedigree) {
                ret[pi.gen_name].emplace_back(state_vectors[vars.find(pi.id)->second]);
            }
            chrono::stop("Geno->State");
#endif
214
215
216
217
            return ret;
        }
};

218
219
220
221
222
223
224
225
226

enum ObservationDomain { ODAncestor, ODAllele };

struct observation_type {
    std::map<char, VectorXd> symbols;
    ObservationDomain domain;
};


227
pedigree_bayesian_network
228
make_bn(const std::vector<pedigree_item>& pedigree, const std::map<std::string, ObservationDomain>& obs_gen, /*const std::string& query_gen,*/ size_t n_alleles=1, double obs_noise=0, double tolerance=1.e-10)
229
230
231
232
233
234
{
    size_t n_par = 0;
    for (const auto& pi: pedigree) {
        if (pi.is_ancestor()) {
            ++n_par;
        }
235
236
237
        /*if (pi.gen_name == query_gen) {*/
            /*++n_q;*/
        /*}*/
238
239
240
    }

    pedigree_bayesian_network ret(n_par, n_alleles, obs_noise, tolerance);
241
242
243
    ret.pedigree = pedigree;

    ret.families = pedigree_families(pedigree, ret.generations, ret.gen_by_id);
244

245
    ret.var_to_state_helper.reserve(pedigree.size());
246
247
248
249
250
251

    bayesian_network& bn = ret.bn;
    auto& obs_map = ret.evidence_by_gen_name;

    for (const auto& pi: pedigree) {
        if (pi.is_ancestor()) {
252
            ret.vars[pi.id] = bn.ancestor();
253
        } else if (pi.is_self()) {
254
            ret.vars[pi.id] = bn.selfing(ret.vars[pi.p1]);
255
        } else if (pi.is_dh()) {
256
            ret.vars[pi.id] = bn.dhing(ret.vars[pi.p1]);
257
        } else if (pi.is_cross()) {
258
            ret.vars[pi.id] = bn.crossing(ret.vars[pi.p1], ret.vars[pi.p2]);
259
        }
260
261
262
263
264
        /* TODO: create state vector nodes here */
        auto ogi = obs_gen.find(pi.gen_name);
        if (ogi != obs_gen.end()) {
            if (ogi->second == ODAllele) {
                obs_map[pi.gen_name].push_back(bn.allele_obs(ret.vars[pi.id]));
265
            } else {
266
                obs_map[pi.gen_name].push_back(bn.ancestor_obs(ret.vars[pi.id]));
267
268
            }
        }
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
        ret.var_to_state_helper.emplace_back(ret.vars[pi.id], ret.vars[pi.p1], ret.vars[pi.p2], ret.gen_by_id[pi.id]);
        bn.varname(ret.vars[pi.id]) = MESSAGE(pi.gen_name << ':' << pi.id);
    }

    size_t dom = n_par * n_par;

    for (const auto& ng: ret.generations) {
        const generation_rs* g = ng.second;
        const auto& labels = g->main_process().row_labels;
        /*SparseMatrix<double>& M = ret.geno_to_state[g] = MatrixXd::Zero(labels.size(), dom);*/
        ret.geno_to_state[g].resize(labels.size(), dom);
        SparseMatrix<double>& M = ret.geno_to_state[g];
        std::vector<Eigen::Triplet<double>> t;
        t.reserve(labels.size());
        std::vector<double> norm;
        norm.resize(dom, 0.);
        for (size_t i = 0; i < labels.size(); ++i) {
            int r = labels[i].first.ancestor - 'a';
            int c = labels[i].second.ancestor - 'a';
            norm[r + n_par * c] += 1.;
        }
        for (double& d: norm) { d = 1. / d; }
        for (size_t i = 0; i < labels.size(); ++i) {
            int r = labels[i].first.ancestor - 'a';
            int c = labels[i].second.ancestor - 'a';
            /*M(i, c * n_par + r) = 1;*/
            t.emplace_back(i, r + c * n_par, norm[r + n_par * c]);
296
        }
297
298
299
300
301
302
303
304
        M.setFromTriplets(t.begin(), t.end());
        /*VectorXd csum = M.array().colwise().sum();*/
        /*M.array().rowwise() /= (csum.array() == 0).select(VectorXd::Ones(csum.size()).array(), csum.array()).transpose();*/

        /*MSG_DEBUG(g->name);*/
        /*for (size_t i = 0; i < labels.size(); ++i) {*/
            /*MSG_DEBUG(labels[i] << "  " << M.row(i));*/
        /*}*/
305
    }
306

307
308
309
310
311
    bn.init_messages();
    return ret;
}


312
313
314
315
316
317
318
319
320
321
322
323
324
template <typename MATRIX_TYPE>
double mass(const MATRIX_TYPE& m)
{
    std::vector<size_t> d;
    size_t s = m.dimensions().size();
    d.reserve(s);
    for (size_t i = 0; i < s; ++i) { d.push_back(i); }
    auto x = sum_over(m, d);
    auto ac = x.accessor();
    return ac.get(0);
}


325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
inline std::ostream& operator << (std::ostream& os, const impl::generation_rs* g)
{
    return os << g->name;
}


void dump_gen(const generation_rs* g, const std::map<const impl::generation_rs*, VectorLC>& expand)
{
    auto vlc = g->this_lincomb;
    if (expand.size() != 0) {
        vlc = vlc | expand;
    }
    const auto& labs = g->main_process().row_labels;
    MSG_DEBUG((*g));
    size_t skip = 0;
    for (int i = 0; i < vlc.size(); ++i) {
        if (vlc(i).m_combination.size()) {
            if (skip) {
                MSG_DEBUG("...skipped " << skip);
                skip = 0;
            }
            MSG_DEBUG(labs[i] << "   " << vlc(i));
        } else {
            ++skip;
        }
    }
}