diff --git a/docs/Deployment_Architecture.md b/docs/Deployment_Architecture.md index a3c6140e..65147395 100644 --- a/docs/Deployment_Architecture.md +++ b/docs/Deployment_Architecture.md @@ -191,6 +191,12 @@ When `flash deploy` provisions endpoints: 3. The State Manager stores `{environment_id, resource_name} -> endpoint_id` 4. At runtime, the `ServiceRegistry` uses the manifest + State Manager to route calls +### Manifest credential handling + +- Runtime endpoint metadata (including API-returned `aiKey`) may be stored in the State Manager manifest for deployment reconciliation. +- Local `.flash/flash_manifest.json` is sanitized before it is written to disk and does not include `aiKey`. +- `RUNPOD_API_KEY` is sourced from environment/credential storage and injected into endpoint env when needed; it is not persisted in the local manifest. + See [Cross-Endpoint Routing](Cross_Endpoint_Routing.md) for the full runtime flow. ## Related Documentation diff --git a/src/runpod_flash/cli/docs/flash-deploy.md b/src/runpod_flash/cli/docs/flash-deploy.md index 15a1567c..2666a748 100644 --- a/src/runpod_flash/cli/docs/flash-deploy.md +++ b/src/runpod_flash/cli/docs/flash-deploy.md @@ -138,9 +138,17 @@ The deploy command combines building and deploying your Flash application in a s - Registers endpoints in environment tracking 4. **Post-Deployment**: - - Displays deployment URLs and available routes - - Shows authentication and testing guidance - - Cleans up temporary build directory + - Displays deployment URLs and available routes + - Shows authentication and testing guidance + - Cleans up temporary build directory + +## Manifest and Credential Handling + +During deploy, Flash updates manifest metadata with runtime endpoint details (for example `endpoint_id`, endpoint URLs, and `aiKey` when returned by the API). + +- The manifest stored in State Manager keeps runtime metadata used for reconciliation. +- The local `.flash/flash_manifest.json` is sanitized before writing to disk and does not persist `aiKey`. +- `RUNPOD_API_KEY` continues to be resolved from credentials/env at runtime and is not stored in the local manifest. ## Build Options diff --git a/src/runpod_flash/cli/docs/flash-logging.md b/src/runpod_flash/cli/docs/flash-logging.md index 417abcf6..a5565225 100644 --- a/src/runpod_flash/cli/docs/flash-logging.md +++ b/src/runpod_flash/cli/docs/flash-logging.md @@ -28,6 +28,12 @@ Logs are written in the same format as console output, so you can grep through t - **Graceful degradation**: Continues with stdout-only if file logging fails - **Zero configuration**: Works out of the box with sensible defaults +### QB request log polling during `Endpoint.run(...)` + +- For queue-based (QB) endpoints, Flash polls endpoint status/metrics while waiting and streams worker log lines to stdout when available. +- Polling is used for async `run(...)` flows (not `runsync(...)`), and is skipped for non-QB endpoint types. +- If endpoint `aiKey` is unavailable, Flash falls back to your configured `RUNPOD_API_KEY`; without a key, log streaming is skipped. + ## Log Location By default, logs are written to: diff --git a/src/runpod_flash/cli/utils/deployment.py b/src/runpod_flash/cli/utils/deployment.py index 84aabc0f..62776b34 100644 --- a/src/runpod_flash/cli/utils/deployment.py +++ b/src/runpod_flash/cli/utils/deployment.py @@ -1,6 +1,7 @@ """Deployment environment management utilities.""" import asyncio +import copy import json import logging from typing import Dict, Any @@ -8,12 +9,46 @@ from pathlib import Path from runpod_flash.config import get_paths +from runpod_flash.core.resources.serverless import ServerlessResource from runpod_flash.core.resources.app import FlashApp from runpod_flash.core.resources.resource_manager import ResourceManager from runpod_flash.runtime.resource_provisioner import create_resource_from_manifest log = logging.getLogger(__name__) +RUNTIME_RESOURCE_FIELDS = set(ServerlessResource.RUNTIME_FIELDS) | { + "id", + "endpoint_id", +} + + +def _normalized_resource_attr(resource: Any, *names: str) -> str | None: + for name in names: + value = getattr(resource, name, None) + if isinstance(value, str) and value.strip(): + return value + return None + + +def _manifest_without_ai_keys(manifest: Dict[str, Any]) -> Dict[str, Any]: + sanitized_manifest = copy.deepcopy(manifest) + resources = sanitized_manifest.get("resources") + if not isinstance(resources, dict): + return sanitized_manifest + + for config in resources.values(): + if isinstance(config, dict): + config.pop("aiKey", None) + + return sanitized_manifest + + +def _resource_config_for_compare(config: Dict[str, Any]) -> Dict[str, Any]: + compare_config = copy.deepcopy(config) + for field in RUNTIME_RESOURCE_FIELDS: + compare_config.pop(field, None) + return compare_config + async def upload_build(app_name: str, build_path: str | Path): app = await FlashApp.from_name(app_name) @@ -147,6 +182,14 @@ async def provision_resources_for_build( resources_endpoints[resource_name] = endpoint_url + endpoint_id = _normalized_resource_attr(deployed_resource, "endpoint_id", "id") + if endpoint_id: + manifest["resources"][resource_name]["endpoint_id"] = endpoint_id + + ai_key = _normalized_resource_attr(deployed_resource, "aiKey", "ai_key") + if ai_key: + manifest["resources"][resource_name]["aiKey"] = ai_key + # Track load balancer URL for prominent logging if manifest["resources"][resource_name].get("is_load_balanced"): lb_endpoint_url = endpoint_url @@ -258,9 +301,15 @@ async def reconcile_and_provision_resources( local_config = local_manifest["resources"][resource_name] state_config = state_manifest.get("resources", {}).get(resource_name, {}) - # Simple hash comparison for config changes - local_json = json.dumps(local_config, sort_keys=True) - state_json = json.dumps(state_config, sort_keys=True) + # Compare only user-managed config fields (exclude runtime metadata) + local_json = json.dumps( + _resource_config_for_compare(local_config), + sort_keys=True, + ) + state_json = json.dumps( + _resource_config_for_compare(state_config), + sort_keys=True, + ) # Check if endpoint exists in state manifest has_endpoint = resource_name in state_manifest.get("resources_endpoints", {}) @@ -282,6 +331,10 @@ async def reconcile_and_provision_resources( local_manifest["resources"][resource_name]["endpoint_id"] = ( state_config["endpoint_id"] ) + if "aiKey" in state_config: + local_manifest["resources"][resource_name]["aiKey"] = state_config[ + "aiKey" + ] if resource_name in state_manifest.get("resources_endpoints", {}): local_manifest.setdefault("resources_endpoints", {})[resource_name] = ( state_manifest["resources_endpoints"][resource_name] @@ -315,13 +368,21 @@ async def reconcile_and_provision_resources( deployed_resource = provisioning_results[i] # Extract endpoint info - endpoint_id = getattr(deployed_resource, "endpoint_id", None) + endpoint_id = _normalized_resource_attr( + deployed_resource, "endpoint_id", "id" + ) endpoint_url = getattr(deployed_resource, "endpoint_url", None) - + if isinstance(endpoint_url, str): + endpoint_url = endpoint_url.strip() or None + else: + endpoint_url = None + ai_key = _normalized_resource_attr(deployed_resource, "aiKey", "ai_key") if endpoint_id: local_manifest["resources"][resource_name]["endpoint_id"] = endpoint_id if endpoint_url: local_manifest["resources_endpoints"][resource_name] = endpoint_url + if ai_key: + local_manifest["resources"][resource_name]["aiKey"] = ai_key log.debug( f"{'Provisioned' if action_type == 'provision' else 'Updated'}: " @@ -348,9 +409,11 @@ async def reconcile_and_provision_resources( f"Successfully provisioned: {provisioned}" ) + local_manifest_for_disk = _manifest_without_ai_keys(local_manifest) + # Write updated manifest back to local file manifest_path = Path.cwd() / ".flash" / "flash_manifest.json" - manifest_path.write_text(json.dumps(local_manifest, indent=2)) + manifest_path.write_text(json.dumps(local_manifest_for_disk, indent=2)) log.debug(f"Local manifest updated at {manifest_path.relative_to(Path.cwd())}") diff --git a/src/runpod_flash/core/api/runpod.py b/src/runpod_flash/core/api/runpod.py index 2870106c..04406554 100644 --- a/src/runpod_flash/core/api/runpod.py +++ b/src/runpod_flash/core/api/runpod.py @@ -396,6 +396,80 @@ async def get_gpu_types( result = await self._execute_graphql(query, variables) return result.get("gpuTypes", []) + async def get_gpu_lowest_price_stock_status( + self, + gpu_id: str, + gpu_count: int, + data_center_id: Optional[str] = None, + ) -> Optional[str]: + query = """ + query ServerlessGpuTypes($lowestPriceInput: GpuLowestPriceInput, $gpuTypesInput: GpuTypeFilter) { + gpuTypes(input: $gpuTypesInput) { + lowestPrice(input: $lowestPriceInput) { + stockStatus + } + } + } + """ + + variables = { + "gpuTypesInput": {"ids": [gpu_id]}, + "lowestPriceInput": { + "dataCenterId": data_center_id, + "gpuCount": gpu_count, + "secureCloud": True, + "includeAiApi": True, + "allowedCudaVersions": [], + "compliance": [], + }, + } + + result = await self._execute_graphql(query, variables) + gpu_types = result.get("gpuTypes") or [] + first = gpu_types[0] if gpu_types else {} + lowest = first.get("lowestPrice") if isinstance(first, dict) else {} + if not isinstance(lowest, dict): + return None + status = lowest.get("stockStatus") + if isinstance(status, str) and status.strip(): + return status.strip() + return None + + async def get_cpu_specific_stock_status( + self, + cpu_flavor_id: str, + instance_id: str, + data_center_id: str, + ) -> Optional[str]: + query = """ + query SecureCpuTypes($cpuFlavorInput: CpuFlavorInput, $specificsInput: SpecificsInput) { + cpuFlavors(input: $cpuFlavorInput) { + specifics(input: $specificsInput) { + stockStatus + } + } + } + """ + + variables = { + "cpuFlavorInput": {"id": cpu_flavor_id}, + "specificsInput": { + "dataCenterId": data_center_id, + "instanceId": instance_id, + }, + } + + result = await self._execute_graphql(query, variables) + cpu_flavors = result.get("cpuFlavors") or [] + first = cpu_flavors[0] if cpu_flavors else {} + specifics = first.get("specifics") if isinstance(first, dict) else {} + if not isinstance(specifics, dict): + return None + status = specifics.get("stockStatus") + if isinstance(status, str) and status.strip(): + return status.strip() + return None + async def get_endpoint(self, endpoint_id: str) -> Dict[str, Any]: """Get endpoint details.""" # Note: The schema doesn't show a specific endpoint query diff --git a/src/runpod_flash/core/resources/request_logs.py b/src/runpod_flash/core/resources/request_logs.py new file mode 100644 index 00000000..df7ce2ed --- /dev/null +++ b/src/runpod_flash/core/resources/request_logs.py @@ -0,0 +1,427 @@ +import logging +import os +import re +from dataclasses import dataclass +from dataclasses import field +from datetime import datetime, timezone +from enum import Enum +from typing import Any, List, Optional + +import httpx + +from runpod_flash.core.utils.http import get_authenticated_httpx_client + +log = logging.getLogger(__name__) + +API_BASE_URL = os.getenv("RUNPOD_API_BASE_URL", "https://api.runpod.ai").rstrip("/") +DEV_API_BASE_URL = "https://dev-api.runpod.ai" +HAPI_BASE_URL = "https://hapi.runpod.net" +DEV_HAPI_BASE_URL = "https://dev-hapi.runpod.net" +LOG_PREFIX_TIMESTAMP_RE = re.compile( + r"^(?P\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?Z)" +) + + +def _resolve_hapi_base_url() -> str: + runpod_env = os.getenv("RUNPOD_ENV", "").lower() + if runpod_env == "dev": + return DEV_HAPI_BASE_URL + + api_base = os.getenv("RUNPOD_API_BASE_URL", "") + if DEV_API_BASE_URL in api_base: + return DEV_HAPI_BASE_URL + + return HAPI_BASE_URL + + +class QBRequestLogPhase(str, Enum): + WAITING_FOR_WORKER = "WAITING_FOR_WORKER" + WAITING_FOR_WORKER_INITIALIZATION = "WAITING_FOR_WORKER_INITIALIZATION" + STREAMING = "STREAMING" + + +@dataclass +class QBRequestLogBatch: + lines: List[str] + matched_by_request_id: bool + worker_id: Optional[str] + phase: QBRequestLogPhase + worker_metrics: dict[str, int] = field(default_factory=dict) + ready_worker_ids: List[str] = field(default_factory=list) + + +class QBRequestLogFetcher: + def __init__( + self, + timeout_seconds: float = 4.0, + max_lines: int = 25, + lookback_seconds: int = 20, + start_time: Optional[datetime] = None, + ): + self.timeout_seconds = timeout_seconds + self.max_lines = max_lines + self.lookback_seconds = lookback_seconds + self.start_time = start_time or datetime.now(timezone.utc) + self.seen = set() + self.worker_id: Optional[str] = None + self.has_streamed_logs = False + self.has_primed_worker_logs = False + + async def fetch_logs( + self, + endpoint_id: str, + request_id: str, + status_api_key: str, + pod_logs_api_key: str, + status_api_key_fallback: Optional[str] = None, + ): + status_payload = await self._fetch_status_payload( + endpoint_id=endpoint_id, + request_id=request_id, + status_api_key=status_api_key, + status_api_key_fallback=status_api_key_fallback, + ) + assigned_worker_id = self._worker_id_from_status_payload(status_payload) + + metrics_payload = await self._fetch_metrics_payload( + endpoint_id=endpoint_id, + status_api_key=status_api_key, + status_api_key_fallback=status_api_key_fallback, + ) + running_worker_ids = self._ready_worker_ids_from_metrics(metrics_payload) + initializing_workers = self._initializing_worker_count(metrics_payload) + worker_metrics = self._worker_metrics_snapshot(metrics_payload) + ready_worker_ids = self._ready_worker_ids_from_metrics(metrics_payload) + + matched_by_request_id = False + if assigned_worker_id: + self._set_worker_id(assigned_worker_id) + matched_by_request_id = True + elif not self.worker_id and running_worker_ids: + self._set_worker_id(running_worker_ids[0]) + + if not self.worker_id: + phase = ( + QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION + if initializing_workers > 0 + else QBRequestLogPhase.WAITING_FOR_WORKER + ) + return QBRequestLogBatch( + lines=[], + matched_by_request_id=False, + worker_id=None, + phase=phase, + worker_metrics=worker_metrics, + ready_worker_ids=ready_worker_ids, + ) + + logs_payload = await self._fetch_pod_logs( + worker_id=self.worker_id, + runpod_api_key=pod_logs_api_key, + ) + if not logs_payload: + return QBRequestLogBatch( + lines=[], + matched_by_request_id=matched_by_request_id, + worker_id=self.worker_id, + phase=QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION, + worker_metrics=worker_metrics, + ready_worker_ids=ready_worker_ids, + ) + + if not self.has_primed_worker_logs: + lines = self._extract_initial_lines(logs_payload, request_id=request_id) + self.has_primed_worker_logs = True + if lines: + self.has_streamed_logs = True + return QBRequestLogBatch( + lines=lines[-self.max_lines :], + matched_by_request_id=matched_by_request_id, + worker_id=self.worker_id, + phase=( + QBRequestLogPhase.STREAMING + if self.has_streamed_logs + else QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION + ), + worker_metrics=worker_metrics, + ready_worker_ids=ready_worker_ids, + ) + + lines = self._extract_lines(logs_payload) + if lines: + self.has_streamed_logs = True + phase = ( + QBRequestLogPhase.STREAMING + if self.has_streamed_logs + else QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION + ) + return QBRequestLogBatch( + lines=lines[-self.max_lines :], + matched_by_request_id=matched_by_request_id, + worker_id=self.worker_id, + phase=phase, + worker_metrics=worker_metrics, + ready_worker_ids=ready_worker_ids, + ) + + async def _fetch_status_payload( + self, + endpoint_id: str, + request_id: str, + status_api_key: str, + status_api_key_fallback: Optional[str], + ) -> Optional[dict[str, Any]]: + url = f"{API_BASE_URL}/v2/{endpoint_id}/status/{request_id}" + auth_keys = self._auth_candidates(status_api_key, status_api_key_fallback) + + for auth_key in auth_keys: + try: + async with get_authenticated_httpx_client( + timeout=self.timeout_seconds, + api_key_override=auth_key, + ) as client: + response = await client.get(url) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as exc: + if exc.response is not None and exc.response.status_code == 401: + continue + log.debug( + "Failed to fetch worker for request %s: %s", + request_id, + exc, + ) + return None + except (httpx.HTTPError, ValueError) as exc: + log.debug("Failed to fetch worker for request %s: %s", request_id, exc) + return None + + return None + + async def _fetch_metrics_payload( + self, + endpoint_id: str, + status_api_key: str, + status_api_key_fallback: Optional[str], + ) -> Optional[dict[str, Any]]: + auth_keys = self._auth_candidates(status_api_key, status_api_key_fallback) + url = f"{API_BASE_URL}/v2/{endpoint_id}/metrics" + + for auth_key in auth_keys: + try: + async with get_authenticated_httpx_client( + timeout=self.timeout_seconds, + api_key_override=auth_key, + ) as client: + response = await client.get(url) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as exc: + if exc.response is not None and exc.response.status_code == 401: + continue + log.debug( + "Failed to fetch endpoint metrics for %s via %s: %s", + endpoint_id, + url, + exc, + ) + return None + except (httpx.HTTPError, ValueError) as exc: + log.debug( + "Failed to fetch endpoint metrics for %s via %s: %s", + endpoint_id, + url, + exc, + ) + return None + + return None + + @staticmethod + def _worker_id_from_status_payload( + payload: Optional[dict[str, Any]], + ) -> Optional[str]: + if not payload: + return None + worker_id = payload.get("workerId") + if not worker_id: + return None + return str(worker_id) + + @staticmethod + def _ready_worker_ids_from_metrics(payload: Optional[dict[str, Any]]) -> List[str]: + if not payload: + return [] + ready_workers = payload.get("readyWorkers") + if not isinstance(ready_workers, list): + return [] + return [str(worker) for worker in ready_workers if worker] + + @staticmethod + def _worker_metrics_snapshot(payload: Optional[dict[str, Any]]) -> dict[str, int]: + base = { + "ready": 0, + "running": 0, + "idle": 0, + "initializing": 0, + "throttled": 0, + "unhealthy": 0, + } + if not payload: + return base + workers = payload.get("workers") + if not isinstance(workers, dict): + return base + for key in base: + value = workers.get(key) + if isinstance(value, int): + base[key] = value + return base + + @staticmethod + def _initializing_worker_count(payload: Optional[dict[str, Any]]) -> int: + if not payload: + return 0 + workers = payload.get("workers") + if not isinstance(workers, dict): + return 0 + initializing = workers.get("initializing", 0) + if isinstance(initializing, int): + return initializing + return 0 + + @staticmethod + def _auth_candidates( + primary_key: str, + fallback_key: Optional[str], + ) -> List[str]: + keys = [primary_key] + if fallback_key and fallback_key != primary_key: + keys.append(fallback_key) + return keys + + async def _fetch_pod_logs( + self, + worker_id: str, + runpod_api_key: str, + ) -> Optional[dict[str, Any]]: + url = f"{_resolve_hapi_base_url()}/v1/pod/{worker_id}/logs" + + try: + async with get_authenticated_httpx_client( + timeout=self.timeout_seconds, + api_key_override=runpod_api_key, + ) as client: + response = await client.get(url) + response.raise_for_status() + return response.json() + except httpx.HTTPStatusError as exc: + body_preview = "" + if exc.response is not None: + body_preview = (exc.response.text or "")[:500] + log.debug( + "Failed to fetch pod logs for %s: %s | response_body=%s", + worker_id, + exc, + body_preview, + ) + return None + except (httpx.HTTPError, ValueError) as exc: + log.debug("Failed to fetch pod logs for %s: %s", worker_id, exc) + return None + + def _extract_lines(self, payload: dict[str, Any]) -> List[str]: + records = self._collect_records(payload) + if not records: + return [] + + lines: List[str] = [] + + for record in records: + if not isinstance(record, str): + continue + + stripped = record.strip().replace("\\n", "") + if not stripped or stripped in self.seen: + continue + self.seen.add(stripped) + lines.append(stripped) + + return lines + + def _extract_initial_lines( + self, payload: dict[str, Any], request_id: str + ) -> List[str]: + records = self._collect_records(payload) + if not records: + return [] + + cutoff = self.start_time.timestamp() - self.lookback_seconds + lines: List[str] = [] + saw_recent_window_line = False + + for record in records: + if not isinstance(record, str): + continue + + stripped = record.strip().replace("\\n", "") + if not stripped: + continue + + if stripped in self.seen: + continue + self.seen.add(stripped) + + timestamp = self._parse_prefix_timestamp(stripped) + if timestamp is not None and timestamp.timestamp() < cutoff: + continue + + if timestamp is not None: + saw_recent_window_line = True + lines.append(stripped) + continue + + if request_id and request_id in stripped: + lines.append(stripped) + continue + + if saw_recent_window_line: + lines.append(stripped) + continue + + return lines + + def _set_worker_id(self, worker_id: str) -> None: + if self.worker_id == worker_id: + return + self.worker_id = worker_id + self.seen = set() + self.has_streamed_logs = False + self.has_primed_worker_logs = False + + @staticmethod + def _collect_records(payload: dict[str, Any]) -> List[Any]: + container_records = payload.get("container") + system_records = payload.get("system") + + records: list[Any] = [] + if isinstance(system_records, list): + records.extend(system_records) + if isinstance(container_records, list): + records.extend(container_records) + + return records + + @staticmethod + def _parse_prefix_timestamp(line: str) -> Optional[datetime]: + match = LOG_PREFIX_TIMESTAMP_RE.match(line) + if not match: + return None + + timestamp_text = match.group("timestamp") + normalized = timestamp_text.replace("Z", "+00:00") + + try: + return datetime.fromisoformat(normalized) + except ValueError: + return None diff --git a/src/runpod_flash/core/resources/serverless.py b/src/runpod_flash/core/resources/serverless.py index c095edeb..648742f8 100644 --- a/src/runpod_flash/core/resources/serverless.py +++ b/src/runpod_flash/core/resources/serverless.py @@ -2,6 +2,9 @@ import json import logging import os +import re +from collections import Counter +from datetime import datetime, timezone from enum import Enum from pathlib import Path from typing import Any, ClassVar, Dict, List, Optional, Set @@ -30,8 +33,11 @@ from .cpu import CpuInstanceType from .gpu import GpuGroup, GpuType from .network_volume import NetworkVolume, DataCenter, CPU_DATACENTERS +from .request_logs import QBRequestLogBatch, QBRequestLogFetcher, QBRequestLogPhase +from .worker_availability_diagnostic import WorkerAvailabilityDiagnostic from .template import KeyValuePair, PodTemplate from .resource_manager import ResourceManager +from ..credentials import get_api_key # Prefix applied to endpoint names during live provisioning @@ -42,6 +48,19 @@ log = logging.getLogger(__name__) +POD_LOG_PREFIX_RE = re.compile(r"^\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}(?:\.\d+)?Z\s+") + + +def _normalize_stream_log_line(line: str) -> str: + normalized = line.strip() + if not normalized: + return "" + + if normalized.lower().startswith("worker log:"): + normalized = normalized.split(":", 1)[1].strip() + + normalized = POD_LOG_PREFIX_RE.sub("", normalized, count=1) + return normalized class ServerlessScalerType(Enum): @@ -213,6 +232,39 @@ def endpoint_url(self) -> str: base_url = self.endpoint.rp_client.endpoint_url_base return f"{base_url}/{self.id}" + async def _emit_endpoint_logs( + self, + fetcher: QBRequestLogFetcher, + request_id: str, + ) -> Optional["QBRequestLogBatch"]: + if self.type != ServerlessType.QB: + return None + + if not self.id: + return None + + pod_logs_api_key = get_api_key() + if not pod_logs_api_key: + return None + + status_api_key = self.aiKey or pod_logs_api_key + + batch = await fetcher.fetch_logs( + endpoint_id=self.id, + request_id=request_id, + status_api_key=status_api_key, + pod_logs_api_key=pod_logs_api_key, + status_api_key_fallback=pod_logs_api_key, + ) + if not batch: + return None + + if batch.lines: + for line in batch.lines: + log.info("worker log: %s", line) + + return batch + @field_serializer("scalerType") def serialize_scaler_type( self, value: Optional[ServerlessScalerType] @@ -1241,8 +1293,6 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": job: Optional[Job] = None try: - # log.debug(f"[{self}] Payload: {payload}") - # Create a job using the endpoint log.info(f"{self} | API /run") job = await asyncio.to_thread(self.endpoint.run, request_input=payload) @@ -1255,10 +1305,25 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": attempt = 0 job_status = Status.UNKNOWN last_status = job_status + fetcher = QBRequestLogFetcher(start_time=datetime.now(timezone.utc)) + last_log_state: ( + tuple[ + QBRequestLogPhase, + bool, + Optional[str], + ] + | None + ) = None + assigned_streaming_announced_worker: Optional[str] = None + worker_availability_diagnostic = WorkerAvailabilityDiagnostic() + repeated_no_worker_message: Optional[str] = None + waiting_update_count = 0 + emitted_initial_wait_metrics = False # Poll for job status while True: await asyncio.sleep(current_pace) + emit_regular_update = False # Check job status job_status = await asyncio.to_thread(job.status) @@ -1266,21 +1331,156 @@ async def run(self, payload: Dict[str, Any]) -> "JobOutput": if last_status == job_status: # nothing changed, increase the gap attempt += 1 - indicator = "." * (attempt // 2) if attempt % 2 == 0 else "" - if indicator: - log.info(f"{log_subgroup} | {indicator}") + if job_status != "IN_PROGRESS" and attempt % 2 == 0: + emit_regular_update = True + log.info(f"{log_subgroup} | {'.' * (attempt // 2)}") else: # status changed, reset the gap log.info(f"{log_subgroup} | Status: {job_status}") attempt = 0 + batch = await self._emit_endpoint_logs( + fetcher=fetcher, + request_id=job.job_id, + ) + + if batch: + current_log_state = ( + batch.phase, + batch.matched_by_request_id, + batch.worker_id, + ) + state_changed = current_log_state != last_log_state + + if ( + batch.phase == QBRequestLogPhase.STREAMING + and batch.matched_by_request_id + and batch.worker_id + ): + repeated_no_worker_message = None + waiting_update_count = 0 + if assigned_streaming_announced_worker != batch.worker_id: + log.info( + f"{log_subgroup} | Request assigned to worker {batch.worker_id}, streaming pod logs" + ) + assigned_streaming_announced_worker = batch.worker_id + elif state_changed: + if batch.phase == QBRequestLogPhase.WAITING_FOR_WORKER: + diagnostic = await worker_availability_diagnostic.diagnose( + self, + worker_metrics=batch.worker_metrics, + ) + log.info(f"{log_subgroup} | {diagnostic.message}") + if diagnostic.reason in ( + "no_gpu_availability", + "workers_throttled", + ): + repeated_no_worker_message = diagnostic.message + else: + repeated_no_worker_message = None + waiting_update_count = 0 + if ( + job_status != "IN_PROGRESS" + and not emitted_initial_wait_metrics + ): + worker_state = ( + batch.worker_id if batch.worker_id else "None" + ) + worker_metrics = batch.worker_metrics or {} + assignment_state = ( + "assigned" + if batch.matched_by_request_id + else "unassigned" + ) + log.info( + f"{log_subgroup} | Waiting for request: endpoint metrics: worker={worker_state}, assignment={assignment_state}, status={job_status}, workers={{ready:{worker_metrics.get('ready', 0)}, running:{worker_metrics.get('running', 0)}, idle:{worker_metrics.get('idle', 0)}, initializing:{worker_metrics.get('initializing', 0)}, throttled:{worker_metrics.get('throttled', 0)}, unhealthy:{worker_metrics.get('unhealthy', 0)}}}, readyWorkers={batch.ready_worker_ids}" + ) + emitted_initial_wait_metrics = True + elif ( + batch.phase + == QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION + ): + repeated_no_worker_message = None + waiting_update_count = 0 + emitted_initial_wait_metrics = False + if batch.matched_by_request_id and batch.worker_id: + log.info( + f"{log_subgroup} | Request assigned to worker {batch.worker_id}, waiting for worker initialization/image pull logs" + ) + elif batch.worker_id: + log.info( + f"{log_subgroup} | Worker capacity detected, waiting for request assignment and worker initialization/image pull" + ) + else: + log.info( + f"{log_subgroup} | Waiting for worker initialization/image pull" + ) + elif batch.phase == QBRequestLogPhase.STREAMING: + repeated_no_worker_message = None + waiting_update_count = 0 + emitted_initial_wait_metrics = False + log.info( + f"{log_subgroup} | Streaming endpoint startup logs while waiting for request assignment" + ) + + last_log_state = current_log_state + + if emit_regular_update: + waiting_update_count += 1 + if waiting_update_count % 5 == 0: + worker_state = ( + batch.worker_id if batch and batch.worker_id else "None" + ) + worker_metrics = (batch.worker_metrics if batch else {}) or {} + assignment_state = ( + "assigned" + if batch and batch.matched_by_request_id + else "unassigned" + ) + log.info( + f"{log_subgroup} | Waiting for request: endpoint metrics: worker={worker_state}, assignment={assignment_state}, status={job_status}, workers={{ready:{worker_metrics.get('ready', 0)}, running:{worker_metrics.get('running', 0)}, idle:{worker_metrics.get('idle', 0)}, initializing:{worker_metrics.get('initializing', 0)}, throttled:{worker_metrics.get('throttled', 0)}, unhealthy:{worker_metrics.get('unhealthy', 0)}}}, readyWorkers={batch.ready_worker_ids if batch else []}" + ) + if repeated_no_worker_message: + log.info(f"{log_subgroup} | {repeated_no_worker_message}") + last_status = job_status # Adjust polling pace appropriately - current_pace = get_backoff_delay(attempt) + current_pace = get_backoff_delay(attempt, max_seconds=5) if job_status in ("COMPLETED", "FAILED", "CANCELLED"): + for _ in range(2): + await self._emit_endpoint_logs( + fetcher=fetcher, + request_id=job.job_id, + ) response = await asyncio.to_thread(job._fetch_job) + output = response.get("output") + if isinstance(output, dict): + stdout = output.get("stdout") + should_dedupe_stdout = ( + self.type == ServerlessType.QB + and fetcher.has_streamed_logs + and bool(fetcher.seen) + ) + if should_dedupe_stdout and isinstance(stdout, str): + seen_normalized_counts = Counter( + normalized + for line in fetcher.seen + if (normalized := _normalize_stream_log_line(line)) + ) + kept = [] + for raw_line in stdout.splitlines(keepends=True): + normalized_raw = _normalize_stream_log_line(raw_line) + if ( + normalized_raw + and seen_normalized_counts.get(normalized_raw, 0) + > 0 + ): + seen_normalized_counts[normalized_raw] -= 1 + continue + kept.append(raw_line) + output["stdout"] = "".join(kept) return JobOutput(**response) except Exception as e: diff --git a/src/runpod_flash/core/resources/worker_availability_diagnostic.py b/src/runpod_flash/core/resources/worker_availability_diagnostic.py new file mode 100644 index 00000000..825eab1a --- /dev/null +++ b/src/runpod_flash/core/resources/worker_availability_diagnostic.py @@ -0,0 +1,242 @@ +import logging +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING + +from ..api.runpod import RunpodGraphQLClient + +if TYPE_CHECKING: + from .serverless import ServerlessResource + + +log = logging.getLogger(__name__) +AVAILABLE_STOCK_STATUSES = {"LOW", "MEDIUM", "HIGH"} + + +@dataclass +class WorkerAvailabilityResult: + message: str + has_availability: Optional[bool] + reason: str + + +class WorkerAvailabilityDiagnostic: + async def diagnose( + self, + resource: "ServerlessResource", + worker_metrics: Optional[Dict[str, int]] = None, + ) -> WorkerAvailabilityResult: + if (resource.workersMax or 0) == 0: + return WorkerAvailabilityResult( + message="No compute available for your chosen configuration: your max workers are currently set to 0.", + has_availability=False, + reason="workers_max_zero", + ) + + compute_kind, compute_choice = self._selected_compute(resource) + if not compute_choice: + return WorkerAvailabilityResult( + message="No compute available for your chosen configuration.", + has_availability=None, + reason="no_compute_selected", + ) + + throttled_workers = (worker_metrics or {}).get("throttled", 0) + if throttled_workers > 0: + compute_label = "gpu type" if compute_kind == "gpu" else "cpu type" + return WorkerAvailabilityResult( + message=( + f"Workers are currently throttled on endpoint for selected {compute_kind} {compute_choice}. " + f"Consider raising max workers or changing {compute_label}." + ), + has_availability=True, + reason="workers_throttled", + ) + + locations = self._selected_locations(resource) + + if compute_kind == "gpu": + availability_by_location = await self._gpu_availability( + gpu_id=compute_choice, + gpu_count=resource.gpuCount or 1, + locations=locations, + ) + return self._build_message( + compute_kind="gpu", + compute_choice=compute_choice, + locations=locations, + availability_by_location=availability_by_location, + include_available_signal=True, + ) + + if compute_kind == "cpu": + availability_by_location = await self._cpu_availability( + instance_id=compute_choice, + locations=locations, + ) + return self._build_message( + compute_kind="cpu", + compute_choice=compute_choice, + locations=locations, + availability_by_location=availability_by_location, + include_available_signal=False, + ) + + return WorkerAvailabilityResult( + message="No compute available for your chosen configuration.", + has_availability=None, + reason="unknown", + ) + + def _build_message( + self, + compute_kind: str, + compute_choice: str, + locations: List[str], + availability_by_location: Dict[str, Optional[str]], + include_available_signal: bool, + ) -> WorkerAvailabilityResult: + has_availability = any( + self._is_available_stock_status(status) + for status in availability_by_location.values() + ) + + if not has_availability: + selected_locations = ", ".join(locations) if locations else "all locations" + return WorkerAvailabilityResult( + message=( + f"No workers available on endpoint: no {compute_kind} availability for {compute_kind} type {compute_choice} " + f"in selected locations ({selected_locations})." + ), + has_availability=False, + reason=f"no_{compute_kind}_availability", + ) + + if include_available_signal: + signal = self._summarize_stock_signal(availability_by_location) + return WorkerAvailabilityResult( + message=( + f"No workers available right now. Current availability signal " + f"for selected gpu {compute_choice}: {signal}." + ), + has_availability=True, + reason="gpu_has_availability", + ) + + return WorkerAvailabilityResult( + message=( + f"No workers available right now for selected {compute_kind} " + f"{compute_choice}." + ), + has_availability=True, + reason=f"{compute_kind}_has_availability", + ) + + async def _gpu_availability( + self, + gpu_id: str, + gpu_count: int, + locations: List[str], + ) -> Dict[str, Optional[str]]: + location_inputs = locations or [None] + availability_by_location: Dict[str, Optional[str]] = {} + + async with RunpodGraphQLClient() as client: + for location in location_inputs: + key = location or "global" + try: + status = await client.get_gpu_lowest_price_stock_status( + gpu_id=gpu_id, + gpu_count=gpu_count, + data_center_id=location, + ) + availability_by_location[key] = status + except Exception as exc: + log.debug("GPU availability query failed for %s: %s", key, exc) + availability_by_location[key] = None + + return availability_by_location + + async def _cpu_availability( + self, + instance_id: str, + locations: List[str], + ) -> Dict[str, Optional[str]]: + flavor_id = self._cpu_flavor_id(instance_id) + if not flavor_id: + return {loc: None for loc in (locations or ["global"])} + + location_inputs = locations or [""] + availability_by_location: Dict[str, Optional[str]] = {} + + async with RunpodGraphQLClient() as client: + for location in location_inputs: + key = location or "global" + try: + status = await client.get_cpu_specific_stock_status( + cpu_flavor_id=flavor_id, + instance_id=instance_id, + data_center_id=location, + ) + availability_by_location[key] = status + except Exception as exc: + log.debug("CPU availability query failed for %s: %s", key, exc) + availability_by_location[key] = None + + return availability_by_location + + @staticmethod + def _selected_compute(resource: "ServerlessResource") -> Tuple[str, Optional[str]]: + if resource.instanceIds: + first_instance = resource.instanceIds[0] + choice = ( + first_instance.value + if hasattr(first_instance, "value") + else str(first_instance) + ) + return "cpu", choice + + gpu_ids = [ + part.strip() for part in (resource.gpuIds or "").split(",") if part.strip() + ] + if gpu_ids: + return "gpu", gpu_ids[0] + + return "unknown", None + + @staticmethod + def _selected_locations(resource: "ServerlessResource") -> List[str]: + return [ + part.strip() + for part in (resource.locations or "").split(",") + if part.strip() + ] + + @staticmethod + def _cpu_flavor_id(instance_id: str) -> Optional[str]: + if "-" not in instance_id: + return None + return instance_id.split("-", 1)[0] + + @staticmethod + def _summarize_stock_signal( + availability_by_location: Dict[str, Optional[str]], + ) -> str: + non_empty = [status for status in availability_by_location.values() if status] + if not non_empty: + return "unknown" + + priority = {"HIGH": 3, "MEDIUM": 2, "LOW": 1} + + def score(value: str) -> int: + normalized = value.strip().upper().replace("-", "_").replace(" ", "_") + return priority.get(normalized, 0) + + best = max(non_empty, key=score) + return best + + @staticmethod + def _is_available_stock_status(status: Optional[str]) -> bool: + if not isinstance(status, str): + return False + normalized = status.strip().upper().replace("-", "_").replace(" ", "_") + return normalized in AVAILABLE_STOCK_STATUSES diff --git a/tests/unit/cli/utils/test_deployment.py b/tests/unit/cli/utils/test_deployment.py index 93e99c29..7e20656a 100644 --- a/tests/unit/cli/utils/test_deployment.py +++ b/tests/unit/cli/utils/test_deployment.py @@ -483,3 +483,152 @@ async def test_deploy_succeeds_without_api_key_when_no_remote_calls(tmp_path): await reconcile_and_provision_resources( app, "build-123", "dev", local_manifest, show_progress=False ) + + +@pytest.mark.asyncio +async def test_provision_resources_persists_ai_key_to_manifest(mock_flash_app): + manifest = { + "resources": { + "cpu": {"resource_type": "ServerlessResource"}, + } + } + mock_flash_app.get_build_manifest.return_value = manifest + + deployed = MagicMock() + deployed.endpoint_url = "https://example.com/endpoint" + deployed.id = "endpoint-123" + deployed.aiKey = "ai-key-123" + + with ( + patch("runpod_flash.cli.utils.deployment.ResourceManager") as mock_manager_cls, + patch( + "runpod_flash.cli.utils.deployment.create_resource_from_manifest" + ) as mock_create_resource, + ): + mock_manager = MagicMock() + mock_manager.get_or_deploy_resource = AsyncMock(return_value=deployed) + mock_manager_cls.return_value = mock_manager + mock_create_resource.return_value = MagicMock() + + await provision_resources_for_build( + mock_flash_app, + "build-123", + "dev", + show_progress=False, + ) + + call_args = mock_flash_app.update_build_manifest.call_args + updated_manifest = call_args[0][1] + assert updated_manifest["resources"]["cpu"]["endpoint_id"] == "endpoint-123" + assert updated_manifest["resources"]["cpu"]["aiKey"] == "ai-key-123" + + +@pytest.mark.asyncio +async def test_reconciliation_copies_ai_key_from_state_manifest(tmp_path): + import json + + flash_dir = tmp_path / ".flash" + flash_dir.mkdir() + + local_manifest = { + "resources": { + "worker": { + "resource_type": "LiveServerless", + "config": "same", + "endpoint_id": "endpoint-123", + "aiKey": "ai-key-123", + }, + }, + "resources_endpoints": {}, + } + (flash_dir / "flash_manifest.json").write_text(json.dumps(local_manifest)) + + state_manifest = { + "resources": { + "worker": { + "resource_type": "LiveServerless", + "config": "same", + "endpoint_id": "endpoint-123", + "aiKey": "ai-key-123", + }, + }, + "resources_endpoints": { + "worker": "https://worker.api.runpod.ai", + }, + } + + app = AsyncMock() + app.get_build_manifest = AsyncMock(return_value=state_manifest) + app.update_build_manifest = AsyncMock() + + with ( + patch("pathlib.Path.cwd", return_value=tmp_path), + patch("runpod_flash.cli.utils.deployment.ResourceManager") as mock_manager_cls, + ): + mock_manager = MagicMock() + mock_manager.get_or_deploy_resource = AsyncMock() + mock_manager_cls.return_value = mock_manager + + await reconcile_and_provision_resources(app, "build-123", "dev", local_manifest) + + updated_manifest = app.update_build_manifest.call_args[0][1] + assert updated_manifest["resources"]["worker"]["endpoint_id"] == "endpoint-123" + assert updated_manifest["resources"]["worker"]["aiKey"] == "ai-key-123" + assert ( + updated_manifest["resources_endpoints"]["worker"] + == "https://worker.api.runpod.ai" + ) + + with open(flash_dir / "flash_manifest.json") as f: + persisted_manifest = json.load(f) + assert "aiKey" not in persisted_manifest["resources"]["worker"] + + +@pytest.mark.asyncio +async def test_reconciliation_ignores_runtime_fields_in_config_comparison(tmp_path): + import json + + flash_dir = tmp_path / ".flash" + flash_dir.mkdir() + + local_manifest = { + "resources": { + "worker": { + "resource_type": "LiveServerless", + "config": "same", + }, + }, + "resources_endpoints": {}, + } + (flash_dir / "flash_manifest.json").write_text(json.dumps(local_manifest)) + + state_manifest = { + "resources": { + "worker": { + "resource_type": "LiveServerless", + "config": "same", + "aiKey": "ai-key-123", + "endpoint_id": "endpoint-123", + "templateId": "template-123", + }, + }, + "resources_endpoints": { + "worker": "https://worker.api.runpod.ai", + }, + } + + app = AsyncMock() + app.get_build_manifest = AsyncMock(return_value=state_manifest) + app.update_build_manifest = AsyncMock() + + with ( + patch("pathlib.Path.cwd", return_value=tmp_path), + patch("runpod_flash.cli.utils.deployment.ResourceManager") as mock_manager_cls, + ): + mock_manager = MagicMock() + mock_manager.get_or_deploy_resource = AsyncMock() + mock_manager_cls.return_value = mock_manager + + await reconcile_and_provision_resources(app, "build-123", "dev", local_manifest) + + mock_manager.get_or_deploy_resource.assert_not_called() diff --git a/tests/unit/resources/test_request_logs.py b/tests/unit/resources/test_request_logs.py new file mode 100644 index 00000000..9619d3bf --- /dev/null +++ b/tests/unit/resources/test_request_logs.py @@ -0,0 +1,247 @@ +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import httpx +import pytest + +from runpod_flash.core.resources.request_logs import ( + QBRequestLogFetcher, + QBRequestLogPhase, +) + + +def _make_async_client(mock_client: MagicMock) -> MagicMock: + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=mock_client) + cm.__aexit__ = AsyncMock(return_value=None) + return cm + + +@pytest.mark.asyncio +async def test_waiting_for_workers_when_none_running_or_initializing(): + fetcher = QBRequestLogFetcher(start_time=datetime(2026, 1, 1, tzinfo=timezone.utc)) + + status_response = MagicMock() + status_response.raise_for_status = MagicMock() + status_response.json.return_value = {"status": "IN_QUEUE"} + + metrics_response = MagicMock() + metrics_response.raise_for_status = MagicMock() + metrics_response.json.return_value = { + "workers": {"initializing": 0}, + "readyWorkers": [], + } + + mock_client = MagicMock() + mock_client.get = AsyncMock(side_effect=[status_response, metrics_response]) + + with patch( + "runpod_flash.core.resources.request_logs.get_authenticated_httpx_client", + return_value=_make_async_client(mock_client), + ): + batch = await fetcher.fetch_logs( + endpoint_id="endpoint-1", + request_id="request-1", + status_api_key="status-key", + pod_logs_api_key="runpod-key", + ) + + assert batch.phase == QBRequestLogPhase.WAITING_FOR_WORKER + assert batch.worker_id is None + assert batch.lines == [] + + +@pytest.mark.asyncio +async def test_waiting_for_worker_initialization_when_workers_initializing(): + fetcher = QBRequestLogFetcher() + + status_response = MagicMock() + status_response.raise_for_status = MagicMock() + status_response.json.return_value = {"status": "IN_QUEUE"} + + metrics_response = MagicMock() + metrics_response.raise_for_status = MagicMock() + metrics_response.json.return_value = { + "workers": {"initializing": 1}, + "readyWorkers": [], + } + + mock_client = MagicMock() + mock_client.get = AsyncMock(side_effect=[status_response, metrics_response]) + + with patch( + "runpod_flash.core.resources.request_logs.get_authenticated_httpx_client", + return_value=_make_async_client(mock_client), + ): + batch = await fetcher.fetch_logs( + endpoint_id="endpoint-1", + request_id="request-1", + status_api_key="status-key", + pod_logs_api_key="runpod-key", + ) + + assert batch.phase == QBRequestLogPhase.WAITING_FOR_WORKER_INITIALIZATION + assert batch.worker_id is None + assert batch.lines == [] + + +@pytest.mark.asyncio +async def test_primes_existing_worker_logs_then_streams_new_lines(): + fetcher = QBRequestLogFetcher( + start_time=datetime(2026, 4, 2, 17, 14, 7, tzinfo=timezone.utc), + lookback_seconds=20, + ) + + status_1 = MagicMock() + status_1.raise_for_status = MagicMock() + status_1.json.return_value = {"status": "IN_QUEUE"} + + metrics_1 = MagicMock() + metrics_1.raise_for_status = MagicMock() + metrics_1.json.return_value = { + "workers": {"initializing": 0, "running": 1}, + "readyWorkers": ["worker-running-1"], + } + + old_logs = MagicMock() + old_logs.raise_for_status = MagicMock() + old_logs.json.return_value = { + "container": ["2026-04-02T17:14:05Z create container"], + "system": [ + "2026-04-02T16:38:18Z very old line", + '{"requestId": "request-1", "message": "Started.", "level": "INFO"}', + "ae1225 smoke: worker started", + ], + } + + status_2 = MagicMock() + status_2.raise_for_status = MagicMock() + status_2.json.return_value = {"status": "IN_QUEUE"} + + metrics_2 = MagicMock() + metrics_2.raise_for_status = MagicMock() + metrics_2.json.return_value = { + "workers": {"initializing": 0, "running": 1}, + "readyWorkers": ["worker-running-1"], + } + + new_logs = MagicMock() + new_logs.raise_for_status = MagicMock() + new_logs.json.return_value = { + "container": ["2026-04-02T17:14:08Z start container"], + "system": ["2026-04-02T17:14:05Z create container"], + } + + mock_client = MagicMock() + mock_client.get = AsyncMock( + side_effect=[status_1, metrics_1, old_logs, status_2, metrics_2, new_logs] + ) + + with patch( + "runpod_flash.core.resources.request_logs.get_authenticated_httpx_client", + return_value=_make_async_client(mock_client), + ): + first_batch = await fetcher.fetch_logs( + endpoint_id="endpoint-1", + request_id="request-1", + status_api_key="endpoint-ai-key", + pod_logs_api_key="runpod-key", + ) + second_batch = await fetcher.fetch_logs( + endpoint_id="endpoint-1", + request_id="request-1", + status_api_key="endpoint-ai-key", + pod_logs_api_key="runpod-key", + ) + + assert first_batch.worker_id == "worker-running-1" + assert first_batch.phase == QBRequestLogPhase.STREAMING + assert first_batch.lines == [ + '{"requestId": "request-1", "message": "Started.", "level": "INFO"}', + "2026-04-02T17:14:05Z create container", + ] + + assert second_batch.worker_id == "worker-running-1" + assert second_batch.phase == QBRequestLogPhase.STREAMING + assert second_batch.lines == ["2026-04-02T17:14:08Z start container"] + assert second_batch.matched_by_request_id is False + + +@pytest.mark.asyncio +async def test_status_uses_fallback_key_on_401(): + fetcher = QBRequestLogFetcher() + + unauthorized = httpx.Response( + status_code=401, + request=httpx.Request( + "GET", "https://api.runpod.ai/v2/endpoint-1/status/request-1" + ), + ) + + status_response = MagicMock() + status_response.raise_for_status = MagicMock() + status_response.json.return_value = {"workerId": "worker-123"} + + metrics_response = MagicMock() + metrics_response.raise_for_status = MagicMock() + metrics_response.json.return_value = { + "workers": {"initializing": 0, "running": 1}, + "readyWorkers": ["worker-123"], + } + + pod_logs_response = MagicMock() + pod_logs_response.raise_for_status = MagicMock() + pod_logs_response.json.return_value = {"container": ["old"], "system": []} + + status_response_2 = MagicMock() + status_response_2.raise_for_status = MagicMock() + status_response_2.json.return_value = {"workerId": "worker-123"} + + metrics_response_2 = MagicMock() + metrics_response_2.raise_for_status = MagicMock() + metrics_response_2.json.return_value = { + "workers": {"initializing": 0, "running": 1}, + "readyWorkers": ["worker-123"], + } + + pod_logs_response_2 = MagicMock() + pod_logs_response_2.raise_for_status = MagicMock() + pod_logs_response_2.json.return_value = {"container": ["new"], "system": []} + + mock_client = MagicMock() + mock_client.get = AsyncMock( + side_effect=[ + httpx.HTTPStatusError( + "unauthorized", request=unauthorized.request, response=unauthorized + ), + status_response, + metrics_response, + pod_logs_response, + status_response_2, + metrics_response_2, + pod_logs_response_2, + ] + ) + + with patch( + "runpod_flash.core.resources.request_logs.get_authenticated_httpx_client", + return_value=_make_async_client(mock_client), + ): + await fetcher.fetch_logs( + endpoint_id="endpoint-1", + request_id="request-1", + status_api_key="endpoint-ai-key", + pod_logs_api_key="runpod-key", + status_api_key_fallback="runpod-key", + ) + second_batch = await fetcher.fetch_logs( + endpoint_id="endpoint-1", + request_id="request-1", + status_api_key="endpoint-ai-key", + pod_logs_api_key="runpod-key", + status_api_key_fallback="runpod-key", + ) + + assert second_batch.worker_id == "worker-123" + assert second_batch.phase == QBRequestLogPhase.STREAMING + assert second_batch.lines == ["new"] diff --git a/tests/unit/resources/test_serverless.py b/tests/unit/resources/test_serverless.py index 3f44a9aa..54833e45 100644 --- a/tests/unit/resources/test_serverless.py +++ b/tests/unit/resources/test_serverless.py @@ -11,6 +11,7 @@ ServerlessResource, ServerlessEndpoint, ServerlessScalerType, + ServerlessType, CudaVersion, JobOutput, WorkersHealth, @@ -22,6 +23,13 @@ from runpod_flash.core.resources.gpu import GpuGroup from runpod_flash.core.resources.cpu import CpuInstanceType from runpod_flash.core.resources.network_volume import NetworkVolume, DataCenter +from runpod_flash.core.resources.request_logs import ( + QBRequestLogBatch, + QBRequestLogPhase, +) +from runpod_flash.core.resources.worker_availability_diagnostic import ( + WorkerAvailabilityResult, +) from runpod_flash.core.resources.template import KeyValuePair, PodTemplate @@ -1205,6 +1213,479 @@ async def test_run_async_success(self): assert result.id == "job-123" assert result.status == "COMPLETED" + @pytest.mark.asyncio + async def test_run_async_dedupes_stdout_against_streamed_pod_logs(self): + serverless = ServerlessResource(name="test") + serverless.id = "endpoint-123" + serverless.type = ServerlessType.QB + serverless.aiKey = "endpoint-ai-key" + + mock_job = MagicMock() + mock_job.job_id = "job-123" + mock_job.status.side_effect = ["IN_QUEUE", "COMPLETED"] + mock_job._fetch_job.return_value = { + "id": "job-123", + "workerId": "worker-456", + "status": "COMPLETED", + "delayTime": 1000, + "executionTime": 2000, + "output": { + "stdout": "2026-04-02T18:18:10.165152015Z 2026-04-02 18:18:10,164 | DEBUG | aiohttp_retry | client.py:110 | Attempt 1 out of 3\n" + "2026-04-02 18:18:10,164 | DEBUG | aiohttp_retry | client.py:110 | Attempt 1 out of 3\n" + "unique stdout line" + }, + } + + mock_endpoint = MagicMock() + mock_endpoint.run.return_value = mock_job + + async def fake_emit(*, fetcher, request_id): + fetcher.has_streamed_logs = True + fetcher.seen.add( + "2026-04-02T18:18:10.165152015Z 2026-04-02 18:18:10,164 | DEBUG | aiohttp_retry | client.py:110 | Attempt 1 out of 3" + ) + return None + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + with patch("asyncio.sleep"): + with patch.object( + ServerlessResource, + "_emit_endpoint_logs", + new=AsyncMock(side_effect=fake_emit), + ): + with patch( + "runpod_flash.core.resources.serverless.get_api_key", + return_value="runpod-key-123", + ): + result = await serverless.run({"input": "test"}) + + assert isinstance(result, JobOutput) + assert result.output["stdout"] == ( + "2026-04-02 18:18:10,164 | DEBUG | aiohttp_retry | client.py:110 | Attempt 1 out of 3\n" + "unique stdout line" + ) + + @pytest.mark.asyncio + async def test_run_async_keeps_stdout_unchanged_when_no_streamed_logs(self): + serverless = ServerlessResource(name="test") + serverless.id = "endpoint-123" + serverless.type = ServerlessType.QB + serverless.aiKey = "endpoint-ai-key" + + original_stdout = "dup line\ndup line\n\n spaced line" + mock_job = MagicMock() + mock_job.job_id = "job-123" + mock_job.status.side_effect = ["IN_QUEUE", "COMPLETED"] + mock_job._fetch_job.return_value = { + "id": "job-123", + "workerId": "worker-456", + "status": "COMPLETED", + "delayTime": 1000, + "executionTime": 2000, + "output": {"stdout": original_stdout}, + } + + mock_endpoint = MagicMock() + mock_endpoint.run.return_value = mock_job + + async def fake_emit(*, fetcher, request_id): + fetcher.seen.add("dup line") + return None + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + with patch("asyncio.sleep"): + with patch.object( + ServerlessResource, + "_emit_endpoint_logs", + new=AsyncMock(side_effect=fake_emit), + ): + result = await serverless.run({"input": "test"}) + + assert isinstance(result, JobOutput) + assert result.output["stdout"] == original_stdout + + @pytest.mark.asyncio + async def test_run_async_fetches_endpoint_logs_while_polling(self): + """Test run async polls endpoint logs every cycle until completion.""" + serverless = ServerlessResource(name="test") + serverless.id = "endpoint-123" + serverless.type = ServerlessType.QB + serverless.aiKey = "ai-key-123" + + mock_job = MagicMock() + mock_job.job_id = "job-123" + mock_job.status.side_effect = [ + "IN_QUEUE", + "IN_PROGRESS", + "IN_PROGRESS", + "COMPLETED", + ] + mock_job._fetch_job.return_value = { + "id": "job-123", + "workerId": "worker-456", + "status": "COMPLETED", + "delayTime": 1000, + "executionTime": 2000, + "output": {"result": "success"}, + } + + mock_endpoint = MagicMock() + mock_endpoint.run.return_value = mock_job + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + with patch("asyncio.sleep"): + with patch.object( + ServerlessResource, + "_emit_endpoint_logs", + new=AsyncMock(), + ) as mock_emit_logs: + await serverless.run({"input": "test"}) + + assert mock_emit_logs.await_count == 6 + fetchers = [call.kwargs["fetcher"] for call in mock_emit_logs.await_args_list] + assert len({id(fetcher) for fetcher in fetchers}) == 1 + request_ids = [ + call.kwargs["request_id"] for call in mock_emit_logs.await_args_list + ] + assert all(request_id == "job-123" for request_id in request_ids) + + @pytest.mark.asyncio + async def test_run_async_announces_assigned_worker_streaming_once(self): + serverless = ServerlessResource(name="test") + serverless.id = "endpoint-123" + serverless.type = ServerlessType.QB + serverless.aiKey = "ai-key-123" + + mock_job = MagicMock() + mock_job.job_id = "job-123" + mock_job.status.side_effect = [ + "IN_PROGRESS", + "IN_PROGRESS", + "IN_PROGRESS", + "COMPLETED", + ] + mock_job._fetch_job.return_value = { + "id": "job-123", + "workerId": "worker-456", + "status": "COMPLETED", + "delayTime": 1000, + "executionTime": 2000, + "output": {"result": "success"}, + } + + mock_endpoint = MagicMock() + mock_endpoint.run.return_value = mock_job + + assigned_batch = QBRequestLogBatch( + worker_id="worker-456", + lines=[], + matched_by_request_id=True, + phase=QBRequestLogPhase.STREAMING, + ) + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + with patch("asyncio.sleep"): + with patch.object( + ServerlessResource, + "_emit_endpoint_logs", + new=AsyncMock( + side_effect=[ + assigned_batch, + assigned_batch, + assigned_batch, + assigned_batch, + assigned_batch, + assigned_batch, + ] + ), + ): + with patch( + "runpod_flash.core.resources.serverless.log.info" + ) as mock_log_info: + await serverless.run({"input": "test"}) + + assigned_messages = [ + str(call.args[0]) + for call in mock_log_info.call_args_list + if call.args + and "Request assigned to worker worker-456, streaming pod logs" + in str(call.args[0]) + ] + assert len(assigned_messages) == 1 + + @pytest.mark.asyncio + async def test_run_async_repeats_no_gpu_availability_message_every_five_updates( + self, + ): + serverless = ServerlessResource(name="test") + serverless.id = "endpoint-123" + serverless.type = ServerlessType.QB + serverless.aiKey = "ai-key-123" + + mock_job = MagicMock() + mock_job.job_id = "job-123" + mock_job.status.side_effect = [ + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "COMPLETED", + ] + mock_job._fetch_job.return_value = { + "id": "job-123", + "workerId": "worker-456", + "status": "COMPLETED", + "delayTime": 1000, + "executionTime": 2000, + "output": {"result": "success"}, + } + + waiting_batch = QBRequestLogBatch( + worker_id=None, + lines=[], + matched_by_request_id=False, + phase=QBRequestLogPhase.WAITING_FOR_WORKER, + ) + + async def emit_waiting_batch(*, fetcher, request_id): + return waiting_batch + + mock_endpoint = MagicMock() + mock_endpoint.run.return_value = mock_job + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + with patch("asyncio.sleep"): + with patch.object( + ServerlessResource, + "_emit_endpoint_logs", + new=AsyncMock(side_effect=emit_waiting_batch), + ): + with patch( + "runpod_flash.core.resources.serverless.WorkerAvailabilityDiagnostic.diagnose", + new=AsyncMock( + return_value=WorkerAvailabilityResult( + message=( + "No workers available on endpoint: no gpu availability for gpu type NVIDIA GeForce RTX 4090" + ), + has_availability=False, + reason="no_gpu_availability", + ) + ), + ): + with patch( + "runpod_flash.core.resources.serverless.log.info" + ) as mock_log_info: + await serverless.run({"input": "test"}) + + no_worker_messages = [ + str(call.args[0]) + for call in mock_log_info.call_args_list + if call.args + and "No workers available on endpoint: no gpu availability for gpu type" + in str(call.args[0]) + ] + assert len(no_worker_messages) == 2 + + @pytest.mark.asyncio + async def test_run_async_stops_waiting_metrics_logs_after_in_progress(self): + serverless = ServerlessResource(name="test") + serverless.id = "endpoint-123" + serverless.type = ServerlessType.QB + serverless.aiKey = "ai-key-123" + + mock_job = MagicMock() + mock_job.job_id = "job-123" + mock_job.status.side_effect = [ + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_QUEUE", + "IN_PROGRESS", + "IN_PROGRESS", + "IN_PROGRESS", + "IN_PROGRESS", + "COMPLETED", + ] + mock_job._fetch_job.return_value = { + "id": "job-123", + "workerId": "worker-456", + "status": "COMPLETED", + "delayTime": 1000, + "executionTime": 2000, + "output": {"result": "success"}, + } + + waiting_batch = QBRequestLogBatch( + worker_id=None, + lines=[], + matched_by_request_id=False, + phase=QBRequestLogPhase.WAITING_FOR_WORKER, + worker_metrics={ + "ready": 0, + "running": 0, + "idle": 0, + "initializing": 0, + "throttled": 2, + "unhealthy": 0, + }, + ) + mock_endpoint = MagicMock() + mock_endpoint.run.return_value = mock_job + + with patch.object( + type(serverless), + "endpoint", + new_callable=lambda: property(lambda self: mock_endpoint), + ): + with patch("asyncio.sleep"): + with patch.object( + ServerlessResource, + "_emit_endpoint_logs", + new=AsyncMock(return_value=waiting_batch), + ): + with patch( + "runpod_flash.core.resources.serverless.WorkerAvailabilityDiagnostic.diagnose", + new=AsyncMock( + return_value=WorkerAvailabilityResult( + message="Workers are currently throttled on endpoint for selected gpu NVIDIA GeForce RTX 4090. Consider raising max workers or changing gpu type.", + has_availability=True, + reason="workers_throttled", + ) + ), + ): + with patch( + "runpod_flash.core.resources.serverless.log.info" + ) as mock_log_info: + await serverless.run({"input": "test"}) + + metrics_logs = [ + str(call.args[0]) + for call in mock_log_info.call_args_list + if call.args + and "Waiting for request: endpoint metrics:" in str(call.args[0]) + ] + assert metrics_logs + assert not any("status=IN_PROGRESS" in line for line in metrics_logs) + + @pytest.mark.asyncio + async def test_emit_endpoint_logs_uses_logger_for_worker_lines(self): + """Endpoint log emission logs each worker line through logger.""" + serverless = ServerlessResource(name="test") + serverless.id = "endpoint-123" + serverless.type = ServerlessType.QB + serverless.aiKey = "endpoint-ai-key" + + mock_fetcher = MagicMock() + mock_fetcher.fetch_logs = AsyncMock( + return_value=QBRequestLogBatch( + worker_id=None, + lines=["line-a", "line-b"], + matched_by_request_id=False, + phase=QBRequestLogPhase.STREAMING, + ) + ) + + with patch( + "runpod_flash.core.resources.serverless.get_api_key", + return_value="runpod-key-123", + ): + with patch("runpod_flash.core.resources.serverless.log.info") as mock_info: + batch = await serverless._emit_endpoint_logs( + fetcher=mock_fetcher, + request_id="job-123", + ) + + mock_fetcher.fetch_logs.assert_awaited_once_with( + endpoint_id="endpoint-123", + request_id="job-123", + status_api_key="endpoint-ai-key", + pod_logs_api_key="runpod-key-123", + status_api_key_fallback="runpod-key-123", + ) + assert batch is not None + assert batch.phase == QBRequestLogPhase.STREAMING + mock_info.assert_any_call("worker log: %s", "line-a") + mock_info.assert_any_call("worker log: %s", "line-b") + + @pytest.mark.asyncio + async def test_emit_endpoint_logs_skips_when_missing_required_fields(self): + """Endpoint log fetch is skipped unless QB endpoint has id and API key.""" + serverless = ServerlessResource(name="test") + mock_fetcher = MagicMock() + mock_fetcher.fetch_logs = AsyncMock(return_value=None) + + serverless.type = ServerlessType.QB + serverless.id = None + serverless.aiKey = "endpoint-ai-key" + with patch( + "runpod_flash.core.resources.serverless.get_api_key", + return_value="runpod-key-123", + ): + await serverless._emit_endpoint_logs( + fetcher=mock_fetcher, + request_id="job-123", + ) + + serverless.id = "endpoint-123" + serverless.aiKey = None + with patch( + "runpod_flash.core.resources.serverless.get_api_key", + return_value=None, + ): + await serverless._emit_endpoint_logs( + fetcher=mock_fetcher, + request_id="job-123", + ) + + serverless.type = ServerlessType.LB + serverless.aiKey = "endpoint-ai-key" + with patch( + "runpod_flash.core.resources.serverless.get_api_key", + return_value="runpod-key-123", + ): + await serverless._emit_endpoint_logs( + fetcher=mock_fetcher, + request_id="job-123", + ) + + mock_fetcher.fetch_logs.assert_not_awaited() + @pytest.mark.asyncio async def test_run_async_failure_cancels_job(self): """Test run async cancels job on exception.""" diff --git a/tests/unit/resources/test_worker_availability_diagnostic.py b/tests/unit/resources/test_worker_availability_diagnostic.py new file mode 100644 index 00000000..0807fe04 --- /dev/null +++ b/tests/unit/resources/test_worker_availability_diagnostic.py @@ -0,0 +1,151 @@ +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from runpod_flash.core.resources.cpu import CpuInstanceType +from runpod_flash.core.resources.serverless import ServerlessResource +from runpod_flash.core.resources.worker_availability_diagnostic import ( + WorkerAvailabilityDiagnostic, +) + + +def _make_client_context(mock_client: MagicMock) -> MagicMock: + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=mock_client) + cm.__aexit__ = AsyncMock(return_value=None) + return cm + + +@pytest.mark.asyncio +async def test_diagnose_returns_workers_max_zero_message(): + resource = ServerlessResource(name="test", workersMax=0) + + diagnostic = WorkerAvailabilityDiagnostic() + result = await diagnostic.diagnose(resource) + + assert result.has_availability is False + assert "max workers are currently set to 0" in result.message + assert result.reason == "workers_max_zero" + + +@pytest.mark.asyncio +async def test_diagnose_gpu_no_availability_includes_selected_locations(): + resource = ServerlessResource(name="test") + resource.gpuIds = "NVIDIA GeForce RTX 4090" + resource.locations = "EU-RO-1,US-GA-2" + + mock_client = MagicMock() + mock_client.get_gpu_lowest_price_stock_status = AsyncMock(side_effect=[None, None]) + + with patch( + "runpod_flash.core.resources.worker_availability_diagnostic.RunpodGraphQLClient", + return_value=_make_client_context(mock_client), + ): + result = await WorkerAvailabilityDiagnostic().diagnose(resource) + + assert result.has_availability is False + assert ( + "No workers available on endpoint: no gpu availability for gpu type NVIDIA GeForce RTX 4090" + in result.message + ) + assert "EU-RO-1, US-GA-2" in result.message + assert result.reason == "no_gpu_availability" + + +@pytest.mark.asyncio +async def test_diagnose_gpu_availability_shows_signal_without_locations(): + resource = ServerlessResource(name="test") + resource.gpuIds = "NVIDIA GeForce RTX 4090" + resource.locations = "EU-RO-1,US-GA-2" + + mock_client = MagicMock() + mock_client.get_gpu_lowest_price_stock_status = AsyncMock(side_effect=[None, "Low"]) + + with patch( + "runpod_flash.core.resources.worker_availability_diagnostic.RunpodGraphQLClient", + return_value=_make_client_context(mock_client), + ): + result = await WorkerAvailabilityDiagnostic().diagnose(resource) + + assert result.has_availability is True + assert ( + "Current availability signal for selected gpu NVIDIA GeForce RTX 4090: Low" + in result.message + ) + assert "EU-RO-1" not in result.message + assert "US-GA-2" not in result.message + assert result.reason == "gpu_has_availability" + + +@pytest.mark.asyncio +async def test_diagnose_cpu_no_availability_message(): + resource = ServerlessResource(name="test") + resource.instanceIds = [CpuInstanceType.CPU3G_2_8] + resource.locations = "EU-RO-1,US-GA-2" + + mock_client = MagicMock() + mock_client.get_cpu_specific_stock_status = AsyncMock(side_effect=[None, None]) + + with patch( + "runpod_flash.core.resources.worker_availability_diagnostic.RunpodGraphQLClient", + return_value=_make_client_context(mock_client), + ): + result = await WorkerAvailabilityDiagnostic().diagnose(resource) + + assert result.has_availability is False + assert ( + "No workers available on endpoint: no cpu availability for cpu type cpu3g-2-8" + in result.message + ) + assert "EU-RO-1, US-GA-2" in result.message + assert result.reason == "no_cpu_availability" + + +@pytest.mark.asyncio +async def test_diagnose_prefers_throttled_reason_over_no_availability(): + resource = ServerlessResource(name="test") + resource.gpuIds = "NVIDIA GeForce RTX 4090" + + result = await WorkerAvailabilityDiagnostic().diagnose( + resource, + worker_metrics={"throttled": 3}, + ) + + assert result.has_availability is True + assert result.reason == "workers_throttled" + assert "Workers are currently throttled on endpoint" in result.message + assert "Consider raising max workers or changing gpu type" in result.message + + +@pytest.mark.asyncio +async def test_diagnose_cpu_throttled_message_references_cpu_type(): + resource = ServerlessResource(name="test") + resource.instanceIds = [CpuInstanceType.CPU3G_2_8] + + result = await WorkerAvailabilityDiagnostic().diagnose( + resource, + worker_metrics={"throttled": 2}, + ) + + assert result.reason == "workers_throttled" + assert "changing cpu type" in result.message + + +@pytest.mark.asyncio +async def test_diagnose_treats_out_of_stock_as_unavailable(): + resource = ServerlessResource(name="test") + resource.gpuIds = "NVIDIA GeForce RTX 4090" + + mock_client = MagicMock() + mock_client.get_gpu_lowest_price_stock_status = AsyncMock( + side_effect=["OUT_OF_STOCK"] + ) + + with patch( + "runpod_flash.core.resources.worker_availability_diagnostic.RunpodGraphQLClient", + return_value=_make_client_context(mock_client), + ): + result = await WorkerAvailabilityDiagnostic().diagnose(resource) + + assert result.has_availability is False + assert result.reason == "no_gpu_availability"