#! /usr/bin/env python import argparse, sys import os.path from math import fabs import numpy import pysam import vcf from svreader.vcf_utils import get_template from svreader import SVInter, SVRecord, SVReader # A small wrapper to print to stderr def eprint(*args, **kwargs): print(*args, file=sys.stderr, **kwargs) class CNVR(SVInter): ''' CNVR object : Copy Number Variation Region ''' def __init__(self,svs): svinter = next(iter(svs)) start = int(numpy.median(numpy.array([x.start for x in svs]))) end = int(numpy.median(numpy.array([x.end for x in svs])) ) super(CNVR, self).__init__(svinter.chrom,start,end,".") self.__svs = {x.id:x for x in svs} self._repr_cnv = self._repr_cnv() self._bp_precision() self._CIntervals() if self._repr_cnv: eprint(self._repr_cnv.id,"Precise") self.start = self._repr_cnv.record.pos self.end = self._repr_cnv.record.info['END'] @property def svs(self): return self.__svs def __str__(self): return self.chrom+":"+str(self.start)+"-"+str(self.end)+"\t"+self.id+"\t"+str(self.length()) def overlaps(self): return [ self.svs[s].length()/float(self.length()) for s in sorted(self.svs)] def intervals(self): return [ self.svs[s].id for s in sorted(self.svs) ] def CopyNumbers(self): svs = self.__svs cn = [] for sv in svs: cn.append(svs[sv].CN()) return ",".join(cn) def NumCNV(self): return len(self.__svs) def Callers(self): callers=list(set([c.split('_')[0] for c in self.__svs])) return callers def CallersVsamples(self): vsamples = [] for caller in self.Callers(): cnv=self.gettoolCNV(caller) if 'VSAMPLES' in cnv.record.info: vsamples.append(caller+":"+'-'.join(cnv.record.info['VSAMPLES'])) return vsamples def NumCallers(self): return len(set([c.split('_')[0] for c in self.__svs])) def CNV(self): return list(self.__svs.keys()) def cipos(self): return ",".join([str(x) for x in self._cipos]) def ciend(self): return ",".join([str(x) for x in self._ciend]) def bedformat(self): return "\t".join(map(str,[self.chrom,self.start,self.end,self.name,len(self.__svs),".",",".join(self.intervals()),",".join(["%0.2f" % x for x in self.overlaps()]),])) def precision(self): return self._precision def IsPrecise(self): if self._repr_cnv: return True else: return False def repr_cnv(self): return self._repr_cnv.record.id def _bp_precision(self): # Trying to infer the precision breakpoints according to the the different CNVs merged svs = self.__svs #print(self) #print(svs.keys()) starts = [ svs[x].start for x in svs ] ends = [ svs[x].end for x in svs ] start_range = max(starts) - min(starts) end_range = max(ends) - min(ends) #print(",".join([ str(svs[x].start) for x in svs]),start_range) #print(",".join([ str(svs[x].end) for x in svs]),end_range) self._precision = [start_range,end_range] def _repr_cnv(self): ''' Trying to identify a cnv among the cnv of this CNVR that is a representative either Pindel (more than 10 supp reads), or PRECISE, delly, lumpy or genomeSTRIP ''' selected = None if "pindel" in self.Callers(): cnv = self.gettoolCNV("pindel") if cnv.record.info["SU"][0]>10: selected = cnv elif "delly" in self.Callers(): cnv = self.gettoolCNV("delly") if "PRECISE" in cnv.record.info and cnv.record.info["SR"]>10: selected = cnv elif "lumpy" in self.Callers(): cnv = self.gettoolCNV("lumpy") if not "IMPRECISE" in cnv.record.info and cnv.record.info["SR"][0]>10: selected = cnv elif "genomestrip" in self.Callers(): cnv = self.gettoolCNV("genomestrip") if not "IMPRECISE" in cnv.record.info: selected = cnv return selected def _CIntervals(self): # Trying to infer CIPOS and CIEND from the called CNVs # we search for the CNV CIPOS and CIEND from the most confident # to the least confident SV detection tool if self._repr_cnv: self._cipos = self._repr_cnv.record.info['CIPOS'] self._ciend = self._repr_cnv.record.info['CIEND'] return # default is large confidence interval self._cipos = (-50,50) self._ciend = (-50,50) tools=["genomestrip","lumpy","delly","pindel"] for tool in tools: cnv = self.gettoolCNV(tool) if cnv: self._cipos = cnv.record.info['CIPOS'] self._ciend = cnv.record.info['CIEND'] break # if a cnv from that tool was found we keep the walues from this call def gettoolCNV(self,toolname): cnvs = self.CNV() # Find CNV made by a specific tool indices = [i for i,s in enumerate(cnvs) if toolname in s] if len(indices): return self.__svs[cnvs[indices[0]]] else: return None def to_vcf_record(cnv): # Copy numbers info = {"SVLEN" : cnv.length(), "SVTYPE" : "DEL", "END" : cnv.end, "NB_CNV" : cnv.NumCNV(), "NAMES":cnv.CNV(),"PRECISION":",".join(self._precision)} alt = [vcf.model._SV("DEL")] vcf_record = vcf.model._Record(cnv.chrom, cnv.start, cnv.name, "N", alt, ".", ".", info, "", [0], []) return vcf_record def trim_column_header(vcf_reader): vcf_reader._column_headers.pop() def passed_variant(record): """Did this variant pass?""" return record.filter is None or len(record.filter) == 0 def intersect(a,b): if a.chrom != b.chrom: return 0 (low,high) = (a,b) if a.start< b.start else (b,a) intersect = 0 if low.end < high.start: # low << high intersect = 0 elif low.end > high.end: # low includes high intersect = high.end - high.start else: # high.start is within low.start and low.end intersect = low.end - high.start return intersect def roverlap(a,b,cutoff): ''' Returns true if the two intervals overlap (> cutoff and reciprocally) ''' inter = intersect(a,b) if inter: pa = float(inter)/a.length() pb = float(inter)/b.length() else: pa=0 pb=0 return True if pa >= cutoff and pb >= cutoff else False def breakpoint_left_precision(a,b): ''' Returns the precision of the left breakpoint ''' left_precision = fabs(a.start-b.start) return left_precision def breakpoint_right_precision(a,b): ''' Returns the precision of the right breakpoint ''' right_precision = fabs(a.end-b.end) return right_precision def construct_overlap_graph(data,cutoff,left_precision,right_precision): ''' Graph construction : add a link whenever two SV overlap ''' nodes = set() for i in range(len(data)): for j in range(i+1,len(data)): if ( roverlap(data[i],data[j],cutoff) and breakpoint_left_precision(data[i],data[j]) <= left_precision and breakpoint_right_precision(data[i],data[j]) <= right_precision ): #print("%s %s %d %d %d %d %d" % (data[i],data[j],data[i].length(),data[j].length(),intersect(data[i],data[j]),breakpoint_left_precision(data[i],data[j]),breakpoint_right_precision(data[i],data[j]))) data[i].add_link(data[j]) def connected_components(nodes): ''' Constructing connected component of the graph input : a set of nodes with linkage information output : a set of groups, one group for each connected component ''' # List of connected components found. The order is random. result = [] # Make a copy of the set, so we can modify it. nodes = set(nodes) # Iterate while we still have nodes to process. while nodes: n = nodes.pop() # Get a random node and remove it from the global set. group = {n} # This set will contain the next group of nodes connected to each other. queue = [n] # Build a queue with this node in it. # Iterate the queue. # When it's empty, we finished visiting a group of connected nodes. while queue: n = queue.pop(0) # Consume the next item from the queue. neighbors = n.links # Fetch the neighbors. neighbors.difference_update(group) # Remove the neighbors we already visited. nodes.difference_update(neighbors) # Remove the remaining nodes from the global set. group.update(neighbors) # Add them to the group of connected nodes. queue.extend(neighbors) # Add them to the queue, so we visit them in the next iterations. result.append(group) # Add the group to the list of groups. # Return the list of gimport vcfroups. return result # -------------------------------------- # main function def main(): # parse the command line args args = get_args() infiles = args.vcf prefix = args.prefix no_index = args.no_index overlap_cutoff = args.overlap_cutoff left_precision = args.left_precision right_precision = args.right_precision filenames=infiles.split(",") # Checking the existence of the files # Reading all the vcf files SVSet=[] for infile in filenames: eprint(" Reading file %s" % (infile)) for record in SVReader(infile): if not passed_variant(record): continue SVSet.append(record) # Computing connected components according to reciprocal overlaps, left and right precision # reciprocal overlaps : only SV intervals with reciprocal overlap > = overlap_cutoff are linked # left, right precision : only SV intervals with left (right) within left_precision are linked eprint("Constructing conected component") construct_overlap_graph(SVSet,overlap_cutoff,left_precision,right_precision) number = 1 cnvr = [] for components in connected_components(SVSet): names = sorted(node.name for node in components) names = ", ".join(names) eprint("Group #%i: %s" % (number, names)) cnv = CNVR(components) cnvr.append(cnv) number += 1 # Writing the merge file in a single vcf file vcf_template_reader = get_template( "merge" ) vcf_writer = vcf.Writer( sys.stdout, vcf_template_reader) number = 1 prefix_name = prefix.split(".")[0] for cnv in sorted(cnvr, key=lambda k: k.start): cnv.name = prefix_name+"_"+str(number) record = to_vcf_record(cnv) vcf_writer.write_record(record) number += 1 vcf_writer.close() # initialize the script if __name__ == '__main__': try: sys.exit(main()) except: raise