Skip to content

Commit 9873dcb

Browse files
committed
fix: address PR review feedback for concurrency support
- Fix inline @endpoint() max_concurrency loss: persist into __remote_config__ so manifest builder can read it when the Endpoint facade is not reachable as a module-level variable. - Fix scanner is_async for class-based workers: detect async public methods so the correct handler template is selected. - Harden _inject_concurrency_modifier: raise ValueError if the expected start call string is absent instead of silently producing a handler without concurrency_modifier. - Add rationale comment for the >100 concurrency warning threshold. - Add tests for all new behaviors.
1 parent f606b94 commit 9873dcb

7 files changed

Lines changed: 136 additions & 4 deletions

File tree

src/runpod_flash/cli/commands/build_utils/handler_generator.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,8 @@ def _generate_deployed_handler_code(
378378
else "# No function to import"
379379
)
380380

381+
# 100 is a soft upper bound: most GPU workloads saturate VRAM well
382+
# below this. The warning nudges users to verify resource capacity.
381383
if max_concurrency > 100:
382384
logger.warning(
383385
"High max_concurrency=%d for resource '%s'. Ensure your handler "
@@ -447,8 +449,14 @@ def _generate_deployed_handler_code(
447449
@staticmethod
448450
def _inject_concurrency_modifier(code: str, max_concurrency: int) -> str:
449451
"""Replace the default runpod.serverless.start call with one including concurrency_modifier."""
452+
start_call = 'runpod.serverless.start({"handler": handler})'
453+
if start_call not in code:
454+
raise ValueError(
455+
"Unable to inject concurrency_modifier: expected "
456+
f"{start_call!r} in generated handler code."
457+
)
450458
return code.replace(
451-
'runpod.serverless.start({"handler": handler})',
459+
start_call,
452460
"runpod.serverless.start({\n"
453461
' "handler": handler,\n'
454462
f' "concurrency_modifier": lambda current: {max_concurrency},\n'

src/runpod_flash/cli/commands/build_utils/manifest.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def _extract_deployment_config(
166166

167167
try:
168168
resource_config = None
169+
remote_cfg = None
169170

170171
if config_variable and hasattr(module, config_variable):
171172
resource_config = getattr(module, config_variable)
@@ -181,12 +182,18 @@ def _extract_deployment_config(
181182
return config
182183

183184
# Extract max_concurrency from Endpoint facade before unwrapping.
184-
# max_concurrency is a Flash concept that lives on Endpoint,
185-
# not on the internal resource config (LiveServerless, etc.).
185+
# For module-level Endpoint variables, read from the instance.
186+
# For inline @Endpoint() decorators, the facade is gone by the
187+
# time we reach here -- read from __remote_config__ instead.
186188
if hasattr(resource_config, "_max_concurrency"):
187189
mc = resource_config._max_concurrency
188190
if mc > 1:
189191
config["max_concurrency"] = mc
192+
elif (
193+
isinstance(remote_cfg, dict)
194+
and remote_cfg.get("max_concurrency", 1) > 1
195+
):
196+
config["max_concurrency"] = remote_cfg["max_concurrency"]
190197

191198
# unwrap Endpoint facade to the internal resource config
192199
if hasattr(resource_config, "_build_resource_config"):

src/runpod_flash/cli/commands/build_utils/scanner.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,15 @@ def _metadata_from_remote_config(
274274
is_async = inspect.iscoroutinefunction(original) or inspect.iscoroutinefunction(
275275
obj
276276
)
277+
elif is_class and target_class is not None:
278+
# A class is async if any of its public methods are coroutines.
279+
# This drives handler template selection for concurrent endpoints.
280+
is_async = any(
281+
inspect.iscoroutinefunction(m)
282+
for name, m in vars(target_class).items()
283+
if not name.startswith("_")
284+
and (inspect.isfunction(m) or inspect.iscoroutinefunction(m))
285+
)
277286

278287
# function/class name: for classes, use the original class name.
279288
# for functions, use __name__ from the unwrapped function.

src/runpod_flash/endpoint.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,14 +611,24 @@ async def process(data: dict) -> dict: ...
611611

612612
from .client import remote as remote_decorator
613613

614-
return remote_decorator(
614+
decorated = remote_decorator(
615615
resource_config=resource_config,
616616
dependencies=self.dependencies,
617617
system_dependencies=self.system_dependencies,
618618
accelerate_downloads=self.accelerate_downloads,
619619
_internal=True,
620620
)(func_or_class)
621621

622+
# Persist max_concurrency into __remote_config__ so the manifest
623+
# builder can read it for inline @Endpoint() decorators where the
624+
# Endpoint facade is not reachable via a module-level variable.
625+
if self._max_concurrency > 1:
626+
remote_cfg = getattr(decorated, "__remote_config__", None)
627+
if isinstance(remote_cfg, dict):
628+
remote_cfg["max_concurrency"] = self._max_concurrency
629+
630+
return decorated
631+
622632
# -- route decorators (lb mode) --
623633

624634
def _route(self, method: str, path: str):

tests/unit/cli/commands/build_utils/test_handler_generator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,3 +814,9 @@ def test_high_concurrency_logs_warning(caplog):
814814
generator = HandlerGenerator(manifest, build_dir)
815815
generator.generate_handlers()
816816
assert any("max_concurrency=150" in r.message for r in caplog.records)
817+
818+
819+
def test_inject_concurrency_modifier_raises_on_missing_start_call():
820+
"""_inject_concurrency_modifier raises if the start call string is absent."""
821+
with pytest.raises(ValueError, match="Unable to inject concurrency_modifier"):
822+
HandlerGenerator._inject_concurrency_modifier("some random code", 5)

tests/unit/cli/commands/build_utils/test_runtime_scanner.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -739,3 +739,71 @@ def test_discovers_in_subdirectory(self, tmp_path):
739739
functions = scanner.discover_remote_functions()
740740
assert len(functions) == 1
741741
assert functions[0].module_path == "workers.gpu_worker"
742+
743+
744+
class TestClassAsyncDetection:
745+
"""scanner detects is_async for class-based workers."""
746+
747+
def test_sync_class_not_async(self, tmp_path):
748+
_write_worker(
749+
tmp_path,
750+
"sync_cls.py",
751+
"""\
752+
from runpod_flash import remote, LiveServerless
753+
cfg = LiveServerless(name="sync-cls")
754+
755+
@remote(cfg)
756+
class SyncWorker:
757+
def predict(self, x):
758+
return x
759+
""",
760+
)
761+
scanner = RuntimeScanner(tmp_path)
762+
functions = scanner.discover_remote_functions()
763+
assert len(functions) == 1
764+
assert functions[0].is_class is True
765+
assert functions[0].is_async is False
766+
767+
def test_async_class_detected(self, tmp_path):
768+
_write_worker(
769+
tmp_path,
770+
"async_cls.py",
771+
"""\
772+
from runpod_flash import remote, LiveServerless
773+
cfg = LiveServerless(name="async-cls")
774+
775+
@remote(cfg)
776+
class AsyncWorker:
777+
async def predict(self, x):
778+
return x
779+
""",
780+
)
781+
scanner = RuntimeScanner(tmp_path)
782+
functions = scanner.discover_remote_functions()
783+
assert len(functions) == 1
784+
assert functions[0].is_class is True
785+
assert functions[0].is_async is True
786+
787+
def test_mixed_class_detected_as_async(self, tmp_path):
788+
"""Class with both sync and async methods is detected as async."""
789+
_write_worker(
790+
tmp_path,
791+
"mixed_cls.py",
792+
"""\
793+
from runpod_flash import remote, LiveServerless
794+
cfg = LiveServerless(name="mixed-cls")
795+
796+
@remote(cfg)
797+
class MixedWorker:
798+
def setup(self):
799+
pass
800+
801+
async def predict(self, x):
802+
return x
803+
""",
804+
)
805+
scanner = RuntimeScanner(tmp_path)
806+
functions = scanner.discover_remote_functions()
807+
assert len(functions) == 1
808+
assert functions[0].is_class is True
809+
assert functions[0].is_async is True

tests/unit/test_endpoint.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,3 +1036,27 @@ def test_rejects_zero(self):
10361036
def test_rejects_negative(self):
10371037
with pytest.raises(ValueError, match="max_concurrency must be >= 1"):
10381038
Endpoint(name="test", max_concurrency=-1)
1039+
1040+
def test_inline_decorator_persists_max_concurrency(self):
1041+
"""Inline @Endpoint() persists max_concurrency into __remote_config__."""
1042+
ep = Endpoint(name="test", gpu=GpuGroup.ANY, max_concurrency=5)
1043+
1044+
async def generate(prompt: str) -> str:
1045+
return prompt
1046+
1047+
decorated = ep(generate)
1048+
remote_cfg = getattr(decorated, "__remote_config__", None)
1049+
assert remote_cfg is not None
1050+
assert remote_cfg["max_concurrency"] == 5
1051+
1052+
def test_inline_decorator_omits_default_max_concurrency(self):
1053+
"""Inline @Endpoint() with default max_concurrency=1 does not add the key."""
1054+
ep = Endpoint(name="test", gpu=GpuGroup.ANY)
1055+
1056+
async def generate(prompt: str) -> str:
1057+
return prompt
1058+
1059+
decorated = ep(generate)
1060+
remote_cfg = getattr(decorated, "__remote_config__", None)
1061+
assert remote_cfg is not None
1062+
assert "max_concurrency" not in remote_cfg

0 commit comments

Comments
 (0)