diff --git a/src/graphon/model_runtime/README.md b/src/graphon/model_runtime/README.md index b9d2c55..7af5ea7 100644 --- a/src/graphon/model_runtime/README.md +++ b/src/graphon/model_runtime/README.md @@ -1,6 +1,8 @@ # Model Runtime -This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers. +This module provides the interfaces for invoking and authenticating various +models, and offers Dify a unified information and credentials form rule for +model providers. - On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers, - On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic. @@ -32,19 +34,35 @@ This module provides the interface for invoking and authenticating various model ## Structure -Model Runtime is divided into three layers: +Model Runtime is divided into protocol and implementation layers: -- The outermost layer is the factory method +- Provider/runtime protocols - It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials. + Shared provider concerns live in `protocols/provider_runtime.py`, while each + model capability has its own protocol module such as + `protocols/llm_runtime.py`, `protocols/text_embedding_runtime.py`, and + `protocols/tts_runtime.py`. Downstream runtimes can implement only the + capabilities they need instead of satisfying a single monolithic interface. -- The second layer is the provider layer +- Aggregate runtime protocol - It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers. + `protocols/runtime.py` composes the individual capability protocols into + `ModelRuntime` for adapters that intentionally implement the full surface + area. -- The bottom layer is the model layer +- Provider factory - It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types). + `model_providers/model_provider_factory.py` now depends only on + `ModelProviderRuntime`. It handles provider discovery, provider/model schema + lookup, credential validation, provider icon lookup, and provider-level model + list projection without assuming any invocation capability. + +- Model wrappers + + Capability wrappers such as `LargeLanguageModel`, `TextEmbeddingModel`, + `RerankModel`, `Speech2TextModel`, `ModerationModel`, and `TTSModel` depend + only on their matching capability protocol. Instantiate those wrappers + directly when you need invocation behavior. ## Documentation diff --git a/src/graphon/model_runtime/README_CN.md b/src/graphon/model_runtime/README_CN.md deleted file mode 100644 index 0a8b56b..0000000 --- a/src/graphon/model_runtime/README_CN.md +++ /dev/null @@ -1,64 +0,0 @@ -# Model Runtime - -该模块提供了各模型的调用、鉴权接口,并为 Dify 提供了统一的模型供应商的信息和凭据表单规则。 - -- 一方面将模型和上下游解耦,方便开发者对模型横向扩展, -- 另一方面提供了只需在后端定义供应商和模型,即可在前端页面直接展示,无需修改前端逻辑。 - -## 功能介绍 - -- 支持 6 种模型类型的能力调用 - - - `LLM` - LLM 文本补全、对话,预计算 tokens 能力 - - `Text Embedding Model` - 文本 Embedding,预计算 tokens 能力 - - `Rerank Model` - 分段 Rerank 能力 - - `Speech-to-text Model` - 语音转文本能力 - - `Text-to-speech Model` - 文本转语音能力 - - `Moderation` - Moderation 能力 - -- 模型供应商展示 - - 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。 - -- 可选择的模型列表展示 - - 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 - - 除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。 - -- 供应商/模型凭据鉴权 - - 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。 - -## 结构 - -Model Runtime 分三层: - -- 最外层为工厂方法 - - 提供获取所有供应商、所有模型列表、获取供应商实例、供应商/模型凭据鉴权方法。 - -- 第二层为供应商层 - - 提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。 - - 对于供应商/模型凭据,有两种情况 - - - 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据 - - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 - - 当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。 - -- 最底层为模型层 - - 提供各种模型类型的直接调用、预定义模型配置信息、获取预定义/远程模型列表、模型凭据鉴权方法,不同模型额外提供了特殊方法,如 LLM 提供预计算 tokens 方法、获取费用信息方法等,**可横向扩展**同供应商下不同的模型(支持的模型类型下)。 - - 在这里我们需要先区分模型参数与模型凭据。 - - - 模型参数 (**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在 DifyRuntime 中,他们的参数名一般为**model_parameters: dict[str, any]**。 - - - 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。 - -## 文档 - -有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。 diff --git a/src/graphon/model_runtime/__init__.py b/src/graphon/model_runtime/__init__.py index e69de29..8134b86 100644 --- a/src/graphon/model_runtime/__init__.py +++ b/src/graphon/model_runtime/__init__.py @@ -0,0 +1,21 @@ +from graphon.model_runtime.protocols import ( + LLMModelRuntime, + ModelProviderRuntime, + ModelRuntime, + ModerationModelRuntime, + RerankModelRuntime, + SpeechToTextModelRuntime, + TextEmbeddingModelRuntime, + TTSModelRuntime, +) + +__all__ = [ + "LLMModelRuntime", + "ModelProviderRuntime", + "ModelRuntime", + "ModerationModelRuntime", + "RerankModelRuntime", + "SpeechToTextModelRuntime", + "TTSModelRuntime", + "TextEmbeddingModelRuntime", +] diff --git a/src/graphon/model_runtime/model_providers/base/ai_model.py b/src/graphon/model_runtime/model_providers/base/ai_model.py index fa082dd..81a55de 100644 --- a/src/graphon/model_runtime/model_providers/base/ai_model.py +++ b/src/graphon/model_runtime/model_providers/base/ai_model.py @@ -20,10 +20,10 @@ InvokeRateLimitError, InvokeServerUnavailableError, ) -from graphon.model_runtime.runtime import ModelRuntime +from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime -class AIModel: +class AIModel[RuntimeT: ModelProviderRuntime]: """Runtime-facing base class for all model providers. This stays a regular Python class because instances hold live collaborators @@ -34,13 +34,13 @@ class attribute; the base class is not meant to be instantiated directly. model_type: ModelType provider_schema: ProviderEntity - model_runtime: ModelRuntime + model_runtime: RuntimeT started_at: float def __init__( self, provider_schema: ProviderEntity, - model_runtime: ModelRuntime, + model_runtime: RuntimeT, *, started_at: float = 0, ) -> None: diff --git a/src/graphon/model_runtime/model_providers/base/large_language_model.py b/src/graphon/model_runtime/model_providers/base/large_language_model.py index 2ffd50b..a32692e 100644 --- a/src/graphon/model_runtime/model_providers/base/large_language_model.py +++ b/src/graphon/model_runtime/model_providers/base/large_language_model.py @@ -22,6 +22,7 @@ PriceType, ) from graphon.model_runtime.model_providers.base.ai_model import AIModel +from graphon.model_runtime.protocols.llm_runtime import LLMModelRuntime logger = logging.getLogger(__name__) @@ -227,7 +228,7 @@ def _invoke_llm_via_runtime( ) -class LargeLanguageModel(AIModel): +class LargeLanguageModel(AIModel[LLMModelRuntime]): """Model class for large language model.""" model_type: ModelType = ModelType.LLM diff --git a/src/graphon/model_runtime/model_providers/base/moderation_model.py b/src/graphon/model_runtime/model_providers/base/moderation_model.py index ad26c42..afde654 100644 --- a/src/graphon/model_runtime/model_providers/base/moderation_model.py +++ b/src/graphon/model_runtime/model_providers/base/moderation_model.py @@ -2,9 +2,12 @@ from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.model_providers.base.ai_model import AIModel +from graphon.model_runtime.protocols.moderation_runtime import ( + ModerationModelRuntime, +) -class ModerationModel(AIModel): +class ModerationModel(AIModel[ModerationModelRuntime]): """Model class for moderation model.""" model_type: ModelType = ModelType.MODERATION diff --git a/src/graphon/model_runtime/model_providers/base/rerank_model.py b/src/graphon/model_runtime/model_providers/base/rerank_model.py index a36b9e7..f33df42 100644 --- a/src/graphon/model_runtime/model_providers/base/rerank_model.py +++ b/src/graphon/model_runtime/model_providers/base/rerank_model.py @@ -4,9 +4,10 @@ RerankResult, ) from graphon.model_runtime.model_providers.base.ai_model import AIModel +from graphon.model_runtime.protocols.rerank_runtime import RerankModelRuntime -class RerankModel(AIModel): +class RerankModel(AIModel[RerankModelRuntime]): """Base Model class for rerank model.""" model_type: ModelType = ModelType.RERANK diff --git a/src/graphon/model_runtime/model_providers/base/speech2text_model.py b/src/graphon/model_runtime/model_providers/base/speech2text_model.py index 6da90eb..2459755 100644 --- a/src/graphon/model_runtime/model_providers/base/speech2text_model.py +++ b/src/graphon/model_runtime/model_providers/base/speech2text_model.py @@ -2,9 +2,12 @@ from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.model_providers.base.ai_model import AIModel +from graphon.model_runtime.protocols.speech_to_text_runtime import ( + SpeechToTextModelRuntime, +) -class Speech2TextModel(AIModel): +class Speech2TextModel(AIModel[SpeechToTextModelRuntime]): """Model class for speech2text model.""" model_type: ModelType = ModelType.SPEECH2TEXT diff --git a/src/graphon/model_runtime/model_providers/base/text_embedding_model.py b/src/graphon/model_runtime/model_providers/base/text_embedding_model.py index 9ee9b84..7f842c7 100644 --- a/src/graphon/model_runtime/model_providers/base/text_embedding_model.py +++ b/src/graphon/model_runtime/model_providers/base/text_embedding_model.py @@ -6,9 +6,12 @@ EmbeddingResult, ) from graphon.model_runtime.model_providers.base.ai_model import AIModel +from graphon.model_runtime.protocols.text_embedding_runtime import ( + TextEmbeddingModelRuntime, +) -class TextEmbeddingModel(AIModel): +class TextEmbeddingModel(AIModel[TextEmbeddingModelRuntime]): """Model class for text embedding model.""" model_type: ModelType = ModelType.TEXT_EMBEDDING diff --git a/src/graphon/model_runtime/model_providers/base/tts_model.py b/src/graphon/model_runtime/model_providers/base/tts_model.py index 09c4286..43ffad6 100644 --- a/src/graphon/model_runtime/model_providers/base/tts_model.py +++ b/src/graphon/model_runtime/model_providers/base/tts_model.py @@ -4,11 +4,12 @@ from graphon.model_runtime.entities.model_entities import ModelType from graphon.model_runtime.model_providers.base.ai_model import AIModel +from graphon.model_runtime.protocols.tts_runtime import TTSModelRuntime logger = logging.getLogger(__name__) -class TTSModel(AIModel): +class TTSModel(AIModel[TTSModelRuntime]): """Model class for TTS model.""" model_type: ModelType = ModelType.TTS diff --git a/src/graphon/model_runtime/model_providers/model_provider_factory.py b/src/graphon/model_runtime/model_providers/model_provider_factory.py index 69a73a9..61032ec 100644 --- a/src/graphon/model_runtime/model_providers/model_provider_factory.py +++ b/src/graphon/model_runtime/model_providers/model_provider_factory.py @@ -7,22 +7,7 @@ ProviderEntity, SimpleProviderEntity, ) -from graphon.model_runtime.model_providers.base.ai_model import AIModel -from graphon.model_runtime.model_providers.base.large_language_model import ( - LargeLanguageModel, -) -from graphon.model_runtime.model_providers.base.moderation_model import ( - ModerationModel, -) -from graphon.model_runtime.model_providers.base.rerank_model import RerankModel -from graphon.model_runtime.model_providers.base.speech2text_model import ( - Speech2TextModel, -) -from graphon.model_runtime.model_providers.base.text_embedding_model import ( - TextEmbeddingModel, -) -from graphon.model_runtime.model_providers.base.tts_model import TTSModel -from graphon.model_runtime.runtime import ModelRuntime +from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime from graphon.model_runtime.schema_validators.model_credential_schema_validator import ( ModelCredentialSchemaValidator, ) @@ -31,26 +16,15 @@ ProviderCredentialSchemaValidator, ) -_MODEL_CLASS_BY_TYPE: dict[ModelType, type[AIModel]] = { - ModelType.LLM: LargeLanguageModel, - ModelType.TEXT_EMBEDDING: TextEmbeddingModel, - ModelType.RERANK: RerankModel, - ModelType.SPEECH2TEXT: Speech2TextModel, - ModelType.MODERATION: ModerationModel, - ModelType.TTS: TTSModel, -} - class ModelProviderFactory: - """Factory for provider schemas and model-type instances - backed by a runtime adapter. - """ + """Factory for provider schemas and credential flows backed by a runtime.""" - def __init__(self, model_runtime: ModelRuntime) -> None: - if model_runtime is None: - msg = "model_runtime is required." + def __init__(self, runtime: ModelProviderRuntime) -> None: + if runtime is None: + msg = "runtime is required." raise ValueError(msg) - self.model_runtime = model_runtime + self.runtime = runtime def get_providers(self) -> Sequence[ProviderEntity]: """Get all providers.""" @@ -58,7 +32,7 @@ def get_providers(self) -> Sequence[ProviderEntity]: def get_model_providers(self) -> Sequence[ProviderEntity]: """Get all model providers exposed by the runtime adapter.""" - return self.model_runtime.fetch_model_providers() + return self.runtime.fetch_model_providers() def get_provider_schema(self, provider: str) -> ProviderEntity: """Get provider schema.""" @@ -90,7 +64,7 @@ def provider_credentials_validate( validator = ProviderCredentialSchemaValidator(provider_credential_schema) filtered_credentials = validator.validate_and_filter(credentials) - self.model_runtime.validate_provider_credentials( + self.runtime.validate_provider_credentials( provider=provider_entity.provider, credentials=filtered_credentials, ) @@ -116,7 +90,7 @@ def model_credentials_validate( validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) filtered_credentials = validator.validate_and_filter(credentials) - self.model_runtime.validate_model_credentials( + self.runtime.validate_model_credentials( provider=provider_entity.provider, model_type=model_type, model=model, @@ -135,7 +109,7 @@ def get_model_schema( ) -> AIModelEntity | None: """Get model schema.""" provider_entity = self.get_model_provider(provider) - return self.model_runtime.get_model_schema( + return self.runtime.get_model_schema( provider=provider_entity.provider, model_type=model_type, model=model, @@ -168,18 +142,6 @@ def get_models( return providers - def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel: - """Get model type instance by provider name and model type.""" - provider_schema = self.get_model_provider(provider) - model_class = _MODEL_CLASS_BY_TYPE.get(model_type) - if model_class is None: - msg = f"Unsupported model type: {model_type}" - raise ValueError(msg) - return model_class( - provider_schema=provider_schema, - model_runtime=self.model_runtime, - ) - def get_provider_icon( self, provider: str, @@ -188,7 +150,7 @@ def get_provider_icon( ) -> tuple[bytes, str]: """Get provider icon.""" provider_entity = self.get_model_provider(provider) - return self.model_runtime.get_provider_icon( + return self.runtime.get_provider_icon( provider=provider_entity.provider, icon_type=icon_type, lang=lang, diff --git a/src/graphon/model_runtime/protocols/__init__.py b/src/graphon/model_runtime/protocols/__init__.py new file mode 100644 index 0000000..42613e9 --- /dev/null +++ b/src/graphon/model_runtime/protocols/__init__.py @@ -0,0 +1,25 @@ +from graphon.model_runtime.protocols.llm_runtime import LLMModelRuntime +from graphon.model_runtime.protocols.moderation_runtime import ( + ModerationModelRuntime, +) +from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime +from graphon.model_runtime.protocols.rerank_runtime import RerankModelRuntime +from graphon.model_runtime.protocols.runtime import ModelRuntime +from graphon.model_runtime.protocols.speech_to_text_runtime import ( + SpeechToTextModelRuntime, +) +from graphon.model_runtime.protocols.text_embedding_runtime import ( + TextEmbeddingModelRuntime, +) +from graphon.model_runtime.protocols.tts_runtime import TTSModelRuntime + +__all__ = [ + "LLMModelRuntime", + "ModelProviderRuntime", + "ModelRuntime", + "ModerationModelRuntime", + "RerankModelRuntime", + "SpeechToTextModelRuntime", + "TTSModelRuntime", + "TextEmbeddingModelRuntime", +] diff --git a/src/graphon/model_runtime/protocols/llm_runtime.py b/src/graphon/model_runtime/protocols/llm_runtime.py new file mode 100644 index 0000000..08b52d1 --- /dev/null +++ b/src/graphon/model_runtime/protocols/llm_runtime.py @@ -0,0 +1,118 @@ +from __future__ import annotations + +from collections.abc import Generator, Sequence +from typing import Any, Literal, Protocol, overload, runtime_checkable + +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + PromptMessageTool, +) +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime + + +@runtime_checkable +class LLMModelRuntime(ModelProviderRuntime, Protocol): + """Runtime surface required by LLM-backed model wrappers.""" + + @overload + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResult: ... + + @overload + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunk, None, None]: ... + + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResult | Generator[LLMResultChunk, None, None]: ... + + @overload + def invoke_llm_with_structured_output( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + json_schema: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + stream: Literal[False], + ) -> LLMResultWithStructuredOutput: ... + + @overload + def invoke_llm_with_structured_output( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + json_schema: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + stream: Literal[True], + ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... + + def invoke_llm_with_structured_output( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + json_schema: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + stream: bool, + ) -> ( + LLMResultWithStructuredOutput + | Generator[LLMResultChunkWithStructuredOutput, None, None] + ): ... + + def get_llm_num_tokens( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: Sequence[PromptMessageTool] | None, + ) -> int: ... diff --git a/src/graphon/model_runtime/protocols/moderation_runtime.py b/src/graphon/model_runtime/protocols/moderation_runtime.py new file mode 100644 index 0000000..8cf6ae9 --- /dev/null +++ b/src/graphon/model_runtime/protocols/moderation_runtime.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + +from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime + + +@runtime_checkable +class ModerationModelRuntime(ModelProviderRuntime, Protocol): + """Runtime surface required by moderation model wrappers.""" + + def invoke_moderation( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + text: str, + ) -> bool: ... diff --git a/src/graphon/model_runtime/protocols/provider_runtime.py b/src/graphon/model_runtime/protocols/provider_runtime.py new file mode 100644 index 0000000..83c35fe --- /dev/null +++ b/src/graphon/model_runtime/protocols/provider_runtime.py @@ -0,0 +1,47 @@ +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, Protocol, runtime_checkable + +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity + + +@runtime_checkable +class ModelProviderRuntime(Protocol): + """Shared provider discovery, credential validation, and schema lookup.""" + + def fetch_model_providers(self) -> Sequence[ProviderEntity]: ... + + def get_provider_icon( + self, + *, + provider: str, + icon_type: str, + lang: str, + ) -> tuple[bytes, str]: ... + + def validate_provider_credentials( + self, + *, + provider: str, + credentials: dict[str, Any], + ) -> None: ... + + def validate_model_credentials( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> None: ... + + def get_model_schema( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> AIModelEntity | None: ... diff --git a/src/graphon/model_runtime/protocols/rerank_runtime.py b/src/graphon/model_runtime/protocols/rerank_runtime.py new file mode 100644 index 0000000..aa3814a --- /dev/null +++ b/src/graphon/model_runtime/protocols/rerank_runtime.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + +from graphon.model_runtime.entities.rerank_entities import ( + MultimodalRerankInput, + RerankResult, +) +from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime + + +@runtime_checkable +class RerankModelRuntime(ModelProviderRuntime, Protocol): + """Runtime surface required by rerank model wrappers.""" + + def invoke_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: str, + docs: list[str], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: ... + + def invoke_multimodal_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: ... diff --git a/src/graphon/model_runtime/protocols/runtime.py b/src/graphon/model_runtime/protocols/runtime.py new file mode 100644 index 0000000..b59d4aa --- /dev/null +++ b/src/graphon/model_runtime/protocols/runtime.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +from typing import Protocol, runtime_checkable + +from graphon.model_runtime.protocols.llm_runtime import LLMModelRuntime +from graphon.model_runtime.protocols.moderation_runtime import ( + ModerationModelRuntime, +) +from graphon.model_runtime.protocols.rerank_runtime import RerankModelRuntime +from graphon.model_runtime.protocols.speech_to_text_runtime import ( + SpeechToTextModelRuntime, +) +from graphon.model_runtime.protocols.text_embedding_runtime import ( + TextEmbeddingModelRuntime, +) +from graphon.model_runtime.protocols.tts_runtime import TTSModelRuntime + + +@runtime_checkable +class ModelRuntime( + LLMModelRuntime, + TextEmbeddingModelRuntime, + RerankModelRuntime, + SpeechToTextModelRuntime, + ModerationModelRuntime, + TTSModelRuntime, + Protocol, +): + """Aggregate runtime for adapters that implement every model capability.""" diff --git a/src/graphon/model_runtime/protocols/speech_to_text_runtime.py b/src/graphon/model_runtime/protocols/speech_to_text_runtime.py new file mode 100644 index 0000000..8f59a62 --- /dev/null +++ b/src/graphon/model_runtime/protocols/speech_to_text_runtime.py @@ -0,0 +1,19 @@ +from __future__ import annotations + +from typing import IO, Any, Protocol, runtime_checkable + +from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime + + +@runtime_checkable +class SpeechToTextModelRuntime(ModelProviderRuntime, Protocol): + """Runtime surface required by speech-to-text model wrappers.""" + + def invoke_speech_to_text( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + file: IO[bytes], + ) -> str: ... diff --git a/src/graphon/model_runtime/protocols/text_embedding_runtime.py b/src/graphon/model_runtime/protocols/text_embedding_runtime.py new file mode 100644 index 0000000..4938ccd --- /dev/null +++ b/src/graphon/model_runtime/protocols/text_embedding_runtime.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from typing import Any, Protocol, runtime_checkable + +from graphon.model_runtime.entities.text_embedding_entities import ( + EmbeddingInputType, + EmbeddingResult, +) +from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime + + +@runtime_checkable +class TextEmbeddingModelRuntime(ModelProviderRuntime, Protocol): + """Runtime surface required by text and multimodal embedding wrappers.""" + + def invoke_text_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: ... + + def invoke_multimodal_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + documents: list[dict[str, Any]], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: ... + + def get_text_embedding_num_tokens( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + ) -> list[int]: ... diff --git a/src/graphon/model_runtime/protocols/tts_runtime.py b/src/graphon/model_runtime/protocols/tts_runtime.py new file mode 100644 index 0000000..2b9129a --- /dev/null +++ b/src/graphon/model_runtime/protocols/tts_runtime.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from collections.abc import Iterable +from typing import Any, Protocol, runtime_checkable + +from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime + + +@runtime_checkable +class TTSModelRuntime(ModelProviderRuntime, Protocol): + """Runtime surface required by text-to-speech model wrappers.""" + + def invoke_tts( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + content_text: str, + voice: str, + ) -> Iterable[bytes]: ... + + def get_tts_model_voices( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + language: str | None, + ) -> Any: ... diff --git a/src/graphon/model_runtime/runtime.py b/src/graphon/model_runtime/runtime.py deleted file mode 100644 index 8a7f483..0000000 --- a/src/graphon/model_runtime/runtime.py +++ /dev/null @@ -1,256 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator, Iterable, Sequence -from typing import IO, Any, Literal, Protocol, overload, runtime_checkable - -from graphon.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from graphon.model_runtime.entities.message_entities import ( - PromptMessage, - PromptMessageTool, -) -from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType -from graphon.model_runtime.entities.provider_entities import ProviderEntity -from graphon.model_runtime.entities.rerank_entities import ( - MultimodalRerankInput, - RerankResult, -) -from graphon.model_runtime.entities.text_embedding_entities import ( - EmbeddingInputType, - EmbeddingResult, -) - - -@runtime_checkable -class ModelRuntime(Protocol): - """Port for provider discovery, schema lookup, and model execution. - - `provider` is the model runtime's canonical provider identifier. Adapters may - derive transport-specific details from it, but those details stay outside - this boundary. - """ - - def fetch_model_providers(self) -> Sequence[ProviderEntity]: ... - - def get_provider_icon( - self, - *, - provider: str, - icon_type: str, - lang: str, - ) -> tuple[bytes, str]: ... - - def validate_provider_credentials( - self, - *, - provider: str, - credentials: dict[str, Any], - ) -> None: ... - - def validate_model_credentials( - self, - *, - provider: str, - model_type: ModelType, - model: str, - credentials: dict[str, Any], - ) -> None: ... - - def get_model_schema( - self, - *, - provider: str, - model_type: ModelType, - model: str, - credentials: dict[str, Any], - ) -> AIModelEntity | None: ... - - @overload - def invoke_llm( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - model_parameters: dict[str, Any], - prompt_messages: Sequence[PromptMessage], - tools: list[PromptMessageTool] | None, - stop: Sequence[str] | None, - stream: Literal[False], - ) -> LLMResult: ... - - @overload - def invoke_llm( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - model_parameters: dict[str, Any], - prompt_messages: Sequence[PromptMessage], - tools: list[PromptMessageTool] | None, - stop: Sequence[str] | None, - stream: Literal[True], - ) -> Generator[LLMResultChunk, None, None]: ... - - def invoke_llm( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - model_parameters: dict[str, Any], - prompt_messages: Sequence[PromptMessage], - tools: list[PromptMessageTool] | None, - stop: Sequence[str] | None, - stream: bool, - ) -> LLMResult | Generator[LLMResultChunk, None, None]: ... - - @overload - def invoke_llm_with_structured_output( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - json_schema: dict[str, Any], - model_parameters: dict[str, Any], - prompt_messages: Sequence[PromptMessage], - stop: Sequence[str] | None, - stream: Literal[False], - ) -> LLMResultWithStructuredOutput: ... - - @overload - def invoke_llm_with_structured_output( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - json_schema: dict[str, Any], - model_parameters: dict[str, Any], - prompt_messages: Sequence[PromptMessage], - stop: Sequence[str] | None, - stream: Literal[True], - ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... - - def invoke_llm_with_structured_output( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - json_schema: dict[str, Any], - model_parameters: dict[str, Any], - prompt_messages: Sequence[PromptMessage], - stop: Sequence[str] | None, - stream: bool, - ) -> ( - LLMResultWithStructuredOutput - | Generator[LLMResultChunkWithStructuredOutput, None, None] - ): ... - - def get_llm_num_tokens( - self, - *, - provider: str, - model_type: ModelType, - model: str, - credentials: dict[str, Any], - prompt_messages: Sequence[PromptMessage], - tools: Sequence[PromptMessageTool] | None, - ) -> int: ... - - def invoke_text_embedding( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - texts: list[str], - input_type: EmbeddingInputType, - ) -> EmbeddingResult: ... - - def invoke_multimodal_embedding( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - documents: list[dict[str, Any]], - input_type: EmbeddingInputType, - ) -> EmbeddingResult: ... - - def get_text_embedding_num_tokens( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - texts: list[str], - ) -> list[int]: ... - - def invoke_rerank( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - query: str, - docs: list[str], - score_threshold: float | None, - top_n: int | None, - ) -> RerankResult: ... - - def invoke_multimodal_rerank( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - query: MultimodalRerankInput, - docs: list[MultimodalRerankInput], - score_threshold: float | None, - top_n: int | None, - ) -> RerankResult: ... - - def invoke_tts( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - content_text: str, - voice: str, - ) -> Iterable[bytes]: ... - - def get_tts_model_voices( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - language: str | None, - ) -> Any: ... - - def invoke_speech_to_text( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - file: IO[bytes], - ) -> str: ... - - def invoke_moderation( - self, - *, - provider: str, - model: str, - credentials: dict[str, Any], - text: str, - ) -> bool: ... diff --git a/src/graphon/model_runtime/slim/prepared_llm.py b/src/graphon/model_runtime/slim/prepared_llm.py index 3c5898a..0e49f33 100644 --- a/src/graphon/model_runtime/slim/prepared_llm.py +++ b/src/graphon/model_runtime/slim/prepared_llm.py @@ -14,9 +14,10 @@ PromptMessageTool, ) from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.protocols.llm_runtime import LLMModelRuntime from graphon.nodes.llm.runtime_protocols import PreparedLLMProtocol -from .runtime import SlimRuntime, SlimStructuredOutputParseError +from .runtime import SlimStructuredOutputParseError class SlimPreparedLLM(PreparedLLMProtocol): @@ -24,7 +25,7 @@ class SlimPreparedLLM(PreparedLLMProtocol): def __init__( self, *, - runtime: SlimRuntime, + runtime: LLMModelRuntime, provider: str, model_name: str, credentials: Mapping[str, Any], diff --git a/src/graphon/model_runtime/slim/runtime.py b/src/graphon/model_runtime/slim/runtime.py index 5d5f283..e082e89 100644 --- a/src/graphon/model_runtime/slim/runtime.py +++ b/src/graphon/model_runtime/slim/runtime.py @@ -58,7 +58,7 @@ from graphon.model_runtime.model_providers.base.large_language_model import ( merge_tool_call_deltas, ) -from graphon.model_runtime.runtime import ModelRuntime +from graphon.model_runtime.protocols.runtime import ModelRuntime from graphon.model_runtime.utils.encoders import jsonable_encoder from .config import SlimConfig diff --git a/src/graphon/protocols/__init__.py b/src/graphon/protocols/__init__.py index 9549873..2b6dccd 100644 --- a/src/graphon/protocols/__init__.py +++ b/src/graphon/protocols/__init__.py @@ -3,7 +3,20 @@ from graphon.graph.validation import GraphValidationRule from graphon.http.protocols import HttpClientProtocol, HttpResponseProtocol from graphon.model_runtime.memory.prompt_message_memory import PromptMessageMemory -from graphon.model_runtime.runtime import ModelRuntime +from graphon.model_runtime.protocols.llm_runtime import LLMModelRuntime +from graphon.model_runtime.protocols.moderation_runtime import ( + ModerationModelRuntime, +) +from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime +from graphon.model_runtime.protocols.rerank_runtime import RerankModelRuntime +from graphon.model_runtime.protocols.runtime import ModelRuntime +from graphon.model_runtime.protocols.speech_to_text_runtime import ( + SpeechToTextModelRuntime, +) +from graphon.model_runtime.protocols.text_embedding_runtime import ( + TextEmbeddingModelRuntime, +) +from graphon.model_runtime.protocols.tts_runtime import TTSModelRuntime from graphon.nodes.code.code_node import WorkflowCodeExecutor from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory from graphon.nodes.llm.runtime_protocols import ( @@ -36,15 +49,22 @@ "HttpResponseProtocol", "HumanInputFormStateProtocol", "HumanInputNodeRuntimeProtocol", + "LLMModelRuntime", "ModelFactory", + "ModelProviderRuntime", "ModelRuntime", + "ModerationModelRuntime", "NodeFactory", "PreparedLLMProtocol", "PromptMessageMemory", "PromptMessageSerializerProtocol", "ReadOnlyGraphRuntimeState", "ReadOnlyVariablePool", + "RerankModelRuntime", "RetrieverAttachmentLoaderProtocol", + "SpeechToTextModelRuntime", + "TTSModelRuntime", + "TextEmbeddingModelRuntime", "ToolFileManagerProtocol", "ToolNodeRuntimeProtocol", "VariableLoader", diff --git a/tests/model_runtime/test_model_dispatch.py b/tests/model_runtime/test_model_dispatch.py index 332dcfd..c88cb96 100644 --- a/tests/model_runtime/test_model_dispatch.py +++ b/tests/model_runtime/test_model_dispatch.py @@ -1,10 +1,40 @@ -from unittest.mock import MagicMock +from collections.abc import Generator, Sequence +from decimal import Decimal +from io import BytesIO +from typing import IO, Any, cast import pytest from graphon.model_runtime.entities.common_entities import I18nObject -from graphon.model_runtime.entities.model_entities import ModelType -from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ( + FieldModelSchema, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) +from graphon.model_runtime.entities.rerank_entities import ( + MultimodalRerankInput, + RerankDocument, + RerankResult, +) +from graphon.model_runtime.entities.text_embedding_entities import ( + EmbeddingInputType, + EmbeddingResult, + EmbeddingUsage, +) from graphon.model_runtime.model_providers.base.large_language_model import ( LargeLanguageModel, ) @@ -22,6 +52,298 @@ from graphon.model_runtime.model_providers.model_provider_factory import ( ModelProviderFactory, ) +from graphon.model_runtime.protocols.llm_runtime import LLMModelRuntime + + +class _ProviderRuntimeStub: + def __init__( + self, + *, + providers: Sequence[ProviderEntity] = (), + provider_icon: tuple[bytes, str] = (b"", ""), + model_schema: AIModelEntity | None = None, + ) -> None: + self._providers = tuple(providers) + self._provider_icon = provider_icon + self._model_schema = model_schema + self.provider_credential_validations: list[dict[str, Any]] = [] + self.model_credential_validations: list[dict[str, Any]] = [] + self.provider_icon_requests: list[dict[str, str]] = [] + self.model_schema_requests: list[dict[str, Any]] = [] + + def fetch_model_providers(self) -> tuple[ProviderEntity, ...]: + return self._providers + + def get_provider_icon( + self, + *, + provider: str, + icon_type: str, + lang: str, + ) -> tuple[bytes, str]: + self.provider_icon_requests.append( + { + "provider": provider, + "icon_type": icon_type, + "lang": lang, + }, + ) + return self._provider_icon + + def validate_provider_credentials( + self, + *, + provider: str, + credentials: dict[str, Any], + ) -> None: + self.provider_credential_validations.append( + { + "provider": provider, + "credentials": credentials, + }, + ) + + def validate_model_credentials( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> None: + self.model_credential_validations.append( + { + "provider": provider, + "model_type": model_type, + "model": model, + "credentials": credentials, + }, + ) + + def get_model_schema( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> AIModelEntity | None: + self.model_schema_requests.append( + { + "provider": provider, + "model_type": model_type, + "model": model, + "credentials": credentials, + }, + ) + return self._model_schema + + +class _LLMRuntimeStub(_ProviderRuntimeStub): + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResult: + _ = provider, credentials, model_parameters, tools, stop, stream + return LLMResult( + model=model, + prompt_messages=list(prompt_messages), + message=AssistantPromptMessage(content="ok"), + usage=LLMUsage.empty_usage(), + ) + + def get_llm_num_tokens( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: Sequence[PromptMessageTool] | None, + ) -> int: + _ = provider, model_type, model, credentials, prompt_messages, tools + return 7 + + def invoke_llm_with_structured_output( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + json_schema: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + stop: Sequence[str] | None, + stream: bool, + ) -> ( + LLMResultWithStructuredOutput + | Generator[LLMResultChunkWithStructuredOutput, None, None] + ): + _ = provider, credentials, json_schema, model_parameters, stop, stream + return LLMResultWithStructuredOutput( + model=model, + prompt_messages=list(prompt_messages), + message=AssistantPromptMessage(content="ok"), + usage=LLMUsage.empty_usage(), + structured_output={"ok": True}, + ) + + +class _EmbeddingRuntimeStub(_ProviderRuntimeStub): + def invoke_text_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: + _ = provider, model, credentials, texts, input_type + return EmbeddingResult( + model=model, + embeddings=[[0.1, 0.2]], + usage=EmbeddingUsage( + tokens=1, + total_tokens=1, + unit_price=Decimal(0), + price_unit=Decimal(0), + total_price=Decimal(0), + currency="USD", + latency=0.0, + ), + ) + + def invoke_multimodal_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + documents: list[dict[str, Any]], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: + _ = provider, model, credentials, documents, input_type + return EmbeddingResult( + model=model, + embeddings=[[0.1, 0.2]], + usage=EmbeddingUsage( + tokens=1, + total_tokens=1, + unit_price=Decimal(0), + price_unit=Decimal(0), + total_price=Decimal(0), + currency="USD", + latency=0.0, + ), + ) + + def get_text_embedding_num_tokens( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + ) -> list[int]: + _ = provider, model, credentials, texts + return [3] + + +class _TTSRuntimeStub(_ProviderRuntimeStub): + def invoke_tts( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + content_text: str, + voice: str, + ) -> list[bytes]: + _ = provider, model, credentials, content_text, voice + return [b"audio"] + + def get_tts_model_voices( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + language: str | None, + ) -> list[str]: + _ = provider, model, credentials, language + return ["nova"] + + +class _ModerationRuntimeStub(_ProviderRuntimeStub): + def invoke_moderation( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + text: str, + ) -> bool: + _ = provider, model, credentials, text + return True + + +class _RerankRuntimeStub(_ProviderRuntimeStub): + def invoke_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: str, + docs: list[str], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: + _ = provider, credentials, query, score_threshold, top_n + return RerankResult( + model=model, + docs=[RerankDocument(index=0, text=docs[0], score=0.9)], + ) + + def invoke_multimodal_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: + _ = provider, credentials, query, score_threshold, top_n + return RerankResult( + model=model, + docs=[RerankDocument(index=0, text=docs[0]["content"], score=0.9)], + ) + + +class _SpeechToTextRuntimeStub(_ProviderRuntimeStub): + def invoke_speech_to_text( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + file: IO[bytes], + ) -> str: + _ = provider, model, credentials, file + return "transcript" @pytest.mark.parametrize( @@ -62,33 +384,205 @@ def test_model_type_to_origin_model_type_uses_model_map( assert model_type.to_origin_model_type() == expected_origin_model_type -@pytest.mark.parametrize( - ("model_type", "expected_model_class"), - [ - (ModelType.LLM, LargeLanguageModel), - (ModelType.TEXT_EMBEDDING, TextEmbeddingModel), - (ModelType.RERANK, RerankModel), - (ModelType.SPEECH2TEXT, Speech2TextModel), - (ModelType.MODERATION, ModerationModel), - (ModelType.TTS, TTSModel), - ], -) -def test_model_provider_factory_uses_model_class_map( - model_type: ModelType, - expected_model_class: type, -) -> None: +def test_model_provider_factory_accepts_provider_only_runtime_surface() -> None: provider = ProviderEntity( provider="test-provider", label=I18nObject(en_US="Test Provider"), - supported_model_types=[model_type], + supported_model_types=[ModelType.LLM], configurate_methods=[], + provider_credential_schema=ProviderCredentialSchema( + credential_form_schemas=[], + ), + model_credential_schema=ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model")), + credential_form_schemas=[], + ), ) - runtime = MagicMock() - runtime.fetch_model_providers.return_value = [provider] - factory = ModelProviderFactory(model_runtime=runtime) + runtime = _ProviderRuntimeStub( + providers=[provider], + provider_icon=(b"icon-bytes", "svg"), + ) + factory = ModelProviderFactory(runtime=runtime) + + assert list(factory.get_providers()) == [provider] + assert list(factory.get_model_providers()) == [provider] + assert factory.get_provider_schema("test-provider") is provider + assert factory.get_model_provider("test-provider") is provider + assert ( + factory.provider_credentials_validate( + provider="test-provider", + credentials={}, + ) + == {} + ) + assert ( + factory.model_credentials_validate( + provider="test-provider", + model_type=ModelType.LLM, + model="fake-chat", + credentials={}, + ) + == {} + ) + assert ( + factory.get_model_schema( + provider="test-provider", + model_type=ModelType.LLM, + model="fake-chat", + credentials={}, + ) + is None + ) + assert factory.get_models(model_type=ModelType.LLM) == [ + provider.to_simple_provider(), + ] + assert factory.get_provider_icon("test-provider", "icon", "en") == ( + b"icon-bytes", + "svg", + ) + assert runtime.provider_credential_validations == [ + { + "provider": "test-provider", + "credentials": {}, + }, + ] + assert runtime.model_credential_validations == [ + { + "provider": "test-provider", + "model_type": ModelType.LLM, + "model": "fake-chat", + "credentials": {}, + }, + ] - model = factory.get_model_type_instance("test-provider", model_type) - assert isinstance(model, expected_model_class) - assert model.provider_schema is provider - assert model.model_runtime is runtime +def test_large_language_model_accepts_llm_only_runtime_surface() -> None: + provider = ProviderEntity( + provider="test-provider", + label=I18nObject(en_US="Test Provider"), + supported_model_types=[ModelType.LLM], + configurate_methods=[], + ) + runtime = cast("LLMModelRuntime", _LLMRuntimeStub()) + model = LargeLanguageModel(provider_schema=provider, model_runtime=runtime) + + result = cast( + "LLMResult", + model.invoke( + model="fake-chat", + credentials={}, + prompt_messages=[UserPromptMessage(content="hello")], + stream=False, + ), + ) + + assert result.message.content == "ok" + assert ( + model.get_num_tokens( + model="fake-chat", + credentials={}, + prompt_messages=[UserPromptMessage(content="hello")], + ) + == 7 + ) + + +def test_text_embedding_model_accepts_embedding_only_runtime_surface() -> None: + provider = ProviderEntity( + provider="test-provider", + label=I18nObject(en_US="Test Provider"), + supported_model_types=[ModelType.TEXT_EMBEDDING], + configurate_methods=[], + ) + runtime = _EmbeddingRuntimeStub() + model = TextEmbeddingModel(provider_schema=provider, model_runtime=runtime) + + assert model.get_num_tokens( + model="embedding-model", + credentials={}, + texts=["hello"], + ) == [3] + + +def test_tts_model_accepts_tts_only_runtime_surface() -> None: + provider = ProviderEntity( + provider="test-provider", + label=I18nObject(en_US="Test Provider"), + supported_model_types=[ModelType.TTS], + configurate_methods=[], + ) + runtime = _TTSRuntimeStub() + model = TTSModel(provider_schema=provider, model_runtime=runtime) + + assert list( + model.invoke( + model="voice-model", + credentials={}, + content_text="hello", + voice="nova", + ), + ) == [b"audio"] + assert model.get_tts_model_voices( + model="voice-model", + credentials={}, + ) == ["nova"] + + +def test_moderation_model_accepts_moderation_only_runtime_surface() -> None: + provider = ProviderEntity( + provider="test-provider", + label=I18nObject(en_US="Test Provider"), + supported_model_types=[ModelType.MODERATION], + configurate_methods=[], + ) + runtime = _ModerationRuntimeStub() + model = ModerationModel(provider_schema=provider, model_runtime=runtime) + + assert ( + model.invoke( + model="moderation-model", + credentials={}, + text="hello", + ) + is True + ) + + +def test_rerank_model_accepts_rerank_only_runtime_surface() -> None: + provider = ProviderEntity( + provider="test-provider", + label=I18nObject(en_US="Test Provider"), + supported_model_types=[ModelType.RERANK], + configurate_methods=[], + ) + runtime = _RerankRuntimeStub() + model = RerankModel(provider_schema=provider, model_runtime=runtime) + + result = model.invoke( + model="rerank-model", + credentials={}, + query="hello", + docs=["doc-1"], + ) + + assert result.docs[0].text == "doc-1" + + +def test_speech_to_text_model_accepts_speech_only_runtime_surface() -> None: + provider = ProviderEntity( + provider="test-provider", + label=I18nObject(en_US="Test Provider"), + supported_model_types=[ModelType.SPEECH2TEXT], + configurate_methods=[], + ) + runtime = _SpeechToTextRuntimeStub() + model = Speech2TextModel(provider_schema=provider, model_runtime=runtime) + + assert ( + model.invoke( + model="stt-model", + credentials={}, + file=BytesIO(b"audio"), + ) + == "transcript" + ) diff --git a/tests/test_protocols_exports.py b/tests/test_protocols_exports.py index 72ffec4..999d52a 100644 --- a/tests/test_protocols_exports.py +++ b/tests/test_protocols_exports.py @@ -4,7 +4,20 @@ from graphon.graph.validation import GraphValidationRule from graphon.http.protocols import HttpClientProtocol, HttpResponseProtocol from graphon.model_runtime.memory.prompt_message_memory import PromptMessageMemory -from graphon.model_runtime.runtime import ModelRuntime +from graphon.model_runtime.protocols.llm_runtime import LLMModelRuntime +from graphon.model_runtime.protocols.moderation_runtime import ( + ModerationModelRuntime, +) +from graphon.model_runtime.protocols.provider_runtime import ModelProviderRuntime +from graphon.model_runtime.protocols.rerank_runtime import RerankModelRuntime +from graphon.model_runtime.protocols.runtime import ModelRuntime +from graphon.model_runtime.protocols.speech_to_text_runtime import ( + SpeechToTextModelRuntime, +) +from graphon.model_runtime.protocols.text_embedding_runtime import ( + TextEmbeddingModelRuntime, +) +from graphon.model_runtime.protocols.tts_runtime import TTSModelRuntime from graphon.nodes.code.code_node import WorkflowCodeExecutor from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory from graphon.nodes.llm.runtime_protocols import ( @@ -40,8 +53,11 @@ from graphon.protocols import ( HumanInputNodeRuntimeProtocol as PublicHumanInputNodeRuntimeProtocol, ) +from graphon.protocols import LLMModelRuntime as PublicLLMModelRuntime from graphon.protocols import ModelFactory as PublicModelFactory +from graphon.protocols import ModelProviderRuntime as PublicModelProviderRuntime from graphon.protocols import ModelRuntime as PublicModelRuntime +from graphon.protocols import ModerationModelRuntime as PublicModerationModelRuntime from graphon.protocols import NodeFactory as PublicNodeFactory from graphon.protocols import PreparedLLMProtocol as PublicPreparedLLMProtocol from graphon.protocols import PromptMessageMemory as PublicPromptMessageMemory @@ -52,11 +68,19 @@ ReadOnlyGraphRuntimeState as PublicReadOnlyGraphRuntimeState, ) from graphon.protocols import ReadOnlyVariablePool as PublicReadOnlyVariablePool +from graphon.protocols import RerankModelRuntime as PublicRerankModelRuntime from graphon.protocols import ( RetrieverAttachmentLoaderProtocol as PublicRetrieverAttachmentLoaderProtocol, ) +from graphon.protocols import ( + SpeechToTextModelRuntime as PublicSpeechToTextModelRuntime, +) +from graphon.protocols import ( + TextEmbeddingModelRuntime as PublicTextEmbeddingModelRuntime, +) from graphon.protocols import ToolFileManagerProtocol as PublicToolFileManagerProtocol from graphon.protocols import ToolNodeRuntimeProtocol as PublicToolNodeRuntimeProtocol +from graphon.protocols import TTSModelRuntime as PublicTTSModelRuntime from graphon.protocols import VariableLoader as PublicVariableLoader from graphon.protocols import WorkflowCodeExecutor as PublicWorkflowCodeExecutor from graphon.protocols import ( @@ -75,6 +99,13 @@ def test_public_protocol_exports_match_canonical_definitions() -> None: assert PublicWorkflowFileRuntimeProtocol is WorkflowFileRuntimeProtocol assert PublicNodeFactory is NodeFactory assert PublicGraphValidationRule is GraphValidationRule + assert PublicModelProviderRuntime is ModelProviderRuntime + assert PublicLLMModelRuntime is LLMModelRuntime + assert PublicTextEmbeddingModelRuntime is TextEmbeddingModelRuntime + assert PublicRerankModelRuntime is RerankModelRuntime + assert PublicSpeechToTextModelRuntime is SpeechToTextModelRuntime + assert PublicModerationModelRuntime is ModerationModelRuntime + assert PublicTTSModelRuntime is TTSModelRuntime assert PublicModelRuntime is ModelRuntime assert PublicPromptMessageMemory is PromptMessageMemory assert PublicWorkflowCodeExecutor is WorkflowCodeExecutor @@ -104,15 +135,22 @@ def test_public_protocol_package_exports_are_stable() -> None: "HttpResponseProtocol", "HumanInputFormStateProtocol", "HumanInputNodeRuntimeProtocol", + "LLMModelRuntime", "ModelFactory", + "ModelProviderRuntime", "ModelRuntime", + "ModerationModelRuntime", "NodeFactory", "PreparedLLMProtocol", "PromptMessageMemory", "PromptMessageSerializerProtocol", "ReadOnlyGraphRuntimeState", "ReadOnlyVariablePool", + "RerankModelRuntime", "RetrieverAttachmentLoaderProtocol", + "SpeechToTextModelRuntime", + "TTSModelRuntime", + "TextEmbeddingModelRuntime", "ToolFileManagerProtocol", "ToolNodeRuntimeProtocol", "VariableLoader",