import sys

from collections import defaultdict
from networkx import Graph, connected_components

import numpy as np
from math import isnan

from pysam import VariantFile

from svrunner_utils import eprint, warn, vcf_to_pybed
from svreader.vcfwrapper import VCFRecord, VCFReader, VCFWriter

HOM_REF = (0, 0)
HET_VAR = (0, 1)
HOM_VAR = (1, 1)

VALID_GENOTYPES = [HOM_REF, HET_VAR,  HOM_VAR]

SVTYPER_GENO_LIKELI_TAG = "GL"
GENOMESTRIP_GENO_LIKELI_TAG = "GP"

Geno_likeli_tag = SVTYPER_GENO_LIKELI_TAG


def coded_geno(geno):
    if geno == HOM_REF:
        return "HOMO_REF"
    elif geno == HET_VAR:
        return "HET_VAR"
    elif geno == HOM_VAR:
        return "HOMO_VAR"
    else:
        return "UNDEF"


def Variant(sample):
    if sample.get('GT') in [HET_VAR, HOM_VAR]:
        return True
    else:
        return False


def Heterozygote(sample):
    if sample.get('GT') in [HET_VAR]:
        return True
    else:
        return False


class AnnotateRecord(VCFRecord):
    """
    A lightweight object to annotated the final records
    """

    def __init__(self, record):
        """
        A pysam VariantRecord wrapper
        """
        super(AnnotateRecord, self).__init__(record)
        self.new_id = None

    @property
    def start(self):
        return self.record.pos

    @property
    def end(self):
        return self.record.stop

    @property
    def svtype(self):
        return self._sv_type

    @property
    def svlen(self):
        return abs(self.stop - self.start)

    @property
    def passed(self):
        if "PASS" in self.record.filter:
            return True
        else:
            return False

    @property
    def num_samples(self):
        return len(self.record.samples.keys())

    def setNewId(self, new_id):
        self.new_id = new_id

    def rename(self):
        try:
            self.record.info['SOURCEID'] = self.id
        except KeyError:
            eprint("SOURCEID absent from record info keys,")
            sys.exit(1)
        self.id = self.new_id

    def num_variant_samples(self):
        return sum([Variant(s) for s in self.samples.values()])

    def variant_read_support(self):
        support = []
        for s in self.samples.values():
            if s.get('AO') is not None:
                support.append(s.get('AO')[0])
        return max(support)

    def qual(self):
        variant_qual = []
        for s in self.samples.values():
            if s.get('SQ') is not None:
                variant_qual.append(s.get('SQ'))
        return sum(variant_qual)

    def GQ_samples(self):
        genotype_qual = []
        for s in self.samples.values():
            if s.get('GQ') is not None:
                genotype_qual.append(s.get('GQ'))
        return genotype_qual

    def GQ_sum_score(self):
        return sum(self.GQ_samples())

    def maxGQ(self):
        return max(self.GQ_samples())

    def set_qual(self):
        self.record.qual = self.qual()

    def numdiffGenotypes(self):
        genotypes = defaultdict()
        for s in self.samples.values():
            if s.get('GT') in VALID_GENOTYPES:
                genotypes[coded_geno(s.get('GT'))] = 1
        return len(genotypes.keys())

    def polymorph(self):
        return self.numdiffGenotypes() > 1

    def add_supporting_infos(self):
        supp_reads = self.variant_read_support()
        num_supp_samples = self.num_variant_samples()
        #print(supp_reads, num_supp_samples)
        try:
            self.record.info['MAX_SUPP_READS'] = supp_reads
            self.record.info['NUM_SUPP_SAMPLES'] = num_supp_samples
        except KeyError:
            eprint("SUPP_READS or NUM_SUPP_SAMPLES absent from record info keys")
            sys.exit(1)

    def call_rate(self, cutoff):
        call_qual = []
        for s in self.samples.values():
            if s.get('GQ') is not None:
                call_qual.append(s.get('GQ'))
        num_qual_call = sum([(qual > cutoff) for qual in call_qual])
        return num_qual_call / self.num_samples

    def variant_call_rate(self, cutoff):
        samples = self.samples.values()
        num_qual_var = 0
        for s in samples:
            if s.get("GQ") is not None and s.get("GQ") > cutoff and Variant(s):
                num_qual_var += 1
        num_var_samples = self.num_variant_samples()
        var_call_rate = num_qual_var / num_var_samples if num_var_samples else 0
        return var_call_rate

    def unify_pass_filtertag(self):
        """
           All records passing the filters (PASS, .) ar now labelled PASS
        """
        record = self.record
        filters = [f for f in record.filter]
        # We make the assumption when a "." is present no other filter
        # are present
        if not filters or "." in filters:
            record.filter.clear()
            record.filter.add("PASS")


class AnnotateReader(VCFReader):
    def __init__(self, file_name, sv_to_report=None):
        super(AnnotateReader, self).__init__(file_name)
        self.filename = file_name
        self.sv_to_report = sv_to_report
        self.vcf_reader = VariantFile(file_name)

        self.add_annotation_metadata()

    def __iter__(self):
        return self

    def __next__(self):
        while True:
            raw_record = next(self.vcf_reader)
            record = AnnotateRecord(raw_record)
            return record

    def getHeader(self):
        return self.vcf_reader.header

    def add_annotation_metadata(self):
        self.addInfo("SOURCEID", 1, "String",
                     "The source sv identifier")
        self.addInfo("MAX_SUPP_READS", 1, "Integer",
                     "Max number of supporting reads")
        self.addInfo("NUM_SUPP_SAMPLES", 1, "Integer",
                     "Number of supporting samples")
        self.addFilter("LOWSUPPORT",
                       "total supp reads < 5 or supp samples < 2")

    def getOrderedSamples(self):
        samples = self.vcf_reader.header.samples
        sample_names = [sample.rsplit('.')[0] for sample in samples]
        return sample_names

    def numSamples(self):
        return len(self.vcf_reader.header.samples)


class AnnotateWriter(VCFWriter):

    def __init__(self, file_name,  template_reader, index=True):

        super(VCFWriter, self).__init__(file_name,
                                        template_reader.tool_name,
                                        template_reader)

    def _open(self):
        self.vcf_writer = VariantFile(self.filename, 'w',
                                      header=self.template_reader.getHeader())
        self._isopen = True

    def _write(self, record):
        self.vcf_writer.write(record.record)

    def _close(self):
        if self._isopen:
            self.vcf_writer.close()
        else:   # nothing was written
            self._dumpemptyvcf()


def ordered(a, b):
    # simply ordering the two string
    return (a, b) if a < b else (b, a)


def setLikeliTag(genotyper):
    global Geno_likeli_tag

    if genotyper == "svtyper":
        Geno_likeli_tag = SVTYPER_GENO_LIKELI_TAG
        warn("Assuming genotypes provided by svtyper software" +
             " hence tag is %s" % (Geno_likeli_tag))

    elif genotyper == "genomestrip":
        Geno_likeli_tag = GENOMESTRIP_GENO_LIKELI_TAG
        warn("Assuming genotypes provided by genomestrip software" +
             " hence tag is %s" % (Geno_likeli_tag))
    else:
        print("Unknown genotyping software")
        exit(1)


def getlikelihoods(sample):
    return sample.get('GL')


def probas(likelihoods):
    # transform log likelihoods into likelihoods
    return np.exp(np.array(likelihoods))


def getprobas(sample):
    # Transforming likelihods into probabilities
    return probas(getlikelihoods(sample)) / np.sum(probas(getlikelihoods(sample)))


def ondiagonal(u_s, v_s):
    # Probability that, for that individual, the two SVs are identically
    # genotyped P1(0/0)*P2(0/0) + ... P1(1/1)*P2(1/1)
    # in the same way :
    p = getprobas(u_s)
    q = getprobas(v_s)
    proba = 0
    for a, b in zip(p, q):
        proba += a * b
    # print("Proba on %3.5f" %(proba))
    return proba


def offdiagonal(u_s, v_s):
    # Probability that, for that individual, the two SVs are not identically
    # in the same way, complement of the previous one
    p = getprobas(u_s)
    q = getprobas(v_s)
    proba = 0
    for i, a in enumerate(p):
        for j, b in enumerate(q):
            if i != j:
                proba += a * b
    # print("Proba off %3.2f" %(proba))
    return proba


def duplicatescore(u, v):
    # For the two SVs, max discordant genotype log-ratio proba
    # same genotypes against discordant against (worse individual)

    # u_s is the sample of id s of u (idem for v_s)
    # Valid samples
    valid_samples = []
    for s in u.samples:
        u_s = u.samples[s]
        v_s = v.samples[s]
        if (u_s.get('GQ') is not None
                and v_s.get('GQ') is not None):
            valid_samples.append(s)

    max_disc = 0
    computed = float('NaN')
    for s in valid_samples:
        # ondiago is not used, we keep it just for comprehension
        # ondiago = ondiagonal(s, u, v)
        u_s = u.samples[s]
        v_s = v.samples[s]
        offdiago = offdiagonal(u_s, v_s)
        if offdiago > max_disc:
            max_disc = offdiago
    if max_disc > 0 and max_disc < 1:
        ratio = (1 - max_disc) / max_disc
        computed = np.log(ratio)
    return computed


def gstrength(u):
    """
    Sum of phred-like genotype qualities provides a measure of the
    combined genotype quality of the site
    np.sum([s['GQ'] if s['GQ'] is not None else 0 for s in u.samples.values()])
    """
    return u.GQ_sum_score()


def variantstrength(u):
    """
    maximum SQ, where SQ stands for
    Phred-scaled probability that this site is variant (non-reference)
    in this sample)
    QUAL = -10 * log(P(locus is reference in all samples)), which is
    equal to the sum of the SQ scores.
    see https://github.com/hall-lab/svtyper/issues/10
    sum([s['SQ'] if s['SQ'] is not None else 0 for s in u.samples.values()])
    """
    return u.qual()
    # max([s['SQ'] if s['SQ'] is not None else 0 for s in u.samples.values()])


def getduplicates_GQ(u, v):
    """
    select the prefered duplicate on the basis of the
    Sum of phred-like genotype qualities
    see gstrength
    returns prefered, discarded, strength of both
    """
    if gstrength(u) > gstrength(v):
        return (u, v, gstrength(u), gstrength(v))
    else:
        return (v, u, gstrength(v), gstrength(u))


def getduplicates_QUAL(u, v):
    """
    select the prefered duplicate on the basis of
    Phred-scaled probability that this site is a variant
    see variantstrength
    returns prefered, discarded, strength of both
    """
    if variantstrength(u) > variantstrength(v):
        return (u, v, variantstrength(u), variantstrength(v))
    else:
        return (v, u, variantstrength(v), variantstrength(u))


def getoverlap(u, osize):
    # percentage overlap given the size of the overlap
    return 100 * osize / u.svlen


def add_redundancy_infos_header(reader):
    # Adding DUPLICATESCORE info (equivalent to GSDUPLICATESCORE)
    reader.addInfo("DUPLICATESCORE", 1, "Float",
                   "LOD score that the events are distinct based on the "
                   "genotypes of the most discordant sample")
    # Adding DUPLICATES info (equivalent to GSDUPLICATEOVERLAP)
    reader.addInfo("DUPLICATEOVERLAP", 1, "Float",
                   "Highest overlap with a duplicate event")
    # Adding DUPLICATEOF info (list of variant prefered to this one)
    reader.addInfo("DUPLICATEOF", ".", "String",
                   "List of duplicate events preferred to this one")
    # Adding DUPLICATES info (list of variants duplicates of this one)
    reader.addInfo("DUPLICATES", ".", "String",
                   "List of duplicate events represented by the current sv")
    # Adding NONDUPLICATEOVERLAP
    reader.addInfo("NONDUPLICATEOVERLAP", 1, "Float",
                   "Amount of overlap with a non-duplicate")
    # Adding TOOLSUPPORT
    reader.addInfo("TOOLSUPPORT", ".", "String",
                   "Tools supporting (detecting) the sv")


def redundancy_annotator(SVSet, reader,
                         overlap_cutoff,
                         duplicatescore_threshold=-2,
                         genotyper="svtyper"):
    """ Annotating duplicate candidates based on the genotype likelihoods
      - genotype likelihoods can be provided by svtyper or genomestrip
    """

    add_redundancy_infos_header(reader)
    setLikeliTag(genotyper)

    variants = defaultdict()
    for sv in SVSet:
        variants[sv.id] = sv

    pybed_variants = vcf_to_pybed(SVSet)
    self_overlap = pybed_variants.intersect(pybed_variants,
                                            f=overlap_cutoff, r=True, wo=True)

    seen = defaultdict(tuple)
    duplicates = defaultdict(list)
    overlapping = defaultdict(tuple)
    reference = defaultdict()
    for o in self_overlap:
        if o[3] == o[7]:
            continue
        a, b = ordered(o[3], o[7])
        if seen[(a, b)]:
            continue
        seen[(a, b)] = True
        u = variants[a]
        v = variants[b]
        score = duplicatescore(u, v)
        # print("Comparing %s and %s : %3.8f" % (u.id, v.id, score))
        if score > duplicatescore_threshold:
            ref, dupli, _, _ = getduplicates_GQ(u, v)
            # print("%s prefered to %s %3.8f" % (ref.id, dupli.id, score))
            reference[ref] = 1
            overlap_size = int(o[-1])
            overlap = getoverlap(dupli, overlap_size)
            if ref.maxGQ() > 0 and dupli.passed:
                # are dismissed
                # - reference variant with 0 genotype quality for all markers
                # - duplicate that are already tagged as filtered out
                # dupli.id is considered as a duplicate of ref.id
                duplicates[dupli.id].append((ref.id, score, overlap))
        else:
            overlap_size = int(o[-1])
            overlap_u = getoverlap(u, overlap_size)
            overlap_v = getoverlap(v, overlap_size)
            overlapping[(u, v)] = {'dupli_score': score,
                                   'overlap_left': overlap_u,
                                   'overlap_right': overlap_v,
                                   }
    for u in SVSet:
        if u.id in duplicates:
            print("tagged as duplicate %s" % u.id)
            duplis = sorted([a for (a, s, o) in duplicates[u.id]])
            score = max([s for (a, s, o) in duplicates[u.id]])
            overlap = max([o for (a, s, o) in duplicates[u.id]])
            add_duplicate_info_sv(u, overlap, score, duplis)


def add_duplicate_info_sv(sv, duplicateoverlap, duplicatescore, duplicates):
    """
        simply adding two information to the sv infos
    """
    if isnan(duplicatescore):
        sv.record.info['DUPLICATESCORE'] = float('nan')
    else:
        sv.record.info['DUPLICATESCORE'] = duplicatescore
    sv.record.info['DUPLICATEOVERLAP'] = duplicateoverlap
    sv.record.info['DUPLICATEOF'] = duplicates


def add_overlap_info_sv(sv, overlap, duplicatescore):
    """
        simply adding two information to the sv infos
    """
    if isnan(duplicatescore):
        sv.record.info['DUPLICATESCORE'] = float('nan')
    else:
        sv.record.info['DUPLICATESCORE'] = duplicatescore
    sv.record.info['NONDUPLICATEOVERLAP'] = overlap


def add_callrate_infos_header(reader):
    # Adding CALLRATE info
    reader.addInfo("CALLRATE", 1, "Float",
                   "Percentage of samples called with a GQ>13")
    # Adding VARIANTCALLRATE info
    reader.addInfo("VARIANTCALLRATE", 1, "Float",
                   "Percentage of variant samples called with a GQ>13")


def add_filter_infos_header(reader):
    # FILTERS
    # Adding specific filters
    reader.addFilter("CALLRATE", "Call rate <0.75")
    reader.addFilter("VARIANTCALLRATE", "Variant Call rate <0.75")
    reader.addFilter("MONOMORPH", "All samples have the same genotype")
    reader.addFilter("DUPLICATE", "GSDUPLICATESCORE>0")
    reader.addFilter("OVERLAP", "NONDUPLICATEOVERLAP>0.7")
    reader.addFilter("ABFREQ", "AB frequency <0.3 for >50% heterosamples")


def variant_filtration(variant_set, reader, filter_monomorph=False,
                       filter_callrate=False):
    """ Filtering the candidate CNVs according to the following criteria
          - non duplicate sites
          - variant sites
          - call rate > 0.8
          - at least one variant (homozygous or heterozygote) has a genotype
            quality > 20
          - the variant is not everywhere heterozygote or homozygote
            (use NONVARIANTSCORE in both cases)
    """

    add_callrate_infos_header(reader)
    add_filter_infos_header(reader)

    for sv in variant_set:
        info = sv.record.info
        sv.record.info['CALLRATE'] = sv.call_rate(13)
        sv.record.info['VARIANTCALLRATE'] = sv.variant_call_rate(13)
        if sv.call_rate(13) < 0.75 and filter_callrate:
            sv.filter.add("CALLRATE")
        if not sv.polymorph() and filter_monomorph:
            sv.filter.add("MONOMORPH")
        if 'NONDUPLICATEOVERLAP' in info and info['NONDUPLICATEOVERLAP'] > 0.8:
            sv.filter.add("OVERLAP")
        if "DUPLICATESCORE" in info is not None and info['DUPLICATESCORE'] > -2:
            sv.filter.add("DUPLICATE")


def AB_filtering(variant_set):
    """ Filtering the candidate CNVs according to the following criteria
          - more than 50% of variant samples should have AB freq > 0.3
    """

    for sv in variant_set:
        valid_AB_freq = []
        for s in sv.record.samples.values():
            if Heterozygote(s):
                valid_AB_freq.append((s.get('AB')[0] > 0.3))
        if (len(valid_AB_freq) > 0 and
                sum(valid_AB_freq) < len(valid_AB_freq) / 2):
            sv.filter.add("ABFREQ")


def get_connected_duplicates(variant_set):
    """
    Construct connected components of duplicates and rename the variants
    """
    undirected = Graph()
    variant_dict = defaultdict()
    representatives = defaultdict()
    for s in variant_set:
        variant_dict[s.id] = s
        if "DUPLICATE" in s.filter:
            for dupli_repr in s.record.info["DUPLICATEOF"]:
                undirected.add_edge(s.id, dupli_repr)
    for component in connected_components(undirected):
        for c in component:
            if "DUPLICATEOF" in variant_dict[c].record.info:
                # the current variant is a duplicate
                continue
            rep = c  # the representative of the equivalence class
            break
        duplicates = component
        duplicates.remove(rep)
        representatives[rep] = duplicates
    add_duplicate_infos(representatives, variant_dict)


def add_duplicate_infos(representatives, sv_dict):
    for rep, elements in representatives.items():
        for d in elements:
            sv_dict[d].record.info['DUPLICATEOF'] = rep
        duplicates = list(elements)
        if 'DUPLICATES' in sv_dict[rep].record.info:
            print(sv_dict[rep].record.info['DUPLICATES'])
            duplicates.extend(sv_dict[rep].record.info['DUPLICATES'])
        if len(duplicates) > 0:
            sv_dict[rep].record.info['DUPLICATES'] = duplicates


def get_tool_name(sv_ident):
    return sv_ident.split("_")[0]


def set_supporting_tools(variant_set):
    for sv in variant_set:
        tools = {get_tool_name(sv.id)}
        if "DUPLICATES" in sv.record.info:
            duplicates = sv.record.info['DUPLICATES']
            # print(duplicates)
            for dupli in duplicates:
                tools.add(get_tool_name(dupli))
        if 'TOOLSUPPORT' in sv.record.info:
            supporting = set(sv.record.info['TOOLSUPPORT'])
            tools = tools.union(supporting)
        sv.record.info['TOOLSUPPORT'] = list(tools)


def static_vars(**kwargs):
    def decorate(func):
        for k in kwargs:
            setattr(func, k, kwargs[k])
        return func
    return decorate


@static_vars(counter=0)
def new_id_str(sv):
    new_id_str.counter += 1
    return "_".join(["cnvpipeline", sv.chrom, sv.svtype,
                     str(new_id_str.counter)])


def rename_info_field(sv, key, sv_dict):
    if key in sv.record.info:
        info_oldid = sv.record.info[key]
        info_newid = [sv_dict[id] for id in info_oldid]
        sv.record.info[key] = info_newid


def rename_variants(SVSet):
    sv_dict = defaultdict()
    for sv in SVSet:
        new_id = new_id_str(sv)
        sv.setNewId(new_id)
        sv_dict[sv.id] = new_id
    for sv in SVSet:
        rename_info_field(sv, "DUPLICATEOF", sv_dict)
        rename_info_field(sv, "DUPLICATES", sv_dict)
        sv.record.info['SOURCEID'] = sv.id
        sv.rename()