Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deepmd/pd/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1038,7 +1038,11 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
if JIT:
break

if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0):
if (
self.change_bias_after_training
and self.num_steps > self.start_step
and (self.rank == 0 or dist.get_rank() == 0)
):
if not self.multi_task:
self.model = model_change_out_bias(
self.model,
Expand Down
6 changes: 5 additions & 1 deletion deepmd/pt/train/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -1745,7 +1745,11 @@ def log_loss_valid(_task_key: str = "Default") -> dict:
if JIT:
break

if self.change_bias_after_training and (self.rank == 0 or dist.get_rank() == 0):
if (
self.change_bias_after_training
and self.num_steps > self.start_step
and (self.rank == 0 or dist.get_rank() == 0)
):
if not self.multi_task:
self.model = model_change_out_bias(
self.model,
Expand Down
16 changes: 16 additions & 0 deletions source/tests/pd/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)

import numpy as np
import paddle

from deepmd.pd.entrypoints.main import (
get_trainer,
Expand Down Expand Up @@ -163,6 +164,21 @@ def setUp(self) -> None:
self.config["training"]["save_freq"] = 1
enable_prim(True)

def test_zero_step_with_change_bias_saves_initial_checkpoint(self) -> None:
config = deepcopy(self.config)
config["training"]["numb_steps"] = 0
config["training"]["change_bias_after_training"] = True
trainer = get_trainer(config)
trainer.run()

self.assertEqual(Path("model.ckpt-0.pd"), trainer.latest_model)
self.assertTrue(Path("model.ckpt-0.pd").exists())
self.assertEqual(Path("model.ckpt-0.pd"), Path("checkpoint").read_text())
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
checkpoint = paddle.load("model.ckpt-0.pd")
train_infos = checkpoint["model"]["_extra_state"]["train_infos"]
self.assertEqual(0, train_infos["step"])
self.assertEqual(0.0, train_infos["lr"])

Comment thread
coderabbitai[bot] marked this conversation as resolved.
def tearDown(self) -> None:
DPTrainTest.tearDown(self)

Expand Down
17 changes: 17 additions & 0 deletions source/tests/pt/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,23 @@ def test_yaml_input(self) -> None:
)
self.assertTrue(Path("out.json").exists())

def test_zero_step_with_change_bias_saves_initial_checkpoint(self) -> None:
config = deepcopy(self.config)
config["training"]["numb_steps"] = 0
config["training"]["change_bias_after_training"] = True
trainer = get_trainer(config)
trainer.run()

self.assertEqual(Path("model.ckpt-0.pt"), trainer.latest_model)
self.assertTrue(Path("model.ckpt-0.pt").exists())
self.assertEqual(Path("model.ckpt-0.pt"), Path("checkpoint").read_text())
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
checkpoint = torch.load(
"model.ckpt-0.pt", map_location="cpu", weights_only=True
)
train_infos = checkpoint["model"]["_extra_state"]["train_infos"]
self.assertEqual(0, train_infos["step"])
self.assertEqual(0.0, train_infos["lr"])

Comment thread
coderabbitai[bot] marked this conversation as resolved.
def tearDown(self) -> None:
DPTrainTest.tearDown(self)
for ff in ["out.json", "input.yaml"]:
Expand Down
Loading