Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changes/11613.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add three-stage PROVISIONING sub-status pipeline (PENDING→STARTING→WARMING_UP→RUNNING) for model service routes with ReplicaID typed identifier.
6 changes: 6 additions & 0 deletions src/ai/backend/common/identifier/replica.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import NewType
from uuid import UUID

__all__ = ("ReplicaID",)

ReplicaID = NewType("ReplicaID", UUID)
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
SessionHistoryNode,
)
from ai.backend.common.dto.manager.v2.scheduling_history.types import SubStepResultInfo
from ai.backend.common.identifier.replica import ReplicaID
from ai.backend.manager.api.adapter_options.pagination.pagination import PaginationSpec
from ai.backend.manager.api.adapters.base import BaseAdapter
from ai.backend.manager.data.deployment.types import DeploymentHistoryData, RouteHistoryData
Expand Down Expand Up @@ -537,7 +538,7 @@ async def route_scoped_search(
input: AdminSearchRouteHistoriesInput,
) -> AdminSearchRouteHistoriesPayload:
"""Search route histories scoped to a route."""
scope = RouteHistorySearchScope(route_id=route_id)
scope = RouteHistorySearchScope(route_id=ReplicaID(route_id))
querier = self._build_route_querier(input)
action_result = (
await self._processors.scheduling_history.search_route_scoped_history.wait_for_complete(
Expand Down
4 changes: 1 addition & 3 deletions src/ai/backend/manager/data/deployment/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class RouteHandlerCategory(enum.StrEnum):

LIFECYCLE = "lifecycle"
HEALTH = "health"
SYNC = "sync"


class DeploymentHandlerCategory(enum.StrEnum):
Expand Down Expand Up @@ -291,9 +292,6 @@ class RouteTransitionTarget:
class RouteStatusTransitions:
"""Status transitions for route handlers.

Route handlers have success/failure/stale outcomes (no expired/give_up).
Each outcome can change lifecycle status, health status, or both.

Attributes:
success: Target state when handler succeeds, None means no change
failure: Target state when handler fails, None means no change
Expand Down
27 changes: 2 additions & 25 deletions src/ai/backend/manager/models/routing/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sqlalchemy.orm import Mapped, mapped_column, relationship, selectinload

from ai.backend.common.identifier.deployment import DeploymentID
from ai.backend.common.identifier.replica import ReplicaID
from ai.backend.common.types import SessionId
from ai.backend.logging import BraceStyleAdapter
from ai.backend.manager.data.deployment.types import (
Expand Down Expand Up @@ -53,7 +54,7 @@ class RoutingRow(Base): # type: ignore[misc]
sa.UniqueConstraint("endpoint", "session", name="uq_routings_endpoint_session"),
)

id: Mapped[uuid.UUID] = mapped_column(
id: Mapped[ReplicaID] = mapped_column(
"id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")
)
endpoint: Mapped[DeploymentID] = mapped_column(
Expand Down Expand Up @@ -237,30 +238,6 @@ async def get(
raise NoResultFound
return row

def __init__(
self,
id: uuid.UUID,
endpoint: DeploymentID,
session: uuid.UUID | None,
session_owner: uuid.UUID,
domain: str,
project: uuid.UUID,
revision: uuid.UUID,
status: RouteStatus = RouteStatus.PROVISIONING,
traffic_ratio: float = 1.0,
traffic_status: RouteTrafficStatus = RouteTrafficStatus.ACTIVE,
) -> None:
self.id = id
self.endpoint = endpoint
self.session = session
self.session_owner = session_owner
self.domain = domain
self.project = project
self.status = status
self.traffic_ratio = traffic_ratio
self.revision = revision
self.traffic_status = traffic_status

def delegate_ownership(self, user_uuid: uuid.UUID) -> None:
self.session_owner = user_uuid

Expand Down
3 changes: 2 additions & 1 deletion src/ai/backend/manager/models/scheduling_history/row.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from sqlalchemy.orm import Mapped, mapped_column

from ai.backend.common.data.model_deployment.types import ModelDeploymentStatus
from ai.backend.common.identifier.replica import ReplicaID
from ai.backend.common.types import KernelId, SessionId
from ai.backend.manager.data.deployment.types import (
DeploymentHandlerCategory,
Expand Down Expand Up @@ -248,7 +249,7 @@ class RouteHistoryRow(Base): # type: ignore[misc]
id: Mapped[uuid.UUID] = mapped_column(
"id", GUID, primary_key=True, server_default=sa.text("uuid_generate_v4()")
)
route_id: Mapped[uuid.UUID] = mapped_column("route_id", GUID, nullable=False, index=True)
route_id: Mapped[ReplicaID] = mapped_column("route_id", GUID, nullable=False, index=True)
deployment_id: Mapped[uuid.UUID] = mapped_column(
"deployment_id", GUID, nullable=False, index=True
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,13 @@ class RouteCreatorSpec(CreatorSpec[RoutingRow]):
@override
def build_row(self) -> RoutingRow:
return RoutingRow(
id=uuid.uuid4(),
endpoint=self.deployment_id,
session=None,
session_owner=self.session_owner_id,
domain=self.domain,
project=self.project_id,
status=RouteStatus.PROVISIONING,
sub_status=RouteSubStatus.PENDING,
traffic_ratio=self.traffic_ratio,
revision=self.revision_id,
traffic_status=self.traffic_status,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from ai.backend.common.identifier.deployment_preset import DeploymentPresetID
from ai.backend.common.identifier.deployment_revision import DeploymentRevisionID
from ai.backend.common.identifier.image import ImageID
from ai.backend.common.identifier.replica import ReplicaID
from ai.backend.common.identifier.resource_group import ResourceGroupName
from ai.backend.common.identifier.runtime_variant import RuntimeVariantID
from ai.backend.common.identifier.vfolder import VFolderUUID
Expand Down Expand Up @@ -179,6 +180,8 @@
ProjectDeploymentSearchScope,
RouteData,
RouteServiceDiscoveryInfo,
RouteSessionInfo,
RouteSessionKernelInfo,
)
from ai.backend.manager.repositories.scheduler.types.session_creation import (
ContainerUserContext,
Expand Down Expand Up @@ -1208,7 +1211,7 @@ async def get_endpoint_id_by_session(

async def fetch_route_service_discovery_info(
self,
route_ids: set[uuid.UUID],
route_ids: set[ReplicaID],
) -> list[RouteServiceDiscoveryInfo]:
"""Fetch service discovery information for routes.

Expand Down Expand Up @@ -1770,9 +1773,9 @@ async def update_route_status_bulk_with_history(
async def _get_last_route_histories_by_category(
self,
db_sess: SASession,
route_ids: list[uuid.UUID],
route_ids: list[ReplicaID],
category: RouteHandlerCategory,
) -> dict[uuid.UUID, RouteHistoryRow]:
) -> dict[ReplicaID, RouteHistoryRow]:
"""Get last history records per route filtered by handler category."""
if not route_ids:
return {}
Expand All @@ -1796,8 +1799,8 @@ async def _get_last_route_histories_by_category(
async def _get_last_route_histories_bulk(
self,
db_sess: SASession,
route_ids: list[uuid.UUID],
) -> dict[uuid.UUID, RouteHistoryRow]:
route_ids: list[ReplicaID],
) -> dict[ReplicaID, RouteHistoryRow]:
"""Get last history records for multiple routes efficiently."""
if not route_ids:
return {}
Expand Down Expand Up @@ -1927,15 +1930,15 @@ async def fetch_kernel_connection_info(

async def update_route_replica_info(
self,
updates: dict[uuid.UUID, tuple[str, int]],
updates: dict[ReplicaID, RouteSessionKernelInfo],
) -> None:
"""Update replica_host and replica_port for routes."""
async with self._begin_session_read_committed() as db_sess:
for route_id, (host, port) in updates.items():
for route_id, kernel in updates.items():
query = (
sa.update(RoutingRow)
.where(RoutingRow.id == route_id)
.values(replica_host=host, replica_port=port)
.values(replica_host=kernel.replica_host, replica_port=kernel.replica_port)
)
await db_sess.execute(query)

Expand Down Expand Up @@ -2129,8 +2132,8 @@ async def fetch_deployment_context(

async def fetch_session_statuses_by_route_ids(
self,
route_ids: set[uuid.UUID],
) -> Mapping[uuid.UUID, SessionStatus | None]:
route_ids: set[ReplicaID],
) -> Mapping[ReplicaID, SessionStatus | None]:
"""Fetch session statuses for multiple routes.

Args:
Expand Down Expand Up @@ -2158,12 +2161,83 @@ async def fetch_session_statuses_by_route_ids(
rows = result.all()

# 결과를 매핑으로 변환
status_map: dict[uuid.UUID, SessionStatus | None] = {}
status_map: dict[ReplicaID, SessionStatus | None] = {}
for route_id, session_status in rows:
status_map[route_id] = session_status
status_map[ReplicaID(route_id)] = session_status

return status_map

async def fetch_route_session_kernel_infos(
self,
route_ids: set[ReplicaID],
) -> Mapping[ReplicaID, RouteSessionInfo | None]:
"""Fetch session status and kernel connection info for multiple routes.

Args:
route_ids: Set of route IDs to fetch information for

Returns:
Mapping of route_id to RouteSessionInfo:
- None → route has no session linked
- RouteSessionInfo(status=TERMINAL, kernel=None) → session terminated
- RouteSessionInfo(status=RUNNING, kernel=RouteSessionKernelInfo(host, port)) → ready
- RouteSessionInfo(status=PREPARING, kernel=None) → not yet running
"""
if not route_ids:
return {}

async with self._begin_readonly_session_read_committed() as db_sess:
query = (
sa.select(
RoutingRow.id,
SessionRow.status,
KernelRow.kernel_host,
KernelRow.service_ports,
)
.select_from(RoutingRow)
.outerjoin(SessionRow, RoutingRow.session == SessionRow.id)
.outerjoin(
KernelRow,
sa.and_(
KernelRow.session_id == RoutingRow.session,
KernelRow.cluster_role == "main",
),
)
.where(RoutingRow.id.in_(route_ids))
)

result = await db_sess.execute(query)
rows = result.all()

info_map: dict[ReplicaID, RouteSessionInfo | None] = {}
for row in rows:
route_id = ReplicaID(row.id)
if row.status is None:
info_map[route_id] = None
continue

kernel: RouteSessionKernelInfo | None = None
if row.kernel_host and row.service_ports:
inference_port: int | None = None
for port_info in row.service_ports:
if port_info.get("is_inference", False):
host_ports = port_info.get("host_ports", [])
if host_ports:
inference_port = host_ports[0]
break
if inference_port is not None:
kernel = RouteSessionKernelInfo(
replica_host=row.kernel_host,
replica_port=inference_port,
)

info_map[route_id] = RouteSessionInfo(
status=row.status,
kernel=kernel,
)

return info_map

async def fetch_route_connection_infos(
self,
*,
Expand Down
33 changes: 28 additions & 5 deletions src/ai/backend/manager/repositories/deployment/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from ai.backend.common.identifier.deployment_preset import DeploymentPresetID
from ai.backend.common.identifier.deployment_revision import DeploymentRevisionID
from ai.backend.common.identifier.image import ImageID
from ai.backend.common.identifier.replica import ReplicaID
from ai.backend.common.identifier.resource_group import ResourceGroupName
from ai.backend.common.identifier.runtime_variant import RuntimeVariantID
from ai.backend.common.identifier.vfolder import VFolderUUID
Expand Down Expand Up @@ -104,7 +105,13 @@

from .db_source import DeploymentDBSource
from .storage_source import DeploymentStorageSource
from .types import ProjectDeploymentSearchScope, RouteData, RouteServiceDiscoveryInfo
from .types import (
ProjectDeploymentSearchScope,
RouteData,
RouteServiceDiscoveryInfo,
RouteSessionInfo,
RouteSessionKernelInfo,
)

log = BraceStyleAdapter(logging.getLogger(__name__))

Expand Down Expand Up @@ -758,7 +765,7 @@ async def fetch_kernel_connection_info(
@deployment_repository_resilience.apply()
async def update_route_replica_info(
self,
updates: dict[uuid.UUID, tuple[str, int]],
updates: dict[ReplicaID, RouteSessionKernelInfo],
) -> None:
"""Update replica_host and replica_port for routes."""
await self._db_source.update_route_replica_info(updates)
Expand Down Expand Up @@ -1105,11 +1112,27 @@ async def calculate_desired_replicas_for_deployment(
@deployment_repository_resilience.apply()
async def fetch_session_statuses_by_route_ids(
self,
route_ids: set[uuid.UUID],
) -> Mapping[uuid.UUID, SessionStatus | None]:
route_ids: set[ReplicaID],
) -> Mapping[ReplicaID, SessionStatus | None]:
"""Fetch session IDs for multiple routes."""
return await self._db_source.fetch_session_statuses_by_route_ids(route_ids)

@deployment_repository_resilience.apply()
async def fetch_route_session_kernel_infos(
self,
route_ids: set[ReplicaID],
) -> Mapping[ReplicaID, RouteSessionInfo | None]:
"""Fetch session status and kernel connection info for multiple routes.

Returns:
Mapping of route_id to RouteSessionInfo:
- None → route has no session linked
- RouteSessionInfo(status=TERMINAL, kernel=None) → session terminated
- RouteSessionInfo(status=RUNNING, kernel=RouteSessionKernelInfo(host, port)) → ready
- RouteSessionInfo(status=PREPARING, kernel=None) → not yet running
"""
return await self._db_source.fetch_route_session_kernel_infos(route_ids)

@deployment_repository_resilience.apply()
async def fetch_route_connection_infos(
self,
Expand Down Expand Up @@ -1150,7 +1173,7 @@ async def get_endpoint_id_by_session(
@deployment_repository_resilience.apply()
async def fetch_route_service_discovery_info(
self,
route_ids: set[uuid.UUID],
route_ids: set[ReplicaID],
) -> list[RouteServiceDiscoveryInfo]:
"""Fetch service discovery information for routes.

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
ProjectDeploymentSearchScope,
RouteData,
RouteServiceDiscoveryInfo,
RouteSessionInfo,
RouteSessionKernelInfo,
)

__all__ = [
Expand All @@ -16,4 +18,6 @@
"ProjectDeploymentSearchScope",
"RouteData",
"RouteServiceDiscoveryInfo",
"RouteSessionInfo",
"RouteSessionKernelInfo",
]
Loading
Loading