diff --git a/cookbooks/maestro_early_stopping.ipynb b/cookbooks/maestro_early_stopping.ipynb new file mode 100644 index 00000000..aefde2dd --- /dev/null +++ b/cookbooks/maestro_early_stopping.ipynb @@ -0,0 +1,226 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0341f3f7", + "metadata": {}, + "source": [ + "# Multimodal Maestro: Using Early Stopping for Efficient Training\n", + "\n", + "This notebook demonstrates how to use the early stopping feature with Multimodal Maestro models to reduce training time and prevent overfitting." + ] + }, + { + "cell_type": "markdown", + "id": "2008f170", + "metadata": {}, + "source": [ + "## Introduction\n", + "\n", + "Early stopping is a regularization technique to prevent overfitting in machine learning models. It works by monitoring a validation metric (typically validation loss) and stopping training when the model performance on the validation set stops improving for a specified number of epochs.\n", + "\n", + "Benefits of early stopping:\n", + "1. Reduces training time\n", + "2. Prevents overfitting\n", + "3. Automatically determines optimal training duration\n", + "\n", + "In this notebook, we'll demonstrate how to enable early stopping with Florence-2 model training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7b0cd0d1", + "metadata": {}, + "outputs": [], + "source": [ + "# Install required packages\n", + "%pip install multimodal-maestro supervision --quiet" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e81adcdc", + "metadata": {}, + "outputs": [], + "source": [ + "# Import necessary libraries\n", + "import os\n", + "\n", + "from maestro.trainer.common.metrics import MeanAveragePrecisionMetric\n", + "from maestro.trainer.models.florence_2.core import Florence2Configuration, train" + ] + }, + { + "cell_type": "markdown", + "id": "fdc038f9", + "metadata": {}, + "source": [ + "## Downloading a sample dataset\n", + "\n", + "For this example, we'll use a small object detection dataset. You can replace this with your own dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e9e51e1", + "metadata": {}, + "outputs": [], + "source": [ + "# Download a sample dataset (chess pieces dataset)\n", + "%pip install roboflow\n", + "\n", + "from roboflow import Roboflow\n", + "\n", + "rf = Roboflow(api_key=\"YOUR_API_KEY\") # Replace with your API key or remove if using public datasets\n", + "project = rf.workspace(\"roboflow-100\").project(\"chess-pieces-detection\")\n", + "dataset = project.version(2).download(\"coco\")" + ] + }, + { + "cell_type": "markdown", + "id": "1888b2dc", + "metadata": {}, + "source": [ + "## Configuring Training with Early Stopping\n", + "\n", + "Now we'll set up the training configuration with early stopping enabled." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "da75158d", + "metadata": {}, + "outputs": [], + "source": [ + "# Get the dataset path\n", + "dataset_path = os.path.join(os.getcwd(), dataset.location)\n", + "\n", + "# Configure the training\n", + "config = Florence2Configuration(\n", + " dataset=dataset_path,\n", + " epochs=20, # Set a large enough number of epochs\n", + " batch_size=2, # Use a small batch size for this example\n", + " lr=1e-5,\n", + " optimization_strategy=\"lora\",\n", + " metrics=[MeanAveragePrecisionMetric()],\n", + " # Early stopping configuration\n", + " early_stopping=True, # Enable early stopping\n", + " early_stopping_patience=3, # Stop after 3 epochs with no improvement\n", + " early_stopping_threshold=0.01, # Minimum change to be considered as improvement\n", + " early_stopping_monitor=\"val_loss\", # Metric to monitor\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "705dc6a9", + "metadata": {}, + "source": [ + "## Training the Model with Early Stopping\n", + "\n", + "Now we'll start training the model. With early stopping enabled, training will automatically stop once the validation loss stops improving for 3 consecutive epochs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "84e11341", + "metadata": {}, + "outputs": [], + "source": [ + "# Train the model\n", + "train(config)" + ] + }, + { + "cell_type": "markdown", + "id": "438768a8", + "metadata": {}, + "source": [ + "## Visualizing Training Metrics\n", + "\n", + "After training completes, you can examine the training and validation metrics to see how early stopping worked." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8b259009", + "metadata": {}, + "outputs": [], + "source": [ + "import glob\n", + "\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "\n", + "# Find the most recent training run\n", + "runs = sorted(glob.glob(\"./training/florence_2/*\"))\n", + "latest_run = runs[-1] if runs else None\n", + "\n", + "if latest_run:\n", + " # Try to load the metrics\n", + " try:\n", + " metrics_dir = os.path.join(latest_run, \"metrics\")\n", + " train_loss = pd.read_csv(os.path.join(metrics_dir, \"train_loss.csv\"))\n", + " val_loss = pd.read_csv(os.path.join(metrics_dir, \"val_loss.csv\"))\n", + "\n", + " # Plot training and validation loss\n", + " plt.figure(figsize=(10, 5))\n", + " plt.plot(train_loss[\"epoch\"], train_loss[\"value\"], label=\"Training Loss\")\n", + " plt.plot(val_loss[\"epoch\"], val_loss[\"value\"], label=\"Validation Loss\")\n", + " plt.xlabel(\"Epoch\")\n", + " plt.ylabel(\"Loss\")\n", + " plt.legend()\n", + " plt.title(\"Training and Validation Loss (with Early Stopping)\")\n", + " plt.grid(True, linestyle=\"--\", alpha=0.7)\n", + " plt.show()\n", + "\n", + " # Show where early stopping occurred\n", + " best_epoch = val_loss[\"value\"].idxmin()\n", + " print(f\"Best epoch: {best_epoch}\")\n", + " print(f\"Best validation loss: {val_loss['value'].min()}\")\n", + " print(f\"Training stopped at epoch: {val_loss['epoch'].max()}\")\n", + " except Exception as e:\n", + " print(f\"Could not load metrics: {e}\")\n", + "else:\n", + " print(\"No training runs found\")" + ] + }, + { + "cell_type": "markdown", + "id": "2b52ad9a", + "metadata": {}, + "source": [ + "## Conclusion\n", + "\n", + "In this notebook, we've demonstrated how to use early stopping with the Florence-2 model in Multimodal Maestro. The same approach can be applied to other models like PaliGemma-2, and Qwen2.5-VL.\n", + "\n", + "Early stopping is a valuable technique for efficient model training, as it:\n", + "\n", + "1. Saves training time and computational resources\n", + "2. Automatically determines the optimal number of training epochs\n", + "3. Helps prevent overfitting\n", + "\n", + "By adjusting the `early_stopping_patience`, `early_stopping_threshold`, and `early_stopping_monitor` parameters, you can fine-tune the early stopping behavior to suit your specific training needs." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.11.9" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/cookbooks/maestro_early_stopping_example.py b/cookbooks/maestro_early_stopping_example.py new file mode 100644 index 00000000..dd10273b --- /dev/null +++ b/cookbooks/maestro_early_stopping_example.py @@ -0,0 +1,64 @@ +""" +Example script demonstrating how to enable early stopping in Maestro models. +This is useful to prevent overfitting and reduce training time when model +performance on the validation set has stopped improving. +""" + +from maestro.trainer.models.florence_2.core import Florence2Configuration +from maestro.trainer.models.florence_2.core import train as train_florence +from maestro.trainer.models.paligemma_2.core import PaliGemma2Configuration +from maestro.trainer.models.paligemma_2.core import train as train_paligemma +from maestro.trainer.models.qwen_2_5_vl.core import Qwen25VLConfiguration +from maestro.trainer.models.qwen_2_5_vl.core import train as train_qwen + + +# Example with Florence-2 model +def train_florence_with_early_stopping(): + """Train a Florence-2 model with early stopping enabled""" + config = Florence2Configuration( + dataset="path/to/your/dataset", # Replace with your dataset path + epochs=20, # Set a larger number of epochs + early_stopping=True, # Enable early stopping + early_stopping_patience=3, # Stop after 3 epochs without improvement + early_stopping_threshold=0.01, # Minimum change to be considered as improvement + early_stopping_monitor="val_loss", # Metric to monitor (default: val_loss) + ) + + train_florence(config) + + +# Example with PaliGemma-2 model +def train_paligemma_with_early_stopping(): + """Train a PaliGemma-2 model with early stopping enabled""" + config = PaliGemma2Configuration( + dataset="path/to/your/dataset", # Replace with your dataset path + epochs=20, # Set a larger number of epochs + early_stopping=True, # Enable early stopping + early_stopping_patience=5, # Stop after 5 epochs without improvement + early_stopping_threshold=0.001, # More sensitive to small improvements + early_stopping_monitor="val_loss", # Metric to monitor + ) + + train_paligemma(config) + + +# Example with Qwen2.5-VL model +def train_qwen_with_early_stopping(): + """Train a Qwen2.5-VL model with early stopping enabled""" + config = Qwen25VLConfiguration( + dataset="path/to/your/dataset", # Replace with your dataset path + epochs=20, # Set a larger number of epochs + early_stopping=True, # Enable early stopping + early_stopping_patience=3, # Stop after 3 epochs without improvement + early_stopping_threshold=0.01, # Minimum change to be considered as improvement + early_stopping_monitor="val_loss", # Metric to monitor + ) + + train_qwen(config) + + +if __name__ == "__main__": + # Choose one of the training functions to run + train_florence_with_early_stopping() + # train_paligemma_with_early_stopping() + # train_qwen_with_early_stopping() diff --git a/maestro/trainer/common/callbacks.py b/maestro/trainer/common/callbacks.py index 60d869bf..ae1e8648 100644 --- a/maestro/trainer/common/callbacks.py +++ b/maestro/trainer/common/callbacks.py @@ -3,7 +3,7 @@ from typing import Callable import lightning -from lightning.pytorch.callbacks import Callback +from lightning.pytorch.callbacks import Callback, EarlyStopping from maestro.trainer.common.training import MaestroTrainer, TModel, TProcessor @@ -26,3 +26,36 @@ def on_train_epoch_end(self, trainer: lightning.Trainer, pl_module: MaestroTrain def on_train_end(self, trainer: lightning.Trainer, pl_module: MaestroTrainer): pass + + +class EarlyStoppingCallback(EarlyStopping): + """ + Early stopping callback for PyTorch Lightning trainers. + + This callback stops training when a monitored metric has stopped improving. + + Attributes: + monitor (str): Quantity to be monitored. Default is 'val_loss'. + min_delta (float): Minimum change in monitored quantity to qualify as improvement. + patience (int): Number of validation epochs with no improvement after which training will be stopped. + mode (str): One of 'min', 'max'. In 'min' mode, training will stop when the quantity monitored + has stopped decreasing; in 'max' mode it will stop when the quantity monitored + has stopped increasing. Default is 'min'. + verbose (bool): Whether to print progress messages. + """ + + def __init__( + self, + monitor: str = "val_loss", + min_delta: float = 0.0, + patience: int = 3, + verbose: bool = True, + mode: str = "min", + ): + super().__init__( + monitor=monitor, + min_delta=min_delta, + patience=patience, + verbose=verbose, + mode=mode, + ) diff --git a/maestro/trainer/models/florence_2/core.py b/maestro/trainer/models/florence_2/core.py index a3c61f5b..65fd0fc6 100644 --- a/maestro/trainer/models/florence_2/core.py +++ b/maestro/trainer/models/florence_2/core.py @@ -86,6 +86,15 @@ class Florence2Configuration: Random seed for ensuring reproducibility. If None, no seeding is applied. peft_advanced_params (Optional[dict]): Custom LoRA configuration . If None, default configuration is applied. + early_stopping_patience (int): + Number of epochs with no improvement after which training will be stopped. + Only applies if early_stopping is True. Default is 3. + early_stopping (bool): + Whether to use early stopping. Default is False. + early_stopping_threshold (float): + Minimum change in monitored quantity to qualify as improvement. Default is 0.0. + early_stopping_monitor (str): + Quantity to be monitored for early stopping. Default is "val_loss". """ dataset: str @@ -106,6 +115,10 @@ class Florence2Configuration: max_new_tokens: int = 1024 random_seed: Optional[int] = None peft_advanced_params: Optional[dict] = None + early_stopping: bool = False + early_stopping_patience: int = 3 + early_stopping_threshold: float = 0.0 + early_stopping_monitor: str = "val_loss" def __post_init__(self): if self.val_batch_size is None: @@ -273,12 +286,29 @@ def train(config: Florence2Configuration | dict) -> None: ) save_checkpoints_path = os.path.join(config.output_dir, "checkpoints") save_checkpoint_callback = SaveCheckpoint(result_path=save_checkpoints_path, save_model_callback=save_model) + + callbacks = [save_checkpoint_callback] + + # Add early stopping if enabled + if config.early_stopping: + from maestro.trainer.common.callbacks import EarlyStoppingCallback + + early_stopping_callback = EarlyStoppingCallback( + monitor=config.early_stopping_monitor, + min_delta=config.early_stopping_threshold, + patience=config.early_stopping_patience, + verbose=True, + mode="min" if config.early_stopping_monitor == "val_loss" else "max", + ) + callbacks.append(early_stopping_callback) + logger.info(f"Early stopping enabled with patience {config.early_stopping_patience}") + trainer = lightning.Trainer( max_epochs=config.epochs, accumulate_grad_batches=config.accumulate_grad_batches, check_val_every_n_epoch=1, limit_val_batches=1, log_every_n_steps=10, - callbacks=[save_checkpoint_callback], + callbacks=callbacks, ) trainer.fit(pl_module) diff --git a/maestro/trainer/models/paligemma_2/core.py b/maestro/trainer/models/paligemma_2/core.py index e29e112f..6234923e 100644 --- a/maestro/trainer/models/paligemma_2/core.py +++ b/maestro/trainer/models/paligemma_2/core.py @@ -75,6 +75,15 @@ class PaliGemma2Configuration: Random seed for ensuring reproducibility. If None, no seeding is applied. peft_advanced_params (Optional[dict]): Custom LoRA configuration . If None, default configuration is applied. + early_stopping (bool): + Whether to use early stopping. Default is False. + early_stopping_patience (int): + Number of epochs with no improvement after which training will be stopped. + Only applies if early_stopping is True. Default is 3. + early_stopping_threshold (float): + Minimum change in monitored quantity to qualify as improvement. Default is 0.0. + early_stopping_monitor (str): + Quantity to be monitored for early stopping. Default is "val_loss". """ dataset: str @@ -95,6 +104,10 @@ class PaliGemma2Configuration: max_new_tokens: int = 512 random_seed: Optional[int] = None peft_advanced_params: Optional[dict] = None + early_stopping: bool = False + early_stopping_patience: int = 3 + early_stopping_threshold: float = 0.0 + early_stopping_monitor: str = "val_loss" def __post_init__(self): if self.val_batch_size is None: @@ -217,6 +230,7 @@ def train(config: PaliGemma2Configuration | dict) -> None: dataset_location = resolve_dataset_path(config.dataset) if dataset_location is None: return + train_loader, valid_loader, test_loader = create_data_loaders( dataset_location=dataset_location, train_batch_size=config.batch_size, @@ -236,12 +250,29 @@ def train(config: PaliGemma2Configuration | dict) -> None: ) save_checkpoints_path = os.path.join(config.output_dir, "checkpoints") save_checkpoint_callback = SaveCheckpoint(result_path=save_checkpoints_path, save_model_callback=save_model) + + callbacks = [save_checkpoint_callback] + + # Add early stopping if enabled + if config.early_stopping: + from maestro.trainer.common.callbacks import EarlyStoppingCallback + + early_stopping_callback = EarlyStoppingCallback( + monitor=config.early_stopping_monitor, + min_delta=config.early_stopping_threshold, + patience=config.early_stopping_patience, + verbose=True, + mode="min" if config.early_stopping_monitor == "val_loss" else "max", + ) + callbacks.append(early_stopping_callback) + logger.info(f"Early stopping enabled with patience {config.early_stopping_patience}") + trainer = lightning.Trainer( max_epochs=config.epochs, accumulate_grad_batches=config.accumulate_grad_batches, check_val_every_n_epoch=1, limit_val_batches=1, log_every_n_steps=10, - callbacks=[save_checkpoint_callback], + callbacks=callbacks, ) trainer.fit(pl_module) diff --git a/maestro/trainer/models/qwen_2_5_vl/core.py b/maestro/trainer/models/qwen_2_5_vl/core.py index 6918a4a7..0cea239d 100644 --- a/maestro/trainer/models/qwen_2_5_vl/core.py +++ b/maestro/trainer/models/qwen_2_5_vl/core.py @@ -67,6 +67,10 @@ class Qwen25VLConfiguration: max_new_tokens (int): Maximum number of new tokens generated during inference. random_seed (Optional[int]): Random seed for ensuring reproducibility. If `None`, no seed is set. peft_advanced_params (Optional[dict]): Custom LoRA configuration . If None, default configuration is applied. + early_stopping (bool): Whether to use early stopping. Default is False. + early_stopping_patience (int): Number of epochs with no improvement after which training will be stopped. + early_stopping_threshold (float): Minimum change in monitored quantity to qualify as improvement. + early_stopping_monitor (str): Quantity to be monitored for early stopping. """ dataset: str @@ -89,6 +93,10 @@ class Qwen25VLConfiguration: max_new_tokens: int = 1024 random_seed: Optional[int] = None peft_advanced_params: Optional[dict] = None + early_stopping: bool = False + early_stopping_patience: int = 3 + early_stopping_threshold: float = 0.0 + early_stopping_monitor: str = "val_loss" def __post_init__(self): if self.val_batch_size is None: @@ -261,6 +269,7 @@ def train(config: Qwen25VLConfiguration | dict) -> None: dataset_location = resolve_dataset_path(config.dataset) if dataset_location is None: return + train_loader, valid_loader, test_loader = create_data_loaders( dataset_location=dataset_location, train_batch_size=config.batch_size, @@ -284,12 +293,29 @@ def train(config: Qwen25VLConfiguration | dict) -> None: ) save_checkpoints_path = os.path.join(config.output_dir, "checkpoints") save_checkpoint_callback = SaveCheckpoint(result_path=save_checkpoints_path, save_model_callback=save_model) + + callbacks = [save_checkpoint_callback] + + # Add early stopping if enabled + if config.early_stopping: + from maestro.trainer.common.callbacks import EarlyStoppingCallback + + early_stopping_callback = EarlyStoppingCallback( + monitor=config.early_stopping_monitor, + min_delta=config.early_stopping_threshold, + patience=config.early_stopping_patience, + verbose=True, + mode="min" if config.early_stopping_monitor == "val_loss" else "max", + ) + callbacks.append(early_stopping_callback) + logger.info(f"Early stopping enabled with patience {config.early_stopping_patience}") + trainer = lightning.Trainer( max_epochs=config.epochs, accumulate_grad_batches=config.accumulate_grad_batches, check_val_every_n_epoch=1, limit_val_batches=1, log_every_n_steps=10, - callbacks=[save_checkpoint_callback], + callbacks=callbacks, ) trainer.fit(pl_module)