Commit cd6dd688 authored by Jean Mainguy's avatar Jean Mainguy
Browse files

refactoring of aln2taxaffi. use of new_taxdump and prot.accession2taxid.FULL

parent f81cd2d3
#!/usr/bin/env python
#!/usr/bin/env python3
"""----------------------------------------------------------------------------
Script Name: aln2taxaffi.py
Description:
Input files: File with correspondence between accession ids and taxon ids, \
taxonomy directory, diamond output file (.m8) and \
output file from DESMAN Lengths.py script (.len)
taxonomy directory and diamond output file (.m8)
Created By: Celine Noirot
Date: 2019-09-06
-------------------------------------------------------------------------------
"""
# Metadata
__author__ = 'Celine Noirot \
__author__ = 'Celine Noirot, Jean Mainguy\
- Plateforme bioinformatique Toulouse'
__copyright__ = 'Copyright (C) 2019 INRA'
__license__ = 'GNU General Public License'
......@@ -20,478 +19,517 @@ __version__ = '0.1'
__email__ = 'support.bioinfo.genotoul@inra.fr'
__status__ = 'dev'
# Status: dev
# Modules importation
try:
import argparse
import pandas as p
import re
import sys
import os
import operator
from collections import defaultdict
from collections import OrderedDict
from collections import Counter
from matplotlib import pyplot
import logging
import csv
except ImportError as error:
print(error)
exit(1)
import re
import os
import operator
from collections import defaultdict
from collections import Counter
import csv
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
import logging
from matplotlib import pyplot
from collections import OrderedDict
# Variables
# These are identities normalized with query coverage:
MIN_IDENTITY_TAXA = (0.40, 0.50, 0.60, 0.70, 0.80, 0.90, 0.95)
RANKS_TO_MIN_SCORE = {'superkingdom': 0.4,
'phylum': 0.5,
'class': 0.6,
'order': 0.7,
'family': 0.8,
'genus': 0.9,
'species': 0.95}
# Fraction of weights needed to assign at a specific level,
# a measure of concensus at that level.
MIN_FRACTION = 0.9
def parse_arguments():
"""Parse script arguments."""
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
parser.add_argument("-b", "--aln_input_file", required=True,
help="file with blast/diamond matches expected format m8 \
\nqueryId, subjectId, percIdentity, alnLength, mismatchCount, gapOpenCount,\
queryStart, queryEnd, subjectStart, subjectEnd, eVal, bitScore")
parser.add_argument('-a', '--acc_taxaid_mapping_file', required=True,
help="mapping from accession to taxaid gzipped")
parser.add_argument('-t', '--taxonomy', required=True,
help="path of taxdump.tar.gz extracted directory")
parser.add_argument('-o', '--output_file', type=str,
default="taxonomyassignation", help=("string specifying output file"))
parser.add_argument('-i', '--min_identity', default=60,
help="percentage of identity")
parser.add_argument('-c', '--min_coverage', default=70,
help="percentage of coverage")
parser.add_argument('--top', default=10,
help="Keep diamond alignments within this percentage range of top alignment score")
parser.add_argument('--keep_only_best_aln',
help="Keep only diamond alignments with top alignment score. (overrides --top)", action="store_true")
parser.add_argument("--query_length_file",
help="tab delimited file of query lengths")
parser.add_argument("-v", "--verbose", help="increase output verbosity",
action="store_true")
args = parser.parse_args()
return args
################################################
# Functions for taxonomy taxdump.tar.gz
# Functions for taxonomy
################################################
# Define taxonomy variable
global d_taxonomy
d_taxonomy = {}
# Define taxonomy main levels
global main_level
main_level = \
["superkingdom", "phylum", "class", "order", "family", "genus", "species"]
# SAME AS Script_renameVcn.py
prot_prefix = "CDS_"
# Definition of the class Node
class Node:
def __init__(self):
self.tax_id = 0 # Number of the tax id.
self.parent = 0 # Number of the parent of this node
self.children = [] # List of the children of this node
self.tip = 0 # Tip=1 if it's a terminal node, 0 if not.
self.name = "" # Name of the node: taxa if it's a terminal node,
# numero if not.
self.level = "None"
def genealogy(self): # Trace genealogy from root to leaf
ancestors = [] # Initialize the list of all nodes
# from root to leaf.
tax_id = self.tax_id # Define leaf
while 1:
if tax_id in d_taxonomy:
ancestors.append(tax_id)
tax_id = d_taxonomy[tax_id].parent
else:
break
if tax_id == "1":
# If it is the root, we reached the end.
# Add it to the list and break the loop
ancestors.append(tax_id)
break
return ancestors # Return the list
def fullnamelineage(self): # Trace genealogy from root to leaf
ancestors = [] # Initialise the list of all nodes
# from root to leaf.
tax_id = self.tax_id # Define leaf
while 1:
if tax_id in d_taxonomy:
ancestors.append(d_taxonomy[tax_id].name)
tax_id = d_taxonomy[tax_id].parent
else:
break
if tax_id == "1":
break
ancestors.reverse()
return "; ".join(ancestors) # Return the list
def genealogy_main_level(self):
ancestors = ["None"] * 7 # Initialise the list of all nodes
# from root to leaf.
tax_id = self.tax_id
while 1:
if tax_id in d_taxonomy:
cur_level = d_taxonomy[tax_id].level
if cur_level in main_level:
ancestors[main_level.index(cur_level)] = tax_id
tax_id = d_taxonomy[tax_id].parent
else:
break
if tax_id == "1":
# If it is the root, we reached the end.
break
return ancestors # Return the list
def lineage_main_level(self):
ancestors = ["None"] * 7 # Initialise the list of all nodes
# from root to leaf.
ancestors_tax_id = ["None"] * 7 # Initialise the list of all nodes
tax_id = self.tax_id
while 1:
if tax_id in d_taxonomy:
cur_level = d_taxonomy[tax_id].level
if cur_level in main_level:
ancestors[main_level.index(cur_level)] = d_taxonomy[tax_id].name
ancestors_tax_id[main_level.index(cur_level)] = str(tax_id)
tax_id = d_taxonomy[tax_id].parent
else:
break
if tax_id == "1":
# If it is the root, we reached the end.
break
return ("; ".join(ancestors), "; ".join(ancestors_tax_id)) # Return the two lists
# Function to find common ancestor between two nodes or more
def common_ancestor(node_list):
global d_taxonomy
# Define the whole genealogy of the first node
list1 = d_taxonomy[node_list[0]].genealogy()
for node in node_list:
# Define the whole genealogy of the second node
list2 = d_taxonomy[node].genealogy()
ancestral_list = []
for i in list1:
if i in list2: # Identify common nodes between the two genealogy
ancestral_list.append(i)
list1 = ancestral_list # Reassing ancestral_list to list 1.
# Finally, the first node of the ancestra_list is the common ancestor
# of all nodes.
common_ancestor = ancestral_list[0]
# Return a node
return common_ancestor
def load_taxonomy(directory):
# Load taxonomy
global d_taxonomy
# Load names definition
d_name_by_tax_id = {} # Initialize dictionary with TAX_ID:NAME
d_name_by_tax_id_reverse = {} # Initialize dictionary with NAME:TAX_ID
# Load NCBI names file ("names.dmp")
with open(os.path.join(directory, "names.dmp"), "r") as name_file:
def load_taxonomy(taxdump_dir, main_ranks, taxids_selection):
logging.info(f'Load taxonomy information for {len(taxids_selection)} taxids.')
nodes_file = os.path.join(taxdump_dir, "nodes.dmp")
taxidlineage_file = os.path.join(taxdump_dir, "taxidlineage.dmp")
merged_file = os.path.join(taxdump_dir, "merged.dmp")
names_file = os.path.join(taxdump_dir, "names.dmp")
merged_taxid = replace_merged_taxid(taxids_selection, merged_file)
logging.info(f'{len(merged_taxid)} taxids have been merged into another taxids: {merged_taxid}')
taxid2lineage_whole_db = get_all_taxid_lineage(taxidlineage_file)
all_taxids = {taxid for leaf_taxid, lineage in taxid2lineage_whole_db.items()
for taxid in lineage if leaf_taxid in taxids_selection}
all_taxids.add(1)
taxid2lineage = {taxid: lineage for taxid,
lineage in taxid2lineage_whole_db.items() if taxid in all_taxids}
taxid2rank = get_taxid_rank(all_taxids, nodes_file)
taxid2rankedlineage = {taxid: get_ranked_lineage(
lineage, taxid2rank, main_ranks) for taxid, lineage in taxid2lineage.items()}
taxid2name = get_taxid2name(all_taxids, names_file)
logging.info(f'Load taxid information done. {len(taxids_selection - set(taxid2rankedlineage))} taxid has not been found in taxdump files')
return taxid2rankedlineage, taxid2name, taxid2rank, merged_taxid
def get_taxid2name(taxids, names_file):
taxid2name = {"None": "None"}
with open(names_file, "r") as name_file:
for line in name_file:
line = line.rstrip().replace("\t", "")
tab = line.split("|")
if tab[3] == "scientific name":
tax_id, name = tab[0], tab[1] # Assign tax_id and name ...
d_name_by_tax_id[tax_id] = name # ... and load them
d_name_by_tax_id_reverse[name] = tax_id # ... into dictionaries
# Load taxonomy NCBI file ("nodes.dmp")
with open(os.path.join(directory, "nodes.dmp"), "r") as taxonomy_file:
for line in taxonomy_file:
line = line.rstrip().replace("\t", "")
tab = line.split("|")
if int(tab[0]) in taxids and tab[3] == "scientific name":
tax_id, name = int(tab[0]), tab[1]
taxid2name[tax_id] = name
return taxid2name
tax_id = str(tab[0].strip())
tax_id_parent = str(tab[1].strip())
division = str(tab[4].strip())
# Define name of the taxid
name = "unknown"
if tax_id in d_name_by_tax_id:
name = d_name_by_tax_id[tax_id]
if tax_id not in d_taxonomy:
d_taxonomy[tax_id] = Node()
d_taxonomy[tax_id].tax_id = tax_id # Assign tax_id
d_taxonomy[tax_id].parent = tax_id_parent # Assign tax_id parent
d_taxonomy[tax_id].name = name # Assign name
d_taxonomy[tax_id].level = str(tab[2].strip()) # Assign level
if tax_id_parent in d_taxonomy:
children = d_taxonomy[tax_id].children # If parent is already in the object
children.append(tax_id) # ... we found its children
d_taxonomy[tax_id].children = children # ... so add them to the parent.
################################################
# END Functions for taxonomy taxdump.tar.gz
################################################
def replace_merged_taxid(taxids, merged_file):
merged_taxid = set()
with open(merged_file) as fl:
for i, l in enumerate(fl):
old_taxid, new_taxid = l.rstrip().replace('\t|', '').split('\t')
if int(old_taxid) in taxids:
taxids.remove(int(old_taxid))
taxids.add(int(new_taxid))
def read_query_length_file(query_length_file):
lengths = {}
for line in open(query_length_file):
(queryid, length) = line.rstrip().split("\t")
lengths[queryid] = float(length)
return lengths
merged_taxid.add(int(old_taxid))
return merged_taxid
def read_blast_input(blastinputfile, lengths, min_identity, max_matches, min_coverage):
# c1.Prot_00001 EFK63346.1 100.0 85 0 0 1 85 62 146 1.6e-36 158.3 85 \
# 146 EFK63346.1 LOW QUALITY PROTEIN: hypothetical protein HMPREF9008_04720, partial [Parabacteroides sp. 20_3]
def get_ranked_lineage(taxids, taxid2rank, ranks_to_keep):
rank2taxid = {taxid2rank[taxid]: taxid for taxid in taxids}
ranked_lineage = []
for rank_to_keep in ranks_to_keep:
try:
ranked_lineage.append(rank2taxid[rank_to_keep])
except KeyError:
ranked_lineage.append('None')
return ranked_lineage
def get_taxid_rank(taxids, nodes):
taxid2rank = {}
with open(nodes) as fl:
for i, l in enumerate(fl):
node_infos = l.rstrip().replace('\t|', ' ').split('\t')
taxid = node_infos[0].strip()
if int(taxid) in taxids:
rank = node_infos[2].strip()
taxid2rank[int(taxid)] = rank
return taxid2rank
#queryId, subjectId, percIdentity, alnLength, mismatchCount, gapOpenCount,
#queryStart, queryEnd, subjectStart, subjectEnd, eVal, bitScore, queryLength, subjectLength, subjectTitle
def get_all_taxid_lineage(taxid_lineage):
taxid2lineage = {}
with open(taxid_lineage) as fl:
for i, l in enumerate(fl):
taxid, taxid_lineage_str = l.rstrip().replace('\t|', ' ').split('\t')
taxid_lineage = [int(taxid) for taxid in taxid_lineage_str.strip().split(' ') if taxid]
taxid_lineage.append(int(taxid))
taxid2lineage[int(taxid)] = taxid_lineage
return taxid2lineage
################################################
# END Functions for taxonomy
################################################
def read_blast_input(blastinputfile, min_identity, min_coverage, top_aln):
logging.info(f'Parsing blast result file {blastinputfile}...')
matches = defaultdict(list)
accs = Counter()
nmatches = Counter()
# accs = Counter()
min_score_per_query = defaultdict(int)
with open(blastinputfile) as blast_handler:
reader = csv.DictReader(blast_handler, delimiter='\t')
for aln in reader:
# (queryId, subjectId, percIdentity, alnLength, mismatchCount, gapOpenCount, \
# queryStart, queryEnd, subjectStart, subjectEnd, eVal, bitScore, queryLength, subjectLength, subjectTitle) \
# = line.rstrip().split("\t")
if aln['sseqid'].startswith("gi|"):
m = re.search(r"gi\|.*?\|.*\|(.*)\|", aln['sseqid'])
acc = m.group(1)
else:
acc = aln['sseqid']
qLength = lengths[aln['qseqid']]
alnLength_in_query = abs(int(aln['qend']) - int(aln['qstart'])) + 1
fHit = float(alnLength_in_query) / qLength
coverage = fHit * 100
fHit *= float(aln['pident']) / 100.0
fHit = min(1.0, fHit)
#hits[queryId] = hits[queryId] + 1
if float(aln['pident']) > min_identity and nmatches[aln['qseqid']] < max_matches and float(coverage) > min_coverage:
matches[aln['qseqid']].append((acc, fHit))
nmatches[aln['qseqid']] += 1
accs[acc] += 1
return (OrderedDict(sorted(matches.items(), key=lambda t: t[0])), list(accs.keys()))
def map_accessions(accs, mapping_file):
first = True
mappings = dict([(acc, -1) for acc in accs])
with open(mapping_file) as mapping_fh:
qlen = int(aln['qlen'])
qseqid = aln['qseqid']
aln_qlen = abs(int(aln['qend']) - int(aln['qstart'])) + 1
qcov = 100 * float(aln_qlen) / qlen
pident = float(aln['pident'])
if pident >= min_identity and qcov >= min_coverage:
score = qcov/100 * pident/100
assert score <= 1.0
# when top = 0, the min_score_per_query is the best score
min_score_per_query[qseqid] = max(
min_score_per_query[qseqid], score*(1 - top_aln/100))
if min_score_per_query[qseqid] <= score:
matches[qseqid].append({'sseqid': acc, 'score': score})
logging.info(f'Parsing blast result file {blastinputfile} is completed. {len(matches)} proteins found with diamond hits.')
matches_filtered = {}
for qseqid, infos in matches.items():
matches_filtered[qseqid] = [
info for info in infos if info['score'] >= min_score_per_query[qseqid]]
return matches_filtered
def parse_accession2taxid(accs, mapping_file):
logging.info(f'Parsing accession2taxid file {mapping_file} to retrieve taxid of {len(accs)} protein accessions.')
counter = 0
total_prot = len(accs)
proper_open = gzip.open if mapping_file.endswith('.gz') else open
mappings = {acc: None for acc in accs}
with proper_open(mapping_file, 'rt') as mapping_fh:
for line in mapping_fh:
_, acc_ver, taxid, _ = line.split("\t")
acc_ver, taxid = line.split("\t")
# Only add taxids for the given acc
if acc_ver in mappings:
counter += 1
if counter % (total_prot/10) == 0:
logging.info(f'{100 * counter/total_prot:.0f}% DONE')
mappings[acc_ver] = int(taxid)
logging.info(f'Parsing accession2taxid is completed. {len([t for t in mappings.values() if not t])}/{len(accs)} accessions have no taxid associated.')
return mappings
def get_consensus(collate_table):
# From collapse_hit retrieve consensus tax_id and lineage
# Compute best lineage consensus
for depth in range(6, -1, -1):
collate = collate_table[depth]
def group_by_contig(matches, contig_pattern):
contig2matches = defaultdict(dict)
for prot_id, hits in matches.items():
contig = contig_pattern.match(prot_id).group(1)
contig2matches[contig][prot_id] = hits
return contig2matches
def collate_protein_hits(sorted_hits, main_ranks, taxid2lineage, accession2taxid):
"""Collatte protein hits."""
taxids_already_processed = set()
collate_hits = {rank: Counter() for rank in main_ranks}
# For each hit, retrieve taxon id and compute weight in lineage
for hit in sorted_hits:
logging.debug('====HIT====')
protein_hit = hit['sseqid']
score = hit['score']
logging.debug(f'{protein_hit}, {score}')
hit_taxid = accession2taxid[protein_hit]
if hit_taxid is None:
logging.debug(f' {protein_hit} accession has no corresponding taxid in accession2taxid file')
# protein hit has not been found in accession2taxid
continue
if hit_taxid in taxids_already_processed:
logging.debug(f' {hit_taxid} already processed')
# Only add the best hit per species
continue
taxids_already_processed.add(hit_taxid)
if hit_taxid in taxid2lineage:
logging.debug(' Hit has a taxo')
hit_taxonomy = taxid2lineage[hit_taxid]
logging.debug(f' Protein taxo: {hit_taxonomy}')
for rank, rank_taxid in zip(main_ranks, hit_taxonomy):
if rank_taxid == "None":
continue
weight = (score - RANKS_TO_MIN_SCORE[rank]) / (1.0 - RANKS_TO_MIN_SCORE[rank])
weight = max(weight, 0.0)
logging.debug(f" R={rank}, T={rank_taxid}, W={weight}")
# could put a transform in here
if weight > 0:
collate_hits[rank][rank_taxid] += weight
else:
logging.debug(f' Hit taxid {hit_taxid} is not found in taxo')
return collate_hits
def get_taxid_consensus(collate_table, main_ranks):
for rank in main_ranks[::-1]:
collate = collate_table[rank]
if not collate:
continue
dWeight = sum(collate.values())
sortCollate = sorted(list(collate.items()), key=operator.itemgetter(1), reverse=True)
nL = len(collate)
if nL > 0:
dP = 0.0
if dWeight > 0.0:
dP = float(sortCollate[0][1]) / dWeight
if dP > MIN_FRACTION:
(fullnamelineage_text, fullnamelineage_ids) = d_taxonomy[str(
sortCollate[0][0])].lineage_main_level()
tax_id_keep = str(sortCollate[0][0])
return (tax_id_keep, fullnamelineage_text, fullnamelineage_ids)
return (1, "Unable to found taxonomy consensus", 1)
def main(argv):
parser = argparse.ArgumentParser()
parser.add_argument("aln_input_file",
help="file with blast/diamond matches expected format m8 \
\nqueryId, subjectId, percIdentity, alnLength, mismatchCount, gapOpenCount,\
queryStart, queryEnd, subjectStart, subjectEnd, eVal, bitScore")
parser.add_argument("query_length_file",
help="tab delimited file of query lengths")
parser.add_argument('-a', '--acc_taxaid_mapping_file',
help="mapping from accession to taxaid gzipped")
parser.add_argument('-t', '--taxonomy',
help="path of taxdump.tar.gz extracted directory")
parser.add_argument('-o', '--output_file', type=str,
default="taxonomyassignation", help=("string specifying output file"))
parser.add_argument('-i', '--identity', default=60,
help="percentage of identity")
parser.add_argument('-m', '--max_matches', default=10,
help="max number of matches to analyze")
parser.add_argument('-c', '--min_coverage', default=70,
help="percentage of coverage")
args = parser.parse_args()
logging.debug(f"{rank}, {sortCollate}, sum score {dWeight}")
best_taxid, best_taxid_score = sortCollate[0]
if len(collate) > 0 and dWeight > 0.0:
dP = float(best_taxid_score) / dWeight
if dP > MIN_FRACTION:
logging.debug(f'-->dP OK {best_taxid}')
return best_taxid
#(fullnamelineage_text, fullnamelineage_ids) = d_taxonomy[str(sortCollate[0][0])].lineage_main_level()
#tax_id_keep = str(sortCollate[0][0])
# return (tax_id_keep, fullnamelineage_text, fullnamelineage_ids)
return 1 # (1,"Unable to found taxonomy consensus",1)
def add_collate_hits(main_collate_hits, collate_hits_to_add):
for rank in main_collate_hits:
logging.debug('RANK {rank}')
for taxid in collate_hits_to_add[rank]:
logging.debug(f' {taxid} {collate_hits_to_add[rank][taxid]}')