Skip to content

Commit 433b1da

Browse files
committed
fix: poll for jobs while tasks are running
1 parent 342a1d2 commit 433b1da

2 files changed

Lines changed: 291 additions & 9 deletions

File tree

runpod/serverless/modules/rp_scale.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import signal
88
import sys
99
import traceback
10-
from typing import Any, Dict
10+
from typing import Any, Dict, Set
1111

1212
from ...http_client import AsyncClientSession, ClientSession, TooManyRequests
1313
from .rp_job import get_job, handle_job
@@ -219,30 +219,35 @@ async def run_jobs(self, session: ClientSession):
219219
220220
Runs the block in an infinite loop while the worker is alive or jobs queue is not empty.
221221
"""
222-
tasks = [] # Store the tasks for concurrent job processing
222+
tasks: Set[asyncio.Task] = set()
223223

224+
last_task_count = 0
224225
while self.is_alive() or not self.jobs_queue.empty():
225226
# Fetch as many jobs as the concurrency allows
226227
while len(tasks) < self.current_concurrency and not self.jobs_queue.empty():
227228
job = await self.jobs_queue.get()
228-
229229
# Create a new task for each job and add it to the task list
230230
task = asyncio.create_task(self.handle_job(session, job))
231-
tasks.append(task)
231+
tasks.add(task)
232232

233233
# Wait for any job to finish
234234
if tasks:
235-
log.info(f"Jobs in progress: {len(tasks)}")
235+
current_task_count = len(tasks)
236+
if current_task_count != last_task_count:
237+
log.info(f"Jobs in progress: {current_task_count}")
238+
last_task_count = current_task_count
236239

237240
done, pending = await asyncio.wait(
238-
tasks, return_when=asyncio.FIRST_COMPLETED
241+
tasks, return_when=asyncio.FIRST_COMPLETED, timeout=0.1
239242
)
240243

241244
# Remove completed tasks from the list
242-
tasks = [t for t in tasks if t not in done]
245+
tasks.difference_update(done)
246+
247+
else:
248+
# don't busy wait
249+
await asyncio.sleep(0.1)
243250

244-
# Yield control back to the event loop
245-
await asyncio.sleep(0)
246251

247252
# Ensure all remaining tasks finish before stopping
248253
await asyncio.gather(*tasks)
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
import asyncio
2+
from unittest.mock import AsyncMock
3+
from dataclasses import dataclass
4+
5+
import pytest
6+
7+
from runpod.serverless.modules import rp_scale
8+
9+
10+
class DummyProgress:
11+
def __init__(self):
12+
self.count = 0
13+
14+
def get_job_count(self):
15+
return self.count
16+
17+
def add(self, _):
18+
self.count += 1
19+
20+
def remove(self, _):
21+
self.count = max(0, self.count - 1)
22+
23+
24+
@dataclass
25+
class PatchScaler:
26+
scaler: rp_scale.JobScaler
27+
progress: DummyProgress
28+
29+
30+
def generate_job(id: str):
31+
return {"id": id, "input": {"test": "data"}}
32+
33+
34+
@pytest.fixture
35+
def job_scaler(monkeypatch) -> PatchScaler:
36+
def dummy_jobs_fetcher(input_job_id: str):
37+
return {"id": input_job_id, "input": {"test": "data"}}
38+
39+
async def dummy_jobs_handler(_session, _config, _job):
40+
await asyncio.sleep(0.05)
41+
return None
42+
43+
dummy_progress = DummyProgress()
44+
monkeypatch.setattr(rp_scale, "JobsProgress", lambda: dummy_progress)
45+
46+
job_scaler_config = {
47+
"handler": lambda *_: None,
48+
"jobs_fetcher": dummy_jobs_fetcher,
49+
}
50+
scaler = rp_scale.JobScaler(job_scaler_config)
51+
scaler.jobs_handler = dummy_jobs_handler
52+
patch_scaler = PatchScaler(scaler=scaler, progress=dummy_progress)
53+
return patch_scaler
54+
55+
56+
@pytest.mark.asyncio
57+
async def test_workers_take_single_job_off_queue(job_scaler: PatchScaler):
58+
scaler = job_scaler.scaler
59+
scaler.current_concurrency = 2
60+
_ = asyncio.create_task(scaler.run_jobs(None))
61+
62+
await scaler.jobs_queue.put(generate_job("test-1"))
63+
64+
assert scaler.jobs_queue.qsize() == 1
65+
await asyncio.sleep(0)
66+
assert scaler.jobs_queue.qsize() == 0
67+
68+
scaler.kill_worker()
69+
70+
71+
@pytest.mark.asyncio
72+
async def test_workers_fully_drain_queue(job_scaler: PatchScaler):
73+
scaler = job_scaler.scaler
74+
scaler.current_concurrency = 2
75+
_ = asyncio.create_task(scaler.run_jobs(None))
76+
77+
scaler.jobs_queue = asyncio.Queue(maxsize=2)
78+
for i in range(2):
79+
await scaler.jobs_queue.put(generate_job(f"test-{i}"))
80+
81+
assert scaler.jobs_queue.qsize() == 2
82+
await asyncio.sleep(0)
83+
assert scaler.jobs_queue.qsize() == 0
84+
scaler.kill_worker()
85+
86+
87+
@pytest.mark.asyncio
88+
async def test_workers_only_take_n_jobs(job_scaler: PatchScaler):
89+
scaler = job_scaler.scaler
90+
scaler.current_concurrency = 2
91+
_ = asyncio.create_task(scaler.run_jobs(None))
92+
93+
scaler.jobs_queue = asyncio.Queue(maxsize=3)
94+
for i in range(3):
95+
await scaler.jobs_queue.put(generate_job(f"test-{i}"))
96+
97+
assert scaler.jobs_queue.qsize() == 3
98+
await asyncio.sleep(0)
99+
assert scaler.jobs_queue.qsize() == 1
100+
101+
scaler.kill_worker()
102+
103+
@pytest.mark.asyncio
104+
async def test_worker_take_concurrent_jobs_dynamically(job_scaler: PatchScaler):
105+
scaler = job_scaler.scaler
106+
scaler.current_concurrency = 3
107+
scaler.jobs_queue = asyncio.Queue(maxsize=3)
108+
_ = asyncio.create_task(scaler.run_jobs(None))
109+
110+
for i in range(2):
111+
await scaler.jobs_queue.put(generate_job(f"test-{i}"))
112+
113+
assert scaler.jobs_queue.qsize() == 2
114+
await asyncio.sleep(0)
115+
assert scaler.jobs_queue.qsize() == 0
116+
117+
await scaler.jobs_queue.put(generate_job(f"test-{2}"))
118+
assert scaler.jobs_queue.qsize() == 1
119+
await asyncio.sleep(0.2)
120+
# workers should take additional job to fill concurrency space
121+
assert scaler.jobs_queue.qsize() == 0
122+
123+
scaler.kill_worker()
124+
125+
126+
@pytest.mark.asyncio
127+
async def test_handle_job_completes_and_clears_state(job_scaler: PatchScaler):
128+
scaler = job_scaler.scaler
129+
finished = []
130+
131+
async def handler(session, config, job):
132+
finished.append(job["id"])
133+
134+
scaler.jobs_handler = handler
135+
job = generate_job("handle-success")
136+
await scaler.jobs_queue.put(job)
137+
job = await scaler.jobs_queue.get()
138+
job_scaler.progress.add(job)
139+
140+
await scaler.handle_job(AsyncMock(), job)
141+
142+
assert finished == ["handle-success"]
143+
assert scaler.jobs_queue.qsize() == 0
144+
assert job_scaler.progress.count == 0
145+
146+
scaler.kill_worker()
147+
148+
@pytest.mark.asyncio
149+
async def test_shutdown_waits_for_inflight_job(job_scaler: PatchScaler):
150+
scaler = job_scaler.scaler
151+
job_started = asyncio.Event()
152+
finish_job = asyncio.Event()
153+
154+
async def handler(session, config, job):
155+
job_started.set()
156+
await finish_job.wait()
157+
158+
scaler.jobs_handler = handler
159+
scaler.current_concurrency = 1
160+
scaler.jobs_queue = asyncio.Queue(maxsize=1)
161+
run_task = asyncio.create_task(scaler.run_jobs(None))
162+
163+
job = {"id": "inflight"}
164+
await scaler.jobs_queue.put(job)
165+
166+
await asyncio.wait_for(job_started.wait(), timeout=2)
167+
168+
scaler.kill_worker()
169+
await asyncio.sleep(0)
170+
171+
assert not run_task.done()
172+
173+
finish_job.set()
174+
await asyncio.wait_for(run_task, timeout=2)
175+
176+
assert job_scaler.progress.count == 0
177+
assert scaler.jobs_queue.qsize() == 0
178+
179+
scaler.kill_worker()
180+
181+
182+
@pytest.mark.asyncio
183+
async def test_shutdown_drains_jobs_in_queue(job_scaler: PatchScaler):
184+
scaler = job_scaler.scaler
185+
finished = []
186+
block = asyncio.Event()
187+
188+
async def handler(session, config, job):
189+
await block.wait()
190+
finished.append(job["id"])
191+
192+
scaler.jobs_handler = handler
193+
scaler.current_concurrency = 2
194+
scaler.jobs_queue = asyncio.Queue(maxsize=2)
195+
196+
session = AsyncMock()
197+
198+
jobs = [{"id": f"job-{idx}"} for idx in range(2)]
199+
for job in jobs:
200+
await scaler.jobs_queue.put(job)
201+
202+
run_task = asyncio.create_task(scaler.run_jobs(session))
203+
scaler.kill_worker()
204+
205+
await asyncio.sleep(0)
206+
assert not run_task.done()
207+
208+
block.set()
209+
await asyncio.wait_for(run_task, timeout=2)
210+
211+
assert sorted(finished) == [job["id"] for job in jobs]
212+
assert scaler.jobs_queue.qsize() == 0
213+
214+
scaler.kill_worker()
215+
216+
217+
@pytest.mark.asyncio
218+
async def test_workers_process_jobs(job_scaler: PatchScaler):
219+
scaler = job_scaler.scaler
220+
handled = []
221+
222+
async def handler(_session, _config, job):
223+
handled.append(job["id"])
224+
225+
scaler.jobs_handler = handler
226+
scaler.current_concurrency = 2
227+
await scaler.set_scale()
228+
for i in range(2):
229+
await scaler.jobs_queue.put(generate_job(f"job-{i}"))
230+
231+
asyncio.create_task(scaler.run_jobs(None))
232+
233+
await asyncio.sleep(0.1) # let workers run once
234+
235+
assert handled == ["job-0", "job-1"]
236+
assert scaler.jobs_queue.qsize() == 0
237+
assert job_scaler.progress.count == 0
238+
239+
scaler.kill_worker()
240+
241+
@pytest.mark.asyncio
242+
async def test_get_jobs_feeds_workers_end_to_end(job_scaler: PatchScaler):
243+
scaler = job_scaler.scaler
244+
handled = []
245+
job_processed = asyncio.Event()
246+
247+
async def handler(_session, _config, job):
248+
handled.append(job["id"])
249+
job_processed.set()
250+
251+
fetch_count = {"value": 0}
252+
253+
async def fetcher(_session, jobs_needed):
254+
if fetch_count["value"]:
255+
return []
256+
fetch_count["value"] += 1
257+
return [generate_job(f"job-{idx}") for idx in range(jobs_needed)]
258+
259+
scaler.jobs_handler = handler
260+
scaler.jobs_fetcher = fetcher
261+
scaler.current_concurrency = 1
262+
263+
session = AsyncMock()
264+
get_task = asyncio.create_task(scaler.get_jobs(session))
265+
266+
run_jobs_task = asyncio.create_task(scaler.run_jobs(None))
267+
await asyncio.wait_for(job_processed.wait(), timeout=5)
268+
269+
scaler.kill_worker()
270+
await asyncio.wait_for(get_task, timeout=5)
271+
await asyncio.wait_for(run_jobs_task, timeout=5)
272+
273+
assert handled == ["job-0"]
274+
assert scaler.jobs_queue.qsize() == 0
275+
assert job_scaler.progress.count == 0
276+
277+
scaler.kill_worker()

0 commit comments

Comments
 (0)