graphnode.h 111 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/* Spell-QTL  Software suite for the QTL analysis of modern datasets.
 * Copyright (C) 2016,2017  Damien Leroux <damien.leroux@inra.fr>, Sylvain Jasson <sylvain.jasson@inra.fr>
 *
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 */

18
19
20
#ifndef _SPEL_BAYES_GRAPH_NODE_H_
#define _SPEL_BAYES_GRAPH_NODE_H_

21
#include "graphnode_base.h"
22

23
#include "../pedigree.h"
24

25
26
27
28
29
30
31
32
33
34
struct graph_type {
    node_vec rank;
    node_vec represented_by;
    std::vector<node_type> type;
    std::vector<colour_proxy> colour;
    std::vector<node_vec> neighbours_in;
    std::vector<node_vec> neighbours_out;
    std::vector<node_vec> inner_nodes;
    std::vector<var_vec> rules;
    var_vec variables;
35
    std::vector<var_vec> node_variables;
36
    std::map<variable_index_type, VariableIO> io;
37

38
39
    std::map<variable_index_type, node_index_type> interface_to_node;
    std::map<node_index_type, variable_index_type> node_to_interface;
40

41
42
43
    std::vector<std::shared_ptr<graph_type>> subgraphs;

    std::vector<std::shared_ptr<message_type>> tables;
44
    /*std::vector<message_type> state;*/
45
    std::map<var_vec, genotype_comb_type> domains;
46
    /* TODO suppress joint_parent_domains and all afferent code */
47
    std::map<var_vec, genotype_comb_type> joint_parent_domains;
48
    std::map<variable_index_type, char> ancestor_letters;
49
    /*std::vector<compute_state_operation_type> compute_state_ops;*/
50
51
52
53
54
55
56

    const graph_type* parent;
    node_index_type index_in_parent;

    bool aggregate_cycles;
    bool generate_interfaces;

57
    size_t n_alleles;
58

59
60
61
    /* THIS DOESN'T HAVE  TO BE SAVED/LOADED. TEMPORARY STATE IN ORDER TO COMPUTE THE SEQUENCES OF OPERATIONS. */
    std::vector<var_vec> annotations;

62
63
    std::vector<bool> is_dh;

64
    graph_type& operator = (graph_type&& other) = delete;
65

66
    /* node index 0 is the trash bin */
67

68
    graph_type(const graph_type& other) = delete;
69

70
71
72
    graph_type()
        : rank(1), represented_by(1), type(1), colour(1), neighbours_in(1), neighbours_out(1), inner_nodes(1), rules(1), variables(1), node_variables(1), io(), interface_to_node(), node_to_interface(), subgraphs(1), tables(1), /*state(1),*/ parent(nullptr), index_in_parent(0),aggregate_cycles(true), generate_interfaces(true), n_alleles(1), annotations(1), is_dh(1, false)
    {}
73

74
75
76
    graph_type(size_t n_al)
        : rank(1), represented_by(1), type(1), colour(1), neighbours_in(1), neighbours_out(1), inner_nodes(1), rules(1), variables(1), node_variables(1), io(), interface_to_node(), node_to_interface(), subgraphs(1), tables(1), /*state(1),*/ parent(nullptr), index_in_parent(0), aggregate_cycles(true), generate_interfaces(true), n_alleles(n_al), annotations(1), is_dh(1, false)
    {}
77

78
    bool is_aggregate(node_index_type node) const { return type[node] == Aggregate/*inner_nodes[node].size() > 1*/; }
79

80
81
82
83
84
85
86
87
88
89
90
91
92
    bool
        is_compound_interface(node_index_type node) const
        {
            if (is_aggregate(node)) {
                for (node_index_type n: inner_nodes[node]) {
                    if (type[n] != Interface) {
                        return false;
                    }
                }
                return true;
            }
            return false;
        }
93

94
95
96
97
98
99
100
101
102
103
104
105
106
    bool
        is_interface(node_index_type node) const
        {
            if (is_aggregate(node)) {
                for (node_index_type n: inner_nodes[node]) {
                    if (type[n] != Interface) {
                        return false;
                    }
                }
                return true;
            }
            return type[node] == Interface;
        }
107

108
109
110
111
112
    bool
        is_computable(node_index_type node) const
        {
            return !is_interface(node);
        }
113
114
115

    size_t size() const { return rank.size(); }

116
    node_vec
117
118
        active_nodes() const
        {
119
            node_vec ret;
120
            ret.reserve(represented_by.size());
121
            for (node_index_type i = 1; i < represented_by.size(); ++i) {
122
123
124
125
126
127
128
                if (represented_by[i] == i) {
                    ret.push_back(i);
                }
            }
            return ret;
        }

129
130
    node_vec
        resolve_vector(const node_vec& vec) const
131
        {
132
            node_vec ret;
133
            ret.reserve(vec.size());
134
            for (node_index_type i: vec) {
135
136
137
138
                node_index_type r = resolve(i);
                if (r > 0) {
                    ret.push_back(r);
                }
139
140
141
142
143
            }
            sort_and_unique(ret);
            return ret;
        }

144
145
    var_vec
        interface_nodes(var_vec inputs) const
146
        {
147
148
149
150
151
152
153
154
            var_vec ret;
            for (variable_index_type v: inputs) {
                ret.push_back(resolve(interface_to_node.find(v)->second));
            }
            sort_and_unique(ret);
            return ret;
        }

155
    std::vector<edge_type>
156
157
        active_edges() const
        {
158
            std::vector<edge_type> ret;
159
            for (node_index_type n: active_nodes()) {
160
                for (node_index_type o: nei_out(n)) {
161
                    ret.emplace_back(this, n, o);
162
163
164
165
166
                }
            }
            return ret;
        }

167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
    node_vec
        nei_in(node_index_type n) const
        {
            return resolve_vector(neighbours_in[n]);
        }

    node_vec
        nei_out(node_index_type n) const
        {
            return resolve_vector(neighbours_out[n]);
        }

    node_vec
        all_nei(node_index_type n) const
        {
            return nei_in(n) + nei_out(n);
        }

    void
        remove_link(node_index_type nin, node_index_type nout)
        {
            neighbours_out[nin] = resolve_vector(neighbours_out[nin]) - node_vec{nout};
            neighbours_in[nout] = resolve_vector(neighbours_in[nout]) - node_vec{nin};
        }
191
192
193

    void
        dump_node(node_index_type n)
194
        {
195
196
            MSG_DEBUG(
                   '[' << rank[n] << "] " << (is_interface(n) ? "INTERFACE " : (n == inner_nodes[n][0] ? "FACTOR " : "AGGREGATE ")) << n << std::endl
197
198
199
200
201
202
                << "  creation rule " << rules[n] << std::endl
                << "  represented by " << represented_by[n] << " (" << resolve(n) << ')' << std::endl
                << "  colour " << get_colour_impl(colour[n]) << std::endl
                << "  inputs " << neighbours_in[n] << " (" << resolve_vector(neighbours_in[n]) << ')' << std::endl
                << "  outputs " << neighbours_out[n] << " (" << resolve_vector(neighbours_out[n]) << ')' << std::endl
                << "  inner nodes " << inner_nodes[n] << std::endl
203
                << "  variable(s) " << variables_of(n) << std::endl
204
                );
205
206
207
            if (node_domains.size() == rank.size()) {
                MSG_DEBUG("  TABLE " << node_domains[n] << std::endl);
            }
208
            if (inner_nodes[n].size() > 1) {
209
210
211
212
                if (subgraphs[n]) {
                    scoped_indent _("  | ");
                    subgraphs[n]->dump_active();
                }
213
            }
214
            MSG_DEBUG("");
215
216
        }

217
218
219
220
221
#define DUMP_SZ(_x) << " * " #_x " " << _x.size() << std::endl

    void
        dump_sizes() const
        {
222
            MSG_DEBUG(""
223
224
225
226
227
228
229
230
231
                DUMP_SZ(rank)
                DUMP_SZ(type)
                DUMP_SZ(rules)
                DUMP_SZ(colour)
                DUMP_SZ(variables)
                DUMP_SZ(inner_nodes)
                DUMP_SZ(neighbours_in)
                DUMP_SZ(neighbours_out)
                DUMP_SZ(represented_by)
232
                );
233
234
        }

235
236
237
    void
        dump()
        {
238
            MSG_DEBUG("ALL NODES");
239
            /*dump_sizes();*/
240
            for (node_index_type i = 1; i < rank.size(); ++i) {
241
242
243
244
245
246
247
                dump_node(i);
            }
        }

    void
        dump_active()
        {
248
            MSG_DEBUG("ACTIVE NODES");
249
            /*dump_sizes();*/
250
            for (node_index_type i: active_nodes()) {
251
252
253
254
                dump_node(i);
            }
        }

255
256
257
258
259
260
261
262
263
264
265
266
267
    void
        compute_ranks()
        {
            std::vector<bool> visited(rank.size(), false);
            compute_ranks(active_nodes(), visited);
        }

    void
        compute_ranks(const node_vec& nodes, std::vector<bool>& visited)
        {
            for (node_index_type n: nodes) {
                if (visited[n]) { continue; }
                visited[n] = true;
268
                auto nin = nei_in(n);
269
270
271
272
273
274
275
276
277
278
279
280
281
                compute_ranks(nin, visited);
                rank[n] = 0;
                if (nin.size()) {
                    for (node_index_type i: nin) {
                        if (rank[n] < rank[i]) {
                            rank[n] = rank[n];
                        }
                    }
                    ++rank[n];
                }
            }
        }

282
283
284
285
286
    typedef std::pair<size_t, var_vec> emitter_and_interface_type;
    struct compare_eai {
        bool operator () (const emitter_and_interface_type& e1, const emitter_and_interface_type& e2) const { return e1.first < e2.first || (e1.first == e2.first && e1.second < e2.second); }
    };

287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
    typedef std::map<emitter_and_interface_type, size_t, compare_eai> interface_map_type;

    node_index_type
        create_interface(const var_vec& varset)
        {
            node_index_type i;
            if (varset.size() == 1) {
                i = add_interface(node_vec{}, varset.front());
                colour[i] = create_colour();
            } else {
                node_vec iv;
                iv.reserve(varset.size());
                for (variable_index_type v: varset) {
                    iv.push_back(add_interface(node_vec{}, v));
                }
                i = add_node(node_vec{}, node_vec{}, var_vec{}, create_colour(), Aggregate, iv, -1);
                for (node_index_type ni: iv) {
                    represented_by[ni] = i;
305
                    colour[ni] = colour[i];
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
                }
            }
            return i;
        }

    node_index_type
        create_interface_between(node_index_type n1, node_index_type n2)
        {
            auto varset = variables_of(n1) % variables_of(n2);
            node_index_type i = create_interface(varset);
            colour[i] = get_colour_impl(colour[n1]);
            neighbours_in[i].push_back(n1);
            neighbours_out[i].push_back(n2);
            return i;
        }

322
    void
323
        rebuild_interface_between(node_index_type n1, node_index_type n2, interface_map_type& interface_map)
324
        {
325
326
327
328
329
            auto varset = variables_of(n1) % variables_of(n2);
            node_index_type& i = interface_map[{n1, varset}];
            if (i == 0) {
                i = create_interface_between(n1, n2);
            } else {
330
                neighbours_out[i] = neighbours_out[i] + node_vec{n2};
331
            }
332
333
334
335
336
337
            neighbours_in[n2] = nei_in(n2) + node_vec{i} - nei_in(i);
            neighbours_out[n1] = nei_out(n1) + node_vec{i} - nei_out(i);
            remove_link(n1, n2);
            /*filter_out_and_replace_by(neighbours_out[n1], node_vec{n2}, i);*/
            /*filter_out_and_replace_by(neighbours_in[n2], node_vec{n1}, i);*/
            /*MSG_DEBUG("REBUILD " << n1 << ' ' << i << ' ' << n2);*/
338
339
340
341
342
            /*dump_node(n1);*/
            /*dump_node(i);*/
            /*dump_node(n2);*/
        }

343
    void
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
        remove_nei(node_vec& neighbours, node_index_type n)
        {
            node_vec new_nei;
            new_nei.reserve(neighbours.size());
            auto N = resolve(n);
            for (auto x: resolve_vector(neighbours)) {
                if (x != N) {
                    new_nei.push_back(x);
                }
            }
            new_nei.swap(neighbours);
        }

    bool
        suppress_all_cycles()
        {
            bool redo;
            bool redone = false;
            do {
                redo = false;
                for (node_index_type n: active_nodes()) {
365
366
367
368
369
370
371
372
373
374
375
376
377
378
                    auto nei = nei_in(n);
                    /*MSG_DEBUG("LOOK FOR CYCLES FROM NEIGHBOURS OF " << n << "   " << nei);*/
                    for (auto i = nei.begin(), j = nei.end(); i != j; ++i) {
                        std::list<node_index_type> path;
                        for (auto k = i + 1; k != j; ++k) {
                            path = find_path_between_parents(*i, *k, n);
                            if (path.size()) {
                                /*MSG_DEBUG("FOUND BETWEEN " << (*i) << " AND " << (*k) << ": " << path);*/
                                /*for (node_index_type x: path) {*/
                                    /*dump_node(x);*/
                                /*}*/
                                break;
                            }
                        }
379
                        if (path.size()) {
380
                            MSG_DEBUG("FOUND CYCLE ENDING ON " << n << " === " << path);
381
382
                            path.push_back(n);
                            MSG_DEBUG("path to aggregate " << path);
383
                            aggregate_path(path);
384
385
386
387
388
389
390
391
392
393
394
395
                            redo = true;
                            redone = true;
                            break;
                        }
                    }
                }
            } while (redo);
            return redone;
        }

    void
        finalize_stage1(bool reconstruct_itf=true)
396
        {
397
            /*check_neighbours();*/
398
            debug_graph("finalize", 10);
399
            /*scoped_indent _("[stage1] ");*/
400
401
            /* search for cycles and aggregate them */
            interface_map_type interface_map;
402
            if (reconstruct_itf) {
403
404
                /*dump_active();*/
                auto edges = active_edges();
405
406
407
408
                for (const auto& e: edges) {
                    if (is_interface(e.first) ^ !is_interface(e.second)) {
                        /* By construction, it's a factor->factor edge, never an interface->interface edge */
                        rebuild_interface_between(e.first, e.second, interface_map);
409
410
411
                    }
                }
            }
412
            /*check_neighbours();*/
413
            debug_graph("finalize", 11);
414
415
            suppress_all_cycles();
            /*check_neighbours();*/
416
            debug_graph("finalize", 12);
417
418

            if (size()) {
419
420
421
422
                node_vec A = active_nodes();

                for (node_index_type n: A) {
                    if (is_aggregate(n) && !is_interface(n)) {
423
                        /*MSG_DEBUG("subgraphs.size() = " << subgraphs.size() << " n = " << n);*/
424
425
426
                        subgraphs[n] = subgraph(n);
                        /* protect all interfaced variables with this level */
                        var_vec var;
427
                        for (node_index_type ni: nei_in(n)) {
428
429
430
                            auto nv = variables_of(ni);
                            var.insert(var.end(), nv.begin(), nv.end());
                        }
431
                        for (node_index_type no: nei_out(n)) {
432
433
434
435
436
437
438
439
                            auto nv = variables_of(no);
                            var.insert(var.end(), nv.begin(), nv.end());
                        }
                        sort_and_unique(var);
                        for (auto v: var) {
                            subgraphs[n]->io[v] = Input | Output;
                        }
                    }
440
441
                    /* then optimize */
                    /*subgraphs[n]->optimize();*/
442
                }
443
                debug_graph("finalize", 13);
444
445
446
447
448
449
450
            }
        }

    void
        finalize_stage2(bool reconstruct_itf)
        {
            if (reconstruct_itf) {
451
                /* Create interfaces for all variables used in this layer not already represented in a top-level interface */
452
453
454
455
456
457
458
459
                var_vec all_variables;
                std::vector<bool> var_represented(1 + *std::max_element(variables.begin(), variables.end()), false);
                node_vec A = active_nodes();
                for (node_index_type n: A) {
                    if (is_interface(n)) {
                        for (variable_index_type v: variables_of(n)) {
                            var_represented[v] = true;
                        }
460
                    } else if (is_aggregate(n) && nei_in(n).size() == 0) {
461
462
                        var_vec vv;
                        for (node_index_type subn: subgraphs[n]->active_nodes()) {
463
                            if (subgraphs[n]->is_interface(subn) && subgraphs[n]->nei_in(subn).size() == 0) {
464
465
466
467
468
469
                                vv = vv + subgraphs[n]->variables_of(subn);
                            }
                        }
                        if (vv.size()) {
                            node_index_type i = create_interface(vv);
                            neighbours_out[i].push_back(n);
470
                            neighbours_in[n] = nei_in(n) + node_vec{i};
471
                            colour[i] = get_colour_impl(colour[n]);
472
473
474
475
                            for (variable_index_type v: vv) {
                                var_represented[v] = true;
                            }
                        }
476
477
                    }
                }
478
                debug_graph("finalize", 21);
479

480
481
482
483
484
                /* Discover input interfaces in order to merge distinct trees */
                /* Actually we need to discover their neighbours so we can sort-unique them by colours */
                if (parent != NULL) {
                    node_vec itf_nei;
                    for (node_index_type n: A) {
485
486
                        if (is_interface(n) && nei_in(n).size() == 0) {
                            itf_nei = itf_nei + nei_out(n);
487
488
489
490
                        } else if (!is_interface(n)) {
                            for (variable_index_type v: variables_of(n)) {
                                if (var_represented[v]) { continue; }
                                node_index_type i = add_interface(node_vec{n}, v);
491
                                nei_out(n).push_back(i);
492
493
494
495
496
497
498
499
500
                            }
                        }
                    }
                    std::sort(itf_nei.begin(), itf_nei.end(), [this] (node_index_type a, node_index_type b) { return colour[a] < colour[b]; });
                    itf_nei.erase(std::unique(itf_nei.begin(), itf_nei.end(), [this] (node_index_type a, node_index_type b) { return colour_equal(colour[a], colour[b]); }), itf_nei.end());

                    node_vec itf;
                    node_vec out, inner;
                    for (node_index_type nei: itf_nei) {
501
502
                        for (node_index_type n: nei_in(nei)) {
                            if (nei_in(n).size() == 0) {
503
                                itf.push_back(n);
504
                                out = out + nei_out(n);
505
506
507
508
509
                                inner = inner + inner_nodes[n];
                            }
                        }
                    }

510
511
512
513
514
515
516
517
518
519
520
                    /* now also add single interfaces that have no neighbours */

                    for (node_index_type i: A) {
                        if (all_nei(i).size() == 0) {
                            itf.push_back(i);
                            inner = inner + inner_nodes[i];
                        }
                    }
                    std::sort(itf.begin(), itf.end());
                    std::sort(inner.begin(), inner.end());

521
522
                    /* Merge inputs of distinct trees to make one single tree */
#if 1
523
524
525
526
                    /*MSG_DEBUG("Input interfaces " << itf);*/
                    /*for (auto i: itf) {*/
                        /*MSG_DEBUG(" * " << i << " has colour " << get_colour_impl(colour[i]));*/
                    /*}*/
527
528
                    /*std::sort(itf.begin(), itf.end(), [this] (node_index_type a, node_index_type b) { return colour[a] < colour[b]; });*/
                    /*itf.erase(std::unique(itf.begin(), itf.end(), [this] (node_index_type a, node_index_type b) { return colour_equal(colour[a], colour[b]); }), itf.end());*/
529
530
531
532
                    /*MSG_DEBUG("Colour unique input interfaces " << itf);*/
                    /*for (auto i: itf) {*/
                        /*MSG_DEBUG(" * " << i << " has colour " << get_colour_impl(colour[i]));*/
                    /*}*/
533
534
535
536
                    if (itf.size() > 1) {
                        /* FIXME merge interface IIF factor colours are different. Otherwise this creates a cycle. */
                        node_index_type agr = add_node(node_vec{}, out, var_vec{}, create_colour(), Aggregate, inner, -1);
                        node_vec va = {agr};
537
538
539
                        for (node_index_type i: itf) {
                            represented_by[i] = agr;
                        }
540
                        for (node_index_type n: out) {
541
                            neighbours_in[n] = nei_in(n) - itf + va;
542
543
544
545
546
                            if (!colour_equal(colour[n], colour[agr])) {
                                assign_colour_impl(colour[n], colour[agr]);
                            }
                        }
                    }
547
                    debug_graph("finalize", 22);
548
#endif
549
550
551
552
553
554
555
556
557
558
                /*} else {*/
                    /* FIXME we shouldn't have to recompute this here */
                    A = active_nodes();
                    for (node_index_type n: A) {
                        if (is_interface(n)) {
                            for (variable_index_type v: variables_of(n)) {
                                var_represented[v] = true;
                            }
                        }
                    }
559
560
561
                    for (node_index_type n: A) {
                        if (!is_interface(n)) {
                            var_vec all_output;
562
                            for (node_index_type o: nei_out(n)) {
563
564
565
566
567
568
569
570
                                all_output = all_output + variables_of(o);
                            }
                            for (variable_index_type v: variables_of(n)) {
                                if (var_represented[v] || std::find(all_output.begin(), all_output.end(), v) == all_output.end()) { continue; }
                                node_index_type i = add_interface(node_vec{n}, v);
                                neighbours_out[n].push_back(i);
                                colour[i] = colour[n];
                            }
571
572
                        }
                    }
573
                    debug_graph("finalize", 23);
574
                }
575
            }
576
577
            /* unbox single-factor layers */
            for (node_index_type n: active_nodes()) {
578
                for (node_index_type o: nei_out(n)) {
579
580
581
582
583
                    if (!colour_equal(colour[n], colour[o])) {
                        assign_colour_impl(colour[o], colour[n]);
                    }
                }
            }
584
#if 1
585
586
587
588
589
590
591
592
593
594
595
596
597
            for (node_index_type n: active_nodes()) {
                if (type[n] == Aggregate && !is_interface(n)) {
                    size_t count_factors = 0;
                    node_index_type factor_node = (node_index_type) -1;
                    for (node_index_type i: inner_nodes[n]) {
                        if (type[i] == Factor) {
                            ++count_factors;
                            factor_node = i;
                        }
                    }
                    MSG_DEBUG("Aggregate #" << n << " has " << count_factors << " factors");
                    if (count_factors == 1) {
                        represented_by[factor_node] = factor_node;
598
599
                        neighbours_in[factor_node] = nei_in(n);
                        neighbours_out[factor_node] = nei_out(n);
600
601
                        represented_by[n] = factor_node;
                        subgraphs[n].reset();
602
603
604
605
606
607
                        for (node_index_type i: nei_in(n)) {
                            neighbours_out[i] = nei_out(i);
                        }
                        for (node_index_type o: nei_out(n)) {
                            neighbours_in[o] = nei_in(o);
                        }
608
609
610
611
                    }
                }
            }
#endif
612
        }
613

614
615
616
617
618
619
    void
        finalize_import_inputs()
        {
            for (node_index_type n: active_nodes()) {
                if (is_aggregate(n) && !is_interface(n)) {
                    std::vector<var_vec> inputs;
620
                    for (node_index_type i: nei_in(n)) {
621
                        inputs.push_back(variables_of(i));
622
623
                    }
                }
624
            }
625
        }
626

627
628
629
630
631
632
    void
        finalize(bool reconstruct_itf=true)
        {
            scoped_indent _("[finalize] ");
            finalize_stage1(reconstruct_itf);
            finalize_stage2(reconstruct_itf);
633
            update_all_ranks();
634
635
636
637
638
            /*if (reconstruct_itf) {*/
                /*finalize_import_inputs();*/
            /*}*/
        }

639
640
641
642
    node_index_type
        add_node(const node_vec& in, const node_vec& out,
                 const var_vec& rule, colour_proxy col, node_type t,
                 const node_vec& inner, variable_index_type var)
643
        {
644
            node_index_type ret = rank.size();
645
            /*MSG_DEBUG("adding node " << ret);*/
646
647
648
649
650
651
652
653
654
655
656
657
658
            neighbours_out.emplace_back(out);
            neighbours_in.emplace_back(in);
            rules.emplace_back(rule);
            colour.emplace_back(col);
            if (in.size()) {
                size_t r = 0;
                for (node_index_type i: in) {
                    r = std::max(r, rank[i]);
                }
                rank.push_back(r + 1);
            } else {
                rank.push_back(0);
            }
659
            /*MSG_DEBUG("rank=" << rank.back());*/
660
661
            type.push_back(t);
            variables.push_back(var);
662
            node_variables.emplace_back();
663
            represented_by.push_back(ret);
664
665
            subgraphs.emplace_back();
            tables.emplace_back();
666
            /*state.emplace_back();*/
667
            annotations.emplace_back();
668
            is_dh.push_back(false);
669
670
671
672
673
674
675
676
677
678
679
680
681
682
            if (inner.size()) {
                inner_nodes.emplace_back(inner);
            } else {
                inner_nodes.emplace_back(node_vec{ret});
            }
            /*dump();*/
            return ret;
        }

    node_index_type
        resolve_interface(variable_index_type var)
        {
            auto it = interface_to_node.find(var);
            if (it == interface_to_node.end()) {
683
                /*MSG_DEBUG("resolve_interface(" << var << ") => new interface");*/
684
685
686
687
688
                if (generate_interfaces) {
                    return add_interface(node_vec{}, var);
                } else {
                    return (node_index_type) -1;
                }
689
            }
690
            /*MSG_DEBUG("resolve_interface(" << var << ") => resolve(" << it->second << ") = " << resolve(it->second));*/
691
692
693
694
            return resolve(it->second);
        }

    node_index_type
695
        add_interface(const node_vec& producer, variable_index_type var, bool force=false)
696
697
        {
            node_index_type ret = add_node(producer, node_vec{}, var_vec{}, create_colour(), Interface, node_vec{}, var);
698
699
700
701
            if (force || interface_to_node.find(var) == interface_to_node.end()) {
                interface_to_node[var] = ret;
                node_to_interface[ret] = var;
            }
702
703
704
705
706
707
708
709
710
711
712
            for (node_index_type p: producer) {
                neighbours_out[p].push_back(ret);
            }
            return ret;
        }

    node_index_type
        add_factor(const var_vec& rule, colour_proxy col,
                 variable_index_type var)
        {
            node_vec in, out;
713
            /*
714
715
716
717
718
719
720
721
722
723
724
725
726
            if (generate_interfaces) {
                for (variable_index_type v: rule) {
                    in.push_back(resolve_interface(v));
                }
            } else {
                for (variable_index_type v: rule) {
                    auto it = interface_to_node.find(v);
                    if (it != interface_to_node.end()) {
                        in.push_back(resolve(it->second));
                    }
                }
            }
            sort_and_unique(in);
727
            */
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
            node_index_type ret = add_node(in, node_vec{}, rule, col, Factor, node_vec{}, var);
            for (node_index_type n: in) {
                neighbours_out[n].push_back(ret);
            }
            if (generate_interfaces) {
                node_index_type i = add_node(node_vec{ret}, node_vec{}, var_vec{}, colour[ret], Interface, node_vec{}, var);
                interface_to_node[var] = i;
                node_to_interface[i] = var;
                neighbours_out[ret].push_back(i);
            } else {
                interface_to_node[var] = ret;
                node_to_interface[ret] = var;
            }
            return ret;
        }

    node_index_type
        add_factor(variable_index_type var)
        {
747
            /*MSG_DEBUG("add_factor(" << var << ')');*/
748
749
            node_index_type ret = add_factor(var_vec{}, create_colour(), var);
            /*dump();*/
750
751
752
            return ret;
        }

753
754
    node_index_type
        resolve(node_index_type n) const
755
756
757
758
759
        {
            while (n != represented_by[n]) { n = represented_by[n]; }
            return n;
        }

760
761
    node_index_type
        add_factor(variable_index_type v1, variable_index_type var)
762
        {
763
            /*MSG_DEBUG("add_factor(" << v1 << ", " << var << ')');*/
764
765
766
767
768
769
770
771
772
773
774
775
            node_index_type p1r;
            if (!generate_interfaces) {
                auto it = interface_to_node.find(v1);
                if (it == interface_to_node.end()) {
                    auto ret = add_factor(var);
                    rules.back() = {v1};
                    return ret;
                } else {
                    p1r = resolve(it->second);
                }
            } else {
                p1r = resolve_interface(v1);
776
            }
777
            node_index_type ret = add_factor(var_vec{v1}, colour[p1r], var);
778
779
780
            neighbours_in[ret] = {p1r};
            neighbours_out[p1r] = {ret};
            compute_ranks();
781
            /*dump();*/
782
783
784
            return ret;
        }

785
    void
786
        aggregate_path(std::list<node_index_type>& path)
787
788
        {
            parents_of_max_rank aggr_first;
789
            while (path.size() > 2 && ((aggr_first = find_parents_of_max_rank(path)), !aggregate(*aggr_first.p1, *aggr_first.p2, *aggr_first.child, path, aggr_first.p1))) {
790
791
792
793
794
795
                path.erase(aggr_first.p1);
                path.erase(aggr_first.p2);
                path.erase(aggr_first.child);
            }
        }

796
797
    node_index_type
        add_factor(variable_index_type v1,variable_index_type v2, variable_index_type var)
798
        {
799
            /*MSG_DEBUG("add_factor(" << v1 << ", " << v2 << ", " << var << ')');*/
800
            node_index_type p1r, p2r;
801
802
803
804
805
806
807
            /* first, search for a factor/aggregate containing both parents */
            var_vec sorted_rule = v1 < v2 ? var_vec{v1, v2} : var_vec{v2, v1};
            bool found = false;
            node_index_type common = 0;
            for (auto n: active_nodes()) {
                if (type[n] != Factor) {
                    continue;
808
                }
809
810
811
812
813
814
815
816
                auto varset = variables_of(n);
                auto result = sorted_rule % varset;
                MSG_DEBUG("searching for rule " << sorted_rule << " in #" << n << ' ' << varset << " => " << result);
                if (result == sorted_rule) {
                    MSG_DEBUG("Found!");
                    found = true;
                    common = n;
                    break;
817
                }
818
819
820
821
822
823
824
825
826
827
828
829
            }
            if (!generate_interfaces) {
                if (found) {
                    p1r = p2r = common;
                } else {
                    auto it = interface_to_node.find(v1);
                    if (it == interface_to_node.end()) {
                        node_index_type ret = add_factor(v2, var);
                        rules.back() = {v1, v2};
                        return ret;
                    } else {
                        p1r = resolve(it->second);
830
                    }
831
832
833
834
835
836
837
                    it = interface_to_node.find(v2);
                    if (it == interface_to_node.end()) {
                        node_index_type ret = add_factor(v1, var);
                        rules.back() = {v1, v2};
                        return ret;
                    } else {
                        p2r = resolve(it->second);
838
839
                    }
                }
840
            } else {
841
                if (found) {
842
843
844
845
846
847
848
                    found = false;
                    node_index_type agr;
                    for (auto n: nei_out(common)) {
                        if (sorted_rule % variables_of(n) == sorted_rule) {
                            agr = n;
                            found = true;
                            break;
849
                        }
850
                    }
851
852
853
854
855
856
857
858
859
                    if (!found) {
                        node_vec common_fac = {common};
                        node_index_type i1, i2;
                        i1 = add_interface(common_fac, v1);
                        i2 = add_interface(common_fac, v2);
                        agr = add_node(common_fac, node_vec{}, var_vec{}, colour[common_fac.front()], Aggregate, node_vec{i1, i2}, -1);
                        represented_by[i1] = represented_by[i2] = agr;
                    }
                    p1r = p2r = agr;
860
861
862
863
864
865
866
                } else {
                    p1r = resolve_interface(v1);
                    p2r = resolve_interface(v2);
                }

                /*p1r = resolve_interface(v1);*/
                /*p2r = resolve_interface(v2);*/
867
868
            }

869
870
            MSG_DEBUG("p1r=" << p1r << " p2r=" << p2r);

871
872
873
874
875
876
            if (p1r > p2r) {
                p1r ^= p2r;
                p2r ^= p1r;
                p1r ^= p2r;
            }

877
            node_index_type ret;
878

879
880
            if (p1r != (node_index_type) -1) {
                ret = add_factor(var_vec{v1, v2}, colour[p1r], var);
881

882
883
                bool cycle = false;
                if (p1r != p2r) {
884
885
886
                    neighbours_in[ret] = {p1r, p2r};
                    neighbours_out[p1r].push_back(ret);
                    neighbours_out[p2r].push_back(ret);
887
888
                    if (colour_equal(colour[p1r], colour[p2r])) {
                        /* cycle! */
889
                        MSG_DEBUG("cycle!");
890
891
                        cycle = true;
                    } else {
892
                        MSG_DEBUG("no cycle.");
893
894
                        assign_colour_impl(colour[p1r], colour[p2r]);
                    }
895
896
897
                } else {
                    neighbours_in[ret] = {p1r};
                    neighbours_out[p1r].push_back(ret);
898
                }
899
900
901
                rank[ret] = std::max(rank[p1r], rank[p2r]) + 1;
                dump();
                MSG_QUEUE_FLUSH();
902

903
                if (cycle && aggregate_cycles) {
904
                    MSG_DEBUG("search a path between " << p1r << " and " << p2r);
905
906
907
                    std::vector<bool> visited(rank.size(), false);
                    visited[ret] = true;
                    auto path = find_shortest_path(node_vec{p1r, p2r},
908
                            [&, this](node_index_type n) { return nei_in(n) + nei_out(n); },
909
910
911
                            [&] (node_index_type n) { return visited[n]; },
                            [&] (node_index_type n) { visited[n] = true; });
                    path.push_back(ret);
912
                    MSG_DEBUG("New algo path: " << path);
913
                    path = find_path_between_parents(p1r, p2r, ret);
914
                    MSG_DEBUG("Old algo path: " << path);
915
                    /*dump_active();*/
916
                    MSG_DEBUG("Found path: "; for (size_t n: path) { std::cout << ' ' << n; } std::cout);
917
                    /*size_t min_rank = 1 + rank[*std::min_element(path.begin(), path.end(), [this](node_index_type i1, node_index_type i2) { return rank[i1] < rank[i2]; })];*/
918
                    aggregate_path(path);
919
                }
920
921
            } else {
                ret = add_factor(var_vec{v1, v2}, create_colour(), var);
922
923
            }

924
            compute_ranks();
925
            /*dump();*/
926
            dump_active();
927
928
            return ret;
        }
929

930
931
932
    void
        add_ancestor(variable_index_type id)
        {
933
            auto& domain = domains[{id}];
934
935
936
            char letter = ancestor_letters.size() + 'a';
            ancestor_letters[id] = letter;
            for (char al = 0; al < (char) n_alleles; ++al) {
937
                domain.m_combination.emplace_back(genotype_comb_type::element_type{{{id, {letter, letter, al, al}}}, 1.});
938
            }
939
            /*MSG_DEBUG("Domains " << domains);*/
940
        }
941

942
943
944
    node_index_type
        add_cross(variable_index_type p1, variable_index_type p2, variable_index_type id)
        {
945
946
            /*MSG_DEBUG("################ ADD FACTOR " << p1 << " " << p2 << " " << id);*/
            /*auto factor = compute_factor_table(id, var_vec{p1, p2}, false);*/
947
            node_index_type ret = add_factor(p1, p2, id);
948
949
950
            /*tables[ret] = factor;*/
            /*state[ret] = *tables[ret];*/
            /*MSG_DEBUG("Computed factor for #" << id << ": " << (*tables[ret]));*/
951
952
953
            /*operations.push_back(op);*/
            return ret;
        }
954

955
956
957
    node_index_type
        add_dh(variable_index_type p1, variable_index_type id)
        {
958
            /*auto factor = compute_factor_table(id, var_vec{p1}, true);*/
959
            node_index_type ret = add_factor(p1, id);
960
961
            /*tables[ret] = factor;*/
            /*state[ret] = *tables[ret];*/
962
            is_dh[ret] = true;
963
            /*MSG_DEBUG("Computed factor for #" << ret << ": " << (*tables[ret]));*/
964
965
966
            /*operations.push_back(op);*/
            return ret;
        }
967

968
969
970
    node_index_type
        add_selfing(variable_index_type p1, variable_index_type id)
        {
971
            /*auto factor = compute_factor_table(id, var_vec{p1}, false);*/
972
            node_index_type ret = add_factor(p1, id);
973
974
975
            /*tables[ret] = factor;*/
            /*state[ret] = *tables[ret];*/
            /*MSG_DEBUG("Computed factor for #" << ret << ": " << (*tables[ret]));*/
976
            /*operations.push_back(op);*/
977
978
979
980
            return ret;
        }

    bool
981
        find_ascending_path(node_index_type p1, node_index_type p2, node_vec& path, std::vector<bool>& visited)
982
        {
983
            /*MSG_DEBUG("path from " << p1 << " to " << p2);*/
984
985
986
            if (p1 == p2) {
                return true;
            }
987
            for (node_index_type i: nei_in(p1)) {
988
                i = resolve(i);
989
990
991
992
993
                if (visited[i]) {
                    return false;
                }
                visited[i] = true;
                if (find_ascending_path(i, p2, path, visited)) {
994
995
996
997
998
999
1000
                    path.push_back(i);
                    return true;
                }
            }
            return false;
        }

1001
1002
1003
1004
1005
1006
    std::vector<bool>
        create_visited()
        {
            return std::vector<bool>(rank.size(), false);
        }

1007
1008
    node_vec
        find_aggregate_chain(node_index_type p1, node_index_type p2)
1009
        {
1010
            node_vec path;
1011
1012
1013
1014
            auto vis1 = create_visited(), vis2 = create_visited();

            if (rank[p1] > rank[p2] && find_ascending_path(p1, p2, path, vis1)) {
                /*MSG_DEBUG("found ascending path from " << p1 << " to " << p2);*/
1015
                path.push_back(p1);
1016
1017
            } else if (rank[p2] > rank[p1] && find_ascending_path(p2, p1, path, vis2)) {
                /*MSG_DEBUG("found ascending path from " << p2 << " to " << p1);*/
1018
1019
                path.push_back(p2);
            } else {
1020
                /*MSG_DEBUG("no ascending path between parents");*/
1021
1022
1023
1024
1025
1026
                path = {p1, p2};
            }
            return path;
        }

    bool
1027
        filter_out(node_vec& vec, const node_vec& aggr)
1028
        {
1029
            node_vec tmp(vec.size());
1030
1031
1032
1033
1034
1035
1036
            auto it = std::set_difference(vec.begin(), vec.end(), aggr.begin(), aggr.end(), tmp.begin());
            tmp.resize(it - tmp.begin());
            tmp.swap(vec);
            return tmp.size() != vec.size();
        }

    void
1037
        filter_out_and_replace_by(node_vec& vec, const node_vec& aggr, node_index_type new_node)
1038
        {
1039
            /*MSG_DEBUG("filter_out_and_replace_by(" << vec << ", " << aggr << ", " << new_node << ')');*/
1040
1041
1042
1043
1044
1045
1046
            /*if (filter_out(vec, aggr)) {*/
                /*vec.push_back(new_node);*/
            /*}*/
            node_vec tmp;
            tmp.reserve(vec.size());
            for (auto n: vec) {
                node_index_type nr = resolve(n);
1047
                /*MSG_DEBUG("" << nr << " vs " << aggr);*/
1048
1049
1050
1051
1052
1053
1054
                if (std::find(aggr.begin(), aggr.end(), nr) == aggr.end()) {
                    tmp.push_back(n);
                }
            }
            sort_and_unique(tmp);
            if (tmp.size() != vec.size()) {
                tmp.push_back(new_node);
1055
            }
1056
            vec.swap(tmp);
1057
            /*MSG_DEBUG(" => " << vec);*/
1058
1059
1060
        }

    bool
1061
        update_rank(node_index_type node)
1062
        {
1063
            node_index_type r = 0;
1064
1065
            if (nei_in(node).size()) {
                for (node_index_type i: nei_in(node)) {
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
                    if (rank[i] > r) {
                        r = rank[i];
                    }
                }
                ++r;
            }
            if (rank[node] != r) {
                rank[node] = r;
                return true;
            }
            return false;
        }

    void
1080
        propagate_update_rank(node_index_type source)
1081
        {
1082
1083
            auto nout = nei_out(source);
            std::deque<node_index_type> stack(nout.begin(), nout.end());
1084
            while (stack.size()) {
1085
                node_index_type n = stack.front();
1086
1087
                stack.pop_front();
                if (update_rank(n)) {
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
                    nout = nei_out(n);
                    stack.insert(stack.end(), nout.begin(), nout.end());
                }
            }
        }

    void
        update_all_ranks()
        {
            std::deque<node_index_type> stack;
            std::vector<bool> visited(rank.size(), false);
            for (node_index_type n: active_nodes()) {
                if (nei_in(n).size() == 0) {
                    rank[n] = 0;
                    auto out = nei_out(n);
                    stack.insert(stack.end(), out.begin(), out.end());
                    visited[n] = true;
                }
            }
            while (stack.size()) {
                node_index_type n = stack.front();
                stack.pop_front();
                if (!visited[n]) {
                    update_rank(n);
                    auto out = nei_out(n);
                    stack.insert(stack.end(), out.begin(), out.end());
                    visited[n] = true;
1115
1116
1117
1118
1119
1120
1121
                }
            }
        }

    void
        check_neighbours()
        {
1122
1123
            for (node_index_type i : active_nodes()) {
                for (node_index_type n: nei_in(i)) {
1124
                    if (represented_by[n] != n) {
1125
                        MSG_DEBUG("error neighbour_in[" << i << "] " << n << " is represented by " << represented_by[n]);
1126
                    }
1127
1128
1129
1130
                    auto tmp = nei_out(n);
                    if (std::find(tmp.begin(), tmp.end(), i) == tmp.end()) {
                        MSG_DEBUG("error neighbour_in[" << i << "] " << n << " doesn't have " << i << " as a neighbour_out " << tmp);
                    }