Skip to content

Commit 3ad7d06

Browse files
committed
feat: Add Qwen3.5 Megatron smoke runners
1 parent d712ce9 commit 3ad7d06

2 files changed

Lines changed: 439 additions & 51 deletions

File tree

Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
"""Launch a multi-step Qwen3.5 Megatron yes-no-maybe run on SkyPilot."""
2+
3+
import argparse
4+
import os
5+
import textwrap
6+
7+
from dotenv import load_dotenv
8+
import sky
9+
from sky import ClusterStatus
10+
11+
load_dotenv()
12+
13+
DEFAULT_IMAGE_ID = "docker:nvidia/cuda:12.8.1-devel-ubuntu22.04"
14+
15+
16+
def _format_env_bool(value: bool) -> str:
17+
return "true" if value else "false"
18+
19+
20+
def _format_int_list(values: list[int]) -> str:
21+
return ",".join(str(value) for value in values)
22+
23+
24+
parser = argparse.ArgumentParser(
25+
description="Launch a Qwen3.5 Megatron yes-no-maybe convergence run."
26+
)
27+
parser.add_argument("--fast", action="store_true")
28+
parser.add_argument("--base-model", type=str, default="Qwen/Qwen3.5-35B-A3B")
29+
parser.add_argument("--accelerator", type=str, default="H200:2")
30+
parser.add_argument(
31+
"--cluster-name", type=str, default="art-qwen35-megatron-yes-no-maybe"
32+
)
33+
parser.add_argument("--image-id", type=str, default=DEFAULT_IMAGE_ID)
34+
parser.add_argument("--project", type=str, default="qwen35-megatron-ynm")
35+
parser.add_argument("--gpu-memory-utilization", type=float, default=0.65)
36+
parser.add_argument("--max-model-len", type=int, default=1024)
37+
parser.add_argument("--max-seq-length", type=int, default=1024)
38+
parser.add_argument("--max-num-seqs", type=int, default=8)
39+
parser.add_argument("--num-steps", type=int, default=10)
40+
parser.add_argument("--rollouts-per-prompt", type=int, default=8)
41+
parser.add_argument("--eval-prompts", type=int, default=24)
42+
parser.add_argument("--max-tokens", type=int, default=5)
43+
parser.add_argument("--learning-rate", type=float, default=5e-5)
44+
parser.add_argument(
45+
"--load-in-4bit", action=argparse.BooleanOptionalAction, default=False
46+
)
47+
parser.add_argument(
48+
"--load-in-16bit", action=argparse.BooleanOptionalAction, default=True
49+
)
50+
parser.add_argument("--trainer-gpu-ids", type=int, nargs="+", default=[0])
51+
parser.add_argument("--inference-gpu-ids", type=int, nargs="+", default=[1])
52+
args = parser.parse_args()
53+
54+
cluster_name = args.cluster_name
55+
cluster_prefix = os.environ.get("CLUSTER_PREFIX")
56+
if cluster_prefix:
57+
cluster_name = f"{cluster_prefix}-{cluster_name}"
58+
59+
setup_script = textwrap.dedent("""\
60+
echo 'Setting up environment...'
61+
apt-get update
62+
apt-get install -y python3 python3-pip python-is-python3 git curl
63+
curl -LsSf https://astral.sh/uv/install.sh | sh
64+
source $HOME/.local/bin/env
65+
""")
66+
67+
env = [
68+
f"PROJECT={args.project}",
69+
"MODEL_NAME=qwen35-megatron-ynm-$(date +%Y%m%d-%H%M%S)",
70+
f"BASE_MODEL={args.base_model}",
71+
f"GPU_MEMORY_UTILIZATION={args.gpu_memory_utilization}",
72+
f"MAX_MODEL_LEN={args.max_model_len}",
73+
f"MAX_SEQ_LENGTH={args.max_seq_length}",
74+
f"MAX_NUM_SEQS={args.max_num_seqs}",
75+
f"LOAD_IN_4BIT={_format_env_bool(args.load_in_4bit)}",
76+
f"LOAD_IN_16BIT={_format_env_bool(args.load_in_16bit)}",
77+
f"NUM_STEPS={args.num_steps}",
78+
f"ROLLOUTS_PER_PROMPT={args.rollouts_per_prompt}",
79+
f"EVAL_PROMPTS={args.eval_prompts}",
80+
f"MAX_TOKENS={args.max_tokens}",
81+
f"LEARNING_RATE={args.learning_rate}",
82+
f"TRAINER_GPU_IDS={_format_int_list(args.trainer_gpu_ids)}",
83+
f"INFERENCE_GPU_IDS={_format_int_list(args.inference_gpu_ids)}",
84+
"ROLLOUT_WEIGHTS_MODE=merged",
85+
]
86+
env_block = " \\\n ".join(env)
87+
88+
run_script = textwrap.dedent(
89+
f"""\
90+
source $HOME/.local/bin/env
91+
cd ~/sky_workdir
92+
bash src/art/megatron/setup.sh
93+
{env_block} \\
94+
~/.local/bin/uv run dev/yes-no-maybe-megatron.py
95+
"""
96+
)
97+
98+
task = sky.Task(
99+
name="qwen3.5-megatron-yes-no-maybe",
100+
setup=setup_script,
101+
run=run_script,
102+
workdir=".",
103+
)
104+
task.set_resources(
105+
sky.Resources(
106+
accelerators=args.accelerator,
107+
cloud=sky.clouds.Kubernetes(),
108+
image_id=args.image_id,
109+
)
110+
)
111+
task.set_file_mounts({"~/sky_workdir/.env": ".env"})
112+
113+
print(f"Launching on cluster: {cluster_name}")
114+
print(f" base_model: {args.base_model}")
115+
print(f" project: {args.project}")
116+
print(f" accelerator: {args.accelerator}")
117+
print(f" image_id: {args.image_id}")
118+
print(f" gpu_memory_utilization: {args.gpu_memory_utilization}")
119+
print(f" max_model_len: {args.max_model_len}")
120+
print(f" max_seq_length: {args.max_seq_length}")
121+
print(f" max_num_seqs: {args.max_num_seqs}")
122+
print(f" num_steps: {args.num_steps}")
123+
print(f" rollouts_per_prompt: {args.rollouts_per_prompt}")
124+
print(f" eval_prompts: {args.eval_prompts}")
125+
print(f" max_tokens: {args.max_tokens}")
126+
print(f" learning_rate: {args.learning_rate}")
127+
print(f" load_in_4bit: {args.load_in_4bit}")
128+
print(f" load_in_16bit: {args.load_in_16bit}")
129+
print(f" trainer_gpu_ids: {args.trainer_gpu_ids}")
130+
print(f" inference_gpu_ids: {args.inference_gpu_ids}")
131+
132+
cluster_status = sky.stream_and_get(sky.status(cluster_names=[cluster_name]))
133+
if cluster_status and cluster_status[0]["status"] == ClusterStatus.UP:
134+
print(f"Cluster {cluster_name} is UP. Canceling any active jobs...")
135+
sky.stream_and_get(sky.cancel(cluster_name, all=True))
136+
137+
job_id, _ = sky.stream_and_get(
138+
sky.launch(
139+
task,
140+
cluster_name=cluster_name,
141+
retry_until_up=True,
142+
idle_minutes_to_autostop=60,
143+
down=True,
144+
fast=args.fast,
145+
)
146+
)
147+
148+
print(f"Job submitted (ID: {job_id}). Streaming logs...")
149+
exit_code = sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True)
150+
print(f"Job {job_id} finished with exit code {exit_code}.")

0 commit comments

Comments
 (0)