Commit 014604ce authored by Damien Leroux's avatar Damien Leroux
Browse files

Implemented min-cut to handle table products in huge joint domains.

parent fcc66c6b
......@@ -2,10 +2,22 @@
<project version="4">
<component name="CMakeWorkspace" PROJECT_DIR="$PROJECT_DIR$" />
<component name="CidrRootsConfiguration">
<sourceRoots>
<file path="$PROJECT_DIR$/3rd-party" />
<file path="$PROJECT_DIR$/include" />
<file path="$PROJECT_DIR$/src" />
<file path="$PROJECT_DIR$/tests" />
</sourceRoots>
<excludeRoots>
<file path="$PROJECT_DIR$/CMakeFiles" />
<file path="$PROJECT_DIR$/cmake-build-debug/CMakeFiles" />
<file path="$PROJECT_DIR$/cmake-build-release/CMakeFiles" />
<file path="$PROJECT_DIR$/include/bayes" />
<file path="$PROJECT_DIR$/include/cache" />
<file path="$PROJECT_DIR$/include/computations" />
<file path="$PROJECT_DIR$/include/data" />
<file path="$PROJECT_DIR$/include/input" />
<file path="$PROJECT_DIR$/include/model" />
<file path="$PROJECT_DIR$/tests/TestCodeCoverage" />
</excludeRoots>
</component>
</project>
\ No newline at end of file
......@@ -2,7 +2,7 @@
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/spell-qtl.iml" filepath="$PROJECT_DIR$/.idea/spell-qtl.iml" />
<module fileurl="file://$PROJECT_DIR$/.idea/spel.iml" filepath="$PROJECT_DIR$/.idea/spel.iml" />
</modules>
</component>
</project>
\ No newline at end of file
This diff is collapsed.
cmake_minimum_required(VERSION 3.5)
project(spell_qtl)
#set(CMAKE_VERBOSE_MAKEFILE ON)
set(CMAKE_CONFIGURATION_TYPES Debug Release CACHE TYPE INTERNAL FORCE)
#set(CMAKE_CXX_STANDARD 11)
set(CMAKE_VERBOSE_MAKEFILE ON)
MESSAGE(STATUS "CMAKE VERSION ${CMAKE_VERSION}")
MESSAGE(STATUS "${CMAKE_CURRENT_SOURCE_DIR}")
execute_process(
COMMAND bash -c "git rev-list --count `git rev-list --tags --max-count=1`..HEAD 2>/dev/null || echo 0"
COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/get_patch_number.sh
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE VERSION_PATCH
OUTPUT_STRIP_TRAILING_WHITESPACE
)
......@@ -29,6 +33,8 @@ LIST(GET VERSION 1 VERSION_MINOR)
set(CMAKE_CXX_STANDARD 11)
SET(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Wextra -Wall -Wno-unused-parameter -pthread -fPIC")
SET(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS} -O3 -DNDEBUG -DEIGEN_NO_DEBUG")
SET(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS} -O0 -ggdb")
add_definitions(-DEIGEN_NO_DEPRECATED_WARNING -DVERSION_MAJOR=\"${VERSION_MAJOR}\" -DVERSION_MINOR=\"${VERSION_MINOR}\" -DVERSION_PATCH=\"${VERSION_PATCH}\")
......@@ -36,8 +42,8 @@ MESSAGE(STATUS "CXX FLAGS ${CMAKE_CXX_FLAGS}")
find_library(EXPAT_LIBRARY_NAMES expat)
find_path(EXPAT_INCLUDE_DIR expat.h HINTS /usr/include /usr/local/include /usr/include/expat/ /usr/local/include/expat/)
find_path(X2C_INCLUDE_DIR x2c.h HINTS /usr/include /usr/local/include /usr/include/x2c/ /usr/local/include/x2c/)
find_path(EIGEN_INCLUDE_DIR Eigen/Eigen HINTS /usr/include /usr/local/include /usr/include/eigen3/ /usr/local/include/eigen3/)
find_path(X2C_INCLUDE_DIR x2c/x2c.h HINTS /usr/include/ /usr/local/include/ /home/daleroux/include/)
find_path(EIGEN_INCLUDE_DIR Eigen/Eigen HINTS /usr/include /usr/local/include /usr/include/eigen3/ /usr/local/include/eigen3/ /home/daleroux/include/eigen3/)
include_directories(AFTER 3rd-party/ThreadPool)
include_directories(AFTER include/ include/input/ include/bayes/ ${EIGEN_INCLUDE_DIR})
......@@ -61,12 +67,12 @@ set(SPELL_QTL_SRC
src/computations/basic_data.cc src/computations/probabilities.cc src/computations/model.cc src/computations/frontends.cc
)
add_executable(spell_pedigree ${SPELL_PEDIGREE_SRC})
add_executable(spell_marker ${SPELL_MARKER_SRC})
add_executable(spell_qtl ${SPELL_QTL_SRC})
add_executable(spell-pedigree ${SPELL_PEDIGREE_SRC})
add_executable(spell-marker ${SPELL_MARKER_SRC})
add_executable(spell-qtl ${SPELL_QTL_SRC})
target_link_libraries(spell_marker expat dl)
target_link_libraries(spell_qtl expat dl)
target_link_libraries(spell-marker expat dl)
target_link_libraries(spell-qtl expat dl)
set(CMAKE_EXE_LINKER_FLAGS "-rdynamic")
SET(EXECUTABLE_OUTPUT_PATH ${PROJECT_BINARY_DIR}/bin)
......@@ -120,7 +120,7 @@ struct state_index_type {
for (int i = n_blocks - 1; i >= 0; ++i) {
uint64_t mask = 1ULL << 63;
while (mask && !(s.data[i] & mask)) { mask >>= 1; }
while (mask) { os << (!!(s.data[i] & mask)); mask >>= 1; }
while (mask) { os << ((s.data[i] & mask) == 0); mask >>= 1; }
}
return os;
}
......@@ -393,6 +393,8 @@ struct joint_variable_product_type {
bool invalid_product = false;
const std::map<std::vector<int>, genotype_comb_type>* all_domains = nullptr;
inline
void
set_output(const std::vector<int>& variables)
......@@ -549,6 +551,7 @@ struct joint_variable_product_type {
void
compile_domains(const std::map<std::vector<int>, genotype_comb_type>& all_variable_domains)
{
all_domains = &all_variable_domains;
total_bits = 0;
for (int v: all_variable_names) {
const auto& domain = all_variable_domains.find(std::vector<int>{v})->second;
......@@ -944,6 +947,51 @@ struct joint_variable_product_type {
genotype_comb_type
compute_generic()
{
if (total_bits >= 32) {
std::vector<size_t> output_variable_indices;
for (size_t i = 0; i < output_variables_in_use.size(); ++i) {
output_variable_indices.push_back(i);
}
/* Build adjacency matrix */
std::vector<std::vector<int>> weights(tables.size());
for (size_t i = 0; i < tables.size(); ++i) { weights[i].resize(tables.size(), 0); }
for (size_t i = 0; i < tables.size(); ++i) {
const auto& Vi = tables[i].variable_indices;
for (size_t j = i + 1; j < tables.size(); ++j) {
const auto& Vj = tables[j].variable_indices;
int w = 0;
for (size_t varidx: (Vi % Vj) - output_variable_indices) {
w += domains[varidx].bit_width;
}
weights[i][j] = w;
weights[j][i] = w;
}
}
/* Find minimum cut */
std::pair<int, std::vector<int>> cut = get_min_cut(weights);
std::vector<bool> first_part(tables.size(), false);
for (int c: cut.second) {
first_part[c] = true;
}
joint_variable_product_type p1, p2;
auto ti = tables.begin();
for (auto b: first_part) {
(b ? p1 : p2).add_table(*ti->data);
}
auto common = p1.all_variable_names % p2.all_variable_names;
p1.set_output(output_variable_names + common);
p2.set_output(output_variable_names + common);
p1.compile(*all_domains);
p2.compile(*all_domains);
genotype_comb_type t1 = p1.compute_generic<key_computing_variant>();
genotype_comb_type t2 = p2.compute_generic<key_computing_variant>();
joint_variable_product_type final;
final.add_table(t1);
final.add_table(t2);
final.compile(*all_domains);
return final.compute_generic<key_computing_variant>();
}
key_computing_variant key_update;
if (invalid_product) { return {}; }
......@@ -1012,6 +1060,45 @@ struct joint_variable_product_type {
inline genotype_comb_type compute() { return compute_generic<update_key>(); }
inline genotype_comb_type build_factor() { return compute_generic<update_key_build_factor>(); }
/* Algorithm adapted from wikipedia page https://en.wikipedia.org/wiki/Stoer%E2%80%93Wagner_algorithm */
std::pair<int, std::vector<int>>
get_min_cut(std::vector<std::vector<int>> &weights) {
int N = weights.size();
std::vector<int> used(N), cut, best_cut;
int best_weight = std::numeric_limits<int>::max();
int best_size = std::numeric_limits<int>::max();
int ws = weights.size();
for (int phase = N-1; phase >= 0; phase--) {
std::vector<int> w = weights[0];
std::vector<int> added = used;
int prev, last = 0;
for (int i = 0; i < phase; i++) {
prev = last;
last = -1;
for (int j = 1; j < N; j++)
if (!added[j] && (last == -1 || w[j] > w[last])) last = j;
if (i == phase - 1) {
for (int j = 0; j < N; j++) weights[prev][j] += weights[last][j];
for (int j = 0; j < N; j++) weights[j][prev] = weights[prev][j];
used[last] = true;
cut.push_back(last);
int cs = ws - 2 * cut.size();
int sz = cs < 0 ? -cs : cs;
if ((w[last] == best_weight && sz < best_size) || w[last] < best_weight) {
best_cut = cut;
best_weight = w[last];
best_size = sz;
}
} else {
for (int j = 0; j < N; j++)
w[j] += weights[last][j];
added[last] = true;
}
}
}
return std::make_pair(best_weight, best_cut);
}
};
......@@ -1121,4 +1208,33 @@ extract_domain(const genotype_comb_type& factor, const std::vector<int>& variabl
}
}
#include <valarray>
struct hierarchical_product {
static constexpr size_t no_table = (size_t) -1;
struct hp_node {
size_t table_index;
std::vector<hp_node> sub_products;
};
hp_node root;
std::vector<const genotype_comb_type*> tables;
void
add_table(const genotype_comb_type& t)
{
tables.push_back(&t);
}
void
finalize()
{
}
};
#endif
......@@ -816,9 +816,9 @@ struct factor_graph_type : public recursive_graph_type<factor_graph_type> {
for (auto v: vv) {
auto nn = filter(nodes(), [&] (node_index_type n) { auto vv = variables_of(n); return std::find(vv.begin(), vv.end(), v) != vv.end(); });
/*MSG_DEBUG("variable " << v << " nodes " << nn);*/
for (auto n: nn) {
/*MSG_DEBUG(" node colour " << colour_of(n));*/
}
// for (auto n: nn) {
// MSG_DEBUG(" node colour " << colour_of(n));
// }
if (nn.size()) {
auto col = colour_of(nn.front());
nv[v].push_back(nn);
......
......@@ -251,7 +251,7 @@ struct multiple_product_type {
for (const auto& kv: bins) {
if ((output % varset_bins[kv.first]).size()) {
if (kv.second.size()) {
MSG_QUEUE_FLUSH();
// MSG_QUEUE_FLUSH();
/*if (kv.second.size() > 1) {*/
tmp.emplace_back(compute_product(kv.second.begin(), kv.second.end(), output, domains));
/*} else {*/
......@@ -669,7 +669,7 @@ struct graph_base_type {
/*MSG_DEBUG("Creating aggregate " << ret);*/
foreach_in(nodes, [&,this] (node_index_type n) { move_node_to_aggregate(n, ret); });
/*MSG_DEBUG("updating ranks...");*/
MSG_QUEUE_FLUSH();
// MSG_QUEUE_FLUSH();
update_ranks(ret);
std::sort(m_nodes[ret].inner_nodes.begin(), m_nodes[ret].inner_nodes.end());
remove_edge(ret, ret);
......@@ -684,7 +684,7 @@ struct graph_base_type {
rec_update_rank(node_index_type n, const std::vector<bool>& modified, std::vector<bool>& updated)
{
/*MSG_DEBUG("rec_update_rank on " << n);*/
MSG_QUEUE_FLUSH();
// MSG_QUEUE_FLUSH();
size_t r = 0;
for (node_index_type i: nei_in(n)) {
if (modified[i] && !updated[i]) {
......
......@@ -50,7 +50,7 @@ dispatch_geno_probs(
MSG_DEBUG("labels " << labels);
MSG_DEBUG("geno_probs " << geno_probs);
MSG_DEBUG("state_prob " << state_prob);
MSG_QUEUE_FLUSH();
// MSG_QUEUE_FLUSH();
std::vector<double> ret(lc.size());
std::map<label_type, double> norms;
for (size_t i = 0; i < ret.size(); ++i) {
......@@ -315,7 +315,7 @@ job_registry = {
}
MSG_DEBUG(ss.str());
}
MSG_QUEUE_FLUSH();
// MSG_QUEUE_FLUSH();
std::map<size_t, std::vector<double>> state_prob, output_prob;
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment