-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Add OPSD (On-Policy Distillation) training example #1002
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from 2 commits
14651ed
6b8f584
7580c28
0e1c004
d0000be
d3eda20
8bf134e
5d1d2e5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,232 @@ | ||
| # On-Policy Distillation (OPSD) on DeepSpeed | ||
|
|
||
| A DeepSpeed-native port of [HJSang/OPSD_OnPolicyDistillation](https://github.com/HJSang/OPSD_OnPolicyDistillation), | ||
| removing the verl dependency and building directly on DeepSpeed primitives | ||
| (ZeRO-3, hybrid engine, `deepspeed.initialize`). | ||
|
|
||
| On-policy distillation trains a small **student** model to imitate a large | ||
| frozen **teacher** on the student's *own* generated rollouts. Each training | ||
| step has three phases: | ||
|
|
||
| ``` | ||
| ┌────────────┐ prompts ┌──────────────────┐ prompt+response ┌────────────┐ | ||
| │ Dataloader │ ──────────▶ │ Student rollout │ ──────────────────▶ │ Teacher │ | ||
| └────────────┘ │ (hybrid / vLLM) │ │ forward │ | ||
| └──────────────────┘ └─────┬──────┘ | ||
| │ logits → CPU cache | ||
| ▼ | ||
| ┌─────────────────────┐ | ||
| │ Student forward + │ | ||
| │ streamed KL / JSD + │ | ||
| │ backward / step │ | ||
| └─────────────────────┘ | ||
| ``` | ||
|
|
||
| Loss = per-token divergence (`forward_kl` | `reverse_kl` | `jsd`) between | ||
| student and teacher distributions on the student's generated tokens, chunked | ||
| over the sequence axis so the full `[B, T, V]` teacher tensor never | ||
| co-resides with the student logits on the training device. | ||
|
|
||
| ## Layout | ||
|
|
||
| ``` | ||
| examples/opsd/ | ||
| ├── main.py # entry point (deepspeed launcher) | ||
| ├── opsd/ | ||
| │ ├── config.py # OPSDConfig dataclass + JSON loader | ||
| │ ├── losses.py # chunked / streamed KL & JSD | ||
| │ ├── teacher.py # frozen teacher + CPU logit cache | ||
| │ ├── trainer.py # three-phase training loop | ||
| │ ├── data.py # JSONL prompt dataset + left-pad collator | ||
| │ ├── utils.py # response-mask + shift helpers | ||
| │ └── rollout/ | ||
| │ ├── base.py # RolloutEngine ABC, request/batch dataclasses | ||
| │ ├── hybrid_engine.py # DeepSpeed hybrid-engine rollout | ||
| │ └── vllm.py # vLLM rollout on disjoint GPUs | ||
| ├── configs/ | ||
| │ ├── ds_zero3.json # base DeepSpeed ZeRO-3 + hybrid engine | ||
| │ ├── opsd_hybrid_engine.json # production-ish hybrid-engine OPSD config | ||
| │ ├── opsd_vllm_disjoint.json # vLLM rollout on a disjoint GPU group | ||
| │ ├── smoke_hybrid.json # 5-step smoke test with Qwen2.5-0.5B / 1.5B | ||
| │ ├── smoke_vllm.json # same but with vLLM rollout | ||
| │ └── smoke_ds_zero3.json # ZeRO-3 config tuned for smoke runs | ||
| ├── scripts/ | ||
| │ ├── train_opsd_hybrid.sh # launch hybrid-engine training | ||
| │ └── train_opsd_vllm.sh # launch vLLM training | ||
| └── tests/ # CPU-only unit tests (run with pytest) | ||
| ``` | ||
|
|
||
| ## Quick start | ||
|
|
||
| ### Install | ||
|
|
||
| ``` | ||
| pip install deepspeed transformers datasets accelerate | ||
| # Optional, only for the vLLM rollout backend: | ||
| pip install 'vllm>=0.6.4' | ||
| ``` | ||
|
|
||
| ### Hybrid-engine training (single-node, no vLLM) | ||
|
|
||
| ``` | ||
| cd examples/opsd | ||
| NUM_GPUS=8 bash scripts/train_opsd_hybrid.sh configs/opsd_hybrid_engine.json | ||
| ``` | ||
|
|
||
| The hybrid engine path lives entirely within DeepSpeed: the student engine | ||
| both trains and generates, sharing weights without a copy step. Easiest to | ||
| get running; slower generation than vLLM. | ||
|
|
||
| ### vLLM training (disjoint GPU group) | ||
|
|
||
| ``` | ||
| cd examples/opsd | ||
| # Train on GPUs 0..5, run vLLM on 6,7 (matches default config) | ||
| NUM_TRAIN_GPUS=6 INCLUDE_GPUS=0,1,2,3,4,5 \ | ||
| bash scripts/train_opsd_vllm.sh configs/opsd_vllm_disjoint.json | ||
| ``` | ||
|
|
||
| vLLM gets dedicated GPUs via the `ROLLOUT_VISIBLE_DEVICE` environment | ||
| variable (comma-separated CUDA device indices, e.g. | ||
| `ROLLOUT_VISIBLE_DEVICE=6,7`). Training rank 0 spawns the vLLM server as | ||
| a subprocess with `CUDA_VISIBLE_DEVICES` set to those devices; other | ||
| training ranks receive generated token ids via NCCL broadcast. | ||
|
|
||
| ### Smoke tests (5 steps, small models) | ||
|
|
||
| The `smoke_*.json` configs run on 2 GPUs in a few minutes with Qwen2.5-0.5B | ||
| (student) and Qwen2.5-1.5B (teacher), so the full pipeline can be validated | ||
| end-to-end before scaling up. | ||
|
|
||
| ``` | ||
| cd examples/opsd | ||
| deepspeed --num_gpus 2 main.py --config configs/smoke_hybrid.json | ||
| # For vLLM (uses GPUs 0,1 for training and 2,3 for vLLM): | ||
| NUM_TRAIN_GPUS=2 INCLUDE_GPUS=0,1 ROLLOUT_VISIBLE_DEVICE=2,3 \ | ||
| deepspeed --num_gpus 2 --include localhost:0,1 \ | ||
| main.py --config configs/smoke_vllm.json | ||
| ``` | ||
|
|
||
| ## Unit tests | ||
|
|
||
| The CPU-runnable test suite exercises the loss math, teacher caching, rollout | ||
| contract, and vLLM stitch logic. Run with: | ||
|
|
||
| ``` | ||
| cd examples/opsd | ||
| python -m pytest tests/ -v | ||
| ``` | ||
|
|
||
| ## Configuration | ||
|
|
||
| `OPSDConfig` is a plain dataclass loaded from JSON (no Hydra). The schema: | ||
|
|
||
| ```json | ||
| { | ||
| "student": { "model_name_or_path": "...", "dtype": "bfloat16", "arch": "qwen2" }, | ||
| "teacher": { "model_name_or_path": "...", "dtype": "bfloat16", "offload_to_cpu": true }, | ||
| "rollout": { "engine": "hybrid_engine | vllm", ... }, | ||
| "distillation": { "loss_type": "reverse_kl", "temperature": 1.0, "chunk_size": 512 }, | ||
| "training": { "train_batch_size": 8, "learning_rate": 1e-6, ... }, | ||
| "data": { "path": "data/prompts.jsonl", "prompt_field": "prompt" }, | ||
| "deepspeed_config": "configs/ds_zero3.json" | ||
| } | ||
| ``` | ||
|
|
||
| See `configs/opsd_hybrid_engine.json` and `configs/opsd_vllm_disjoint.json` | ||
| for fully-populated examples. | ||
|
|
||
| **GPU placement for vLLM rollout:** The GPUs available to the vLLM server | ||
| are controlled by the `ROLLOUT_VISIBLE_DEVICE` environment variable | ||
| (comma-separated CUDA device indices, e.g. `ROLLOUT_VISIBLE_DEVICE=6,7`), | ||
| not by a field in the JSON config. This keeps the vLLM device assignment | ||
| decoupled from the DeepSpeed launcher's own `CUDA_VISIBLE_DEVICES` / | ||
| `--include` flags, which control only the training ranks. | ||
|
|
||
| ## Adding a new model architecture | ||
|
|
||
| No special steps are needed for new model architectures. vLLM's RLHF weight | ||
| transfer API handles TP slicing internally; the caller only needs to send full | ||
| tensors. | ||
|
|
||
| ## Design notes | ||
|
|
||
| * **Why CPU-cache the teacher logits?** Holding both student and teacher | ||
| `[B, T, V]` tensors on GPU at once doubles memory pressure. Staging the | ||
| teacher to host between the teacher forward and the student backward halves | ||
| the worst-case GPU footprint of the loss path. The streamed loss | ||
| (`losses.streamed_distillation_loss`) pulls teacher chunks back to GPU | ||
| one sequence slice at a time so the full tensor never re-materialises. | ||
|
|
||
| * **Why an abstract `RolloutEngine`?** The hybrid-engine and vLLM backends | ||
| have very different lifecycles (hybrid engine reads student weights live; | ||
| vLLM holds its own copy and must be synced) but the trainer should not | ||
| care. The ABC keeps the trainer engine-agnostic so additional backends | ||
| (e.g. a future colocated-vLLM-with-`sleep_mode`) drop in without touching | ||
| the loop. | ||
|
|
||
| * **vLLM topology = disjoint, not colocated (v1).** The disjoint topology is | ||
| simpler to debug — failures in vLLM don't take down training and vice | ||
| versa. A colocated topology using vLLM 0.6.4+'s `sleep_mode` is planned as | ||
| a follow-up. | ||
|
|
||
| * **Weight sync uses vLLM's RLHF API.** vLLM 0.22.0+ exposes | ||
| ``/update_weights`` which handles TP slicing internally. The trainer | ||
| sends full tensors and vLLM distributes them. | ||
|
|
||
| ## vLLM status | ||
|
|
||
| The vLLM rollout (`opsd/rollout/vllm.py`) is **written and unit-tested but | ||
| not yet usable under the DeepSpeed launcher**. During live validation on | ||
| 4× H200 we hit a blocking issue: | ||
|
|
||
| > vLLM's worker init calls `new_group(...)` on the global process group as | ||
| > a collective. Under `deepspeed --num_gpus N`, the world is all `N` | ||
| > training ranks but only rank 0 calls into vLLM, so the constructor hangs | ||
| > waiting on the other ranks. Reproduced with vllm 0.6.6 + deepspeed 0.15.4 + | ||
| > torch 2.5.1. Standalone vLLM (world size 1) works in seconds. | ||
|
|
||
| The fix requires running vLLM in a **separate top-level Python process** | ||
| with its own world, accessed over HTTP/RPC from the trainer — the pattern | ||
| used by TRL and OpenRLHF. That's a larger refactor than fits in this PR; | ||
| the current `VLLMRollout` will be the basis for it once landed. | ||
|
|
||
| What's verified for the vLLM path today: | ||
| * `tests/test_vllm_stitch.py` — prompt + response stitching (CPU unit test) | ||
| * `vllm.LLM` itself runs fine standalone on Qwen2.5-0.5B (validated) | ||
|
|
||
| What's **not** verified: | ||
| * End-to-end training loop with `rollout.engine = "vllm"` in `OPSDConfig` | ||
| * `LLM.collective_rpc("load_weights", ...)` weight sync at training time | ||
|
|
||
| The hybrid-engine path (`rollout.engine = "hybrid_engine"`) is validated | ||
| end-to-end on the same hardware. | ||
|
|
||
| ## Other known limitations (v1) | ||
|
|
||
| * **vLLM weight sync (when it works) goes through pickle** — | ||
| `LLM.collective_rpc("load_weights", args=((name, tensor_on_cpu),))`. | ||
| Expect several seconds per sync on a 7B model. A faster v2 would broadcast | ||
| tensors via NCCL on a shared trainer↔vLLM process group — see verl's | ||
| `bucketed_weight_transfer.py` for a reference design. | ||
| * **vLLM `tensor_parallel_size > 1` is untested.** The weight bridge's | ||
| slicing math is unit-tested but no live run exists. | ||
| * **Reward-weighted distillation** (OPSD's `opd.reward_beta` knob) is not | ||
| ported. Easy to add: scale `per_tok` by a reward weight in the loss path. | ||
| * **GRPO and other on-policy RL recipes** are out of scope. The | ||
| `RolloutEngine` / `WeightBridge` abstractions are reusable, but a GRPO | ||
| trainer would add its own advantage / KL-to-reference logic on top. | ||
| * **Qwen3-MoE** is not covered. Add `weight_bridge/qwen3_moe.py` when needed. | ||
| * **Hybrid engine on Qwen-family models uses a ZeRO-3 fallback** (no | ||
| hybrid-engine inference acceleration), since DeepSpeed's inference policy | ||
| list only covers GPT2/GPT-NeoX/OPT/BLOOM/LLAMA/LLAMA2/InternLM as of 0.15. | ||
| The fallback gathers params via `GatheredParameters` and calls the HF | ||
| model's `generate` directly — correct, just ~3-5x slower than the | ||
| accelerated path. | ||
|
|
||
| ## References | ||
|
|
||
| * OPSD reference repo: <https://github.com/HJSang/OPSD_OnPolicyDistillation> | ||
| * DeepSpeed hybrid engine: `deepspeed/runtime/hybrid_engine.py` | ||
| * verl rollout / weight-sync design (used as a cross-check): | ||
| <https://github.com/volcengine/verl/tree/main/verl/workers/rollout/vllm_rollout> | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| { | ||
| "bf16": { | ||
| "enabled": true | ||
| }, | ||
| "zero_optimization": { | ||
| "stage": 3, | ||
| "overlap_comm": true, | ||
| "contiguous_gradients": true, | ||
| "reduce_bucket_size": 5e7, | ||
| "stage3_prefetch_bucket_size": 5e7, | ||
| "stage3_param_persistence_threshold": 1e6, | ||
| "stage3_max_live_parameters": 1e9, | ||
| "stage3_max_reuse_distance": 1e9, | ||
| "stage3_gather_16bit_weights_on_model_save": true | ||
| }, | ||
| "optimizer": { | ||
| "type": "AdamW", | ||
| "params": { | ||
| "lr": 1e-6, | ||
| "betas": [0.9, 0.95], | ||
| "eps": 1e-8, | ||
| "weight_decay": 0.0 | ||
| } | ||
| }, | ||
| "scheduler": { | ||
| "type": "WarmupLR", | ||
| "params": { | ||
| "warmup_min_lr": 0, | ||
| "warmup_max_lr": 1e-6, | ||
| "warmup_num_steps": 0 | ||
| } | ||
| }, | ||
| "gradient_clipping": 1.0, | ||
| "hybrid_engine": { | ||
| "enabled": true, | ||
| "max_out_tokens": 2048, | ||
| "inference_tp_size": 1, | ||
| "release_inference_cache": false, | ||
| "pin_parameters": true, | ||
| "tp_gather_partition_size": 8 | ||
| }, | ||
| "wall_clock_breakdown": false | ||
| } |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,48 @@ | ||
| { | ||
| "student": { | ||
| "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct", | ||
| "dtype": "bfloat16", | ||
| "trust_remote_code": false, | ||
| }, | ||
| "teacher": { | ||
| "model_name_or_path": "Qwen/Qwen2.5-Math-7B-Instruct", | ||
| "dtype": "bfloat16", | ||
| "trust_remote_code": false, | ||
| "offload_to_cpu": true | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. offload decides whether teacher needs to be offloaded to free up memory, so it needs to be set under teacher configuration. |
||
| }, | ||
| "rollout": { | ||
| "engine": "hybrid_engine", | ||
| "max_prompt_length": 1024, | ||
| "max_response_length": 1024, | ||
| "temperature": 0, | ||
| "top_p": 1.0, | ||
| "top_k": -1, | ||
| "n_samples_per_prompt": 1, | ||
| "weight_sync_interval": 1 | ||
| }, | ||
| "distillation": { | ||
| "loss_type": "reverse_kl", | ||
| "temperature": 0, | ||
| "chunk_size": 512 | ||
| }, | ||
| "training": { | ||
| "train_batch_size": 1, | ||
| "micro_batch_size_per_gpu": 1, | ||
| "gradient_accumulation_steps": 1, | ||
| "learning_rate": 1e-6, | ||
| "weight_decay": 0.0, | ||
| "num_train_epochs": 1, | ||
| "max_steps": -1, | ||
| "warmup_steps": 0, | ||
| "save_steps": 500, | ||
| "logging_steps": 10, | ||
| "save_dir": "./opsd_ckpt_hybrid", | ||
| "seed": 42 | ||
| }, | ||
| "data": { | ||
| "path": "data/prompts.jsonl", | ||
| "prompt_field": "prompt", | ||
| "shuffle": true | ||
| }, | ||
| "deepspeed_config": "configs/ds_zero3.json" | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,53 @@ | ||
| { | ||
| "student": { | ||
| "model_name_or_path": "Qwen/Qwen2.5-0.5B-Instruct", | ||
| "dtype": "bfloat16", | ||
| "trust_remote_code": false, | ||
| }, | ||
| "teacher": { | ||
| "model_name_or_path": "Qwen/Qwen2.5-Math-7B-Instruct", | ||
| "dtype": "bfloat16", | ||
| "trust_remote_code": false, | ||
| "offload_to_cpu": true | ||
| }, | ||
| "rollout": { | ||
| "engine": "vllm", | ||
| "max_prompt_length": 1024, | ||
| "max_response_length": 1024, | ||
| "temperature": 0, | ||
| "top_p": 1.0, | ||
| "top_k": -1, | ||
| "n_samples_per_prompt": 1, | ||
| "tensor_parallel_size": 2, | ||
| "gpu_memory_utilization": 0.85, | ||
| "engine_dtype": "bfloat16", | ||
| "weight_sync_interval": 4, | ||
| "vllm_min_version": "0.6.4", | ||
| "vllm_port": 8000 | ||
| }, | ||
| "distillation": { | ||
| "loss_type": "reverse_kl", | ||
| "temperature": 0, | ||
| "chunk_size": 512 | ||
| }, | ||
| "training": { | ||
| "train_batch_size": 1, | ||
| "micro_batch_size_per_gpu": 1, | ||
| "gradient_accumulation_steps": 1, | ||
| "learning_rate": 1e-6, | ||
| "weight_decay": 0.0, | ||
| "num_train_epochs": 1, | ||
| "max_steps": -1, | ||
| "warmup_steps": 0, | ||
| "save_steps": 500, | ||
| "logging_steps": 10, | ||
| "save_dir": "./opsd_ckpt_vllm", | ||
| "seed": 42 | ||
| }, | ||
| "data": { | ||
| "path": "data/prompts.jsonl", | ||
| "prompt_field": "prompt", | ||
| "shuffle": true | ||
| }, | ||
| "deepspeed_config": "configs/ds_zero3.json" | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It seems this folder structure exists in DS not DSE, right?