Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- [[PR #446](https://github.com/nf-core/proteinfold/pulls/446)] - Fix warnings from Nextflow lint.
- [[PR #451](https://github.com/nf-core/proteinfold/pulls/451)] - Remove af2 multimer padding from msa plots.
- [[#417](https://github.com/nf-core/proteinfold/issues/417)] - Add `boltz_use_kernels` parameter to enable/disable using optimized Triton-based CUDA kernels CUDA kernels for Boltz inference.
- [[#417](https://github.com/nf-core/proteinfold/issues/417)] - Handle incompatible CUDA kernel errors in Boltz by automatically retrying with `--no_kernels` false.
- [[PR #454](https://github.com/nf-core/proteinfold/pulls/454)] - Update publishdir patterns for alphafold2 modules
- [[#417](https://github.com/nf-core/proteinfold/issues/417)] - Handle incompatible CUDA kernel errors in Boltz by automatically retrying with `--use_kernels` false.
- [[PR #454](https://github.com/nf-core/proteinfold/pulls/454)] - Update publishdir patterns for alphafold2 modules.
- [[PR #458](https://github.com/nf-core/proteinfold/pulls/458)] - Update publishdir patterns for colabfold module.
- [[#313](https://github.com/nf-core/proteinfold/issues/313)] - Harmonize colabfold metrics extraction with other modes.
- [[#455](https://github.com/nf-core/proteinfold/issues/455)] - Fix colabfold monomer inheriting id from fasta header.
- [[#457](https://github.com/nf-core/proteinfold/issues/457)] - Fix colabfold multimer always downloading model weights.

### Parameters

Expand Down
54 changes: 35 additions & 19 deletions bin/extract_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,6 @@ def read_pkl(name, pkl_files):
if pkl_file.endswith("final_features.pkl"): # HelixFold3 - This one must be first
write_tsv(f"{name}_msa.tsv", format_msa_rows(data["feat"]["msa"]))
elif pkl_file.endswith("features.pkl"): # AlphaFold2.3
#data["msa"][:data["num_alignments"]]
try:
N = data["num_alignments"][0] #monomer
except:
Expand All @@ -164,21 +163,21 @@ def read_pkl(name, pkl_files):
if 'ptm' not in data.keys():
print(f"No pTM/iPTM output in {pkl_file}, it was likely a monomer calculation")
else:
#with open(f"{name}_{model_id}_ptm.tsv", 'w') as f:
# f.write(f"{np.round(data['ptm'],3)}\n")
#with open(f"{name}_{model_id}_iptm.tsv", 'w') as f:
# f.write(f"{np.round(data['iptm'],3)}\n")
ptm_data[f"{model_id}"] = f"{np.round(data['ptm'],3)}\n"
iptm_data[f"{model_id}"] = f"{np.round(data.get('iptm',0.),3)}\n"
ptm_data[model_id] = f"{np.round(data['ptm'],3)}\n"

if 'iptm' in data:
iptm_data[model_id] = f"{np.round(data['iptm'],3)}\n"
if ptm_data:
ptm_rows = [[k, v.strip()] for k, v in ptm_data.items()]
ptm_rows = sorted([[k, v.strip()] for k, v in ptm_data.items()], key=lambda x: x[0])
write_tsv(f"{name}_ptm.tsv", ptm_rows)

if iptm_data:
iptm_rows = [[k, v.strip()] for k, v in iptm_data.items()]
iptm_rows = sorted([[k, v.strip()] for k, v in iptm_data.items()], key=lambda x: x[0])
write_tsv(f"{name}_iptm.tsv", iptm_rows)


def read_paired_a3m(name, a3m_file):
msa_rows = a3m_to_int(a3m_file)
write_tsv(f"{name}_msa.tsv", format_msa_rows(msa_rows))

def read_a3m(name, a3m_files):
# RosettaFold-All-Atom
Expand Down Expand Up @@ -426,20 +425,35 @@ def read_pt(name, pt_files):
write_tsv(f"{name}_0_pae.tsv", format_pae_rows(np.squeeze(data["pae"].numpy())))
break

def read_colabfold_paes(name, colabfold_pae_fn):
with open(colabfold_pae_fn) as f:
data = json.load(f)
pae = data["predicted_aligned_error"]
write_tsv(f"{name}_0_pae.tsv", format_pae_rows(pae))
def read_colabfold_metrics(name, colabfold_metrics_fns):
ptm_rows = []
iptm_rows = []
for fn in colabfold_metrics_fns:
with open(fn) as f:
data = json.load(f)
rank_id = int(fn.split("rank_")[1].split("_")[0])-1
model_id = int(fn.split("model_")[1].split("_")[0])
seed_id = int(fn.split("seed_")[1].split(".")[0])
if "pae" in data:
write_tsv(f"{name}_{rank_id}_pae.tsv", format_pae_rows(data["pae"]))
if "ptm" in data:
ptm_rows.append((f"{rank_id}", data["ptm"]))
if "iptm" in data:
iptm_rows.append((f"{rank_id}", data["iptm"]))
if len(ptm_rows)>0:
write_tsv(f"{name}_ptm.tsv", sorted(ptm_rows, key = lambda x: x[0]))
if len(iptm_rows)>0:
write_tsv(f"{name}_iptm.tsv", sorted(iptm_rows, key = lambda x: x[0]))

def main():
parser = argparse.ArgumentParser()
parser.add_argument("--pkls", dest="pkls", required=False, nargs="+") # For reading both HelixFold3 and AlphaFold2 MSA formats
parser.add_argument("--npzs", dest="npzs", required=False, nargs="+") # For reading the Boltz-1 PAE formats. TODO: Boltz-1 MSA not implemented (go straight to .a3m file), implement
parser.add_argument("--a3ms", dest="a3ms", required=False, nargs="+") # For reading the RosettaFold-All-Atom, ColabFold MSA formats
parser.add_argument("--a3ms", dest="a3ms", required=False, nargs="+") # For reading the RosettaFold-All-Atom MSA formats
parser.add_argument("--paired_a3m", dest="paired_a3m", required=False) # For reading the ColabFold MSA format
parser.add_argument("--csvs", dest="csvs", required=False, nargs="+") # For reading boltz csvs
parser.add_argument("--jsons", dest="jsons", required=False, nargs="+") # For reading the AF3 MSA & PAE, HF3 PAE
parser.add_argument("--colabfold_pae_fn", required=False)
parser.add_argument("--colabfold_metrics_fns", required=False, nargs="+")
parser.add_argument("--pts", dest="pts", required=False, nargs="+") # For read RFAA pytorch model to get PAE data
parser.add_argument("--structs", dest="structs", required=False, nargs="+")
parser.add_argument("--name", default="untitled", dest="name") # might need a --name $meta.id
Expand All @@ -449,6 +463,8 @@ def main():
read_pkl(args.name, args.pkls)
if args.a3ms:
read_a3m(args.name, args.a3ms)
if args.paired_a3m:
read_paired_a3m(args.name, args.paired_a3m)
if args.csvs:
read_csv(args.name, args.csvs)
if args.npzs:
Expand All @@ -459,8 +475,8 @@ def main():
read_pt(args.name, args.pts)
if args.structs:
extract_structs_plddt_to_tsv(args.name, args.structs)
if args.colabfold_pae_fn:
read_colabfold_paes(args.name, args.colabfold_pae_fn)
if args.colabfold_metrics_fns:
read_colabfold_metrics(args.name, args.colabfold_metrics_fns)

if __name__ == "__main__":
main()
8 changes: 2 additions & 6 deletions bin/generate_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def generate_pae_plot(pae_path, out_dir, name, save_image=False):

def generate_output_images(msa_path, plddt_data, name, out_dir, in_type, generate_tsv, pdb):
msa = []
if in_type.lower() != "colabfold" and not msa_path.endswith("NO_FILE"):
if not msa_path.endswith("NO_FILE"):
with open(msa_path, "r") as in_file:
for line in in_file:
msa.append([int(x) for x in line.strip().split()])
Expand Down Expand Up @@ -448,11 +448,7 @@ def pdb_to_lddt(struct_files, generate_tsv):
i += 1

if not args.msa.endswith("NO_FILE"):
image_path = (
f"{args.output_dir}/{args.msa}"
if args.in_type.lower() == "colabfold"
else f"{args.output_dir}/{args.name}_{args.in_type}_seq_coverage.png"
)
image_path = f"{args.output_dir}/{args.name}_{args.in_type}_seq_coverage.png"
with open(image_path, "rb") as in_file:
proteinfold_template = proteinfold_template.replace(
"seq_coverage.png",
Expand Down
18 changes: 14 additions & 4 deletions conf/modules_colabfold.config
Original file line number Diff line number Diff line change
Expand Up @@ -42,16 +42,26 @@ process {
].join(' ').trim()
publishDir = [
[
path: { "${params.outdir}/colabfold/" },
path: { "${params.outdir}/colabfold/${meta.id}/" },
mode: 'copy',
saveAs: { filename -> filename.equals('versions.yml') ? null : filename },
pattern: '*.*'
saveAs: { filename ->
if(filename.endsWith('_pae.tsv')){
"paes/$filename"
} else { filename }
},
pattern: '*.tsv'
],
[
enabled: params.save_intermediates,
path: { "${params.outdir}/colabfold/${meta.id}/" },
mode: 'copy',
pattern: 'raw/**'
],
[
path: { "${params.outdir}/colabfold/top_ranked_structures" },
mode: 'copy',
saveAs: { "${meta.id}.pdb" },
pattern: '*_relaxed_rank_001*.pdb'
pattern: '*_colabfold.pdb'
]
]
}
Expand Down
53 changes: 33 additions & 20 deletions modules/local/colabfold_batch/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@ process COLABFOLD_BATCH {
val numRec

output:
tuple val(meta), path ("${meta.id}_colabfold.pdb"), emit: top_ranked_pdb
tuple val(meta), path ("*relaxed_rank_*.pdb") , emit: pdb
tuple val(meta), path ("*_coverage.png") , emit: msa
tuple val(meta), path ("*_mqc.png") , emit: multiqc
tuple val(meta), path ("${meta.id}_0_pae.tsv") , emit: pae
path "versions.yml" , emit: versions
path ("raw/**") , emit: raw
tuple val(meta), path ("${meta.id}_colabfold.pdb") , emit: top_ranked_pdb
tuple val(meta), path ("raw/*relaxed_rank_*.pdb") , emit: pdb
tuple val(meta), path ("${meta.id}_colabfold_msa.tsv") , emit: msa
tuple val(meta), path ("${meta.id}_plddt.tsv") , emit: multiqc
tuple val(meta), path ("${meta.id}_*_pae.tsv") , optional: true, emit: paes
tuple val(meta), path ("${meta.id}_0_pae.tsv") , optional: true, emit: pae
tuple val(meta), path ("${meta.id}_ptm.tsv") , optional: true, emit: ptms
tuple val(meta), path ("${meta.id}_iptm.tsv") , optional: true, emit: iptms
path "versions.yml" , emit: versions

when:
task.ext.when == null || task.ext.when
Expand All @@ -37,27 +41,33 @@ process COLABFOLD_BATCH {
fi

touch params/download_finished.txt
touch params/download_complexes_multimer_v3_finished.txt
touch params/download_complexes_multimer_v2_finished.txt
touch params/download_complexes_multimer_v1_finished.txt

colabfold_batch \\
$args \\
--num-recycle ${numRec} \\
--data \$PWD \\
--model-type ${colabfold_model_preset} \\
${fasta} \\
\$PWD
raw/

for i in `find *.png -maxdepth 0`; do cp \$i \${i%'.png'}_mqc.png; done
if [ ! -e `find *_relaxed_rank_001_*.pdb` ]; then
cp *_relaxed_rank_001*.pdb ${meta.id}_colabfold.pdb
if [ ! -e `find raw/*_relaxed_rank_001_*.pdb` ]; then
prefix=relaxed
cp raw/*_relaxed_rank_001*.pdb ${meta.id}_colabfold.pdb
else
cp *_unrelaxed_rank_001*.pdb ${meta.id}_colabfold.pdb
prefix=unrelaxed
cp raw/*_unrelaxed_rank_001*.pdb ${meta.id}_colabfold.pdb
fi

#Note: only multimer prefix is meta.id
extract_metrics.py --name ${meta.id} \\
--colabfold_pae *_predicted_aligned_error_v1.json
--colabfold_metrics_fns raw/*scores_rank*.json \\
--structs raw/*_\${prefix}_rank*.pdb \\
--paired_a3m raw/${meta.id}.a3m

mv *_coverage.png ${meta.id}_seq_coverage.png
cp raw/*_coverage.png ${meta.id}_seq_coverage.png
mv "${meta.id}_msa.tsv" "${meta.id}_colabfold_msa.tsv"

cat <<-END_VERSIONS > versions.yml
"${task.process}":
Expand All @@ -68,14 +78,17 @@ process COLABFOLD_BATCH {

stub:
"""
mkdir raw
touch ./"${meta.id}"_colabfold.pdb
touch ./"${meta.id}"_mqc.png
touch ./${meta.id}_relaxed_rank_01.pdb
touch ./${meta.id}_relaxed_rank_02.pdb
touch ./${meta.id}_relaxed_rank_03.pdb
touch ./${meta.id}_coverage.png
touch ./${meta.id}_scores_rank.json
touch ./raw/${meta.id}_relaxed_rank_001_model_1_seed_000.pdb
touch ./raw/${meta.id}_relaxed_rank_002_model_2_seed_000.pdb
touch ./raw/${meta.id}_relaxed_rank_003_model_3_seed_000.pdb
touch ./${meta.id}_seq_coverage.png
touch ./raw/${meta.id}_scores_rank.json
touch ./${meta.id}_0_pae.tsv
touch ./${meta.id}_ptm.tsv
touch ./${meta.id}_plddt.tsv
touch ./${meta.id}_colabfold_msa.tsv

cat <<-END_VERSIONS > versions.yml
"${task.process}":
Expand Down
2 changes: 0 additions & 2 deletions modules/local/generate_report/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,6 @@ process GENERATE_REPORT {
--name ${meta.id} \\
$args \\

[ -f ${meta.id}_seq_coverage.png ] && mv ${meta.id}_seq_coverage.png ${meta.id}_colabfold_seq_coverage.png

cat <<-END_VERSIONS > versions.yml
"${task.process}":
python: \$(python3 --version | sed 's/Python //g')
Expand Down
2 changes: 1 addition & 1 deletion nextflow.config
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ params {
colabfold_use_amber = true
colabfold_db = null
colabfold_db_load_mode = 0
colabfold_use_templates = true
colabfold_use_templates = false
colabfold_create_index = false

// Colabfold links
Expand Down
2 changes: 1 addition & 1 deletion nextflow_schema.json
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@
},
"colabfold_use_templates": {
"type": "boolean",
"default": true,
"default": false,
"description": "Use PDB templates",
"fa_icon": "fas fa-paste"
},
Expand Down
64 changes: 33 additions & 31 deletions tests/colabfold_download.nf.test.snap
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
{
"-profile test_colabfold_download": {
"content": [
7,
9,
{
"ARIA2": {
"aria2": null
Expand All @@ -14,6 +14,9 @@
"python": "3.12.7",
"generate_report.py": "Python 3.12.7"
},
"MULTIFASTA_TO_CSV": {
"sed": 4.7
},
"Workflow": {
"nf-core/proteinfold": "v1.2.0dev"
}
Expand All @@ -23,25 +26,27 @@
"DBs/colabfold",
"DBs/colabfold/params",
"colabfold",
"colabfold/T1024_0_pae.tsv",
"colabfold/T1024_colabfold.pdb",
"colabfold/T1024_coverage.png",
"colabfold/T1024_mqc.png",
"colabfold/T1024_relaxed_rank_01.pdb",
"colabfold/T1024_relaxed_rank_02.pdb",
"colabfold/T1024_relaxed_rank_03.pdb",
"colabfold/T1026_0_pae.tsv",
"colabfold/T1026_colabfold.pdb",
"colabfold/T1026_coverage.png",
"colabfold/T1026_mqc.png",
"colabfold/T1026_relaxed_rank_01.pdb",
"colabfold/T1026_relaxed_rank_02.pdb",
"colabfold/T1026_relaxed_rank_03.pdb",
"colabfold/T1024",
"colabfold/T1024/T1024_colabfold_msa.tsv",
"colabfold/T1024/T1024_plddt.tsv",
"colabfold/T1024/T1024_ptm.tsv",
"colabfold/T1024/paes",
"colabfold/T1024/paes/T1024_0_pae.tsv",
"colabfold/T1026",
"colabfold/T1026/T1026_colabfold_msa.tsv",
"colabfold/T1026/T1026_plddt.tsv",
"colabfold/T1026/T1026_ptm.tsv",
"colabfold/T1026/paes",
"colabfold/T1026/paes/T1026_0_pae.tsv",
"colabfold/top_ranked_structures",
"colabfold/top_ranked_structures/T1024.pdb",
"colabfold/top_ranked_structures/T1026.pdb",
"generate",
"generate/test_LDDT.html",
"generate/test_alphafold2_report.html",
"generate/test_seq_coverage.png",
"multifasta",
"multifasta/input.csv",
"multiqc",
"multiqc/multiqc_data",
"multiqc/multiqc_plots",
Expand All @@ -53,29 +58,26 @@
[
"file.txt:md5,d41d8cd98f00b204e9800998ecf8427e"
],
"T1024_colabfold_msa.tsv:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1024_plddt.tsv:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1024_ptm.tsv:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1024_0_pae.tsv:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1024_colabfold.pdb:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1024_coverage.png:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1024_mqc.png:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1024_relaxed_rank_01.pdb:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1024_relaxed_rank_02.pdb:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1024_relaxed_rank_03.pdb:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1026_colabfold_msa.tsv:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1026_plddt.tsv:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1026_ptm.tsv:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1026_0_pae.tsv:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1026_colabfold.pdb:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1026_coverage.png:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1026_mqc.png:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1026_relaxed_rank_01.pdb:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1026_relaxed_rank_02.pdb:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1026_relaxed_rank_03.pdb:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1024.pdb:md5,d41d8cd98f00b204e9800998ecf8427e",
"T1026.pdb:md5,d41d8cd98f00b204e9800998ecf8427e",
"test_LDDT.html:md5,d41d8cd98f00b204e9800998ecf8427e",
"test_alphafold2_report.html:md5,d41d8cd98f00b204e9800998ecf8427e",
"test_seq_coverage.png:md5,d41d8cd98f00b204e9800998ecf8427e"
"test_seq_coverage.png:md5,d41d8cd98f00b204e9800998ecf8427e",
"input.csv:md5,d41d8cd98f00b204e9800998ecf8427e"
]
],
"meta": {
"nf-test": "0.9.3",
"nextflow": "25.10.2"
"nf-test": "0.9.2",
"nextflow": "25.10.3"
},
"timestamp": "2026-01-13T14:36:06.220362"
"timestamp": "2026-01-30T21:28:32.731283423"
}
}
Loading