Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 26 additions & 8 deletions src/graphon/model_runtime/README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Model Runtime

This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers.
This module provides the interfaces for invoking and authenticating various
models, and offers Dify a unified information and credentials form rule for
model providers.

- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers,
- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic.
Expand Down Expand Up @@ -32,19 +34,35 @@ This module provides the interface for invoking and authenticating various model

## Structure

Model Runtime is divided into three layers:
Model Runtime is divided into protocol and implementation layers:

- The outermost layer is the factory method
- Provider/runtime protocols

It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials.
Shared provider concerns live in `protocols/provider_runtime.py`, while each
model capability has its own protocol module such as
`protocols/llm_runtime.py`, `protocols/text_embedding_runtime.py`, and
`protocols/tts_runtime.py`. Downstream runtimes can implement only the
capabilities they need instead of satisfying a single monolithic interface.

- The second layer is the provider layer
- Aggregate runtime protocol

It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers.
`protocols/runtime.py` composes the individual capability protocols into
`ModelRuntime` for adapters that intentionally implement the full surface
area.

- The bottom layer is the model layer
- Provider factory

It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types).
`model_providers/model_provider_factory.py` now depends only on
`ModelProviderRuntime`. It handles provider discovery, provider/model schema
lookup, credential validation, provider icon lookup, and provider-level model
list projection without assuming any invocation capability.

- Model wrappers

Capability wrappers such as `LargeLanguageModel`, `TextEmbeddingModel`,
`RerankModel`, `Speech2TextModel`, `ModerationModel`, and `TTSModel` depend
only on their matching capability protocol. Instantiate those wrappers
directly when you need invocation behavior.

## Documentation

Expand Down
64 changes: 0 additions & 64 deletions src/graphon/model_runtime/README_CN.md

This file was deleted.

21 changes: 21 additions & 0 deletions src/graphon/model_runtime/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from graphon.model_runtime.protocols import (
LLMModelRuntime,
ModelProviderRuntime,
ModelRuntime,
ModerationModelRuntime,
RerankModelRuntime,
SpeechToTextModelRuntime,
TextEmbeddingModelRuntime,
TTSModelRuntime,
)

__all__ = [
"LLMModelRuntime",
"ModelProviderRuntime",
"ModelRuntime",
"ModerationModelRuntime",
"RerankModelRuntime",
"SpeechToTextModelRuntime",
"TTSModelRuntime",
"TextEmbeddingModelRuntime",
]
8 changes: 4 additions & 4 deletions src/graphon/model_runtime/model_providers/base/ai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
InvokeRateLimitError,
InvokeServerUnavailableError,
)
from graphon.model_runtime.runtime import ModelRuntime
from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime


class AIModel:
class AIModel[RuntimeT: ModelProviderRuntime]:
"""Runtime-facing base class for all model providers.

This stays a regular Python class because instances hold live collaborators
Expand All @@ -34,13 +34,13 @@ class attribute; the base class is not meant to be instantiated directly.

model_type: ModelType
provider_schema: ProviderEntity
model_runtime: ModelRuntime
model_runtime: RuntimeT
started_at: float

def __init__(
self,
provider_schema: ProviderEntity,
model_runtime: ModelRuntime,
model_runtime: RuntimeT,
*,
started_at: float = 0,
) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PriceType,
)
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.protocols.llm_runtime import LLMModelRuntime

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -227,7 +228,7 @@ def _invoke_llm_via_runtime(
)


class LargeLanguageModel(AIModel):
class LargeLanguageModel(AIModel[LLMModelRuntime]):
"""Model class for large language model."""

model_type: ModelType = ModelType.LLM
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.protocols.moderation_runtime import (
ModerationModelRuntime,
)


class ModerationModel(AIModel):
class ModerationModel(AIModel[ModerationModelRuntime]):
"""Model class for moderation model."""

model_type: ModelType = ModelType.MODERATION
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
RerankResult,
)
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.protocols.rerank_runtime import RerankModelRuntime


class RerankModel(AIModel):
class RerankModel(AIModel[RerankModelRuntime]):
"""Base Model class for rerank model."""

model_type: ModelType = ModelType.RERANK
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@

from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.protocols.speech_to_text_runtime import (
SpeechToTextModelRuntime,
)


class Speech2TextModel(AIModel):
class Speech2TextModel(AIModel[SpeechToTextModelRuntime]):
"""Model class for speech2text model."""

model_type: ModelType = ModelType.SPEECH2TEXT
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
EmbeddingResult,
)
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.protocols.text_embedding_runtime import (
TextEmbeddingModelRuntime,
)


class TextEmbeddingModel(AIModel):
class TextEmbeddingModel(AIModel[TextEmbeddingModelRuntime]):
"""Model class for text embedding model."""

model_type: ModelType = ModelType.TEXT_EMBEDDING
Expand Down
3 changes: 2 additions & 1 deletion src/graphon/model_runtime/model_providers/base/tts_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@

from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.protocols.tts_runtime import TTSModelRuntime

logger = logging.getLogger(__name__)


class TTSModel(AIModel):
class TTSModel(AIModel[TTSModelRuntime]):
"""Model class for TTS model."""

model_type: ModelType = ModelType.TTS
Expand Down
60 changes: 11 additions & 49 deletions src/graphon/model_runtime/model_providers/model_provider_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,7 @@
ProviderEntity,
SimpleProviderEntity,
)
from graphon.model_runtime.model_providers.base.ai_model import AIModel
from graphon.model_runtime.model_providers.base.large_language_model import (
LargeLanguageModel,
)
from graphon.model_runtime.model_providers.base.moderation_model import (
ModerationModel,
)
from graphon.model_runtime.model_providers.base.rerank_model import RerankModel
from graphon.model_runtime.model_providers.base.speech2text_model import (
Speech2TextModel,
)
from graphon.model_runtime.model_providers.base.text_embedding_model import (
TextEmbeddingModel,
)
from graphon.model_runtime.model_providers.base.tts_model import TTSModel
from graphon.model_runtime.runtime import ModelRuntime
from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime
from graphon.model_runtime.schema_validators.model_credential_schema_validator import (
ModelCredentialSchemaValidator,
)
Expand All @@ -31,34 +16,23 @@
ProviderCredentialSchemaValidator,
)

_MODEL_CLASS_BY_TYPE: dict[ModelType, type[AIModel]] = {
ModelType.LLM: LargeLanguageModel,
ModelType.TEXT_EMBEDDING: TextEmbeddingModel,
ModelType.RERANK: RerankModel,
ModelType.SPEECH2TEXT: Speech2TextModel,
ModelType.MODERATION: ModerationModel,
ModelType.TTS: TTSModel,
}


class ModelProviderFactory:
"""Factory for provider schemas and model-type instances
backed by a runtime adapter.
"""
"""Factory for provider schemas and credential flows backed by a runtime."""

def __init__(self, model_runtime: ModelRuntime) -> None:
if model_runtime is None:
msg = "model_runtime is required."
def __init__(self, runtime: ModelProviderRuntime) -> None:
if runtime is None:
msg = "runtime is required."
raise ValueError(msg)
self.model_runtime = model_runtime
self.runtime = runtime

def get_providers(self) -> Sequence[ProviderEntity]:
"""Get all providers."""
return list(self.get_model_providers())

def get_model_providers(self) -> Sequence[ProviderEntity]:
"""Get all model providers exposed by the runtime adapter."""
return self.model_runtime.fetch_model_providers()
return self.runtime.fetch_model_providers()

def get_provider_schema(self, provider: str) -> ProviderEntity:
"""Get provider schema."""
Expand Down Expand Up @@ -90,7 +64,7 @@ def provider_credentials_validate(
validator = ProviderCredentialSchemaValidator(provider_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials)

self.model_runtime.validate_provider_credentials(
self.runtime.validate_provider_credentials(
provider=provider_entity.provider,
credentials=filtered_credentials,
)
Expand All @@ -116,7 +90,7 @@ def model_credentials_validate(
validator = ModelCredentialSchemaValidator(model_type, model_credential_schema)
filtered_credentials = validator.validate_and_filter(credentials)

self.model_runtime.validate_model_credentials(
self.runtime.validate_model_credentials(
provider=provider_entity.provider,
model_type=model_type,
model=model,
Expand All @@ -135,7 +109,7 @@ def get_model_schema(
) -> AIModelEntity | None:
"""Get model schema."""
provider_entity = self.get_model_provider(provider)
return self.model_runtime.get_model_schema(
return self.runtime.get_model_schema(
provider=provider_entity.provider,
model_type=model_type,
model=model,
Expand Down Expand Up @@ -168,18 +142,6 @@ def get_models(

return providers

def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel:
"""Get model type instance by provider name and model type."""
provider_schema = self.get_model_provider(provider)
model_class = _MODEL_CLASS_BY_TYPE.get(model_type)
if model_class is None:
msg = f"Unsupported model type: {model_type}"
raise ValueError(msg)
return model_class(
provider_schema=provider_schema,
model_runtime=self.model_runtime,
)

def get_provider_icon(
self,
provider: str,
Expand All @@ -188,7 +150,7 @@ def get_provider_icon(
) -> tuple[bytes, str]:
"""Get provider icon."""
provider_entity = self.get_model_provider(provider)
return self.model_runtime.get_provider_icon(
return self.runtime.get_provider_icon(
provider=provider_entity.provider,
icon_type=icon_type,
lang=lang,
Expand Down
25 changes: 25 additions & 0 deletions src/graphon/model_runtime/protocols/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from graphon.model_runtime.protocols.llm_runtime import LLMModelRuntime
from graphon.model_runtime.protocols.moderation_runtime import (
ModerationModelRuntime,
)
from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime
from graphon.model_runtime.protocols.rerank_runtime import RerankModelRuntime
from graphon.model_runtime.protocols.runtime import ModelRuntime
from graphon.model_runtime.protocols.speech_to_text_runtime import (
SpeechToTextModelRuntime,
)
from graphon.model_runtime.protocols.text_embedding_runtime import (
TextEmbeddingModelRuntime,
)
from graphon.model_runtime.protocols.tts_runtime import TTSModelRuntime

__all__ = [
"LLMModelRuntime",
"ModelProviderRuntime",
"ModelRuntime",
"ModerationModelRuntime",
"RerankModelRuntime",
"SpeechToTextModelRuntime",
"TTSModelRuntime",
"TextEmbeddingModelRuntime",
]
Loading
Loading