From d2c794aee92386af7195b844b401f2d59b686ce8 Mon Sep 17 00:00:00 2001 From: Floreal Cabanettes <floreal.cabanettes@inra.fr> Date: Thu, 29 Mar 2018 18:17:21 +0200 Subject: [PATCH] Add build of jupyter notebook summary --- Summarized_results.ipynb | 490 +++++++++++++++++++++++++++++++++++++++ build_results.py | 88 +++---- lib/genotype_results.py | 45 ++-- lib/vcf.py | 41 ++++ 4 files changed, 593 insertions(+), 71 deletions(-) create mode 100644 Summarized_results.ipynb create mode 100644 lib/vcf.py diff --git a/Summarized_results.ipynb b/Summarized_results.ipynb new file mode 100644 index 0000000..e5ae770 --- /dev/null +++ b/Summarized_results.ipynb @@ -0,0 +1,490 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CNV detection benchmark" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%%html\n", + "\n", + "<script>\n", + " function code_toggle() {\n", + " if (code_shown){\n", + " $('div.input').hide('500');\n", + " $('#toggleButton').val('Show Code')\n", + " } else {\n", + " $(\"div.input:not(:first)\").show('500');\n", + " $('#toggleButton').val('Hide Code')\n", + " }\n", + " code_shown = !code_shown\n", + " }\n", + "\n", + " $( document ).ready(function(){\n", + " code_shown=false;\n", + " $('div.input').hide()\n", + " });\n", + "</script>\n", + "<form action=\"javascript:code_toggle()\" style=\"position:fixed;left:10px;top:10px\"><input type=\"submit\" id=\"toggleButton\" value=\"Show Code\"></form>" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "from collections import OrderedDict\n", + "\n", + "import math\n", + "\n", + "import re\n", + "import os\n", + "import json\n", + "import pylab as P\n", + "import matplotlib as mpl\n", + "import matplotlib.pyplot as plt\n", + "import seaborn as sns\n", + "sns.set(style=\"whitegrid\", color_codes=True)\n", + "%matplotlib inline\n", + "\n", + "from pysam import VariantFile\n", + "from collections import defaultdict\n", + "from collections import Counter\n", + "\n", + "from IPython.display import display, HTML\n", + "\n", + "# Read tsv file\n", + "results_df = pd.read_table(\"results_sv_per_tools.tsv\", header=0, index_col=0)\n", + "\n", + "# Retrieve list of tools\n", + "tools = set()\n", + "for col in results_df.columns:\n", + " if col.endswith(\"__Start\") and col != \"Real_data__Start\" and col != \"Filtered_results__Start\":\n", + " tools.add(col.split(\"__\")[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Recall & Precision" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_tp_fp_fn(svs, tool):\n", + " tp = 0\n", + " fp = 0\n", + " fn = 0\n", + " start_values = svs[\"{0}__Start\".format(tool)]\n", + " for i in range(0, len(start_values)):\n", + " if math.isnan(start_values[i]) and not math.isnan(svs[\"Real_data__Start\"][i]):\n", + " fn += 1\n", + " elif not math.isnan(start_values[i]) and not math.isnan(svs[\"Real_data__Start\"][i]):\n", + " tp += 1\n", + " elif not math.isnan(start_values[i]) and math.isnan(svs[\"Real_data__Start\"][i]):\n", + " fp += 1\n", + " return tp, fp, fn\n", + "\n", + "recall = OrderedDict()\n", + "precision = OrderedDict()\n", + "for tool in tools:\n", + " if tool + \"__Start\" in results_df:\n", + " tp, fp, fn = compute_tp_fp_fn(results_df, tool)\n", + " recall[tool] = [tp / (tp+fn) * 100]\n", + " precision[tool] = [tp / (tp+fp) * 100]\n", + " \n", + "plt.figure(1, figsize=(20,10))\n", + "\n", + "# Plot recall\n", + "plt.subplot(121)\n", + "recall_df = pd.DataFrame.from_dict(recall, orient=\"columns\")\n", + "plot = sns.barplot(data=recall_df)\n", + "plot.set_title(\"Recall\", fontsize=30)\n", + "plot.set_ylabel(\"Recall (%)\", fontsize=20)\n", + "plot.set_xlabel(\"Tool\", fontsize=20)\n", + "plot.tick_params(labelsize=14)\n", + "plot.set_ylim([0,100])\n", + "\n", + "# Plot precision\n", + "plt.subplot(122)\n", + "precision_df = pd.DataFrame.from_dict(precision, orient=\"columns\")\n", + "plot2 = sns.barplot(data=precision_df)\n", + "plot2.set_title(\"Precision\", fontsize=30)\n", + "plot2.set_ylabel(\"Precision (%)\", fontsize=20)\n", + "plot2.set_xlabel(\"Tool\", fontsize=20)\n", + "plot2.tick_params(labelsize=14)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Influence of variant size on recall" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "groups = []\n", + "with open(\"rules.sim\", \"r\") as rules:\n", + " for line in rules:\n", + " line = line.rstrip()\n", + " if line != \"\":\n", + " parts = re.split(r\"\\s+\", line)\n", + " groups.append((int(parts[1]), int(parts[2])))\n", + "\n", + "groups.sort(key=lambda x: x[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### By tool" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "nrows = math.ceil(len(tools)/2)\n", + "ncols = min(2, len(tools))\n", + "\n", + "palettes = [\"Blues_d\", \"Greens_d\", \"Reds_d\", \"Purples_d\", \"YlOrBr_d\", \"PuBu_d\"]\n", + "\n", + "plt.figure(1, figsize=(20,nrows * 8))\n", + "\n", + "nplot=0\n", + "\n", + "for tool in tools: \n", + " npalette = nplot\n", + " while npalette >= len(palettes):\n", + " npalette -= len(palettes)\n", + " nplot += 1\n", + " results_by_group = OrderedDict()\n", + " for group in groups:\n", + " tmp_res = results_df[(results_df.Real_data__Length >= group[0]) & (results_df.Real_data__Length < group[1])]\n", + " tp, fp, fn = compute_tp_fp_fn(tmp_res, tool)\n", + " results_by_group[\"-\".join(map(str,group))] = [tp / (tp+fn) * 100]\n", + "\n", + " recall_df = pd.DataFrame.from_dict(results_by_group, orient=\"columns\")\n", + " \n", + " plt.subplot(nrows, ncols, nplot)\n", + " plot = sns.barplot(data=recall_df, palette=palettes[npalette])\n", + " plot.set_title(tool, fontsize=25)\n", + " plot.set_ylabel(\"Recall (%)\", fontsize=20)\n", + " plot.set_xlabel(\"Variant size (bp)\", fontsize=20)\n", + " plot.tick_params(labelsize=14)\n", + " plot.set_ylim([0,100])\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Global" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "results_by_group = OrderedDict()\n", + " \n", + "for group in groups:\n", + " tmp_res = results_df[(results_df.Real_data__Length >= group[0]) & (results_df.Real_data__Length < group[1])]\n", + " tp, fp, fn = compute_tp_fp_fn(tmp_res, \"Filtered_results\")\n", + " results_by_group[\"-\".join(map(str,group))] = [tp / (tp+fn) * 100]\n", + " \n", + "plt.figure(1, figsize=(15,8))\n", + " \n", + "recall_df = pd.DataFrame.from_dict(results_by_group, orient=\"columns\")\n", + "plot = sns.barplot(data=recall_df, color='black')\n", + "plot.set_title(\"Recall\", fontsize=30)\n", + "plot.set_ylabel(\"Recall (%)\", fontsize=20)\n", + "plot.set_xlabel(\"Variant size (bp)\", fontsize=20)\n", + "plot.tick_params(labelsize=14)\n", + "plot.set_ylim([0,100])\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Shared variant by tool" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def compute_found_sv(svs: pd.DataFrame, tool: str):\n", + " variants = []\n", + " start_values = svs[\"{0}__Start\".format(tool)]\n", + " for idx in svs.index:\n", + " if not math.isnan(start_values[idx]) and not math.isnan(svs[\"Real_data__Start\"][idx]):\n", + " variants.append(idx)\n", + " return variants\n", + "\n", + "variants_by_tool = []\n", + "\n", + "for tool in tools:\n", + " variants_by_tool.append({\"name\": tool, \"data\": compute_found_sv(results_df, tool)})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "html=HTML(\"<script type='text/javascript' src='http://jvenn.toulouse.inra.fr/app/js/canvas2svg.js'></script> \\\n", + "<script type='text/javascript' src='http://jvenn.toulouse.inra.fr/app/js/jvenn.min.js'></script> \\\n", + "<div id='draw'></div> \\\n", + "<script type='text/javascript'> \\\n", + " $(document).ready(function(){ \\\n", + " $('#draw').jvenn({ \\\n", + " series: \" + json.dumps(variants_by_tool) + \" \\\n", + " }); \\\n", + " }); \\\n", + "</script>\")\n", + "\n", + "display(html)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Breakpoints precision" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(1, figsize=(20,8))\n", + "\n", + "# Start precision\n", + "df_diffs=pd.read_table(\"results_sv_diffs_per_tools.tsv\", header=0, index_col=0)\n", + "all_diffs_soft_start = pd.DataFrame()\n", + "for tool in tools:\n", + " all_diffs_soft_start[tool] = df_diffs[tool + \"__Start\"].abs()\n", + "\n", + "plt.subplot(121)\n", + "plot = sns.stripplot(data=all_diffs_soft_start, jitter=True)\n", + "plot.set_ylim([0,150])\n", + "plot.tick_params(labelsize=15)\n", + "plot.set_title(\"Start position\", fontsize=28, y=1.04)\n", + "plot.set_ylabel(\"Diff from real data (abs)\", fontsize=20)\n", + "\n", + "# End precision\n", + "df_diffs=pd.read_table(\"results_sv_diffs_per_tools.tsv\", header=0, index_col=0)\n", + "all_diffs_soft_end = pd.DataFrame()\n", + "for tool in tools:\n", + " all_diffs_soft_end[tool] = df_diffs[tool + \"__End\"].abs()\n", + "\n", + "plt.subplot(122)\n", + "plot = sns.stripplot(data=all_diffs_soft_end, jitter=True)\n", + "plot.set_ylim([0,150])\n", + "plot.tick_params(labelsize=15)\n", + "plot.set_title(\"End position\", fontsize=28, y=1.04)\n", + "plot.set_ylabel(\"Diff from real data (abs)\", fontsize=20)\n", + "\n", + "plt.show()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2. CNV Genotyping\n", + "\n", + "We compare quality of genotyping" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Simulated deletions\n", + "total = len(results_df[results_df.Real_data__Start.notnull()])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Predicted deletions genotyped\n", + "\n", + "df_svtyper=pd.read_csv(\"results_genotype.tsv\",sep='\\t',index_col=False,names=['del','software','delsize','recall','left','right'])\n", + "df_svtyper['precision'] = [ \"precise\" if (x+y)<20 else \"unprecise\" for (x,y) in zip(df_svtyper.left,df_svtyper.right)]\n", + "df_svtyper['deltype'] = ['small' if x<200 else 'medium' for x in df_svtyper.delsize]\n", + "counts_svtyper = np.unique(df_svtyper.software,return_counts=True)[1]" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Genotype recall for each software prediction" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "scrolled": false + }, + "outputs": [], + "source": [ + "gt_tools = list(tools) + [\"pass\"]\n", + "plt.figure(1, figsize=(8,5))\n", + "sns.stripplot(x=\"software\", y=\"recall\", jitter=True, palette=\"Set2\", dodge=True, linewidth=1, edgecolor='gray', order=gt_tools,data=df_svtyper)\n", + "axes = sns.boxplot(x=\"software\", y=\"recall\", palette=\"Set2\", order=gt_tools,data=df_svtyper)\n", + "axes.title.set_position([.5, 1.2])\n", + "axes.set_title(str(total)+\" simulated variants\",size=25)\n", + "axes.axes.xaxis.label.set_size(20)\n", + "axes.axes.yaxis.label.set_size(20)\n", + "axes.tick_params(labelsize=15)\n", + "ymax = axes.get_ylim()[1]\n", + "for i in range(0, len(counts_svtyper)):\n", + " t1=axes.text(-0.1 + (i * 1), ymax+ymax/100, counts_svtyper[i], fontsize=15)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Inluence of precision : precise means less than 20bp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(1, figsize=(8,5))\n", + "sns.stripplot(x=\"software\", y=\"recall\", jitter=True, hue=\"precision\", palette=\"Set2\", dodge=True, linewidth=1, edgecolor='gray', order=gt_tools,data=df_svtyper)\n", + "axes = sns.boxplot(x=\"software\", y=\"recall\",hue=\"precision\",palette=\"Set2\", order=gt_tools,data=df_svtyper)\n", + "axes.title.set_position([.5, 1.2])\n", + "axes.set_ylim(0.15, 1.05)\n", + "axes.set_title(str(total)+\" simulated variants\",size=25)\n", + "axes.axes.xaxis.label.set_size(20)\n", + "axes.axes.yaxis.label.set_size(20)\n", + "axes.tick_params(labelsize=15)\n", + "ymax = axes.get_ylim()[1]\n", + "for i in range(0, len(counts_svtyper)):\n", + " t1=axes.text(-0.1 + (i * 1), ymax+ymax/100, counts_svtyper[i], fontsize=15)\n", + "plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Inluence of deletion size : small means less than 200bp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(1, figsize=(8,5))\n", + "sns.stripplot(x=\"software\", y=\"recall\", jitter=True, hue=\"deltype\", palette=\"Set2\", dodge=True, linewidth=1, edgecolor='gray', order=gt_tools,data=df_svtyper)\n", + "axes = sns.boxplot(x=\"software\", y=\"recall\",hue=\"deltype\",palette=\"Set2\", order=gt_tools,data=df_svtyper)\n", + "axes.title.set_position([.5, 1.2])\n", + "axes.set_ylim(0.15, 1.05)\n", + "axes.set_title(str(total)+\" simulated variants\",size=25)\n", + "axes.axes.xaxis.label.set_size(20)\n", + "axes.axes.yaxis.label.set_size(20)\n", + "axes.tick_params(labelsize=15)\n", + "ymax = axes.get_ylim()[1]\n", + "for i in range(0, len(counts_svtyper)):\n", + " t1=axes.text(-0.1 + (i * 1), ymax+ymax/100, counts_svtyper[i], fontsize=15)\n", + "plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "##### Size VS Precision" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plt.figure(1, figsize=(8,5))\n", + "sns.stripplot(x=\"precision\", y=\"delsize\", jitter=True, palette=\"Set2\", dodge=True, linewidth=1, edgecolor='gray', order=['precise', 'unprecise'],data=df_svtyper)\n", + "axes = sns.boxplot(x=\"precision\", y=\"delsize\",palette=\"Set2\", order=['precise', 'unprecise'],data=df_svtyper)\n", + "axes.axes.xaxis.label.set_size(20)\n", + "axes.axes.yaxis.label.set_size(20)\n", + "axes.tick_params(labelsize=15)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.5.2" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/build_results.py b/build_results.py index dcaae65..2b9981d 100755 --- a/build_results.py +++ b/build_results.py @@ -18,9 +18,13 @@ import argparse import string import os import re +import shutil from pysam import VariantFile +import lib.genotype_results as gtres +import lib.vcf as vcf + import sys prg_path = os.path.dirname(os.path.realpath(__file__)) sys.path.insert(0, os.path.join(prg_path, "svlib")) @@ -38,8 +42,6 @@ COLOR_IS_KEPT = "#81F781" COLOR_FALSE_POSITIVE = "#FE642E" COLOR_WRONG_GT = "#B40404" -ALLOW_VARIANTS = ['del', 'inv'] - def get_args(): """ @@ -51,15 +53,16 @@ Build Results \n \ description: Build results of the simulated data detection") parser.add_argument('-v', '--vcfs', type=str, required=True, help='File listing vcf files for each detection tool') parser.add_argument('-t', '--true-vcf', type=str, required=True, help='VCF file containing the simulated deletions') - parser.add_argument('-f', '--filtered-vcf', type=str, required=False, + parser.add_argument('-f', '--filtered-vcf', type=str, required=True, help='File listing VCF files containing the filtered results') - parser.add_argument('-y', '--type', required=True, type=str, choices=ALLOW_VARIANTS, help="Type of variant") + parser.add_argument('-y', '--type', required=True, type=str, choices=vcf.ALLOW_VARIANTS, help="Type of variant") parser.add_argument('--overlap_cutoff', type=float, default=0.5, help='cutoff for reciprocal overlap') parser.add_argument('--left_precision', type=int, default=-1, help='left breakpoint precision') parser.add_argument('--right_precision', type=int, default=-1, help='right breakpoint precision') parser.add_argument('-o', '--output', type=str, default="results", help='output folder') parser.add_argument('--no-xls', action='store_const', const=True, default=False, help='Do not build Excel file') parser.add_argument('--haploid', action='store_const', const=True, default=False, help='The organism is haploid') + parser.add_argument('-r', '--rules', type=str, required=False, help="Simulation rule file") # parse the arguments args = parser.parse_args() @@ -69,6 +72,9 @@ description: Build results of the simulated data detection") if args.right_precision == -1: args.right_precision = sys.maxsize + if args.rules is None: + args.rules = os.path.join(prg_path, "defaults.rules") + # send back the user input return args @@ -79,35 +85,6 @@ def eprint(*args, **kwargs): """ print(*args, file=sys.stderr, **kwargs) - -def passed_variant(record): - """ - Did this variant pass? - :param record: vcf record object - :return: True if pass, False else - """ - return record.filter is None or len(record.filter) == 0 or "PASS" in record.filter - - -def read_vcf_file(infile, type_v): - """ - Read a vcf file - :param infile: vcf file path - :param type_v: type of variant ("del" or "inv") - :return: set or records, list of records ids - """ - if type_v.lower() not in ALLOW_VARIANTS: - raise ValueError("Invalid variant type: %s" % type_v) - SVSet=[] - ids = [] - for record in VCFReader(infile): - if record.sv_type.lower() == type_v: - if not passed_variant(record): - continue - SVSet.append(record) - ids.append(record.id) - return SVSet, ids - def svsort(sv, records): """ @@ -818,7 +795,7 @@ def build_xlsx_cols(): XLSX_COLS.append(alp + j) -def init(output, vcf_files, true_vcf, filtered_vcfs=None, type_v="del", overlap_cutoff=0.5, +def init(output, vcf_files, true_vcf, rules, filtered_vcfs=None, type_v="del", overlap_cutoff=0.5, left_precision=sys.maxsize, right_precision=sys.maxsize, no_xls=False, haploid=False): if not os.path.exists(output): @@ -831,15 +808,18 @@ def init(output, vcf_files, true_vcf, filtered_vcfs=None, type_v="del", overlap_ nb_inds = 0 + filtered = None + filtered_all = None filtered_records = None do_genotype = False if filtered_vcfs: - filtered_records = [] + filtered = {} + filtered_all = {} for filtered_vcf in filtered_vcfs: eprint(" Reading file %s" % filtered_vcf) - filtered_records += read_vcf_file(filtered_vcf, type_v)[1] - + vcf.readvariants(filtered_vcf, type_v, filtered, True, filtered_all) + filtered_records = filtered.keys() genotypes, gt_quality, nb_inds = get_genotypes(filtered_vcfs, true_vcf) do_genotype = True @@ -848,20 +828,20 @@ def init(output, vcf_files, true_vcf, filtered_vcfs=None, type_v="del", overlap_ for infile in vcf_files: eprint(" Reading file %s" % infile) try: - sv_set += read_vcf_file(infile, type_v)[0] + sv_set += vcf.readvariants(infile, type_v).values() except: print("Ignoreing file %s" % infile) eprint(" Reading file %s" % true_vcf) - sv_set_to, true_ones_records = read_vcf_file(true_vcf, type_v) - sv_set += sv_set_to + true_variants = vcf.readvariants(true_vcf, type_v) + sv_set += true_variants.values() # Compute connected components: eprint("Computing Connected components") construct_overlap_graph(sv_set, overlap_cutoff, left_precision, right_precision) # Build records: - records, tools, orphans = build_records(genotypes, sv_set, true_ones_records, filtered_records, gt_quality) + records, tools, orphans = build_records(genotypes, sv_set, true_variants.values(), filtered_records, gt_quality) nb_records = len(records) @@ -911,6 +891,29 @@ def init(output, vcf_files, true_vcf, filtered_vcfs=None, type_v="del", overlap_ nb_tools + (2 if filtered_records is not None else 1), nb_inds, (2, nb_records + 2)) + ############################### + # Build genotypes result file # + ############################### + + gtres.build(true_genotypes=true_variants, + pred_genotypes=filtered_all, + filtered_genotypes=filtered, + output=os.path.join(output, "results_genotype.tsv")) + + ####################################### + # Build Jupyter HTML notebook summary # + ####################################### + + # Copy necessary files: + + files_to_copy = [rules, "Summarized_results.ipynb"] + for file in files_to_copy: + shutil.copy(file, os.path.join(output, os.path.basename(file))) + + # Build HTML summary: + ipynb = os.path.join(output, "Summarized_results.ipynb") + os.popen("jupyter nbconvert --to html --template basic --execute %s" % ipynb) + print_results(nb_records, orphans, with_xlsx, output, do_genotype) @@ -944,7 +947,8 @@ def main(): left_precision=args.left_precision, right_precision=args.right_precision, no_xls=args.no_xls, - haploid=args.haploid) + haploid=args.haploid, + rules=args.rules) # initialize the script diff --git a/lib/genotype_results.py b/lib/genotype_results.py index 2199f1f..d0ca823 100755 --- a/lib/genotype_results.py +++ b/lib/genotype_results.py @@ -2,7 +2,7 @@ from pybedtools import BedTool from pybedtools import create_interval_from_list -from pysam import VariantFile +import lib.vcf as vcf def variants_to_pybed(variants): @@ -17,23 +17,6 @@ def variants_to_pybed(variants): return BedTool(intervals).sort() -def passed_variant(record): - """ - Did this variant pass? - :param record: vcf record object - :return: True if pass, False else - """ - return record.filter is None or len(record.filter) == 0 or "PASS" in record.filter - - -def readgenotypes(vcffile: str, variants: dict, type_v, dofilter: bool=False): - vcfin = VariantFile(vcffile) - for r in vcfin: - if (not dofilter or passed_variant(r)) and r.alts[0][1:-1].lower() == type_v: - variants[r.id] = r - return variants - - def canonize(geno): if geno == (1, 0): return 0, 1 @@ -65,7 +48,7 @@ def getvarsize(variant): return variant.stop - variant.start + 1 -def build(true_genotypes, pred_genotypes, filtered_genotypes, output): +def build(true_genotypes: dict, pred_genotypes: dict, filtered_genotypes: dict, output: str): true_pybed = variants_to_pybed(true_genotypes) pred_pybed = variants_to_pybed(pred_genotypes) @@ -89,23 +72,27 @@ def build(true_genotypes, pred_genotypes, filtered_genotypes, output): def main(genotypes, predicted, filtered, output, type_v="del", verbose=False): - true_genotypes = readgenotypes(vcffile=genotypes, - variants={}, - type_v=type_v) + true_genotypes = vcf.readvariants(vcffile=genotypes, + variants={}, + type_v=type_v) pred_genotypes = {} filtered_genotypes = {} with open(predicted, "r") as preds: for pred in preds: pred = pred.rstrip() if pred != "": - readgenotypes(vcffile=pred, - variants=pred_genotypes, - type_v=type_v) if filtered: - readgenotypes(vcffile=pred, - variants=filtered_genotypes, - type_v=type_v, - dofilter=True) + vcf.readvariants(vcffile=pred, + variants=filtered_genotypes, + type_v=type_v, + dofilter=True, + all_variants=pred_genotypes) + else: + if filtered: + vcf.readvariants(vcffile=pred, + variants=pred_genotypes, + type_v=type_v, + dofilter=False) pred_with_true = build(true_genotypes=true_genotypes, pred_genotypes=pred_genotypes, diff --git a/lib/vcf.py b/lib/vcf.py new file mode 100644 index 0000000..ed4db7a --- /dev/null +++ b/lib/vcf.py @@ -0,0 +1,41 @@ +from collections import OrderedDict +from pysam import VariantFile + + +ALLOW_VARIANTS = ['del', 'inv'] + + +def passed_variant(record): + """ + Did this variant pass? + :param record: vcf record object + :return: True if pass, False else + """ + return record.filter is None or len(record.filter) == 0 or "PASS" in record.filter + + +def readvariants(vcffile: str, type_v, variants: dict=None, dofilter: bool=False, all_variants: dict=None): + """ + Read variants from a VCF file + :param vcffile: input vcf file + :param type_v: type of variant to get + :param variants: dict containing variants (only filtered if dofilter) + :param dofilter: filter to get only PASS variants + :param all_variants: all variants (ignored if is None or if dofilter is False) + :return: + """ + if variants is None: + variants = {} + if type_v.lower() not in ALLOW_VARIANTS: + raise ValueError("Invalid variant type: %s" % type_v) + + vcfin = VariantFile(vcffile) + for r in vcfin: + if r.alts[0][1:-1].lower() == type_v: + if not dofilter or passed_variant(r): + variants[r.id] = r + if dofilter and all_variants is not None: + all_variants[r.id] = r + if dofilter and all_variants is not None: + return variants, all_variants + return variants -- GitLab