Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 18 additions & 2 deletions src/graphon/model_runtime/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Model Runtime

This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers.
This module provides the interfaces for invoking and authenticating various
models, and offers Dify a unified information and credentials form rule for
model providers.

- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers,
- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic.
Expand Down Expand Up @@ -32,7 +34,21 @@ This module provides the interface for invoking and authenticating various model

## Structure

Model Runtime is divided into three layers:
Model Runtime is divided into protocol and implementation layers:

- Provider/runtime protocols

Shared provider concerns live in `protocols/provider_runtime.py`, while each
model capability has its own protocol module such as
`protocols/llm_runtime.py`, `protocols/text_embedding_runtime.py`, and
`protocols/tts_runtime.py`. Downstream runtimes can implement only the
capabilities they need instead of satisfying a single monolithic interface.

- Aggregate runtime protocol

`protocols/runtime.py` composes the individual capability protocols into
`ModelRuntime` for adapters that intentionally implement the full surface
area.

- The outermost layer is the factory method

Expand Down
21 changes: 21 additions & 0 deletions src/graphon/model_runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from graphon.model_runtime.protocols import (
LLMModelRuntime,
ModelProviderRuntime,
ModelRuntime,
ModerationModelRuntime,
RerankModelRuntime,
SpeechToTextModelRuntime,
TextEmbeddingModelRuntime,
TTSModelRuntime,
)

__all__ = [
"LLMModelRuntime",
"ModelProviderRuntime",
"ModelRuntime",
"ModerationModelRuntime",
"RerankModelRuntime",
"SpeechToTextModelRuntime",
"TTSModelRuntime",
"TextEmbeddingModelRuntime",
]
8 changes: 4 additions & 4 deletions src/graphon/model_runtime/model_providers/base/ai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from graphon.model_runtime.runtime import ModelRuntime
from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime


class AIModel:
class AIModel[RuntimeT: ModelProviderRuntime]:
"""Runtime-facing base class for all model providers.

This stays a regular Python class because instances hold live collaborators
Expand All @@ -34,13 +34,13 @@ class attribute; the base class is not meant to be instantiated directly.

model_type: ModelType
provider_schema: ProviderEntity
model_runtime: ModelRuntime
model_runtime: RuntimeT
started_at: float

def __init__(
self,
provider_schema: ProviderEntity,
model_runtime: ModelRuntime,
model_runtime: RuntimeT,
*,
started_at: float = 0,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PriceType,
)
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.protocols.llm_runtime import LLMModelRuntime

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -227,7 +228,7 @@ def _invoke_llm_via_runtime(
)


class LargeLanguageModel(AIModel):
class LargeLanguageModel(AIModel[LLMModelRuntime]):
"""Model class for large language model."""

model_type: ModelType = ModelType.LLM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.protocols.moderation_runtime import (
ModerationModelRuntime,
)


class ModerationModel(AIModel):
class ModerationModel(AIModel[ModerationModelRuntime]):
"""Model class for moderation model."""

model_type: ModelType = ModelType.MODERATION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
RerankResult,
)
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.protocols.rerank_runtime import RerankModelRuntime


class RerankModel(AIModel):
class RerankModel(AIModel[RerankModelRuntime]):
"""Base Model class for rerank model."""

model_type: ModelType = ModelType.RERANK
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.protocols.speech_to_text_runtime import (
SpeechToTextModelRuntime,
)


class Speech2TextModel(AIModel):
class Speech2TextModel(AIModel[SpeechToTextModelRuntime]):
"""Model class for speech2text model."""

model_type: ModelType = ModelType.SPEECH2TEXT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
EmbeddingResult,
)
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.protocols.text_embedding_runtime import (
TextEmbeddingModelRuntime,
)


class TextEmbeddingModel(AIModel):
class TextEmbeddingModel(AIModel[TextEmbeddingModelRuntime]):
"""Model class for text embedding model."""

model_type: ModelType = ModelType.TEXT_EMBEDDING
Expand Down
3 changes: 2 additions & 1 deletion src/graphon/model_runtime/model_providers/base/tts_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.protocols.tts_runtime import TTSModelRuntime

logger = logging.getLogger(__name__)


class TTSModel(AIModel):
class TTSModel(AIModel[TTSModelRuntime]):
"""Model class for TTS model."""

model_type: ModelType = ModelType.TTS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
TextEmbeddingModel,
)
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from graphon.model_runtime.runtime import ModelRuntime
from graphon.model_runtime.protocols.runtime import ModelRuntime
from graphon.model_runtime.schema_validators.model_credential_schema_validator import (
ModelCredentialSchemaValidator,
)
Expand Down
25 changes: 25 additions & 0 deletions src/graphon/model_runtime/protocols/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from graphon.model_runtime.protocols.llm_runtime import LLMModelRuntime
from graphon.model_runtime.protocols.moderation_runtime import (
ModerationModelRuntime,
)
from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime
from graphon.model_runtime.protocols.rerank_runtime import RerankModelRuntime
from graphon.model_runtime.protocols.runtime import ModelRuntime
from graphon.model_runtime.protocols.speech_to_text_runtime import (
SpeechToTextModelRuntime,
)
from graphon.model_runtime.protocols.text_embedding_runtime import (
TextEmbeddingModelRuntime,
)
from graphon.model_runtime.protocols.tts_runtime import TTSModelRuntime

__all__ = [
"LLMModelRuntime",
"ModelProviderRuntime",
"ModelRuntime",
"ModerationModelRuntime",
"RerankModelRuntime",
"SpeechToTextModelRuntime",
"TTSModelRuntime",
"TextEmbeddingModelRuntime",
]
118 changes: 118 additions & 0 deletions src/graphon/model_runtime/protocols/llm_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import annotations

from collections.abc import Generator, Sequence
from typing import Any, Literal, Protocol, overload, runtime_checkable

from graphon.model_runtime.entities.llm_entities import (
LLMResult,
LLMResultChunk,
LLMResultChunkWithStructuredOutput,
LLMResultWithStructuredOutput,
)
from graphon.model_runtime.entities.message_entities import (
PromptMessage,
PromptMessageTool,
)
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime


@runtime_checkable
class LLMModelRuntime(ModelProviderRuntime, Protocol):
"""Runtime surface required by LLM-backed model wrappers."""

@overload
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResult: ...

@overload
def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunk, None, None]: ...

def invoke_llm(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: list[PromptMessageTool] | None,
stop: Sequence[str] | None,
stream: bool,
) -> LLMResult | Generator[LLMResultChunk, None, None]: ...

@overload
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: Literal[False],
) -> LLMResultWithStructuredOutput: ...

@overload
def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: Literal[True],
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...

def invoke_llm_with_structured_output(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
json_schema: dict[str, Any],
model_parameters: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
stop: Sequence[str] | None,
stream: bool,
) -> (
LLMResultWithStructuredOutput
| Generator[LLMResultChunkWithStructuredOutput, None, None]
): ...

def get_llm_num_tokens(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
prompt_messages: Sequence[PromptMessage],
tools: Sequence[PromptMessageTool] | None,
) -> int: ...
19 changes: 19 additions & 0 deletions src/graphon/model_runtime/protocols/moderation_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from __future__ import annotations

from typing import Any, Protocol, runtime_checkable

from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime


@runtime_checkable
class ModerationModelRuntime(ModelProviderRuntime, Protocol):
"""Runtime surface required by moderation model wrappers."""

def invoke_moderation(
self,
*,
provider: str,
model: str,
credentials: dict[str, Any],
text: str,
) -> bool: ...
47 changes: 47 additions & 0 deletions src/graphon/model_runtime/protocols/provider_runtime.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from collections.abc import Sequence
from typing import Any, Protocol, runtime_checkable

from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
from graphon.model_runtime.entities.provider_entities import ProviderEntity


@runtime_checkable
class ModelProviderRuntime(Protocol):
"""Shared provider discovery, credential validation, and schema lookup."""

def fetch_model_providers(self) -> Sequence[ProviderEntity]: ...

def get_provider_icon(
self,
*,
provider: str,
icon_type: str,
lang: str,
) -> tuple[bytes, str]: ...

def validate_provider_credentials(
self,
*,
provider: str,
credentials: dict[str, Any],
) -> None: ...

def validate_model_credentials(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> None: ...

def get_model_schema(
self,
*,
provider: str,
model_type: ModelType,
model: str,
credentials: dict[str, Any],
) -> AIModelEntity | None: ...
Loading
Loading