Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 25 additions & 0 deletions codelabs/gke/rl-sandbox-intro/Dockerfile.gpu_worker
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# ==============================================================================
# Base Image: Use the official vLLM production image.
# ==============================================================================
FROM vllm/vllm-openai:latest

USER root

# Install system dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
numactl \
libnuma-dev \
wget \
ca-certificates \
&& apt-get clean \
&& rm -rf /var/lib/apt/lists/*

# Install Ray, TRL, and Sandbox tools
# TRL does not require compiling flash_attn from source.
RUN pip install --no-cache-dir \
"ray[default]==2.55.1" \
"numpy<2.0" \
gymnasium>=0.28.1 \
k8s-agent-sandbox>=0.4.6 \
trl transformers packaging ninja cachetools accelerate datasets peft
49 changes: 49 additions & 0 deletions codelabs/gke/rl-sandbox-intro/Dockerfile.sandbox
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Use a stable Debian-based Miniconda image
FROM condaforge/miniforge3:latest

# 1. Install essential system libraries (including sqlite3 for Django tests)
RUN apt-get update && apt-get install -y \
git \
build-essential \
libsqlite3-dev \
&& rm -rf /var/lib/apt/lists/*
Comment on lines +5 to +9

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To keep the Docker image size as small as possible, it is recommended to use the --no-install-recommends flag with apt-get install.

RUN apt-get update && apt-get install -y --no-install-recommends \
    git \
    build-essential \
    libsqlite3-dev \
    && rm -rf /var/lib/apt/lists/*


# 2. Set up the /workspace directory and grant ownership to the pre-existing non-root 'ubuntu' user (UID 1000)
RUN mkdir -p /workspace \
&& chown -R 1000:1000 /workspace

# 3. Switch to the non-root user
USER ubuntu
WORKDIR /workspace

# 4. Pre-configure Git globally so the agent can run git commands
RUN git config --global user.email "agent@gke-sandbox.local" \
&& git config --global user.name "Agent"

# 5. Pre-clone the repository as the non-root user
RUN git clone https://github.com/django/django.git .

# 6. Pre-build Conda environments and pre-cache common dependencies
# We do NOT run "pip install -e ." here to avoid Python version conflicts with the main branch.
# Instead, we pre-install the heavy dependencies so that runtime installation is instantaneous.
RUN conda create -y -n django-py39 python=3.9 \
&& conda run -n django-py39 pip install --no-cache-dir asgiref sqlparse tzdata pytest pytest-django

RUN conda create -y -n django-py310 python=3.10 \
&& conda run -n django-py310 pip install --no-cache-dir asgiref sqlparse tzdata pytest pytest-django

# --- Add Agent Server ---
# We use a multi-stage build to copy the agent server from the official python-runtime-sandbox image
COPY --from=registry.k8s.io/agent-sandbox/python-runtime-sandbox:v0.1.0 /app /opt/sandbox-agent
USER root
RUN chown -R 1000:1000 /opt/sandbox-agent \
&& /opt/conda/bin/pip install --no-cache-dir -r /opt/sandbox-agent/requirements.txt \
&& sed -i 's|"/app"|"/workspace"|g' /opt/sandbox-agent/main.py
USER ubuntu
# ------------------------

# Prepend the django-py39 conda environment bin to PATH for commands executed inside the container
ENV PATH=/home/ubuntu/.conda/envs/django-py39/bin:$PATH

# Keep the container alive and run the agent server using the system Python
CMD ["/opt/conda/bin/python3", "-m", "uvicorn", "main:app", "--host", "0.0.0.0", "--port", "8888", "--log-level", "trace", "--app-dir", "/opt/sandbox-agent"]
12 changes: 12 additions & 0 deletions codelabs/gke/rl-sandbox-intro/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# High-Performance Distributed RL Sandbox

This directory contains the code samples and configuration files for the Google Kubernetes Engine (GKE) codelab: **High-Performance Distributed RL Sandbox**.

## Purpose

These files provide a hands-on environment for setting up and running distributed Reinforcement Learning (RL) training workloads on GKE. The codelab demonstrates how to build a scalable and secure sandbox environment using Ray and GKE features like sandbox routers and warm pools.

## Codelab

To follow the complete step-by-step guide and learn how to use these files, please visit the full codelab:
[High-Performance Distributed RL Sandbox](https://codelabs.developers.google.com/codelabs/gke/high-performance-distributed-rl-sandbox)
17 changes: 17 additions & 0 deletions codelabs/gke/rl-sandbox-intro/network_policy.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
apiVersion: networking.k8s.io/v1
kind: NetworkPolicy
metadata:
name: block-metadata-egress
namespace: default
spec:
podSelector:
matchLabels:
sandbox.gke.io/runtime: gvisor
policyTypes:
- Egress
egress:
- to:
- ipBlock:
cidr: 0.0.0.0/0
except:
- 169.254.169.254/32
57 changes: 57 additions & 0 deletions codelabs/gke/rl-sandbox-intro/raycluster.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
apiVersion: ray.io/v1
kind: RayCluster
metadata:
name: grpo-cluster
namespace: default
spec:
rayVersion: "2.35.0"
headGroupSpec:
rayStartParams:
dashboard-host: "0.0.0.0"
template:
spec:
tolerations:
- key: "nvidia.com/gpu"
operator: "Exists"
effect: "NoSchedule"
containers:
- name: ray-head
image: us-west3-docker.pkg.dev/dx-supercomputer-testing/rl-sandbox-repo/ray-gpu-worker:v1
ports:
- containerPort: 6379
name: gcs-server
- containerPort: 8265
name: dashboard
- containerPort: 10001
name: client
resources:
limits:
cpu: "4"
memory: "16Gi"
requests:
cpu: "4"
memory: "16Gi"
workerGroupSpecs:
- groupName: gpu-group
replicas: 1
minReplicas: 1
maxReplicas: 1
rayStartParams: {}
template:
spec:
tolerations:
- key: "nvidia.com/gpu"
operator: "Exists"
effect: "NoSchedule"
containers:
- name: ray-worker
image: us-west3-docker.pkg.dev/dx-supercomputer-testing/rl-sandbox-repo/ray-gpu-worker:v1
resources:
limits:
cpu: "12"
memory: "120Gi"
nvidia.com/gpu: "1"
requests:
cpu: "12"
memory: "120Gi"
nvidia.com/gpu: "1"
65 changes: 65 additions & 0 deletions codelabs/gke/rl-sandbox-intro/sandbox_router.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
apiVersion: rbac.authorization.k8s.io/v1
kind: Role
metadata:
namespace: default
name: sandbox-claim-manager
rules:
- apiGroups: ["extensions.agents.x-k8s.io"]
resources: ["sandboxclaims"]
verbs: ["get", "list", "watch", "create", "update", "patch", "delete"]
- apiGroups: ["agents.x-k8s.io"]
resources: ["sandboxes"]
verbs: ["get", "list", "watch", "create", "update", "patch", "delete"]
---
apiVersion: rbac.authorization.k8s.io/v1
kind: RoleBinding
metadata:
name: sandbox-claim-manager-binding
namespace: default
subjects:
- kind: ServiceAccount
name: default
namespace: default
roleRef:
kind: Role
name: sandbox-claim-manager
apiGroup: rbac.authorization.k8s.io
---
apiVersion: v1
kind: Service
metadata:
name: sandbox-router
namespace: default
spec:
type: ClusterIP
selector:
app: sandbox-router
ports:
- name: http
protocol: TCP
port: 8080
targetPort: 8080
---
apiVersion: apps/v1
kind: Deployment
metadata:
name: sandbox-router-deployment
namespace: default
spec:
replicas: 2
selector:
matchLabels:
app: sandbox-router
template:
metadata:
labels:
app: sandbox-router
spec:
containers:
- name: router
image: us-central1-docker.pkg.dev/k8s-staging-images/agent-sandbox/sandbox-router:latest-main
env:
- name: ALLOW_UNAUTHENTICATED_ROUTER
value: "true"
ports:
- containerPort: 8080
25 changes: 25 additions & 0 deletions codelabs/gke/rl-sandbox-intro/sandbox_warmpool.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
apiVersion: extensions.agents.x-k8s.io/v1alpha1
kind: SandboxTemplate
metadata:
name: swe-bench-django
namespace: default
spec:
podTemplate:
spec:
containers:
- name: sandbox
image: us-west3-docker.pkg.dev/dx-supercomputer-testing/rl-sandbox-repo/django-sandbox:v1
resources:
requests:
cpu: "2"
memory: "4Gi"
---
apiVersion: extensions.agents.x-k8s.io/v1alpha1
kind: SandboxWarmPool
metadata:
name: swe-bench-django-warmpool
namespace: default
spec:
replicas: 10
sandboxTemplateRef:
name: swe-bench-django
142 changes: 142 additions & 0 deletions codelabs/gke/rl-sandbox-intro/train_trl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
import ray
from k8s_agent_sandbox import SandboxClient
from k8s_agent_sandbox.models import SandboxDirectConnectionConfig
from trl import GRPOConfig, GRPOTrainer
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import urllib.request
import re

ray.init(ignore_reinit_error=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Calling ray.init() at the module level is a Ray anti-pattern. When Ray workers import this module to execute tasks, they will run ray.init() again. Although ignore_reinit_error=True suppresses the error, it can still cause unexpected behavior or warnings. It is best practice to initialize Ray inside the main() function or under the if __name__ == "__main__": block.


# 1. Define the Ray remote evaluation function
@ray.remote(num_cpus=0.1)
def evaluate_rollout(code, prompt_data):
client = SandboxClient(connection_config=SandboxDirectConnectionConfig(api_url="http://sandbox-router.default.svc.cluster.local:8080"))

# Claim a pre-warmed sandbox instantly based on the repo
repo = prompt_data.get("repo")

# In a full system, you'd route to different warmpools based on repo
# Here we default to django for our single task
sandbox = client.create_sandbox(
template="swe-bench-django",
warmpool="swe-bench-django-warmpool",
sandbox_ready_timeout=600
)

try:
# Check if the code is correctly formatted
bash_match = re.search(r"```bash\n(.*?)\n```", code, re.DOTALL)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

LLMs frequently output sh instead of bash, or include trailing whitespace after the language identifier. The current regex is strict and will fail to match these variations, resulting in a 0.0 reward for otherwise valid completions. Using a more permissive regex like r"```(?:bash|sh)\\s*\\n(.*?)\\n```" is much more robust.

Suggested change
bash_match = re.search(r"```bash\n(.*?)\n```", code, re.DOTALL)
bash_match = re.search(r"```(?:bash|sh)\\s*\\n(.*?)\\n```", code, re.DOTALL)

if not bash_match:
return 0.0

script = bash_match.group(1)

# In a real environment, we would apply the base commit and install here
# For simplicity, we just execute the script
import shlex
script_cmd = f"bash -c {shlex.quote(script)}"
result = sandbox.commands.run(script_cmd, timeout=60)

# Calculate continuous reward based on test passage ratio
if result.exit_code == 0:
return 1.0

# Very simple heuristic reward
return 0.1

finally:
# Clean up and release the sandbox back to the pool
client.delete_sandbox(sandbox.claim_name)

# 2. Define the Reward Function for TRL
def sandbox_reward_func(prompts, completions, **kwargs):
# Dispatch evaluation to Ray cluster
futures = [
evaluate_rollout.remote(completion, {
"repo": kwargs.get('repo', [])[i] if 'repo' in kwargs else None,
"base_commit": kwargs.get('base_commit', [])[i] if 'base_commit' in kwargs else None
}) for i, completion in enumerate(completions)
]

# Block and wait for all sandbox evaluations to complete
rewards = ray.get(futures)
return rewards

# 3. Setup GRPO Trainer
@ray.remote(num_gpus=1, num_cpus=8)
def train():
# Load dataset
dataset = load_dataset("princeton-nlp/SWE-bench_Lite", split="test")
# Filter to our selected target issue
dataset = dataset.filter(lambda x: x["instance_id"] == "django__django-15388")

def format_dataset(example):
files = re.findall(r'^\+\+\+ b/(.+)$', example["patch"], re.MULTILINE)
target_file = files[0] if files else ""

file_content = ""
if target_file:
try:
github_repo = example["repo"]
url = f"https://raw.githubusercontent.com/{github_repo}/{example['base_commit']}/{target_file}"
with urllib.request.urlopen(url) as response:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The urllib.request.urlopen call does not specify a timeout. If the GitHub raw server is slow or unresponsive, this call can block indefinitely, hanging the dataset mapping process. It is recommended to set a reasonable timeout.

Suggested change
with urllib.request.urlopen(url) as response:
with urllib.request.urlopen(url, timeout=10) as response:

file_content = response.read().decode('utf-8')
except Exception as e:
pass

prompt = f"""You are an expert software engineer.
You are given a GitHub issue and the content of the file that contains the bug.
Write an executable bash script that will modify the target file to fix the bug (e.g. using cat << 'EOF' > {target_file} or inline python edits).
Wrap your bash script in ```bash ... ``` tags. Do not output raw python code directly.

Target File: {target_file}

Original File Content:
```python
{file_content}
```

Issue:
{example['problem_statement']}
"""
return {
"prompt": prompt,
"repo": example["repo"],
"instance_id": example["instance_id"],
"base_commit": example["base_commit"],
}

dataset = dataset.map(format_dataset)

model_name = "Qwen/Qwen2.5-Coder-1.5B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name)

training_args = GRPOConfig(
output_dir="outputs",
learning_rate=5e-6,
max_steps=10,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=8,
generation_batch_size=8,
)
Comment on lines +116 to +124

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since the GPU worker base image is vllm/vllm-openai, you can significantly accelerate the generation phase of GRPO by enabling vLLM integration in GRPOConfig using use_vllm=True.

Suggested change
training_args = GRPOConfig(
output_dir="outputs",
learning_rate=5e-6,
max_steps=10,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=8,
generation_batch_size=8,
)
training_args = GRPOConfig(
output_dir="outputs",
learning_rate=5e-6,
max_steps=10,
per_device_train_batch_size=1,
gradient_accumulation_steps=4,
num_generations=8,
generation_batch_size=8,
use_vllm=True,
)


trainer = GRPOTrainer(
model=model_name,
processing_class=tokenizer,
reward_funcs=[sandbox_reward_func],
args=training_args,
train_dataset=dataset,
)

print("Starting GRPO training with GKE Agent Sandboxes...")
trainer.train()

def main():
print("Submitting training job to GPU worker...")
ray.get(train.remote())
Comment on lines +137 to +139

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Initialize Ray inside the main() function to ensure it only runs on the driver process and not on the Ray workers when they import this module.

Suggested change
def main():
print("Submitting training job to GPU worker...")
ray.get(train.remote())
def main():
ray.init(ignore_reinit_error=True)
print("Submitting training job to GPU worker...")
ray.get(train.remote())


if __name__ == "__main__":
main()
Loading