Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/backends/default.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ struct DefaultBackend end
function GraphPPL.model_macro_interior_pipelines(::DefaultBackend)
return (
GraphPPL.check_reserved_variable_names_model,
GraphPPL.convert_zero_output_tilde,
GraphPPL.save_expression_in_tilde,
GraphPPL.convert_deterministic_statement,
GraphPPL.convert_local_statement,
Expand Down
57 changes: 57 additions & 0 deletions src/graph_engine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,14 @@ Base.show(io::IO, label::EdgeLabel) = print(io, to_symbol(label))
Base.:(==)(label1::EdgeLabel, label2::EdgeLabel) = label1.name == label2.name && label1.index == label2.index
Base.hash(label::EdgeLabel, h::UInt) = hash(label.name, hash(label.index, h))

"""
NothingInterface

A sentinel type used to represent a zero-output submodel invocation (no LHS variable).
When used as the `lhs_interface` in `make_node!`, all interfaces must be specified on the RHS.
"""
struct NothingInterface end

"""
Splat{T}

Expand Down Expand Up @@ -1863,6 +1871,41 @@ function prepare_interfaces(::StaticInterfaces{I}, fform::F, lhs_interface, rhs_
return NamedTuple{(missing_interface, keys(rhs_interfaces)...)}((lhs_interface, values(rhs_interfaces)...))
end

# Multi-output: lhs_interfaces is a Tuple of multiple interfaces
function prepare_interfaces(model::Model, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {F}
n_lhs = length(lhs_interfaces)
missing = missing_interfaces(model, fform, static(length(rhs_interfaces) + n_lhs), rhs_interfaces)
return prepare_interfaces_multi(missing, fform, lhs_interfaces, rhs_interfaces)
end

function prepare_interfaces_multi(::StaticInterfaces{I}, fform::F, lhs_interfaces::Tuple, rhs_interfaces::NamedTuple) where {I, F}
if length(I) != length(lhs_interfaces)
n = "\n"
error(
lazy"Node '$(fform)' has $(length(I)) missing interface(s) $(I) but $(length(lhs_interfaces)) were provided on the LHS.$(n)$(n)Currently specified interfaces are: $(keys(rhs_interfaces))."
)
end
all_keys = (I..., keys(rhs_interfaces)...)
all_vals = (lhs_interfaces..., values(rhs_interfaces)...)
return NamedTuple{all_keys}(all_vals)
end

# Zero-output: NothingInterface means all interfaces are on the RHS
function prepare_interfaces(model::Model, fform::F, ::NothingInterface, rhs_interfaces::NamedTuple) where {F}
missing = missing_interfaces(model, fform, static(length(rhs_interfaces)), rhs_interfaces)
return prepare_interfaces_zero(missing, fform, rhs_interfaces)
end

function prepare_interfaces_zero(::StaticInterfaces{I}, fform::F, rhs_interfaces::NamedTuple) where {I, F}
if length(I) != 0
n = "\n"
error(
lazy"Zero-output call to '$(fform)' but $(length(I)) interface(s) are still missing: $(I).$(n)$(n)Currently specified interfaces are: $(keys(rhs_interfaces)). All interfaces must be provided for a zero-output submodel call."
)
end
return rhs_interfaces
end

function materialize_interface(model, context, interface)
return getifcreated(model, context, unroll(interface))
end
Expand Down Expand Up @@ -2012,6 +2055,20 @@ make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model
Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + 1)
)

# Multi-output: Tuple LHS for composite nodes
make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::Tuple, rhs_interfaces::NamedTuple) where {F} = make_node!(
Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces) + length(lhs_interface))
)

# Zero-output: NothingInterface LHS for composite nodes
make_node!(nodetype::Composite, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::NothingInterface, rhs_interfaces) where {F} = make_node!(
True(), nodetype, Stochastic(), model, ctx, options, fform, lhs_interface, rhs_interfaces
)

make_node!(materialize::True, node_type::Composite, behaviour::Stochastic, model::Model, ctx::Context, options::NodeCreationOptions, fform::F, lhs_interface::NothingInterface, rhs_interfaces::NamedTuple) where {F} = make_node!(
Composite(), model, ctx, options, fform, lhs_interface, rhs_interfaces, static(length(rhs_interfaces))
)

"""
make_node!

Expand Down
126 changes: 114 additions & 12 deletions src/model_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ function check_reserved_variable_names_model(e::Expr)
:(__lhs_interface__),
:(__rhs_interfaces__),
:(__interfaces__),
:(__n_interfaces__)
:(__n_interfaces__),
:(__nothing__)
]
)
error("Variable name in $(prettify(e)) cannot be used as it is a reserved variable name in the model macro.")
Expand All @@ -166,6 +167,21 @@ end

what_walk(::typeof(check_incomplete_factorization_constraint)) = walk_until_occurrence((:(lhs_ = rhs_), :(lhs_::rhs_)))

"""
convert_zero_output_tilde(e::Expr)

Converts a zero-output submodel call using unary `~` syntax to a binary `~` with `__nothing__` as the LHS.
This transforms `~ submodel(x = val1, y = val2)` into `__nothing__ ~ submodel(x = val1, y = val2)`.
"""
function convert_zero_output_tilde(e::Expr)
# Unary ~ is parsed as Expr(:call, :~, rhs)
if e.head == :call && length(e.args) == 2 && e.args[1] == :~
rhs = e.args[2]
return Expr(:call, :~, :__nothing__, rhs)
end
return e
end

"""
save_expression_in_tilde(expr::Expr)

Expand Down Expand Up @@ -399,10 +415,25 @@ A `quote` block with the modified expression.
"""
function add_get_or_create_expression(e::Expr)
if @capture(e, (lhs_ ~ rhs_ where {options__}))
@capture(lhs, (var_[index__]) | (var_))
return quote
$(generate_get_or_create(var, index, rhs))
$e
if lhs === :__nothing__
# Zero-output: no variable to get or create
return e
elseif lhs isa Expr && lhs.head == :tuple
# Multi-output: get_or_create each element in the tuple
creates = map(lhs.args) do elem
@capture(elem, (var_[index__]) | (var_)) || error("Invalid tuple element on LHS: $(elem). Must be in a `var` or `var[index]` form.")
generate_get_or_create(var, index, rhs)
end
return quote
$(creates...)
$e
end
else
@capture(lhs, (var_[index__]) | (var_))
return quote
$(generate_get_or_create(var, index, rhs))
$e
end
end
end
return e
Expand Down Expand Up @@ -647,13 +678,40 @@ function convert_tilde_expression(e::Expr)
options = GraphPPL.options_vector_to_named_tuple(options)
nodesym = gensym(:node)
varsym = gensym(:var)
@capture(lhs, (var_[index__]) | (var_)) || error("Invalid left-hand side $(lhs). Must be in a `var` or `var[index]` form.")
return quote
begin
$nodesym, $varsym = GraphPPL.make_node!(
__model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $(generate_lhs_proxylabel(var, index)), $args
)
$varsym
if lhs === :__nothing__
# Zero-output: pass NothingInterface as lhs
return quote
begin
$nodesym, $varsym = GraphPPL.make_node!(
__model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, GraphPPL.NothingInterface(), $args
)
$varsym
end
end
elseif lhs isa Expr && lhs.head == :tuple
# Multi-output: generate a tuple of proxy labels for the LHS
proxy_labels = map(lhs.args) do elem
@capture(elem, (var_[index__]) | (var_)) || error("Invalid tuple element on LHS: $(elem). Must be in a `var` or `var[index]` form.")
generate_lhs_proxylabel(var, index)
end
lhs_tuple = Expr(:tuple, proxy_labels...)
return quote
begin
$nodesym, $varsym = GraphPPL.make_node!(
__model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $lhs_tuple, $args
)
$varsym
end
end
else
@capture(lhs, (var_[index__]) | (var_)) || error("Invalid left-hand side $(lhs). Must be in a `var` or `var[index]` form.")
return quote
begin
$nodesym, $varsym = GraphPPL.make_node!(
__model__, __context__, GraphPPL.NodeCreationOptions($(options)), $fform, $(generate_lhs_proxylabel(var, index)), $args
)
$varsym
end
end
end
elseif @capture(e, (lhs_ .~ fform_(args__; kwargs__) where {options__}) | (lhs_ .~ fform_(args__) where {options__}))
Expand Down Expand Up @@ -781,6 +839,50 @@ function get_make_node_function(model_specification, ms_body, ms_args, ms_name)
return __context__, __lhs_interface__
end

# Multi-output: Tuple LHS
function GraphPPL.make_node!(
::GraphPPL.Composite,
__model__::GraphPPL.Model,
__parent_context__::GraphPPL.Context,
__options__::GraphPPL.NodeCreationOptions,
::typeof($ms_name),
__lhs_interface__::Tuple,
__rhs_interfaces__::NamedTuple,
__n_interfaces__::GraphPPL.StaticInt{$(length(ms_args))}
)
__interfaces__ = GraphPPL.prepare_interfaces(__model__, $ms_name, __lhs_interface__, __rhs_interfaces__)
__context__ = GraphPPL.Context(__parent_context__, $ms_name)
GraphPPL.copy_markov_blanket_to_child_context(__context__, __interfaces__)
GraphPPL.add_composite_factor_node!(__model__, __parent_context__, __context__, $ms_name)
__returnval__ = GraphPPL.add_terminated_submodel!(
__model__, __context__, __options__, $ms_name, __interfaces__, __n_interfaces__
)
GraphPPL.returnval!(__context__, __returnval__)
return __context__, __lhs_interface__
end

# Zero-output: NothingInterface LHS
function GraphPPL.make_node!(
::GraphPPL.Composite,
__model__::GraphPPL.Model,
__parent_context__::GraphPPL.Context,
__options__::GraphPPL.NodeCreationOptions,
::typeof($ms_name),
__lhs_interface__::GraphPPL.NothingInterface,
__rhs_interfaces__::NamedTuple,
__n_interfaces__::GraphPPL.StaticInt{$(length(ms_args))}
)
__interfaces__ = GraphPPL.prepare_interfaces(__model__, $ms_name, __lhs_interface__, __rhs_interfaces__)
__context__ = GraphPPL.Context(__parent_context__, $ms_name)
GraphPPL.copy_markov_blanket_to_child_context(__context__, __interfaces__)
GraphPPL.add_composite_factor_node!(__model__, __parent_context__, __context__, $ms_name)
__returnval__ = GraphPPL.add_terminated_submodel!(
__model__, __context__, __options__, $ms_name, __interfaces__, __n_interfaces__
)
GraphPPL.returnval!(__context__, __returnval__)
return __context__, __lhs_interface__
end

function GraphPPL.add_terminated_submodel!(
__model__::GraphPPL.Model,
__context__::GraphPPL.Context,
Expand Down
Loading
Loading