Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 31 additions & 4 deletions models/rfd3/src/rfd3/inference/input_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,16 +750,43 @@ def _apply_symmetry(self, atom_array, atom_array_input_annotated):
def _set_origin(self, atom_array):
"""Set origin token and initialize coordinates."""
if self.is_partial_diffusion:
# Partial diffusion: use COM, keep all coordinates
# Partial diffusion: keep all coordinates. Centering rules:
# 1. If symmetric: skip centering to preserve chain spacing.
# 2. If the user supplied `ori_token` or `infer_ori_strategy`,
# honor it (same path as regular diffusion). Previously this
# branch hard-coded `ori_token=None`, silently dropping the
# user's request.
# 3. Otherwise default to centering on the diffused-region COM
# (matches training `center_option=diffuse`). Centering on the
# joint target+binder COM places the binder far from origin in
# a frame the model never saw at training, biasing denoising
# to drag the binder toward the target's COM.
if exists(self.symmetry) and self.symmetry.id:
# For symmetric structures, avoid COM centering that would collapse chains
logger.info(
"Partial diffusion with symmetry: skipping COM centering to preserve chain spacing"
)
else:
elif exists(self.ori_token) or exists(self.infer_ori_strategy):
atom_array = set_com(
atom_array, ori_token=None, infer_ori_strategy="com"
atom_array,
ori_token=self.ori_token,
infer_ori_strategy=self.infer_ori_strategy,
)
else:
is_motif = atom_array.is_motif_atom_with_fixed_coord.astype(bool)
if is_motif.any() and (~is_motif).any():
diffused_coord = atom_array.coord[~is_motif]
finite = np.isfinite(diffused_coord).all(axis=-1)
center = np.nan_to_num(
np.mean(diffused_coord[finite], axis=0)
)
atom_array.coord = atom_array.coord - center
logger.info(
f"Partial diffusion: centering on diffused-region COM ({center})."
)
else:
atom_array = set_com(
atom_array, ori_token=None, infer_ori_strategy="com"
)
else:
# Standard: set ori token, zero out diffused atoms
atom_array = set_com(
Expand Down
69 changes: 69 additions & 0 deletions models/rfd3/tests/test_partial_diffusion.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import copy
import sys

import numpy as np
import pytest
from rfd3.inference.input_parsing import DesignInputSpecification
from rfd3.testing.testing_utils import (
TEST_JSON_DATA,
build_pipelines,
Expand All @@ -22,5 +25,71 @@ def test_partial_diffusion(example):
assert "partial_t" in aa.get_annotation_categories(), "partial_t not in atom_array"


def _build_pipeline_atom_array(args):
"""Run input parsing through `to_pipeline_input` and return the atom array.

Mirrors what `instantiate_example` does for inference but returns the
parsed atom array directly so we can assert on coordinate centering.
"""
spec = DesignInputSpecification.safe_init(**args)
pipeline_input = spec.to_pipeline_input(example_id="example")
return pipeline_input["atom_array"]


@pytest.mark.fast
def test_partial_diffusion_respects_ori_token():
"""User-supplied `ori_token` must shift the structure during partial diffusion.

Regression test: previously `_set_origin` hard-coded `ori_token=None` for
partial diffusion, silently dropping the user's request.
"""
base = copy.deepcopy(TEST_JSON_DATA["partial_diffusion"])
base.pop("ori_token", None)
base.pop("infer_ori_strategy", None)

aa_default = _build_pipeline_atom_array(copy.deepcopy(base))

shift = np.array([50.0, 0.0, 0.0], dtype=np.float32)
args_shift = copy.deepcopy(base)
args_shift["ori_token"] = shift.tolist()
aa_shift = _build_pipeline_atom_array(args_shift)

assert aa_default.array_length() == aa_shift.array_length()
delta = aa_shift.coord.mean(axis=0) - aa_default.coord.mean(axis=0)
# Coords are translated by -ori_token at parse time, so the post-parse
# whole-structure mean must drop by ~50 Å on x.
assert delta[0] == pytest.approx(-50.0, abs=1.5), (
f"ori_token=[50,0,0] should shift coords by -50 in x; got delta={delta}"
)
assert abs(delta[1]) < 1.5 and abs(delta[2]) < 1.5, (
f"ori_token=[50,0,0] should not move y/z; got delta={delta}"
)


@pytest.mark.fast
def test_partial_diffusion_defaults_to_diffused_region_com():
"""When neither `ori_token` nor `infer_ori_strategy` is supplied, partial
diffusion must center on the diffused-region COM (matches training
convention `center_option=diffuse`).

Regression test: previously this branch centered on the joint
target+diffused COM, biasing the model to drag the diffused region toward
the motif's COM.
"""
base = copy.deepcopy(TEST_JSON_DATA["partial_diffusion"])
base.pop("ori_token", None)
base.pop("infer_ori_strategy", None)
aa = _build_pipeline_atom_array(base)

is_motif = aa.is_motif_atom_with_fixed_coord.astype(bool)
if not is_motif.any() or not (~is_motif).any():
pytest.skip("Test fixture has no separable motif/diffused split")

diffused_com = aa.coord[~is_motif].mean(axis=0)
assert np.allclose(diffused_com, 0, atol=1e-3), (
f"diffused-region COM should be at origin after centering; got {diffused_com}"
)


if __name__ == "__main__":
pytest.main(sys.argv)