diff --git a/models/rfd3/src/rfd3/inference/input_parsing.py b/models/rfd3/src/rfd3/inference/input_parsing.py index f8b90217..cdf714b2 100644 --- a/models/rfd3/src/rfd3/inference/input_parsing.py +++ b/models/rfd3/src/rfd3/inference/input_parsing.py @@ -750,16 +750,43 @@ def _apply_symmetry(self, atom_array, atom_array_input_annotated): def _set_origin(self, atom_array): """Set origin token and initialize coordinates.""" if self.is_partial_diffusion: - # Partial diffusion: use COM, keep all coordinates + # Partial diffusion: keep all coordinates. Centering rules: + # 1. If symmetric: skip centering to preserve chain spacing. + # 2. If the user supplied `ori_token` or `infer_ori_strategy`, + # honor it (same path as regular diffusion). Previously this + # branch hard-coded `ori_token=None`, silently dropping the + # user's request. + # 3. Otherwise default to centering on the diffused-region COM + # (matches training `center_option=diffuse`). Centering on the + # joint target+binder COM places the binder far from origin in + # a frame the model never saw at training, biasing denoising + # to drag the binder toward the target's COM. if exists(self.symmetry) and self.symmetry.id: - # For symmetric structures, avoid COM centering that would collapse chains logger.info( "Partial diffusion with symmetry: skipping COM centering to preserve chain spacing" ) - else: + elif exists(self.ori_token) or exists(self.infer_ori_strategy): atom_array = set_com( - atom_array, ori_token=None, infer_ori_strategy="com" + atom_array, + ori_token=self.ori_token, + infer_ori_strategy=self.infer_ori_strategy, ) + else: + is_motif = atom_array.is_motif_atom_with_fixed_coord.astype(bool) + if is_motif.any() and (~is_motif).any(): + diffused_coord = atom_array.coord[~is_motif] + finite = np.isfinite(diffused_coord).all(axis=-1) + center = np.nan_to_num( + np.mean(diffused_coord[finite], axis=0) + ) + atom_array.coord = atom_array.coord - center + logger.info( + f"Partial diffusion: centering on diffused-region COM ({center})." + ) + else: + atom_array = set_com( + atom_array, ori_token=None, infer_ori_strategy="com" + ) else: # Standard: set ori token, zero out diffused atoms atom_array = set_com( diff --git a/models/rfd3/tests/test_partial_diffusion.py b/models/rfd3/tests/test_partial_diffusion.py index 4cacd799..a6cb3161 100644 --- a/models/rfd3/tests/test_partial_diffusion.py +++ b/models/rfd3/tests/test_partial_diffusion.py @@ -1,6 +1,9 @@ +import copy import sys +import numpy as np import pytest +from rfd3.inference.input_parsing import DesignInputSpecification from rfd3.testing.testing_utils import ( TEST_JSON_DATA, build_pipelines, @@ -22,5 +25,71 @@ def test_partial_diffusion(example): assert "partial_t" in aa.get_annotation_categories(), "partial_t not in atom_array" +def _build_pipeline_atom_array(args): + """Run input parsing through `to_pipeline_input` and return the atom array. + + Mirrors what `instantiate_example` does for inference but returns the + parsed atom array directly so we can assert on coordinate centering. + """ + spec = DesignInputSpecification.safe_init(**args) + pipeline_input = spec.to_pipeline_input(example_id="example") + return pipeline_input["atom_array"] + + +@pytest.mark.fast +def test_partial_diffusion_respects_ori_token(): + """User-supplied `ori_token` must shift the structure during partial diffusion. + + Regression test: previously `_set_origin` hard-coded `ori_token=None` for + partial diffusion, silently dropping the user's request. + """ + base = copy.deepcopy(TEST_JSON_DATA["partial_diffusion"]) + base.pop("ori_token", None) + base.pop("infer_ori_strategy", None) + + aa_default = _build_pipeline_atom_array(copy.deepcopy(base)) + + shift = np.array([50.0, 0.0, 0.0], dtype=np.float32) + args_shift = copy.deepcopy(base) + args_shift["ori_token"] = shift.tolist() + aa_shift = _build_pipeline_atom_array(args_shift) + + assert aa_default.array_length() == aa_shift.array_length() + delta = aa_shift.coord.mean(axis=0) - aa_default.coord.mean(axis=0) + # Coords are translated by -ori_token at parse time, so the post-parse + # whole-structure mean must drop by ~50 Å on x. + assert delta[0] == pytest.approx(-50.0, abs=1.5), ( + f"ori_token=[50,0,0] should shift coords by -50 in x; got delta={delta}" + ) + assert abs(delta[1]) < 1.5 and abs(delta[2]) < 1.5, ( + f"ori_token=[50,0,0] should not move y/z; got delta={delta}" + ) + + +@pytest.mark.fast +def test_partial_diffusion_defaults_to_diffused_region_com(): + """When neither `ori_token` nor `infer_ori_strategy` is supplied, partial + diffusion must center on the diffused-region COM (matches training + convention `center_option=diffuse`). + + Regression test: previously this branch centered on the joint + target+diffused COM, biasing the model to drag the diffused region toward + the motif's COM. + """ + base = copy.deepcopy(TEST_JSON_DATA["partial_diffusion"]) + base.pop("ori_token", None) + base.pop("infer_ori_strategy", None) + aa = _build_pipeline_atom_array(base) + + is_motif = aa.is_motif_atom_with_fixed_coord.astype(bool) + if not is_motif.any() or not (~is_motif).any(): + pytest.skip("Test fixture has no separable motif/diffused split") + + diffused_com = aa.coord[~is_motif].mean(axis=0) + assert np.allclose(diffused_com, 0, atol=1e-3), ( + f"diffused-region COM should be at origin after centering; got {diffused_com}" + ) + + if __name__ == "__main__": pytest.main(sys.argv)