Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions src/instructlab/training/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,19 @@ class TrainingArgs(BaseModel):
description="How often to evaluate validation loss (in training steps). Required when validation_split > 0.",
)

on_demand_checkpointing: bool = Field(
default=False,
description=(
"Enable on-demand full-state checkpointing triggered by Unix signals. "
"When enabled, the parent process intercepts termination signals "
"(SIGTERM, SIGINT, SIGUSR1, SIGUSR2, SIGXCPU, SIGHUP) and writes a "
"trigger file to /dev/shm. Worker processes check for this trigger "
"after each training step and collectively save a distributed "
"checkpoint before exiting gracefully. Designed for OpenShift AI / "
"KubeFlow training jobs where preemption signals must be handled."
),
)

@model_validator(mode="after")
def validate_validation_config(self):
if not 0.0 <= self.validation_split < 1.0:
Expand Down
102 changes: 95 additions & 7 deletions src/instructlab/training/main_ds.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def train(
accelerator: Accelerator,
val_data_loader=None,
validation_frequency=None,
on_demand_checkpointing: bool = False,
):
model.train()

Expand All @@ -183,6 +184,15 @@ def train(
metric_logger = logging.getLogger("instructlab.training.metrics")
base_logger = logging.getLogger("instructlab.training")

# Import on-demand checkpointing utilities once if the feature is enabled
if on_demand_checkpointing:
from instructlab.training.on_demand_checkpoint import (
check_checkpoint_requested,
save_on_demand_checkpoint,
)

base_logger.info("On-demand checkpointing is enabled in worker process.")

# Mini_trainer approach: batch_size will be determined dynamically by data loader
# For save logic, use effective_batch_size since that's the target
samples_seen = 0
Expand Down Expand Up @@ -308,6 +318,22 @@ def train(
base_logger.debug("RANK (%d) waiting at post-save barrier.", local_rank)
dist.barrier()

# --- On-demand checkpointing: check if a signal triggered a save ---
if on_demand_checkpointing and check_checkpoint_requested():
save_on_demand_checkpoint(
args=args,
accelerator=accelerator,
model=model,
tokenizer=model.tokenizer,
samples_seen=samples_seen,
epoch=epoch,
is_lora=bool(args.lora_r),
)
base_logger.info(
"On-demand checkpoint saved. Exiting training gracefully."
)
return

global_step += 1
if local_rank == 0:
inner_pb.update(1)
Expand Down Expand Up @@ -561,6 +587,7 @@ def main(args):
accelerator=accelerator,
val_data_loader=val_loader,
validation_frequency=validation_frequency,
on_demand_checkpointing=getattr(args, "on_demand_checkpointing", False),
)

dist.barrier()
Expand Down Expand Up @@ -791,7 +818,24 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
if train_args.keep_last_checkpoint_only:
command.append("--keep_last_checkpoint_only")

if train_args.on_demand_checkpointing:
command.append("--on_demand_checkpointing")

logger.info("Running training command as subprocess: %s", " ".join(command))

# --- On-demand checkpointing: install signal handlers in the parent ---
signal_handler = None
if train_args.on_demand_checkpointing:
# First Party
from instructlab.training.on_demand_checkpoint import ParentSignalHandler

signal_handler = ParentSignalHandler()
signal_handler.install()
logger.info(
"On-demand checkpointing is ENABLED. "
"Termination signals will trigger a full-state checkpoint before exit."
)

process = None
interrupt: KeyboardInterrupt | Exception | None = None
failure = False
Expand All @@ -811,19 +855,49 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
interrupt = e
finally:
if "process" not in locals() or process is None:
if signal_handler is not None:
signal_handler.uninstall()
return

# If a signal was caught by the on-demand checkpoint handler, give
# the workers time to detect the trigger file and save a checkpoint
# before we start sending our own signals to the subprocess.
if signal_handler is not None and signal_handler.signal_received is not None:
logger.info(
"On-demand checkpoint: signal %s received. Waiting for workers to "
"save checkpoint before proceeding with shutdown...",
signal_handler.signal_received.name,
)
# Give workers generous time to complete the checkpoint save.
# The workers will exit on their own after saving.
try:
process.wait(timeout=300)
except subprocess.TimeoutExpired:
logger.warning(
"On-demand checkpoint: workers did not finish within 300s. "
"Proceeding with shutdown."
)

# wait for the process to exit so we can properly read the exit code
process.wait(timeout=60)
try:
process.wait(timeout=60)
except subprocess.TimeoutExpired:
pass
process_code = process.poll()
failure = process_code != 0
failure = process_code is not None and process_code != 0

if not failure:
logger.info("Operation completed successfully! 🎉")
if process_code is not None and not failure:
logger.info("Operation completed successfully!")
else:
logger.error(
f"Training subprocess has not exited yet. Sending SIGTERM. Process code: {process_code}"
)
if process_code is None:
logger.error(
"Training subprocess has not exited yet. Sending SIGTERM."
)
else:
logger.error(
"Training subprocess exited with code %d. Sending SIGTERM.",
process_code,
)

process.terminate()
try:
Expand All @@ -835,6 +909,9 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
)
process.kill()

if signal_handler is not None:
signal_handler.uninstall()

if interrupt:
raise interrupt
if failure:
Expand Down Expand Up @@ -1045,6 +1122,17 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
),
)

parser.add_argument(
"--on_demand_checkpointing",
action="store_true",
default=False,
help=(
"Enable on-demand full-state checkpointing triggered by Unix signals. "
"When enabled, workers check for a trigger file in /dev/shm after each "
"training step and collectively save a distributed checkpoint before "
"exiting. Designed for OpenShift AI / KubeFlow preemption handling."
),
)
parser.add_argument(
"--use_liger",
action="store_true",
Expand Down
Loading
Loading