Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 13 additions & 9 deletions src/protein_quest/alphafold/confidence.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from protein_quest.converter import Percentage, PositiveInt, converter
from protein_quest.io import read_structure, write_structure
from protein_quest.ss import nr_of_residues_in_total
from protein_quest.utils import CopyMethod, copyfile
from protein_quest.utils import Cacher, PassthroughCacher, copyfile

"""
Methods to filter AlphaFoldDB structures on confidence scores.
Expand Down Expand Up @@ -113,20 +113,23 @@ class ConfidenceFilterResult:


def filter_file_on_residues(
file: Path, query: ConfidenceFilterQuery, filtered_dir: Path, copy_method: CopyMethod = "copy"
file: Path, query: ConfidenceFilterQuery, filtered_dir: Path, cacher: Cacher | None = None
) -> ConfidenceFilterResult:
"""Filter a single AlphaFoldDB structure file (*.pdb[.gz], *.cif[.gz]) based on confidence.

Args:
file: The path to the PDB file to filter.
query: The confidence filter query.
filtered_dir: The directory to save the filtered PDB file.
copy_method: How to copy when no residues have to be removed.
cacher: An optional cacher to use for caching downloaded files.
And copy method when no residues have to be removed.

Returns:
result with filtered_file property set to Path where filtered PDB file is saved.
or None if structure was filtered out.
"""
if cacher is None:
cacher = PassthroughCacher()
structure = read_structure(file)
residues = set(find_high_confidence_residues(structure, query.confidence))
count = len(residues)
Expand All @@ -138,16 +141,16 @@ def filter_file_on_residues(
count=count,
)
total_residues = nr_of_residues_in_total(structure)
filtered_file = filtered_dir / file.name
filtered_file = filtered_dir / f"c{query.confidence}_{file.name}"
if count == total_residues:
# if no residues have to be removed then copy instead of slower gemmi writing
copyfile(file, filtered_file, copy_method)
copyfile(file, filtered_file, cacher.copy_method)
else:
new_structure = filter_out_low_confidence_residues(
structure,
residues,
)
write_structure(new_structure, filtered_file)
write_structure(new_structure, filtered_file, cacher)
return ConfidenceFilterResult(
input_file=file.name,
count=count,
Expand All @@ -159,15 +162,16 @@ def filter_files_on_confidence(
alphafold_pdb_files: list[Path],
query: ConfidenceFilterQuery,
filtered_dir: Path,
copy_method: CopyMethod = "copy",
cacher: Cacher | None = None,
) -> Generator[ConfidenceFilterResult]:
"""Filter AlphaFoldDB structures based on confidence.

Args:
alphafold_pdb_files: List of mmcif/PDB files from AlphaFoldDB to filter.
query: The confidence filter query containing the confidence thresholds.
filtered_dir: Directory where the filtered mmcif/PDB files will be saved.
copy_method: How to copy when a direct copy is possible.
cacher: An optional cacher to use for caching downloaded files.
And copy method when no residues have to be removed.

Yields:
For each mmcif/PDB files yields whether it was filtered or not,
Expand All @@ -177,4 +181,4 @@ def filter_files_on_confidence(
# In ../filter.py:filter_files_on_residues() we filter on number of residues on a file level
# here we filter on file level and inside file remove low confidence residues
for pdb_file in alphafold_pdb_files:
yield filter_file_on_residues(pdb_file, query, filtered_dir, copy_method)
yield filter_file_on_residues(pdb_file, query, filtered_dir, cacher)
14 changes: 6 additions & 8 deletions src/protein_quest/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def _add_filter_confidence_parser(subparsers: argparse._SubParsersAction):
In CSV format with `<input_file>,<residue_count>,<passed>,<output_file>` columns.
Use `-` for stdout."""),
)
_add_copy_method_arguments(parser)
_add_cacher_arguments(parser)


def _add_filter_chain_parser(subparsers: argparse._SubParsersAction):
Expand Down Expand Up @@ -486,7 +486,7 @@ def _add_filter_chain_parser(subparsers: argparse._SubParsersAction):
If not provided, will create a local cluster.
If set to `sequential` will run tasks sequentially."""),
)
_add_copy_method_arguments(parser)
_add_cacher_arguments(parser)


def _add_filter_residue_parser(subparsers: argparse._SubParsersAction):
Expand Down Expand Up @@ -864,7 +864,7 @@ def _handle_filter_confidence(args: argparse.Namespace):
min_residues = args.min_residues
max_residues = args.max_residues
stats_file: TextIOWrapper | None = args.write_stats
copy_method: CopyMethod = structure(args.copy_method, CopyMethod) # pyright: ignore[reportArgumentType]
cacher = _initialize_cacher(args)

output_dir.mkdir(parents=True, exist_ok=True)
input_files = sorted(glob_structure_files(input_dir))
Expand All @@ -884,7 +884,7 @@ def _handle_filter_confidence(args: argparse.Namespace):

passed_count = 0
for r in tqdm(
filter_files_on_confidence(input_files, query, output_dir, copy_method=copy_method),
filter_files_on_confidence(input_files, query, output_dir, cacher=cacher),
total=len(input_files),
unit="file",
):
Expand All @@ -903,7 +903,7 @@ def _handle_filter_chain(args):
output_dir = structure(args.output_dir, Path)
pdb_id2chain_mapping_file = args.chains
scheduler_address = structure(args.scheduler_address, str | None) # pyright: ignore[reportArgumentType]
copy_method: CopyMethod = structure(args.copy_method, CopyMethod) # pyright: ignore[reportArgumentType]
cacher = _initialize_cacher(args)

# make sure files in input dir with entries in mapping file are the same
# complain when files from mapping file are missing on disk
Expand All @@ -928,9 +928,7 @@ def _handle_filter_chain(args):
rprint("[red]No valid structure files found. Exiting.")
sys.exit(1)

results = filter_files_on_chain(
file2chain, output_dir, scheduler_address=scheduler_address, copy_method=copy_method
)
results = filter_files_on_chain(file2chain, output_dir, scheduler_address=scheduler_address, cacher=cacher)

nr_written = len([r for r in results if r.passed])

Expand Down
15 changes: 8 additions & 7 deletions src/protein_quest/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from protein_quest.parallel import configure_dask_scheduler, dask_map_with_progress
from protein_quest.structure import nr_residues_in_chain, write_single_chain_structure_file
from protein_quest.utils import CopyMethod, copyfile
from protein_quest.utils import Cacher, CopyMethod, copyfile

logger = logging.getLogger(__name__)

Expand All @@ -30,13 +30,13 @@ def filter_file_on_chain(
file_and_chain: tuple[Path, str],
output_dir: Path,
out_chain: str = "A",
copy_method: CopyMethod = "copy",
cacher: Cacher | None = None,
) -> ChainFilterStatistics:
input_file, chain_id = file_and_chain
logger.debug("Filtering %s on chain %s", input_file, chain_id)
try:
output_file = write_single_chain_structure_file(
input_file, chain_id, output_dir, out_chain=out_chain, copy_method=copy_method
input_file, chain_id, output_dir, out_chain=out_chain, cacher=cacher
)
return ChainFilterStatistics(
input_file=input_file,
Expand All @@ -53,7 +53,7 @@ def filter_files_on_chain(
output_dir: Path,
out_chain: str = "A",
scheduler_address: str | Cluster | Literal["sequential"] | None = None,
copy_method: CopyMethod = "copy",
cacher: Cacher | None = None,
) -> list[ChainFilterStatistics]:
"""Filter mmcif/PDB files by chain.

Expand All @@ -65,7 +65,8 @@ def filter_files_on_chain(
scheduler_address: The address of the Dask scheduler.
If not provided, will create a local cluster.
If set to `sequential` will run tasks sequentially.
copy_method: How to copy when a direct copy is possible.
cacher: An optional cacher to use for caching written files.
And copy method when file is copied without changes.

Returns:
Result of the filtering process.
Expand All @@ -74,7 +75,7 @@ def filter_files_on_chain(
if scheduler_address == "sequential":

def task(file_and_chain: tuple[Path, str]) -> ChainFilterStatistics:
return filter_file_on_chain(file_and_chain, output_dir, out_chain=out_chain, copy_method=copy_method)
return filter_file_on_chain(file_and_chain, output_dir, out_chain=out_chain, cacher=cacher)

return list(map(task, file2chains))

Expand All @@ -93,7 +94,7 @@ def task(file_and_chain: tuple[Path, str]) -> ChainFilterStatistics:
file2chains,
output_dir=output_dir,
out_chain=out_chain,
copy_method=copy_method,
cacher=cacher,
)


Expand Down
23 changes: 15 additions & 8 deletions src/protein_quest/io.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Module for structure file input/output."""

import asyncio
import gzip
import logging
import shutil
Expand All @@ -17,7 +18,7 @@
from mmcif.io.PdbxReader import PdbxReader
from mmcif.io.PdbxWriter import PdbxWriter

from protein_quest.utils import CopyMethod, copyfile, user_cache_root_dir
from protein_quest.utils import Cacher, CopyMethod, PassthroughCacher, copyfile, user_cache_root_dir

logger = logging.getLogger(__name__)

Expand All @@ -33,7 +34,7 @@
"""Set of valid structure file extensions."""


def write_structure(structure: gemmi.Structure, path: Path):
def write_structure(structure: gemmi.Structure, path: Path, cacher: Cacher | None = None):
"""Write a gemmi structure to a file.

Args:
Expand All @@ -42,31 +43,37 @@ def write_structure(structure: gemmi.Structure, path: Path):
The format depends on the file extension.
See [StructureFileExtensions][protein_quest.io.StructureFileExtensions]
for supported extensions.
cacher: An optional cacher to use for caching written files.

Raises:
ValueError: If the file extension is not supported.
"""
if cacher is None:
cacher = PassthroughCacher()
if path.name.endswith(".pdb") or path.name.endswith(".ent"):
body: str = structure.make_pdb_string()
path.write_text(body)
asyncio.run(cacher.write_bytes(path, body.encode("utf-8")))
elif path.name.endswith(".pdb.gz") or path.name.endswith(".ent.gz"):
body: str = structure.make_pdb_string()
with gzip.open(path, "wt") as f:
f.write(body)
gzbody: bytes = gzip.compress(body.encode("utf-8"))
asyncio.run(cacher.write_bytes(path, gzbody))
elif path.name.endswith(".cif"):
# do not write chem_comp so it is viewable by molstar
# see https://github.com/project-gemmi/gemmi/discussions/362
doc = structure.make_mmcif_document(gemmi.MmcifOutputGroups(True, chem_comp=False))
doc.write_file(str(path))
body = doc.as_string()
asyncio.run(cacher.write_bytes(path, body.encode("utf-8")))
elif path.name.endswith(".cif.gz"):
doc = structure.make_mmcif_document(gemmi.MmcifOutputGroups(True, chem_comp=False))
cif_str = doc.as_string()
with gzip.open(path, "wt") as f:
f.write(cif_str)
gzbody: bytes = gzip.compress(cif_str.encode("utf-8"))
asyncio.run(cacher.write_bytes(path, gzbody))
elif path.name.endswith(".bcif"):
structure2bcif(structure, path)
# TODO write to cacher
elif path.name.endswith(".bcif.gz"):
structure2bcifgz(structure, path)
# TODO write to cacher
else:
msg = f"Unsupported file extension in {path.name}. Supported extensions are: {valid_structure_file_extensions}"
raise ValueError(msg)
Expand Down
13 changes: 8 additions & 5 deletions src/protein_quest/structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from protein_quest.__version__ import __version__
from protein_quest.io import read_structure, split_name_and_extension, write_structure
from protein_quest.utils import CopyMethod, copyfile
from protein_quest.utils import Cacher, PassthroughCacher, copyfile

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -134,7 +134,7 @@ def write_single_chain_structure_file(
chain2keep: str,
output_dir: Path,
out_chain: str = "A",
copy_method: CopyMethod = "copy",
cacher: Cacher | None = None,
) -> Path:
"""Write a single chain from a structure file to a new structure file.

Expand All @@ -157,7 +157,8 @@ def write_single_chain_structure_file(
chain2keep: The chain to keep.
output_dir: Directory to save the output file.
out_chain: The chain identifier for the output file.
copy_method: How to copy when no changes are needed to output file.
cacher: An optional cacher to use for caching written files.
Also how to copy when no changes are needed to output file

Returns:
Path to the output structure file
Expand All @@ -166,6 +167,8 @@ def write_single_chain_structure_file(
FileNotFoundError: If the input file does not exist.
ChainNotFoundError: If the specified chain is not found in the input file.
"""
if cacher is None:
cacher = PassthroughCacher()

logger.debug(f"chain2keep: {chain2keep}, out_chain: {out_chain}")
structure = read_structure(input_file)
Expand All @@ -191,7 +194,7 @@ def write_single_chain_structure_file(
out_chain,
output_file,
)
copyfile(input_file, output_file, copy_method)
copyfile(input_file, output_file, cacher.copy_method)
return output_file

gemmi.Selection(chain_name).remove_not_selected(structure)
Expand All @@ -203,6 +206,6 @@ def write_single_chain_structure_file(
_dedup_sheets(structure, out_chain)
_add_provenance_info(structure, chain_name, out_chain)

write_structure(structure, output_file)
write_structure(structure, output_file, cacher)

return output_file
10 changes: 10 additions & 0 deletions src/protein_quest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ def _cache_sub_dir(root_cache_dir: Path, filename: str, hash_length: int = 4) ->
class Cacher(Protocol):
"""Protocol for a cacher."""

@property
def copy_method(self) -> CopyMethod:
"""The copy method used by the cacher."""
...

def __contains__(self, item: str | Path) -> bool:
"""Check if a file is in the cache.

Expand Down Expand Up @@ -121,6 +126,11 @@ class PassthroughCacher(Cacher):
On writes it just writes to the target path.
"""

@property
def copy_method(self) -> CopyMethod:
"""The copy method used by the cacher."""
return "copy"

def __contains__(self, item: str | Path) -> bool:
# We don't have anything cached ever
return False
Expand Down
Loading