diff --git a/src/protein_quest/alphafold/confidence.py b/src/protein_quest/alphafold/confidence.py index ffaa1b5..d4d8892 100644 --- a/src/protein_quest/alphafold/confidence.py +++ b/src/protein_quest/alphafold/confidence.py @@ -10,7 +10,7 @@ from protein_quest.converter import Percentage, PositiveInt, converter from protein_quest.io import read_structure, write_structure from protein_quest.ss import nr_of_residues_in_total -from protein_quest.utils import CopyMethod, copyfile +from protein_quest.utils import Cacher, PassthroughCacher, copyfile """ Methods to filter AlphaFoldDB structures on confidence scores. @@ -113,7 +113,7 @@ class ConfidenceFilterResult: def filter_file_on_residues( - file: Path, query: ConfidenceFilterQuery, filtered_dir: Path, copy_method: CopyMethod = "copy" + file: Path, query: ConfidenceFilterQuery, filtered_dir: Path, cacher: Cacher | None = None ) -> ConfidenceFilterResult: """Filter a single AlphaFoldDB structure file (*.pdb[.gz], *.cif[.gz]) based on confidence. @@ -121,12 +121,15 @@ def filter_file_on_residues( file: The path to the PDB file to filter. query: The confidence filter query. filtered_dir: The directory to save the filtered PDB file. - copy_method: How to copy when no residues have to be removed. + cacher: An optional cacher to use for caching downloaded files. + And copy method when no residues have to be removed. Returns: result with filtered_file property set to Path where filtered PDB file is saved. or None if structure was filtered out. """ + if cacher is None: + cacher = PassthroughCacher() structure = read_structure(file) residues = set(find_high_confidence_residues(structure, query.confidence)) count = len(residues) @@ -138,16 +141,16 @@ def filter_file_on_residues( count=count, ) total_residues = nr_of_residues_in_total(structure) - filtered_file = filtered_dir / file.name + filtered_file = filtered_dir / f"c{query.confidence}_{file.name}" if count == total_residues: # if no residues have to be removed then copy instead of slower gemmi writing - copyfile(file, filtered_file, copy_method) + copyfile(file, filtered_file, cacher.copy_method) else: new_structure = filter_out_low_confidence_residues( structure, residues, ) - write_structure(new_structure, filtered_file) + write_structure(new_structure, filtered_file, cacher) return ConfidenceFilterResult( input_file=file.name, count=count, @@ -159,7 +162,7 @@ def filter_files_on_confidence( alphafold_pdb_files: list[Path], query: ConfidenceFilterQuery, filtered_dir: Path, - copy_method: CopyMethod = "copy", + cacher: Cacher | None = None, ) -> Generator[ConfidenceFilterResult]: """Filter AlphaFoldDB structures based on confidence. @@ -167,7 +170,8 @@ def filter_files_on_confidence( alphafold_pdb_files: List of mmcif/PDB files from AlphaFoldDB to filter. query: The confidence filter query containing the confidence thresholds. filtered_dir: Directory where the filtered mmcif/PDB files will be saved. - copy_method: How to copy when a direct copy is possible. + cacher: An optional cacher to use for caching downloaded files. + And copy method when no residues have to be removed. Yields: For each mmcif/PDB files yields whether it was filtered or not, @@ -177,4 +181,4 @@ def filter_files_on_confidence( # In ../filter.py:filter_files_on_residues() we filter on number of residues on a file level # here we filter on file level and inside file remove low confidence residues for pdb_file in alphafold_pdb_files: - yield filter_file_on_residues(pdb_file, query, filtered_dir, copy_method) + yield filter_file_on_residues(pdb_file, query, filtered_dir, cacher) diff --git a/src/protein_quest/cli.py b/src/protein_quest/cli.py index a6ceefb..ad704cf 100644 --- a/src/protein_quest/cli.py +++ b/src/protein_quest/cli.py @@ -446,7 +446,7 @@ def _add_filter_confidence_parser(subparsers: argparse._SubParsersAction): In CSV format with `,,,` columns. Use `-` for stdout."""), ) - _add_copy_method_arguments(parser) + _add_cacher_arguments(parser) def _add_filter_chain_parser(subparsers: argparse._SubParsersAction): @@ -486,7 +486,7 @@ def _add_filter_chain_parser(subparsers: argparse._SubParsersAction): If not provided, will create a local cluster. If set to `sequential` will run tasks sequentially."""), ) - _add_copy_method_arguments(parser) + _add_cacher_arguments(parser) def _add_filter_residue_parser(subparsers: argparse._SubParsersAction): @@ -864,7 +864,7 @@ def _handle_filter_confidence(args: argparse.Namespace): min_residues = args.min_residues max_residues = args.max_residues stats_file: TextIOWrapper | None = args.write_stats - copy_method: CopyMethod = structure(args.copy_method, CopyMethod) # pyright: ignore[reportArgumentType] + cacher = _initialize_cacher(args) output_dir.mkdir(parents=True, exist_ok=True) input_files = sorted(glob_structure_files(input_dir)) @@ -884,7 +884,7 @@ def _handle_filter_confidence(args: argparse.Namespace): passed_count = 0 for r in tqdm( - filter_files_on_confidence(input_files, query, output_dir, copy_method=copy_method), + filter_files_on_confidence(input_files, query, output_dir, cacher=cacher), total=len(input_files), unit="file", ): @@ -903,7 +903,7 @@ def _handle_filter_chain(args): output_dir = structure(args.output_dir, Path) pdb_id2chain_mapping_file = args.chains scheduler_address = structure(args.scheduler_address, str | None) # pyright: ignore[reportArgumentType] - copy_method: CopyMethod = structure(args.copy_method, CopyMethod) # pyright: ignore[reportArgumentType] + cacher = _initialize_cacher(args) # make sure files in input dir with entries in mapping file are the same # complain when files from mapping file are missing on disk @@ -928,9 +928,7 @@ def _handle_filter_chain(args): rprint("[red]No valid structure files found. Exiting.") sys.exit(1) - results = filter_files_on_chain( - file2chain, output_dir, scheduler_address=scheduler_address, copy_method=copy_method - ) + results = filter_files_on_chain(file2chain, output_dir, scheduler_address=scheduler_address, cacher=cacher) nr_written = len([r for r in results if r.passed]) diff --git a/src/protein_quest/filters.py b/src/protein_quest/filters.py index 431f6f8..761cf37 100644 --- a/src/protein_quest/filters.py +++ b/src/protein_quest/filters.py @@ -12,7 +12,7 @@ from protein_quest.parallel import configure_dask_scheduler, dask_map_with_progress from protein_quest.structure import nr_residues_in_chain, write_single_chain_structure_file -from protein_quest.utils import CopyMethod, copyfile +from protein_quest.utils import Cacher, CopyMethod, copyfile logger = logging.getLogger(__name__) @@ -30,13 +30,13 @@ def filter_file_on_chain( file_and_chain: tuple[Path, str], output_dir: Path, out_chain: str = "A", - copy_method: CopyMethod = "copy", + cacher: Cacher | None = None, ) -> ChainFilterStatistics: input_file, chain_id = file_and_chain logger.debug("Filtering %s on chain %s", input_file, chain_id) try: output_file = write_single_chain_structure_file( - input_file, chain_id, output_dir, out_chain=out_chain, copy_method=copy_method + input_file, chain_id, output_dir, out_chain=out_chain, cacher=cacher ) return ChainFilterStatistics( input_file=input_file, @@ -53,7 +53,7 @@ def filter_files_on_chain( output_dir: Path, out_chain: str = "A", scheduler_address: str | Cluster | Literal["sequential"] | None = None, - copy_method: CopyMethod = "copy", + cacher: Cacher | None = None, ) -> list[ChainFilterStatistics]: """Filter mmcif/PDB files by chain. @@ -65,7 +65,8 @@ def filter_files_on_chain( scheduler_address: The address of the Dask scheduler. If not provided, will create a local cluster. If set to `sequential` will run tasks sequentially. - copy_method: How to copy when a direct copy is possible. + cacher: An optional cacher to use for caching written files. + And copy method when file is copied without changes. Returns: Result of the filtering process. @@ -74,7 +75,7 @@ def filter_files_on_chain( if scheduler_address == "sequential": def task(file_and_chain: tuple[Path, str]) -> ChainFilterStatistics: - return filter_file_on_chain(file_and_chain, output_dir, out_chain=out_chain, copy_method=copy_method) + return filter_file_on_chain(file_and_chain, output_dir, out_chain=out_chain, cacher=cacher) return list(map(task, file2chains)) @@ -93,7 +94,7 @@ def task(file_and_chain: tuple[Path, str]) -> ChainFilterStatistics: file2chains, output_dir=output_dir, out_chain=out_chain, - copy_method=copy_method, + cacher=cacher, ) diff --git a/src/protein_quest/io.py b/src/protein_quest/io.py index 4367c01..d59bbda 100644 --- a/src/protein_quest/io.py +++ b/src/protein_quest/io.py @@ -1,5 +1,6 @@ """Module for structure file input/output.""" +import asyncio import gzip import logging import shutil @@ -17,7 +18,7 @@ from mmcif.io.PdbxReader import PdbxReader from mmcif.io.PdbxWriter import PdbxWriter -from protein_quest.utils import CopyMethod, copyfile, user_cache_root_dir +from protein_quest.utils import Cacher, CopyMethod, PassthroughCacher, copyfile, user_cache_root_dir logger = logging.getLogger(__name__) @@ -33,7 +34,7 @@ """Set of valid structure file extensions.""" -def write_structure(structure: gemmi.Structure, path: Path): +def write_structure(structure: gemmi.Structure, path: Path, cacher: Cacher | None = None): """Write a gemmi structure to a file. Args: @@ -42,31 +43,37 @@ def write_structure(structure: gemmi.Structure, path: Path): The format depends on the file extension. See [StructureFileExtensions][protein_quest.io.StructureFileExtensions] for supported extensions. + cacher: An optional cacher to use for caching written files. Raises: ValueError: If the file extension is not supported. """ + if cacher is None: + cacher = PassthroughCacher() if path.name.endswith(".pdb") or path.name.endswith(".ent"): body: str = structure.make_pdb_string() - path.write_text(body) + asyncio.run(cacher.write_bytes(path, body.encode("utf-8"))) elif path.name.endswith(".pdb.gz") or path.name.endswith(".ent.gz"): body: str = structure.make_pdb_string() - with gzip.open(path, "wt") as f: - f.write(body) + gzbody: bytes = gzip.compress(body.encode("utf-8")) + asyncio.run(cacher.write_bytes(path, gzbody)) elif path.name.endswith(".cif"): # do not write chem_comp so it is viewable by molstar # see https://github.com/project-gemmi/gemmi/discussions/362 doc = structure.make_mmcif_document(gemmi.MmcifOutputGroups(True, chem_comp=False)) - doc.write_file(str(path)) + body = doc.as_string() + asyncio.run(cacher.write_bytes(path, body.encode("utf-8"))) elif path.name.endswith(".cif.gz"): doc = structure.make_mmcif_document(gemmi.MmcifOutputGroups(True, chem_comp=False)) cif_str = doc.as_string() - with gzip.open(path, "wt") as f: - f.write(cif_str) + gzbody: bytes = gzip.compress(cif_str.encode("utf-8")) + asyncio.run(cacher.write_bytes(path, gzbody)) elif path.name.endswith(".bcif"): structure2bcif(structure, path) + # TODO write to cacher elif path.name.endswith(".bcif.gz"): structure2bcifgz(structure, path) + # TODO write to cacher else: msg = f"Unsupported file extension in {path.name}. Supported extensions are: {valid_structure_file_extensions}" raise ValueError(msg) diff --git a/src/protein_quest/structure.py b/src/protein_quest/structure.py index 31625a6..fe3c7eb 100644 --- a/src/protein_quest/structure.py +++ b/src/protein_quest/structure.py @@ -9,7 +9,7 @@ from protein_quest.__version__ import __version__ from protein_quest.io import read_structure, split_name_and_extension, write_structure -from protein_quest.utils import CopyMethod, copyfile +from protein_quest.utils import Cacher, PassthroughCacher, copyfile logger = logging.getLogger(__name__) @@ -134,7 +134,7 @@ def write_single_chain_structure_file( chain2keep: str, output_dir: Path, out_chain: str = "A", - copy_method: CopyMethod = "copy", + cacher: Cacher | None = None, ) -> Path: """Write a single chain from a structure file to a new structure file. @@ -157,7 +157,8 @@ def write_single_chain_structure_file( chain2keep: The chain to keep. output_dir: Directory to save the output file. out_chain: The chain identifier for the output file. - copy_method: How to copy when no changes are needed to output file. + cacher: An optional cacher to use for caching written files. + Also how to copy when no changes are needed to output file Returns: Path to the output structure file @@ -166,6 +167,8 @@ def write_single_chain_structure_file( FileNotFoundError: If the input file does not exist. ChainNotFoundError: If the specified chain is not found in the input file. """ + if cacher is None: + cacher = PassthroughCacher() logger.debug(f"chain2keep: {chain2keep}, out_chain: {out_chain}") structure = read_structure(input_file) @@ -191,7 +194,7 @@ def write_single_chain_structure_file( out_chain, output_file, ) - copyfile(input_file, output_file, copy_method) + copyfile(input_file, output_file, cacher.copy_method) return output_file gemmi.Selection(chain_name).remove_not_selected(structure) @@ -203,6 +206,6 @@ def write_single_chain_structure_file( _dedup_sheets(structure, out_chain) _add_provenance_info(structure, chain_name, out_chain) - write_structure(structure, output_file) + write_structure(structure, output_file, cacher) return output_file diff --git a/src/protein_quest/utils.py b/src/protein_quest/utils.py index 1dcc884..de31933 100644 --- a/src/protein_quest/utils.py +++ b/src/protein_quest/utils.py @@ -57,6 +57,11 @@ def _cache_sub_dir(root_cache_dir: Path, filename: str, hash_length: int = 4) -> class Cacher(Protocol): """Protocol for a cacher.""" + @property + def copy_method(self) -> CopyMethod: + """The copy method used by the cacher.""" + ... + def __contains__(self, item: str | Path) -> bool: """Check if a file is in the cache. @@ -121,6 +126,11 @@ class PassthroughCacher(Cacher): On writes it just writes to the target path. """ + @property + def copy_method(self) -> CopyMethod: + """The copy method used by the cacher.""" + return "copy" + def __contains__(self, item: str | Path) -> bool: # We don't have anything cached ever return False