Skip to content
Snippets Groups Projects
Commit 87109464 authored by Floreal Cabanettes's avatar Floreal Cabanettes
Browse files

Fix cases of tp+fn ==0

parent 792f5079
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
# CNV detection benchmark # CNV detection benchmark
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import pandas as pd import pandas as pd
import numpy as np import numpy as np
from collections import OrderedDict from collections import OrderedDict
import math import math
import re import re
import os import os
import json import json
import pylab as P import pylab as P
import matplotlib as mpl import matplotlib as mpl
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import seaborn as sns import seaborn as sns
sns.set(style="whitegrid", color_codes=True) sns.set(style="whitegrid", color_codes=True)
%matplotlib inline %matplotlib inline
from pysam import VariantFile from pysam import VariantFile
from collections import defaultdict from collections import defaultdict
from collections import Counter from collections import Counter
from IPython.display import display, HTML from IPython.display import display, HTML
# Read tsv file # Read tsv file
results_df = pd.read_table("results_sv_per_tools.tsv", header=0, index_col=0) results_df = pd.read_table("results_sv_per_tools.tsv", header=0, index_col=0)
# Retrieve list of tools # Retrieve list of tools
tools = set() tools = set()
for col in results_df.columns: for col in results_df.columns:
if col.endswith("__Start") and col != "Real_data__Start" and col != "Filtered_results__Start": if col.endswith("__Start") and col != "Real_data__Start" and col != "Filtered_results__Start":
tools.add(col.split("__")[0]) tools.add(col.split("__")[0])
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Recall & Precision ### Recall & Precision
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def compute_tp_fp_fn(svs, tool): def compute_tp_fp_fn(svs, tool):
tp = 0 tp = 0
fp = 0 fp = 0
fn = 0 fn = 0
start_values = svs["{0}__Start".format(tool)] start_values = svs["{0}__Start".format(tool)]
for i in range(0, len(start_values)): for i in range(0, len(start_values)):
if math.isnan(start_values[i]) and not math.isnan(svs["Real_data__Start"][i]): if math.isnan(start_values[i]) and not math.isnan(svs["Real_data__Start"][i]):
fn += 1 fn += 1
elif not math.isnan(start_values[i]) and not math.isnan(svs["Real_data__Start"][i]): elif not math.isnan(start_values[i]) and not math.isnan(svs["Real_data__Start"][i]):
tp += 1 tp += 1
elif not math.isnan(start_values[i]) and math.isnan(svs["Real_data__Start"][i]): elif not math.isnan(start_values[i]) and math.isnan(svs["Real_data__Start"][i]):
fp += 1 fp += 1
return tp, fp, fn return tp, fp, fn
recall = OrderedDict() recall = OrderedDict()
precision = OrderedDict() precision = OrderedDict()
for tool in tools: for tool in tools:
if tool + "__Start" in results_df: if tool + "__Start" in results_df:
tp, fp, fn = compute_tp_fp_fn(results_df, tool) tp, fp, fn = compute_tp_fp_fn(results_df, tool)
recall[tool] = [tp / (tp+fn) * 100] recall[tool] = [tp / (tp+fn) * 100]
precision[tool] = [tp / (tp+fp) * 100] precision[tool] = [tp / (tp+fp) * 100]
plt.figure(1, figsize=(20,10)) plt.figure(1, figsize=(20,10))
# Plot recall # Plot recall
plt.subplot(121) plt.subplot(121)
recall_df = pd.DataFrame.from_dict(recall, orient="columns") recall_df = pd.DataFrame.from_dict(recall, orient="columns")
plot = sns.barplot(data=recall_df) plot = sns.barplot(data=recall_df)
plot.set_title("Recall", fontsize=30) plot.set_title("Recall", fontsize=30)
plot.set_ylabel("Recall (%)", fontsize=20) plot.set_ylabel("Recall (%)", fontsize=20)
plot.set_xlabel("Tool", fontsize=20) plot.set_xlabel("Tool", fontsize=20)
plot.tick_params(labelsize=14) plot.tick_params(labelsize=14)
plot.set_ylim([0,100]) plot.set_ylim([0,100])
# Plot precision # Plot precision
plt.subplot(122) plt.subplot(122)
precision_df = pd.DataFrame.from_dict(precision, orient="columns") precision_df = pd.DataFrame.from_dict(precision, orient="columns")
plot2 = sns.barplot(data=precision_df) plot2 = sns.barplot(data=precision_df)
plot2.set_title("Precision", fontsize=30) plot2.set_title("Precision", fontsize=30)
plot2.set_ylabel("Precision (%)", fontsize=20) plot2.set_ylabel("Precision (%)", fontsize=20)
plot2.set_xlabel("Tool", fontsize=20) plot2.set_xlabel("Tool", fontsize=20)
plot2.tick_params(labelsize=14) plot2.tick_params(labelsize=14)
plt.show() plt.show()
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Influence of variant size on recall ### Influence of variant size on recall
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
groups = [] groups = []
with open("rules.sim", "r") as rules: with open("rules.sim", "r") as rules:
for line in rules: for line in rules:
line = line.rstrip() line = line.rstrip()
if line != "": if line != "":
parts = re.split(r"\s+", line) parts = re.split(r"\s+", line)
groups.append((int(parts[1]), int(parts[2]))) groups.append((int(parts[1]), int(parts[2])))
groups.sort(key=lambda x: x[0]) groups.sort(key=lambda x: x[0])
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
#### By tool #### By tool
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
nrows = math.ceil(len(tools)/2) nrows = math.ceil(len(tools)/2)
ncols = min(2, len(tools)) ncols = min(2, len(tools))
palettes = ["Blues_d", "Greens_d", "Reds_d", "Purples_d", "YlOrBr_d", "PuBu_d"] palettes = ["Blues_d", "Greens_d", "Reds_d", "Purples_d", "YlOrBr_d", "PuBu_d"]
plt.figure(1, figsize=(20,nrows * 8)) plt.figure(1, figsize=(20,nrows * 8))
nplot=0 nplot=0
for tool in tools: for tool in tools:
npalette = nplot npalette = nplot
while npalette >= len(palettes): while npalette >= len(palettes):
npalette -= len(palettes) npalette -= len(palettes)
nplot += 1 nplot += 1
results_by_group = OrderedDict() results_by_group = OrderedDict()
for group in groups: for group in groups:
tmp_res = results_df[(results_df.Real_data__Length >= group[0]) & (results_df.Real_data__Length < group[1])] tmp_res = results_df[(results_df.Real_data__Length >= group[0]) & (results_df.Real_data__Length < group[1])]
tp, fp, fn = compute_tp_fp_fn(tmp_res, tool) tp, fp, fn = compute_tp_fp_fn(tmp_res, tool)
results_by_group["-".join(map(str,group))] = [tp / (tp+fn) * 100] results_by_group["-".join(map(str,group))] = [tp / (tp+fn) * 100 if tp+fn >0 else 0]
recall_df = pd.DataFrame.from_dict(results_by_group, orient="columns") recall_df = pd.DataFrame.from_dict(results_by_group, orient="columns")
plt.subplot(nrows, ncols, nplot) plt.subplot(nrows, ncols, nplot)
plot = sns.barplot(data=recall_df, palette=palettes[npalette]) plot = sns.barplot(data=recall_df, palette=palettes[npalette])
plot.set_title(tool, fontsize=25) plot.set_title(tool, fontsize=25)
plot.set_ylabel("Recall (%)", fontsize=20) plot.set_ylabel("Recall (%)", fontsize=20)
plot.set_xlabel("Variant size (bp)", fontsize=20) plot.set_xlabel("Variant size (bp)", fontsize=20)
plot.tick_params(labelsize=14) plot.tick_params(labelsize=14)
plot.set_ylim([0,100]) plot.set_ylim([0,100])
plt.show() plt.show()
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
#### Global #### Global
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
results_by_group = OrderedDict() results_by_group = OrderedDict()
for group in groups: for group in groups:
tmp_res = results_df[(results_df.Real_data__Length >= group[0]) & (results_df.Real_data__Length < group[1])] tmp_res = results_df[(results_df.Real_data__Length >= group[0]) & (results_df.Real_data__Length < group[1])]
tp, fp, fn = compute_tp_fp_fn(tmp_res, "Filtered_results") tp, fp, fn = compute_tp_fp_fn(tmp_res, "Filtered_results")
results_by_group["-".join(map(str,group))] = [tp / (tp+fn) * 100] results_by_group["-".join(map(str,group))] = [tp / (tp+fn) * 100 if tp+fn >0 else 0]
plt.figure(1, figsize=(15,8)) plt.figure(1, figsize=(15,8))
recall_df = pd.DataFrame.from_dict(results_by_group, orient="columns") recall_df = pd.DataFrame.from_dict(results_by_group, orient="columns")
plot = sns.barplot(data=recall_df, color='black') plot = sns.barplot(data=recall_df, color='black')
plot.set_title("Recall", fontsize=30) plot.set_title("Recall", fontsize=30)
plot.set_ylabel("Recall (%)", fontsize=20) plot.set_ylabel("Recall (%)", fontsize=20)
plot.set_xlabel("Variant size (bp)", fontsize=20) plot.set_xlabel("Variant size (bp)", fontsize=20)
plot.tick_params(labelsize=14) plot.tick_params(labelsize=14)
plot.set_ylim([0,100]) plot.set_ylim([0,100])
plt.show() plt.show()
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Shared variant by tool ### Shared variant by tool
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def compute_found_sv(svs: pd.DataFrame, tool: str): def compute_found_sv(svs: pd.DataFrame, tool: str):
variants = [] variants = []
start_values = svs["{0}__Start".format(tool)] start_values = svs["{0}__Start".format(tool)]
for idx in svs.index: for idx in svs.index:
if not math.isnan(start_values[idx]) and not math.isnan(svs["Real_data__Start"][idx]): if not math.isnan(start_values[idx]) and not math.isnan(svs["Real_data__Start"][idx]):
variants.append(idx) variants.append(idx)
return variants return variants
variants_by_tool = [] variants_by_tool = []
for tool in tools: for tool in tools:
variants_by_tool.append({"name": tool, "data": compute_found_sv(results_df, tool)}) variants_by_tool.append({"name": tool, "data": compute_found_sv(results_df, tool)})
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
html=HTML("<script type='text/javascript' src='http://jvenn.toulouse.inra.fr/app/js/canvas2svg.js'></script> \ html=HTML("<script type='text/javascript' src='http://jvenn.toulouse.inra.fr/app/js/canvas2svg.js'></script> \
<script type='text/javascript' src='http://jvenn.toulouse.inra.fr/app/js/jvenn.min.js'></script> \ <script type='text/javascript' src='http://jvenn.toulouse.inra.fr/app/js/jvenn.min.js'></script> \
<div id='draw'></div> \ <div id='draw'></div> \
<script type='text/javascript'> \ <script type='text/javascript'> \
$(document).ready(function(){ \ $(document).ready(function(){ \
$('#draw').jvenn({ \ $('#draw').jvenn({ \
series: " + json.dumps(variants_by_tool) + " \ series: " + json.dumps(variants_by_tool) + " \
}); \ }); \
}); \ }); \
</script>") </script>")
display(html) display(html)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Breakpoints precision ### Breakpoints precision
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
plt.figure(1, figsize=(20,8)) plt.figure(1, figsize=(20,8))
# Start precision # Start precision
df_diffs=pd.read_table("results_sv_diffs_per_tools.tsv", header=0, index_col=0) df_diffs=pd.read_table("results_sv_diffs_per_tools.tsv", header=0, index_col=0)
all_diffs_soft_start = pd.DataFrame() all_diffs_soft_start = pd.DataFrame()
for tool in tools: for tool in tools:
all_diffs_soft_start[tool] = df_diffs[tool + "__Start"].abs() all_diffs_soft_start[tool] = df_diffs[tool + "__Start"].abs()
plt.subplot(121) plt.subplot(121)
plot = sns.stripplot(data=all_diffs_soft_start, jitter=True) plot = sns.stripplot(data=all_diffs_soft_start, jitter=True)
plot.set_ylim([0,150]) plot.set_ylim([0,150])
plot.tick_params(labelsize=15) plot.tick_params(labelsize=15)
plot.set_title("Start position", fontsize=28, y=1.04) plot.set_title("Start position", fontsize=28, y=1.04)
plot.set_ylabel("Diff from real data (abs)", fontsize=20) plot.set_ylabel("Diff from real data (abs)", fontsize=20)
# End precision # End precision
df_diffs=pd.read_table("results_sv_diffs_per_tools.tsv", header=0, index_col=0) df_diffs=pd.read_table("results_sv_diffs_per_tools.tsv", header=0, index_col=0)
all_diffs_soft_end = pd.DataFrame() all_diffs_soft_end = pd.DataFrame()
for tool in tools: for tool in tools:
all_diffs_soft_end[tool] = df_diffs[tool + "__End"].abs() all_diffs_soft_end[tool] = df_diffs[tool + "__End"].abs()
plt.subplot(122) plt.subplot(122)
plot = sns.stripplot(data=all_diffs_soft_end, jitter=True) plot = sns.stripplot(data=all_diffs_soft_end, jitter=True)
plot.set_ylim([0,150]) plot.set_ylim([0,150])
plot.tick_params(labelsize=15) plot.tick_params(labelsize=15)
plot.set_title("End position", fontsize=28, y=1.04) plot.set_title("End position", fontsize=28, y=1.04)
plot.set_ylabel("Diff from real data (abs)", fontsize=20) plot.set_ylabel("Diff from real data (abs)", fontsize=20)
plt.show() plt.show()
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## 2. CNV Genotyping ## 2. CNV Genotyping
We compare quality of genotyping We compare quality of genotyping
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Simulated deletions # Simulated deletions
total = len(results_df[results_df.Real_data__Start.notnull()]) total = len(results_df[results_df.Real_data__Start.notnull()])
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Predicted deletions genotyped # Predicted deletions genotyped
df_svtyper=pd.read_csv("results_genotype.tsv",sep='\t',index_col=False,names=['del','software','delsize','recall','left','right']) df_svtyper=pd.read_csv("results_genotype.tsv",sep='\t',index_col=False,names=['del','software','delsize','recall','left','right'])
df_svtyper['precision'] = [ "precise" if (x+y)<20 else "unprecise" for (x,y) in zip(df_svtyper.left,df_svtyper.right)] df_svtyper['precision'] = [ "precise" if (x+y)<20 else "unprecise" for (x,y) in zip(df_svtyper.left,df_svtyper.right)]
df_svtyper['deltype'] = ['small' if x<200 else 'medium' for x in df_svtyper.delsize] df_svtyper['deltype'] = ['small' if x<200 else 'medium' for x in df_svtyper.delsize]
counts_svtyper = np.unique(df_svtyper.software,return_counts=True)[1] counts_svtyper = np.unique(df_svtyper.software,return_counts=True)[1]
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
#### Genotype recall for each software prediction #### Genotype recall for each software prediction
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
gt_tools = list(tools) + ["pass"] gt_tools = list(tools) + ["pass"]
plt.figure(1, figsize=(8,5)) plt.figure(1, figsize=(8,5))
sns.stripplot(x="software", y="recall", jitter=True, palette="Set2", dodge=True, linewidth=1, edgecolor='gray', order=gt_tools,data=df_svtyper) sns.stripplot(x="software", y="recall", jitter=True, palette="Set2", dodge=True, linewidth=1, edgecolor='gray', order=gt_tools,data=df_svtyper)
axes = sns.boxplot(x="software", y="recall", palette="Set2", order=gt_tools,data=df_svtyper) axes = sns.boxplot(x="software", y="recall", palette="Set2", order=gt_tools,data=df_svtyper)
axes.title.set_position([.5, 1.2]) axes.title.set_position([.5, 1.2])
axes.set_title(str(total)+" simulated variants",size=25) axes.set_title(str(total)+" simulated variants",size=25)
axes.axes.xaxis.label.set_size(20) axes.axes.xaxis.label.set_size(20)
axes.axes.yaxis.label.set_size(20) axes.axes.yaxis.label.set_size(20)
axes.tick_params(labelsize=15) axes.tick_params(labelsize=15)
ymax = axes.get_ylim()[1] ymax = axes.get_ylim()[1]
for i in range(0, len(counts_svtyper)): for i in range(0, len(counts_svtyper)):
t1=axes.text(-0.1 + (i * 1), ymax+ymax/100, counts_svtyper[i], fontsize=15) t1=axes.text(-0.1 + (i * 1), ymax+ymax/100, counts_svtyper[i], fontsize=15)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
#### Inluence of precision : precise means less than 20bp #### Inluence of precision : precise means less than 20bp
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
plt.figure(1, figsize=(8,5)) plt.figure(1, figsize=(8,5))
sns.stripplot(x="software", y="recall", jitter=True, hue="precision", palette="Set2", dodge=True, linewidth=1, edgecolor='gray', order=gt_tools,data=df_svtyper) sns.stripplot(x="software", y="recall", jitter=True, hue="precision", palette="Set2", dodge=True, linewidth=1, edgecolor='gray', order=gt_tools,data=df_svtyper)
axes = sns.boxplot(x="software", y="recall",hue="precision",palette="Set2", order=gt_tools,data=df_svtyper) axes = sns.boxplot(x="software", y="recall",hue="precision",palette="Set2", order=gt_tools,data=df_svtyper)
axes.title.set_position([.5, 1.2]) axes.title.set_position([.5, 1.2])
axes.set_ylim(0.15, 1.05) axes.set_ylim(0.15, 1.05)
axes.set_title(str(total)+" simulated variants",size=25) axes.set_title(str(total)+" simulated variants",size=25)
axes.axes.xaxis.label.set_size(20) axes.axes.xaxis.label.set_size(20)
axes.axes.yaxis.label.set_size(20) axes.axes.yaxis.label.set_size(20)
axes.tick_params(labelsize=15) axes.tick_params(labelsize=15)
ymax = axes.get_ylim()[1] ymax = axes.get_ylim()[1]
for i in range(0, len(counts_svtyper)): for i in range(0, len(counts_svtyper)):
t1=axes.text(-0.1 + (i * 1), ymax+ymax/100, counts_svtyper[i], fontsize=15) t1=axes.text(-0.1 + (i * 1), ymax+ymax/100, counts_svtyper[i], fontsize=15)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
#### Inluence of deletion size : small means less than 200bp #### Inluence of deletion size : small means less than 200bp
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
plt.figure(1, figsize=(8,5)) plt.figure(1, figsize=(8,5))
sns.stripplot(x="software", y="recall", jitter=True, hue="deltype", palette="Set2", dodge=True, linewidth=1, edgecolor='gray', order=gt_tools,data=df_svtyper) sns.stripplot(x="software", y="recall", jitter=True, hue="deltype", palette="Set2", dodge=True, linewidth=1, edgecolor='gray', order=gt_tools,data=df_svtyper)
axes = sns.boxplot(x="software", y="recall",hue="deltype",palette="Set2", order=gt_tools,data=df_svtyper) axes = sns.boxplot(x="software", y="recall",hue="deltype",palette="Set2", order=gt_tools,data=df_svtyper)
axes.title.set_position([.5, 1.2]) axes.title.set_position([.5, 1.2])
axes.set_ylim(0.15, 1.05) axes.set_ylim(0.15, 1.05)
axes.set_title(str(total)+" simulated variants",size=25) axes.set_title(str(total)+" simulated variants",size=25)
axes.axes.xaxis.label.set_size(20) axes.axes.xaxis.label.set_size(20)
axes.axes.yaxis.label.set_size(20) axes.axes.yaxis.label.set_size(20)
axes.tick_params(labelsize=15) axes.tick_params(labelsize=15)
ymax = axes.get_ylim()[1] ymax = axes.get_ylim()[1]
for i in range(0, len(counts_svtyper)): for i in range(0, len(counts_svtyper)):
t1=axes.text(-0.1 + (i * 1), ymax+ymax/100, counts_svtyper[i], fontsize=15) t1=axes.text(-0.1 + (i * 1), ymax+ymax/100, counts_svtyper[i], fontsize=15)
plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.) plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
##### Size VS Precision ##### Size VS Precision
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
plt.figure(1, figsize=(8,5)) plt.figure(1, figsize=(8,5))
sns.stripplot(x="precision", y="delsize", jitter=True, palette="Set2", dodge=True, linewidth=1, edgecolor='gray', order=['precise', 'unprecise'],data=df_svtyper) sns.stripplot(x="precision", y="delsize", jitter=True, palette="Set2", dodge=True, linewidth=1, edgecolor='gray', order=['precise', 'unprecise'],data=df_svtyper)
axes = sns.boxplot(x="precision", y="delsize",palette="Set2", order=['precise', 'unprecise'],data=df_svtyper) axes = sns.boxplot(x="precision", y="delsize",palette="Set2", order=['precise', 'unprecise'],data=df_svtyper)
axes.axes.xaxis.label.set_size(20) axes.axes.xaxis.label.set_size(20)
axes.axes.yaxis.label.set_size(20) axes.axes.yaxis.label.set_size(20)
axes.tick_params(labelsize=15) axes.tick_params(labelsize=15)
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment