diff --git a/docs/src/nested_models.md b/docs/src/nested_models.md index 5794eb9e..5a614bed 100644 --- a/docs/src/nested_models.md +++ b/docs/src/nested_models.md @@ -30,4 +30,43 @@ If we want to chain these `gcv` submodels together into a Hierarchical Gaussian end ``` -Note that in our invocations of `gcv`, we haven't specified the `y` argument of the Markov Blanket. This is what is being recognized as the missing interface and `GraphPPL` will assign `means[i]` to `y`. \ No newline at end of file +Note that in our invocations of `gcv`, we haven't specified the `y` argument of the Markov Blanket. This is what is being recognized as the missing interface and `GraphPPL` will assign `means[i]` to `y`. + +## Multi-output submodels + +When a submodel produces multiple outputs — multiple interfaces left unspecified on the RHS — you can bind them all on the LHS using a tuple. There are two syntaxes: + +**Positional:** list outer variables in the same order as the unspecified interfaces appear in the submodel definition. + +``` @example nested-models +@model function linear_gaussian(x, y, z) + x ~ Normal(z, 1) + y ~ Normal(x, 1) +end + +@model function outer_positional(c) + (a, b) ~ linear_gaussian(z = c) # a → interface x, b → interface y (by position) +end +``` + +**Named (kwarg-style):** explicitly map each outer variable to its interface name using `name = var` pairs. This is order-independent and recommended when submodel argument order may change. + +``` @example nested-models +@model function outer_named(my_z) + (y = my_y, x = my_x) ~ linear_gaussian(z = my_z) # binds by name, regardless of order + obs ~ Normal(my_x, my_y) +end +``` + +Both syntaxes work with indexed variables in loops: + +``` @example nested-models +@model function chain(z, n) + for i in 1:n + (x = xs[i], y = ys[i]) ~ linear_gaussian(z = z) + end +end +``` + +!!! note + If the same interface name appears on both LHS and RHS, `GraphPPL` raises an error at model-creation time. Similarly, providing a name on the LHS that does not match any of the submodel's interface names is caught with a descriptive error. \ No newline at end of file diff --git a/src/graph_engine.jl b/src/graph_engine.jl index 101450f0..35ed5a95 100644 --- a/src/graph_engine.jl +++ b/src/graph_engine.jl @@ -1203,6 +1203,7 @@ struct StaticInterfaces{I} end StaticInterfaces(I::Tuple) = StaticInterfaces{I}() Base.getindex(::StaticInterfaces{I}, index) where {I} = I[index] +iface_names(::StaticInterfaces{I}) where {I} = I function Base.convert(::Type{NamedTuple}, ::StaticInterfaces{I}, t::Tuple) where {I} return NamedTuple{I}(t) @@ -1870,6 +1871,37 @@ function prepare_interfaces(::StaticInterfaces{I}, fform::F, lhs_interface, rhs_ return NamedTuple{(missing_interface, keys(rhs_interfaces)...)}((lhs_interface, values(rhs_interfaces)...)) end +# Multi-output: lhs_interfaces is a Tuple of multiple interfaces (positional) +function prepare_interfaces(model::Model, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {F} + n_lhs = length(lhs_interfaces) + missing = missing_interfaces(model, fform, static(length(rhs_interfaces) + n_lhs), rhs_interfaces) + return prepare_interfaces_multi(missing, fform, lhs_interfaces, rhs_interfaces) +end + +function prepare_interfaces_multi(::StaticInterfaces{I}, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {I, F} + all_keys = (I..., keys(rhs_interfaces)...) + all_vals = (lhs_interfaces..., values(rhs_interfaces)...) + return NamedTuple{all_keys}(all_vals) +end + +# Named-output: lhs_interfaces is a NamedTuple (kwarg-style, e.g. (a = m_a, b = m_b) ~ sub(x = x)) +function prepare_interfaces(model::Model, fform::F, lhs_interfaces::NamedTuple, rhs_interfaces::NamedTuple) where {F} + for k in keys(lhs_interfaces) + if k ∈ keys(rhs_interfaces) + error(lazy"Interface ':$(k)' of '$(fform)' is specified on both LHS and RHS.") + end + end + merged = merge(lhs_interfaces, rhs_interfaces) + all_ifaces = interfaces(model, fform, static(length(merged))) + valid_names = iface_names(all_ifaces) + for k in keys(lhs_interfaces) + if k ∉ valid_names + error(lazy"Interface ':$(k)' does not exist in '$(fform)'. Valid interfaces are: $(valid_names).") + end + end + return merged +end + function materialize_interface(model, context, interface) return getifcreated(model, context, unroll(interface)) end @@ -2019,6 +2051,16 @@ make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + 1) ) +# Multi-output: Tuple LHS for composite nodes (positional) +make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::Tuple, rhs_interfaces::NamedTuple) where {F} = make_node!( + Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + length(lhs_interface)) +) + +# Named-output: NamedTuple LHS for composite nodes (kwarg-style) +make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::NamedTuple, rhs_interfaces::NamedTuple) where {F} = make_node!( + Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + length(lhs_interface)) +) + """ make_node! diff --git a/src/model_macro.jl b/src/model_macro.jl index 525b2dd1..5a1388f8 100644 --- a/src/model_macro.jl +++ b/src/model_macro.jl @@ -399,15 +399,41 @@ A `quote` block with the modified expression. """ function add_get_or_create_expression(e::Expr) if @capture(e, (lhs_ ~ rhs_ where {options__})) - @capture(lhs, (var_[index__]) | (var_)) - return quote - $(generate_get_or_create(var, index, rhs)) - $e + if lhs isa Expr && lhs.head == :tuple + if is_named_tuple_lhs(lhs) + # Named-output: (a = m_a, b = m_b) ~ sub(...) — get_or_create the outer (RHS) variables + creates = map(lhs.args) do elem + outer_var_expr = elem.args[2] + @capture(outer_var_expr, (var_[index__]) | (var_)) || + error("Invalid named LHS value in $(elem). Value must be in a `var` or `var[index]` form.") + generate_get_or_create(var, index, rhs) + end + else + # Positional-output: (m_a, m_b) ~ sub(...) + creates = map(lhs.args) do elem + @capture(elem, (var_[index__]) | (var_)) || + error("Invalid tuple element on LHS: $(elem). Must be in a `var` or `var[index]` form.") + generate_get_or_create(var, index, rhs) + end + end + return quote + $(creates...) + $e + end + else + @capture(lhs, (var_[index__]) | (var_)) + return quote + $(generate_get_or_create(var, index, rhs)) + $e + end end end return e end +is_named_tuple_lhs(lhs) = false +is_named_tuple_lhs(lhs::Expr) = lhs.head === :tuple && !isempty(lhs.args) && all(elem -> elem isa Expr && elem.head === :(=), lhs.args) + what_walk(::typeof(add_get_or_create_expression)) = not_created_by """ @@ -647,13 +673,51 @@ function convert_tilde_expression(e::Expr) options = GraphPPL.options_vector_to_named_tuple(options) nodesym = gensym(:node) varsym = gensym(:var) - @capture(lhs, (var_[index__]) | (var_)) || error("Invalid left-hand side $(lhs). Must be in a `var` or `var[index]` form.") - return quote - begin - $nodesym, $varsym = GraphPPL.make_node!( - __model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $(generate_lhs_proxylabel(var, index)), $args - ) - $varsym + if lhs isa Expr && lhs.head == :tuple + if is_named_tuple_lhs(lhs) + # Named-output: (a = m_a, b = m_b) ~ sub(...) — build a NamedTuple of proxy labels + proxy_pairs = map(lhs.args) do elem + iface_name = elem.args[1] + outer_var_expr = elem.args[2] + @capture(outer_var_expr, (var_[index__]) | (var_)) || + error("Invalid named LHS value in $(elem). Value must be in a `var` or `var[index]` form.") + Expr(:(=), iface_name, generate_lhs_proxylabel(var, index)) + end + lhs_named = Expr(:tuple, proxy_pairs...) + return quote + begin + $nodesym, $varsym = GraphPPL.make_node!( + __model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $lhs_named, $args + ) + $varsym + end + end + else + # Positional-output: (m_a, m_b) ~ sub(...) — build a plain Tuple of proxy labels + proxy_labels = map(lhs.args) do elem + @capture(elem, (var_[index__]) | (var_)) || + error("Invalid tuple element on LHS: $(elem). Must be in a `var` or `var[index]` form.") + generate_lhs_proxylabel(var, index) + end + lhs_tuple = Expr(:tuple, proxy_labels...) + return quote + begin + $nodesym, $varsym = GraphPPL.make_node!( + __model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $lhs_tuple, $args + ) + $varsym + end + end + end + else + @capture(lhs, (var_[index__]) | (var_)) || error("Invalid left-hand side $(lhs). Must be in a `var` or `var[index]` form.") + return quote + begin + $nodesym, $varsym = GraphPPL.make_node!( + __model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $(generate_lhs_proxylabel(var, index)), $args + ) + $varsym + end end end elseif @capture(e, (lhs_ .~ fform_(args__; kwargs__) where {options__}) | (lhs_ .~ fform_(args__) where {options__})) @@ -781,6 +845,50 @@ function get_make_node_function(model_specification, ms_body, ms_args, ms_name) return __context__, __lhs_interface__ end + # Multi-output: Tuple LHS (positional) + function GraphPPL.make_node!( + ::GraphPPL.Composite, + __model__::GraphPPL.Model, + __parent_context__::GraphPPL.Context, + __options__::GraphPPL.NodeCreationOptions, + ::typeof($ms_name), + __lhs_interface__::Tuple, + __rhs_interfaces__::NamedTuple, + __n_interfaces__::GraphPPL.StaticInt{$(length(ms_args))} + ) + __interfaces__ = GraphPPL.prepare_interfaces(__model__, $ms_name, __lhs_interface__, __rhs_interfaces__) + __context__ = GraphPPL.Context(__parent_context__, $ms_name) + GraphPPL.copy_markov_blanket_to_child_context(__context__, __interfaces__) + GraphPPL.add_composite_factor_node!(__model__, __parent_context__, __context__, $ms_name) + __returnval__ = GraphPPL.add_terminated_submodel!( + __model__, __context__, __options__, $ms_name, __interfaces__, __n_interfaces__ + ) + GraphPPL.returnval!(__context__, __returnval__) + return __context__, __lhs_interface__ + end + + # Named-output: NamedTuple LHS (kwarg-style) + function GraphPPL.make_node!( + ::GraphPPL.Composite, + __model__::GraphPPL.Model, + __parent_context__::GraphPPL.Context, + __options__::GraphPPL.NodeCreationOptions, + ::typeof($ms_name), + __lhs_interface__::NamedTuple, + __rhs_interfaces__::NamedTuple, + __n_interfaces__::GraphPPL.StaticInt{$(length(ms_args))} + ) + __interfaces__ = GraphPPL.prepare_interfaces(__model__, $ms_name, __lhs_interface__, __rhs_interfaces__) + __context__ = GraphPPL.Context(__parent_context__, $ms_name) + GraphPPL.copy_markov_blanket_to_child_context(__context__, __interfaces__) + GraphPPL.add_composite_factor_node!(__model__, __parent_context__, __context__, $ms_name) + __returnval__ = GraphPPL.add_terminated_submodel!( + __model__, __context__, __options__, $ms_name, __interfaces__, __n_interfaces__ + ) + GraphPPL.returnval!(__context__, __returnval__) + return __context__, __lhs_interface__ + end + function GraphPPL.add_terminated_submodel!( __model__::GraphPPL.Model, __context__::GraphPPL.Context, diff --git a/test/multi_and_zero_output_tests.jl b/test/multi_and_zero_output_tests.jl new file mode 100644 index 00000000..229a9729 --- /dev/null +++ b/test/multi_and_zero_output_tests.jl @@ -0,0 +1,285 @@ +@testitem "Multi-output submodel: two outputs" begin + using Distributions + import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function two_output_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) + end + + @model function main_multi_two(x) + (a, b) ~ two_output_sub(x = x) + y ~ Normal(a, b) + end + + model = create_model(main_multi_two()) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + return (x = x,) + end + + # The submodel has 1 Normal inside, plus the outer model has 1 Normal => 3 total Normal nodes + # (two_output_sub creates 2 Normals, main creates 1) + @test length(collect(filter(as_node(Normal), model))) === 3 + @test length(collect(filter(as_variable(:a), model))) === 1 + @test length(collect(filter(as_variable(:b), model))) === 1 + @test length(collect(filter(as_variable(:x), model))) === 1 +end + +@testitem "Multi-output submodel: three outputs" begin + using Distributions + import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function three_output_sub(a, b, c, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) + c ~ Normal(b, 1) + end + + @model function main_multi_three(x) + (a, b, c) ~ three_output_sub(x = x) + y ~ Normal(a + b, c) + end + + model = create_model(main_multi_three()) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + return (x = x,) + end + + # three_output_sub: 3 Normals, main: 1 Normal => 4 Normals total + @test length(collect(filter(as_node(Normal), model))) === 4 + @test length(collect(filter(as_variable(:a), model))) === 1 + @test length(collect(filter(as_variable(:b), model))) === 1 + @test length(collect(filter(as_variable(:c), model))) === 1 +end + +@testitem "Multi-output submodel: wrong number of LHS variables" begin + using Distributions + import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function two_iface_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) + end + + # Only one missing interface (a) but providing two on LHS + @model function main_mismatch(x) + b ~ Normal(0, 1) + (p, q) ~ two_iface_sub(x = x, b = b) + end + + @test_throws "no method matching make_node!" create_model(main_mismatch()) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + return (x = x,) + end +end + +@testitem "Multi-output submodel in a loop" begin + using Distributions + import GraphPPL: create_model, getorcreate!, datalabel, NodeCreationOptions, VariableKindData + + include("testutils.jl") + + @model function pair_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) + end + + @model function main_loop_multi(x, n) + for i in 1:n + (a[i], b[i]) ~ pair_sub(x = x) + end + end + + model = create_model(main_loop_multi(n = 3)) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + return (x = x,) + end + + # Each iteration creates 2 Normals -> 3 * 2 = 6 + @test length(collect(filter(as_node(Normal), model))) === 6 + @test length(collect(filter(as_variable(:a), model))) === 3 + @test length(collect(filter(as_variable(:b), model))) === 3 +end + +@testitem "Multi-output submodel: named LHS basic" begin + using Distributions + import GraphPPL: create_model, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function named_two_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) + end + + @model function main_named_two(x) + (a = out_a, b = out_b) ~ named_two_sub(x = x) + y ~ Normal(out_a, out_b) + end + + model = create_model(main_named_two()) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + return (x = x,) + end + + @test length(collect(filter(as_node(Normal), model))) === 3 + @test length(collect(filter(as_variable(:out_a), model))) === 1 + @test length(collect(filter(as_variable(:out_b), model))) === 1 + @test length(collect(filter(as_variable(:x), model))) === 1 +end + +@testitem "Multi-output submodel: named LHS order-invariant" begin + using Distributions + import GraphPPL: create_model, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function ordered_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) + end + + # Provide b before a on the LHS — must still bind correctly by name + @model function main_swapped(x) + (b = out_b, a = out_a) ~ ordered_sub(x = x) + y ~ Normal(out_a, out_b) + end + + model = create_model(main_swapped()) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + return (x = x,) + end + + @test length(collect(filter(as_node(Normal), model))) === 3 + @test length(collect(filter(as_variable(:out_a), model))) === 1 + @test length(collect(filter(as_variable(:out_b), model))) === 1 +end + +@testitem "Multi-output submodel: named LHS in a loop" begin + using Distributions + import GraphPPL: create_model, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function loop_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) + end + + @model function main_named_loop(x, n) + for i in 1:n + (a = out_a[i], b = out_b[i]) ~ loop_sub(x = x) + end + end + + model = create_model(main_named_loop(n = 3)) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + return (x = x,) + end + + @test length(collect(filter(as_node(Normal), model))) === 6 + @test length(collect(filter(as_variable(:out_a), model))) === 3 + @test length(collect(filter(as_variable(:out_b), model))) === 3 +end + +@testitem "Multi-output submodel: named LHS invalid interface name" begin + using Distributions + import GraphPPL: create_model, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function valid_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) + end + + @model function main_bad_name(x) + (z = out_z, b = out_b) ~ valid_sub(x = x) + end + + @test_throws "does not exist in" create_model(main_bad_name()) do model, ctx + x = datalabel(model, ctx, NodeCreationOptions(kind = :data), :x, 1.0) + return (x = x,) + end +end + +@testitem "Multi-output submodel: named LHS conflicts with RHS" begin + using Distributions + import GraphPPL: create_model, datalabel, NodeCreationOptions + + include("testutils.jl") + + @model function conflict_sub(a, b, x) + a ~ Normal(x, 1) + b ~ Normal(a, 1) + end + + # 'b' appears as a key on both LHS (meaning "bind out_b to interface b") and as a kwarg on RHS — must error + @model function main_conflict(some_var) + (b = out_b, a = out_a) ~ conflict_sub(b = some_var) + end + + @test_throws "is specified on both LHS and RHS" create_model(main_conflict()) do model, ctx + some_var = datalabel(model, ctx, NodeCreationOptions(kind = :data), :some_var, 1.0) + return (some_var = some_var,) + end +end + +@testitem "Multi-output: is_named_tuple_lhs helper" begin + import GraphPPL: is_named_tuple_lhs + + @test is_named_tuple_lhs(:(a)) === false + @test is_named_tuple_lhs(:((a, b))) === false + @test is_named_tuple_lhs(:((a = x, b = y))) === true + @test is_named_tuple_lhs(:((a = x,))) === true + # Mixed tuple (one named, one positional) is not pure named + @test is_named_tuple_lhs(Expr(:tuple, Expr(:(=), :a, :x), :b)) === false +end + +@testitem "Multi-output: add_get_or_create_expression pipeline step" begin + import GraphPPL: apply_pipeline, add_get_or_create_expression + + include("testutils.jl") + + # Positional tuple LHS produces get_or_create blocks for each element + input_positional = quote + (a, b) ~ sub(x = x) where {created_by = ((a, b) ~ sub(x = x))} + end + result = apply_pipeline(input_positional, add_get_or_create_expression) + @test result isa Expr + + # Named tuple LHS produces get_or_create blocks for the outer (value) variables + input_named = quote + (a = m_a, b = m_b) ~ sub(x = x) where {created_by = ((a = m_a, b = m_b) ~ sub(x = x))} + end + result = apply_pipeline(input_named, add_get_or_create_expression) + @test result isa Expr +end + +@testitem "Multi-output: convert_tilde_expression pipeline step" begin + import GraphPPL: apply_pipeline, convert_tilde_expression + + include("testutils.jl") + + # Positional tuple LHS emits a make_node! call + input_positional = quote + (a, b) ~ sub(x = x) where {created_by = ((a, b) ~ sub(x = x))} + end + result_positional = apply_pipeline(input_positional, convert_tilde_expression) + @test result_positional isa Expr + @test occursin("make_node!", string(result_positional)) + + # Named tuple LHS also emits a make_node! call + input_named = quote + (a = m_a, b = m_b) ~ sub(x = x) where {created_by = ((a = m_a, b = m_b) ~ sub(x = x))} + end + result_named = apply_pipeline(input_named, convert_tilde_expression) + @test result_named isa Expr + @test occursin("make_node!", string(result_named)) +end diff --git a/test/runtests.jl b/test/runtests.jl index 319b12f1..2bfb95ea 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,7 @@ end nthreads, ncores = Hwloc.num_virtual_cores(), Hwloc.num_physical_cores() nthreads, ncores = max(nthreads, 1), max(ncores, 1) -nworker_threads = Int(nthreads / ncores) +nworker_threads = max(1, div(nthreads, ncores)) memory_threshold = 1.0 pkg_root = dirname(pathof(GraphPPL)) |> dirname