From d2c794aee92386af7195b844b401f2d59b686ce8 Mon Sep 17 00:00:00 2001
From: Floreal Cabanettes <floreal.cabanettes@inra.fr>
Date: Thu, 29 Mar 2018 18:17:21 +0200
Subject: [PATCH] Add build of jupyter notebook summary

---
 Summarized_results.ipynb | 490 +++++++++++++++++++++++++++++++++++++++
 build_results.py         |  88 +++----
 lib/genotype_results.py  |  45 ++--
 lib/vcf.py               |  41 ++++
 4 files changed, 593 insertions(+), 71 deletions(-)
 create mode 100644 Summarized_results.ipynb
 create mode 100644 lib/vcf.py

diff --git a/Summarized_results.ipynb b/Summarized_results.ipynb
new file mode 100644
index 0000000..e5ae770
--- /dev/null
+++ b/Summarized_results.ipynb
@@ -0,0 +1,490 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# CNV detection benchmark"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%%html\n",
+    "\n",
+    "<script>\n",
+    "  function code_toggle() {\n",
+    "    if (code_shown){\n",
+    "      $('div.input').hide('500');\n",
+    "      $('#toggleButton').val('Show Code')\n",
+    "    } else {\n",
+    "      $(\"div.input:not(:first)\").show('500');\n",
+    "      $('#toggleButton').val('Hide Code')\n",
+    "    }\n",
+    "    code_shown = !code_shown\n",
+    "  }\n",
+    "\n",
+    "  $( document ).ready(function(){\n",
+    "    code_shown=false;\n",
+    "    $('div.input').hide()\n",
+    "  });\n",
+    "</script>\n",
+    "<form action=\"javascript:code_toggle()\" style=\"position:fixed;left:10px;top:10px\"><input type=\"submit\" id=\"toggleButton\" value=\"Show Code\"></form>"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "import pandas as pd\n",
+    "import numpy as np\n",
+    "from collections import OrderedDict\n",
+    "\n",
+    "import math\n",
+    "\n",
+    "import re\n",
+    "import os\n",
+    "import json\n",
+    "import pylab as P\n",
+    "import matplotlib as mpl\n",
+    "import matplotlib.pyplot as plt\n",
+    "import seaborn as sns\n",
+    "sns.set(style=\"whitegrid\", color_codes=True)\n",
+    "%matplotlib inline\n",
+    "\n",
+    "from pysam import VariantFile\n",
+    "from collections import defaultdict\n",
+    "from collections import Counter\n",
+    "\n",
+    "from IPython.display import display, HTML\n",
+    "\n",
+    "# Read tsv file\n",
+    "results_df = pd.read_table(\"results_sv_per_tools.tsv\", header=0, index_col=0)\n",
+    "\n",
+    "# Retrieve list of tools\n",
+    "tools = set()\n",
+    "for col in results_df.columns:\n",
+    "    if col.endswith(\"__Start\") and col != \"Real_data__Start\" and col != \"Filtered_results__Start\":\n",
+    "        tools.add(col.split(\"__\")[0])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Recall & Precision"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def compute_tp_fp_fn(svs, tool):\n",
+    "    tp = 0\n",
+    "    fp = 0\n",
+    "    fn = 0\n",
+    "    start_values = svs[\"{0}__Start\".format(tool)]\n",
+    "    for i in range(0, len(start_values)):\n",
+    "        if math.isnan(start_values[i]) and not math.isnan(svs[\"Real_data__Start\"][i]):\n",
+    "            fn += 1\n",
+    "        elif not math.isnan(start_values[i]) and not math.isnan(svs[\"Real_data__Start\"][i]):\n",
+    "            tp += 1\n",
+    "        elif not math.isnan(start_values[i]) and math.isnan(svs[\"Real_data__Start\"][i]):\n",
+    "            fp += 1\n",
+    "    return tp, fp, fn\n",
+    "\n",
+    "recall = OrderedDict()\n",
+    "precision = OrderedDict()\n",
+    "for tool in tools:\n",
+    "    if tool + \"__Start\" in results_df:\n",
+    "        tp, fp, fn = compute_tp_fp_fn(results_df, tool)\n",
+    "        recall[tool] = [tp / (tp+fn) * 100]\n",
+    "        precision[tool] = [tp / (tp+fp) * 100]\n",
+    "        \n",
+    "plt.figure(1, figsize=(20,10))\n",
+    "\n",
+    "# Plot recall\n",
+    "plt.subplot(121)\n",
+    "recall_df = pd.DataFrame.from_dict(recall, orient=\"columns\")\n",
+    "plot = sns.barplot(data=recall_df)\n",
+    "plot.set_title(\"Recall\", fontsize=30)\n",
+    "plot.set_ylabel(\"Recall (%)\", fontsize=20)\n",
+    "plot.set_xlabel(\"Tool\", fontsize=20)\n",
+    "plot.tick_params(labelsize=14)\n",
+    "plot.set_ylim([0,100])\n",
+    "\n",
+    "# Plot precision\n",
+    "plt.subplot(122)\n",
+    "precision_df = pd.DataFrame.from_dict(precision, orient=\"columns\")\n",
+    "plot2 = sns.barplot(data=precision_df)\n",
+    "plot2.set_title(\"Precision\", fontsize=30)\n",
+    "plot2.set_ylabel(\"Precision (%)\", fontsize=20)\n",
+    "plot2.set_xlabel(\"Tool\", fontsize=20)\n",
+    "plot2.tick_params(labelsize=14)\n",
+    "\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Influence of variant size on recall"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "groups = []\n",
+    "with open(\"rules.sim\", \"r\") as rules:\n",
+    "    for line in rules:\n",
+    "        line = line.rstrip()\n",
+    "        if line != \"\":\n",
+    "            parts = re.split(r\"\\s+\", line)\n",
+    "            groups.append((int(parts[1]), int(parts[2])))\n",
+    "\n",
+    "groups.sort(key=lambda x: x[0])"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### By tool"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "nrows = math.ceil(len(tools)/2)\n",
+    "ncols = min(2, len(tools))\n",
+    "\n",
+    "palettes = [\"Blues_d\", \"Greens_d\", \"Reds_d\", \"Purples_d\", \"YlOrBr_d\", \"PuBu_d\"]\n",
+    "\n",
+    "plt.figure(1, figsize=(20,nrows * 8))\n",
+    "\n",
+    "nplot=0\n",
+    "\n",
+    "for tool in tools:         \n",
+    "    npalette = nplot\n",
+    "    while npalette >= len(palettes):\n",
+    "        npalette -= len(palettes)\n",
+    "    nplot += 1\n",
+    "    results_by_group = OrderedDict()\n",
+    "    for group in groups:\n",
+    "        tmp_res = results_df[(results_df.Real_data__Length >= group[0]) & (results_df.Real_data__Length < group[1])]\n",
+    "        tp, fp, fn = compute_tp_fp_fn(tmp_res, tool)\n",
+    "        results_by_group[\"-\".join(map(str,group))] = [tp / (tp+fn) * 100]\n",
+    "\n",
+    "    recall_df = pd.DataFrame.from_dict(results_by_group, orient=\"columns\")\n",
+    "    \n",
+    "    plt.subplot(nrows, ncols, nplot)\n",
+    "    plot = sns.barplot(data=recall_df, palette=palettes[npalette])\n",
+    "    plot.set_title(tool, fontsize=25)\n",
+    "    plot.set_ylabel(\"Recall (%)\", fontsize=20)\n",
+    "    plot.set_xlabel(\"Variant size (bp)\", fontsize=20)\n",
+    "    plot.tick_params(labelsize=14)\n",
+    "    plot.set_ylim([0,100])\n",
+    "\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Global"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "results_by_group = OrderedDict()\n",
+    "            \n",
+    "for group in groups:\n",
+    "    tmp_res = results_df[(results_df.Real_data__Length >= group[0]) & (results_df.Real_data__Length < group[1])]\n",
+    "    tp, fp, fn = compute_tp_fp_fn(tmp_res, \"Filtered_results\")\n",
+    "    results_by_group[\"-\".join(map(str,group))] = [tp / (tp+fn) * 100]\n",
+    "    \n",
+    "plt.figure(1, figsize=(15,8))\n",
+    "    \n",
+    "recall_df = pd.DataFrame.from_dict(results_by_group, orient=\"columns\")\n",
+    "plot = sns.barplot(data=recall_df, color='black')\n",
+    "plot.set_title(\"Recall\", fontsize=30)\n",
+    "plot.set_ylabel(\"Recall (%)\", fontsize=20)\n",
+    "plot.set_xlabel(\"Variant size (bp)\", fontsize=20)\n",
+    "plot.tick_params(labelsize=14)\n",
+    "plot.set_ylim([0,100])\n",
+    "\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Shared variant by tool"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def compute_found_sv(svs: pd.DataFrame, tool: str):\n",
+    "    variants = []\n",
+    "    start_values = svs[\"{0}__Start\".format(tool)]\n",
+    "    for idx in svs.index:\n",
+    "        if not math.isnan(start_values[idx]) and not math.isnan(svs[\"Real_data__Start\"][idx]):\n",
+    "            variants.append(idx)\n",
+    "    return variants\n",
+    "\n",
+    "variants_by_tool = []\n",
+    "\n",
+    "for tool in tools:\n",
+    "    variants_by_tool.append({\"name\": tool, \"data\": compute_found_sv(results_df, tool)})"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "html=HTML(\"<script type='text/javascript' src='http://jvenn.toulouse.inra.fr/app/js/canvas2svg.js'></script> \\\n",
+    "<script type='text/javascript' src='http://jvenn.toulouse.inra.fr/app/js/jvenn.min.js'></script> \\\n",
+    "<div id='draw'></div> \\\n",
+    "<script type='text/javascript'> \\\n",
+    "  $(document).ready(function(){ \\\n",
+    "     $('#draw').jvenn({ \\\n",
+    "       series: \" + json.dumps(variants_by_tool) + \" \\\n",
+    "     }); \\\n",
+    "  }); \\\n",
+    "</script>\")\n",
+    "\n",
+    "display(html)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### Breakpoints precision"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "plt.figure(1, figsize=(20,8))\n",
+    "\n",
+    "# Start precision\n",
+    "df_diffs=pd.read_table(\"results_sv_diffs_per_tools.tsv\", header=0, index_col=0)\n",
+    "all_diffs_soft_start = pd.DataFrame()\n",
+    "for tool in tools:\n",
+    "    all_diffs_soft_start[tool] = df_diffs[tool + \"__Start\"].abs()\n",
+    "\n",
+    "plt.subplot(121)\n",
+    "plot = sns.stripplot(data=all_diffs_soft_start, jitter=True)\n",
+    "plot.set_ylim([0,150])\n",
+    "plot.tick_params(labelsize=15)\n",
+    "plot.set_title(\"Start position\", fontsize=28, y=1.04)\n",
+    "plot.set_ylabel(\"Diff from real data (abs)\", fontsize=20)\n",
+    "\n",
+    "# End precision\n",
+    "df_diffs=pd.read_table(\"results_sv_diffs_per_tools.tsv\", header=0, index_col=0)\n",
+    "all_diffs_soft_end = pd.DataFrame()\n",
+    "for tool in tools:\n",
+    "    all_diffs_soft_end[tool] = df_diffs[tool + \"__End\"].abs()\n",
+    "\n",
+    "plt.subplot(122)\n",
+    "plot = sns.stripplot(data=all_diffs_soft_end, jitter=True)\n",
+    "plot.set_ylim([0,150])\n",
+    "plot.tick_params(labelsize=15)\n",
+    "plot.set_title(\"End position\", fontsize=28, y=1.04)\n",
+    "plot.set_ylabel(\"Diff from real data (abs)\", fontsize=20)\n",
+    "\n",
+    "plt.show()"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## 2. CNV Genotyping\n",
+    "\n",
+    "We compare quality of genotyping"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Simulated deletions\n",
+    "total = len(results_df[results_df.Real_data__Start.notnull()])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Predicted deletions genotyped\n",
+    "\n",
+    "df_svtyper=pd.read_csv(\"results_genotype.tsv\",sep='\\t',index_col=False,names=['del','software','delsize','recall','left','right'])\n",
+    "df_svtyper['precision'] = [ \"precise\" if (x+y)<20 else \"unprecise\" for (x,y) in zip(df_svtyper.left,df_svtyper.right)]\n",
+    "df_svtyper['deltype'] = ['small' if x<200 else 'medium' for x in df_svtyper.delsize]\n",
+    "counts_svtyper = np.unique(df_svtyper.software,return_counts=True)[1]"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Genotype recall for each software prediction"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [],
+   "source": [
+    "gt_tools = list(tools) + [\"pass\"]\n",
+    "plt.figure(1, figsize=(8,5))\n",
+    "sns.stripplot(x=\"software\", y=\"recall\", jitter=True, palette=\"Set2\", dodge=True, linewidth=1, edgecolor='gray', order=gt_tools,data=df_svtyper)\n",
+    "axes = sns.boxplot(x=\"software\", y=\"recall\", palette=\"Set2\", order=gt_tools,data=df_svtyper)\n",
+    "axes.title.set_position([.5, 1.2])\n",
+    "axes.set_title(str(total)+\" simulated variants\",size=25)\n",
+    "axes.axes.xaxis.label.set_size(20)\n",
+    "axes.axes.yaxis.label.set_size(20)\n",
+    "axes.tick_params(labelsize=15)\n",
+    "ymax = axes.get_ylim()[1]\n",
+    "for i in range(0, len(counts_svtyper)):\n",
+    "    t1=axes.text(-0.1 + (i * 1), ymax+ymax/100, counts_svtyper[i], fontsize=15)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Inluence of precision : precise means less than 20bp"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "plt.figure(1, figsize=(8,5))\n",
+    "sns.stripplot(x=\"software\", y=\"recall\", jitter=True, hue=\"precision\", palette=\"Set2\", dodge=True, linewidth=1, edgecolor='gray', order=gt_tools,data=df_svtyper)\n",
+    "axes = sns.boxplot(x=\"software\", y=\"recall\",hue=\"precision\",palette=\"Set2\", order=gt_tools,data=df_svtyper)\n",
+    "axes.title.set_position([.5, 1.2])\n",
+    "axes.set_ylim(0.15, 1.05)\n",
+    "axes.set_title(str(total)+\" simulated variants\",size=25)\n",
+    "axes.axes.xaxis.label.set_size(20)\n",
+    "axes.axes.yaxis.label.set_size(20)\n",
+    "axes.tick_params(labelsize=15)\n",
+    "ymax = axes.get_ylim()[1]\n",
+    "for i in range(0, len(counts_svtyper)):\n",
+    "    t1=axes.text(-0.1 + (i * 1), ymax+ymax/100, counts_svtyper[i], fontsize=15)\n",
+    "plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "#### Inluence of deletion size : small means less than 200bp"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "plt.figure(1, figsize=(8,5))\n",
+    "sns.stripplot(x=\"software\", y=\"recall\", jitter=True, hue=\"deltype\", palette=\"Set2\", dodge=True, linewidth=1, edgecolor='gray', order=gt_tools,data=df_svtyper)\n",
+    "axes = sns.boxplot(x=\"software\", y=\"recall\",hue=\"deltype\",palette=\"Set2\", order=gt_tools,data=df_svtyper)\n",
+    "axes.title.set_position([.5, 1.2])\n",
+    "axes.set_ylim(0.15, 1.05)\n",
+    "axes.set_title(str(total)+\" simulated variants\",size=25)\n",
+    "axes.axes.xaxis.label.set_size(20)\n",
+    "axes.axes.yaxis.label.set_size(20)\n",
+    "axes.tick_params(labelsize=15)\n",
+    "ymax = axes.get_ylim()[1]\n",
+    "for i in range(0, len(counts_svtyper)):\n",
+    "    t1=axes.text(-0.1 + (i * 1), ymax+ymax/100, counts_svtyper[i], fontsize=15)\n",
+    "plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "##### Size VS Precision"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "plt.figure(1, figsize=(8,5))\n",
+    "sns.stripplot(x=\"precision\", y=\"delsize\", jitter=True, palette=\"Set2\", dodge=True, linewidth=1, edgecolor='gray', order=['precise', 'unprecise'],data=df_svtyper)\n",
+    "axes = sns.boxplot(x=\"precision\", y=\"delsize\",palette=\"Set2\", order=['precise', 'unprecise'],data=df_svtyper)\n",
+    "axes.axes.xaxis.label.set_size(20)\n",
+    "axes.axes.yaxis.label.set_size(20)\n",
+    "axes.tick_params(labelsize=15)"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.5.2"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/build_results.py b/build_results.py
index dcaae65..2b9981d 100755
--- a/build_results.py
+++ b/build_results.py
@@ -18,9 +18,13 @@ import argparse
 import string
 import os
 import re
+import shutil
 
 from pysam import VariantFile
 
+import lib.genotype_results as gtres
+import lib.vcf as vcf
+
 import sys
 prg_path = os.path.dirname(os.path.realpath(__file__))
 sys.path.insert(0, os.path.join(prg_path, "svlib"))
@@ -38,8 +42,6 @@ COLOR_IS_KEPT = "#81F781"
 COLOR_FALSE_POSITIVE = "#FE642E"
 COLOR_WRONG_GT = "#B40404"
 
-ALLOW_VARIANTS = ['del', 'inv']
-
 
 def get_args():
     """
@@ -51,15 +53,16 @@ Build Results \n \
 description: Build results of the simulated data detection")
     parser.add_argument('-v', '--vcfs', type=str, required=True, help='File listing vcf files for each detection tool')
     parser.add_argument('-t', '--true-vcf', type=str, required=True, help='VCF file containing the simulated deletions')
-    parser.add_argument('-f', '--filtered-vcf', type=str, required=False,
+    parser.add_argument('-f', '--filtered-vcf', type=str, required=True,
                         help='File listing VCF files containing the filtered results')
-    parser.add_argument('-y', '--type', required=True, type=str, choices=ALLOW_VARIANTS, help="Type of variant")
+    parser.add_argument('-y', '--type', required=True, type=str, choices=vcf.ALLOW_VARIANTS, help="Type of variant")
     parser.add_argument('--overlap_cutoff',  type=float, default=0.5, help='cutoff for reciprocal overlap')
     parser.add_argument('--left_precision',  type=int, default=-1, help='left breakpoint precision')
     parser.add_argument('--right_precision',  type=int, default=-1, help='right breakpoint precision')
     parser.add_argument('-o', '--output', type=str, default="results", help='output folder')
     parser.add_argument('--no-xls', action='store_const', const=True, default=False, help='Do not build Excel file')
     parser.add_argument('--haploid', action='store_const', const=True, default=False, help='The organism is haploid')
+    parser.add_argument('-r', '--rules', type=str, required=False, help="Simulation rule file")
     # parse the arguments
     args = parser.parse_args()
 
@@ -69,6 +72,9 @@ description: Build results of the simulated data detection")
     if args.right_precision == -1:
         args.right_precision = sys.maxsize
 
+    if args.rules is None:
+        args.rules = os.path.join(prg_path, "defaults.rules")
+
     # send back the user input
     return args
 
@@ -79,35 +85,6 @@ def eprint(*args, **kwargs):
     """
     print(*args, file=sys.stderr, **kwargs)
 
-
-def passed_variant(record):
-    """
-    Did this variant pass?
-    :param record: vcf record object
-    :return: True if pass, False else
-    """
-    return record.filter is None or len(record.filter) == 0 or "PASS" in record.filter
-
-
-def read_vcf_file(infile, type_v):
-    """
-    Read a vcf file
-    :param infile: vcf file path
-    :param type_v: type of variant ("del" or "inv")
-    :return: set or records, list of records ids
-    """
-    if type_v.lower() not in ALLOW_VARIANTS:
-        raise ValueError("Invalid variant type: %s" % type_v)
-    SVSet=[]
-    ids = []
-    for record in VCFReader(infile):
-        if record.sv_type.lower() == type_v:
-            if not passed_variant(record):
-                continue
-            SVSet.append(record)
-            ids.append(record.id)
-    return SVSet, ids
-
     
 def svsort(sv, records):
     """
@@ -818,7 +795,7 @@ def build_xlsx_cols():
             XLSX_COLS.append(alp + j)
 
 
-def init(output, vcf_files, true_vcf, filtered_vcfs=None, type_v="del", overlap_cutoff=0.5,
+def init(output, vcf_files, true_vcf, rules, filtered_vcfs=None, type_v="del", overlap_cutoff=0.5,
          left_precision=sys.maxsize, right_precision=sys.maxsize, no_xls=False, haploid=False):
 
     if not os.path.exists(output):
@@ -831,15 +808,18 @@ def init(output, vcf_files, true_vcf, filtered_vcfs=None, type_v="del", overlap_
 
     nb_inds = 0
 
+    filtered = None
+    filtered_all = None
     filtered_records = None
     do_genotype = False
 
     if filtered_vcfs:
-        filtered_records = []
+        filtered = {}
+        filtered_all = {}
         for filtered_vcf in filtered_vcfs:
             eprint(" Reading file %s" % filtered_vcf)
-            filtered_records += read_vcf_file(filtered_vcf, type_v)[1]
-
+            vcf.readvariants(filtered_vcf, type_v, filtered, True, filtered_all)
+        filtered_records = filtered.keys()
         genotypes, gt_quality, nb_inds = get_genotypes(filtered_vcfs, true_vcf)
         do_genotype = True
 
@@ -848,20 +828,20 @@ def init(output, vcf_files, true_vcf, filtered_vcfs=None, type_v="del", overlap_
     for infile in vcf_files:
         eprint(" Reading file %s" % infile)
         try:
-            sv_set += read_vcf_file(infile, type_v)[0]
+            sv_set += vcf.readvariants(infile, type_v).values()
         except:
             print("Ignoreing file %s" % infile)
 
     eprint(" Reading file %s" % true_vcf)
-    sv_set_to, true_ones_records = read_vcf_file(true_vcf, type_v)
-    sv_set += sv_set_to
+    true_variants = vcf.readvariants(true_vcf, type_v)
+    sv_set += true_variants.values()
 
     # Compute connected components:
     eprint("Computing Connected components")
     construct_overlap_graph(sv_set, overlap_cutoff, left_precision, right_precision)
 
     # Build records:
-    records, tools, orphans = build_records(genotypes, sv_set, true_ones_records, filtered_records, gt_quality)
+    records, tools, orphans = build_records(genotypes, sv_set, true_variants.values(), filtered_records, gt_quality)
 
     nb_records = len(records)
 
@@ -911,6 +891,29 @@ def init(output, vcf_files, true_vcf, filtered_vcfs=None, type_v="del", overlap_
                         nb_tools + (2 if filtered_records is not None else 1),
                         nb_inds, (2, nb_records + 2))
 
+    ###############################
+    # Build genotypes result file #
+    ###############################
+
+    gtres.build(true_genotypes=true_variants,
+                pred_genotypes=filtered_all,
+                filtered_genotypes=filtered,
+                output=os.path.join(output, "results_genotype.tsv"))
+
+    #######################################
+    # Build Jupyter HTML notebook summary #
+    #######################################
+
+    # Copy necessary files:
+
+    files_to_copy = [rules, "Summarized_results.ipynb"]
+    for file in files_to_copy:
+        shutil.copy(file, os.path.join(output, os.path.basename(file)))
+
+    # Build HTML summary:
+    ipynb = os.path.join(output, "Summarized_results.ipynb")
+    os.popen("jupyter nbconvert --to html --template basic --execute %s" % ipynb)
+
     print_results(nb_records, orphans, with_xlsx, output, do_genotype)
 
 
@@ -944,7 +947,8 @@ def main():
          left_precision=args.left_precision,
          right_precision=args.right_precision,
          no_xls=args.no_xls,
-         haploid=args.haploid)
+         haploid=args.haploid,
+         rules=args.rules)
 
 
 # initialize the script
diff --git a/lib/genotype_results.py b/lib/genotype_results.py
index 2199f1f..d0ca823 100755
--- a/lib/genotype_results.py
+++ b/lib/genotype_results.py
@@ -2,7 +2,7 @@
 
 from pybedtools import BedTool
 from pybedtools import create_interval_from_list
-from pysam import VariantFile
+import lib.vcf as vcf
 
 
 def variants_to_pybed(variants):
@@ -17,23 +17,6 @@ def variants_to_pybed(variants):
     return BedTool(intervals).sort()
 
 
-def passed_variant(record):
-    """
-    Did this variant pass?
-    :param record: vcf record object
-    :return: True if pass, False else
-    """
-    return record.filter is None or len(record.filter) == 0 or "PASS" in record.filter
-
-
-def readgenotypes(vcffile: str, variants: dict, type_v, dofilter: bool=False):
-    vcfin = VariantFile(vcffile)
-    for r in vcfin:
-        if (not dofilter or passed_variant(r)) and r.alts[0][1:-1].lower() == type_v:
-            variants[r.id] = r
-    return variants
-
-
 def canonize(geno):
     if geno == (1, 0):
         return 0, 1
@@ -65,7 +48,7 @@ def getvarsize(variant):
     return variant.stop - variant.start + 1
 
 
-def build(true_genotypes, pred_genotypes, filtered_genotypes, output):
+def build(true_genotypes: dict, pred_genotypes: dict, filtered_genotypes: dict, output: str):
     true_pybed = variants_to_pybed(true_genotypes)
     pred_pybed = variants_to_pybed(pred_genotypes)
 
@@ -89,23 +72,27 @@ def build(true_genotypes, pred_genotypes, filtered_genotypes, output):
 
 
 def main(genotypes, predicted, filtered, output, type_v="del", verbose=False):
-    true_genotypes = readgenotypes(vcffile=genotypes,
-                                   variants={},
-                                   type_v=type_v)
+    true_genotypes = vcf.readvariants(vcffile=genotypes,
+                                      variants={},
+                                      type_v=type_v)
     pred_genotypes = {}
     filtered_genotypes = {}
     with open(predicted, "r") as preds:
         for pred in preds:
             pred = pred.rstrip()
             if pred != "":
-                readgenotypes(vcffile=pred,
-                              variants=pred_genotypes,
-                              type_v=type_v)
                 if filtered:
-                    readgenotypes(vcffile=pred,
-                                  variants=filtered_genotypes,
-                                  type_v=type_v,
-                                  dofilter=True)
+                    vcf.readvariants(vcffile=pred,
+                                     variants=filtered_genotypes,
+                                     type_v=type_v,
+                                     dofilter=True,
+                                     all_variants=pred_genotypes)
+                else:
+                    if filtered:
+                        vcf.readvariants(vcffile=pred,
+                                         variants=pred_genotypes,
+                                         type_v=type_v,
+                                         dofilter=False)
 
     pred_with_true = build(true_genotypes=true_genotypes,
                            pred_genotypes=pred_genotypes,
diff --git a/lib/vcf.py b/lib/vcf.py
new file mode 100644
index 0000000..ed4db7a
--- /dev/null
+++ b/lib/vcf.py
@@ -0,0 +1,41 @@
+from collections import OrderedDict
+from pysam import VariantFile
+
+
+ALLOW_VARIANTS = ['del', 'inv']
+
+
+def passed_variant(record):
+    """
+    Did this variant pass?
+    :param record: vcf record object
+    :return: True if pass, False else
+    """
+    return record.filter is None or len(record.filter) == 0 or "PASS" in record.filter
+
+
+def readvariants(vcffile: str, type_v, variants: dict=None, dofilter: bool=False, all_variants: dict=None):
+    """
+    Read variants from a VCF file
+    :param vcffile: input vcf file
+    :param type_v: type of variant to get
+    :param variants: dict containing variants (only filtered if dofilter)
+    :param dofilter: filter to get only PASS variants
+    :param all_variants: all variants (ignored if is None or if dofilter is False)
+    :return:
+    """
+    if variants is None:
+        variants = {}
+    if type_v.lower() not in ALLOW_VARIANTS:
+        raise ValueError("Invalid variant type: %s" % type_v)
+
+    vcfin = VariantFile(vcffile)
+    for r in vcfin:
+        if r.alts[0][1:-1].lower() == type_v:
+            if not dofilter or passed_variant(r):
+                variants[r.id] = r
+            if dofilter and all_variants is not None:
+                all_variants[r.id] = r
+    if dofilter and all_variants is not None:
+        return variants, all_variants
+    return variants
-- 
GitLab