Skip to content
Snippets Groups Projects
svinterval.py 11.6 KiB
Newer Older
#! /usr/bin/env python


import argparse, sys
import os.path

from math import fabs
import numpy

import pysam
import vcf

from svreader.vcf_utils import get_template
from svreader import SVInter, SVRecord, SVReader

# A small wrapper to print to stderr    
def eprint(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)

class CNVR(SVInter):
    '''
        CNVR object :
          Copy Number Variation Region
    '''
    def __init__(self,svs):
        svinter =  next(iter(svs))
        start = int(numpy.median(numpy.array([x.start for x in svs])))
        end = int(numpy.median(numpy.array([x.end for x in svs])) )
        super(CNVR, self).__init__(svinter.chrom,start,end,".")
        self.__svs = {x.id:x for x in svs}
        self._repr_cnv = self._repr_cnv()
        self._bp_precision()
        self._CIntervals()
        if self._repr_cnv:
            eprint(self._repr_cnv.id,"Precise")
            self.start = self._repr_cnv.record.pos
            self.end = self._repr_cnv.record.info['END']
    @property
    def svs(self):
        return self.__svs
    def __str__(self):
        return self.chrom+":"+str(self.start)+"-"+str(self.end)+"\t"+self.id+"\t"+str(self.length())
    def overlaps(self):
        return [ self.svs[s].length()/float(self.length()) for s in sorted(self.svs)]
    def intervals(self):
        return [ self.svs[s].id for s in sorted(self.svs) ]
    def CopyNumbers(self):
        svs = self.__svs
        cn = []        
        for sv in svs:
            cn.append(svs[sv].CN())
        return ",".join(cn)
    def NumCNV(self):
        return len(self.__svs)
    def Callers(self):
        callers=list(set([c.split('_')[0] for c in self.__svs]))
        return callers
    def CallersVsamples(self):
        vsamples = []
        for caller in self.Callers():
            cnv=self.gettoolCNV(caller)
            if 'VSAMPLES' in cnv.record.info:
                vsamples.append(caller+":"+'-'.join(cnv.record.info['VSAMPLES']))
        return vsamples
    def NumCallers(self):
        return len(set([c.split('_')[0] for c in self.__svs]))
    def CNV(self):
        return list(self.__svs.keys())
    def cipos(self):
        return ",".join([str(x) for x in self._cipos])
    def ciend(self):
        return ",".join([str(x) for x in self._ciend])
    def bedformat(self):
        return "\t".join(map(str,[self.chrom,self.start,self.end,self.name,len(self.__svs),".",",".join(self.intervals()),",".join(["%0.2f" % x for x in self.overlaps()]),]))
    def precision(self):
        return self._precision
    def IsPrecise(self):
        if self._repr_cnv:
            return True
        else:
            return False
    def repr_cnv(self):
        return self._repr_cnv.record.id
    def _bp_precision(self):
        # Trying to infer the precision breakpoints according to the the different CNVs merged
        svs = self.__svs
        #print(self)
        #print(svs.keys())
        starts = [ svs[x].start for x in svs ]
        ends = [ svs[x].end for x in svs ]
        start_range = max(starts) - min(starts)
        end_range =  max(ends) - min(ends)
        #print(",".join([ str(svs[x].start) for x in svs]),start_range)
        #print(",".join([ str(svs[x].end) for x in svs]),end_range)
        self._precision = [start_range,end_range]
    def _repr_cnv(self):
        '''
           Trying to identify a cnv among the cnv of this CNVR that is a representative
             either Pindel (more than 10 supp reads), or PRECISE, delly, lumpy or genomeSTRIP
        '''
        selected = None
        if "pindel" in self.Callers():
            cnv = self.gettoolCNV("pindel")
            if cnv.record.info["SU"][0]>10:
                selected =  cnv
        elif "delly" in self.Callers():
            cnv = self.gettoolCNV("delly")
            if "PRECISE" in cnv.record.info and cnv.record.info["SR"]>10:
                selected =  cnv
        elif "lumpy" in self.Callers():
            cnv = self.gettoolCNV("lumpy")
            if not "IMPRECISE" in cnv.record.info and cnv.record.info["SR"][0]>10:
                selected =  cnv
        elif "genomestrip" in self.Callers():
            cnv = self.gettoolCNV("genomestrip")
            if not "IMPRECISE" in cnv.record.info:
                selected =  cnv
        
        return selected
                
    def _CIntervals(self):
        # Trying to infer CIPOS and CIEND from the called CNVs
        # we search for the CNV CIPOS and CIEND from the most confident 
        # to the least confident SV detection tool
        
        if self._repr_cnv:
            self._cipos = self._repr_cnv.record.info['CIPOS']
            self._ciend = self._repr_cnv.record.info['CIEND']
            return
            
        # default is large confidence interval
        self._cipos = (-50,50)
        self._ciend = (-50,50)
        tools=["genomestrip","lumpy","delly","pindel"]
        for tool in tools:
            cnv = self.gettoolCNV(tool)
            if cnv:
                self._cipos = cnv.record.info['CIPOS']
                self._ciend = cnv.record.info['CIEND']
                break # if a cnv from that tool was found we keep the walues from this call
    
    def gettoolCNV(self,toolname):
        cnvs = self.CNV()
        # Find CNV made by a specific tool
        indices = [i for i,s in enumerate(cnvs) if toolname in s]
        if len(indices):
            return self.__svs[cnvs[indices[0]]]
        else:
            return None
    
def to_vcf_record(cnv):
        # Copy numbers
        info = {"SVLEN" : cnv.length(), "SVTYPE" : "DEL", "END" : cnv.end, "NB_CNV" : cnv.NumCNV(), "NAMES":cnv.CNV(),"PRECISION":",".join(self._precision)}
        alt = [vcf.model._SV("DEL")]
        
        vcf_record = vcf.model._Record(cnv.chrom,
                                       cnv.start,
                                       cnv.name,
                                       "N",
                                       alt,
                                       ".",
                                       ".",
                                       info,
                                       "",
                                       [0],
                                       [])
        return vcf_record

def trim_column_header(vcf_reader):
    vcf_reader._column_headers.pop()
    
def passed_variant(record):
        """Did this variant pass?"""
        return record.filter is None or len(record.filter) == 0

def intersect(a,b):
    if a.chrom != b.chrom:
        return 0
    (low,high) = (a,b) if a.start< b.start else (b,a)
    intersect = 0
    if low.end < high.start:     # low << high
        intersect = 0
    elif low.end > high.end:     # low includes high
        intersect = high.end - high.start
    else:                        # high.start is within low.start and low.end
        intersect = low.end - high.start
    return intersect

def roverlap(a,b,cutoff):
    '''
        Returns true if the two intervals overlap (> cutoff and reciprocally)
    '''
    inter = intersect(a,b)
    if inter:
        pa = float(inter)/a.length() 
        pb = float(inter)/b.length() 
    else:
        pa=0
        pb=0
    return True if pa >= cutoff and pb >= cutoff else False

def breakpoint_left_precision(a,b):
    '''
        Returns the precision of the left breakpoint
    '''
    left_precision  = fabs(a.start-b.start)
    return left_precision
    
def breakpoint_right_precision(a,b):
    '''
        Returns the precision of the right breakpoint
    '''
    right_precision = fabs(a.end-b.end)
    return right_precision

def construct_overlap_graph(data,cutoff,left_precision,right_precision):
    '''
        Graph construction : add a link whenever two SV overlap
    '''
    nodes = set()
    for i in range(len(data)):
        for j in range(i+1,len(data)):
            if  ( roverlap(data[i],data[j],cutoff) and  
                breakpoint_left_precision(data[i],data[j]) <= left_precision  and  
                breakpoint_right_precision(data[i],data[j]) <= right_precision ):
                #print("%s %s %d %d %d %d %d" % (data[i],data[j],data[i].length(),data[j].length(),intersect(data[i],data[j]),breakpoint_left_precision(data[i],data[j]),breakpoint_right_precision(data[i],data[j])))
                data[i].add_link(data[j])
            
def connected_components(nodes):
    '''
        Constructing connected component of the graph
        input  : a set of nodes with linkage information
        output : a set of groups, one group for each connected component 
    '''
    # List of connected components found. The order is random.
    result = []
    # Make a copy of the set, so we can modify it.
    nodes = set(nodes)
    # Iterate while we still have nodes to process.
    while nodes:
        n = nodes.pop()   # Get a random node and remove it from the global set.
        group = {n}       # This set will contain the next group of nodes connected to each other.
        queue = [n]       # Build a queue with this node in it.
        # Iterate the queue.
        # When it's empty, we finished visiting a group of connected nodes.
        while queue:
            n = queue.pop(0)                       # Consume the next item from the queue.
            neighbors = n.links                    # Fetch the neighbors.
            neighbors.difference_update(group)     # Remove the neighbors we already visited.
            nodes.difference_update(neighbors)     # Remove the remaining nodes from the global set.
            group.update(neighbors)                # Add them to the group of connected nodes.            
            queue.extend(neighbors)                # Add them to the queue, so we visit them in the next iterations.
        result.append(group)                   # Add the group to the list of groups.
    # Return the list of gimport vcfroups.
    return result


    
# --------------------------------------
# main function

def main():
    # parse the command line args
    args = get_args()
    
    infiles = args.vcf
    prefix  = args.prefix
    no_index = args.no_index
    
    overlap_cutoff   = args.overlap_cutoff
    left_precision   = args.left_precision
    right_precision  = args.right_precision
    
    filenames=infiles.split(",")
    # Checking the existence of the files
    
    # Reading all the vcf files
    SVSet=[]
    for infile in filenames:
        eprint(" Reading file %s" % (infile))
        for record in SVReader(infile):
            if not passed_variant(record):
                continue
            SVSet.append(record)
    
    # Computing connected components according to reciprocal overlaps, left and right precision
    # reciprocal overlaps : only SV intervals with reciprocal overlap > = overlap_cutoff are linked
    # left, right precision : only SV intervals with left (right) within left_precision are linked
    eprint("Constructing conected component")
    construct_overlap_graph(SVSet,overlap_cutoff,left_precision,right_precision)

    number = 1
    cnvr = []
    for components in connected_components(SVSet):
        names = sorted(node.name for node in components)
        names = ", ".join(names)
        eprint("Group #%i: %s" % (number, names))
        cnv = CNVR(components)
        cnvr.append(cnv)
        number += 1
        
    # Writing the merge file in a single vcf file
    vcf_template_reader = get_template( "merge" )
    vcf_writer  = vcf.Writer( sys.stdout, vcf_template_reader)
    
    number = 1
    prefix_name = prefix.split(".")[0]
    for cnv in sorted(cnvr, key=lambda k: k.start):
        cnv.name = prefix_name+"_"+str(number)
        record = to_vcf_record(cnv)
        vcf_writer.write_record(record)
        number += 1
    vcf_writer.close()
    
# initialize the script
if __name__ == '__main__':
    try:
        sys.exit(main())
    except:
        raise