graphnode.h 110 KB
Newer Older
1
2
3
#ifndef _SPEL_BAYES_GRAPH_NODE_H_
#define _SPEL_BAYES_GRAPH_NODE_H_

4
#include "graphnode_base.h"
5

6
#include "../pedigree.h"
7

8
9
10
11
12
13
14
15
16
17
struct graph_type {
    node_vec rank;
    node_vec represented_by;
    std::vector<node_type> type;
    std::vector<colour_proxy> colour;
    std::vector<node_vec> neighbours_in;
    std::vector<node_vec> neighbours_out;
    std::vector<node_vec> inner_nodes;
    std::vector<var_vec> rules;
    var_vec variables;
18
    std::vector<var_vec> node_variables;
19
    std::map<variable_index_type, VariableIO> io;
20

21
22
    std::map<variable_index_type, node_index_type> interface_to_node;
    std::map<node_index_type, variable_index_type> node_to_interface;
23

24
25
26
    std::vector<std::shared_ptr<graph_type>> subgraphs;

    std::vector<std::shared_ptr<message_type>> tables;
27
    /*std::vector<message_type> state;*/
28
    std::map<var_vec, genotype_comb_type> domains;
29
    /* TODO suppress joint_parent_domains and all afferent code */
30
    std::map<var_vec, genotype_comb_type> joint_parent_domains;
31
    std::map<variable_index_type, char> ancestor_letters;
32
    /*std::vector<compute_state_operation_type> compute_state_ops;*/
33
34
35
36
37
38
39

    const graph_type* parent;
    node_index_type index_in_parent;

    bool aggregate_cycles;
    bool generate_interfaces;

40
    size_t n_alleles;
41

42
43
44
    /* THIS DOESN'T HAVE  TO BE SAVED/LOADED. TEMPORARY STATE IN ORDER TO COMPUTE THE SEQUENCES OF OPERATIONS. */
    std::vector<var_vec> annotations;

45
46
    std::vector<bool> is_dh;

47
    graph_type& operator = (graph_type&& other) = delete;
48

49
    /* node index 0 is the trash bin */
50

51
    graph_type(const graph_type& other) = delete;
52

53
54
55
    graph_type()
        : rank(1), represented_by(1), type(1), colour(1), neighbours_in(1), neighbours_out(1), inner_nodes(1), rules(1), variables(1), node_variables(1), io(), interface_to_node(), node_to_interface(), subgraphs(1), tables(1), /*state(1),*/ parent(nullptr), index_in_parent(0),aggregate_cycles(true), generate_interfaces(true), n_alleles(1), annotations(1), is_dh(1, false)
    {}
56

57
58
59
    graph_type(size_t n_al)
        : rank(1), represented_by(1), type(1), colour(1), neighbours_in(1), neighbours_out(1), inner_nodes(1), rules(1), variables(1), node_variables(1), io(), interface_to_node(), node_to_interface(), subgraphs(1), tables(1), /*state(1),*/ parent(nullptr), index_in_parent(0), aggregate_cycles(true), generate_interfaces(true), n_alleles(n_al), annotations(1), is_dh(1, false)
    {}
60

61
    bool is_aggregate(node_index_type node) const { return type[node] == Aggregate/*inner_nodes[node].size() > 1*/; }
62

63
64
65
66
67
68
69
70
71
72
73
74
75
    bool
        is_compound_interface(node_index_type node) const
        {
            if (is_aggregate(node)) {
                for (node_index_type n: inner_nodes[node]) {
                    if (type[n] != Interface) {
                        return false;
                    }
                }
                return true;
            }
            return false;
        }
76

77
78
79
80
81
82
83
84
85
86
87
88
89
    bool
        is_interface(node_index_type node) const
        {
            if (is_aggregate(node)) {
                for (node_index_type n: inner_nodes[node]) {
                    if (type[n] != Interface) {
                        return false;
                    }
                }
                return true;
            }
            return type[node] == Interface;
        }
90

91
92
93
94
95
    bool
        is_computable(node_index_type node) const
        {
            return !is_interface(node);
        }
96
97
98

    size_t size() const { return rank.size(); }

99
    node_vec
100
101
        active_nodes() const
        {
102
            node_vec ret;
103
            ret.reserve(represented_by.size());
104
            for (node_index_type i = 1; i < represented_by.size(); ++i) {
105
106
107
108
109
110
111
                if (represented_by[i] == i) {
                    ret.push_back(i);
                }
            }
            return ret;
        }

112
113
    node_vec
        resolve_vector(const node_vec& vec) const
114
        {
115
            node_vec ret;
116
            ret.reserve(vec.size());
117
            for (node_index_type i: vec) {
118
119
120
121
                node_index_type r = resolve(i);
                if (r > 0) {
                    ret.push_back(r);
                }
122
123
124
125
126
            }
            sort_and_unique(ret);
            return ret;
        }

127
128
    var_vec
        interface_nodes(var_vec inputs) const
129
        {
130
131
132
133
134
135
136
137
            var_vec ret;
            for (variable_index_type v: inputs) {
                ret.push_back(resolve(interface_to_node.find(v)->second));
            }
            sort_and_unique(ret);
            return ret;
        }

138
    std::vector<edge_type>
139
140
        active_edges() const
        {
141
            std::vector<edge_type> ret;
142
            for (node_index_type n: active_nodes()) {
143
                for (node_index_type o: nei_out(n)) {
144
                    ret.emplace_back(this, n, o);
145
146
147
148
149
                }
            }
            return ret;
        }

150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
    node_vec
        nei_in(node_index_type n) const
        {
            return resolve_vector(neighbours_in[n]);
        }

    node_vec
        nei_out(node_index_type n) const
        {
            return resolve_vector(neighbours_out[n]);
        }

    node_vec
        all_nei(node_index_type n) const
        {
            return nei_in(n) + nei_out(n);
        }

    void
        remove_link(node_index_type nin, node_index_type nout)
        {
            neighbours_out[nin] = resolve_vector(neighbours_out[nin]) - node_vec{nout};
            neighbours_in[nout] = resolve_vector(neighbours_in[nout]) - node_vec{nin};
        }
174
175
176

    void
        dump_node(node_index_type n)
177
        {
178
179
            MSG_DEBUG(
                   '[' << rank[n] << "] " << (is_interface(n) ? "INTERFACE " : (n == inner_nodes[n][0] ? "FACTOR " : "AGGREGATE ")) << n << std::endl
180
181
182
183
184
185
                << "  creation rule " << rules[n] << std::endl
                << "  represented by " << represented_by[n] << " (" << resolve(n) << ')' << std::endl
                << "  colour " << get_colour_impl(colour[n]) << std::endl
                << "  inputs " << neighbours_in[n] << " (" << resolve_vector(neighbours_in[n]) << ')' << std::endl
                << "  outputs " << neighbours_out[n] << " (" << resolve_vector(neighbours_out[n]) << ')' << std::endl
                << "  inner nodes " << inner_nodes[n] << std::endl
186
                << "  variable(s) " << variables_of(n) << std::endl
187
                );
188
189
190
            if (node_domains.size() == rank.size()) {
                MSG_DEBUG("  TABLE " << node_domains[n] << std::endl);
            }
191
            if (inner_nodes[n].size() > 1) {
192
193
194
195
                if (subgraphs[n]) {
                    scoped_indent _("  | ");
                    subgraphs[n]->dump_active();
                }
196
            }
197
            MSG_DEBUG("");
198
199
        }

200
201
202
203
204
#define DUMP_SZ(_x) << " * " #_x " " << _x.size() << std::endl

    void
        dump_sizes() const
        {
205
            MSG_DEBUG(""
206
207
208
209
210
211
212
213
214
                DUMP_SZ(rank)
                DUMP_SZ(type)
                DUMP_SZ(rules)
                DUMP_SZ(colour)
                DUMP_SZ(variables)
                DUMP_SZ(inner_nodes)
                DUMP_SZ(neighbours_in)
                DUMP_SZ(neighbours_out)
                DUMP_SZ(represented_by)
215
                );
216
217
        }

218
219
220
    void
        dump()
        {
221
            MSG_DEBUG("ALL NODES");
222
            /*dump_sizes();*/
223
            for (node_index_type i = 1; i < rank.size(); ++i) {
224
225
226
227
228
229
230
                dump_node(i);
            }
        }

    void
        dump_active()
        {
231
            MSG_DEBUG("ACTIVE NODES");
232
            /*dump_sizes();*/
233
            for (node_index_type i: active_nodes()) {
234
235
236
237
                dump_node(i);
            }
        }

238
239
240
241
242
243
244
245
246
247
248
249
250
    void
        compute_ranks()
        {
            std::vector<bool> visited(rank.size(), false);
            compute_ranks(active_nodes(), visited);
        }

    void
        compute_ranks(const node_vec& nodes, std::vector<bool>& visited)
        {
            for (node_index_type n: nodes) {
                if (visited[n]) { continue; }
                visited[n] = true;
251
                auto nin = nei_in(n);
252
253
254
255
256
257
258
259
260
261
262
263
264
                compute_ranks(nin, visited);
                rank[n] = 0;
                if (nin.size()) {
                    for (node_index_type i: nin) {
                        if (rank[n] < rank[i]) {
                            rank[n] = rank[n];
                        }
                    }
                    ++rank[n];
                }
            }
        }

265
266
267
268
269
    typedef std::pair<size_t, var_vec> emitter_and_interface_type;
    struct compare_eai {
        bool operator () (const emitter_and_interface_type& e1, const emitter_and_interface_type& e2) const { return e1.first < e2.first || (e1.first == e2.first && e1.second < e2.second); }
    };

270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
    typedef std::map<emitter_and_interface_type, size_t, compare_eai> interface_map_type;

    node_index_type
        create_interface(const var_vec& varset)
        {
            node_index_type i;
            if (varset.size() == 1) {
                i = add_interface(node_vec{}, varset.front());
                colour[i] = create_colour();
            } else {
                node_vec iv;
                iv.reserve(varset.size());
                for (variable_index_type v: varset) {
                    iv.push_back(add_interface(node_vec{}, v));
                }
                i = add_node(node_vec{}, node_vec{}, var_vec{}, create_colour(), Aggregate, iv, -1);
                for (node_index_type ni: iv) {
                    represented_by[ni] = i;
288
                    colour[ni] = colour[i];
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
                }
            }
            return i;
        }

    node_index_type
        create_interface_between(node_index_type n1, node_index_type n2)
        {
            auto varset = variables_of(n1) % variables_of(n2);
            node_index_type i = create_interface(varset);
            colour[i] = get_colour_impl(colour[n1]);
            neighbours_in[i].push_back(n1);
            neighbours_out[i].push_back(n2);
            return i;
        }

305
    void
306
        rebuild_interface_between(node_index_type n1, node_index_type n2, interface_map_type& interface_map)
307
        {
308
309
310
311
312
            auto varset = variables_of(n1) % variables_of(n2);
            node_index_type& i = interface_map[{n1, varset}];
            if (i == 0) {
                i = create_interface_between(n1, n2);
            } else {
313
                neighbours_out[i] = neighbours_out[i] + node_vec{n2};
314
            }
315
316
317
318
319
320
            neighbours_in[n2] = nei_in(n2) + node_vec{i} - nei_in(i);
            neighbours_out[n1] = nei_out(n1) + node_vec{i} - nei_out(i);
            remove_link(n1, n2);
            /*filter_out_and_replace_by(neighbours_out[n1], node_vec{n2}, i);*/
            /*filter_out_and_replace_by(neighbours_in[n2], node_vec{n1}, i);*/
            /*MSG_DEBUG("REBUILD " << n1 << ' ' << i << ' ' << n2);*/
321
322
323
324
325
            /*dump_node(n1);*/
            /*dump_node(i);*/
            /*dump_node(n2);*/
        }

326
    void
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
        remove_nei(node_vec& neighbours, node_index_type n)
        {
            node_vec new_nei;
            new_nei.reserve(neighbours.size());
            auto N = resolve(n);
            for (auto x: resolve_vector(neighbours)) {
                if (x != N) {
                    new_nei.push_back(x);
                }
            }
            new_nei.swap(neighbours);
        }

    bool
        suppress_all_cycles()
        {
            bool redo;
            bool redone = false;
            do {
                redo = false;
                for (node_index_type n: active_nodes()) {
348
349
350
351
352
353
354
355
356
357
358
359
360
361
                    auto nei = nei_in(n);
                    /*MSG_DEBUG("LOOK FOR CYCLES FROM NEIGHBOURS OF " << n << "   " << nei);*/
                    for (auto i = nei.begin(), j = nei.end(); i != j; ++i) {
                        std::list<node_index_type> path;
                        for (auto k = i + 1; k != j; ++k) {
                            path = find_path_between_parents(*i, *k, n);
                            if (path.size()) {
                                /*MSG_DEBUG("FOUND BETWEEN " << (*i) << " AND " << (*k) << ": " << path);*/
                                /*for (node_index_type x: path) {*/
                                    /*dump_node(x);*/
                                /*}*/
                                break;
                            }
                        }
362
                        if (path.size()) {
363
                            MSG_DEBUG("FOUND CYCLE ENDING ON " << n << " === " << path);
364
365
                            path.push_back(n);
                            MSG_DEBUG("path to aggregate " << path);
366
                            aggregate_path(path);
367
368
369
370
371
372
373
374
375
376
377
378
                            redo = true;
                            redone = true;
                            break;
                        }
                    }
                }
            } while (redo);
            return redone;
        }

    void
        finalize_stage1(bool reconstruct_itf=true)
379
        {
380
            /*check_neighbours();*/
381
            debug_graph("finalize", 10);
382
            /*scoped_indent _("[stage1] ");*/
383
384
            /* search for cycles and aggregate them */
            interface_map_type interface_map;
385
            if (reconstruct_itf) {
386
387
                /*dump_active();*/
                auto edges = active_edges();
388
389
390
391
                for (const auto& e: edges) {
                    if (is_interface(e.first) ^ !is_interface(e.second)) {
                        /* By construction, it's a factor->factor edge, never an interface->interface edge */
                        rebuild_interface_between(e.first, e.second, interface_map);
392
393
394
                    }
                }
            }
395
            /*check_neighbours();*/
396
            debug_graph("finalize", 11);
397
398
            suppress_all_cycles();
            /*check_neighbours();*/
399
            debug_graph("finalize", 12);
400
401

            if (size()) {
402
403
404
405
                node_vec A = active_nodes();

                for (node_index_type n: A) {
                    if (is_aggregate(n) && !is_interface(n)) {
406
                        /*MSG_DEBUG("subgraphs.size() = " << subgraphs.size() << " n = " << n);*/
407
408
409
                        subgraphs[n] = subgraph(n);
                        /* protect all interfaced variables with this level */
                        var_vec var;
410
                        for (node_index_type ni: nei_in(n)) {
411
412
413
                            auto nv = variables_of(ni);
                            var.insert(var.end(), nv.begin(), nv.end());
                        }
414
                        for (node_index_type no: nei_out(n)) {
415
416
417
418
419
420
421
422
                            auto nv = variables_of(no);
                            var.insert(var.end(), nv.begin(), nv.end());
                        }
                        sort_and_unique(var);
                        for (auto v: var) {
                            subgraphs[n]->io[v] = Input | Output;
                        }
                    }
423
424
                    /* then optimize */
                    /*subgraphs[n]->optimize();*/
425
                }
426
                debug_graph("finalize", 13);
427
428
429
430
431
432
433
            }
        }

    void
        finalize_stage2(bool reconstruct_itf)
        {
            if (reconstruct_itf) {
434
                /* Create interfaces for all variables used in this layer not already represented in a top-level interface */
435
436
437
438
439
440
441
442
                var_vec all_variables;
                std::vector<bool> var_represented(1 + *std::max_element(variables.begin(), variables.end()), false);
                node_vec A = active_nodes();
                for (node_index_type n: A) {
                    if (is_interface(n)) {
                        for (variable_index_type v: variables_of(n)) {
                            var_represented[v] = true;
                        }
443
                    } else if (is_aggregate(n) && nei_in(n).size() == 0) {
444
445
                        var_vec vv;
                        for (node_index_type subn: subgraphs[n]->active_nodes()) {
446
                            if (subgraphs[n]->is_interface(subn) && subgraphs[n]->nei_in(subn).size() == 0) {
447
448
449
450
451
452
                                vv = vv + subgraphs[n]->variables_of(subn);
                            }
                        }
                        if (vv.size()) {
                            node_index_type i = create_interface(vv);
                            neighbours_out[i].push_back(n);
453
                            neighbours_in[n] = nei_in(n) + node_vec{i};
454
                            colour[i] = get_colour_impl(colour[n]);
455
456
457
458
                            for (variable_index_type v: vv) {
                                var_represented[v] = true;
                            }
                        }
459
460
                    }
                }
461
                debug_graph("finalize", 21);
462

463
464
465
466
467
                /* Discover input interfaces in order to merge distinct trees */
                /* Actually we need to discover their neighbours so we can sort-unique them by colours */
                if (parent != NULL) {
                    node_vec itf_nei;
                    for (node_index_type n: A) {
468
469
                        if (is_interface(n) && nei_in(n).size() == 0) {
                            itf_nei = itf_nei + nei_out(n);
470
471
472
473
                        } else if (!is_interface(n)) {
                            for (variable_index_type v: variables_of(n)) {
                                if (var_represented[v]) { continue; }
                                node_index_type i = add_interface(node_vec{n}, v);
474
                                nei_out(n).push_back(i);
475
476
477
478
479
480
481
482
483
                            }
                        }
                    }
                    std::sort(itf_nei.begin(), itf_nei.end(), [this] (node_index_type a, node_index_type b) { return colour[a] < colour[b]; });
                    itf_nei.erase(std::unique(itf_nei.begin(), itf_nei.end(), [this] (node_index_type a, node_index_type b) { return colour_equal(colour[a], colour[b]); }), itf_nei.end());

                    node_vec itf;
                    node_vec out, inner;
                    for (node_index_type nei: itf_nei) {
484
485
                        for (node_index_type n: nei_in(nei)) {
                            if (nei_in(n).size() == 0) {
486
                                itf.push_back(n);
487
                                out = out + nei_out(n);
488
489
490
491
492
                                inner = inner + inner_nodes[n];
                            }
                        }
                    }

493
494
495
496
497
498
499
500
501
502
503
                    /* now also add single interfaces that have no neighbours */

                    for (node_index_type i: A) {
                        if (all_nei(i).size() == 0) {
                            itf.push_back(i);
                            inner = inner + inner_nodes[i];
                        }
                    }
                    std::sort(itf.begin(), itf.end());
                    std::sort(inner.begin(), inner.end());

504
505
                    /* Merge inputs of distinct trees to make one single tree */
#if 1
506
507
508
509
                    /*MSG_DEBUG("Input interfaces " << itf);*/
                    /*for (auto i: itf) {*/
                        /*MSG_DEBUG(" * " << i << " has colour " << get_colour_impl(colour[i]));*/
                    /*}*/
510
511
                    /*std::sort(itf.begin(), itf.end(), [this] (node_index_type a, node_index_type b) { return colour[a] < colour[b]; });*/
                    /*itf.erase(std::unique(itf.begin(), itf.end(), [this] (node_index_type a, node_index_type b) { return colour_equal(colour[a], colour[b]); }), itf.end());*/
512
513
514
515
                    /*MSG_DEBUG("Colour unique input interfaces " << itf);*/
                    /*for (auto i: itf) {*/
                        /*MSG_DEBUG(" * " << i << " has colour " << get_colour_impl(colour[i]));*/
                    /*}*/
516
517
518
519
                    if (itf.size() > 1) {
                        /* FIXME merge interface IIF factor colours are different. Otherwise this creates a cycle. */
                        node_index_type agr = add_node(node_vec{}, out, var_vec{}, create_colour(), Aggregate, inner, -1);
                        node_vec va = {agr};
520
521
522
                        for (node_index_type i: itf) {
                            represented_by[i] = agr;
                        }
523
                        for (node_index_type n: out) {
524
                            neighbours_in[n] = nei_in(n) - itf + va;
525
526
527
528
529
                            if (!colour_equal(colour[n], colour[agr])) {
                                assign_colour_impl(colour[n], colour[agr]);
                            }
                        }
                    }
530
                    debug_graph("finalize", 22);
531
#endif
532
533
534
535
536
537
538
539
540
541
                /*} else {*/
                    /* FIXME we shouldn't have to recompute this here */
                    A = active_nodes();
                    for (node_index_type n: A) {
                        if (is_interface(n)) {
                            for (variable_index_type v: variables_of(n)) {
                                var_represented[v] = true;
                            }
                        }
                    }
542
543
544
                    for (node_index_type n: A) {
                        if (!is_interface(n)) {
                            var_vec all_output;
545
                            for (node_index_type o: nei_out(n)) {
546
547
548
549
550
551
552
553
                                all_output = all_output + variables_of(o);
                            }
                            for (variable_index_type v: variables_of(n)) {
                                if (var_represented[v] || std::find(all_output.begin(), all_output.end(), v) == all_output.end()) { continue; }
                                node_index_type i = add_interface(node_vec{n}, v);
                                neighbours_out[n].push_back(i);
                                colour[i] = colour[n];
                            }
554
555
                        }
                    }
556
                    debug_graph("finalize", 23);
557
                }
558
            }
559
560
            /* unbox single-factor layers */
            for (node_index_type n: active_nodes()) {
561
                for (node_index_type o: nei_out(n)) {
562
563
564
565
566
                    if (!colour_equal(colour[n], colour[o])) {
                        assign_colour_impl(colour[o], colour[n]);
                    }
                }
            }
567
#if 1
568
569
570
571
572
573
574
575
576
577
578
579
580
            for (node_index_type n: active_nodes()) {
                if (type[n] == Aggregate && !is_interface(n)) {
                    size_t count_factors = 0;
                    node_index_type factor_node = (node_index_type) -1;
                    for (node_index_type i: inner_nodes[n]) {
                        if (type[i] == Factor) {
                            ++count_factors;
                            factor_node = i;
                        }
                    }
                    MSG_DEBUG("Aggregate #" << n << " has " << count_factors << " factors");
                    if (count_factors == 1) {
                        represented_by[factor_node] = factor_node;
581
582
                        neighbours_in[factor_node] = nei_in(n);
                        neighbours_out[factor_node] = nei_out(n);
583
584
                        represented_by[n] = factor_node;
                        subgraphs[n].reset();
585
586
587
588
589
590
                        for (node_index_type i: nei_in(n)) {
                            neighbours_out[i] = nei_out(i);
                        }
                        for (node_index_type o: nei_out(n)) {
                            neighbours_in[o] = nei_in(o);
                        }
591
592
593
594
                    }
                }
            }
#endif
595
        }
596

597
598
599
600
601
602
    void
        finalize_import_inputs()
        {
            for (node_index_type n: active_nodes()) {
                if (is_aggregate(n) && !is_interface(n)) {
                    std::vector<var_vec> inputs;
603
                    for (node_index_type i: nei_in(n)) {
604
                        inputs.push_back(variables_of(i));
605
606
                    }
                }
607
            }
608
        }
609

610
611
612
613
614
615
    void
        finalize(bool reconstruct_itf=true)
        {
            scoped_indent _("[finalize] ");
            finalize_stage1(reconstruct_itf);
            finalize_stage2(reconstruct_itf);
616
            update_all_ranks();
617
618
619
620
621
            /*if (reconstruct_itf) {*/
                /*finalize_import_inputs();*/
            /*}*/
        }

622
623
624
625
    node_index_type
        add_node(const node_vec& in, const node_vec& out,
                 const var_vec& rule, colour_proxy col, node_type t,
                 const node_vec& inner, variable_index_type var)
626
        {
627
            node_index_type ret = rank.size();
628
            /*MSG_DEBUG("adding node " << ret);*/
629
630
631
632
633
634
635
636
637
638
639
640
641
            neighbours_out.emplace_back(out);
            neighbours_in.emplace_back(in);
            rules.emplace_back(rule);
            colour.emplace_back(col);
            if (in.size()) {
                size_t r = 0;
                for (node_index_type i: in) {
                    r = std::max(r, rank[i]);
                }
                rank.push_back(r + 1);
            } else {
                rank.push_back(0);
            }
642
            /*MSG_DEBUG("rank=" << rank.back());*/
643
644
            type.push_back(t);
            variables.push_back(var);
645
            node_variables.emplace_back();
646
            represented_by.push_back(ret);
647
648
            subgraphs.emplace_back();
            tables.emplace_back();
649
            /*state.emplace_back();*/
650
            annotations.emplace_back();
651
            is_dh.push_back(false);
652
653
654
655
656
657
658
659
660
661
662
663
664
665
            if (inner.size()) {
                inner_nodes.emplace_back(inner);
            } else {
                inner_nodes.emplace_back(node_vec{ret});
            }
            /*dump();*/
            return ret;
        }

    node_index_type
        resolve_interface(variable_index_type var)
        {
            auto it = interface_to_node.find(var);
            if (it == interface_to_node.end()) {
666
                /*MSG_DEBUG("resolve_interface(" << var << ") => new interface");*/
667
668
669
670
671
                if (generate_interfaces) {
                    return add_interface(node_vec{}, var);
                } else {
                    return (node_index_type) -1;
                }
672
            }
673
            /*MSG_DEBUG("resolve_interface(" << var << ") => resolve(" << it->second << ") = " << resolve(it->second));*/
674
675
676
677
            return resolve(it->second);
        }

    node_index_type
678
        add_interface(const node_vec& producer, variable_index_type var, bool force=false)
679
680
        {
            node_index_type ret = add_node(producer, node_vec{}, var_vec{}, create_colour(), Interface, node_vec{}, var);
681
682
683
684
            if (force || interface_to_node.find(var) == interface_to_node.end()) {
                interface_to_node[var] = ret;
                node_to_interface[ret] = var;
            }
685
686
687
688
689
690
691
692
693
694
695
            for (node_index_type p: producer) {
                neighbours_out[p].push_back(ret);
            }
            return ret;
        }

    node_index_type
        add_factor(const var_vec& rule, colour_proxy col,
                 variable_index_type var)
        {
            node_vec in, out;
696
            /*
697
698
699
700
701
702
703
704
705
706
707
708
709
            if (generate_interfaces) {
                for (variable_index_type v: rule) {
                    in.push_back(resolve_interface(v));
                }
            } else {
                for (variable_index_type v: rule) {
                    auto it = interface_to_node.find(v);
                    if (it != interface_to_node.end()) {
                        in.push_back(resolve(it->second));
                    }
                }
            }
            sort_and_unique(in);
710
            */
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
            node_index_type ret = add_node(in, node_vec{}, rule, col, Factor, node_vec{}, var);
            for (node_index_type n: in) {
                neighbours_out[n].push_back(ret);
            }
            if (generate_interfaces) {
                node_index_type i = add_node(node_vec{ret}, node_vec{}, var_vec{}, colour[ret], Interface, node_vec{}, var);
                interface_to_node[var] = i;
                node_to_interface[i] = var;
                neighbours_out[ret].push_back(i);
            } else {
                interface_to_node[var] = ret;
                node_to_interface[ret] = var;
            }
            return ret;
        }

    node_index_type
        add_factor(variable_index_type var)
        {
730
            /*MSG_DEBUG("add_factor(" << var << ')');*/
731
732
            node_index_type ret = add_factor(var_vec{}, create_colour(), var);
            /*dump();*/
733
734
735
            return ret;
        }

736
737
    node_index_type
        resolve(node_index_type n) const
738
739
740
741
742
        {
            while (n != represented_by[n]) { n = represented_by[n]; }
            return n;
        }

743
744
    node_index_type
        add_factor(variable_index_type v1, variable_index_type var)
745
        {
746
            /*MSG_DEBUG("add_factor(" << v1 << ", " << var << ')');*/
747
748
749
750
751
752
753
754
755
756
757
758
            node_index_type p1r;
            if (!generate_interfaces) {
                auto it = interface_to_node.find(v1);
                if (it == interface_to_node.end()) {
                    auto ret = add_factor(var);
                    rules.back() = {v1};
                    return ret;
                } else {
                    p1r = resolve(it->second);
                }
            } else {
                p1r = resolve_interface(v1);
759
            }
760
            node_index_type ret = add_factor(var_vec{v1}, colour[p1r], var);
761
762
763
            neighbours_in[ret] = {p1r};
            neighbours_out[p1r] = {ret};
            compute_ranks();
764
            /*dump();*/
765
766
767
            return ret;
        }

768
    void
769
        aggregate_path(std::list<node_index_type>& path)
770
771
        {
            parents_of_max_rank aggr_first;
772
            while (path.size() > 2 && ((aggr_first = find_parents_of_max_rank(path)), !aggregate(*aggr_first.p1, *aggr_first.p2, *aggr_first.child, path, aggr_first.p1))) {
773
774
775
776
777
778
                path.erase(aggr_first.p1);
                path.erase(aggr_first.p2);
                path.erase(aggr_first.child);
            }
        }

779
780
    node_index_type
        add_factor(variable_index_type v1,variable_index_type v2, variable_index_type var)
781
        {
782
            /*MSG_DEBUG("add_factor(" << v1 << ", " << v2 << ", " << var << ')');*/
783
            node_index_type p1r, p2r;
784
785
786
787
788
789
790
            /* first, search for a factor/aggregate containing both parents */
            var_vec sorted_rule = v1 < v2 ? var_vec{v1, v2} : var_vec{v2, v1};
            bool found = false;
            node_index_type common = 0;
            for (auto n: active_nodes()) {
                if (type[n] != Factor) {
                    continue;
791
                }
792
793
794
795
796
797
798
799
                auto varset = variables_of(n);
                auto result = sorted_rule % varset;
                MSG_DEBUG("searching for rule " << sorted_rule << " in #" << n << ' ' << varset << " => " << result);
                if (result == sorted_rule) {
                    MSG_DEBUG("Found!");
                    found = true;
                    common = n;
                    break;
800
                }
801
802
803
804
805
806
807
808
809
810
811
812
            }
            if (!generate_interfaces) {
                if (found) {
                    p1r = p2r = common;
                } else {
                    auto it = interface_to_node.find(v1);
                    if (it == interface_to_node.end()) {
                        node_index_type ret = add_factor(v2, var);
                        rules.back() = {v1, v2};
                        return ret;
                    } else {
                        p1r = resolve(it->second);
813
                    }
814
815
816
817
818
819
820
                    it = interface_to_node.find(v2);
                    if (it == interface_to_node.end()) {
                        node_index_type ret = add_factor(v1, var);
                        rules.back() = {v1, v2};
                        return ret;
                    } else {
                        p2r = resolve(it->second);
821
822
                    }
                }
823
            } else {
824
                if (found) {
825
826
827
828
829
830
831
                    found = false;
                    node_index_type agr;
                    for (auto n: nei_out(common)) {
                        if (sorted_rule % variables_of(n) == sorted_rule) {
                            agr = n;
                            found = true;
                            break;
832
                        }
833
                    }
834
835
836
837
838
839
840
841
842
                    if (!found) {
                        node_vec common_fac = {common};
                        node_index_type i1, i2;
                        i1 = add_interface(common_fac, v1);
                        i2 = add_interface(common_fac, v2);
                        agr = add_node(common_fac, node_vec{}, var_vec{}, colour[common_fac.front()], Aggregate, node_vec{i1, i2}, -1);
                        represented_by[i1] = represented_by[i2] = agr;
                    }
                    p1r = p2r = agr;
843
844
845
846
847
848
849
                } else {
                    p1r = resolve_interface(v1);
                    p2r = resolve_interface(v2);
                }

                /*p1r = resolve_interface(v1);*/
                /*p2r = resolve_interface(v2);*/
850
851
            }

852
853
            MSG_DEBUG("p1r=" << p1r << " p2r=" << p2r);

854
855
856
857
858
859
            if (p1r > p2r) {
                p1r ^= p2r;
                p2r ^= p1r;
                p1r ^= p2r;
            }

860
            node_index_type ret;
861

862
863
            if (p1r != (node_index_type) -1) {
                ret = add_factor(var_vec{v1, v2}, colour[p1r], var);
864

865
866
                bool cycle = false;
                if (p1r != p2r) {
867
868
869
                    neighbours_in[ret] = {p1r, p2r};
                    neighbours_out[p1r].push_back(ret);
                    neighbours_out[p2r].push_back(ret);
870
871
                    if (colour_equal(colour[p1r], colour[p2r])) {
                        /* cycle! */
872
                        MSG_DEBUG("cycle!");
873
874
                        cycle = true;
                    } else {
875
                        MSG_DEBUG("no cycle.");
876
877
                        assign_colour_impl(colour[p1r], colour[p2r]);
                    }
878
879
880
                } else {
                    neighbours_in[ret] = {p1r};
                    neighbours_out[p1r].push_back(ret);
881
                }
882
883
884
                rank[ret] = std::max(rank[p1r], rank[p2r]) + 1;
                dump();
                MSG_QUEUE_FLUSH();
885

886
                if (cycle && aggregate_cycles) {
887
                    MSG_DEBUG("search a path between " << p1r << " and " << p2r);
888
889
890
                    std::vector<bool> visited(rank.size(), false);
                    visited[ret] = true;
                    auto path = find_shortest_path(node_vec{p1r, p2r},
891
                            [&, this](node_index_type n) { return nei_in(n) + nei_out(n); },
892
893
894
                            [&] (node_index_type n) { return visited[n]; },
                            [&] (node_index_type n) { visited[n] = true; });
                    path.push_back(ret);
895
                    MSG_DEBUG("New algo path: " << path);
896
                    path = find_path_between_parents(p1r, p2r, ret);
897
                    MSG_DEBUG("Old algo path: " << path);
898
                    /*dump_active();*/
899
                    MSG_DEBUG("Found path: "; for (size_t n: path) { std::cout << ' ' << n; } std::cout);
900
                    /*size_t min_rank = 1 + rank[*std::min_element(path.begin(), path.end(), [this](node_index_type i1, node_index_type i2) { return rank[i1] < rank[i2]; })];*/
901
                    aggregate_path(path);
902
                }
903
904
            } else {
                ret = add_factor(var_vec{v1, v2}, create_colour(), var);
905
906
            }

907
            compute_ranks();
908
            /*dump();*/
909
            dump_active();
910
911
            return ret;
        }
912

913
914
915
    void
        add_ancestor(variable_index_type id)
        {
916
            auto& domain = domains[{id}];
917
918
919
            char letter = ancestor_letters.size() + 'a';
            ancestor_letters[id] = letter;
            for (char al = 0; al < (char) n_alleles; ++al) {
920
                domain.m_combination.emplace_back(genotype_comb_type::element_type{{{id, {letter, letter, al, al}}}, 1.});
921
            }
922
            /*MSG_DEBUG("Domains " << domains);*/
923
        }
924

925
926
927
    node_index_type
        add_cross(variable_index_type p1, variable_index_type p2, variable_index_type id)
        {
928
929
            /*MSG_DEBUG("################ ADD FACTOR " << p1 << " " << p2 << " " << id);*/
            /*auto factor = compute_factor_table(id, var_vec{p1, p2}, false);*/
930
            node_index_type ret = add_factor(p1, p2, id);
931
932
933
            /*tables[ret] = factor;*/
            /*state[ret] = *tables[ret];*/
            /*MSG_DEBUG("Computed factor for #" << id << ": " << (*tables[ret]));*/
934
935
936
            /*operations.push_back(op);*/
            return ret;
        }
937

938
939
940
    node_index_type
        add_dh(variable_index_type p1, variable_index_type id)
        {
941
            /*auto factor = compute_factor_table(id, var_vec{p1}, true);*/
942
            node_index_type ret = add_factor(p1, id);
943
944
            /*tables[ret] = factor;*/
            /*state[ret] = *tables[ret];*/
945
            is_dh[ret] = true;
946
            /*MSG_DEBUG("Computed factor for #" << ret << ": " << (*tables[ret]));*/
947
948
949
            /*operations.push_back(op);*/
            return ret;
        }
950

951
952
953
    node_index_type
        add_selfing(variable_index_type p1, variable_index_type id)
        {
954
            /*auto factor = compute_factor_table(id, var_vec{p1}, false);*/
955
            node_index_type ret = add_factor(p1, id);
956
957
958
            /*tables[ret] = factor;*/
            /*state[ret] = *tables[ret];*/
            /*MSG_DEBUG("Computed factor for #" << ret << ": " << (*tables[ret]));*/
959
            /*operations.push_back(op);*/
960
961
962
963
            return ret;
        }

    bool
964
        find_ascending_path(node_index_type p1, node_index_type p2, node_vec& path, std::vector<bool>& visited)
965
        {
966
            /*MSG_DEBUG("path from " << p1 << " to " << p2);*/
967
968
969
            if (p1 == p2) {
                return true;
            }
970
            for (node_index_type i: nei_in(p1)) {
971
                i = resolve(i);
972
973
974
975
976
                if (visited[i]) {
                    return false;
                }
                visited[i] = true;
                if (find_ascending_path(i, p2, path, visited)) {
977
978
979
980
981
982
983
                    path.push_back(i);
                    return true;
                }
            }
            return false;
        }

984
985
986
987
988
989
    std::vector<bool>
        create_visited()
        {
            return std::vector<bool>(rank.size(), false);
        }

990
991
    node_vec
        find_aggregate_chain(node_index_type p1, node_index_type p2)
992
        {
993
            node_vec path;
994
995
996
997
            auto vis1 = create_visited(), vis2 = create_visited();

            if (rank[p1] > rank[p2] && find_ascending_path(p1, p2, path, vis1)) {
                /*MSG_DEBUG("found ascending path from " << p1 << " to " << p2);*/
998
                path.push_back(p1);
999
1000
            } else if (rank[p2] > rank[p1] && find_ascending_path(p2, p1, path, vis2)) {
                /*MSG_DEBUG("found ascending path from " << p2 << " to " << p1);*/