From 2e4a5a51d39caa532b8bf0b874ce56e7905784d8 Mon Sep 17 00:00:00 2001 From: Alex <119050107+ofSingularMind@users.noreply.github.com> Date: Sun, 19 Apr 2026 14:03:27 +0200 Subject: [PATCH 1/4] feat: add multioutput_submodels. First draft. --- MULTI_AND_ZERO_OUTPUT_PLAN.md | 379 ++++++++++++++++++++++++++++ src/backends/default.jl | 1 + src/graph_engine.jl | 57 +++++ src/model_macro.jl | 126 ++++++++- test/multi_and_zero_output_tests.jl | 233 +++++++++++++++++ 5 files changed, 784 insertions(+), 12 deletions(-) create mode 100644 MULTI_AND_ZERO_OUTPUT_PLAN.md create mode 100644 test/multi_and_zero_output_tests.jl diff --git a/MULTI_AND_ZERO_OUTPUT_PLAN.md b/MULTI_AND_ZERO_OUTPUT_PLAN.md new file mode 100644 index 00000000..a1784f17 --- /dev/null +++ b/MULTI_AND_ZERO_OUTPUT_PLAN.md @@ -0,0 +1,379 @@ +# Implementation Plan: Multi-Output and Zero-Output Submodel Interfaces + +## Overview + +Two new features for GraphPPL.jl's nested model specification: + +**(a) Multi-output LHS** — Allow multiple return interfaces on the left-hand side of `~`: +```julia +(a, b) ~ my_submodel(c = c_val, d = d_val) +``` +This means 2 interfaces are "missing" from the RHS and are assigned to `a` and `b` in declaration order. + +**(b) Zero-output (bare call)** — Allow submodel invocation without `~`: +```julia +my_submodel(x = x_val, y = y_val, z = z_val) +``` +All interfaces are explicitly provided on the RHS; there is no LHS variable. + +--- + +## Current Architecture (Summary) + +The `@model` macro body is transformed through a pipeline of expression-rewriting functions (defined in `src/backends/default.jl`): + +1. `check_reserved_variable_names_model` +2. `save_expression_in_tilde` — injects `created_by` option +3. `convert_deterministic_statement` — `:=` → `~` with `is_deterministic` +4. `convert_local_statement` — `local x ~ ...` +5. `convert_to_kwargs_expression` — positional kw-like args → kwargs +6. `add_get_or_create_expression` — ensure LHS var exists in model +7. `convert_anonymous_variables` — nested `f(g(...))` → anonymous + tilde +8. `replace_begin_end` — `begin`/`end` → `FunctionalIndex` +9. `convert_tilde_expression` — final transform to `make_node!` call + +### Where the single-output assumption is enforced: + +| Location | Constraint | +|---|---| +| `model_macro.jl:~651` — `convert_tilde_expression` | `@capture(lhs, (var_[index__]) \| (var_))` only matches single symbol or indexed symbol | +| `model_macro.jl:~606` — `generate_lhs_proxylabel` | Takes a single `var` symbol | +| `model_macro.jl:~406` — `add_get_or_create_expression` | Calls `@capture(lhs, (var_[index__]) \| (var_))` — single var | +| `graph_engine.jl:~1846` — `prepare_interfaces` | `static(length(rhs_interfaces)) + static(1)` — hardcoded +1 | +| `graph_engine.jl:~1850-1863` — `prepare_interfaces` dispatch | Errors if missing ≠ 1 | +| `graph_engine.jl:~generated make_node!` | `__lhs_interface__::Union{NodeLabel, ProxyLabel, VariableRef}` — single type | +| `model_macro.jl:~780` — `get_make_node_function` | `make_node!` signature expects single `__lhs_interface__` | + +### For zero-output (bare function call): +Without `~`, the expression never enters the tilde pipeline at all — it's treated as a regular Julia function call. No pipeline step captures bare `my_submodel(x=..., y=..., z=...)`. + +--- + +## Implementation Plan + +### Feature (a): Multi-Output LHS `(a, b) ~ my_submodel(...)` + +#### 1. `save_expression_in_tilde` (model_macro.jl ~line 178) +**No change needed.** The `@capture(e, (lhs_ ~ rhs_ ...))` already captures any LHS expression including tuples, since `lhs_` is a generic pattern. + +#### 2. `add_get_or_create_expression` (model_macro.jl ~line 400) +**Change:** Add a new branch that detects when `lhs` is a tuple expression. For each element in the tuple, call `generate_get_or_create` separately. + +```julia +function add_get_or_create_expression(e::Expr) + if @capture(e, (lhs_ ~ rhs_ where {options__})) + if lhs isa Expr && lhs.head == :tuple + # Multi-output: get_or_create each element + creates = map(lhs.args) do elem + @capture(elem, (var_[index__]) | (var_)) + generate_get_or_create(var, index, rhs) + end + return quote + $(creates...) + $e + end + else + @capture(lhs, (var_[index__]) | (var_)) + return quote + $(generate_get_or_create(var, index, rhs)) + $e + end + end + end + return e +end +``` + +#### 3. `convert_tilde_expression` (model_macro.jl ~line 641) +**Change:** Add a new branch for tuple LHS. Generate multiple `ProxyLabel`s and pack them into a tuple, then call `make_node!` with the tuple. + +```julia +function convert_tilde_expression(e::Expr) + if @capture(e, (lhs_ ~ fform_(args__; kwargs__) where {options__}) | ...) + args = GraphPPL.proxy_args(combine_args(args, kwargs)) + options = GraphPPL.options_vector_to_named_tuple(options) + nodesym = gensym(:node) + varsym = gensym(:var) + + if lhs isa Expr && lhs.head == :tuple + # Multi-output LHS: (a, b) ~ submodel(...) + proxy_labels = map(lhs.args) do elem + @capture(elem, (var_[index__]) | (var_)) || error("Invalid tuple element $(elem)") + generate_lhs_proxylabel(var, index) + end + lhs_tuple = Expr(:tuple, proxy_labels...) + return quote + begin + $nodesym, $varsym = GraphPPL.make_node!( + __model__, __context__, GraphPPL.NodeCreationOptions($(options)), + $fform, $lhs_tuple, $args + ) + $varsym + end + end + else + # Existing single-output path + @capture(lhs, (var_[index__]) | (var_)) || error(...) + ... + end + end +end +``` + +#### 4. `generate_lhs_proxylabel` — tuple variant (model_macro.jl ~line 606) +**Add:** A new method to generate a tuple of proxy labels. Actually handled inline in step 3 above. + +#### 5. `prepare_interfaces` (graph_engine.jl ~line 1845) +**Change:** Add a new dispatch for multi-output. When `lhs_interface` is a `Tuple`, compute `length(rhs) + length(lhs_tuple)` as total interface count, then match multiple missing interfaces to the tuple elements. + +```julia +# New: multi-output prepare_interfaces +function prepare_interfaces(model::Model, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {F} + n_total = static(length(rhs_interfaces)) + static(length(lhs_interfaces)) + missing = missing_interfaces(model, fform, n_total, rhs_interfaces) + return prepare_interfaces_multi(missing, fform, lhs_interfaces, rhs_interfaces) +end + +function prepare_interfaces_multi(::StaticInterfaces{I}, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {I, F} + if length(I) != length(lhs_interfaces) + error("Node '$(fform)' has $(length(I)) missing interfaces but $(length(lhs_interfaces)) were provided on the LHS.") + end + # Build NamedTuple: missing interfaces mapped to lhs elements, then rhs + all_keys = (I..., keys(rhs_interfaces)...) + all_vals = (lhs_interfaces..., values(rhs_interfaces)...) + return NamedTuple{all_keys}(all_vals) +end +``` + +#### 6. `make_node!` generated function (model_macro.jl ~line 760) +**Change:** Add a second `make_node!` method in `get_make_node_function` that accepts a `Tuple` for `__lhs_interface__`. This method calls `prepare_interfaces` with the tuple, then proceeds as normal but returns the tuple of variables. + +Also add a dispatch in `graph_engine.jl` for the `Composite` path that routes `Tuple` lhs through to the generated function. + +#### 7. `make_node!` dispatch chain (graph_engine.jl ~line 1930+) +**Change:** Add dispatches to route `lhs_interface::Tuple` through the composite node path. + +```julia +# Multi-output composite dispatch +make_node!(::True, ::Composite, ::Stochastic, model, ctx, options, fform, + lhs_interface::Tuple, rhs_interfaces::NamedTuple) = + make_node!(Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, + static(length(rhs_interfaces) + length(lhs_interface))) +``` + +--- + +### Feature (b): Zero-Output Bare Call `my_submodel(x=..., y=..., z=...)` + +#### 1. New pipeline step: `convert_bare_submodel_call` (model_macro.jl) +**Add:** A new pipeline function that detects bare function calls with all keyword arguments and converts them to a `~` expression with no LHS. + +```julia +function convert_bare_submodel_call(e::Expr) + if @capture(e, fform_(; kwargs__)) || @capture(e, fform_(args__)) + if kwargs !== nothing || (args !== nothing && is_kwargs_expression(args)) + # Check if fform is a known Composite node — we can't know at macro time, + # so we generate a runtime check + return :(GraphPPL.make_node_no_output!(__model__, __context__, + GraphPPL.NodeCreationOptions(), $fform, $(keyword_expressions_to_named_tuple(args_or_kwargs)))) + end + end + return e +end +``` + +Actually, a simpler approach: convert the bare call to a tilde expression with a special sentinel LHS, like `nothing ~ submodel(...)`. This requires less invasive changes. + +**Better approach:** Add a new pipeline step that runs before `save_expression_in_tilde`. It detects function calls where ALL arguments are keyword arguments AND the function is annotated with `@model`. Since we can't know at macro-expansion time whether a function is `@model`-defined, we use a **runtime check**: wrap the call in `GraphPPL.__maybe_submodel_call(...)` which checks `NodeType` at runtime. + +**Simplest approach (chosen):** We define a new pipeline step `convert_zero_output_submodel` that transforms: +```julia +my_submodel(x = val1, y = val2, z = val3) +``` +into: +```julia +nothing ~ my_submodel(x = val1, y = val2, z = val3) +``` + +Then modify the tilde pipeline to handle `nothing` on the LHS by: +- In `add_get_or_create_expression`: skip get_or_create when lhs is `nothing` (literal) +- In `convert_tilde_expression`: when lhs is `:nothing`, pass a sentinel (e.g., `GraphPPL.NothingInterface()`) +- In `prepare_interfaces`: add a dispatch for `NothingInterface` that asserts 0 missing interfaces +- In `make_node!`: add dispatch for `NothingInterface` LHS that skips the lhs_interface assignment + +**Issue:** At macro time, we cannot distinguish `my_submodel(x=1, y=2)` (which should become a submodel call) from `some_function(x=1, y=2)` (which is a regular Julia function call). The pipeline would incorrectly transform all keyword-only function calls. + +**Resolution:** We require explicit opt-in syntax. Two options: +1. `@submodel my_submodel(x=1, y=2, z=3)` — a macro annotation +2. `~ my_submodel(x=1, y=2, z=3)` — bare tilde with no LHS + +**Option 2 is cleanest** and consistent with the existing `~` operator. The syntax `~ my_submodel(x=1, y=2, z=3)` is unambiguous and easy to detect. + +Actually, even simpler: Julia already parses `~ expr` as a unary `~` call: `Expr(:call, :~, expr)`. But `lhs ~ rhs` is `Expr(:call, :~, lhs, rhs)`. So we need to handle the unary case. + +Wait — actually in Julia, `x ~ y` is not valid syntax by default. GraphPPL uses MacroTools `@capture` with the `~` pattern inside the `@model` macro which does AST rewriting. Let me reconsider. + +Actually, looking at the code more carefully, inside `@model`, the `~` is handled via `@capture(e, lhs_ ~ rhs_)` which matches the infix `~` operator. Julia does parse `a ~ b` as `Expr(:call, :~, a, b)`. And `~ b` would be `Expr(:call, :~, b)`. + +So `~ my_submodel(x=1, y=2)` would be parsed as `Expr(:call, :~, :(my_submodel(x=1, y=2)))` — a unary call. This won't match the existing binary `~` patterns, so we can add a new pipeline step or modify `save_expression_in_tilde` and `convert_tilde_expression` to handle it. + +**Final Design for Zero-Output:** + +Syntax: `~ my_submodel(x = val1, y = val2, z = val3)` + +This is parsed as `Expr(:call, :~, :(my_submodel(x=val1, y=val2, z=val3)))`. + +1. Add `convert_zero_output_tilde` pipeline step (early, before `save_expression_in_tilde`): + ```julia + function convert_zero_output_tilde(e::Expr) + if e.head == :call && length(e.args) == 2 && e.args[1] == :~ + rhs = e.args[2] + # Transform: ~ submodel(...) → __nothing__ ~ submodel(...) + return Expr(:call, :~, :__nothing__, rhs) + end + return e + end + ``` + + We use `__nothing__` as a synthetic symbol that will flow through the pipeline. + +2. In `add_get_or_create_expression`: detect `__nothing__` and skip variable creation. + +3. In `convert_tilde_expression`: detect `__nothing__` and pass `GraphPPL.NothingInterface()` as lhs. + +4. Add `NothingInterface` sentinel type in `graph_engine.jl`. + +5. In `prepare_interfaces`: new dispatch for `NothingInterface` that expects 0 missing interfaces. + +6. In `make_node!` generated function and dispatch: accept `NothingInterface`, skip lhs assignment. + +--- + +## Detailed Change List + +### File: `src/graph_engine.jl` + +1. **Add `NothingInterface` type** (near other interface types): + ```julia + struct NothingInterface end + ``` + +2. **Add `prepare_interfaces` dispatch for zero-output** (after existing `prepare_interfaces`): + ```julia + function prepare_interfaces(model::Model, fform::F, ::NothingInterface, rhs_interfaces::NamedTuple) where {F} + missing = missing_interfaces(model, fform, static(length(rhs_interfaces)), rhs_interfaces) + return prepare_interfaces_zero(missing, fform, rhs_interfaces) + end + + function prepare_interfaces_zero(::StaticInterfaces{I}, fform::F, rhs_interfaces::NamedTuple) where {I, F} + if length(I) != 0 + error("Zero-output call to '$(fform)' but $(length(I)) interfaces are still missing: $(I)") + end + return rhs_interfaces + end + ``` + +3. **Add `prepare_interfaces` dispatch for multi-output** (after existing `prepare_interfaces`): + ```julia + function prepare_interfaces(model::Model, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {F} + n_lhs = length(lhs_interfaces) + missing = missing_interfaces(model, fform, static(length(rhs_interfaces) + n_lhs), rhs_interfaces) + return prepare_interfaces_multi(missing, fform, lhs_interfaces, rhs_interfaces) + end + + function prepare_interfaces_multi(::StaticInterfaces{I}, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {I, F} + if length(I) != length(lhs_interfaces) + error("Node '$(fform)': $(length(I)) missing interfaces $(I) but $(length(lhs_interfaces)) provided on LHS.") + end + keys_all = (I..., keys(rhs_interfaces)...) + vals_all = (lhs_interfaces..., values(rhs_interfaces)...) + return NamedTuple{keys_all}(vals_all) + end + ``` + +4. **Add `make_node!` dispatches for Tuple and NothingInterface LHS** (in the Composite path): + ```julia + # Multi-output: Tuple LHS + make_node!(::True, ::Composite, ::Stochastic, model, ctx, options, fform, + lhs_interface::Tuple, rhs_interfaces::NamedTuple) = + make_node!(Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, + static(length(rhs_interfaces) + length(lhs_interface))) + + # Zero-output: NothingInterface LHS + make_node!(nodetype::Composite, model, ctx, options, fform, lhs_interface::NothingInterface, rhs_interfaces) = + make_node!(True(), nodetype, Stochastic(), model, ctx, options, fform, lhs_interface, rhs_interfaces) + + make_node!(::True, ::Composite, ::Stochastic, model, ctx, options, fform, + lhs_interface::NothingInterface, rhs_interfaces::NamedTuple) = + make_node!(Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, + static(length(rhs_interfaces))) + ``` + +5. **In `get_make_node_function`:** Add two more generated `make_node!` methods for `Tuple` and `NothingInterface` LHS types. + +### File: `src/model_macro.jl` + +6. **Add `generate_lhs_proxylabel` for tuples** (near line 606): + Already handled by generating individual proxy labels and packing into a tuple in `convert_tilde_expression`. + +7. **Modify `add_get_or_create_expression`** (line 400): Handle tuple LHS and `__nothing__`. + +8. **Modify `convert_tilde_expression`** (line 641): Handle tuple LHS and `__nothing__`. + +9. **Add `convert_zero_output_tilde` pipeline function**. + +### File: `src/backends/default.jl` + +10. **Add `convert_zero_output_tilde` to the pipeline** (before `save_expression_in_tilde`). + +### File: `test/graph_construction_tests.jl` + +11. **Add tests for multi-output submodel invocation.** +12. **Add tests for zero-output submodel invocation.** + +--- + +## Test Plan + +### Multi-Output Tests + +```julia +@model function two_output_submodel(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) +end + +@model function main_multi_output(x) + (a, b) ~ two_output_submodel(x = x) + y ~ Normal(a + b, 1) +end +``` + +Verify: +- Model creates correctly +- Both `a` and `b` are valid variable references +- The submodel's internal structure is correct +- Interface mapping: `a` → first missing, `b` → second missing (in declaration order) + +### Zero-Output Tests + +```julia +@model function closed_submodel(x, y) + z ~ Normal(x, 1) + y ~ Normal(z, 1) +end + +@model function main_zero_output(x, y) + ~ closed_submodel(x = x, y = y) +end +``` + +Verify: +- Model creates correctly +- All interfaces are provided explicitly +- Error when not all interfaces are specified + +### Error Cases +- `(a, b) ~ submodel(...)` when only 1 interface is missing → error +- `~ submodel(x=...)` when 1 interface is still missing → error +- `(a,) ~ submodel(...)` — single-element tuple should work (1 missing) diff --git a/src/backends/default.jl b/src/backends/default.jl index cc3bd039..f6993d1c 100644 --- a/src/backends/default.jl +++ b/src/backends/default.jl @@ -8,6 +8,7 @@ struct DefaultBackend end function GraphPPL.model_macro_interior_pipelines(::DefaultBackend) return ( GraphPPL.check_reserved_variable_names_model, + GraphPPL.convert_zero_output_tilde, GraphPPL.save_expression_in_tilde, GraphPPL.convert_deterministic_statement, GraphPPL.convert_local_statement, diff --git a/src/graph_engine.jl b/src/graph_engine.jl index f9961354..1d877805 100644 --- a/src/graph_engine.jl +++ b/src/graph_engine.jl @@ -290,6 +290,14 @@ Base.show(io::IO, label::EdgeLabel) = print(io, to_symbol(label)) Base.:(==)(label1::EdgeLabel, label2::EdgeLabel) = label1.name == label2.name && label1.index == label2.index Base.hash(label::EdgeLabel, h::UInt) = hash(label.name, hash(label.index, h)) +""" + NothingInterface + +A sentinel type used to represent a zero-output submodel invocation (no LHS variable). +When used as the `lhs_interface` in `make_node!`, all interfaces must be specified on the RHS. +""" +struct NothingInterface end + """ Splat{T} @@ -1863,6 +1871,41 @@ function prepare_interfaces(::StaticInterfaces{I}, fform::F, lhs_interface, rhs_ return NamedTuple{(missing_interface, keys(rhs_interfaces)...)}((lhs_interface, values(rhs_interfaces)...)) end +# Multi-output: lhs_interfaces is a Tuple of multiple interfaces +function prepare_interfaces(model::Model, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {F} + n_lhs = length(lhs_interfaces) + missing = missing_interfaces(model, fform, static(length(rhs_interfaces) + n_lhs), rhs_interfaces) + return prepare_interfaces_multi(missing, fform, lhs_interfaces, rhs_interfaces) +end + +function prepare_interfaces_multi(::StaticInterfaces{I}, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {I, F} + if length(I) != length(lhs_interfaces) + n = "\n" + error( + lazy"Node '$(fform)' has $(length(I)) missing interface(s) $(I) but $(length(lhs_interfaces)) were provided on the LHS.$(n)$(n)Currently specified interfaces are: $(keys(rhs_interfaces))." + ) + end + all_keys = (I..., keys(rhs_interfaces)...) + all_vals = (lhs_interfaces..., values(rhs_interfaces)...) + return NamedTuple{all_keys}(all_vals) +end + +# Zero-output: NothingInterface means all interfaces are on the RHS +function prepare_interfaces(model::Model, fform::F, ::NothingInterface, rhs_interfaces::NamedTuple) where {F} + missing = missing_interfaces(model, fform, static(length(rhs_interfaces)), rhs_interfaces) + return prepare_interfaces_zero(missing, fform, rhs_interfaces) +end + +function prepare_interfaces_zero(::StaticInterfaces{I}, fform::F, rhs_interfaces::NamedTuple) where {I, F} + if length(I) != 0 + n = "\n" + error( + lazy"Zero-output call to '$(fform)' but $(length(I)) interface(s) are still missing: $(I).$(n)$(n)Currently specified interfaces are: $(keys(rhs_interfaces)). All interfaces must be provided for a zero-output submodel call." + ) + end + return rhs_interfaces +end + function materialize_interface(model, context, interface) return getifcreated(model, context, unroll(interface)) end @@ -2012,6 +2055,20 @@ make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + 1) ) +# Multi-output: Tuple LHS for composite nodes +make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::Tuple, rhs_interfaces::NamedTuple) where {F} = make_node!( + Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + length(lhs_interface)) +) + +# Zero-output: NothingInterface LHS for composite nodes +make_node!(nodetype::Composite, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::NothingInterface, rhs_interfaces) where {F} = make_node!( + True(), nodetype, Stochastic(), model, ctx, options, fform, lhs_interface, rhs_interfaces +) + +make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::NothingInterface, rhs_interfaces::NamedTuple) where {F} = make_node!( + Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces)) +) + """ make_node! diff --git a/src/model_macro.jl b/src/model_macro.jl index 69607905..a50e55e0 100644 --- a/src/model_macro.jl +++ b/src/model_macro.jl @@ -149,7 +149,8 @@ function check_reserved_variable_names_model(e::Expr) :(__lhs_interface__), :(__rhs_interfaces__), :(__interfaces__), - :(__n_interfaces__) + :(__n_interfaces__), + :(__nothing__) ] ) error("Variable name in $(prettify(e)) cannot be used as it is a reserved variable name in the model macro.") @@ -166,6 +167,21 @@ end what_walk(::typeof(check_incomplete_factorization_constraint)) = walk_until_occurrence((:(lhs_ = rhs_), :(lhs_::rhs_))) +""" + convert_zero_output_tilde(e::Expr) + +Converts a zero-output submodel call using unary `~` syntax to a binary `~` with `__nothing__` as the LHS. +This transforms `~ submodel(x = val1, y = val2)` into `__nothing__ ~ submodel(x = val1, y = val2)`. +""" +function convert_zero_output_tilde(e::Expr) + # Unary ~ is parsed as Expr(:call, :~, rhs) + if e.head == :call && length(e.args) == 2 && e.args[1] == :~ + rhs = e.args[2] + return Expr(:call, :~, :__nothing__, rhs) + end + return e +end + """ save_expression_in_tilde(expr::Expr) @@ -399,10 +415,25 @@ A `quote` block with the modified expression. """ function add_get_or_create_expression(e::Expr) if @capture(e, (lhs_ ~ rhs_ where {options__})) - @capture(lhs, (var_[index__]) | (var_)) - return quote - $(generate_get_or_create(var, index, rhs)) - $e + if lhs === :__nothing__ + # Zero-output: no variable to get or create + return e + elseif lhs isa Expr && lhs.head == :tuple + # Multi-output: get_or_create each element in the tuple + creates = map(lhs.args) do elem + @capture(elem, (var_[index__]) | (var_)) || error("Invalid tuple element on LHS: $(elem). Must be in a `var` or `var[index]` form.") + generate_get_or_create(var, index, rhs) + end + return quote + $(creates...) + $e + end + else + @capture(lhs, (var_[index__]) | (var_)) + return quote + $(generate_get_or_create(var, index, rhs)) + $e + end end end return e @@ -647,13 +678,40 @@ function convert_tilde_expression(e::Expr) options = GraphPPL.options_vector_to_named_tuple(options) nodesym = gensym(:node) varsym = gensym(:var) - @capture(lhs, (var_[index__]) | (var_)) || error("Invalid left-hand side $(lhs). Must be in a `var` or `var[index]` form.") - return quote - begin - $nodesym, $varsym = GraphPPL.make_node!( - __model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $(generate_lhs_proxylabel(var, index)), $args - ) - $varsym + if lhs === :__nothing__ + # Zero-output: pass NothingInterface as lhs + return quote + begin + $nodesym, $varsym = GraphPPL.make_node!( + __model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, GraphPPL.NothingInterface(), $args + ) + $varsym + end + end + elseif lhs isa Expr && lhs.head == :tuple + # Multi-output: generate a tuple of proxy labels for the LHS + proxy_labels = map(lhs.args) do elem + @capture(elem, (var_[index__]) | (var_)) || error("Invalid tuple element on LHS: $(elem). Must be in a `var` or `var[index]` form.") + generate_lhs_proxylabel(var, index) + end + lhs_tuple = Expr(:tuple, proxy_labels...) + return quote + begin + $nodesym, $varsym = GraphPPL.make_node!( + __model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $lhs_tuple, $args + ) + $varsym + end + end + else + @capture(lhs, (var_[index__]) | (var_)) || error("Invalid left-hand side $(lhs). Must be in a `var` or `var[index]` form.") + return quote + begin + $nodesym, $varsym = GraphPPL.make_node!( + __model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $(generate_lhs_proxylabel(var, index)), $args + ) + $varsym + end end end elseif @capture(e, (lhs_ .~ fform_(args__; kwargs__) where {options__}) | (lhs_ .~ fform_(args__) where {options__})) @@ -781,6 +839,50 @@ function get_make_node_function(model_specification, ms_body, ms_args, ms_name) return __context__, __lhs_interface__ end + # Multi-output: Tuple LHS + function GraphPPL.make_node!( + ::GraphPPL.Composite, + __model__::GraphPPL.Model, + __parent_context__::GraphPPL.Context, + __options__::GraphPPL.NodeCreationOptions, + ::typeof($ms_name), + __lhs_interface__::Tuple, + __rhs_interfaces__::NamedTuple, + __n_interfaces__::GraphPPL.StaticInt{$(length(ms_args))} + ) + __interfaces__ = GraphPPL.prepare_interfaces(__model__, $ms_name, __lhs_interface__, __rhs_interfaces__) + __context__ = GraphPPL.Context(__parent_context__, $ms_name) + GraphPPL.copy_markov_blanket_to_child_context(__context__, __interfaces__) + GraphPPL.add_composite_factor_node!(__model__, __parent_context__, __context__, $ms_name) + __returnval__ = GraphPPL.add_terminated_submodel!( + __model__, __context__, __options__, $ms_name, __interfaces__, __n_interfaces__ + ) + GraphPPL.returnval!(__context__, __returnval__) + return __context__, __lhs_interface__ + end + + # Zero-output: NothingInterface LHS + function GraphPPL.make_node!( + ::GraphPPL.Composite, + __model__::GraphPPL.Model, + __parent_context__::GraphPPL.Context, + __options__::GraphPPL.NodeCreationOptions, + ::typeof($ms_name), + __lhs_interface__::GraphPPL.NothingInterface, + __rhs_interfaces__::NamedTuple, + __n_interfaces__::GraphPPL.StaticInt{$(length(ms_args))} + ) + __interfaces__ = GraphPPL.prepare_interfaces(__model__, $ms_name, __lhs_interface__, __rhs_interfaces__) + __context__ = GraphPPL.Context(__parent_context__, $ms_name) + GraphPPL.copy_markov_blanket_to_child_context(__context__, __interfaces__) + GraphPPL.add_composite_factor_node!(__model__, __parent_context__, __context__, $ms_name) + __returnval__ = GraphPPL.add_terminated_submodel!( + __model__, __context__, __options__, $ms_name, __interfaces__, __n_interfaces__ + ) + GraphPPL.returnval!(__context__, __returnval__) + return __context__, __lhs_interface__ + end + function GraphPPL.add_terminated_submodel!( __model__::GraphPPL.Model, __context__::GraphPPL.Context, diff --git a/test/multi_and_zero_output_tests.jl b/test/multi_and_zero_output_tests.jl new file mode 100644 index 00000000..a9917320 --- /dev/null +++ b/test/multi_and_zero_output_tests.jl @@ -0,0 +1,233 @@ +@testitem "Multi-output submodel: two outputs" begin + using Distributions + import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function two_output_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) + end + + @model function main_multi_two(x) + (a, b) ~ two_output_sub(x = x) + y ~ Normal(a, b) + end + + model = create_model(main_multi_two()) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + return (x = x,) + end + + # The submodel has 1 Normal inside, plus the outer model has 1 Normal => 3 total Normal nodes + # (two_output_sub creates 2 Normals, main creates 1) + @test length(collect(filter(as_node(Normal), model))) === 3 + @test length(collect(filter(as_variable(:a), model))) >= 1 + @test length(collect(filter(as_variable(:b), model))) >= 1 + @test length(collect(filter(as_variable(:x), model))) === 1 +end + +@testitem "Multi-output submodel: three outputs" begin + using Distributions + import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function three_output_sub(a, b, c, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) + c ~ Normal(b, 1) + end + + @model function main_multi_three(x) + (a, b, c) ~ three_output_sub(x = x) + y ~ Normal(a + b, c) + end + + model = create_model(main_multi_three()) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + return (x = x,) + end + + # three_output_sub: 3 Normals, main: 1 Normal => 4 Normals total + @test length(collect(filter(as_node(Normal), model))) === 4 + @test length(collect(filter(as_variable(:a), model))) >= 1 + @test length(collect(filter(as_variable(:b), model))) >= 1 + @test length(collect(filter(as_variable(:c), model))) >= 1 +end + +@testitem "Multi-output submodel: wrong number of LHS variables" begin + using Distributions + import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function two_iface_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) + end + + # Only one missing interface (a) but providing two on LHS + @model function main_wrong_count(x) + (a, b) ~ two_iface_sub(x = x, b = b_val) where {b_val = 1.0} + end + + # Actually let's define a simpler test that errors: + # two_iface_sub has 3 interfaces: a, b, x + # If we provide x and b on RHS, only a is missing + # But we try to capture (p, q) on LHS (2 outputs for 1 missing) -> error + @model function main_mismatch(x) + b ~ Normal(0, 1) + (p, q) ~ two_iface_sub(x = x, b = b) + end + + @test_throws Exception create_model(main_mismatch()) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + return (x = x,) + end +end + +@testitem "Zero-output submodel: all interfaces provided" begin + using Distributions + import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function closed_sub(x, y) + z ~ Normal(x, 1) + y ~ Normal(z, 1) + end + + @model function main_zero_output(x, y) + ~ closed_sub(x = x, y = y) + end + + model = create_model(main_zero_output()) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + y = datalabel(model, ctx, NodeCreationOptions(kind = :data), :y, 2.0) + return (x = x, y = y) + end + + @test length(collect(filter(as_node(Normal), model))) === 2 + @test length(collect(filter(as_variable(:x), model))) === 1 + @test length(collect(filter(as_variable(:y), model))) === 1 +end + +@testitem "Zero-output submodel: missing interface should error" begin + using Distributions + import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function needs_output_sub(a, x) + a ~ Normal(x, 1) + end + + # Only providing x, but 'a' is still missing -> should error + @model function main_zero_missing(x) + ~ needs_output_sub(x = x) + end + + @test_throws Exception create_model(main_zero_missing()) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + return (x = x,) + end +end + +@testitem "Zero-output submodel: with constants inside" begin + using Distributions + import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function fully_closed_sub(x) + y ~ Normal(x, 1) + z ~ Normal(y, 2) + end + + @model function main_fully_closed(x) + ~ fully_closed_sub(x = x) + end + + model = create_model(main_fully_closed()) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 5.0) + return (x = x,) + end + + @test length(collect(filter(as_node(Normal), model))) === 2 + @test length(collect(filter(as_variable(:x), model))) === 1 +end + +@testitem "Multi-output submodel in a loop" begin + using Distributions + import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions, VariableKindData + + include("testutils.jl") + + @model function pair_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) + end + + @model function main_loop_multi(x, n) + for i in 1:n + (a[i], b[i]) ~ pair_sub(x = x) + end + end + + model = create_model(main_loop_multi(n = 3)) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + return (x = x,) + end + + # Each iteration creates 2 Normals -> 3 * 2 = 6 + @test length(collect(filter(as_node(Normal), model))) === 6 + @test length(collect(filter(as_variable(:a), model))) === 3 + @test length(collect(filter(as_variable(:b), model))) === 3 +end + +@testitem "Zero-output submodel in a loop" begin + using Distributions + import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions, VariableKindData + + include("testutils.jl") + + @model function observe_sub(x, y) + y ~ Normal(x, 1) + end + + @model function main_loop_zero(x, y, n) + for i in 1:n + ~ observe_sub(x = x[i], y = y[i]) + end + end + + model = create_model(main_loop_zero(n = 4)) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = VariableKindData), :x) + y = datalabel(model, ctx, NodeCreationOptions(kind = VariableKindData), :y) + return (x = x, y = y) + end + + @test length(collect(filter(as_node(Normal), model))) === 4 + @test length(collect(filter(as_variable(:x), model))) === 4 + @test length(collect(filter(as_variable(:y), model))) === 4 +end + +@testitem "convert_zero_output_tilde pipeline step" begin + import GraphPPL: apply_pipeline, convert_zero_output_tilde + + include("testutils.jl") + + # Test that unary ~ is converted to binary with __nothing__ + input = Expr(:call, :~, :(my_submodel(x = 1, y = 2))) + output = apply_pipeline(input, convert_zero_output_tilde) + @test output.head == :call + @test output.args[1] == :~ + @test output.args[2] == :__nothing__ + @test output.args[3] == :(my_submodel(x = 1, y = 2)) + + # Test that binary ~ is left unchanged + input2 = :(x ~ Normal(0, 1)) + output2 = apply_pipeline(input2, convert_zero_output_tilde) + @test output2 == input2 +end From 5cd3a0f078164c7b124bf5bb651c9ea9c44c1c4c Mon Sep 17 00:00:00 2001 From: Alex <119050107+ofSingularMind@users.noreply.github.com> Date: Sun, 19 Apr 2026 15:11:34 +0200 Subject: [PATCH 2/4] cleanup --- test/multi_and_zero_output_tests.jl | 8 -------- 1 file changed, 8 deletions(-) diff --git a/test/multi_and_zero_output_tests.jl b/test/multi_and_zero_output_tests.jl index a9917320..c300cbd1 100644 --- a/test/multi_and_zero_output_tests.jl +++ b/test/multi_and_zero_output_tests.jl @@ -68,14 +68,6 @@ end end # Only one missing interface (a) but providing two on LHS - @model function main_wrong_count(x) - (a, b) ~ two_iface_sub(x = x, b = b_val) where {b_val = 1.0} - end - - # Actually let's define a simpler test that errors: - # two_iface_sub has 3 interfaces: a, b, x - # If we provide x and b on RHS, only a is missing - # But we try to capture (p, q) on LHS (2 outputs for 1 missing) -> error @model function main_mismatch(x) b ~ Normal(0, 1) (p, q) ~ two_iface_sub(x = x, b = b) From e001f0d5614a3f4d71413b260ad7d3df42fdcd65 Mon Sep 17 00:00:00 2001 From: Alex <119050107+ofSingularMind@users.noreply.github.com> Date: Sun, 19 Apr 2026 15:13:55 +0200 Subject: [PATCH 3/4] untrack md --- MULTI_AND_ZERO_OUTPUT_PLAN.md | 379 ---------------------------------- 1 file changed, 379 deletions(-) delete mode 100644 MULTI_AND_ZERO_OUTPUT_PLAN.md diff --git a/MULTI_AND_ZERO_OUTPUT_PLAN.md b/MULTI_AND_ZERO_OUTPUT_PLAN.md deleted file mode 100644 index a1784f17..00000000 --- a/MULTI_AND_ZERO_OUTPUT_PLAN.md +++ /dev/null @@ -1,379 +0,0 @@ -# Implementation Plan: Multi-Output and Zero-Output Submodel Interfaces - -## Overview - -Two new features for GraphPPL.jl's nested model specification: - -**(a) Multi-output LHS** — Allow multiple return interfaces on the left-hand side of `~`: -```julia -(a, b) ~ my_submodel(c = c_val, d = d_val) -``` -This means 2 interfaces are "missing" from the RHS and are assigned to `a` and `b` in declaration order. - -**(b) Zero-output (bare call)** — Allow submodel invocation without `~`: -```julia -my_submodel(x = x_val, y = y_val, z = z_val) -``` -All interfaces are explicitly provided on the RHS; there is no LHS variable. - ---- - -## Current Architecture (Summary) - -The `@model` macro body is transformed through a pipeline of expression-rewriting functions (defined in `src/backends/default.jl`): - -1. `check_reserved_variable_names_model` -2. `save_expression_in_tilde` — injects `created_by` option -3. `convert_deterministic_statement` — `:=` → `~` with `is_deterministic` -4. `convert_local_statement` — `local x ~ ...` -5. `convert_to_kwargs_expression` — positional kw-like args → kwargs -6. `add_get_or_create_expression` — ensure LHS var exists in model -7. `convert_anonymous_variables` — nested `f(g(...))` → anonymous + tilde -8. `replace_begin_end` — `begin`/`end` → `FunctionalIndex` -9. `convert_tilde_expression` — final transform to `make_node!` call - -### Where the single-output assumption is enforced: - -| Location | Constraint | -|---|---| -| `model_macro.jl:~651` — `convert_tilde_expression` | `@capture(lhs, (var_[index__]) \| (var_))` only matches single symbol or indexed symbol | -| `model_macro.jl:~606` — `generate_lhs_proxylabel` | Takes a single `var` symbol | -| `model_macro.jl:~406` — `add_get_or_create_expression` | Calls `@capture(lhs, (var_[index__]) \| (var_))` — single var | -| `graph_engine.jl:~1846` — `prepare_interfaces` | `static(length(rhs_interfaces)) + static(1)` — hardcoded +1 | -| `graph_engine.jl:~1850-1863` — `prepare_interfaces` dispatch | Errors if missing ≠ 1 | -| `graph_engine.jl:~generated make_node!` | `__lhs_interface__::Union{NodeLabel, ProxyLabel, VariableRef}` — single type | -| `model_macro.jl:~780` — `get_make_node_function` | `make_node!` signature expects single `__lhs_interface__` | - -### For zero-output (bare function call): -Without `~`, the expression never enters the tilde pipeline at all — it's treated as a regular Julia function call. No pipeline step captures bare `my_submodel(x=..., y=..., z=...)`. - ---- - -## Implementation Plan - -### Feature (a): Multi-Output LHS `(a, b) ~ my_submodel(...)` - -#### 1. `save_expression_in_tilde` (model_macro.jl ~line 178) -**No change needed.** The `@capture(e, (lhs_ ~ rhs_ ...))` already captures any LHS expression including tuples, since `lhs_` is a generic pattern. - -#### 2. `add_get_or_create_expression` (model_macro.jl ~line 400) -**Change:** Add a new branch that detects when `lhs` is a tuple expression. For each element in the tuple, call `generate_get_or_create` separately. - -```julia -function add_get_or_create_expression(e::Expr) - if @capture(e, (lhs_ ~ rhs_ where {options__})) - if lhs isa Expr && lhs.head == :tuple - # Multi-output: get_or_create each element - creates = map(lhs.args) do elem - @capture(elem, (var_[index__]) | (var_)) - generate_get_or_create(var, index, rhs) - end - return quote - $(creates...) - $e - end - else - @capture(lhs, (var_[index__]) | (var_)) - return quote - $(generate_get_or_create(var, index, rhs)) - $e - end - end - end - return e -end -``` - -#### 3. `convert_tilde_expression` (model_macro.jl ~line 641) -**Change:** Add a new branch for tuple LHS. Generate multiple `ProxyLabel`s and pack them into a tuple, then call `make_node!` with the tuple. - -```julia -function convert_tilde_expression(e::Expr) - if @capture(e, (lhs_ ~ fform_(args__; kwargs__) where {options__}) | ...) - args = GraphPPL.proxy_args(combine_args(args, kwargs)) - options = GraphPPL.options_vector_to_named_tuple(options) - nodesym = gensym(:node) - varsym = gensym(:var) - - if lhs isa Expr && lhs.head == :tuple - # Multi-output LHS: (a, b) ~ submodel(...) - proxy_labels = map(lhs.args) do elem - @capture(elem, (var_[index__]) | (var_)) || error("Invalid tuple element $(elem)") - generate_lhs_proxylabel(var, index) - end - lhs_tuple = Expr(:tuple, proxy_labels...) - return quote - begin - $nodesym, $varsym = GraphPPL.make_node!( - __model__, __context__, GraphPPL.NodeCreationOptions($(options)), - $fform, $lhs_tuple, $args - ) - $varsym - end - end - else - # Existing single-output path - @capture(lhs, (var_[index__]) | (var_)) || error(...) - ... - end - end -end -``` - -#### 4. `generate_lhs_proxylabel` — tuple variant (model_macro.jl ~line 606) -**Add:** A new method to generate a tuple of proxy labels. Actually handled inline in step 3 above. - -#### 5. `prepare_interfaces` (graph_engine.jl ~line 1845) -**Change:** Add a new dispatch for multi-output. When `lhs_interface` is a `Tuple`, compute `length(rhs) + length(lhs_tuple)` as total interface count, then match multiple missing interfaces to the tuple elements. - -```julia -# New: multi-output prepare_interfaces -function prepare_interfaces(model::Model, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {F} - n_total = static(length(rhs_interfaces)) + static(length(lhs_interfaces)) - missing = missing_interfaces(model, fform, n_total, rhs_interfaces) - return prepare_interfaces_multi(missing, fform, lhs_interfaces, rhs_interfaces) -end - -function prepare_interfaces_multi(::StaticInterfaces{I}, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {I, F} - if length(I) != length(lhs_interfaces) - error("Node '$(fform)' has $(length(I)) missing interfaces but $(length(lhs_interfaces)) were provided on the LHS.") - end - # Build NamedTuple: missing interfaces mapped to lhs elements, then rhs - all_keys = (I..., keys(rhs_interfaces)...) - all_vals = (lhs_interfaces..., values(rhs_interfaces)...) - return NamedTuple{all_keys}(all_vals) -end -``` - -#### 6. `make_node!` generated function (model_macro.jl ~line 760) -**Change:** Add a second `make_node!` method in `get_make_node_function` that accepts a `Tuple` for `__lhs_interface__`. This method calls `prepare_interfaces` with the tuple, then proceeds as normal but returns the tuple of variables. - -Also add a dispatch in `graph_engine.jl` for the `Composite` path that routes `Tuple` lhs through to the generated function. - -#### 7. `make_node!` dispatch chain (graph_engine.jl ~line 1930+) -**Change:** Add dispatches to route `lhs_interface::Tuple` through the composite node path. - -```julia -# Multi-output composite dispatch -make_node!(::True, ::Composite, ::Stochastic, model, ctx, options, fform, - lhs_interface::Tuple, rhs_interfaces::NamedTuple) = - make_node!(Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, - static(length(rhs_interfaces) + length(lhs_interface))) -``` - ---- - -### Feature (b): Zero-Output Bare Call `my_submodel(x=..., y=..., z=...)` - -#### 1. New pipeline step: `convert_bare_submodel_call` (model_macro.jl) -**Add:** A new pipeline function that detects bare function calls with all keyword arguments and converts them to a `~` expression with no LHS. - -```julia -function convert_bare_submodel_call(e::Expr) - if @capture(e, fform_(; kwargs__)) || @capture(e, fform_(args__)) - if kwargs !== nothing || (args !== nothing && is_kwargs_expression(args)) - # Check if fform is a known Composite node — we can't know at macro time, - # so we generate a runtime check - return :(GraphPPL.make_node_no_output!(__model__, __context__, - GraphPPL.NodeCreationOptions(), $fform, $(keyword_expressions_to_named_tuple(args_or_kwargs)))) - end - end - return e -end -``` - -Actually, a simpler approach: convert the bare call to a tilde expression with a special sentinel LHS, like `nothing ~ submodel(...)`. This requires less invasive changes. - -**Better approach:** Add a new pipeline step that runs before `save_expression_in_tilde`. It detects function calls where ALL arguments are keyword arguments AND the function is annotated with `@model`. Since we can't know at macro-expansion time whether a function is `@model`-defined, we use a **runtime check**: wrap the call in `GraphPPL.__maybe_submodel_call(...)` which checks `NodeType` at runtime. - -**Simplest approach (chosen):** We define a new pipeline step `convert_zero_output_submodel` that transforms: -```julia -my_submodel(x = val1, y = val2, z = val3) -``` -into: -```julia -nothing ~ my_submodel(x = val1, y = val2, z = val3) -``` - -Then modify the tilde pipeline to handle `nothing` on the LHS by: -- In `add_get_or_create_expression`: skip get_or_create when lhs is `nothing` (literal) -- In `convert_tilde_expression`: when lhs is `:nothing`, pass a sentinel (e.g., `GraphPPL.NothingInterface()`) -- In `prepare_interfaces`: add a dispatch for `NothingInterface` that asserts 0 missing interfaces -- In `make_node!`: add dispatch for `NothingInterface` LHS that skips the lhs_interface assignment - -**Issue:** At macro time, we cannot distinguish `my_submodel(x=1, y=2)` (which should become a submodel call) from `some_function(x=1, y=2)` (which is a regular Julia function call). The pipeline would incorrectly transform all keyword-only function calls. - -**Resolution:** We require explicit opt-in syntax. Two options: -1. `@submodel my_submodel(x=1, y=2, z=3)` — a macro annotation -2. `~ my_submodel(x=1, y=2, z=3)` — bare tilde with no LHS - -**Option 2 is cleanest** and consistent with the existing `~` operator. The syntax `~ my_submodel(x=1, y=2, z=3)` is unambiguous and easy to detect. - -Actually, even simpler: Julia already parses `~ expr` as a unary `~` call: `Expr(:call, :~, expr)`. But `lhs ~ rhs` is `Expr(:call, :~, lhs, rhs)`. So we need to handle the unary case. - -Wait — actually in Julia, `x ~ y` is not valid syntax by default. GraphPPL uses MacroTools `@capture` with the `~` pattern inside the `@model` macro which does AST rewriting. Let me reconsider. - -Actually, looking at the code more carefully, inside `@model`, the `~` is handled via `@capture(e, lhs_ ~ rhs_)` which matches the infix `~` operator. Julia does parse `a ~ b` as `Expr(:call, :~, a, b)`. And `~ b` would be `Expr(:call, :~, b)`. - -So `~ my_submodel(x=1, y=2)` would be parsed as `Expr(:call, :~, :(my_submodel(x=1, y=2)))` — a unary call. This won't match the existing binary `~` patterns, so we can add a new pipeline step or modify `save_expression_in_tilde` and `convert_tilde_expression` to handle it. - -**Final Design for Zero-Output:** - -Syntax: `~ my_submodel(x = val1, y = val2, z = val3)` - -This is parsed as `Expr(:call, :~, :(my_submodel(x=val1, y=val2, z=val3)))`. - -1. Add `convert_zero_output_tilde` pipeline step (early, before `save_expression_in_tilde`): - ```julia - function convert_zero_output_tilde(e::Expr) - if e.head == :call && length(e.args) == 2 && e.args[1] == :~ - rhs = e.args[2] - # Transform: ~ submodel(...) → __nothing__ ~ submodel(...) - return Expr(:call, :~, :__nothing__, rhs) - end - return e - end - ``` - - We use `__nothing__` as a synthetic symbol that will flow through the pipeline. - -2. In `add_get_or_create_expression`: detect `__nothing__` and skip variable creation. - -3. In `convert_tilde_expression`: detect `__nothing__` and pass `GraphPPL.NothingInterface()` as lhs. - -4. Add `NothingInterface` sentinel type in `graph_engine.jl`. - -5. In `prepare_interfaces`: new dispatch for `NothingInterface` that expects 0 missing interfaces. - -6. In `make_node!` generated function and dispatch: accept `NothingInterface`, skip lhs assignment. - ---- - -## Detailed Change List - -### File: `src/graph_engine.jl` - -1. **Add `NothingInterface` type** (near other interface types): - ```julia - struct NothingInterface end - ``` - -2. **Add `prepare_interfaces` dispatch for zero-output** (after existing `prepare_interfaces`): - ```julia - function prepare_interfaces(model::Model, fform::F, ::NothingInterface, rhs_interfaces::NamedTuple) where {F} - missing = missing_interfaces(model, fform, static(length(rhs_interfaces)), rhs_interfaces) - return prepare_interfaces_zero(missing, fform, rhs_interfaces) - end - - function prepare_interfaces_zero(::StaticInterfaces{I}, fform::F, rhs_interfaces::NamedTuple) where {I, F} - if length(I) != 0 - error("Zero-output call to '$(fform)' but $(length(I)) interfaces are still missing: $(I)") - end - return rhs_interfaces - end - ``` - -3. **Add `prepare_interfaces` dispatch for multi-output** (after existing `prepare_interfaces`): - ```julia - function prepare_interfaces(model::Model, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {F} - n_lhs = length(lhs_interfaces) - missing = missing_interfaces(model, fform, static(length(rhs_interfaces) + n_lhs), rhs_interfaces) - return prepare_interfaces_multi(missing, fform, lhs_interfaces, rhs_interfaces) - end - - function prepare_interfaces_multi(::StaticInterfaces{I}, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {I, F} - if length(I) != length(lhs_interfaces) - error("Node '$(fform)': $(length(I)) missing interfaces $(I) but $(length(lhs_interfaces)) provided on LHS.") - end - keys_all = (I..., keys(rhs_interfaces)...) - vals_all = (lhs_interfaces..., values(rhs_interfaces)...) - return NamedTuple{keys_all}(vals_all) - end - ``` - -4. **Add `make_node!` dispatches for Tuple and NothingInterface LHS** (in the Composite path): - ```julia - # Multi-output: Tuple LHS - make_node!(::True, ::Composite, ::Stochastic, model, ctx, options, fform, - lhs_interface::Tuple, rhs_interfaces::NamedTuple) = - make_node!(Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, - static(length(rhs_interfaces) + length(lhs_interface))) - - # Zero-output: NothingInterface LHS - make_node!(nodetype::Composite, model, ctx, options, fform, lhs_interface::NothingInterface, rhs_interfaces) = - make_node!(True(), nodetype, Stochastic(), model, ctx, options, fform, lhs_interface, rhs_interfaces) - - make_node!(::True, ::Composite, ::Stochastic, model, ctx, options, fform, - lhs_interface::NothingInterface, rhs_interfaces::NamedTuple) = - make_node!(Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, - static(length(rhs_interfaces))) - ``` - -5. **In `get_make_node_function`:** Add two more generated `make_node!` methods for `Tuple` and `NothingInterface` LHS types. - -### File: `src/model_macro.jl` - -6. **Add `generate_lhs_proxylabel` for tuples** (near line 606): - Already handled by generating individual proxy labels and packing into a tuple in `convert_tilde_expression`. - -7. **Modify `add_get_or_create_expression`** (line 400): Handle tuple LHS and `__nothing__`. - -8. **Modify `convert_tilde_expression`** (line 641): Handle tuple LHS and `__nothing__`. - -9. **Add `convert_zero_output_tilde` pipeline function**. - -### File: `src/backends/default.jl` - -10. **Add `convert_zero_output_tilde` to the pipeline** (before `save_expression_in_tilde`). - -### File: `test/graph_construction_tests.jl` - -11. **Add tests for multi-output submodel invocation.** -12. **Add tests for zero-output submodel invocation.** - ---- - -## Test Plan - -### Multi-Output Tests - -```julia -@model function two_output_submodel(a, b, x) - a ~ Normal(x, 1) - b ~ Normal(a, 1) -end - -@model function main_multi_output(x) - (a, b) ~ two_output_submodel(x = x) - y ~ Normal(a + b, 1) -end -``` - -Verify: -- Model creates correctly -- Both `a` and `b` are valid variable references -- The submodel's internal structure is correct -- Interface mapping: `a` → first missing, `b` → second missing (in declaration order) - -### Zero-Output Tests - -```julia -@model function closed_submodel(x, y) - z ~ Normal(x, 1) - y ~ Normal(z, 1) -end - -@model function main_zero_output(x, y) - ~ closed_submodel(x = x, y = y) -end -``` - -Verify: -- Model creates correctly -- All interfaces are provided explicitly -- Error when not all interfaces are specified - -### Error Cases -- `(a, b) ~ submodel(...)` when only 1 interface is missing → error -- `~ submodel(x=...)` when 1 interface is still missing → error -- `(a,) ~ submodel(...)` — single-element tuple should work (1 missing) From 36872846c240d6126711d6c2e2a3ab6cca4f4f60 Mon Sep 17 00:00:00 2001 From: Alex <119050107+ofSingularMind@users.noreply.github.com> Date: Tue, 26 May 2026 16:24:30 +0200 Subject: [PATCH 4/4] fix: remove zero-output submodel calling fix: improve test coverage and exactness some docs: document multi-output submodel calls --- docs/src/nested_models.md | 41 +++++- src/backends/default.jl | 1 - src/graph_engine.jl | 57 +++----- src/model_macro.jl | 106 +++++++------- test/multi_and_zero_output_tests.jl | 210 ++++++++++++++++++---------- test/runtests.jl | 2 +- 6 files changed, 253 insertions(+), 164 deletions(-) diff --git a/docs/src/nested_models.md b/docs/src/nested_models.md index 5794eb9e..5a614bed 100644 --- a/docs/src/nested_models.md +++ b/docs/src/nested_models.md @@ -30,4 +30,43 @@ If we want to chain these `gcv` submodels together into a Hierarchical Gaussian end ``` -Note that in our invocations of `gcv`, we haven't specified the `y` argument of the Markov Blanket. This is what is being recognized as the missing interface and `GraphPPL` will assign `means[i]` to `y`. \ No newline at end of file +Note that in our invocations of `gcv`, we haven't specified the `y` argument of the Markov Blanket. This is what is being recognized as the missing interface and `GraphPPL` will assign `means[i]` to `y`. + +## Multi-output submodels + +When a submodel produces multiple outputs — multiple interfaces left unspecified on the RHS — you can bind them all on the LHS using a tuple. There are two syntaxes: + +**Positional:** list outer variables in the same order as the unspecified interfaces appear in the submodel definition. + +``` @example nested-models +@model function linear_gaussian(x, y, z) + x ~ Normal(z, 1) + y ~ Normal(x, 1) +end + +@model function outer_positional(c) + (a, b) ~ linear_gaussian(z = c) # a → interface x, b → interface y (by position) +end +``` + +**Named (kwarg-style):** explicitly map each outer variable to its interface name using `name = var` pairs. This is order-independent and recommended when submodel argument order may change. + +``` @example nested-models +@model function outer_named(my_z) + (y = my_y, x = my_x) ~ linear_gaussian(z = my_z) # binds by name, regardless of order + obs ~ Normal(my_x, my_y) +end +``` + +Both syntaxes work with indexed variables in loops: + +``` @example nested-models +@model function chain(z, n) + for i in 1:n + (x = xs[i], y = ys[i]) ~ linear_gaussian(z = z) + end +end +``` + +!!! note + If the same interface name appears on both LHS and RHS, `GraphPPL` raises an error at model-creation time. Similarly, providing a name on the LHS that does not match any of the submodel's interface names is caught with a descriptive error. \ No newline at end of file diff --git a/src/backends/default.jl b/src/backends/default.jl index f6993d1c..cc3bd039 100644 --- a/src/backends/default.jl +++ b/src/backends/default.jl @@ -8,7 +8,6 @@ struct DefaultBackend end function GraphPPL.model_macro_interior_pipelines(::DefaultBackend) return ( GraphPPL.check_reserved_variable_names_model, - GraphPPL.convert_zero_output_tilde, GraphPPL.save_expression_in_tilde, GraphPPL.convert_deterministic_statement, GraphPPL.convert_local_statement, diff --git a/src/graph_engine.jl b/src/graph_engine.jl index 61a150b9..35ed5a95 100644 --- a/src/graph_engine.jl +++ b/src/graph_engine.jl @@ -290,14 +290,6 @@ Base.show(io::IO, label::EdgeLabel) = print(io, to_symbol(label)) Base.:(==)(label1::EdgeLabel, label2::EdgeLabel) = label1.name == label2.name && label1.index == label2.index Base.hash(label::EdgeLabel, h::UInt) = hash(label.name, hash(label.index, h)) -""" - NothingInterface - -A sentinel type used to represent a zero-output submodel invocation (no LHS variable). -When used as the `lhs_interface` in `make_node!`, all interfaces must be specified on the RHS. -""" -struct NothingInterface end - """ Splat{T} @@ -1211,6 +1203,7 @@ struct StaticInterfaces{I} end StaticInterfaces(I::Tuple) = StaticInterfaces{I}() Base.getindex(::StaticInterfaces{I}, index) where {I} = I[index] +iface_names(::StaticInterfaces{I}) where {I} = I function Base.convert(::Type{NamedTuple}, ::StaticInterfaces{I}, t::Tuple) where {I} return NamedTuple{I}(t) @@ -1878,7 +1871,7 @@ function prepare_interfaces(::StaticInterfaces{I}, fform::F, lhs_interface, rhs_ return NamedTuple{(missing_interface, keys(rhs_interfaces)...)}((lhs_interface, values(rhs_interfaces)...)) end -# Multi-output: lhs_interfaces is a Tuple of multiple interfaces +# Multi-output: lhs_interfaces is a Tuple of multiple interfaces (positional) function prepare_interfaces(model::Model, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {F} n_lhs = length(lhs_interfaces) missing = missing_interfaces(model, fform, static(length(rhs_interfaces) + n_lhs), rhs_interfaces) @@ -1886,31 +1879,27 @@ function prepare_interfaces(model::Model, fform::F, lhs_interfaces::Tuple, rhs_i end function prepare_interfaces_multi(::StaticInterfaces{I}, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {I, F} - if length(I) != length(lhs_interfaces) - n = "\n" - error( - lazy"Node '$(fform)' has $(length(I)) missing interface(s) $(I) but $(length(lhs_interfaces)) were provided on the LHS.$(n)$(n)Currently specified interfaces are: $(keys(rhs_interfaces))." - ) - end all_keys = (I..., keys(rhs_interfaces)...) all_vals = (lhs_interfaces..., values(rhs_interfaces)...) return NamedTuple{all_keys}(all_vals) end -# Zero-output: NothingInterface means all interfaces are on the RHS -function prepare_interfaces(model::Model, fform::F, ::NothingInterface, rhs_interfaces::NamedTuple) where {F} - missing = missing_interfaces(model, fform, static(length(rhs_interfaces)), rhs_interfaces) - return prepare_interfaces_zero(missing, fform, rhs_interfaces) -end - -function prepare_interfaces_zero(::StaticInterfaces{I}, fform::F, rhs_interfaces::NamedTuple) where {I, F} - if length(I) != 0 - n = "\n" - error( - lazy"Zero-output call to '$(fform)' but $(length(I)) interface(s) are still missing: $(I).$(n)$(n)Currently specified interfaces are: $(keys(rhs_interfaces)). All interfaces must be provided for a zero-output submodel call." - ) +# Named-output: lhs_interfaces is a NamedTuple (kwarg-style, e.g. (a = m_a, b = m_b) ~ sub(x = x)) +function prepare_interfaces(model::Model, fform::F, lhs_interfaces::NamedTuple, rhs_interfaces::NamedTuple) where {F} + for k in keys(lhs_interfaces) + if k ∈ keys(rhs_interfaces) + error(lazy"Interface ':$(k)' of '$(fform)' is specified on both LHS and RHS.") + end + end + merged = merge(lhs_interfaces, rhs_interfaces) + all_ifaces = interfaces(model, fform, static(length(merged))) + valid_names = iface_names(all_ifaces) + for k in keys(lhs_interfaces) + if k ∉ valid_names + error(lazy"Interface ':$(k)' does not exist in '$(fform)'. Valid interfaces are: $(valid_names).") + end end - return rhs_interfaces + return merged end function materialize_interface(model, context, interface) @@ -2062,18 +2051,14 @@ make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + 1) ) -# Multi-output: Tuple LHS for composite nodes +# Multi-output: Tuple LHS for composite nodes (positional) make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::Tuple, rhs_interfaces::NamedTuple) where {F} = make_node!( Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + length(lhs_interface)) ) -# Zero-output: NothingInterface LHS for composite nodes -make_node!(nodetype::Composite, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::NothingInterface, rhs_interfaces) where {F} = make_node!( - True(), nodetype, Stochastic(), model, ctx, options, fform, lhs_interface, rhs_interfaces -) - -make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::NothingInterface, rhs_interfaces::NamedTuple) where {F} = make_node!( - Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces)) +# Named-output: NamedTuple LHS for composite nodes (kwarg-style) +make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::NamedTuple, rhs_interfaces::NamedTuple) where {F} = make_node!( + Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + length(lhs_interface)) ) """ diff --git a/src/model_macro.jl b/src/model_macro.jl index 5873350b..5a1388f8 100644 --- a/src/model_macro.jl +++ b/src/model_macro.jl @@ -149,8 +149,7 @@ function check_reserved_variable_names_model(e::Expr) :(__lhs_interface__), :(__rhs_interfaces__), :(__interfaces__), - :(__n_interfaces__), - :(__nothing__) + :(__n_interfaces__) ] ) error("Variable name in $(prettify(e)) cannot be used as it is a reserved variable name in the model macro.") @@ -167,21 +166,6 @@ end what_walk(::typeof(check_incomplete_factorization_constraint)) = walk_until_occurrence((:(lhs_ = rhs_), :(lhs_::rhs_))) -""" - convert_zero_output_tilde(e::Expr) - -Converts a zero-output submodel call using unary `~` syntax to a binary `~` with `__nothing__` as the LHS. -This transforms `~ submodel(x = val1, y = val2)` into `__nothing__ ~ submodel(x = val1, y = val2)`. -""" -function convert_zero_output_tilde(e::Expr) - # Unary ~ is parsed as Expr(:call, :~, rhs) - if e.head == :call && length(e.args) == 2 && e.args[1] == :~ - rhs = e.args[2] - return Expr(:call, :~, :__nothing__, rhs) - end - return e -end - """ save_expression_in_tilde(expr::Expr) @@ -415,14 +399,22 @@ A `quote` block with the modified expression. """ function add_get_or_create_expression(e::Expr) if @capture(e, (lhs_ ~ rhs_ where {options__})) - if lhs === :__nothing__ - # Zero-output: no variable to get or create - return e - elseif lhs isa Expr && lhs.head == :tuple - # Multi-output: get_or_create each element in the tuple - creates = map(lhs.args) do elem - @capture(elem, (var_[index__]) | (var_)) || error("Invalid tuple element on LHS: $(elem). Must be in a `var` or `var[index]` form.") - generate_get_or_create(var, index, rhs) + if lhs isa Expr && lhs.head == :tuple + if is_named_tuple_lhs(lhs) + # Named-output: (a = m_a, b = m_b) ~ sub(...) — get_or_create the outer (RHS) variables + creates = map(lhs.args) do elem + outer_var_expr = elem.args[2] + @capture(outer_var_expr, (var_[index__]) | (var_)) || + error("Invalid named LHS value in $(elem). Value must be in a `var` or `var[index]` form.") + generate_get_or_create(var, index, rhs) + end + else + # Positional-output: (m_a, m_b) ~ sub(...) + creates = map(lhs.args) do elem + @capture(elem, (var_[index__]) | (var_)) || + error("Invalid tuple element on LHS: $(elem). Must be in a `var` or `var[index]` form.") + generate_get_or_create(var, index, rhs) + end end return quote $(creates...) @@ -439,6 +431,9 @@ function add_get_or_create_expression(e::Expr) return e end +is_named_tuple_lhs(lhs) = false +is_named_tuple_lhs(lhs::Expr) = lhs.head === :tuple && !isempty(lhs.args) && all(elem -> elem isa Expr && elem.head === :(=), lhs.args) + what_walk(::typeof(add_get_or_create_expression)) = not_created_by """ @@ -678,29 +673,40 @@ function convert_tilde_expression(e::Expr) options = GraphPPL.options_vector_to_named_tuple(options) nodesym = gensym(:node) varsym = gensym(:var) - if lhs === :__nothing__ - # Zero-output: pass NothingInterface as lhs - return quote - begin - $nodesym, $varsym = GraphPPL.make_node!( - __model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, GraphPPL.NothingInterface(), $args - ) - $varsym + if lhs isa Expr && lhs.head == :tuple + if is_named_tuple_lhs(lhs) + # Named-output: (a = m_a, b = m_b) ~ sub(...) — build a NamedTuple of proxy labels + proxy_pairs = map(lhs.args) do elem + iface_name = elem.args[1] + outer_var_expr = elem.args[2] + @capture(outer_var_expr, (var_[index__]) | (var_)) || + error("Invalid named LHS value in $(elem). Value must be in a `var` or `var[index]` form.") + Expr(:(=), iface_name, generate_lhs_proxylabel(var, index)) end - end - elseif lhs isa Expr && lhs.head == :tuple - # Multi-output: generate a tuple of proxy labels for the LHS - proxy_labels = map(lhs.args) do elem - @capture(elem, (var_[index__]) | (var_)) || error("Invalid tuple element on LHS: $(elem). Must be in a `var` or `var[index]` form.") - generate_lhs_proxylabel(var, index) - end - lhs_tuple = Expr(:tuple, proxy_labels...) - return quote - begin - $nodesym, $varsym = GraphPPL.make_node!( - __model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $lhs_tuple, $args - ) - $varsym + lhs_named = Expr(:tuple, proxy_pairs...) + return quote + begin + $nodesym, $varsym = GraphPPL.make_node!( + __model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $lhs_named, $args + ) + $varsym + end + end + else + # Positional-output: (m_a, m_b) ~ sub(...) — build a plain Tuple of proxy labels + proxy_labels = map(lhs.args) do elem + @capture(elem, (var_[index__]) | (var_)) || + error("Invalid tuple element on LHS: $(elem). Must be in a `var` or `var[index]` form.") + generate_lhs_proxylabel(var, index) + end + lhs_tuple = Expr(:tuple, proxy_labels...) + return quote + begin + $nodesym, $varsym = GraphPPL.make_node!( + __model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $lhs_tuple, $args + ) + $varsym + end end end else @@ -839,7 +845,7 @@ function get_make_node_function(model_specification, ms_body, ms_args, ms_name) return __context__, __lhs_interface__ end - # Multi-output: Tuple LHS + # Multi-output: Tuple LHS (positional) function GraphPPL.make_node!( ::GraphPPL.Composite, __model__::GraphPPL.Model, @@ -861,14 +867,14 @@ function get_make_node_function(model_specification, ms_body, ms_args, ms_name) return __context__, __lhs_interface__ end - # Zero-output: NothingInterface LHS + # Named-output: NamedTuple LHS (kwarg-style) function GraphPPL.make_node!( ::GraphPPL.Composite, __model__::GraphPPL.Model, __parent_context__::GraphPPL.Context, __options__::GraphPPL.NodeCreationOptions, ::typeof($ms_name), - __lhs_interface__::GraphPPL.NothingInterface, + __lhs_interface__::NamedTuple, __rhs_interfaces__::NamedTuple, __n_interfaces__::GraphPPL.StaticInt{$(length(ms_args))} ) diff --git a/test/multi_and_zero_output_tests.jl b/test/multi_and_zero_output_tests.jl index c300cbd1..229a9729 100644 --- a/test/multi_and_zero_output_tests.jl +++ b/test/multi_and_zero_output_tests.jl @@ -22,8 +22,8 @@ # The submodel has 1 Normal inside, plus the outer model has 1 Normal => 3 total Normal nodes # (two_output_sub creates 2 Normals, main creates 1) @test length(collect(filter(as_node(Normal), model))) === 3 - @test length(collect(filter(as_variable(:a), model))) >= 1 - @test length(collect(filter(as_variable(:b), model))) >= 1 + @test length(collect(filter(as_variable(:a), model))) === 1 + @test length(collect(filter(as_variable(:b), model))) === 1 @test length(collect(filter(as_variable(:x), model))) === 1 end @@ -51,9 +51,9 @@ end # three_output_sub: 3 Normals, main: 1 Normal => 4 Normals total @test length(collect(filter(as_node(Normal), model))) === 4 - @test length(collect(filter(as_variable(:a), model))) >= 1 - @test length(collect(filter(as_variable(:b), model))) >= 1 - @test length(collect(filter(as_variable(:c), model))) >= 1 + @test length(collect(filter(as_variable(:a), model))) === 1 + @test length(collect(filter(as_variable(:b), model))) === 1 + @test length(collect(filter(as_variable(:c), model))) === 1 end @testitem "Multi-output submodel: wrong number of LHS variables" begin @@ -73,153 +73,213 @@ end (p, q) ~ two_iface_sub(x = x, b = b) end - @test_throws Exception create_model(main_mismatch()) do model, ctx + @test_throws "no method matching make_node!" create_model(main_mismatch()) do model, ctx x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) return (x = x,) end end -@testitem "Zero-output submodel: all interfaces provided" begin +@testitem "Multi-output submodel in a loop" begin using Distributions - import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions + import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions, VariableKindData include("testutils.jl") - @model function closed_sub(x, y) - z ~ Normal(x, 1) - y ~ Normal(z, 1) + @model function pair_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) end - @model function main_zero_output(x, y) - ~ closed_sub(x = x, y = y) + @model function main_loop_multi(x, n) + for i in 1:n + (a[i], b[i]) ~ pair_sub(x = x) + end end - model = create_model(main_zero_output()) do model, ctx + model = create_model(main_loop_multi(n = 3)) do model, ctx x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) - y = datalabel(model, ctx, NodeCreationOptions(kind = :data), :y, 2.0) - return (x = x, y = y) + return (x = x,) end - @test length(collect(filter(as_node(Normal), model))) === 2 - @test length(collect(filter(as_variable(:x), model))) === 1 - @test length(collect(filter(as_variable(:y), model))) === 1 + # Each iteration creates 2 Normals -> 3 * 2 = 6 + @test length(collect(filter(as_node(Normal), model))) === 6 + @test length(collect(filter(as_variable(:a), model))) === 3 + @test length(collect(filter(as_variable(:b), model))) === 3 end -@testitem "Zero-output submodel: missing interface should error" begin +@testitem "Multi-output submodel: named LHS basic" begin using Distributions - import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions + import GraphPPL: create_model, datalabel, NodeCreationOptions include("testutils.jl") - @model function needs_output_sub(a, x) + @model function named_two_sub(a, b, x) a ~ Normal(x, 1) + b ~ Normal(a, 1) end - # Only providing x, but 'a' is still missing -> should error - @model function main_zero_missing(x) - ~ needs_output_sub(x = x) + @model function main_named_two(x) + (a = out_a, b = out_b) ~ named_two_sub(x = x) + y ~ Normal(out_a, out_b) end - @test_throws Exception create_model(main_zero_missing()) do model, ctx + model = create_model(main_named_two()) do model, ctx x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) return (x = x,) end + + @test length(collect(filter(as_node(Normal), model))) === 3 + @test length(collect(filter(as_variable(:out_a), model))) === 1 + @test length(collect(filter(as_variable(:out_b), model))) === 1 + @test length(collect(filter(as_variable(:x), model))) === 1 end -@testitem "Zero-output submodel: with constants inside" begin +@testitem "Multi-output submodel: named LHS order-invariant" begin using Distributions - import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions + import GraphPPL: create_model, datalabel, NodeCreationOptions include("testutils.jl") - @model function fully_closed_sub(x) - y ~ Normal(x, 1) - z ~ Normal(y, 2) + @model function ordered_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) end - @model function main_fully_closed(x) - ~ fully_closed_sub(x = x) + # Provide b before a on the LHS — must still bind correctly by name + @model function main_swapped(x) + (b = out_b, a = out_a) ~ ordered_sub(x = x) + y ~ Normal(out_a, out_b) end - model = create_model(main_fully_closed()) do model, ctx - x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 5.0) + model = create_model(main_swapped()) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) return (x = x,) end - @test length(collect(filter(as_node(Normal), model))) === 2 - @test length(collect(filter(as_variable(:x), model))) === 1 + @test length(collect(filter(as_node(Normal), model))) === 3 + @test length(collect(filter(as_variable(:out_a), model))) === 1 + @test length(collect(filter(as_variable(:out_b), model))) === 1 end -@testitem "Multi-output submodel in a loop" begin +@testitem "Multi-output submodel: named LHS in a loop" begin using Distributions - import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions, VariableKindData + import GraphPPL: create_model, datalabel, NodeCreationOptions include("testutils.jl") - @model function pair_sub(a, b, x) + @model function loop_sub(a, b, x) a ~ Normal(x, 1) b ~ Normal(a, 1) end - @model function main_loop_multi(x, n) + @model function main_named_loop(x, n) for i in 1:n - (a[i], b[i]) ~ pair_sub(x = x) + (a = out_a[i], b = out_b[i]) ~ loop_sub(x = x) end end - model = create_model(main_loop_multi(n = 3)) do model, ctx + model = create_model(main_named_loop(n = 3)) do model, ctx x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) return (x = x,) end - # Each iteration creates 2 Normals -> 3 * 2 = 6 @test length(collect(filter(as_node(Normal), model))) === 6 - @test length(collect(filter(as_variable(:a), model))) === 3 - @test length(collect(filter(as_variable(:b), model))) === 3 + @test length(collect(filter(as_variable(:out_a), model))) === 3 + @test length(collect(filter(as_variable(:out_b), model))) === 3 end -@testitem "Zero-output submodel in a loop" begin +@testitem "Multi-output submodel: named LHS invalid interface name" begin using Distributions - import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions, VariableKindData + import GraphPPL: create_model, datalabel, NodeCreationOptions include("testutils.jl") - @model function observe_sub(x, y) - y ~ Normal(x, 1) + @model function valid_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) end - @model function main_loop_zero(x, y, n) - for i in 1:n - ~ observe_sub(x = x[i], y = y[i]) - end + @model function main_bad_name(x) + (z = out_z, b = out_b) ~ valid_sub(x = x) + end + + @test_throws "does not exist in" create_model(main_bad_name()) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + return (x = x,) end +end - model = create_model(main_loop_zero(n = 4)) do model, ctx - x = datalabel(model, ctx, NodeCreationOptions(kind = VariableKindData), :x) - y = datalabel(model, ctx, NodeCreationOptions(kind = VariableKindData), :y) - return (x = x, y = y) +@testitem "Multi-output submodel: named LHS conflicts with RHS" begin + using Distributions + import GraphPPL: create_model, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function conflict_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) end - @test length(collect(filter(as_node(Normal), model))) === 4 - @test length(collect(filter(as_variable(:x), model))) === 4 - @test length(collect(filter(as_variable(:y), model))) === 4 + # 'b' appears as a key on both LHS (meaning "bind out_b to interface b") and as a kwarg on RHS — must error + @model function main_conflict(some_var) + (b = out_b, a = out_a) ~ conflict_sub(b = some_var) + end + + @test_throws "is specified on both LHS and RHS" create_model(main_conflict()) do model, ctx + some_var = datalabel(model, ctx, NodeCreationOptions(kind = :data), :some_var, 1.0) + return (some_var = some_var,) + end end -@testitem "convert_zero_output_tilde pipeline step" begin - import GraphPPL: apply_pipeline, convert_zero_output_tilde +@testitem "Multi-output: is_named_tuple_lhs helper" begin + import GraphPPL: is_named_tuple_lhs + + @test is_named_tuple_lhs(:(a)) === false + @test is_named_tuple_lhs(:((a, b))) === false + @test is_named_tuple_lhs(:((a = x, b = y))) === true + @test is_named_tuple_lhs(:((a = x,))) === true + # Mixed tuple (one named, one positional) is not pure named + @test is_named_tuple_lhs(Expr(:tuple, Expr(:(=), :a, :x), :b)) === false +end + +@testitem "Multi-output: add_get_or_create_expression pipeline step" begin + import GraphPPL: apply_pipeline, add_get_or_create_expression include("testutils.jl") - # Test that unary ~ is converted to binary with __nothing__ - input = Expr(:call, :~, :(my_submodel(x = 1, y = 2))) - output = apply_pipeline(input, convert_zero_output_tilde) - @test output.head == :call - @test output.args[1] == :~ - @test output.args[2] == :__nothing__ - @test output.args[3] == :(my_submodel(x = 1, y = 2)) - - # Test that binary ~ is left unchanged - input2 = :(x ~ Normal(0, 1)) - output2 = apply_pipeline(input2, convert_zero_output_tilde) - @test output2 == input2 + # Positional tuple LHS produces get_or_create blocks for each element + input_positional = quote + (a, b) ~ sub(x = x) where {created_by = ((a, b) ~ sub(x = x))} + end + result = apply_pipeline(input_positional, add_get_or_create_expression) + @test result isa Expr + + # Named tuple LHS produces get_or_create blocks for the outer (value) variables + input_named = quote + (a = m_a, b = m_b) ~ sub(x = x) where {created_by = ((a = m_a, b = m_b) ~ sub(x = x))} + end + result = apply_pipeline(input_named, add_get_or_create_expression) + @test result isa Expr +end + +@testitem "Multi-output: convert_tilde_expression pipeline step" begin + import GraphPPL: apply_pipeline, convert_tilde_expression + + include("testutils.jl") + + # Positional tuple LHS emits a make_node! call + input_positional = quote + (a, b) ~ sub(x = x) where {created_by = ((a, b) ~ sub(x = x))} + end + result_positional = apply_pipeline(input_positional, convert_tilde_expression) + @test result_positional isa Expr + @test occursin("make_node!", string(result_positional)) + + # Named tuple LHS also emits a make_node! call + input_named = quote + (a = m_a, b = m_b) ~ sub(x = x) where {created_by = ((a = m_a, b = m_b) ~ sub(x = x))} + end + result_named = apply_pipeline(input_named, convert_tilde_expression) + @test result_named isa Expr + @test occursin("make_node!", string(result_named)) end diff --git a/test/runtests.jl b/test/runtests.jl index 319b12f1..2bfb95ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,7 @@ end nthreads, ncores = Hwloc.num_virtual_cores(), Hwloc.num_physical_cores() nthreads, ncores = max(nthreads, 1), max(ncores, 1) -nworker_threads = Int(nthreads / ncores) +nworker_threads = max(1, div(nthreads, ncores)) memory_threshold = 1.0 pkg_root = dirname(pathof(GraphPPL)) |> dirname