diff --git a/.editorconfig b/.editorconfig new file mode 100644 index 0000000..3e72d8b --- /dev/null +++ b/.editorconfig @@ -0,0 +1,18 @@ +root = true + +[*] +end_of_line = lf +insert_final_newline = true +trim_trailing_whitespace = true +charset = utf-8 + +[*.py] +indent_style = space +indent_size = 4 + +[*.toml] +indent_style = space +indent_size = 4 + +[*.md] +trim_trailing_whitespace = false diff --git a/.gitignore b/.gitignore index f7508cc..cfdefb8 100644 --- a/.gitignore +++ b/.gitignore @@ -16,3 +16,17 @@ dist/ # Plugin Daemon Slim .slim +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ diff --git a/pyproject.toml b/pyproject.toml index a9fe54a..4356929 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -80,6 +80,11 @@ ignore-one-line-docstrings = true 'S101', # Assert statements used for pytest. 'PLR2004', # Magic value used in test cases. 'PLC2701', # Allow import private members from graphon module. + 'PLC1901', # be more precise about assertion in tests. + + # permit calls of private methods until we separate + # public interface tests from implementation tests. + 'SLF001', # allow calls of private class methods in tests ] [tool.ty.environment] diff --git a/src/graphon/entities/pause_reason.py b/src/graphon/entities/pause_reason.py index 8ad5391..deeaaaf 100644 --- a/src/graphon/entities/pause_reason.py +++ b/src/graphon/entities/pause_reason.py @@ -4,7 +4,7 @@ from pydantic import BaseModel, Field -from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.entities import FormInputConfig, UserActionConfig class PauseReasonType(StrEnum): @@ -18,8 +18,8 @@ class HumanInputRequired(BaseModel): ) form_id: str form_content: str - inputs: list[FormInput] = Field(default_factory=list) - actions: list[UserAction] = Field(default_factory=list) + inputs: list[FormInputConfig] = Field(default_factory=list[FormInputConfig]) + actions: list[UserActionConfig] = Field(default_factory=list[UserActionConfig]) node_id: str node_title: str diff --git a/src/graphon/graph_events/node.py b/src/graphon/graph_events/node.py index 9fd507c..f4759c9 100644 --- a/src/graphon/graph_events/node.py +++ b/src/graphon/graph_events/node.py @@ -4,6 +4,7 @@ from pydantic import Field from graphon.entities.pause_reason import PauseReason +from graphon.variables.segments import Segment from graphon.variables.variables import Variable from .base import GraphNodeEventBase @@ -91,6 +92,10 @@ class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase): ..., description="Display text of the chosen action button.", ) + submitted_data: Mapping[str, Segment] = Field( + default_factory=dict, + description="Runtime submitted values keyed by form output variable name.", + ) class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase): diff --git a/src/graphon/node_events/node.py b/src/graphon/node_events/node.py index a733b4d..396567f 100644 --- a/src/graphon/node_events/node.py +++ b/src/graphon/node_events/node.py @@ -8,6 +8,7 @@ from graphon.file.models import File from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.node_events.base import NodeRunResult +from graphon.variables.segments import Segment from graphon.variables.variables import Variable from .base import NodeEventBase @@ -73,6 +74,11 @@ class HumanInputFormFilledEvent(NodeEventBase): action_id: str action_text: str + # submitted_data records the data user submitted in the form inputs. + # It is a mapping from FormInput.output_variable_name to + # their runtime values. + submitted_data: Mapping[str, Segment] = Field(default_factory=dict) + class HumanInputFormTimeoutEvent(NodeEventBase): """Event emitted when a human input form times out.""" diff --git a/src/graphon/nodes/base/node.py b/src/graphon/nodes/base/node.py index e867621..bfd6c5e 100644 --- a/src/graphon/nodes/base/node.py +++ b/src/graphon/nodes/base/node.py @@ -884,6 +884,7 @@ def _(self, event: HumanInputFormFilledEvent) -> NodeRunHumanInputFormFilledEven rendered_content=event.rendered_content, action_id=event.action_id, action_text=event.action_text, + submitted_data=event.submitted_data, ) @_dispatch.register diff --git a/src/graphon/nodes/human_input/_exc.py b/src/graphon/nodes/human_input/_exc.py new file mode 100644 index 0000000..6f47569 --- /dev/null +++ b/src/graphon/nodes/human_input/_exc.py @@ -0,0 +1,22 @@ +from graphon.file.enums import FileTransferMethod + + +class InvalidConfigError(Exception): + pass + + +class InvalidSubmittedDataError(Exception): + pass + + +class InvalidTransferMethodError(InvalidConfigError): + transfer_method: FileTransferMethod + + def __init__(self, transfer_method: FileTransferMethod) -> None: + self.transfer_method = transfer_method + super().__init__(f"invalid file transfer method: {transfer_method}") + + +class ExtensionsNotSetErrorValueError(InvalidConfigError): + def __init__(self) -> None: + super().__init__("allowed_file_extensions not set") diff --git a/src/graphon/nodes/human_input/entities.py b/src/graphon/nodes/human_input/entities.py index f2cc1de..ca8daab 100644 --- a/src/graphon/nodes/human_input/entities.py +++ b/src/graphon/nodes/human_input/entities.py @@ -5,26 +5,31 @@ outside `graphon`. """ +import abc import re from collections.abc import Mapping, Sequence from datetime import datetime, timedelta -from typing import Any, Self, assert_never +from typing import Annotated, Any, Literal, Self, assert_never -from pydantic import BaseModel, Field, field_validator, model_validator +from pydantic import BaseModel, Field, NonNegativeInt, field_validator, model_validator from graphon.entities.base_node_data import BaseNodeData from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.file.enums import FileTransferMethod, FileType from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import Segment -from .enums import ButtonStyle, FormInputType, PlaceholderType, TimeoutUnit +from . import _exc as exc +from .enums import ButtonStyle, FormInputType, TimeoutUnit, ValueSourceType _OUTPUT_VARIABLE_PATTERN = re.compile( r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}", ) -class FormInputDefault(BaseModel): +class StringSource(BaseModel): """Default configuration for form inputs.""" # NOTE: Ideally, a discriminated union would be used to model @@ -32,7 +37,9 @@ class FormInputDefault(BaseModel): # value when switching between `VARIABLE` and `CONSTANT` types. This # necessitates retaining all fields, making a discriminated union unsuitable. - type: PlaceholderType + # NOTE: This class is renamed from FormInputDefault. + + type: ValueSourceType # The selector of default variable, used when `type` is `VARIABLE`. selector: Sequence[str] = Field(default_factory=tuple) @@ -43,7 +50,7 @@ class FormInputDefault(BaseModel): @model_validator(mode="after") def _validate_selector(self) -> Self: - if self.type == PlaceholderType.CONSTANT: + if self.type == ValueSourceType.CONSTANT: return self if len(self.selector) < SELECTORS_LENGTH: msg = ( @@ -54,18 +61,146 @@ def _validate_selector(self) -> Self: return self -class FormInput(BaseModel): - """Form input definition.""" +class StringListSource(BaseModel): + type: ValueSourceType + + # The selector of default variable, used when `type` is `VARIABLE`. + selector: Sequence[str] = Field(default_factory=tuple) + + # The value of the default, used when `type` is `CONSTANT`. + value: list[str] = Field(default_factory=list) + + +class BaseInputConfig(BaseModel): + """BaseInputConfig is the base class for all input field definitions. + One input corresponds to one output variable during form submission. + """ - type: FormInputType output_variable_name: str - default: FormInputDefault | None = None + + @abc.abstractmethod + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + """`extract_variable_selectors` extracts variable selectors + used by this input field. + """ + + @abc.abstractmethod + def resolve_default_value(self, pool: ReadOnlyVariablePool) -> Segment | None: + """`resolve_default_value` resolves the default value for form submission. + + If the form input does not specify a default value, or the default value does + not depend on the runtime variable, this method should return `None`. + """ + + +class ParagraphInputConfig(BaseInputConfig): + """Form input definition.""" + + # NOTE: This class is renamed from FormInput. + type: Literal[FormInputType.PARAGRAPH] = FormInputType.PARAGRAPH + default: StringSource | None = None + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + default = self.default + if default is None: + return [] + if default.type == ValueSourceType.CONSTANT: + return [] + return [default.selector] + + def resolve_default_value(self, pool: ReadOnlyVariablePool) -> Segment | None: + default = self.default + if default is None: + return None + + if default.type == ValueSourceType.CONSTANT: + return None + + return pool.get(default.selector) + + +class SelectInputConfig(BaseInputConfig): + type: Literal[FormInputType.SELECT] = FormInputType.SELECT + option_source: StringListSource + + def extract_variable_selectors(self) -> Sequence[Sequence[NodeType]]: + if self.option_source.type == ValueSourceType.CONSTANT: + return [] + return [self.option_source.selector] + + def resolve_default_value(self, pool: ReadOnlyVariablePool) -> Segment | None: + _ = pool + return None + + +_ALLOWED_TRANSFER_METHOD = frozenset([ + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, +]) + + +class _FileInputCommonConfig(BaseModel): + allowed_file_types: Sequence[FileType] = Field(default_factory=list[FileType]) + allowed_file_extensions: Sequence[str] = Field(default_factory=list) + allowed_file_upload_methods: Sequence[FileTransferMethod] = Field( + default_factory=list[FileTransferMethod] + ) + + @field_validator("allowed_file_upload_methods", mode="after") + @classmethod + def _validate_upload_methods( + cls, transfer_methods: Sequence[FileTransferMethod] + ) -> Sequence[FileTransferMethod]: + validated_values: list[FileTransferMethod] = [] + for value in transfer_methods: + if value not in _ALLOWED_TRANSFER_METHOD: + raise exc.InvalidTransferMethodError(value) + validated_values.append(value) + + return validated_values + + @model_validator(mode="after") + def _validate_extensions(self) -> Self: + if self.allowed_file_types != FileType.CUSTOM: + return self + if not self.allowed_file_extensions: + raise exc.ExtensionsNotSetErrorValueError + return self + + +class FileInputConfig(_FileInputCommonConfig, BaseInputConfig): + type: Literal[FormInputType.FILE] = FormInputType.FILE + + def extract_variable_selectors(self) -> Sequence[Sequence[NodeType]]: + return [] + + def resolve_default_value(self, pool: ReadOnlyVariablePool) -> Segment | None: + _ = pool + return None + + +class FileListInputConfig(_FileInputCommonConfig, BaseInputConfig): + type: Literal[FormInputType.FILE_LIST] = FormInputType.FILE_LIST + number_limits: NonNegativeInt = 0 + + def extract_variable_selectors(self) -> Sequence[Sequence[NodeType]]: + return [] + + def resolve_default_value(self, pool: ReadOnlyVariablePool) -> Segment | None: + _ = pool + return None + + +type FormInputConfig = Annotated[ + ParagraphInputConfig | SelectInputConfig | FileInputConfig | FileListInputConfig, + Field(discriminator="type"), +] _IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") -class UserAction(BaseModel): +class UserActionConfig(BaseModel): """User action configuration.""" # id is the identifier for this action. @@ -94,14 +229,14 @@ class HumanInputNodeData(BaseNodeData): type: NodeType = BuiltinNodeTypes.HUMAN_INPUT form_content: str = "" - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) + inputs: list[FormInputConfig] = Field(default_factory=list[FormInputConfig]) + user_actions: list[UserActionConfig] = Field(default_factory=list[UserActionConfig]) timeout: int = 36 timeout_unit: TimeoutUnit = TimeoutUnit.HOUR @field_validator("inputs") @classmethod - def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: + def _validate_inputs(cls, inputs: list[FormInputConfig]) -> list[FormInputConfig]: seen_names: set[str] = set() for form_input in inputs: name = form_input.output_variable_name @@ -113,7 +248,9 @@ def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: @field_validator("user_actions") @classmethod - def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: + def _validate_user_actions( + cls, user_actions: list[UserActionConfig] + ) -> list[UserActionConfig]: seen_ids: set[str] = set() for action in user_actions: action_id = action.id @@ -162,14 +299,11 @@ def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None: ]) for form_input in self.inputs: - default_value = form_input.default - if default_value is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - default_value_key = ".".join(default_value.selector) - qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#" - variable_mappings[qualified_variable_mapping_key] = default_value.selector + selectors = form_input.extract_variable_selectors() + for selector in selectors: + value_key = ".".join(selector) + qualified_variable_mapping_key = f"{node_id}.#{value_key}#" + variable_mappings[qualified_variable_mapping_key] = selector return variable_mappings @@ -201,8 +335,8 @@ def must_resolve_action_value(self, action_id: str) -> str: class FormDefinition(BaseModel): form_content: str - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) + inputs: list[FormInputConfig] = Field(default_factory=list[FormInputConfig]) + user_actions: list[UserActionConfig] = Field(default_factory=list[UserActionConfig]) rendered_content: str expiration_time: datetime @@ -222,8 +356,8 @@ class HumanInputSubmissionValidationError(ValueError): def validate_human_input_submission( *, - inputs: Sequence[FormInput], - user_actions: Sequence[UserAction], + inputs: Sequence[FormInputConfig], + user_actions: Sequence[UserActionConfig], selected_action_id: str, form_data: Mapping[str, Any], ) -> None: diff --git a/src/graphon/nodes/human_input/enums.py b/src/graphon/nodes/human_input/enums.py index 3fb0ab4..53a3bd0 100644 --- a/src/graphon/nodes/human_input/enums.py +++ b/src/graphon/nodes/human_input/enums.py @@ -42,14 +42,35 @@ class TimeoutUnit(enum.StrEnum): class FormInputType(enum.StrEnum): - """Form input types.""" + """Form input types. - TEXT_INPUT = enum.auto() - PARAGRAPH = enum.auto() + Name for this enumeration are intentionally keep the same as those for + `VariableEntityType`. + """ + # Both `TEXT_INPUT` and `PARAGRAPH` represent string input fields. + # The corresponding generated variable type is `SegmentType.STRING`. + PARAGRAPH = "paragraph" -class PlaceholderType(enum.StrEnum): - """Default value types for form inputs.""" + # A single-select input field (e.g., a dropdown or radio buttons). + # The corresponding generated variable type is `SegmentType.STRING`. + SELECT = "select" + # A file input field that accepts a single file. + # The corresponding generated variable type is `SegmentType.FILE`. + FILE = "file" + + # A file input field that accepts zero or more files. + # The corresponding generated variable type is `SegmentType.ARRAY_FILE`. + FILE_LIST = "file-list" + + +class ValueSourceType(enum.StrEnum): + """ValueSourceType records whether the value comes from a static setting + in form definiton, or a variable while the workflow is running. + """ + + # `VARIABLE` means that the value comes from a variable in workflow execution VARIABLE = enum.auto() + # `CONSTANT` measn that the value comes from a static setting in form definition. CONSTANT = enum.auto() diff --git a/src/graphon/nodes/human_input/human_input_node.py b/src/graphon/nodes/human_input/human_input_node.py index 59a9b32..985185d 100644 --- a/src/graphon/nodes/human_input/human_input_node.py +++ b/src/graphon/nodes/human_input/human_input_node.py @@ -4,7 +4,7 @@ import logging from collections.abc import Generator, Mapping, Sequence from datetime import UTC, datetime -from typing import Any, override +from typing import Any, assert_never, override from graphon.entities.graph_init_params import GraphInitParams from graphon.entities.pause_reason import HumanInputRequired @@ -21,6 +21,7 @@ StreamCompletedEvent, ) from graphon.nodes.base.node import Node +from graphon.nodes.protocols import FileReferenceFactoryProtocol from graphon.nodes.runtime import ( HumanInputFormStateProtocol, HumanInputNodeRuntimeProtocol, @@ -28,10 +29,20 @@ _normalize_human_input_runtime, ) from graphon.runtime.graph_runtime_state import GraphRuntimeState +from graphon.variables.factory import build_segment +from graphon.variables.segments import Segment from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter -from .entities import HumanInputNodeData -from .enums import HumanInputFormStatus, PlaceholderType +from . import _exc as exc +from .entities import ( + FileInputConfig, + FileListInputConfig, + FormInputConfig, + HumanInputNodeData, + ParagraphInputConfig, + SelectInputConfig, +) +from .enums import HumanInputFormStatus _SELECTED_BRANCH_KEY = "selected_branch" @@ -73,6 +84,7 @@ def __init__( # Make `runtime` optional once Graphon provides a default human-input # runtime adapter instead of requiring an embedding-specific implementation. runtime: _HumanInputRuntimeLike, + file_reference_factory: FileReferenceFactoryProtocol, form_repository: object | None = None, ) -> None: super().__init__( @@ -85,6 +97,7 @@ def __init__( runtime, form_repository=form_repository, ) + self._file_reference_factory = file_reference_factory @classmethod @override @@ -157,17 +170,13 @@ def resolve_default_values(self) -> Mapping[str, Any]: variable_pool = self.graph_runtime_state.variable_pool resolved_defaults: dict[str, Any] = {} for form_input in self._node_data.inputs: - if (default_value := form_input.default) is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - resolved_value = variable_pool.get(default_value.selector) - if resolved_value is None: + resolved_default = form_input.resolve_default_value(variable_pool) + if resolved_default is None: # Treat missing variable-backed defaults as absent defaults. continue resolved_defaults[form_input.output_variable_name] = ( WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive( - resolved_value.value, + resolved_default.value, ) ) @@ -236,10 +245,11 @@ def _run(self) -> Generator[NodeEventBase, None, None]: yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - self._OUTPUT_FIELD_ACTION_ID: "", - self._OUTPUT_FIELD_ACTION_VALUE: "", - }, + outputs=self._build_special_outputs( + action_id="", + action_value="", + rendered_content=form.rendered_content, + ), edge_source_handle=self._TIMEOUT_HANDLE, ), ) @@ -256,18 +266,35 @@ def _run(self) -> Generator[NodeEventBase, None, None]: f"form_id={form.id}" ) raise AssertionError(msg) - submitted_inputs = dict(form.submitted_data or {}) - outputs: dict[str, Any] = dict(submitted_inputs) - outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id - outputs[self._OUTPUT_FIELD_ACTION_VALUE] = ( - self._node_data.must_resolve_action_value(selected_action_id) + restored_submission_data = self._restore_submitted_data( + submitted_data=form.submitted_data or {}, + ) + inputs_by_name = { + form_input.output_variable_name: form_input + for form_input in self._node_data.inputs + } + submitted_data: dict[str, Segment] = {} + for name, value in restored_submission_data.items(): + if name not in inputs_by_name: + logger.error("unexpected form data in submitted data, key=%s", name) + continue + submitted_data[name] = value + selected_action_value = next( + ua.title + for ua in self._node_data.user_actions + if ua.id == selected_action_id ) rendered_content = self.render_form_content_with_outputs( form.rendered_content, - outputs, + submitted_data, self._node_data.outputs_field_names(), + self._node_data.inputs, + ) + outputs = dict(submitted_data) | self._build_special_outputs( + action_id=selected_action_id, + action_value=selected_action_value, + rendered_content=rendered_content, ) - outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content action_text = self._node_data.find_action_text(selected_action_id) @@ -276,12 +303,13 @@ def _run(self) -> Generator[NodeEventBase, None, None]: rendered_content=rendered_content, action_id=selected_action_id, action_text=action_text, + submitted_data=submitted_data, ) yield StreamCompletedEvent( node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=submitted_inputs, + inputs=submitted_data, outputs=outputs, edge_source_handle=selected_action_id, ), @@ -309,21 +337,166 @@ def render_form_content_with_outputs( form_content: str, outputs: Mapping[str, Any], field_names: Sequence[str], + form_inputs: Sequence[FormInputConfig] | None = None, ) -> str: - """Replace {{#$output.xxx#}} placeholders with submitted values.""" + """Replace {{#$output.xxx#}} placeholders with submitted values. + + Text inputs render their submitted value directly. File inputs render as + stable placeholders so the final content stays readable and does not + inline transport metadata. + + Returns: + the interplated form content + """ + inputs_by_name = {} + if form_inputs is not None: + inputs_by_name = { + form_input.output_variable_name: form_input + for form_input in form_inputs + } + rendered_content = form_content for field_name in field_names: placeholder = "{{#$output." + field_name + "#}}" - value = outputs.get(field_name) - if value is None: - replacement = "" - elif isinstance(value, (dict, list)): - replacement = json.dumps(value, ensure_ascii=False) - else: - replacement = str(value) + replacement = HumanInputNode._render_output_placeholder_value( + value=outputs.get(field_name), + form_input=inputs_by_name.get(field_name), + ) rendered_content = rendered_content.replace(placeholder, replacement) return rendered_content + @staticmethod + def _render_output_placeholder_value( + *, + value: Any, + form_input: FormInputConfig | None, + ) -> str: + if isinstance(value, Segment): + value = WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive( + value, + ) + + if value is None: + return "" + + if isinstance(form_input, FileInputConfig): + return "[file]" + + if isinstance(form_input, FileListInputConfig): + file_count = 0 + if isinstance(value, Sequence) and not isinstance(value, str | bytes): + file_count = len(value) + return f"[{file_count} files]" + + if isinstance(form_input, ParagraphInputConfig | SelectInputConfig): + return str(value) + + if isinstance(value, (dict, list)): + return json.dumps(value, ensure_ascii=False) + + return str(value) + + def _restore_submitted_data( + self, + *, + submitted_data: Mapping[str, Any], + ) -> dict[str, Segment]: + """_restore_submitted_data restruct python data types from + **validated form data**. + + Returns: + A mapping from input field names to their corresponding + graphon runtime values. + + Raises: + InvalidSubmittedDataError: if submission data are invalid. + """ + # NOTE: ideally this logic shoule be integrated into + # `HumanInputFormStateProtocol.submitted_data`. + + restored_data: dict[str, Segment] = {} + inputs_by_name = { + form_input.output_variable_name: form_input + for form_input in self._node_data.inputs + } + + for name, value in submitted_data.items(): + form_input = inputs_by_name.get(name) + if form_input is None: + restored_data[name] = build_segment(value) + continue + + match form_input: + case FileInputConfig(): + if not isinstance(value, Mapping): + msg = ( + "HumanInput file input expects a mapping payload, " + f"output_variable_name={name}, got={type(value).__name__}" + ) + raise exc.InvalidSubmittedDataError(msg) + restored_data[name] = build_segment( + self._restore_file_value( + output_variable_name=name, + value=value, + ) + ) + case FileListInputConfig(): + if not isinstance(value, list): + msg = ( + "HumanInput file list input expects a list payload, " + f"output_variable_name={name}, got={type(value).__name__}" + ) + raise exc.InvalidSubmittedDataError(msg) + if not all(isinstance(item, Mapping) for item in value): + msg = ( + "HumanInput file list input expects list items to be " + "mapping payloads, " + f"output_variable_name={name}" + ) + raise exc.InvalidSubmittedDataError(msg) + restored_data[name] = build_segment([ + self._restore_file_value( + output_variable_name=name, + value=item, + ) + for item in value + ]) + case ParagraphInputConfig() | SelectInputConfig(): + if not isinstance(value, str): + msg = ( + "HumanInput file list input expects a string, " + f"output_variable_name={name}, got={type(value).__name__}" + ) + raise exc.InvalidSubmittedDataError(msg) + restored_data[name] = build_segment(value) + case _: + assert_never(form_input) + + return restored_data + + def _restore_file_value( + self, + *, + output_variable_name: str, + value: Any, + ) -> Any: + _ = output_variable_name + return self._file_reference_factory.build_from_mapping(mapping=value) + + @classmethod + def _build_special_outputs( + cls, + *, + action_id: str, + action_value: str, + rendered_content: str, + ) -> dict[str, Segment]: + return { + cls._OUTPUT_FIELD_ACTION_ID: build_segment(action_id), + cls._OUTPUT_FIELD_RENDERED_CONTENT: build_segment(rendered_content), + cls._OUTPUT_FIELD_ACTION_VALUE: build_segment(action_value), + } + @classmethod @override def _extract_variable_selector_to_variable_mapping( diff --git a/src/graphon/nodes/protocols.py b/src/graphon/nodes/protocols.py index f13e1ea..3557792 100644 --- a/src/graphon/nodes/protocols.py +++ b/src/graphon/nodes/protocols.py @@ -25,6 +25,10 @@ def get_file_generator_by_tool_file_id( class FileReferenceFactoryProtocol(Protocol): + """FileReferenceFactoryProtocol recreates File object from serialized JSON + format. It enforces approriate permission filtering for the file. + """ + def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: ... diff --git a/src/graphon/nodes/runtime.py b/src/graphon/nodes/runtime.py index f82bd0b..2843e4a 100644 --- a/src/graphon/nodes/runtime.py +++ b/src/graphon/nodes/runtime.py @@ -1,5 +1,6 @@ from __future__ import annotations +import abc from collections.abc import Generator, Mapping, Sequence from datetime import datetime from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable @@ -62,12 +63,14 @@ def build_file_reference(self, *, mapping: Mapping[str, Any]) -> Any: ... class HumanInputNodeRuntimeProtocol(Protocol): """Workflow-layer adapter for human-input runtime persistence and delivery.""" + @abc.abstractmethod def get_form( self, *, node_id: str, ) -> HumanInputFormStateProtocol | None: ... + @abc.abstractmethod def create_form( self, *, diff --git a/src/graphon/runtime/graph_runtime_state_protocol.py b/src/graphon/runtime/graph_runtime_state_protocol.py index 98fb121..7b5c9ce 100644 --- a/src/graphon/runtime/graph_runtime_state_protocol.py +++ b/src/graphon/runtime/graph_runtime_state_protocol.py @@ -12,11 +12,7 @@ def get(self, selector: Sequence[str], /) -> Segment | None: """Get a variable value (read-only).""" ... - def get_all_by_node(self, node_id: str) -> Mapping[str, object]: - """Get all variables for a node (read-only).""" - ... - - def get_by_prefix(self, prefix: str) -> Mapping[str, object]: + def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: """Get all variables stored under a given node prefix (read-only).""" ... diff --git a/src/graphon/runtime/read_only_wrappers.py b/src/graphon/runtime/read_only_wrappers.py index dbc1ef2..822c3ce 100644 --- a/src/graphon/runtime/read_only_wrappers.py +++ b/src/graphon/runtime/read_only_wrappers.py @@ -7,10 +7,11 @@ from graphon.variables.segments import Segment from .graph_runtime_state import GraphRuntimeState +from .graph_runtime_state_protocol import ReadOnlyVariablePool from .variable_pool import VariablePool -class ReadOnlyVariablePoolWrapper: +class ReadOnlyVariablePoolWrapper(ReadOnlyVariablePool): """Provide defensive, read-only access to ``VariablePool``.""" def __init__(self, variable_pool: VariablePool) -> None: @@ -21,17 +22,7 @@ def get(self, selector: Sequence[str], /) -> Segment | None: value = self._variable_pool.get(selector) return deepcopy(value) if value is not None else None - def get_all_by_node(self, node_id: str) -> Mapping[str, object]: - """Return a copy of all variables for the specified node.""" - variables: dict[str, object] = {} - if node_id in self._variable_pool.variable_dictionary: - for key, variable in self._variable_pool.variable_dictionary[ - node_id - ].items(): - variables[key] = deepcopy(variable.value) - return variables - - def get_by_prefix(self, prefix: str) -> Mapping[str, object]: + def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: """Return a copy of all variables stored under the given prefix.""" return self._variable_pool.get_by_prefix(prefix) diff --git a/src/graphon/runtime/variable_pool.py b/src/graphon/runtime/variable_pool.py index 157f28b..6fe8103 100644 --- a/src/graphon/runtime/variable_pool.py +++ b/src/graphon/runtime/variable_pool.py @@ -4,7 +4,7 @@ from collections import defaultdict from collections.abc import Mapping, Sequence from copy import deepcopy -from typing import Annotated, Any, Self +from typing import TYPE_CHECKING, Annotated, Any, Self from pydantic import BaseModel, ConfigDict, Field, model_validator @@ -43,6 +43,9 @@ class VariablePool(BaseModel): # It's the first-level key in the dictionary. # Other elements of the selector are keys in the second-level dictionary. # To get the key, we hash the elements of the selector except the first one. + # + # The `variable_dictionary` is the source of truth for the runtime + # value of variables. variable_dictionary: defaultdict[ str, Annotated[dict[str, Variable], Field(default_factory=dict)], @@ -368,3 +371,12 @@ def flatten(self, *, unprefixed_node_id: str | None = None) -> Mapping[str, obje def empty(cls) -> VariablePool: """Create an empty variable pool.""" return cls() + + +if TYPE_CHECKING: + # static assertion to ensure VariablePool implements the + # ReadOnlyVariablePool. + from .graph_runtime_state_protocol import ReadOnlyVariablePool + + def _assert_readonly_variable_pool(pool: VariablePool) -> ReadOnlyVariablePool: # pyright: ignore[reportUnusedFunction] + return pool diff --git a/tests/entities/test_pause_reason.py b/tests/entities/test_pause_reason.py index 8e76f57..0245f51 100644 --- a/tests/entities/test_pause_reason.py +++ b/tests/entities/test_pause_reason.py @@ -8,6 +8,11 @@ PauseReason, SchedulingPause, ) +from graphon.nodes.human_input.entities import ParagraphInputConfig +from graphon.nodes.human_input.enums import ( + FormInputType, + ValueSourceType, +) class _Holder(BaseModel): @@ -80,3 +85,74 @@ def test_model_validate_with_invalid_type(self) -> None: def test_unknown_type_fails_validation(self) -> None: with pytest.raises(ValidationError): _Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}}) + + def test_human_input_required_model_validate_accepts_current_form_input_payload( + self, + ) -> None: + + form_inputs_json = [ + { + "type": "paragraph", + "output_variable_name": "name", + "default": { + "type": "constant", + "selector": [], + "value": "Alice", + }, + }, + { + "type": "paragraph", + "output_variable_name": "bio", + "default": { + "type": "variable", + "selector": ["start", "bio"], + "value": "", + }, + }, + ] + + actions_json = [ + { + "id": "approve", + "title": "Approve", + "button_style": "primary", + } + ] + payload = { + "reason": { + "TYPE": "human_input_required", + "form_id": "form_id", + "form_content": "form_content", + "inputs": form_inputs_json, + "actions": actions_json, + "node_id": "node_id", + "node_title": "node_title", + "resolved_default_values": {"name": "Alice"}, + } + } + + restored = _Holder.model_validate(payload) + restored_reason = restored.reason + assert isinstance(restored_reason, HumanInputRequired) + assert restored_reason.form_id == "form_id" + assert restored_reason.form_content == "form_content" + assert restored_reason.node_id == "node_id" + assert restored_reason.node_title == "node_title" + assert len(restored_reason.inputs) == 2 + + assert isinstance(restored_reason.inputs[0], ParagraphInputConfig) + assert restored_reason.inputs[0].type == FormInputType.PARAGRAPH + assert restored_reason.inputs[0].output_variable_name == "name" + assert restored_reason.inputs[0].default is not None + assert restored_reason.inputs[0].default.type == ValueSourceType.CONSTANT + assert restored_reason.inputs[0].default.value == "Alice" + + assert isinstance(restored_reason.inputs[1], ParagraphInputConfig) + assert restored_reason.inputs[1].type == FormInputType.PARAGRAPH + assert restored_reason.inputs[1].default is not None + assert restored_reason.inputs[1].default.type == ValueSourceType.VARIABLE + assert restored_reason.inputs[1].default.selector == ["start", "bio"] + assert restored_reason.inputs[1].default.value == "" + assert [action.id for action in restored_reason.actions] == ["approve"] + assert restored_reason.actions[0].button_style.value == "primary" + assert restored_reason.resolved_default_values == {"name": "Alice"} diff --git a/tests/nodes/document_extractor/test_dispatch.py b/tests/nodes/document_extractor/test_dispatch.py index 9c80d20..89803c8 100644 --- a/tests/nodes/document_extractor/test_dispatch.py +++ b/tests/nodes/document_extractor/test_dispatch.py @@ -1,5 +1,3 @@ -# ruff: noqa: SLF001 - import json from unittest.mock import MagicMock diff --git a/tests/nodes/human_input/test_entities.py b/tests/nodes/human_input/test_entities.py new file mode 100644 index 0000000..ef9b45f --- /dev/null +++ b/tests/nodes/human_input/test_entities.py @@ -0,0 +1,247 @@ +from datetime import UTC, datetime +from typing import Any + +from pydantic import BaseModel + +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.human_input.entities import ( + FormDefinition, + HumanInputNodeData, + ParagraphInputConfig, + StringSource, + UserActionConfig, +) +from graphon.nodes.human_input.enums import ( + ButtonStyle, + FormInputType, + TimeoutUnit, + ValueSourceType, +) + +_FORM_INPUTS_JSON_PAYLOAD = [ + { + "type": "paragraph", + "output_variable_name": "name", + "default": { + "type": "constant", + "selector": [], + "value": "Alice", + }, + }, + { + "type": "paragraph", + "output_variable_name": "bio", + "default": { + "type": "variable", + "selector": ["start", "bio"], + "value": "", + }, + }, +] + +_USER_ACTIONS_JSON_PAYLOAD = [ + { + "id": "approve", + "title": "Approve", + "button_style": "primary", + }, + { + "id": "reject", + "title": "Reject", + "button_style": "ghost", + }, +] + + +class _FormInputHolder(BaseModel): + form_input: ParagraphInputConfig + + +class TestHumanInputNodeDataDeserialization: + def test_model_validate_accepts_current_form_input_payload(self) -> None: + payload: dict[str, Any] = { + "type": "human-input", + "title": "Collect Input", + "form_content": "Name: {{#$output.name#}}", + "inputs": _FORM_INPUTS_JSON_PAYLOAD, + "user_actions": _USER_ACTIONS_JSON_PAYLOAD, + "timeout": 3, + "timeout_unit": "day", + } + + restored = HumanInputNodeData.model_validate(payload) + + assert restored.type == "human-input" + assert restored.title == "Collect Input" + assert restored.form_content == "Name: {{#$output.name#}}" + assert len(restored.inputs) == 2 + assert isinstance(restored.inputs[0], ParagraphInputConfig) + assert restored.inputs[0].type == FormInputType.PARAGRAPH + assert restored.inputs[0].output_variable_name == "name" + assert restored.inputs[0].default is not None + assert restored.inputs[0].default.type == ValueSourceType.CONSTANT + assert restored.inputs[0].default.selector == [] + assert restored.inputs[0].default.value == "Alice" + + assert isinstance(restored.inputs[1], ParagraphInputConfig) + + assert restored.inputs[1].type == FormInputType.PARAGRAPH + assert restored.inputs[1].default is not None + assert restored.inputs[1].default.type == ValueSourceType.VARIABLE + assert restored.inputs[1].default.selector == ["start", "bio"] + assert [action.id for action in restored.user_actions] == ["approve", "reject"] + assert [action.button_style.value for action in restored.user_actions] == [ + "primary", + "ghost", + ] + assert restored.timeout == 3 + assert restored.timeout_unit == TimeoutUnit.DAY + + +class TestFormDefinitionDeserialization: + def test_model_validate_accepts_current_form_input_payload(self) -> None: + payload: dict[str, Any] = { + "form_content": "Name: {{#$output.name#}}", + "inputs": _FORM_INPUTS_JSON_PAYLOAD, + "user_actions": _USER_ACTIONS_JSON_PAYLOAD, + "rendered_content": "Name: Alice", + "expiration_time": "2026-04-19T12:00:00Z", + "default_values": {"bio": "Graph runtime"}, + "node_title": "Collect Input", + "display_in_ui": True, + } + + restored = FormDefinition.model_validate(payload) + + assert restored.form_content == "Name: {{#$output.name#}}" + assert restored.rendered_content == "Name: Alice" + assert len(restored.inputs) == 2 + + assert isinstance(restored.inputs[0], ParagraphInputConfig) + assert restored.inputs[0].type == FormInputType.PARAGRAPH + assert restored.inputs[0].default is not None + assert restored.inputs[0].default.type == ValueSourceType.CONSTANT + assert restored.inputs[0].default.value == "Alice" + + assert isinstance(restored.inputs[1], ParagraphInputConfig) + assert restored.inputs[1].type == FormInputType.PARAGRAPH + assert restored.inputs[1].default is not None + assert restored.inputs[1].default.selector == ["start", "bio"] + assert [action.id for action in restored.user_actions] == ["approve", "reject"] + assert restored.default_values == {"bio": "Graph runtime"} + assert restored.node_title == "Collect Input" + assert restored.display_in_ui is True + assert restored.expiration_time == datetime(2026, 4, 19, 12, 0, tzinfo=UTC) + + +class TestFormInputRoundTrip: + def test_paragraph_roundtrip_in_wrapper_model(self) -> None: + original = _FormInputHolder( + form_input=ParagraphInputConfig( + type=FormInputType.PARAGRAPH, + output_variable_name="bio", + default=StringSource( + type=ValueSourceType.VARIABLE, + selector=("start", "bio"), + ), + ) + ) + + payload = original.model_dump(mode="json") + restored = _FormInputHolder.model_validate(payload) + + assert payload == { + "form_input": { + "type": "paragraph", + "output_variable_name": "bio", + "default": { + "type": "variable", + "selector": ["start", "bio"], + "value": "", + }, + } + } + assert restored.form_input.type == FormInputType.PARAGRAPH + assert restored.form_input.output_variable_name == "bio" + assert restored.form_input.default is not None + assert restored.form_input.default.type == ValueSourceType.VARIABLE + assert restored.form_input.default.selector == ["start", "bio"] + assert restored.form_input.default.value == "" + + +class TestHumanInputNodeDataVariableSelectorMapping: + def test_extract_variable_mapping_preserves_current_paragraph_input_behavior( + self, + ) -> None: + node_data = HumanInputNodeData( + title="Collect Input", + form_content=( + "Profile: {{#start.user.name#}} " + "Query: {{#sys.query#}} " + "Output: {{#$output.answer#}}" + ), + inputs=[ + ParagraphInputConfig( + output_variable_name="notes", + ), + ParagraphInputConfig( + output_variable_name="summary", + default=StringSource( + type=ValueSourceType.CONSTANT, + value="Pinned summary", + ), + ), + ParagraphInputConfig( + output_variable_name="bio", + default=StringSource( + type=ValueSourceType.VARIABLE, + selector=("input", "profile", "bio"), + ), + ), + ], + ) + + mapping = node_data.extract_variable_selector_to_variable_mapping("human-node") + + assert mapping == { + "human-node.#start.user#": ["start", "user"], + "human-node.#sys.query#": ["sys", "query"], + "human-node.#input.profile.bio#": ("input", "profile", "bio"), + } + + def test_extract_variable_mapping_ignores_short_template_selectors( + self, + monkeypatch: Any, + ) -> None: + def _extract_short_selector(_self: Any) -> list[VariableSelector]: + return [ + VariableSelector( + variable="#start#", + value_selector=["start"], + ) + ] + + monkeypatch.setattr( + "graphon.nodes.human_input.entities.VariableTemplateParser.extract_variable_selectors", + _extract_short_selector, + ) + + node_data = HumanInputNodeData( + title="Collect Input", + form_content="ignored", + inputs=[], + ) + + mapping = node_data.extract_variable_selector_to_variable_mapping("human-node") + + assert mapping == {} + + +def test_user_action_title_accepts_long_business_value() -> None: + action = UserActionConfig( + id="approve", + title="card_visa_enterprise_001_long_value", + button_style=ButtonStyle.DEFAULT, + ) + + assert action.title == "card_visa_enterprise_001_long_value" diff --git a/tests/nodes/human_input/test_enums.py b/tests/nodes/human_input/test_enums.py new file mode 100644 index 0000000..20ad9fb --- /dev/null +++ b/tests/nodes/human_input/test_enums.py @@ -0,0 +1,11 @@ +import pytest + +from graphon.nodes.human_input.enums import FormInputType +from graphon.variables.input_entities import VariableEntityType + + +@pytest.mark.parametrize("form_input_type", FormInputType, ids=lambda item: item.name) +def test_form_input_type_members_exist_in_variable_entity_type( + form_input_type: type[FormInputType], +) -> None: + assert form_input_type.name in VariableEntityType.__members__ diff --git a/tests/nodes/human_input/test_human_input_node.py b/tests/nodes/human_input/test_human_input_node.py index 9661557..a437d5b 100644 --- a/tests/nodes/human_input/test_human_input_node.py +++ b/tests/nodes/human_input/test_human_input_node.py @@ -1,45 +1,52 @@ -from __future__ import annotations - from collections.abc import Mapping -from dataclasses import dataclass from datetime import UTC, datetime, timedelta -from time import time +from time import perf_counter from typing import Any -from graphon.entities.graph_init_params import GraphInitParams -from graphon.node_events.node import HumanInputFormTimeoutEvent, StreamCompletedEvent -from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.file import File, FileTransferMethod, FileType +from graphon.graph_events.node import ( + NodeRunHumanInputFormFilledEvent, + NodeRunSucceededEvent, +) +from graphon.node_events import HumanInputFormTimeoutEvent, StreamCompletedEvent +from graphon.nodes.human_input.entities import ( + FileInputConfig, + FileListInputConfig, + FormInputConfig, + HumanInputNodeData, + ParagraphInputConfig, + StringSource, + UserActionConfig, +) from graphon.nodes.human_input.enums import ( - ButtonStyle, HumanInputFormStatus, - TimeoutUnit, + ValueSourceType, ) from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.protocols import FileReferenceFactoryProtocol +from graphon.nodes.runtime import ( + HumanInputFormStateProtocol, + HumanInputNodeRuntimeProtocol, +) from graphon.runtime.graph_runtime_state import GraphRuntimeState -from graphon.runtime.variable_pool import VariablePool - - -@dataclass -class _FakeForm: - id: str - rendered_content: str - expiration_time: datetime - status: HumanInputFormStatus - selected_action_id: str | None = None - submitted_data: dict[str, str] | None = None - - @property - def submitted(self) -> bool: - return self.status == HumanInputFormStatus.SUBMITTED +from graphon.variables.segments import ( + ArrayFileSegment, + FileSegment, + Segment, + StringSegment, +) +from ...helpers import build_graph_init_params, build_variable_pool -class _FakeHumanInputRuntime: - def __init__(self, form: _FakeForm) -> None: - self._form = form - def get_form(self, *, node_id: str) -> _FakeForm | None: - assert node_id == "human_input_node" - return self._form +class _RuntimeStub(HumanInputNodeRuntimeProtocol): + def get_form( + self, + *, + node_id: str, + ) -> HumanInputFormStateProtocol | None: + _ = node_id + return None def create_form( self, @@ -48,97 +55,482 @@ def create_form( node_data: HumanInputNodeData, rendered_content: str, resolved_default_values: Mapping[str, Any], - ) -> _FakeForm: - _ = (node_id, node_data, rendered_content, resolved_default_values) - msg = "create_form should not be called in these tests" + ) -> HumanInputFormStateProtocol: + _ = node_id, node_data, rendered_content, resolved_default_values + msg = "create_form should not be called in resolve_default_values tests" raise AssertionError(msg) -def _build_graph_init_params() -> GraphInitParams: - return GraphInitParams( - workflow_id="workflow", - graph_config={}, - run_context={}, - call_depth=0, +class _FileReferenceFactory(FileReferenceFactoryProtocol): + def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + remote_url=mapping.get("remote_url"), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + ) + + +def _build_node( + *, + inputs: list[FormInputConfig], + variables: tuple[tuple[tuple[str, ...], Any], ...] = (), +) -> HumanInputNode: + runtime_state = GraphRuntimeState( + variable_pool=build_variable_pool(variables=variables), + start_at=perf_counter(), + ) + return HumanInputNode( + node_id="human-node", + data=HumanInputNodeData( + title="Collect Input", + form_content="Profile", + inputs=inputs, + ), + graph_init_params=build_graph_init_params( + graph_config={"nodes": [], "edges": []}, + ), + graph_runtime_state=runtime_state, + runtime=_RuntimeStub(), + file_reference_factory=_FileReferenceFactory(), ) -def _build_node(*, form: _FakeForm) -> HumanInputNode: - node_data = HumanInputNodeData( - title="Approval", - type="human-input", - form_content="Selected ticket: {{#$output.ticket#}}", - user_actions=[ - UserAction( - id="approve", - title="card_visa_enterprise_001_long_value", - button_style=ButtonStyle.DEFAULT, +class TestHumanInputNodeResolveDefaultValues: + def test_resolve_default_values_skips_absent_constant_and_missing_defaults( + self, + ) -> None: + node = _build_node( + inputs=[ + ParagraphInputConfig(output_variable_name="without_default"), + ParagraphInputConfig( + output_variable_name="constant_default", + default=StringSource( + type=ValueSourceType.CONSTANT, + value="Pinned text", + ), + ), + ParagraphInputConfig( + output_variable_name="missing_default", + default=StringSource( + type=ValueSourceType.VARIABLE, + selector=("start", "missing"), + ), + ), + ParagraphInputConfig( + output_variable_name="resolved_default", + default=StringSource( + type=ValueSourceType.VARIABLE, + selector=("start", "profile"), + ), + ), + ], + variables=( + ( + ("start", "profile"), + { + "headline": "Graph runtime", + "tags": ["human-input", 3], + }, + ), ), - ], - timeout=3, - timeout_unit=TimeoutUnit.DAY, + ) + + resolved = node.resolve_default_values() + + assert resolved == { + "resolved_default": { + "headline": "Graph runtime", + "tags": ["human-input", 3], + } + } + + +class _SubmittedFormStub(HumanInputFormStateProtocol): + @property + def id(self) -> str: + return "form-1" + + @property + def rendered_content(self) -> str: + return "Attachment submitted" + + @property + def selected_action_id(self) -> str | None: + return "approve" + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return { + "attachment": { + "id": "file-1", + "type": FileType.DOCUMENT, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "upload-1", + "filename": "resume.pdf", + "extension": ".pdf", + "mime_type": "application/pdf", + "size": 128, + }, + "attachments": [ + { + "id": "file-2", + "type": FileType.DOCUMENT, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "upload-2", + "filename": "a.pdf", + "extension": ".pdf", + "mime_type": "application/pdf", + "size": 64, + }, + ], + } + + @property + def submitted(self) -> bool: + return True + + @property + def status(self) -> HumanInputFormStatus: + return HumanInputFormStatus.SUBMITTED + + @property + def expiration_time(self) -> datetime: + return (datetime.now(UTC) + timedelta(hours=1)).replace(tzinfo=None) + + +class _ResumeRuntimeStub(_RuntimeStub): + def get_form( + self, + *, + node_id: str, + ) -> HumanInputFormStateProtocol | None: + _ = node_id + return _SubmittedFormStub() + + +class _SubmittedTextFormStub(HumanInputFormStateProtocol): + @property + def id(self) -> str: + return "form-2" + + @property + def rendered_content(self) -> str: + return "Name: {{#$output.name#}}" + + @property + def selected_action_id(self) -> str | None: + return "approve" + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return { + "name": "Alice", + "unexpected": "discard from event", + } + + @property + def submitted(self) -> bool: + return True + + @property + def status(self) -> HumanInputFormStatus: + return HumanInputFormStatus.SUBMITTED + + @property + def expiration_time(self) -> datetime: + return (datetime.now(UTC) + timedelta(hours=1)).replace(tzinfo=None) + + +class _ResumeTextRuntimeStub(_RuntimeStub): + def get_form( + self, + *, + node_id: str, + ) -> HumanInputFormStateProtocol | None: + _ = node_id + return _SubmittedTextFormStub() + + +class _TimedOutFormStub(HumanInputFormStateProtocol): + @property + def id(self) -> str: + return "form-timeout" + + @property + def rendered_content(self) -> str: + return "Timed out content" + + @property + def selected_action_id(self) -> str | None: + return None + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return None + + @property + def submitted(self) -> bool: + return False + + @property + def status(self) -> HumanInputFormStatus: + return HumanInputFormStatus.TIMEOUT + + @property + def expiration_time(self) -> datetime: + return (datetime.now(UTC) - timedelta(hours=1)).replace(tzinfo=None) + + +class _TimeoutRuntimeStub(_RuntimeStub): + def get_form( + self, + *, + node_id: str, + ) -> HumanInputFormStateProtocol | None: + _ = node_id + return _TimedOutFormStub() + + +def test_human_input_resume_emits_runtime_file_segments() -> None: + runtime_state = GraphRuntimeState( + variable_pool=build_variable_pool(variables=()), + start_at=perf_counter(), ) - return HumanInputNode( - node_id="human_input_node", - data=node_data, - graph_init_params=_build_graph_init_params(), - graph_runtime_state=GraphRuntimeState( - variable_pool=VariablePool(), - start_at=time(), + node = HumanInputNode( + node_id="human-node", + data=HumanInputNodeData( + title="Collect Input", + form_content="Attachment submitted", + inputs=[ + FileInputConfig(output_variable_name="attachment"), + FileListInputConfig( + output_variable_name="attachments", + number_limits=1, + ), + ], + user_actions=[UserActionConfig(id="approve", title="Approve")], + ), + graph_init_params=build_graph_init_params( + graph_config={"nodes": [], "edges": []}, + ), + graph_runtime_state=runtime_state, + runtime=_ResumeRuntimeStub(), + file_reference_factory=_FileReferenceFactory(), + ) + + events = list(node.run()) + filled_event = next( + event for event in events if isinstance(event, NodeRunHumanInputFormFilledEvent) + ) + result = events[-1] + + assert isinstance(filled_event.submitted_data["attachment"], FileSegment) + assert isinstance(filled_event.submitted_data["attachments"], ArrayFileSegment) + assert isinstance(result, NodeRunSucceededEvent) + assert all( + isinstance(value, Segment) for value in result.node_run_result.outputs.values() + ) + assert isinstance(result.node_run_result.outputs["attachment"], FileSegment) + assert isinstance(result.node_run_result.outputs["attachments"], ArrayFileSegment) + assert isinstance(result.node_run_result.outputs["__action_id"], StringSegment) + assert isinstance( + result.node_run_result.outputs["__rendered_content"], + StringSegment, + ) + + +def test_human_input_resume_filters_unknown_fields_from_outputs() -> None: + runtime_state = GraphRuntimeState( + variable_pool=build_variable_pool(variables=()), + start_at=perf_counter(), + ) + node = HumanInputNode( + node_id="human-node", + data=HumanInputNodeData( + title="Collect Input", + form_content="Name: {{#$output.name#}}", + inputs=[ + ParagraphInputConfig(output_variable_name="name"), + ], + user_actions=[UserActionConfig(id="approve", title="Approve")], + ), + graph_init_params=build_graph_init_params( + graph_config={"nodes": [], "edges": []}, + ), + graph_runtime_state=runtime_state, + runtime=_ResumeTextRuntimeStub(), + file_reference_factory=_FileReferenceFactory(), + ) + + events = list(node.run()) + filled_event = next( + event for event in events if isinstance(event, NodeRunHumanInputFormFilledEvent) + ) + result = events[-1] + + assert isinstance(result, NodeRunSucceededEvent) + assert all( + isinstance(value, Segment) for value in result.node_run_result.outputs.values() + ) + assert set(result.node_run_result.outputs) == { + "name", + "__action_id", + "__action_value", + "__rendered_content", + } + assert isinstance(result.node_run_result.outputs["name"], StringSegment) + + assert set(filled_event.submitted_data) == {"name"} + assert filled_event.submitted_data["name"] == result.node_run_result.outputs["name"] + assert filled_event.rendered_content == "Name: Alice" + + +def test_human_input_resume_adds_special_outputs_separately() -> None: + runtime_state = GraphRuntimeState( + variable_pool=build_variable_pool(variables=()), + start_at=perf_counter(), + ) + node = HumanInputNode( + node_id="human-node", + data=HumanInputNodeData( + title="Collect Input", + form_content="Name: {{#$output.name#}}", + inputs=[ + ParagraphInputConfig(output_variable_name="name"), + ], + user_actions=[UserActionConfig(id="approve", title="Approve")], ), - runtime=_FakeHumanInputRuntime(form), + graph_init_params=build_graph_init_params( + graph_config={"nodes": [], "edges": []}, + ), + graph_runtime_state=runtime_state, + runtime=_ResumeTextRuntimeStub(), + file_reference_factory=_FileReferenceFactory(), ) + events = list(node.run()) + result = events[-1] -def _run_node_events(form: _FakeForm) -> list[object]: - return list(_build_node(form=form)._run()) # noqa: SLF001 + assert isinstance(result, NodeRunSucceededEvent) + assert result.node_run_result.outputs["__action_id"] == StringSegment( + value="approve", + ) + assert result.node_run_result.outputs["__rendered_content"] == StringSegment( + value="Name: Alice", + ) -def test_user_action_title_accepts_long_business_value() -> None: - action = UserAction( - id="approve", - title="card_visa_enterprise_001_long_value", - button_style=ButtonStyle.DEFAULT, +def test_human_input_timeout_adds_special_outputs_separately() -> None: + runtime_state = GraphRuntimeState( + variable_pool=build_variable_pool(variables=()), + start_at=perf_counter(), + ) + node = HumanInputNode( + node_id="human-node", + data=HumanInputNodeData( + title="Collect Input", + form_content="Name: {{#$output.name#}}", + inputs=[ + ParagraphInputConfig(output_variable_name="name"), + ], + user_actions=[UserActionConfig(id="approve", title="Approve")], + ), + graph_init_params=build_graph_init_params( + graph_config={"nodes": [], "edges": []}, + ), + graph_runtime_state=runtime_state, + runtime=_TimeoutRuntimeStub(), + file_reference_factory=_FileReferenceFactory(), ) - assert action.title == "card_visa_enterprise_001_long_value" + events = list(node.run()) + result = events[-1] + + assert isinstance(result, NodeRunSucceededEvent) + assert result.node_run_result.outputs == { + "__action_id": StringSegment(value=""), + "__action_value": StringSegment(value=""), + "__rendered_content": StringSegment(value="Timed out content"), + } def test_human_input_submission_emits_action_value_outputs() -> None: - form = _FakeForm( - id="form-1", - rendered_content="Selected ticket: {{#$output.ticket#}}", - expiration_time=datetime.now(UTC).replace(tzinfo=None) + timedelta(hours=1), - status=HumanInputFormStatus.SUBMITTED, - selected_action_id="approve", - submitted_data={"ticket": "TICKET-1"}, + + runtime_state = GraphRuntimeState( + variable_pool=build_variable_pool(variables=()), + start_at=perf_counter(), + ) + node = HumanInputNode( + node_id="human-node", + data=HumanInputNodeData( + title="Collect Input", + form_content="Name: {{#$output.name#}}", + inputs=[ + ParagraphInputConfig(output_variable_name="name"), + ], + user_actions=[UserActionConfig(id="approve", title="Approve")], + ), + graph_init_params=build_graph_init_params( + graph_config={"nodes": [], "edges": []}, + ), + graph_runtime_state=runtime_state, + runtime=_ResumeRuntimeStub(), + file_reference_factory=_FileReferenceFactory(), ) - events = _run_node_events(form) + events = list(node._run()) completed = next( event for event in events if isinstance(event, StreamCompletedEvent) ) - assert completed.node_run_result.outputs["__action_id"] == "approve" - assert ( - completed.node_run_result.outputs["__action_value"] - == "card_visa_enterprise_001_long_value" + assert completed.node_run_result.outputs["__action_id"] == StringSegment( + value="approve" + ) + assert completed.node_run_result.outputs["__action_value"] == StringSegment( + value="Approve" ) def test_human_input_timeout_emits_empty_action_value() -> None: - form = _FakeForm( - id="form-2", - rendered_content="Selected ticket: {{#$output.ticket#}}", - expiration_time=datetime.now(UTC).replace(tzinfo=None) - timedelta(minutes=1), - status=HumanInputFormStatus.TIMEOUT, - ) - events = _run_node_events(form) + runtime_state = GraphRuntimeState( + variable_pool=build_variable_pool(variables=()), + start_at=perf_counter(), + ) + node = HumanInputNode( + node_id="human-node", + data=HumanInputNodeData( + title="Collect Input", + form_content="Name: {{#$output.name#}}", + inputs=[ + ParagraphInputConfig(output_variable_name="name"), + ], + user_actions=[ + UserActionConfig( + id="approve", title="card_visa_enterprise_001_long_value" + ) + ], + ), + graph_init_params=build_graph_init_params( + graph_config={"nodes": [], "edges": []}, + ), + graph_runtime_state=runtime_state, + runtime=_TimeoutRuntimeStub(), + file_reference_factory=_FileReferenceFactory(), + ) + events = list(node._run()) assert any(isinstance(event, HumanInputFormTimeoutEvent) for event in events) completed = next( event for event in events if isinstance(event, StreamCompletedEvent) ) - assert not completed.node_run_result.outputs["__action_id"] - assert not completed.node_run_result.outputs["__action_value"] + assert completed.node_run_result.outputs["__action_id"] == StringSegment(value="") + assert completed.node_run_result.outputs["__action_value"] == StringSegment( + value="" + ) diff --git a/tests/nodes/human_input/test_human_input_node_internal.py b/tests/nodes/human_input/test_human_input_node_internal.py new file mode 100644 index 0000000..9b6a2fd --- /dev/null +++ b/tests/nodes/human_input/test_human_input_node_internal.py @@ -0,0 +1,242 @@ +from __future__ import annotations + +from collections.abc import Mapping +from time import perf_counter +from typing import Any + +import pytest + +from graphon.file import File, FileTransferMethod, FileType +from graphon.nodes.human_input import _exc as exc +from graphon.nodes.human_input.entities import ( + FileInputConfig, + FileListInputConfig, + FormInputConfig, + HumanInputNodeData, + ParagraphInputConfig, + SelectInputConfig, + StringListSource, +) +from graphon.nodes.human_input.enums import ValueSourceType +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.protocols import FileReferenceFactoryProtocol +from graphon.nodes.runtime import ( + HumanInputFormStateProtocol, + HumanInputNodeRuntimeProtocol, +) +from graphon.runtime.graph_runtime_state import GraphRuntimeState +from graphon.variables.segments import ArrayFileSegment, FileSegment, StringSegment + +from ...helpers import build_graph_init_params, build_variable_pool + + +class _RuntimeStub(HumanInputNodeRuntimeProtocol): + def get_form(self, *, node_id: str) -> HumanInputFormStateProtocol | None: + _ = node_id + msg = "not used in internal tests" + raise AssertionError(msg) + + def create_form( + self, + *, + node_id: str, + node_data: HumanInputNodeData, + rendered_content: str, + resolved_default_values: Mapping[str, Any], + ) -> HumanInputFormStateProtocol: + _ = (node_id, node_data, rendered_content, resolved_default_values) + msg = "not used in internal tests" + raise AssertionError(msg) + + +class _FileReferenceFactory(FileReferenceFactoryProtocol): + def __init__(self) -> None: + self.mappings: list[Mapping[str, Any]] = [] + + def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: + self.mappings.append(mapping) + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + remote_url=mapping.get("remote_url"), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + ) + + +def _build_node( + *, + file_reference_factory: FileReferenceFactoryProtocol, + inputs: list[FormInputConfig] | None = None, +) -> HumanInputNode: + if inputs is None: + inputs = [ + FileInputConfig(output_variable_name="attachment"), + FileListInputConfig( + output_variable_name="attachments", + number_limits=2, + ), + ] + + return HumanInputNode( + node_id="human-node", + data=HumanInputNodeData( + title="Collect Input", + inputs=inputs, + ), + graph_init_params=build_graph_init_params( + graph_config={"nodes": [], "edges": []}, + ), + graph_runtime_state=GraphRuntimeState( + variable_pool=build_variable_pool(), + start_at=perf_counter(), + ), + runtime=_RuntimeStub(), + file_reference_factory=file_reference_factory, + ) + + +def test_restore_submitted_data_builds_segments_from_submitted_values() -> None: + factory = _FileReferenceFactory() + node = _build_node( + file_reference_factory=factory, + inputs=[ + FileInputConfig(output_variable_name="attachment"), + FileListInputConfig( + output_variable_name="attachments", + number_limits=2, + ), + ParagraphInputConfig(output_variable_name="name"), + SelectInputConfig( + output_variable_name="choice", + option_source=StringListSource( + type=ValueSourceType.CONSTANT, + value=["yes", "no"], + ), + ), + ], + ) + + restored = node._restore_submitted_data( + submitted_data={ + "attachment": { + "type": FileType.DOCUMENT, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "upload-1", + "filename": "resume.pdf", + "extension": ".pdf", + "mime_type": "application/pdf", + "size": 128, + }, + "attachments": [ + { + "type": FileType.DOCUMENT, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "upload-2", + "filename": "cover.pdf", + "extension": ".pdf", + "mime_type": "application/pdf", + "size": 64, + } + ], + "name": "Alice", + "choice": "yes", + "unexpected": "keep as segment", + }, + ) + + assert isinstance(restored["attachment"], FileSegment) + assert restored["attachment"].value.related_id == "upload-1" + assert isinstance(restored["attachments"], ArrayFileSegment) + assert len(restored["attachments"].value) == 1 + assert restored["attachments"].value[0].related_id == "upload-2" + assert restored["name"] == StringSegment(value="Alice") + assert restored["choice"] == StringSegment(value="yes") + assert restored["unexpected"] == StringSegment(value="keep as segment") + assert len(factory.mappings) == 2 + + +def test_restore_submitted_data_rejects_non_mapping_file_payload() -> None: + node = _build_node(file_reference_factory=_FileReferenceFactory()) + + with pytest.raises( + exc.InvalidSubmittedDataError, match="expects a mapping payload" + ): + node._restore_submitted_data( + submitted_data={ + "attachment": "upload-1", + }, + ) + + +def test_restore_submitted_data_rejects_non_list_file_list_payload() -> None: + node = _build_node(file_reference_factory=_FileReferenceFactory()) + + with pytest.raises(exc.InvalidSubmittedDataError, match="expects a list payload"): + node._restore_submitted_data( + submitted_data={ + "attachments": { + "type": FileType.DOCUMENT, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "upload-2", + }, + }, + ) + + +def test_restore_submitted_data_rejects_non_mapping_file_list_items() -> None: + node = _build_node(file_reference_factory=_FileReferenceFactory()) + + with pytest.raises( + exc.InvalidSubmittedDataError, + match="expects list items to be mapping payloads", + ): + node._restore_submitted_data( + submitted_data={ + "attachments": [ + { + "type": FileType.DOCUMENT, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "related_id": "upload-2", + }, + "upload-3", + ], + }, + ) + + +@pytest.mark.parametrize( + ("field_name", "field_config"), + [ + ("name", ParagraphInputConfig(output_variable_name="name")), + ( + "choice", + SelectInputConfig( + output_variable_name="choice", + option_source=StringListSource( + type=ValueSourceType.CONSTANT, + value=["yes", "no"], + ), + ), + ), + ], +) +def test_restore_submitted_data_rejects_non_string_text_payload( + field_name: str, + field_config: FormInputConfig, +) -> None: + node = _build_node( + file_reference_factory=_FileReferenceFactory(), + inputs=[field_config], + ) + + with pytest.raises(exc.InvalidSubmittedDataError, match="expects a string"): + node._restore_submitted_data( + submitted_data={ + field_name: 123, + }, + ) diff --git a/tests/nodes/llm/test_node.py b/tests/nodes/llm/test_node.py index 1ccbca4..34c40cb 100644 --- a/tests/nodes/llm/test_node.py +++ b/tests/nodes/llm/test_node.py @@ -69,7 +69,7 @@ def test_run_emits_model_identity_in_node_result_inputs( ]), ) - events = list(node._run()) # noqa: SLF001 + events = list(node._run()) completed_event = next( event for event in events if isinstance(event, StreamCompletedEvent) ) diff --git a/tests/nodes/parameter_extractor/test_prompts.py b/tests/nodes/parameter_extractor/test_prompts.py index d3ee731..c19832e 100644 --- a/tests/nodes/parameter_extractor/test_prompts.py +++ b/tests/nodes/parameter_extractor/test_prompts.py @@ -176,7 +176,7 @@ def test_prepare_run_context_exposes_model_identity_in_inputs( ) monkeypatch.setattr(node, "_build_run_prompt", lambda **_: ([], [])) - run_context = node._prepare_run_context() # noqa: SLF001 + run_context = node._prepare_run_context() assert run_context.inputs["query"] == "weather in sf" assert run_context.inputs["model_provider"] == "test" @@ -210,7 +210,7 @@ def test_parameter_extractor_run_emits_model_identity_in_inputs( ) monkeypatch.setattr(node.model_instance, "invoke_llm", lambda **_: invoke_result) - result = node._run() # noqa: SLF001 + result = node._run() assert result.inputs["query"] == "weather in sf" assert result.inputs["model_provider"] == "test" diff --git a/tests/nodes/question_classifier/test_question_classifier_node.py b/tests/nodes/question_classifier/test_question_classifier_node.py index cd44630..3ab8333 100644 --- a/tests/nodes/question_classifier/test_question_classifier_node.py +++ b/tests/nodes/question_classifier/test_question_classifier_node.py @@ -173,7 +173,7 @@ def test_question_classifier_constructor_accepts_dependency_bundle( ]), ) - result = node._run() # noqa: SLF001 + result = node._run() assert result.process_data["prompts"] == ["serialized prompt"] @@ -295,7 +295,7 @@ def test_question_classifier_run_returns_custom_class_label( ]), ) - result = node._run() # noqa: SLF001 + result = node._run() assert result.outputs["class_name"] == "Questions about refunds" assert result.outputs["class_label"] == "Refund desk" @@ -367,7 +367,7 @@ def test_question_classifier_run_falls_back_to_canonical_class_label( ]), ) - result = node._run() # noqa: SLF001 + result = node._run() assert result.outputs["class_name"] == "Questions about refunds" assert result.outputs["class_label"] == "CLASS 2" diff --git a/tests/nodes/test_human_input_runtime_binding.py b/tests/nodes/test_human_input_runtime_binding.py index 5ae3870..70309f5 100644 --- a/tests/nodes/test_human_input_runtime_binding.py +++ b/tests/nodes/test_human_input_runtime_binding.py @@ -7,9 +7,11 @@ import pytest +from graphon.file import File, FileTransferMethod, FileType from graphon.nodes.human_input.entities import HumanInputNodeData from graphon.nodes.human_input.enums import HumanInputFormStatus from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.protocols import FileReferenceFactoryProtocol from graphon.nodes.runtime import ( HumanInputFormRepositoryBindableRuntimeProtocol, HumanInputFormStateProtocol, @@ -20,7 +22,7 @@ from ..helpers import build_graph_init_params, build_variable_pool -class _StubHumanInputRuntime: +class _StubHumanInputRuntime(HumanInputNodeRuntimeProtocol): def get_form(self, *, node_id: str) -> HumanInputFormStateProtocol | None: _ = node_id msg = "not used in this test" @@ -74,6 +76,21 @@ def expiration_time(self) -> datetime: return self._expiration_time +class _FileReferenceFactory(FileReferenceFactoryProtocol): + def build_from_mapping(self, *, mapping: Mapping[str, Any]) -> File: + return File( + file_id=mapping.get("id"), + file_type=FileType(mapping["type"]), + transfer_method=FileTransferMethod(mapping["transfer_method"]), + remote_url=mapping.get("remote_url"), + related_id=mapping.get("related_id"), + filename=mapping.get("filename"), + extension=mapping.get("extension"), + mime_type=mapping.get("mime_type"), + size=mapping.get("size", -1), + ) + + class _RunnableHumanInputRuntime(_StubHumanInputRuntime): def __init__(self) -> None: self.get_form_calls: list[str] = [] @@ -102,7 +119,9 @@ def __init__( *, bound_runtime: HumanInputNodeRuntimeProtocol | None = None, ) -> None: - self.bound_runtime = bound_runtime or _StubHumanInputRuntime() + self.bound_runtime: HumanInputNodeRuntimeProtocol = ( + bound_runtime or _StubHumanInputRuntime() + ) self.bound_repositories: list[object] = [] def with_form_repository( @@ -153,6 +172,7 @@ def _build_human_input_node( start_at=perf_counter(), ), runtime=runtime, + file_reference_factory=_FileReferenceFactory(), form_repository=form_repository, ) diff --git a/tests/runtime/test_graph_runtime_state.py b/tests/runtime/test_graph_runtime_state.py index 29c02c2..6cebe5b 100644 --- a/tests/runtime/test_graph_runtime_state.py +++ b/tests/runtime/test_graph_runtime_state.py @@ -4,12 +4,14 @@ import pytest +from graphon.file import File, FileTransferMethod, FileType from graphon.graph_engine.domain.graph_execution import GraphExecution from graphon.graph_engine.ready_queue.in_memory import InMemoryReadyQueue from graphon.model_runtime.entities.llm_entities import LLMUsage from graphon.runtime.graph_runtime_state import GraphRuntimeState from graphon.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper from graphon.runtime.variable_pool import VariablePool +from graphon.variables.segments import ArrayFileSegment, FileSegment from graphon.variables.variables import StringVariable CONVERSATION_VARIABLE_NODE_ID = "conversation" @@ -395,6 +397,32 @@ def test_snapshot_restore_preserves_updated_conversation_variable(self) -> None: assert restored_value is not None assert restored_value.value == "after" + def test_snapshot_restore_preserves_file_segments(self) -> None: + variable_pool = VariablePool() + file_value = File( + file_id="file-1", + file_type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/resume.pdf", + filename="resume.pdf", + extension=".pdf", + mime_type="application/pdf", + size=128, + ) + variable_pool.add(("node", "attachment"), FileSegment(value=file_value)) + variable_pool.add(("node", "attachments"), ArrayFileSegment(value=[file_value])) + + state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) + + restored = GraphRuntimeState.from_snapshot(state.dumps()) + + restored_file = restored.variable_pool.get(("node", "attachment")) + restored_files = restored.variable_pool.get(("node", "attachments")) + assert isinstance(restored_file, FileSegment) + assert restored_file.value.filename == "resume.pdf" + assert isinstance(restored_files, ArrayFileSegment) + assert restored_files.value[0].filename == "resume.pdf" + def test_snapshot_restore_preserves_file_variable_id(self) -> None: restored = GraphRuntimeState.from_snapshot( _HISTORICAL_FILE_SNAPSHOT_JSON_FROM_749751D_PARENT,