From aa96035d522c9eef78dce559924f87212b51bead Mon Sep 17 00:00:00 2001
From: Thomas Faraut <Thomas.Faraut@inra.fr>
Date: Tue, 10 Dec 2019 09:17:32 +0100
Subject: [PATCH] annotating tools and some corrections

---
 svreader/annotation.py | 160 ++++++++++++++++++++++++++++++++++++++---
 1 file changed, 151 insertions(+), 9 deletions(-)

diff --git a/svreader/annotation.py b/svreader/annotation.py
index 3b3fac6..11c2b98 100644
--- a/svreader/annotation.py
+++ b/svreader/annotation.py
@@ -2,6 +2,8 @@
 import sys
 
 from collections import defaultdict
+from networkx import Graph, connected_components
+
 import numpy as np
 from math import isnan
 
@@ -40,6 +42,13 @@ def Variant(sample):
         return False
 
 
+def Heterozygote(sample):
+    if sample.get('GT') in [HET_VAR]:
+        return True
+    else:
+        return False
+
+
 class AnnotateRecord(VCFRecord):
     """
     A lightweight object to annotated the final records
@@ -78,8 +87,8 @@ class AnnotateRecord(VCFRecord):
     def num_samples(self):
         return len(self.record.samples.keys())
 
-    def setNewId(self, identifier):
-        self.new_id = identifier
+    def setNewId(self, new_id):
+        self.new_id = new_id
 
     def rename(self):
         try:
@@ -93,16 +102,31 @@ class AnnotateRecord(VCFRecord):
         return sum([Variant(s) for s in self.samples.values()])
 
     def variant_read_support(self):
-        return max([s.get('AO')[0] for s in self.samples.values()])
+        support = []
+        for s in self.samples.values():
+            if s.get('AO') is not None:
+                support.append(s.get('AO')[0])
+        return max(support)
 
     def qual(self):
-        return sum([s.get('SQ') for s in self.samples.values()])
+        variant_qual = []
+        for s in self.samples.values():
+            if s.get('SQ') is not None:
+                variant_qual.append(s.get('SQ'))
+        return sum(variant_qual)
+
+    def GQ_samples(self):
+        genotype_qual = []
+        for s in self.samples.values():
+            if s.get('GQ') is not None:
+                genotype_qual.append(s.get('GQ'))
+        return genotype_qual
 
     def GQ_sum_score(self):
-        return sum([s.get('GQ') for s in self.samples.values()])
+        return sum(self.GQ_samples())
 
     def maxGQ(self):
-        return max([s.get('GQ') for s in self.samples.values()])
+        return max(self.GQ_samples())
 
     def setQual(self):
         self.record.qual = self.qual()
@@ -128,12 +152,22 @@ class AnnotateRecord(VCFRecord):
             sys.exit(1)
 
     def CallRate(self, cutoff):
-        num_qual_call = sum([(s.get('GQ') > cutoff) for s in self.samples.values()])
+        call_qual = []
+        for s in self.samples.values():
+            if s.get('GQ') is not None:
+                call_qual.append(s.get('GQ'))
+        num_qual_call = sum([(qual > cutoff) for qual in call_qual])
         return num_qual_call/self.num_samples
 
     def VariantCallRate(self, cutoff):
-        num_qual_call = sum([(s.get('GQ') > cutoff) for s in self.samples.values() if Variant(s)])
-        return num_qual_call/self.num_variant_samples()
+        samples = self.samples.values()
+        num_qual_var = 0
+        for s in samples:
+            if s.get("GQ") is not None and s.get("GQ") > cutoff and Variant(s):
+                num_qual_var += 1
+        num_var_samples = self.num_variant_samples()
+        var_call_rate = num_qual_var/num_var_samples if num_var_samples else 0
+        return var_call_rate
 
     def UnifiedPass(self):
         """
@@ -379,6 +413,9 @@ def add_redundancy_infos_header(reader):
     # Adding NONDUPLICATEOVERLAP
     reader.addInfo("NONDUPLICATEOVERLAP", 1, "Float",
                    "Amount of overlap with a non-duplicate")
+    # Adding TOOLSUPPORT
+    reader.addInfo("TOOLSUPPORT", ".", "String",
+                   "Tools supporting (detecting) the sv")
 
 
 def GenomeSTRIPLikeRedundancyAnnotator(SVSet, reader,
@@ -483,6 +520,7 @@ def add_filter_infos_header(reader):
     reader.addFilter("MONOMORPH", "All samples have the same genotype")
     reader.addFilter("DUPLICATE", "GSDUPLICATESCORE>0")
     reader.addFilter("OVERLAP", "NONDUPLICATEOVERLAP>0.7")
+    reader.addFilter("ABFREQ", "AB frequency <0.3 for >50% heterosamples")
 
 
 def GenomeSTRIPLikefiltering(SVSet, reader):
@@ -511,3 +549,107 @@ def GenomeSTRIPLikefiltering(SVSet, reader):
             sv.filter.add("OVERLAP")
         if "DUPLICATESCORE" in info is not None and info['DUPLICATESCORE'] > -2:
             sv.filter.add("DUPLICATE")
+
+
+def ABFreqFiltering(SVSet):
+    """ Filtering the candidate CNVs according to the following criteria
+          - more than 50% of variant samples should have AB freq > 0.3
+    """
+
+    for sv in SVSet:
+        ABfreqOK = []
+        for s in sv.record.samples.values():
+            if Heterozygote(s):
+                ABfreqOK.append((s.get('AB')[0] > 0.3))
+        if len(ABfreqOK) > 0 and sum(ABfreqOK) < len(ABfreqOK)/2:
+            sv.filter.add("ABFREQ")
+
+
+def GetConnectedDuplicates(SVSet):
+    """
+    Construct connected components of duplicates and rename the variants
+    """
+    undirected = Graph()
+    variant_dict = defaultdict()
+    representatives = defaultdict()
+    for s in SVSet:
+        variant_dict[s.id] = s
+        if "DUPLICATE" in s.filter:
+            for dupli_repr in s.record.info["DUPLICATEOF"]:
+                undirected.add_edge(s.id, dupli_repr)
+    for component in connected_components(undirected):
+        for c in component:
+            if "DUPLICATEOF" in variant_dict[c].record.info:
+                # the current variant is a duplicate
+                continue
+            rep = c  # the representative of the equivalence class
+            break
+        duplicates = component
+        duplicates.remove(rep)
+        representatives[rep] = duplicates
+    add_duplicate_infos(representatives, variant_dict)
+
+
+def add_duplicate_infos(representatives, sv_dict):
+    for rep, elements in representatives.items():
+        for d in elements:
+            sv_dict[d].record.info['DUPLICATEOF'] = rep
+        duplicates = list(elements)
+        if 'DUPLICATES' in sv_dict[rep].record.info:
+            print(sv_dict[rep].record.info['DUPLICATES'])
+            duplicates.extend(sv_dict[rep].record.info['DUPLICATES'])
+        if len(duplicates) > 0:
+            sv_dict[rep].record.info['DUPLICATES'] = duplicates
+
+
+def get_tool_name(sv_ident):
+    return sv_ident.split("_")[0]
+
+
+def SetSupportingTools(SVSet):
+    for sv in SVSet:
+        tools = {get_tool_name(sv.id)}
+        if "DUPLICATES" in sv.record.info:
+            duplicates = sv.record.info['DUPLICATES']
+            # print(duplicates)
+            for dupli in duplicates:
+                tools.add(get_tool_name(dupli))
+        if 'TOOLSUPPORT' in sv.record.info:
+            supporting = set(sv.record.info['TOOLSUPPORT'])
+            tools.union(supporting)
+        sv.record.info['TOOLSUPPORT'] = list(tools)
+
+
+def static_vars(**kwargs):
+    def decorate(func):
+        for k in kwargs:
+            setattr(func, k, kwargs[k])
+        return func
+    return decorate
+
+
+@static_vars(counter=0)
+def new_id_str(sv):
+    new_id_str.counter += 1
+    return "_".join(["cnvpipeline", sv.chrom, sv.svtype,
+                     str(new_id_str.counter)])
+
+
+def rename_info_field(sv, key, sv_dict):
+    if key in sv.record.info:
+        info_oldid = sv.record.info[key]
+        info_newid = [sv_dict[id] for id in info_oldid]
+        sv.record.info[key] = info_newid
+
+
+def RenameSV(SVSet):
+    sv_dict = defaultdict()
+    for sv in SVSet:
+        new_id = new_id_str(sv)
+        sv.setNewId(new_id)
+        sv_dict[sv.id] = new_id
+    for sv in SVSet:
+        rename_info_field(sv, "DUPLICATEOF", sv_dict)
+        rename_info_field(sv, "DUPLICATES", sv_dict)
+        sv.record.info['SOURCEID'] = sv.id
+        sv.rename()
-- 
GitLab