diff --git a/humann/humann.py b/humann/humann.py index 63107e5e..91331992 100755 --- a/humann/humann.py +++ b/humann/humann.py @@ -937,23 +937,6 @@ def main(): # If id mapping is provided then process if args.id_mapping: alignments.process_id_mapping(args.id_mapping) - - # Load in the reactions database - reactions_database=None - if config.pathways_database_part1: - reactions_database=store.ReactionsDatabase(config.pathways_database_part1) - - message="Load pathways database part 1: " + config.pathways_database_part1 - logger.info(message) - - # Load in the pathways database - pathways_database=store.PathwaysDatabase(config.pathways_database_part2, reactions_database) - - if config.pathways_database_part1: - message="Load pathways database part 2: " + config.pathways_database_part2 - else: - message="Load pathways database: " + config.pathways_database_part2 - logger.info(message) # Start timer start_time=time.time() @@ -1019,12 +1002,13 @@ def main(): logger.debug("Custom database is empty") reduced_aligned_reads_file = "Empty" unaligned_reads_file_fasta=args.input - unaligned_reads_store=store.Reads(unaligned_reads_file_fasta, minimize_memory_use=minimize_memory_use) + unaligned_reads_store.add_from_fasta(unaligned_reads_file_fasta) # Do not run if set to bypass translated search in config file if not config.bypass_translated_search: # Run translated search on UniRef database if unaligned reads exit if unaligned_reads_store.count_reads()>0: + translated_alignment_file = translated.alignment(config.protein_database, unaligned_reads_file_fasta) @@ -1125,6 +1109,24 @@ def main(): # Clear all of the alignments data as they are no longer needed alignments.clear() + + # Load in the reactions database + reactions_database=None + if config.pathways_database_part1: + reactions_database=store.ReactionsDatabase(config.pathways_database_part1) + + message="Load pathways database part 1: " + config.pathways_database_part1 + logger.info(message) + + # Load in the pathways database + pathways_database=store.PathwaysDatabase(config.pathways_database_part2, reactions_database) + + if config.pathways_database_part1: + message="Load pathways database part 2: " + config.pathways_database_part2 + else: + message="Load pathways database: " + config.pathways_database_part2 + logger.info(message) + # Identify reactions and then pathways from the alignments message="Computing pathways abundance and coverage ..." logger.info(message) diff --git a/humann/search/nucleotide.py b/humann/search/nucleotide.py index d1b075a2..9115f457 100644 --- a/humann/search/nucleotide.py +++ b/humann/search/nucleotide.py @@ -287,6 +287,7 @@ def unaligned_reads(sam_alignment_file, alignments, unaligned_reads_store, keep_ file_handle_write_aligned.close() # process alignments to determine genes for filtering + unaligned_reads_store.start_bulk_write() allowed_genes = blastx_coverage.blastx_coverage(reduced_aligned_reads_file, config.nucleotide_subject_coverage_threshold, alignments, log_messages=True, apply_filter=True, nucleotide=True, query_coverage_threshold=config.nucleotide_query_coverage_threshold, @@ -297,8 +298,8 @@ def unaligned_reads(sam_alignment_file, alignments, unaligned_reads_store, keep_ # read through the file line by line # capture alignments and also write out unaligned reads for next step in processing + alignments.start_bulk_write() line = file_handle_read.readline() - query_ids=set() no_frames_found_count=0 small_identity_count=0 filtered_genes_count=0 @@ -308,7 +309,6 @@ def unaligned_reads(sam_alignment_file, alignments, unaligned_reads_store, keep_ unaligned_read=False if not re.search("^@",line): info=line.split(config.sam_delimiter) - query_ids.add(info[config.blast_query_index]) # check flag to determine if unaligned if int(info[config.sam_flag_index]) & config.sam_unmapped_flag != 0: unaligned_read=True @@ -378,12 +378,8 @@ def unaligned_reads(sam_alignment_file, alignments, unaligned_reads_store, keep_ file_handle_read.close() file_handle_write_unaligned.close() file_handle_write_aligned.close() - - # set the total number of queries - unaligned_reads_store.set_initial_read_count(len(query_ids)) - - # set the unaligned reads file to read sequences from - unaligned_reads_store.set_file(unaligned_reads_file_fasta) + alignments.end_bulk_write() + unaligned_reads_store.end_bulk_write() if write_picked_frames: file_handle_write_unaligned_frames.close() diff --git a/humann/search/translated.py b/humann/search/translated.py index cb9f7896..3f79c919 100644 --- a/humann/search/translated.py +++ b/humann/search/translated.py @@ -294,6 +294,8 @@ def unaligned_reads(unaligned_reads_store, alignment_file_tsv, alignments): # run through final filter of alignment by allowed proteins small_coverage_count=0 + alignments.start_bulk_write() + unaligned_reads_store.start_bulk_write() for alignment_info in utilities.get_filtered_translated_alignments(alignment_file_tsv, alignments, apply_filter=True, log_filter=True, identity_threshold=config.identity_threshold): (protein_name, gene_length, queryid, matches, bug, alignment_length, @@ -308,6 +310,8 @@ def unaligned_reads(unaligned_reads_store, alignment_file_tsv, alignments): unaligned_reads_store.remove_id(queryid) else: small_coverage_count+=1 + alignments.end_bulk_write() + unaligned_reads_store.end_bulk_write() logger.debug("Total translated alignments not included based on small subject coverage value: " + str(small_coverage_count)) diff --git a/humann/store.py b/humann/store.py index 21cb771e..f1150336 100644 --- a/humann/store.py +++ b/humann/store.py @@ -34,6 +34,7 @@ import sys import gzip import bz2 +import sqlite3 from . import config from . import utilities @@ -98,103 +99,87 @@ def normalized_gene_length(gene_length, read_length): return (abs(gene_length - read_length)+1)/1000.0 -class Alignments: - """ - Holds all of the alignments for all bugs - """ - - def __init__(self,minimize_memory_use=None): - self.__total_scores_by_query={} - self.__multiple_hits_queries={} - self.__hits_by_query={} - self.__scores_by_bug_gene={} - self.__gene_counts={} - self.__bug_counts={} - self.__id_mapping={} - - self.__temp_alignments_file=None - self.__temp_alignments_file_handle=None - self.__delimiter="\t" - - if minimize_memory_use: - self.__minimize_memory_use=True - logger.debug("Initialize Alignments class instance to minimize memory use") - else: - self.__minimize_memory_use=False - logger.debug("Initialize Alignments class instance to maximize memory use") - - def write_temp_alignments_file(self,query,bug,reference,score,normalized_reference_length): +class SqliteStore: + def __init__(self, minimize_memory_use = None): + self.__minimize_memory_use=minimize_memory_use + self.__dbpath = None + self.__conn = None + self.__is_within_transaction = False + self.__stateful_ops_in_bulk_write = None + + + def connect(self): """ - Write an alignment to the temp alignments file, first create if needed + Open the sqlite3 connection """ + if self.__conn: + return + + if not self.__dbpath: + store_name=type(self).__name__ + if self.__minimize_memory_use: + self.__dbpath = utilities.unnamed_temp_file(store_name + ".sqlite") + logger.debug("Initializing {0} store backed by a temporary file to minimize memory use".format(store_name)) + else: + self.__dbpath = ":memory:" + logger.debug("Initializing {0} store in-memory".format(store_name)) - if not self.__temp_alignments_file: - self.create_temp_alignments_file() - - line=self.__delimiter.join([query,bug,reference,str(score),str(normalized_reference_length)]) + self.__conn = sqlite3.connect(self.__dbpath, isolation_level=None) - try: - self.__temp_alignments_file_handle.write(line+"\n") - except EnvironmentError: - logger.warning("Unable to write to temp alignments file") - - def read_temp_alignments_file(self, queries): + def do(self, *args): """ - Read in those alignments which are included in queries + Run a stateful statement like add or delete + If within a transaction, commit and reopen every 100k operations """ - - # close and reopen the temp alignments file - line="" - try: - self.__temp_alignments_file_handle.close() - self.__temp_alignments_file_handle=open(self.__temp_alignments_file, "rt") - - line=self.__temp_alignments_file_handle.readline() - except (EnvironmentError, AttributeError): - pass - - while line: - # lines should be of the format query \t bug \t reference \t score \t length - (query,bug,reference,score,length)=line.rstrip().split(self.__delimiter) - if query in queries: - yield (query,bug,reference,float(score),float(length)) - - line=self.__temp_alignments_file_handle.readline() - - try: - self.__temp_alignments_file_handle.close() - except (EnvironmentError, AttributeError): - pass - - def create_temp_alignments_file(self): + self.__conn.execute(*args) + if self.__is_within_transaction: + self.__stateful_ops_in_bulk_write +=1 + if self.__stateful_ops_in_bulk_write % 100000 == 0: + self.__conn.execute("commit transaction") + self.__conn.execute("begin transaction") + + def query(self, *args): """ - Create and open a temp alignments file + Use the sqlite3 connection """ - - self.__temp_alignments_file=utilities.unnamed_temp_file("temp_alignments") - - try: - self.__temp_alignments_file_handle=open(self.__temp_alignments_file, "w") - except EnvironmentError: - sys.exit("CRITICAL ERROR: Unable to open temp alignments file") - - def delete_temp_alignments_file(self): + return self.__conn.execute(*args) + + def clear(self): """ - Delete the temp alignments file + Clear all of the stored data """ - try: - self.__temp_alignments_file_handle.close() - except EnvironmentError: - pass - - try: - os.unlink(self.__temp_alignments_file) - except EnvironmentError: - logger.warning("Unable to delete the temp alignments file") - - self.__temp_alignments_file=None - self.__temp_alignments_file_handle=None + self.__conn.close() + self.__conn = None + + def start_bulk_write(self): + self.__is_within_transaction = True + self.__stateful_ops_in_bulk_write = 0 + self.__conn.execute("begin transaction") + + def end_bulk_write(self): + self.__conn.execute("commit transaction") + self.__is_within_transaction = False + self.__stateful_ops_in_bulk_write = None + + +class Alignments(SqliteStore): + + """ + Holds all of the alignments for all bugs + """ + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.connect() + self.do('''create table alignment ( + query text not null, + bug text not null, + reference text not null, + score real not null, + length real not null + );''') + self.__id_mapping={} def process_id_mapping(self,file): """ @@ -307,134 +292,61 @@ def add(self, reference, reference_length, query, matches, bug, read_length=None except ValueError: logger.debug("Could not convert the number of matches to score: " + str(matches)) score=0.0 - - # Increase the counts for gene and bug - self.__bug_counts[bug]=self.__bug_counts.get(bug,0)+1 - self.__gene_counts[reference]=self.__gene_counts.get(reference,0)+1 - - # Add to the scores by query and store if query has multiple scores - if query in self.__total_scores_by_query: - current_query_total=self.__total_scores_by_query[query] - if query in self.__multiple_hits_queries: - self.__multiple_hits_queries[query]=self.__multiple_hits_queries[query]+[score] - else: - self.__multiple_hits_queries[query]=[current_query_total,score] - - self.__total_scores_by_query[query]=current_query_total+score - else: - self.__total_scores_by_query[query]=score - + # Store the scores by bug and gene normalized_reference_length=normalized_gene_length(reference_length, read_length) - normalized_score=1/normalized_reference_length - if bug in self.__scores_by_bug_gene: - self.__scores_by_bug_gene[bug][reference]=self.__scores_by_bug_gene[bug].get(reference,0)+normalized_score - else: - self.__scores_by_bug_gene[bug]={reference:normalized_score} - - # write the information for the hit to the temp alignments file - # or store in memory depending on the memory use setting - if self.__minimize_memory_use: - self.write_temp_alignments_file(query, bug, reference, score, normalized_reference_length) - else: - hit=(bug,reference,score,normalized_reference_length) - if query in self.__hits_by_query: - self.__hits_by_query[query].append(hit) - else: - self.__hits_by_query[query]=[hit] + # write the information for the hit + self.do('insert into alignment (query, bug, reference, score, length) values (?,?,?,?,?)', [query, bug, reference, score, normalized_reference_length]) + def count_bugs(self): """ Return total number of bugs """ - return len(self.__bug_counts) + + return self.query("select count (distinct bug) from alignment").fetchone()[0] def count_genes(self): """ Return total number of genes """ - return len(self.__gene_counts) + + return self.query("select count (distinct reference) from alignment").fetchone()[0] def counts_by_bug(self): """ Return each bug and the total number of hits """ - lines=[] - for bug in self.__bug_counts: - lines.append(bug + ": " + str(self.__bug_counts.get(bug,0)) + " hits") - - return "\n".join(lines) + + return "\n".join(["{0}: {1} hits".format(row[0], row[1]) for row in self.query('select bug, count(*) as c from alignment group by bug order by -c')]) def gene_list(self): """ Return a list of all of the gene families """ - return list(self.__gene_counts.keys()) + return [row[0] for row in self.query('select distinct reference from alignment')] def bug_list(self): """ Return a list of all of the bugs """ - - return list(self.__bug_counts.keys()) + + return [row[0] for row in self.query('select distinct bug from alignment')] def get_hit_list(self): """ Return a list of all of the hits """ - - # Add the query to the hits - list=[] - # if the hits are stored in memory use the dictionary - if self.__hits_by_query: - for query in self.__hits_by_query: - for (bug,reference,score,length) in self.__hits_by_query[query]: - list.append([query]+[bug,reference,score,length]) - else: - # else read through the temp file for the hits - for (query,bug,reference,score,length) in self.read_temp_alignments_file(self.__total_scores_by_query): - list.append([query,bug,reference,score,length]) - - return list + + return [row for row in self.query('select query, bug, reference, score, length from alignment')] def hits_for_gene(self,gene): """ Return a list of all of the hits for a specific gene """ - # Add the query to the hits - list=[] - # if the hits are stored in memory use the dictionary - if self.__hits_by_query: - for query in self.__hits_by_query: - for (bug,reference,score,length) in self.__hits_by_query[query]: - if reference==gene: - list.append([query]+[bug,reference,score,length]) - else: - # else read through the temp file for the hits - for (query,bug,reference,score,length) in self.read_temp_alignments_file(self.__total_scores_by_query): - if reference==gene: - list.append([query,bug,reference,score,length]) - - return list - - def add_query_normalization_to_alignment_score(self,query,bug,reference,score,length): - """ - Update the gene score added for the single alignment provided - This update adds the query normalization - """ - - # Normalize by query hits for all queries with multiple hits - # Hits where it is the only match per query will have scores of 1 - # as this is the result of normalizing (ie score/score) - - query_normalize=self.__total_scores_by_query[query] - - original_score=1/length - updated_score=score/query_normalize*original_score - self.__scores_by_bug_gene[bug][reference]=self.__scores_by_bug_gene[bug][reference]-original_score+updated_score - + return [row for row in self.query('select query, bug, reference, score, length from alignment where reference=?', [gene])] def convert_alignments_to_gene_scores(self,gene_scores_store): """ @@ -442,53 +354,62 @@ def convert_alignments_to_gene_scores(self,gene_scores_store): Add to the gene_scores store """ - # Normalize by query hits for all queries with multiple hits + # calculate the score per bug and gene + # for a single query result, it is 1/a.length - # process through the temp alignments file if the data is not stored in memory - if not self.__hits_by_query: - for (query,bug,reference,score,length) in self.read_temp_alignments_file(self.__multiple_hits_queries): - self.add_query_normalization_to_alignment_score(query,bug,reference,score,length) - # use the hits stored in memory - else: - for query in self.__multiple_hits_queries: - for (bug,reference,score,length) in self.__hits_by_query[query]: - self.add_query_normalization_to_alignment_score(query, bug, reference, score, length) - - # compute the scores for the genes - all_gene_scores={} - messages=[] - for bug in self.__scores_by_bug_gene: - # Add up all genes scores for each bug - for gene in self.__scores_by_bug_gene[bug]: - all_gene_scores[gene]=all_gene_scores.get(gene,0)+self.__scores_by_bug_gene[bug][gene] - # Add to the gene scores structure - gene_scores_store.add(self.__scores_by_bug_gene[bug],bug) - total_gene_families_for_bug=len(self.__scores_by_bug_gene[bug]) - messages.append(bug + " : " + str(total_gene_families_for_bug) + " gene families") - - # add all gene scores to structure - gene_scores_store.add(all_gene_scores,"all") - - # print messages if in verbose mode - message="\n".join(messages) - message="Total gene families : " +str(len(all_gene_scores))+"\n"+message + query=None + try: + # by default, bowtie2 and diamond are ran to report the best alignment + # try create a unique index and use a simpler query + self.do('create unique index query_uix on alignment(query)') + query= ''' + select a.bug, a.reference, sum(1/a.length) as score from alignment a + group by bug, reference + order by bug, reference + ''' + except sqlite3.IntegrityError: + # if a query matches multiple bugs and genes, its score is distributed by weighted average + # scores from multiple queries are added up per bug and gene + # see unit tests for examples + query=''' + select bug, reference, sum(normalized_score_partial) as score from ( + select + a.query, + a.bug, + a.reference, + sum(a.score / a.length ) / (total_score_for_query) as normalized_score_partial + from + alignment a join ( + select query, sum(score) as total_score_for_query + from alignment group by query + ) as m + where a.query = m.query + group by a.bug, a.reference, a.query + ) group by bug, reference + order by bug, reference + ''' + result={} + resultAll={} + for bug, gene, score in self.query(query): + if bug not in result: + result[bug]={} + result[bug][gene]=score + if gene not in resultAll: + resultAll[gene]=0 + resultAll[gene]+=score + + # Add to the store + for bug in result: + gene_scores_store.add(result[bug], bug) + gene_scores_store.add(resultAll, "all") + + # Log a summary, and print if in verbose mode + message="\n".join(["{0} : {1} gene families".format(bug, len(result[bug])) for bug in result]) + message="Total gene families : " +str(sum([len(result[bug]) for bug in result]))+"\n"+message if config.verbose: print(message) logger.info("\n"+message) - - def clear(self): - """ - Clear all of the stored data - """ - - self.__total_scores_by_query.clear() - self.__multiple_hits_queries.clear() - self.__hits_by_query.clear() - self.__scores_by_bug_gene.clear() - self.__gene_counts.clear() - self.__bug_counts.clear() - class GeneScores: """ Holds scores for all of the genes @@ -1234,10 +1155,17 @@ def get_database(self): config.pathways_database_delimiter.join(self.__pathways_to_reactions[pathway])) return "\n".join(data) -class Reads: +class Reads(SqliteStore): """ Holds all of the reads data to create a fasta file """ + def __init__(self, file=None, **kwargs): + super().__init__(**kwargs) + self.connect() + self.do('''create table read ( + id text not null primary key, + sequence text not null + );''') def add(self, id, sequence): """ @@ -1245,11 +1173,7 @@ def add(self, id, sequence): >id sequence """ - - if self.__minimize_memory_use: - self.__ids.add(id) - else: - self.__reads[id]=sequence + self.do('insert or ignore into read (id, sequence) values (?,?)', [id, sequence]) def process_file(self, file): """ @@ -1294,104 +1218,36 @@ def process_file(self, file): # Remove the temp fasta file if exists if temp_file: utilities.remove_file(temp_file) - - def __init__(self, file=None, minimize_memory_use=None): - """ - Create initial data structures and load if file name provided - """ - self.__reads={} - self.__ids=set() - self.__initial_read_count=0 - self.__file=file - - if minimize_memory_use: - self.__minimize_memory_use=True - logger.debug("Initialize Reads class instance to minimize memory use") - else: - self.__minimize_memory_use=False - logger.debug("Initialize Reads class instance to maximize memory use") - if self.__file: - for (id,sequence) in self.process_file(file): - self.add(id, sequence) - self.__initial_read_count+=1 - - def set_file(self, file): - """ - Set the file to read sequences from - """ - - self.__file=file + def add_from_fasta(self, fasta_path): + for (id,sequence) in self.process_file(fasta_path): + self.add(id, sequence) def remove_id(self, id): """ Remove the id and sequence from the read structure """ - if id in self.__reads: - del self.__reads[id] - elif id in self.__ids: - self.__ids.discard(id) - - def get_fasta(self, file=None): + self.do('delete from read where id = ? or id = ?', [id, utilities.remove_length_annotation(id)]) + + def get_fasta(self): """ - Return a string of the fasta file sequences stored or read from a file + Return a string of the fasta file sequences stored """ - - if not file: - file=self.__file - - # use the stored reads if present - if self.__reads: - for id, sequence in self.__reads.items(): - yield ">"+id+"\n"+sequence - else: - if file: - for id, sequence in self.process_file(file): - # check for the id or the id without the length annotation - if utilities.remove_length_annotation(id) in self.__ids or id in self.__ids: - yield ">"+id+"\n"+sequence - + for row in self.query('select id, sequence from read'): + yield ">{0}\n{1}".format(*row) + def id_list(self): """ Return a list of all of the fasta ids """ - - if self.__reads: - return list(self.__reads.keys()) - else: - return list(self.__ids) + return [row[0] for row in self.query('select id from read')] def count_reads(self): """ Return the total number of reads stored """ - if self.__reads: - return len(self.__reads.keys()) - else: - return len(self.__ids) - - def clear(self): - """ - Clear all of the stored reads and ids - """ - - self.__reads.clear() - self.__ids.clear() - - def set_initial_read_count(self,total): - """ - Set the total number of reads from the original input file - """ - - self.__initial_read_count=total - - def get_initial_read_count(self): - """ - Get the total number of reads from the original input file - """ - - return self.__initial_read_count + return self.query("select count(*) from read").fetchone()[0] class Names: """ diff --git a/humann/tests/advanced_tests_quantify_families.py b/humann/tests/advanced_tests_quantify_families.py index 824d3504..3e1811ba 100644 --- a/humann/tests/advanced_tests_quantify_families.py +++ b/humann/tests/advanced_tests_quantify_families.py @@ -2,7 +2,7 @@ import re import tempfile import os -import filecmp +import io import logging import cfg @@ -125,8 +125,9 @@ def test_gene_families_tsv_output(self): gene_families_file=families.gene_families(alignments,gene_scores,0) # check the gene families output is as expected - self.assertTrue(filecmp.cmp(gene_families_file, - cfg.gene_familes_file, shallow=False)) + self.assertListEqual( + list(io.open(gene_families_file)), + list(io.open(cfg.gene_familes_file))) # reset the mapping file config.gene_family_name_mapping_file=original_gene_family_mapping_file @@ -183,9 +184,10 @@ def test_gene_families_tsv_output_with_names(self): gene_families_file=families.gene_families(alignments,gene_scores,1) # check the gene families output is as expected - self.assertTrue(filecmp.cmp(gene_families_file, - cfg.gene_familes_uniref50_with_names_file, shallow=False)) - + self.assertListEqual( + list(io.open(gene_families_file)), + list(io.open(cfg.gene_familes_uniref50_with_names_file))) + # reset the mapping file config.gene_family_name_mapping_file=original_gene_family_mapping_file diff --git a/humann/tests/advanced_tests_store.py b/humann/tests/advanced_tests_store.py index 69ba2239..9031eca5 100644 --- a/humann/tests/advanced_tests_store.py +++ b/humann/tests/advanced_tests_store.py @@ -99,7 +99,7 @@ def test_Alignments_compute_gene_scores_single_gene_double_query(self): query1_sum=hit1_score+hit2_score gene_score=hit1_score/query1_sum/gene1_length - self.assertEqual(gene_scores_store.get_score("bug1","gene1"),gene_score) + self.assertAlmostEqual(gene_scores_store.get_score("bug1","gene1"),gene_score) def test_Alignments_compute_gene_scores_double_gene_double_query(self): """ @@ -231,257 +231,6 @@ def test_Alignments_id_mapping_half_hits(self): self.assertEqual(sorted(stored_lengths),sorted([1/1000.0,100/1000.0, 200/1000.0,1000/1000.0])) - def test_Alignments_compute_gene_scores_single_gene_single_query_with_temp_alignment_file(self): - """ - Test the compute_gene_scores function - Test one hit for gene with one hit for query - Test with the temp alignment file - """ - - # create a set of hits - matches1=41.0 - matches2=57.1 - matches3=61.0 - matches4=72.1 - - gene1_length=2 - gene2_length=3 - gene3_length=4 - - # Create a set of alignments - alignments_store=store.Alignments(minimize_memory_use=True) - alignments_store.add("gene1",gene1_length,"query1",matches1,"bug1") - alignments_store.add("gene2",gene2_length,"query1",matches2,"bug1") - alignments_store.add("gene2",gene2_length,"query2",matches3,"bug1") - alignments_store.add("gene3",gene3_length,"query3",matches4,"bug1") - - gene_scores_store=store.GeneScores() - - # compute gene scores - alignments_store.convert_alignments_to_gene_scores(gene_scores_store) - - # convert lengths to per kb - gene3_length=gene3_length/1000.0 - - # gene3 - hit4_score=math.pow(matches4, config.match_power) - query3_sum=hit4_score - expected_gene_score=hit4_score/query3_sum/gene3_length - - actual_gene_score=gene_scores_store.get_score("bug1","gene3") - - # delete the temp alignment file - alignments_store.delete_temp_alignments_file() - - self.assertEqual(actual_gene_score,expected_gene_score) - - def test_Alignments_compute_gene_scores_single_gene_double_query_with_temp_alignment_file(self): - """ - Test the compute_gene_scores function - Test one hit for gene with more than one hit per query - Test with the temp alignment file - """ - - # create a set of hits - # bug, reference, reference_length, query, matches = hit - - matches1=41.0 - matches2=57.1 - matches3=61.0 - matches4=72.1 - - gene1_length=2 - gene2_length=3 - gene3_length=4 - - # Create a set of alignments - alignments_store=store.Alignments(minimize_memory_use=True) - alignments_store.add("gene1",gene1_length,"query1",matches1,"bug1") - alignments_store.add("gene2",gene2_length,"query1",matches2,"bug1") - alignments_store.add("gene2",gene2_length,"query2",matches3,"bug1") - alignments_store.add("gene3",gene3_length,"query3",matches4,"bug1") - - gene_scores_store=store.GeneScores() - - # compute gene scores - alignments_store.convert_alignments_to_gene_scores(gene_scores_store) - - # convert lengths to per kb - gene1_length=gene1_length/1000.0 - - # gene1 - hit1_score=math.pow(matches1, config.match_power) - hit2_score=math.pow(matches2, config.match_power) - query1_sum=hit1_score+hit2_score - expected_gene_score=hit1_score/query1_sum/gene1_length - - actual_gene_score=gene_scores_store.get_score("bug1","gene1") - - # delete the temp alignment file - alignments_store.delete_temp_alignments_file() - - self.assertAlmostEqual(actual_gene_score,expected_gene_score) - - def test_Alignments_compute_gene_scores_double_gene_double_query_with_temp_alignment_file(self): - """ - Test the compute_gene_scores function - Test two hits to gene with more than one hit per query - Test with the temp alignment file - """ - - # create a set of hits - # bug, reference, reference_length, query, matches = hit - - matches1=41.0 - matches2=57.1 - matches3=61.0 - matches4=72.1 - - gene1_length=2 - gene2_length=3 - gene3_length=4 - - # Create a set of alignments - alignments_store=store.Alignments(minimize_memory_use=True) - alignments_store.add("gene1",gene1_length,"query1",matches1,"bug1") - alignments_store.add("gene2",gene2_length,"query1",matches2,"bug1") - alignments_store.add("gene2",gene2_length,"query2",matches3,"bug1") - alignments_store.add("gene3",gene3_length,"query3",matches4,"bug1") - - gene_scores_store=store.GeneScores() - - # compute gene scores - alignments_store.convert_alignments_to_gene_scores(gene_scores_store) - - # gene1 - hit1_score=math.pow(matches1, config.match_power) - hit2_score=math.pow(matches2, config.match_power) - query1_sum=hit1_score+hit2_score - - # convert lengths to per kb - gene2_length=gene2_length/1000.0 - - # gene2 - hit3_score=math.pow(matches3, config.match_power) - query2_sum=hit3_score - expected_gene_score=hit3_score/query2_sum/gene2_length + hit2_score/query1_sum/gene2_length - - actual_gene_score=gene_scores_store.get_score("bug1","gene2") - - # delete the temp alignment file - alignments_store.delete_temp_alignments_file() - - self.assertAlmostEqual(actual_gene_score,expected_gene_score,places=7) - - def test_Alignments_id_mapping_all_gene_list_with_temp_alignment_file(self): - """ - Test the store_id_mapping function - Test the add_annotated and process_reference_annotation with id mapping - Test the genes are mapped correctly - Test with the temp alignment file - """ - - alignments_store=store.Alignments(minimize_memory_use=True) - - # load in the id_mapping file - alignments_store.process_id_mapping(cfg.id_mapping_file) - - # store some alignments - alignments_store.add_annotated("query1",1,"ref1") - alignments_store.add_annotated("query2",1,"ref2") - alignments_store.add_annotated("query3",1,"ref3") - - gene_list=alignments_store.gene_list() - - # delete the temp alignment file - alignments_store.delete_temp_alignments_file() - - # test the genes are correct - self.assertEqual(sorted(gene_list),sorted(["gene1","gene2","gene3"])) - - def test_Alignments_id_mapping_all_bug_list_with_temp_alignment_file(self): - """ - Test the store_id_mapping function - Test the add_annotated and process_reference_annotation with id mapping - Test the bugs are mapped correctly - Test with the temp alignment file - """ - - alignments_store=store.Alignments(minimize_memory_use=True) - - # load in the id_mapping file - alignments_store.process_id_mapping(cfg.id_mapping_file) - - # store some alignments - alignments_store.add_annotated("query1",1,"ref1") - alignments_store.add_annotated("query2",1,"ref2") - alignments_store.add_annotated("query3",1,"ref3") - - bug_list=alignments_store.bug_list() - - # delete the temp alignment file - alignments_store.delete_temp_alignments_file() - - # test the bugs are correct - self.assertEqual(sorted(bug_list),sorted(["bug3","unclassified"])) - - def test_Alignments_id_mapping_all_hits_with_temp_alignment_file(self): - """ - Test the store_id_mapping function - Test the add_annotated and process_reference_annotation with id mapping - Test the lengths are mapped correctly - Test with the temp alignment file - """ - - alignments_store=store.Alignments(minimize_memory_use=True) - - # load in the id_mapping file - alignments_store.process_id_mapping(cfg.id_mapping_file) - - # store some alignments - alignments_store.add_annotated("query1",1,"ref1") - alignments_store.add_annotated("query2",1,"ref2") - alignments_store.add_annotated("query3",1,"ref3") - - hit_list=alignments_store.get_hit_list() - - # delete the temp alignment file - alignments_store.delete_temp_alignments_file() - - # test the lengths are correct - stored_lengths=[item[-1] for item in hit_list] - self.assertEqual(sorted(stored_lengths),sorted([1/1000.0,10/1000.0,1000/1000.0])) - - def test_Alignments_id_mapping_half_hits_with_temp_alignment_file(self): - """ - Test the store_id_mapping function - Test the add_annotated and process_reference_annotation with id mapping - Test the lengths are mapped correctly with only some references included - in those provided for id mapping - Test with the temp alignment file - """ - - alignments_store=store.Alignments(minimize_memory_use=True) - - # load in the id_mapping file - alignments_store.process_id_mapping(cfg.id_mapping_file) - - # store some alignments - alignments_store.add_annotated("query1",1,"ref1") - alignments_store.add_annotated("query2",1,"ref2") - alignments_store.add_annotated("query3",1,"ref1|100") - alignments_store.add_annotated("query3",1,"200|ref2") - - hit_list=alignments_store.get_hit_list() - - # delete the temp alignment file - alignments_store.delete_temp_alignments_file() - - # test the lengths are correct - stored_lengths=[item[-1] for item in hit_list] - self.assertEqual(sorted(stored_lengths),sorted([1/1000.0,100/1000.0, - 200/1000.0,1000/1000.0])) - def test_GeneScores_add_from_file_id_mapping_bug_list(self): """ GeneScores class: Test add_from_file bug list with id mapping diff --git a/humann/tests/basic_tests_store.py b/humann/tests/basic_tests_store.py index 31ab4ad3..3a177d6a 100644 --- a/humann/tests/basic_tests_store.py +++ b/humann/tests/basic_tests_store.py @@ -57,7 +57,8 @@ def test_Read_print_fasta_id_count(self): Test the total number of expected ids are loaded """ - reads_store=store.Reads(cfg.small_fasta_file) + reads_store=store.Reads() + reads_store.add_from_fasta(cfg.small_fasta_file) # Check that the total number of expected reads are loaded self.assertEqual(len(reads_store.id_list()), cfg.small_fasta_file_total_sequences) @@ -68,7 +69,8 @@ def test_Read_print_fasta_count_reads(self): Test the total number of expected reads counted """ - reads_store=store.Reads(cfg.small_fasta_file) + reads_store=store.Reads() + reads_store.add_from_fasta(cfg.small_fasta_file) # Check that the total number of expected reads are counted self.assertEqual(reads_store.count_reads(), cfg.small_fasta_file_total_sequences) @@ -80,7 +82,8 @@ def test_Read_print_fasta_count_reads_minimize_memory_use(self): Test with minimize memory use """ - reads_store=store.Reads(cfg.small_fasta_file, minimize_memory_use=True) + reads_store=store.Reads(minimize_memory_use=True) + reads_store.add_from_fasta(cfg.small_fasta_file) # Check that the total number of expected reads are counted self.assertEqual(reads_store.count_reads(), cfg.small_fasta_file_total_sequences) @@ -92,7 +95,8 @@ def test_Read_print_fasta_id_list(self): Test the expected ids are loaded """ - reads_store=store.Reads(cfg.small_fasta_file) + reads_store=store.Reads() + reads_store.add_from_fasta(cfg.small_fasta_file) # Check the reads are printed correctly stored_fasta=[] @@ -139,7 +143,8 @@ def test_Read_print_fasta_id_list_minimize_memory_use(self): Test with minimize memory use """ - reads_store=store.Reads(cfg.small_fasta_file, minimize_memory_use=True) + reads_store=store.Reads(minimize_memory_use=True) + reads_store.add_from_fasta(cfg.small_fasta_file) # Check the reads are printed correctly stored_fasta=[] @@ -185,7 +190,8 @@ def test_Read_print_fasta_sequence_list(self): Test the sequences are loaded """ - reads_store=store.Reads(cfg.small_fasta_file) + reads_store=store.Reads() + reads_store.add_from_fasta(cfg.small_fasta_file) # Check the reads are printed correctly stored_fasta=[] @@ -229,7 +235,8 @@ def test_Read_print_fasta_sequence_list_minimize_memory_use(self): Test with minimize memory use """ - reads_store=store.Reads(cfg.small_fasta_file, minimize_memory_use=True) + reads_store=store.Reads(minimize_memory_use=True) + reads_store.add_from_fasta(cfg.small_fasta_file) # Check the reads are printed correctly stored_fasta=[] @@ -271,7 +278,8 @@ def test_Read_delete_id(self): Read class: Test the deleting of ids """ - reads_store=store.Reads(cfg.small_fasta_file) + reads_store=store.Reads() + reads_store.add_from_fasta(cfg.small_fasta_file) # delete all but one of the reads and check structure is empty id_list=reads_store.id_list() @@ -288,7 +296,8 @@ def test_Read_delete_id_minimize_memory_use(self): Test with minimial memory use """ - reads_store=store.Reads(cfg.small_fasta_file, minimize_memory_use=True) + reads_store=store.Reads(minimize_memory_use=True) + reads_store.add_from_fasta(cfg.small_fasta_file) # delete all but one of the reads and check structure is empty id_list=reads_store.id_list() @@ -778,7 +787,7 @@ def test_PathwaysAndReactions_max_median_score_even_number_vary_pathways(self): # test median score for an even number of values self.assertEqual(pathways_and_reactions.max_median_score("bug"),2.5) - + def test_Alignments_add_bug_count(self): """ Alignments class: Test add function @@ -881,52 +890,6 @@ def test_Alignments_add_gene_lengths_with_read_length_normalization(self): # test the lengths are correct stored_lengths=[item[-1] for item in alignments_store.get_hit_list()] self.assertEqual(sorted(stored_lengths),sorted([1/1000.0,91/1000.0,901/1000.0,901/1000.0])) - - def test_Alignments_add_gene_lengths_with_temp_alignment_file(self): - """ - Alignments class: Test add function - Test the gene lengths - Test using the temp alignment file instead of storing data in memory - """ - - alignments_store=store.Alignments(minimize_memory_use=True) - - alignments_store.add("gene2", 10, "Q3", 0.01, "bug1",1) - alignments_store.add("gene1", 100, "Q1", 0.01, "bug2",1) - alignments_store.add("gene3", 1000, "Q2", 0.01, "bug3",1) - alignments_store.add("gene1", 0, "Q1", 0.01, "bug1",1) - - # test the lengths are correct - stored_lengths=[item[-1] for item in alignments_store.get_hit_list()] - - # delete the temp alignment file - alignments_store.delete_temp_alignments_file() - - self.assertEqual(sorted(stored_lengths),sorted([10/1000.0,100/1000.0,1000/1000.0,1000/1000.0])) - - def test_Alignments_add_gene_lengths_with_read_length_normalization(self): - """ - Alignments class: Test add function - Test the gene lengths with read length normalization - """ - - alignments_store=store.Alignments(minimize_memory_use=True) - - # set the average read length - average_read_length=100 - - alignments_store.add("gene2", 10, "Q3", 0.01, "bug1",average_read_length) - alignments_store.add("gene1", 100, "Q1", 0.01, "bug2",average_read_length) - alignments_store.add("gene3", 1000, "Q2", 0.01, "bug3",average_read_length) - alignments_store.add("gene1", 0, "Q1", 0.01, "bug1",average_read_length) - - # test the lengths are correct - stored_lengths=[item[-1] for item in alignments_store.get_hit_list()] - - # delete the temp alignment file - alignments_store.delete_temp_alignments_file() - - self.assertEqual(sorted(stored_lengths),sorted([1/1000.0,91/1000.0,901/1000.0,901/1000.0])) def test_Alignments_process_chocophlan_length(self): """ @@ -1168,3 +1131,48 @@ def test_GeneScores_add_from_file_scores(self): self.assertDictEqual(cfg.genetable_file_bug_scores[bug],gene_scores.scores_for_bug(bug)) + def test_Alignments_clear(self): + """ + Alignments class: test clear function + Test a cleared used store is like a cleared new store + """ + alignments_store=store.Alignments() + alignments_store.add("gene2", 1, "Q3", 0.01, "bug1",1) + + second_alignments_store=store.Alignments() + + self.assertNotEqual(vars(alignments_store), vars(second_alignments_store)) + + alignments_store.clear() + second_alignments_store.clear() + + self.assertEqual(vars(alignments_store), vars(second_alignments_store)) + + def test_Alignments_bulk(self): + with self.subTest(): + self._test_Alignments_bulk(minimize_memory_use = True) + self._test_Alignments_bulk(minimize_memory_use = False) + + def _test_Alignments_bulk(self, **kwargs): + """ + Alignments class: test transaction management + Try select within transaction, rollback, begin and end, and end without begin + """ + with self.subTest("Checking transactions", **kwargs): + alignments_store=store.Alignments(**kwargs) + alignments_store.add("gene1", 1, "Q1", 0.01, "bug1",1) + hits=alignments_store.get_hit_list() + + alignments_store.start_bulk_write() + alignments_store.add("gene2", 2, "Q2", 0.02, "bug2",2) + self.assertNotEqual(alignments_store.get_hit_list(), hits) + alignments_store.do("rollback transaction") + self.assertEqual(alignments_store.get_hit_list(), hits) + + second_alignments_store=store.Alignments(**kwargs) + second_alignments_store.start_bulk_write() + second_alignments_store.add("gene1", 1, "Q1", 0.01, "bug1",1) + second_alignments_store.end_bulk_write() + self.assertEqual(alignments_store.get_hit_list(), hits) + + self.assertRaises(BaseException, alignments_store.end_bulk_write) diff --git a/humann/utilities.py b/humann/utilities.py index 5e886137..36ff79f1 100644 --- a/humann/utilities.py +++ b/humann/utilities.py @@ -258,30 +258,12 @@ def gunzip_file(gzip_file): return new_file -def double_sort(pathways_dictionary): +def double_sort(d): """ - Return the keys to a dictionary sorted with top values first - then for duplicate values sorted alphabetically by key + Return the keys to a dictionary of floats + sorted by decreasing rounded value, then alphabetically by key """ - - sorted_keys=[] - prior_value="" - store=[] - for pathway in sorted(pathways_dictionary, key=pathways_dictionary.get, reverse=True): - if prior_value == pathways_dictionary[pathway]: - if not store: - store.append(sorted_keys.pop()) - store.append(pathway) - else: - if store: - sorted_keys+=sorted(store) - store=[] - prior_value=pathways_dictionary[pathway] - sorted_keys.append(pathway) - - if store: - sorted_keys+=sorted(store) - return sorted_keys + return [t[1] for t in sorted([(-round(d[k],config.output_max_decimals), k) for k in d])] def unnamed_temp_file(prefix=None): """ @@ -872,13 +854,7 @@ def estimate_unaligned_reads_stored(input_fastq, unaligned_store): Calculate an estimate of the percent of reads unaligned and stored """ - # check if the total number of reads from the input file is stored - if not unaligned_store.get_initial_read_count(): - # check files exist and are readable - file_exists_readable(input_fastq) - unaligned_store.set_initial_read_count(count_reads(input_fastq)) - - percent=unaligned_store.count_reads()/float(unaligned_store.get_initial_read_count()) * 100 + percent=unaligned_store.count_reads()/float(count_reads(input_fastq)) * 100 return format_float_to_string(percent)