Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion examples/hybrid/plane/bubble_2d_invariant_rhoe.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,6 @@ function rhs_invariant!(dY, Y, _, t)
dρ .= 0 .* cρ

cw = If2c.(fw)
fuₕ = Ic2f.(cuₕ)
cuw = Geometry.Covariant13Vector.(cuₕ) .+ Geometry.Covariant13Vector.(cw)

ce = @. cρe / cρ
Expand Down
2 changes: 1 addition & 1 deletion src/ClimaCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ const VERSION = PkgVersion.@Version
import ClimaComms

include("DebugOnly/DebugOnly.jl")
include("Utilities/Utilities.jl")
include("interface.jl")
include("devices.jl")
include("Utilities/Utilities.jl")
include("RecursiveApply/RecursiveApply.jl")
include("DataLayouts/DataLayouts.jl")
include("Geometry/Geometry.jl")
Expand Down
6 changes: 4 additions & 2 deletions src/DataLayouts/DataLayouts.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ import MultiBroadcastFusion as MBF
import Adapt
using UnrolledUtilities

import ..Utilities.Unrolled:
unrolled_setindex, unrolled_insert, unrolled_map_with_inbounds
import ..Utilities:
PlusHalf, unionall_type, replace_type_parameter, fieldtype_vals
import ..DebugOnly: call_post_op_callback, post_op_callback
Expand Down Expand Up @@ -346,8 +348,8 @@ end
function replace_storage(data::AbstractData, ::Type{S}, ::Type{T}) where {S, T}
D = field_dim(singleton(data))
params = Base.tail(type_params(data))
new_array_size = Base.setindex(size(parent(data)), num_basetypes(T, S), D)
new_array = similar(parent(data), T, new_array_size...)
new_size = unrolled_setindex(size(parent(data)), num_basetypes(T, S), Val(D))
new_array = similar(parent(data), T, new_size...)
return union_all(singleton(data)){S, params...}(new_array)
end

Expand Down
6 changes: 3 additions & 3 deletions src/DataLayouts/non_extruded_broadcasted.jl
Original file line number Diff line number Diff line change
Expand Up @@ -140,11 +140,11 @@ Base.@propagate_inbounds function _broadcast_getindex(
end
@inline _broadcast_getindex_evalf(f::Tf, args::Vararg{Any, N}) where {Tf, N} =
f(args...) # not propagate_inbounds
Base.@propagate_inbounds function _getindex(args::Tuple, I)
unrolled_map(args) do arg
Base.@propagate_inbounds _getindex(args::Tuple, I) =
unrolled_map_with_inbounds(args) do arg
Base.@_propagate_inbounds_meta
_broadcast_getindex(arg, I)
end
end

@inline Base.axes(bc::NonExtrudedBroadcasted) = _axes(bc, bc.axes)
_axes(::NonExtrudedBroadcasted, axes::Tuple) = axes
Expand Down
10 changes: 5 additions & 5 deletions src/DataLayouts/struct_storage.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ can be specified through a `Val` that contains its index `D`.
num_D_indices = num_basetypes(eltype(array), fieldtype(S, F))
last_D_index = num_basetypes(eltype(array), Tuple{fieldtypes(S)[1:F]...})
D_indices = (last_D_index - num_D_indices + 1):last_D_index
all_indices = Base.setindex(axes(array), D_indices, D)
all_indices = unrolled_setindex(axes(array), D_indices, Val(D))
@boundscheck checkbounds(array, all_indices...)
return Base.unsafe_view(array, all_indices...)
end
Expand All @@ -119,8 +119,9 @@ end
index::CartesianIndex,
::Val{D} = Val(ndims(array)),
) where {num_indices, D}
start = CartesianIndex(Tuple(index)[1:(D - 1)]..., 1, Tuple(index)[D:end]...)
checkbounds(array, start:Base.setindex(start, num_indices, D))
start = CartesianIndex(unrolled_insert(Tuple(index), 1, Val(D)))
stop = CartesianIndex(unrolled_insert(Tuple(index), num_indices, Val(D)))
checkbounds(array, start:stop)
end

@inline struct_index(i, array) = i
Expand All @@ -135,8 +136,7 @@ end
array,
index::CartesianIndex,
::Val{D} = Val(ndims(array)),
) where {D} =
CartesianIndex(Tuple(index)[1:(D - 1)]..., i, Tuple(index)[D:end]...)
) where {D} = CartesianIndex(unrolled_insert(Tuple(index), i, Val(D)))

"""
set_struct!(array, value, [index], [Val(D)])
Expand Down
24 changes: 13 additions & 11 deletions src/Operators/spectralelement.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import UnrolledUtilities: unrolled_map
import ..Utilities.Unrolled: unrolled_map_with_inbounds

abstract type AbstractSpectralStyle <: Fields.AbstractFieldStyle end

Expand Down Expand Up @@ -233,25 +234,23 @@ Base.@propagate_inbounds function resolve_operator(
bc::SpectralBroadcasted{SlabBlockSpectralStyle},
slabidx,
)
args = _resolve_operator_args(slabidx, bc.args)
args = _resolve_operator(slabidx, bc.args)
apply_operator(bc.op, bc.axes, slabidx, args...)
end
Base.@propagate_inbounds function resolve_operator(
bc::Base.Broadcast.Broadcasted{SlabBlockSpectralStyle},
slabidx,
)
args = _resolve_operator_args(slabidx, bc.args)
args = _resolve_operator(slabidx, bc.args)
Base.Broadcast.Broadcasted{SlabBlockSpectralStyle}(bc.f, args, bc.axes)
end
@inline resolve_operator(x, slabidx) = x

"""
_resolve_operator_args(slabidx, args)

Calls `resolve_operator(arg, slabidx)` for each `arg` in `args`
"""
Base.@propagate_inbounds _resolve_operator_args(slabidx, args) =
unrolled_map(arg -> resolve_operator(arg, slabidx), args)
Base.@propagate_inbounds _resolve_operator(slabidx, args) =
unrolled_map_with_inbounds(args) do arg
Base.@_propagate_inbounds_meta
resolve_operator(arg, slabidx)
end

function strip_space(bc::SpectralBroadcasted{Style}, parent_space) where {Style}
current_space = axes(bc)
Expand Down Expand Up @@ -329,8 +328,11 @@ end
return slabidx.v + half <= Nv
end

Base.@propagate_inbounds _get_node(space, ij, slabidx, args::Tuple) =
unrolled_map(arg -> get_node(space, arg, ij, slabidx), args)
Base.@propagate_inbounds _get_node(space, ij, slabidx, args) =
unrolled_map_with_inbounds(args) do arg
Base.@_propagate_inbounds_meta
get_node(space, arg, ij, slabidx)
end

Base.@propagate_inbounds function get_node(space, scalar, ij, slabidx)
scalar[]
Expand Down
27 changes: 27 additions & 0 deletions src/Utilities/Utilities.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,33 @@ import UnrolledUtilities: unrolled_map
include("plushalf.jl")
include("cache.jl")

module Unrolled # TODO: Move all of these functions into UnrolledUtilities.jl

# Alternative to Base.setindex with guaranteed constant propagation
@inline unrolled_setindex(x::Tuple, value, ::Val{i}) where {i} =
ntuple(n -> n == i ? value : x[n], Val(length(x)))

# Analogue of insert! that follows the same pattern as unrolled_setindex
@inline unrolled_insert(x::Tuple, value, ::Val{i}) where {i} =
ntuple(n -> n == i ? value : x[n < i ? n : n - 1], Val(length(x) + 1))

# Same as UnrolledUtilities.unrolled_map, but annotated with @propagate_inbounds
@generated unrolled_map_with_inbounds(f, x::NTuple{N, Any}) where {N} = quote
Base.@_propagate_inbounds_meta
Comment thread
dennisYatunin marked this conversation as resolved.
return Base.Cartesian.@ntuple $N n -> f(x[n])
end

# Remove each function's recursion limit for better type inference on Julia 1.10
if hasfield(Method, :recursion_relation)
for f in (unrolled_setindex, unrolled_insert, unrolled_map_with_inbounds)
for m in methods(f)
m.recursion_relation = Returns(true)
end
end
end

end

"""
cart_ind(n::NTuple, i::Integer)

Expand Down
24 changes: 10 additions & 14 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# Toplevel interface functions for recurisve broadcast expressions
import UnrolledUtilities: unrolled_map
import ..Utilities.Unrolled: unrolled_map_with_inbounds

"""
slab(data::AbstractData, h::Integer)
Expand All @@ -13,13 +13,11 @@ function slab end
Base.@propagate_inbounds slab(x, inds...) = x
Base.@propagate_inbounds slab(tup::Tuple, inds...) = slab_args(tup, inds...)

# Recursively call slab() on broadcast arguments in a way that is statically reducible by the optimizer
# see Base.Broadcast.preprocess_args
Base.@propagate_inbounds function slab_args(args::Tuple, inds...)
unrolled_map(args) do arg
Base.@propagate_inbounds slab_args(args::Tuple, inds...) =
unrolled_map_with_inbounds(args) do arg
Base.@_propagate_inbounds_meta
Comment thread
dennisYatunin marked this conversation as resolved.
slab(arg, inds...)
end
end
Base.@propagate_inbounds slab_args(args::NamedTuple, inds...) =
NamedTuple{keys(args)}(slab_args(values(args), inds...))

Expand All @@ -35,21 +33,19 @@ function column end
Base.@propagate_inbounds column(x, inds...) = x
Base.@propagate_inbounds column(tup::Tuple, inds...) = column_args(tup, inds...)

# Recursively call column() on broadcast arguments in a way that is statically reducible by the optimizer
# see Base.Broadcast.preprocess_args
Base.@propagate_inbounds function column_args(args::Tuple, inds...)
unrolled_map(args) do arg
Base.@propagate_inbounds column_args(args::Tuple, inds...) =
unrolled_map_with_inbounds(args) do arg
Base.@_propagate_inbounds_meta
column(arg, inds...)
end
end
Base.@propagate_inbounds column_args(args::NamedTuple, inds...) =
NamedTuple{keys(args)}(column_args(values(args), inds...))

function level end

Base.@propagate_inbounds level(x, inds...) = x
Base.@propagate_inbounds function level_args(args::Tuple, inds...)
unrolled_map(args) do arg
Base.@propagate_inbounds level_args(args::Tuple, inds...) =
unrolled_map_with_inbounds(args) do arg
Base.@_propagate_inbounds_meta
level(arg, inds...)
end
end
Loading