from math import isnan
import re

from collections import defaultdict

import numpy as np

from svrunner_utils import warn, vcf_to_pybed

HOM_REF = (0, 0)
HET_VAR = (0, 1)
HOM_VAR = (1, 1)

SVTYPER_GENO_LIKELI_TAG = "GL"
GENOMESTRIP_GENO_LIKELI_TAG = "GP"

Geno_likeli_tag = SVTYPER_GENO_LIKELI_TAG


def getNonVariantSamples(sv, variants, samples):
    return set([s for s in samples if s not in variants])


def getVariantSamples(sv):
    if re.search(r"genomestrip", sv.id) is not None:
        return set(sv.record.info['GSSAMPLES'])
    else:
        return set(sv.record.info['VSAMPLES'])


def passed_variant(record):
    """Did this variant pass?"""
    return (record.filter is None or
            len(record.filter) == 0 or
            "PASS" in record.filter)


def Called(record):
    called = 0
    for sample in record.samples:
        if getGQ(record.samples[sample]) > 13:
            called += 1
        # if (record.samples[sample]["GQ"] is not None and
        #    record.samples[sample]["GQ"]>13):
        #   called += 1
    return called


def getGQ(sample):
    return sample["GQ"] if sample["GQ"] is not None else 0


def callRate(sv):
    record = sv.record
    return float(Called(record))/len(record.samples)


def MaxGQ(record, geno):
    maxGQ = 0
    for sample in record.samples:
        if (record.samples[sample]["GT"] == geno
                and record.samples[sample]["GQ"] > maxGQ):
            maxGQ = record.samples[sample]["GQ"]
    return maxGQ


def Polymorph(sv):
    """
       At least two different genotypes
    """
    record = sv.record
    MaxHomRefGQ = MaxGQ(record, HOM_REF)
    MaxHomVarGQ = MaxGQ(record, HOM_VAR)
    MaxHetVarGQ = MaxGQ(record, HET_VAR)
    return MaxHomRefGQ*MaxHomVarGQ + MaxHomRefGQ*MaxHetVarGQ + MaxHomVarGQ*MaxHetVarGQ


def NumGeno(record, geno):
    numgeno = 0
    for sample in record.samples:
        if record.samples[sample]["GT"] == geno:
            numgeno += 1
    return numgeno


def InbreedingCoef(sv):
    """
       The inbreeding coefficient calculated as 1 - #observed het/#expected het
    """
    obs_hetero = NumGeno(sv, HET_VAR)
    exp_hetero = NumGeno(sv, HOM_VAR) + NumGeno(sv, HET_VAR)/2
    return 1 - float(obs_hetero)/exp_hetero


# Redundancy annotator
def getlikelihoods(sample):
    # return an array of genotype loglikelihoods
    return sample[Geno_likeli_tag]


def probas(likelihoods):
    # transform log likelihoods into likelihoods
    return np.exp(np.array(likelihoods))


def getprobas(sample):
    # Transforming likelihods into probabilities
    return probas(getlikelihoods(sample))/np.sum(probas(getlikelihoods(sample)))


def gstrength(u):
    # Sum of phred-like genotype qualities provides a measure of the
    # combined genotype quality of the site
    return np.sum([s['GQ'] if s['GQ'] is not None else 0 for s in u.samples.values()])


def variantstrength(u):
    # maximum SQ, where SQ stands for
    # Phred-scaled probability that this site is variant (non-reference)
    # in this sample)
    # QUAL = -10 * log(P(locus is reference is all samples)), which is
    # equal to the sum of the SQ scores.
    # see https://github.com/hall-lab/svtyper/issues/10
    return np.sum([s['SQ'] if s['SQ'] is not None else 0 for s in u.samples.values()])
    # return max([s['SQ'] if s['SQ'] is not None else 0 for s in u.samples.values()])


def ondiagonal(s, u, v):
    # Probability that, for that individual, the two SVs are identically
    # genotyped P1(0/0)*P2(0/0) + ... P1(1/1)*P2(1/1)
    # in the same way :
    p = getprobas(u.samples[s])
    q = getprobas(v.samples[s])
    proba = 0
    for a, b in zip(p, q):
        proba += a*b
    # print("Proba on %3.5f" %(proba))
    return proba


def offdiagonal(s, u, v):
    # Probability that, for that individual, the two SVs are not identically
    # in the same way, complement of the previous one
    p = getprobas(u.samples[s])
    q = getprobas(v.samples[s])
    proba = 0
    for i, a in enumerate(p):
        for j, b in enumerate(q):
            if i != j:
                proba += a*b
    # print("Proba off %3.2f" %(proba))
    return proba


def duplicatescore(u, v):
    # For the two SVs, max discordant genotype log-ratio proba
    # same genotypes against discordant against (worse individual)

    # Valid samples
    valid_samples = []
    for s in u.samples:
        if (u.samples[s]['GQ'] is not None
                and v.samples[s]['GQ'] is not None):
            valid_samples.append(s)

    max_disc = 0
    computed = float('NaN')
    for s in valid_samples:
        # ondiago is not used just for comprehension
        # ondiago = ondiagonal(s, u, v)
        offdiago = offdiagonal(s, u, v)
        if offdiago > max_disc:
            max_disc = offdiago
    if max_disc > 0 and max_disc < 1:
        ratio = (1-max_disc)/max_disc
        computed = np.log(ratio)
    return computed


def getduplicates(u, v):
    # select the prefered duplicate on the basis of the
    # Sum of phred-like genotype qualities
    # see gstrength
    # returns prefered, discarded, strength of both
    if gstrength(u) > gstrength(v):
        return (u, v, gstrength(u), gstrength(v))
    else:
        return (v, u, gstrength(v), gstrength(u))


def getduplicatesV2(u, v):
    # select the prefered duplicate on the basis of
    # Phred-scaled probability that this site is a variant
    # see variantstrength
    # returns prefered, discarded, strength of both
    if variantstrength(u) > variantstrength(v):
        return (u, v, variantstrength(u), variantstrength(v))
    else:
        return (v, u, variantstrength(v), variantstrength(u))


def GenotypeQuals(u):
    # the genotype qualities as an array
    return [u.samples[s]['GQ'] if u.samples[s]['GQ'] is not None else 0
            for s in u.samples]


def maxGQ(u):
    # maximum genotype quality
    return max(GenotypeQuals(u))


def passed(u):
    # did the variant passed the previous filter
    if len(u.filter) == 0 or "PASS" in u.filter:
        return True
    else:
        return False


def ordered(a, b):
    # simply ordering the two string
    return (a, b) if a < b else (b, a)


def svlen(u):
    # the length of the SV
    # vcf format is 1-based
    return u.stop-u.pos+1


def getoverlap(u, osize):
    # percentage overlap given the size of the overlap
    return 100*osize/svlen(u)


def setLikeliTag(genotyper):
    global Geno_likeli_tag

    if genotyper == "svtyper":
        Geno_likeli_tag = SVTYPER_GENO_LIKELI_TAG
        warn("Assuming genotypes provided by svtyper software" +
             " hence tag is %s" % (Geno_likeli_tag))

    elif genotyper == "genomestrip":
        Geno_likeli_tag = GENOMESTRIP_GENO_LIKELI_TAG
        warn("Assuming genotypes provided by genomestrip software" +
             " hence tag is %s" % (Geno_likeli_tag))
    else:
        print("Unknown genotyping software")
        exit(1)


def add_redundancy_infos_header(reader):
    # Adding DUPLICATESCORE info (equivalent to GSDUPLICATESCORE)
    reader.addInfo("DUPLICATESCORE", 1, "Float",
                   "LOD score that the events are distinct based on the "
                   "genotypes of the most discordant sample")
    # Adding DUPLICATES info (equivalent to GSDUPLICATEOVERLAP)
    reader.addInfo("DUPLICATEOVERLAP", 1, "Float",
                   "Highest overlap with a duplicate event")
    # Adding DUPLICATEOVERLAP info (equivalent to GSDUPLICATEOVERLAP)
    reader.addInfo("DUPLICATES", ".", "String",
                   "List of duplicate events preferred to this one")
    # Adding NONDUPLICATEOVERLAP
    reader.addInfo("NONDUPLICATEOVERLAP", 1, "Float",
                    "Amount of overlap with a non-duplicate")


def GenomeSTRIPLikeRedundancyAnnotator(SVSet, reader,
                                       duplicatescore_threshold=-2,
                                       genotyper="svtyper"):
    """ Annotating duplicate candidates based on the genotype likelihoods
      - genotype likelihoods can be provided by svtyper or genomestrip
    """

    add_redundancy_infos_header(reader)
    setLikeliTag(genotyper)

    variants = defaultdict()
    for sv in SVSet:
        variants[sv.id] = sv

    pybed_variants = vcf_to_pybed(SVSet)
    self_overlap = pybed_variants.intersect(pybed_variants, f=0.5, r=True, wo=True)

    seen = defaultdict(tuple)
    duplicates = defaultdict(list)
    overlapping = defaultdict(tuple)
    reference = defaultdict()
    for o in self_overlap:
        if o[3] == o[7]:
            continue
        a, b = ordered(o[3], o[7])
        if seen[(a, b)]:
            continue
        seen[(a, b)] = True
        u = variants[a]
        v = variants[b]
        score = duplicatescore(u, v)
        # print("Comparing %s and %s : %3.8f" % (u.id, v.id, score))
        if score > duplicatescore_threshold:
            ref, dupli, s1, s2 = getduplicates(u, v)
            # print("%s prefered to %s %3.8f" % (ref.id, dupli.id, score))
            reference[ref] = 1
            overlap_size = int(o[-1])
            overlap = getoverlap(dupli, overlap_size)
            if maxGQ(ref) > 0 and passed(dupli):
                # are dismissed
                # - reference variant with 0 genotype quality for all markers
                # - duplicate that are already tagged as filtered out
                # dupli.id is considered as a duplicate of ref.id
                duplicates[dupli.id].append((ref.id, score, overlap))
        else:
            overlap_size = int(o[-1])
            overlap_u = getoverlap(u, overlap_size)
            overlap_v = getoverlap(v, overlap_size)
            overlapping[(u, v)] = {'dupli_score': score,
                                   'overlap_left': overlap_u,
                                   'overlap_right': overlap_v,
                                   }

    for u in SVSet:
        if u.id in duplicates:
            # print("tagged as duplicate %s" % u.id)
            duplis = sorted([a for (a, s, o) in duplicates[u.id]])
            score = max([s for (a, s, o) in duplicates[u.id]])
            overlap = max([o for (a, s, o) in duplicates[u.id]])
            add_duplicate_info_sv(u, overlap, score, duplis)

    for o in overlapping:
        info = overlapping[o]
        u, v = o
        if u not in reference:
            add_overlap_info_sv(u, info['overlap_left'], info['dupli_score'])
        if v not in reference:
            add_overlap_info_sv(v, info['overlap_right'], info['dupli_score'])


def add_duplicate_info_sv(sv, duplicateoverlap, duplicatescore, duplicates):
    """
        simply adding two information to the sv infos
    """
    if isnan(duplicatescore):
        sv.record.info['DUPLICATESCORE'] = float('nan')
    else:
        sv.record.info['DUPLICATESCORE'] = duplicatescore
    sv.record.info['DUPLICATEOVERLAP'] = duplicateoverlap
    sv.record.info['DUPLICATES'] = duplicates


def add_overlap_info_sv(sv, overlap, duplicatescore):
    """
        simply adding two information to the sv infos
    """
    if isnan(duplicatescore):
        sv.record.info['DUPLICATESCORE'] = float('nan')
    else:
        sv.record.info['DUPLICATESCORE'] = duplicatescore
    sv.record.info['NONDUPLICATEOVERLAP'] = overlap


def add_filter_infos_header(reader):
    # FILTERS
    # Adding specific filters
    reader.addFilter("CALLRATE", "Call rate <0.75")
    reader.addFilter("MONOMORPH", "All samples have the same genotype")
    reader.addFilter("DUPLICATE", "GSDUPLICATESCORE>0")
    reader.addFilter("OVERLAP", "NONDUPLICATEOVERLAP>0.7")


def GenomeSTRIPLikefiltering(SVSet, reader):
    """ Filtering the candidate CNVs according to the following criteria
          - non duplicate sites
          - variant sites
          - call rate > 0.8
          - at least one variant (homozygous or heterozygote) has a genotype
            quality > 20
          - the variant is not everywhere heterozygote or homozygote
            (use NONVARIANTSCORE in both cases)
    """

    add_filter_infos_header(reader)

    for sv in SVSet:
        info = sv.record.info
        if callRate(sv) < 0.75:
            sv.filter.add("CALLRATE")
        if not Polymorph(sv):
            sv.filter.add("MONOMORPH")
        if 'NONDUPLICATEOVERLAP' in info and info['NONDUPLICATEOVERLAP'] > 0.7:
            sv.filter.add("OVERLAP")
        if "DUPLICATESCORE" in info is not None and info['DUPLICATESCORE'] > -2:
            sv.filter.add("DUPLICATE")