Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
205 changes: 194 additions & 11 deletions src/runpod_flash/cli/commands/build_utils/handler_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,49 @@ def handler(job):
runpod.serverless.start({{"handler": handler}})
'''

DEPLOYED_ASYNC_HANDLER_TEMPLATE = '''"""
Auto-generated deployed handler for resource: {resource_name}
Generated at: {timestamp}

Concurrent async endpoint handler: accepts plain JSON, no cloudpickle.
max_concurrency={max_concurrency}

This file is generated by the Flash build process. Do not edit manually.
"""

import importlib
import logging
import traceback

_logger = logging.getLogger(__name__)

# Import the function for this endpoint
{import_statement}


async def handler(job):
"""Async handler for concurrent QB endpoint. Accepts plain JSON kwargs."""
job_input = job.get("input", {{}})
try:
result = await {function_name}(**job_input)
return result
except Exception as e:
_logger.error(
"Deployed handler error for {function_name}: %s",
e,
exc_info=True,
)
return {{"error": str(e), "traceback": traceback.format_exc()}}


if __name__ == "__main__":
import runpod
runpod.serverless.start({{
"handler": handler,
"concurrency_modifier": lambda current: {max_concurrency},
}})
'''

DEPLOYED_CLASS_HANDLER_TEMPLATE = '''"""
Auto-generated deployed handler for class-based resource: {resource_name}
Generated at: {timestamp}
Expand Down Expand Up @@ -155,6 +198,84 @@ def handler(job):
'''


DEPLOYED_ASYNC_CLASS_HANDLER_TEMPLATE = '''"""
Auto-generated deployed handler for class-based resource: {resource_name}
Generated at: {timestamp}

Concurrent async class endpoint handler: accepts plain JSON, no cloudpickle.
max_concurrency={max_concurrency}

This file is generated by the Flash build process. Do not edit manually.
"""

import importlib
import logging
import traceback

_logger = logging.getLogger(__name__)

# import the class for this endpoint
{import_statement}

# instantiate once at cold start
_instance = {class_name}()

# public methods available for dispatch
_METHODS = {methods_dict}


async def handler(job):
"""Async handler for concurrent class-based QB endpoint.

Dispatches to a method on the singleton class instance.
If the class has exactly one public method, input is passed
directly as kwargs. If multiple methods exist, input must
include a "method" key to select the target.
"""
job_input = job.get("input", {{}})
try:
if len(_METHODS) == 1:
method_name = next(iter(_METHODS))
else:
method_name = job_input.pop("method", None)
if method_name is None:
return {{
"error": (
"class {class_name} has multiple methods: "
+ ", ".join(sorted(_METHODS))
+ ". include a \\"method\\" key in the input."
)
}}
if method_name not in _METHODS:
return {{
"error": (
f"unknown method '{{method_name}}' on {class_name}. "
f"available: {{', '.join(sorted(_METHODS))}}"
)
}}

method = getattr(_instance, method_name)
result = await method(**job_input)
return result
except Exception as e:
_logger.error(
"Deployed handler error for {class_name}.%s: %s",
job_input.get("method", "<default>"),
e,
exc_info=True,
)
return {{"error": str(e), "traceback": traceback.format_exc()}}


if __name__ == "__main__":
import runpod
runpod.serverless.start({{
"handler": handler,
"concurrency_modifier": lambda current: {max_concurrency},
}})
'''


class HandlerGenerator:
"""Generates handler_<name>.py files for each resource config."""

Expand Down Expand Up @@ -190,12 +311,7 @@ def generate_handlers(self) -> List[Path]:
return handler_paths

def _generate_handler(self, resource_name: str, resource_data: Any) -> Path:
"""Generate a handler file for a deployed QB endpoint.

Produces a plain-JSON handler that accepts raw dicts (no cloudpickle).
For classes, instantiates once at cold start and dispatches to methods.
For functions, calls the function directly with job input as kwargs.
"""
"""Generate a handler file for a deployed QB endpoint."""
handler_filename = f"handler_{resource_name}.py"
handler_path = self.build_dir / handler_filename

Expand All @@ -211,8 +327,14 @@ def _generate_handler(self, resource_name: str, resource_data: Any) -> Path:
else resource_data.get("functions", [])
)

max_concurrency = (
resource_data.max_concurrency
if hasattr(resource_data, "max_concurrency")
else resource_data.get("max_concurrency", 1)
)

handler_code = self._generate_deployed_handler_code(
resource_name, timestamp, functions
resource_name, timestamp, functions, max_concurrency
)

handler_path.write_text(handler_code)
Expand All @@ -225,11 +347,14 @@ def _generate_deployed_handler_code(
resource_name: str,
timestamp: str,
functions: List[Any],
max_concurrency: int = 1,
) -> str:
"""Generate deployed handler code for a QB endpoint.

Selects the class template for class-based workers and the
function template for plain functions.
Selects template based on max_concurrency and is_async:
- max_concurrency=1: current sync templates (unchanged behavior)
- max_concurrency>1 + async: new async templates with concurrency_modifier
- max_concurrency>1 + sync: sync templates with concurrency_modifier injected
"""
if not functions:
raise ValueError(
Expand All @@ -243,34 +368,92 @@ def _generate_deployed_handler_code(
is_class = (
func.is_class if hasattr(func, "is_class") else func.get("is_class", False)
)
is_async = (
func.is_async if hasattr(func, "is_async") else func.get("is_async", False)
)

import_statement = (
f"{name} = importlib.import_module('{module}').{name}"
if module and name
else "# No function to import"
)

if max_concurrency > 100:
logger.warning(
"High max_concurrency=%d for resource '%s'. Ensure your handler "
"and GPU can support this level of concurrent execution.",
max_concurrency,
resource_name,
)

if max_concurrency > 1 and not is_async:
logger.warning(
"max_concurrency=%d set on sync handler '%s'. "
"Only async handlers benefit from concurrent execution. "
"Consider making the handler async.",
max_concurrency,
resource_name,
)

if is_class:
class_methods = (
func.class_methods
if hasattr(func, "class_methods")
else func.get("class_methods", [])
)
methods_dict = {m: m for m in class_methods} if class_methods else {}
return DEPLOYED_CLASS_HANDLER_TEMPLATE.format(

if max_concurrency > 1 and is_async:
return DEPLOYED_ASYNC_CLASS_HANDLER_TEMPLATE.format(
resource_name=resource_name,
timestamp=timestamp,
import_statement=import_statement,
class_name=name or "None",
methods_dict=repr(methods_dict),
max_concurrency=max_concurrency,
)
Comment on lines 368 to +416
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is_async is used to decide whether to emit the new async class handler template, but the build scanner currently never marks class-based resources as async (it only sets is_async for non-class functions). This makes the max_concurrency > 1 and is_async async-class branch effectively unreachable in real builds, and will also trigger the “sync handler” warning even when the class methods are actually async def.

Suggested fix: either (a) update the scanner/manifest to set is_async=True for class resources when any public method is coroutine, or (b) make handler generation detect/handle async class methods without relying on the manifest-level is_async flag (e.g., generate an async handler that awaits coroutine results conditionally).

Copilot uses AI. Check for mistakes.

code = DEPLOYED_CLASS_HANDLER_TEMPLATE.format(
resource_name=resource_name,
timestamp=timestamp,
import_statement=import_statement,
class_name=name or "None",
methods_dict=repr(methods_dict),
)
if max_concurrency > 1:
code = self._inject_concurrency_modifier(code, max_concurrency)
return code

return DEPLOYED_HANDLER_TEMPLATE.format(
# Function-based handler
if max_concurrency > 1 and is_async:
return DEPLOYED_ASYNC_HANDLER_TEMPLATE.format(
resource_name=resource_name,
timestamp=timestamp,
import_statement=import_statement,
function_name=name or "None",
max_concurrency=max_concurrency,
)

code = DEPLOYED_HANDLER_TEMPLATE.format(
resource_name=resource_name,
timestamp=timestamp,
import_statement=import_statement,
function_name=name or "None",
)
if max_concurrency > 1:
code = self._inject_concurrency_modifier(code, max_concurrency)
return code

@staticmethod
def _inject_concurrency_modifier(code: str, max_concurrency: int) -> str:
"""Replace the default runpod.serverless.start call with one including concurrency_modifier."""
return code.replace(
'runpod.serverless.start({"handler": handler})',
"runpod.serverless.start({\n"
' "handler": handler,\n'
f' "concurrency_modifier": lambda current: {max_concurrency},\n'
" })",
)

Comment on lines +458 to 465
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_inject_concurrency_modifier relies on a hard-coded string replacement of runpod.serverless.start({"handler": handler}). If the template formatting/spacing changes (or if the start call is refactored), this will silently fail and omit concurrency_modifier without any error.

Suggested fix: generate the runpod.serverless.start(...) block via template placeholders (or structured formatting) instead of str.replace, or at minimum assert that the replacement occurred (e.g., check the expected substring was present and raise if not).

Suggested change
return code.replace(
'runpod.serverless.start({"handler": handler})',
"runpod.serverless.start({\n"
' "handler": handler,\n'
f' "concurrency_modifier": lambda current: {max_concurrency},\n'
" })",
)
start_call = 'runpod.serverless.start({"handler": handler})'
replacement = (
"runpod.serverless.start({\n"
' "handler": handler,\n'
f' "concurrency_modifier": lambda current: {max_concurrency},\n'
" })"
)
if code.count(start_call) != 1:
raise ValueError(
"Unable to inject concurrency_modifier: expected exactly one "
f"occurrence of {start_call!r} in generated handler code."
)
return code.replace(start_call, replacement, 1)

Copilot uses AI. Check for mistakes.
def _validate_handler_imports(self, handler_path: Path) -> None:
"""Validate that generated handler has valid Python syntax.
Expand Down
21 changes: 21 additions & 0 deletions src/runpod_flash/cli/commands/build_utils/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,14 @@ def _extract_deployment_config(
if resource_config is None:
return config

# Extract max_concurrency from Endpoint facade before unwrapping.
# max_concurrency is a Flash concept that lives on Endpoint,
# not on the internal resource config (LiveServerless, etc.).
if hasattr(resource_config, "_max_concurrency"):
mc = resource_config._max_concurrency
if mc > 1:
config["max_concurrency"] = mc

Comment on lines +184 to +197
Copy link

Copilot AI Apr 8, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_concurrency extraction currently only works when resource_config is an Endpoint instance (e.g., referenced via a module-level variable and found via config_variable). For the common inline decorator form (@Endpoint(..., max_concurrency=...)), Endpoint.__call__() passes the internal resource config into @remote, so __remote_config__["resource_config"] is not an Endpoint and the ephemeral Endpoint instance is not reachable during manifest build. As a result, max_concurrency will be silently dropped from the manifest for inline-decorated QB endpoints.

Suggested fix: persist max_concurrency into the decorated object’s __remote_config__ (e.g., a top-level key like max_concurrency) during decoration, and have _extract_deployment_config() read it from remote_cfg when present, rather than relying on accessing a private field on the Endpoint facade.

Copilot uses AI. Check for mistakes.
# unwrap Endpoint facade to the internal resource config
if hasattr(resource_config, "_build_resource_config"):
resource_config = resource_config._build_resource_config()
Expand Down Expand Up @@ -439,6 +447,19 @@ def build(self) -> Dict[str, Any]:
**deployment_config, # Include imageName, templateId, gpuIds, workers config
}

# max_concurrency is QB-only; warn and remove for LB endpoints
if (
is_load_balanced
and resources_dict[resource_name].get("max_concurrency", 1) > 1
):
logger.warning(
"max_concurrency=%d on LB endpoint '%s' is ignored. "
"LB endpoints handle concurrency via uvicorn.",
resources_dict[resource_name]["max_concurrency"],
resource_name,
)
resources_dict[resource_name].pop("max_concurrency", None)

if not is_load_balanced:
resources_dict[resource_name]["handler_file"] = (
f"handler_{resource_name}.py"
Expand Down
5 changes: 5 additions & 0 deletions src/runpod_flash/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ def __init__(
scaler_value: int = 4,
template: Optional[PodTemplate] = None,
min_cuda_version: Optional[CudaVersion | str] = CudaVersion.V12_8,
max_concurrency: int = 1,
):
if gpu is not None and cpu is not None:
raise ValueError(
Expand All @@ -378,6 +379,10 @@ def __init__(
if name is None and id is None and image is not None:
raise ValueError("name or id is required when image= is set.")

if max_concurrency < 1:
raise ValueError(f"max_concurrency must be >= 1, got {max_concurrency}")
self._max_concurrency = max_concurrency

# name can be None here for QB decorator mode (@Endpoint(gpu=...)).
# it gets derived from the decorated function/class in __call__().
self.name = name
Expand Down
2 changes: 2 additions & 0 deletions src/runpod_flash/runtime/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ResourceConfig:
False # LB endpoint (LoadBalancerSlsResource or LiveLoadBalancer)
)
is_live_resource: bool = False # LiveLoadBalancer/LiveServerless (local dev only)
max_concurrency: int = 1 # Concurrent jobs per worker (QB only)

@classmethod
def from_dict(cls, data: Dict[str, Any]) -> "ResourceConfig":
Expand All @@ -43,6 +44,7 @@ def from_dict(cls, data: Dict[str, Any]) -> "ResourceConfig":
makes_remote_calls=data.get("makes_remote_calls", True),
is_load_balanced=data.get("is_load_balanced", False),
is_live_resource=data.get("is_live_resource", False),
max_concurrency=data.get("max_concurrency", 1),
)


Expand Down
Loading
Loading