|
| 1 | +"""Launch a multi-step Qwen3.5 Megatron yes-no-maybe run on SkyPilot.""" |
| 2 | + |
| 3 | +import argparse |
| 4 | +import os |
| 5 | +import textwrap |
| 6 | + |
| 7 | +from dotenv import load_dotenv |
| 8 | +import sky |
| 9 | +from sky import ClusterStatus |
| 10 | + |
| 11 | +load_dotenv() |
| 12 | + |
| 13 | +DEFAULT_IMAGE_ID = "docker:nvidia/cuda:12.8.1-devel-ubuntu22.04" |
| 14 | + |
| 15 | + |
| 16 | +def _format_env_bool(value: bool) -> str: |
| 17 | + return "true" if value else "false" |
| 18 | + |
| 19 | + |
| 20 | +def _format_int_list(values: list[int]) -> str: |
| 21 | + return ",".join(str(value) for value in values) |
| 22 | + |
| 23 | + |
| 24 | +parser = argparse.ArgumentParser( |
| 25 | + description="Launch a Qwen3.5 Megatron yes-no-maybe convergence run." |
| 26 | +) |
| 27 | +parser.add_argument("--fast", action="store_true") |
| 28 | +parser.add_argument("--base-model", type=str, default="Qwen/Qwen3.5-35B-A3B") |
| 29 | +parser.add_argument("--accelerator", type=str, default="H200:2") |
| 30 | +parser.add_argument( |
| 31 | + "--cluster-name", type=str, default="art-qwen35-megatron-yes-no-maybe" |
| 32 | +) |
| 33 | +parser.add_argument("--image-id", type=str, default=DEFAULT_IMAGE_ID) |
| 34 | +parser.add_argument("--project", type=str, default="qwen35-megatron-ynm") |
| 35 | +parser.add_argument("--gpu-memory-utilization", type=float, default=0.65) |
| 36 | +parser.add_argument("--max-model-len", type=int, default=1024) |
| 37 | +parser.add_argument("--max-seq-length", type=int, default=1024) |
| 38 | +parser.add_argument("--max-num-seqs", type=int, default=8) |
| 39 | +parser.add_argument("--num-steps", type=int, default=10) |
| 40 | +parser.add_argument("--rollouts-per-prompt", type=int, default=8) |
| 41 | +parser.add_argument("--eval-prompts", type=int, default=24) |
| 42 | +parser.add_argument("--max-tokens", type=int, default=5) |
| 43 | +parser.add_argument("--learning-rate", type=float, default=5e-5) |
| 44 | +parser.add_argument( |
| 45 | + "--load-in-4bit", action=argparse.BooleanOptionalAction, default=False |
| 46 | +) |
| 47 | +parser.add_argument( |
| 48 | + "--load-in-16bit", action=argparse.BooleanOptionalAction, default=True |
| 49 | +) |
| 50 | +parser.add_argument("--trainer-gpu-ids", type=int, nargs="+", default=[0]) |
| 51 | +parser.add_argument("--inference-gpu-ids", type=int, nargs="+", default=[1]) |
| 52 | +args = parser.parse_args() |
| 53 | + |
| 54 | +cluster_name = args.cluster_name |
| 55 | +cluster_prefix = os.environ.get("CLUSTER_PREFIX") |
| 56 | +if cluster_prefix: |
| 57 | + cluster_name = f"{cluster_prefix}-{cluster_name}" |
| 58 | + |
| 59 | +setup_script = textwrap.dedent("""\ |
| 60 | + echo 'Setting up environment...' |
| 61 | + apt-get update |
| 62 | + apt-get install -y python3 python3-pip python-is-python3 git curl |
| 63 | + curl -LsSf https://astral.sh/uv/install.sh | sh |
| 64 | + source $HOME/.local/bin/env |
| 65 | +""") |
| 66 | + |
| 67 | +env = [ |
| 68 | + f"PROJECT={args.project}", |
| 69 | + "MODEL_NAME=qwen35-megatron-ynm-$(date +%Y%m%d-%H%M%S)", |
| 70 | + f"BASE_MODEL={args.base_model}", |
| 71 | + f"GPU_MEMORY_UTILIZATION={args.gpu_memory_utilization}", |
| 72 | + f"MAX_MODEL_LEN={args.max_model_len}", |
| 73 | + f"MAX_SEQ_LENGTH={args.max_seq_length}", |
| 74 | + f"MAX_NUM_SEQS={args.max_num_seqs}", |
| 75 | + f"LOAD_IN_4BIT={_format_env_bool(args.load_in_4bit)}", |
| 76 | + f"LOAD_IN_16BIT={_format_env_bool(args.load_in_16bit)}", |
| 77 | + f"NUM_STEPS={args.num_steps}", |
| 78 | + f"ROLLOUTS_PER_PROMPT={args.rollouts_per_prompt}", |
| 79 | + f"EVAL_PROMPTS={args.eval_prompts}", |
| 80 | + f"MAX_TOKENS={args.max_tokens}", |
| 81 | + f"LEARNING_RATE={args.learning_rate}", |
| 82 | + f"TRAINER_GPU_IDS={_format_int_list(args.trainer_gpu_ids)}", |
| 83 | + f"INFERENCE_GPU_IDS={_format_int_list(args.inference_gpu_ids)}", |
| 84 | + "ROLLOUT_WEIGHTS_MODE=merged", |
| 85 | +] |
| 86 | +env_block = " \\\n ".join(env) |
| 87 | + |
| 88 | +run_script = textwrap.dedent( |
| 89 | + f"""\ |
| 90 | + source $HOME/.local/bin/env |
| 91 | + cd ~/sky_workdir |
| 92 | + bash src/art/megatron/setup.sh |
| 93 | + {env_block} \\ |
| 94 | + ~/.local/bin/uv run dev/yes-no-maybe-megatron.py |
| 95 | +""" |
| 96 | +) |
| 97 | + |
| 98 | +task = sky.Task( |
| 99 | + name="qwen3.5-megatron-yes-no-maybe", |
| 100 | + setup=setup_script, |
| 101 | + run=run_script, |
| 102 | + workdir=".", |
| 103 | +) |
| 104 | +task.set_resources( |
| 105 | + sky.Resources( |
| 106 | + accelerators=args.accelerator, |
| 107 | + cloud=sky.clouds.Kubernetes(), |
| 108 | + image_id=args.image_id, |
| 109 | + ) |
| 110 | +) |
| 111 | +task.set_file_mounts({"~/sky_workdir/.env": ".env"}) |
| 112 | + |
| 113 | +print(f"Launching on cluster: {cluster_name}") |
| 114 | +print(f" base_model: {args.base_model}") |
| 115 | +print(f" project: {args.project}") |
| 116 | +print(f" accelerator: {args.accelerator}") |
| 117 | +print(f" image_id: {args.image_id}") |
| 118 | +print(f" gpu_memory_utilization: {args.gpu_memory_utilization}") |
| 119 | +print(f" max_model_len: {args.max_model_len}") |
| 120 | +print(f" max_seq_length: {args.max_seq_length}") |
| 121 | +print(f" max_num_seqs: {args.max_num_seqs}") |
| 122 | +print(f" num_steps: {args.num_steps}") |
| 123 | +print(f" rollouts_per_prompt: {args.rollouts_per_prompt}") |
| 124 | +print(f" eval_prompts: {args.eval_prompts}") |
| 125 | +print(f" max_tokens: {args.max_tokens}") |
| 126 | +print(f" learning_rate: {args.learning_rate}") |
| 127 | +print(f" load_in_4bit: {args.load_in_4bit}") |
| 128 | +print(f" load_in_16bit: {args.load_in_16bit}") |
| 129 | +print(f" trainer_gpu_ids: {args.trainer_gpu_ids}") |
| 130 | +print(f" inference_gpu_ids: {args.inference_gpu_ids}") |
| 131 | + |
| 132 | +cluster_status = sky.stream_and_get(sky.status(cluster_names=[cluster_name])) |
| 133 | +if cluster_status and cluster_status[0]["status"] == ClusterStatus.UP: |
| 134 | + print(f"Cluster {cluster_name} is UP. Canceling any active jobs...") |
| 135 | + sky.stream_and_get(sky.cancel(cluster_name, all=True)) |
| 136 | + |
| 137 | +job_id, _ = sky.stream_and_get( |
| 138 | + sky.launch( |
| 139 | + task, |
| 140 | + cluster_name=cluster_name, |
| 141 | + retry_until_up=True, |
| 142 | + idle_minutes_to_autostop=60, |
| 143 | + down=True, |
| 144 | + fast=args.fast, |
| 145 | + ) |
| 146 | +) |
| 147 | + |
| 148 | +print(f"Job submitted (ID: {job_id}). Streaming logs...") |
| 149 | +exit_code = sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True) |
| 150 | +print(f"Job {job_id} finished with exit code {exit_code}.") |
0 commit comments