Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion docs/src/nested_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,4 +30,43 @@ If we want to chain these `gcv` submodels together into a Hierarchical Gaussian
end
```

Note that in our invocations of `gcv`, we haven't specified the `y` argument of the Markov Blanket. This is what is being recognized as the missing interface and `GraphPPL` will assign `means[i]` to `y`.
Note that in our invocations of `gcv`, we haven't specified the `y` argument of the Markov Blanket. This is what is being recognized as the missing interface and `GraphPPL` will assign `means[i]` to `y`.

## Multi-output submodels

When a submodel produces multiple outputs — multiple interfaces left unspecified on the RHS — you can bind them all on the LHS using a tuple. There are two syntaxes:

**Positional:** list outer variables in the same order as the unspecified interfaces appear in the submodel definition.

``` @example nested-models
@model function linear_gaussian(x, y, z)
x ~ Normal(z, 1)
y ~ Normal(x, 1)
end

@model function outer_positional(c)
(a, b) ~ linear_gaussian(z = c) # a → interface x, b → interface y (by position)
end
```

**Named (kwarg-style):** explicitly map each outer variable to its interface name using `name = var` pairs. This is order-independent and recommended when submodel argument order may change.

``` @example nested-models
@model function outer_named(my_z)
(y = my_y, x = my_x) ~ linear_gaussian(z = my_z) # binds by name, regardless of order
obs ~ Normal(my_x, my_y)
end
```

Both syntaxes work with indexed variables in loops:

``` @example nested-models
@model function chain(z, n)
for i in 1:n
(x = xs[i], y = ys[i]) ~ linear_gaussian(z = z)
end
end
```

!!! note
If the same interface name appears on both LHS and RHS, `GraphPPL` raises an error at model-creation time. Similarly, providing a name on the LHS that does not match any of the submodel's interface names is caught with a descriptive error.
42 changes: 42 additions & 0 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1203,6 +1203,7 @@ struct StaticInterfaces{I} end

StaticInterfaces(I::Tuple) = StaticInterfaces{I}()
Base.getindex(::StaticInterfaces{I}, index) where {I} = I[index]
iface_names(::StaticInterfaces{I}) where {I} = I

function Base.convert(::Type{NamedTuple}, ::StaticInterfaces{I}, t::Tuple) where {I}
return NamedTuple{I}(t)
Expand Down Expand Up @@ -1870,6 +1871,37 @@ function prepare_interfaces(::StaticInterfaces{I}, fform::F, lhs_interface, rhs_
return NamedTuple{(missing_interface, keys(rhs_interfaces)...)}((lhs_interface, values(rhs_interfaces)...))
end

# Multi-output: lhs_interfaces is a Tuple of multiple interfaces (positional)
function prepare_interfaces(model::Model, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {F}
n_lhs = length(lhs_interfaces)
missing = missing_interfaces(model, fform, static(length(rhs_interfaces) + n_lhs), rhs_interfaces)
return prepare_interfaces_multi(missing, fform, lhs_interfaces, rhs_interfaces)
end

function prepare_interfaces_multi(::StaticInterfaces{I}, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {I, F}
all_keys = (I..., keys(rhs_interfaces)...)
all_vals = (lhs_interfaces..., values(rhs_interfaces)...)
return NamedTuple{all_keys}(all_vals)
end

# Named-output: lhs_interfaces is a NamedTuple (kwarg-style, e.g. (a = m_a, b = m_b) ~ sub(x = x))
function prepare_interfaces(model::Model, fform::F, lhs_interfaces::NamedTuple, rhs_interfaces::NamedTuple) where {F}
for k in keys(lhs_interfaces)
if k ∈ keys(rhs_interfaces)
error(lazy"Interface ':$(k)' of '$(fform)' is specified on both LHS and RHS.")
end
end
merged = merge(lhs_interfaces, rhs_interfaces)
all_ifaces = interfaces(model, fform, static(length(merged)))
valid_names = iface_names(all_ifaces)
for k in keys(lhs_interfaces)
if k ∉ valid_names
error(lazy"Interface ':$(k)' does not exist in '$(fform)'. Valid interfaces are: $(valid_names).")
end
end
return merged
end

function materialize_interface(model, context, interface)
return getifcreated(model, context, unroll(interface))
end
Expand Down Expand Up @@ -2019,6 +2051,16 @@ make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model
Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + 1)
)

# Multi-output: Tuple LHS for composite nodes (positional)
make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::Tuple, rhs_interfaces::NamedTuple) where {F} = make_node!(
Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + length(lhs_interface))
)

# Named-output: NamedTuple LHS for composite nodes (kwarg-style)
make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::NamedTuple, rhs_interfaces::NamedTuple) where {F} = make_node!(
Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + length(lhs_interface))
)

"""
make_node!

Expand Down
130 changes: 119 additions & 11 deletions src/model_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -399,15 +399,41 @@ A `quote` block with the modified expression.
"""
function add_get_or_create_expression(e::Expr)
if @capture(e, (lhs_ ~ rhs_ where {options__}))
@capture(lhs, (var_[index__]) | (var_))
return quote
$(generate_get_or_create(var, index, rhs))
$e
if lhs isa Expr && lhs.head == :tuple
if is_named_tuple_lhs(lhs)
# Named-output: (a = m_a, b = m_b) ~ sub(...) — get_or_create the outer (RHS) variables
creates = map(lhs.args) do elem
outer_var_expr = elem.args[2]
@capture(outer_var_expr, (var_[index__]) | (var_)) ||
error("Invalid named LHS value in $(elem). Value must be in a `var` or `var[index]` form.")
generate_get_or_create(var, index, rhs)
end
else
# Positional-output: (m_a, m_b) ~ sub(...)
creates = map(lhs.args) do elem
@capture(elem, (var_[index__]) | (var_)) ||
error("Invalid tuple element on LHS: $(elem). Must be in a `var` or `var[index]` form.")
generate_get_or_create(var, index, rhs)
end
end
return quote
$(creates...)
$e
end
else
@capture(lhs, (var_[index__]) | (var_))
return quote
$(generate_get_or_create(var, index, rhs))
$e
end
end
end
return e
end

is_named_tuple_lhs(lhs) = false
is_named_tuple_lhs(lhs::Expr) = lhs.head === :tuple && !isempty(lhs.args) && all(elem -> elem isa Expr && elem.head === :(=), lhs.args)

what_walk(::typeof(add_get_or_create_expression)) = not_created_by

"""
Expand Down Expand Up @@ -647,13 +673,51 @@ function convert_tilde_expression(e::Expr)
options = GraphPPL.options_vector_to_named_tuple(options)
nodesym = gensym(:node)
varsym = gensym(:var)
@capture(lhs, (var_[index__]) | (var_)) || error("Invalid left-hand side $(lhs). Must be in a `var` or `var[index]` form.")
return quote
begin
$nodesym, $varsym = GraphPPL.make_node!(
__model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $(generate_lhs_proxylabel(var, index)), $args
)
$varsym
if lhs isa Expr && lhs.head == :tuple
if is_named_tuple_lhs(lhs)
# Named-output: (a = m_a, b = m_b) ~ sub(...) — build a NamedTuple of proxy labels
proxy_pairs = map(lhs.args) do elem
iface_name = elem.args[1]
outer_var_expr = elem.args[2]
@capture(outer_var_expr, (var_[index__]) | (var_)) ||
error("Invalid named LHS value in $(elem). Value must be in a `var` or `var[index]` form.")
Expr(:(=), iface_name, generate_lhs_proxylabel(var, index))
end
lhs_named = Expr(:tuple, proxy_pairs...)
return quote
begin
$nodesym, $varsym = GraphPPL.make_node!(
__model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $lhs_named, $args
)
$varsym
end
end
else
# Positional-output: (m_a, m_b) ~ sub(...) — build a plain Tuple of proxy labels
proxy_labels = map(lhs.args) do elem
@capture(elem, (var_[index__]) | (var_)) ||
error("Invalid tuple element on LHS: $(elem). Must be in a `var` or `var[index]` form.")
generate_lhs_proxylabel(var, index)
end
lhs_tuple = Expr(:tuple, proxy_labels...)
return quote
begin
$nodesym, $varsym = GraphPPL.make_node!(
__model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $lhs_tuple, $args
)
$varsym
end
end
end
else
@capture(lhs, (var_[index__]) | (var_)) || error("Invalid left-hand side $(lhs). Must be in a `var` or `var[index]` form.")
return quote
begin
$nodesym, $varsym = GraphPPL.make_node!(
__model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $(generate_lhs_proxylabel(var, index)), $args
)
$varsym
end
end
end
elseif @capture(e, (lhs_ .~ fform_(args__; kwargs__) where {options__}) | (lhs_ .~ fform_(args__) where {options__}))
Expand Down Expand Up @@ -781,6 +845,50 @@ function get_make_node_function(model_specification, ms_body, ms_args, ms_name)
return __context__, __lhs_interface__
end

# Multi-output: Tuple LHS (positional)
function GraphPPL.make_node!(
::GraphPPL.Composite,
__model__::GraphPPL.Model,
__parent_context__::GraphPPL.Context,
__options__::GraphPPL.NodeCreationOptions,
::typeof($ms_name),
__lhs_interface__::Tuple,
__rhs_interfaces__::NamedTuple,
__n_interfaces__::GraphPPL.StaticInt{$(length(ms_args))}
)
__interfaces__ = GraphPPL.prepare_interfaces(__model__, $ms_name, __lhs_interface__, __rhs_interfaces__)
__context__ = GraphPPL.Context(__parent_context__, $ms_name)
GraphPPL.copy_markov_blanket_to_child_context(__context__, __interfaces__)
GraphPPL.add_composite_factor_node!(__model__, __parent_context__, __context__, $ms_name)
__returnval__ = GraphPPL.add_terminated_submodel!(
__model__, __context__, __options__, $ms_name, __interfaces__, __n_interfaces__
)
GraphPPL.returnval!(__context__, __returnval__)
return __context__, __lhs_interface__
end

# Named-output: NamedTuple LHS (kwarg-style)
function GraphPPL.make_node!(
::GraphPPL.Composite,
__model__::GraphPPL.Model,
__parent_context__::GraphPPL.Context,
__options__::GraphPPL.NodeCreationOptions,
::typeof($ms_name),
__lhs_interface__::NamedTuple,
__rhs_interfaces__::NamedTuple,
__n_interfaces__::GraphPPL.StaticInt{$(length(ms_args))}
)
__interfaces__ = GraphPPL.prepare_interfaces(__model__, $ms_name, __lhs_interface__, __rhs_interfaces__)
__context__ = GraphPPL.Context(__parent_context__, $ms_name)
GraphPPL.copy_markov_blanket_to_child_context(__context__, __interfaces__)
GraphPPL.add_composite_factor_node!(__model__, __parent_context__, __context__, $ms_name)
__returnval__ = GraphPPL.add_terminated_submodel!(
__model__, __context__, __options__, $ms_name, __interfaces__, __n_interfaces__
)
GraphPPL.returnval!(__context__, __returnval__)
return __context__, __lhs_interface__
end

function GraphPPL.add_terminated_submodel!(
__model__::GraphPPL.Model,
__context__::GraphPPL.Context,
Expand Down
Loading
Loading