#!/usr/bin/env python3
# coding: utf-8
import os.path
from os import path
from collections import defaultdict
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.Alphabet import IUPAC
from Bio import SeqIO
import pysam
import sys

#### command line arguments
inputMarkersFlankingGenes = sys.argv[1] # table with mapped flanking ISBPs from genes
outputDir = sys.argv[2] # output directory ...
genomeQuery = sys.argv[3] # fasta file of the v1 genome (must be fai indexed)
genomeTarget = sys.argv[4] # fasta file of the v2 genome (must be fai indexed)
markerPosOnTarget = sys.argv[5] # bed file of new coords of markers on target reference genome


#### generic data structure
# dict with summary of blat alignements: target region, coverage, indels, missmatches, anchoring status
blatResults = defaultdict()


def main():

	### check for input files and output directory
	checkFile(file=inputMarkersFlankingGenes)
	if ( checkDir(directory=outputDir) ):
		sys.stdout.write( " Output directory already exists, no need to create it\n")

	else:
		sys.stdout.write( " Creating output directory %s" % str(outputDir))
		os.mkdir( outputDir, 0o750 )

	### open for reading the query and target reference file
	queryFasta = pysam.FastaFile(genomeQuery)
	targetFasta = pysam.FastaFile(genomeTarget)

	### save coords of marker on target reference
	markersOnTarget_dict = getCoordsFromBed(bedfile=markerPosOnTarget)

	### read input file containing the flanking IBSP markers for each gene
	##### variables
	numlines=0
	index_result_dir=0
	prefixResultDir=outputDir+'/temp/'+str(index_result_dir).zfill(6)

	##### output files
	with open(inputMarkersFlankingGenes) as file:
		for line in file.readlines():
			# increment counters
			# because to avoid having a big result dir with +100k subdirs, 
			# the scripts creates subdirs for every 1k genes
			numlines+=1
			if numlines % 1000 == 0:
				index_result_dir+=1
				prefixResultDir=outputDir+'/temp/'+str(index_result_dir).zfill(6)

			# define the working dir for the current gene analysis
			workingDir=prefixResultDir+'/'+str(numlines).zfill(6)
			os.makedirs(workingDir, 0o750)
			
			# deal with data
			lineArray = line.rstrip("\n").split("\t")
			print("\n\n------------------------------------------")
			print(" Results of gene %s going into subdir %s" % (lineArray[0], str(workingDir)))
			print(" array of data: ")
			print(lineArray)

			# check if both anchors are mapped on the same chromosome
			geneDict={ 'geneId': lineArray[0],
						'geneChrom':lineArray[1],
						'geneStart':lineArray[2],
						'geneStop':lineArray[3],
						'geneScore':lineArray[4],
						'geneStrand':lineArray[5],
						'marker5pChrom':lineArray[6],
						'marker5pStart':lineArray[7],
						'marker5pStop':lineArray[8],
						'marker5pId':lineArray[9],
						'marker5pDistance':lineArray[10],
						'marker3pChrom':lineArray[11],
						'marker3pStart':lineArray[12],
						'marker3pStop':lineArray[13],
						'marker3pId':lineArray[14],
						'marker3pDistance':lineArray[15],
						'regionLengthOnQuery':lineArray[16]
					}
			print(" Gene directory:")
			print(geneDict)

			anchor5pTargetCoods_dict= {'chrom': 0, 'start':0, 'stop':0}
			anchor3pTargetCoods_dict= {}

			if geneDict['marker5pChrom'] == '.' or geneDict['marker3pChrom'] == '.':
				# deal with genes with one anchor missing (basecally start and end of chromosomes)
				sys.stderr.write(" current gene %s has missing anchor on one side \n" % geneDict['geneId'])

				if geneDict['marker5pId'] == '.':
					print(' Missing 5prime anchor: using genomic sequence from the start of chromosome')
					geneDict['marker5pChrom'] = geneDict['marker3pChrom'].replace('chr','Chr')
					geneDict['marker5pStart'] = 0
					geneDict['marker5pStop'] = 0
					geneDict['marker5pId'] = geneDict['marker5pChrom']
					anchor5pTargetCoods_dict.update(chrom= geneDict['marker5pChrom'], start= 1, stop=1)
					anchor3pTargetCoods_dict.update(markersOnTarget_dict[geneDict['marker3pId']])

				elif geneDict['marker3pId'] == '.':
					print(' Missing 3prime anchor: using genomic sequence to the end of chromosome')
					geneDict['marker3pChrom'] = geneDict['marker5pChrom'].replace('chr','Chr')
					geneDict['marker3pStart'] = targetFasta.get_reference_length(geneDict['marker3pChrom'])
					geneDict['marker3pStop'] = targetFasta.get_reference_length(geneDict['marker3pChrom'])
					geneDict['marker3pId'] = geneDict['marker3pChrom']
					anchor3pTargetCoods_dict.update(chrom= geneDict['marker3pChrom'], start=geneDict['marker3pStart'], stop=geneDict['marker3pStop'])
					anchor5pTargetCoods_dict.update(markersOnTarget_dict[geneDict['marker5pId']])
				else:
					print(" Problem with the anchors: canno find 5prime and 3prime anchors for the gene {}".format(geneDict['geneId']))
			else:
				# get mapping of the flanking anchors on the target genome
				anchor5pTargetCoods_dict.update(markersOnTarget_dict[geneDict['marker5pId']])
				anchor3pTargetCoods_dict.update(markersOnTarget_dict[geneDict['marker3pId']])

			print(" anchors 5prime:\n")
			print(anchor5pTargetCoods_dict)
			print(" anchors 3prime:\n")
			print(anchor3pTargetCoods_dict)

			if anchor5pTargetCoods_dict['chrom'] == anchor3pTargetCoods_dict['chrom']:
				print(" both anchors are mapped on the same chrom. we can proceed to sequence extraction and mapping\n")
				### deal with query sequence: extract the fasta seq to map
				queryGeneSeq=getFastaSeq(fasta=queryFasta, 
					chrom=geneDict['geneChrom'], 
					start=int(geneDict['geneStart']), 
					stop=int(geneDict['geneStop']))
				querySeqRecord= SeqRecord(
									Seq(queryGeneSeq, IUPAC.ambiguous_dna), 
									id=geneDict['geneId'],
									description='coords '+geneDict['geneChrom']+'_'+str(geneDict['geneStart'])+'-'+str(geneDict['geneStop'])+', flanking markers '+ geneDict['marker5pId']+'-'+geneDict['marker3pId'])
				#print(querySeqRecord)
				queryFasta2blast = workingDir+'/query.fasta'
				SeqIO.write(querySeqRecord, queryFasta2blast, 'fasta')

				### deal with target sequence: extract the fasta seq to use a db
				# fisrt check orientation:
				if anchor5pTargetCoods_dict['start'] > anchor3pTargetCoods_dict['stop']:
					# invert start and stop
					anchor5pTargetCoods_dict['start'],anchor3pTargetCoods_dict['stop'] = anchor3pTargetCoods_dict['stop'],anchor5pTargetCoods_dict['start']
					
				targetGenomeSeq=getFastaSeq(fasta=targetFasta, 
					chrom=anchor5pTargetCoods_dict['chrom'], 
					start=int(anchor5pTargetCoods_dict['start']), 
					stop=int(anchor3pTargetCoods_dict['stop']))
				targetSeqRecord= SeqRecord(
									Seq(targetGenomeSeq, IUPAC.ambiguous_dna), 
									id='target_'+anchor5pTargetCoods_dict['chrom']+'_'+str(anchor5pTargetCoods_dict['start'])+'_'+str(anchor3pTargetCoods_dict['stop']))
				targetFasta2blast = workingDir+'/target.fasta'
				SeqIO.write(targetSeqRecord, targetFasta2blast, 'fasta')

				################### run Blat of the genomic sequence on the target region on new assembly
				blatResultFile=runBlat(db=targetFasta2blast, 
					query=queryFasta2blast, 
					blatResult=workingDir+'/blat.txt', 
					outFormat='psl')

				############ check if the genomic align perfectly on the new reference
				checkPerfectHit(blatresult=blatResultFile, 
					workingDir=workingDir, 
					maxhit=1)

				os.remove(queryFasta2blast)
				os.remove(targetFasta2blast)


			else:
				# TODO
				sys.stderr.write(" 5prime anchor %s and 3prime anchor %s are not on the same chromsome\n" % (geneDict['marker5pId'],geneDict['marker3pId']))
				sys.stderr.write(" Cannot anchoring gene %s \n" % geneDict['geneId'])
				


			# #input()
			# if numlines == 150:
			# 	break

def checkPerfectHit(blatresult, workingDir, maxhit):
	print(" blatResults %s" % blatresult)
	hitIndex=0
	with open(blatresult) as blat:
		for hit in blat.readlines():
			hitIndex+=1

			if hitIndex > maxhit:
				sys.stderr.write(" Tried more than {} hits: giving up for this gene".format(maxhit))
				break
			print("    -> #%i Blat hit %s " % (hitIndex,hit))
				
			if parsePSL(pslRecord=hit) :
				print(" we can anchor this gene")
				break
				
			else:
				print(" we cannot anchor this gene\n")
				print(" adding mummer/nucmer alignment info to this gene\n")

				### add more alignemnt info using mummer/nucmer
				# nucmer: perform the alignment
				nucmerPref=workingDir+'/nucmer'
				ref=workingDir+'/target.fasta'
				query=workingDir+'/query.fasta'
				os.system('nucmer -p {} {} {}'.format(nucmerPref,ref,query))

				# mummerplot: generate the dot plot in png format
				mummerplotPref=workingDir+'/mummerplot'
				os.system('mummerplot --SNP -t png -p {} {}'.format(mummerplotPref,nucmerPref+'.delta'))

				# dnadiff: show potential snps and gaps
				dnadiffPref=workingDir+'/dnadiff'
				os.system('dnadiff -p {} -d {}'.format(dnadiffPref,nucmerPref+'.delta'))

				### add another blat format output for visualisation with ACT
				blatfile=workingDir+'/blast_m8.tab'
				os.system(' blat -extendThroughN -out=blast8 {} {} {}'.format(ref,query,blatfile))
				run_act_file=workingDir+'/run_act.sh'
				os.system(' echo act query.fasta blast_m8.tab target.fasta > {}'.format(blatfile, run_act_file))
				



def parsePSL(pslRecord):
	cov = ''
	missmatches = ''
	indels = ''
	pslData=pslRecord.rstrip('\n').split('\t')

	##### calc the % coverage
	# pc_cov = (num_matche + num_mm + num_N) / length(querySeq) * 100
	# add N count if we have some
	# BLAT take into acount NN stretches with the '-extendThroughN' parameter.
	sumCov = int(pslData[0]) + int(pslData[1]) + int(pslData[3])
	pc_cov = ( sumCov / int(pslData[10])) * 100 
	print("     - Coverage: {}".format(pc_cov))

	##### get missmatches info
	missmatches = int(pslData[1])
	pc_mm = (missmatches / sumCov) * 100
	print("     - missmatches: {} ({} %)".format(missmatches, pc_mm))

	##### get indels info
	# based on the column 8: Number of bases inserted into target
	indelBases = int(pslData[7])
	pc_indels = (indelBases / sumCov) * 100
	print("     - indels: {} ({} %)".format(indelBases, pc_indels))

	##### decide if we remap or not the gene
	if int(float(pc_cov)) < 100 or int(indelBases) > 0:
		print(" Gene {} has indel or is not fully covered: CANNOT ANCHOR IT\n".format(pslData[9]))
		return 0
	else:
		if missmatches == 0:
			print(" Gene {} is a perfect full match: anchoring can be done!\n".format(pslData[9]))
			return 1
		else:
			print(" Gene {} has missmatches warning in anchoring!\n".format(pslData[9]))
			return 1

def runBlat(db,query,blatResult, outFormat):
	cmd='blat -noHead -extendThroughN -out='+outFormat+' '+db+' '+query+' '+blatResult
	print(' blat command to run: %s'% str(cmd))
	os.system('blat -noHead -extendThroughN -out={} {} {} {}'.format(outFormat,db,query,blatResult))
	return blatResult

def getFastaSeq(fasta,chrom,start,stop):
	seq=fasta.fetch(reference=chrom, start=start,end=stop)
	print(" fasta sequence for chrom %s from %s to %s\n" %(chrom,start,stop))
	#print(seq)
	return seq

def getCoordsFromBed (bedfile):
	checkFile(file=bedfile)
	bedDict = defaultdict()
	with open(bedfile) as file:
		for line in file.readlines():
			bed=line.rstrip("\n").split("\t")
			if len(bed) < 4:
				sys.stderr.write(" ERROR PARSING BEDFILE %s " % str(bedfile)) 
				sys.stderr.write(" File do ot have 4 columns or is not tab delimited\n")
				sys.stderr.write(" Please check the format\n")
				sys.exit()
			#print(" bed array content ")
			#print(bed)
			bedDict[str(bed[3])] ={'chrom': bed[0],'start':bed[1],'stop':bed[2],'mapQ': bed[4],'strand':bed[5]}

	numRecords=len(bedDict.keys())
	sys.stdout.write(" Found %i records in Bed file %s " % (numRecords, bedfile))
	return bedDict

def checkFile (file):
	if os.path.isfile(file):
		sys.stdout.write(" File %s found\n" % str(file))
	else:
		sys.stderr.write(" ERROR: Cannot find file %s \n" % str(file))
		sys.exit()

def checkDir(directory):
	if os.path.isdir(directory):
		sys.stdout.write(" Directory %s found\n" % str(directory))
		return 1
	else:
		sys.stderr.write(" Cannot find directory %s \n" % str(directory))
		return 0


if __name__== "__main__":
	main()