Skip to content

Add sketch of new Tensor interface#2457

Open
dennisYatunin wants to merge 13 commits into
mainfrom
dy/tensors
Open

Add sketch of new Tensor interface#2457
dennisYatunin wants to merge 13 commits into
mainfrom
dy/tensors

Conversation

@dennisYatunin
Copy link
Copy Markdown
Member

This PR refactors the Geometry module to use a much simpler interface, so that all vector/tensor operations can expressed using standard math operations instead of custom API functions. This will allow us to remove a large amount of duplicate code, reduce compilation latency, and speed up GPU runs by optimizing the geometry data passed to each kernel.

  • Code follows the style guidelines OR N/A.
  • Unit tests are included OR N/A.
  • Code is exercised in an integration test OR N/A.
  • Documentation has been added/updated OR N/A.

@dennisYatunin dennisYatunin force-pushed the dy/tensors branch 23 times, most recently from d1c85a9 to dc71352 Compare February 21, 2026 03:40
@dennisYatunin dennisYatunin force-pushed the dy/tensors branch 6 times, most recently from 271e3cd to 1f8a383 Compare February 25, 2026 06:57
@nefrathenrici nefrathenrici marked this pull request as ready for review April 30, 2026 00:32
@nefrathenrici nefrathenrici force-pushed the dy/tensors branch 2 times, most recently from b6c3879 to b9f7f37 Compare April 30, 2026 19:40
@nefrathenrici nefrathenrici force-pushed the dy/tensors branch 7 times, most recently from 9c066c5 to e354450 Compare May 1, 2026 23:58
Copy link
Copy Markdown
Member Author

@dennisYatunin dennisYatunin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this looks mergeable, save for some minor nitpicks. If you want to check this with ClimaAtmos before merging, feel free to define some aliases like const AxisTensor = Tensor, which I used to ensure backward compatibility when refactoring RecursiveApply.

Comment thread docs/src/matrix_fields.md

- A `UniformScaling`, which contains a `Number`
- A `DiagonalMatrixRow`, which can contain either a `Number` or a tensor (represented as a `Geometry.Axis2Tensor`)
- A `DiagonalMatrixRow`, which can contain either a `Number` or a rank-2 tensor (represented as a `Geometry.Tensor{2}`)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this say AbstractTensor{2} instead of Tensor{2}? Same for the instance of Tensor{2} below.

x::T1,
y::T2,
) where {T1 <: Geometry.AdjointAxisVector, T2 <: Geometry.Axis2Tensor} = (x * y)'
) where {T1 <: Geometry.AbstractCovector, T2 <: Geometry.Tensor{2}} = (x * y)'
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
) where {T1 <: Geometry.AbstractCovector, T2 <: Geometry.Tensor{2}} = (x * y)'
) where {T1 <: Geometry.AbstractCovector, T2 <: Geometry.AbstractTensor{2}} = (x * y)'

Comment on lines +4 to +5
# every conversion is a single matvec — names outside `lg`'s geometry `I`
# ride the identity block of the padded matrix automatically. Same-type
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a bit unclear. The LocalGeometry tensors now cover all available dimension names, and there's no additional padding here.

Comment on lines +56 to +57
# 1D vector types can be constructed from a scalar + LocalGeometry.
# The LocalGeometry is ignored — the scalar is wrapped directly.
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need to be supported? It looks redundant.

lg.∂ξ∂x[2, 1]*v[1, 1]+lg.∂ξ∂x[2, 2]*v[2, 1] lg.∂ξ∂x[2, 1]*v[1, 2]+lg.∂ξ∂x[2, 2]*v[2, 2]
]
)
for (BT, VecType, fn) in (
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like BT should be a symbol here, and fn can be replaced with VecType.

Comment on lines +34 to +35
basis1(::Type{<:Tensor{2, <:Any, <:Tuple{B, Any}}}) where {B} = B
basis2(::Type{<:Tensor{2, <:Any, <:Tuple{Any, B}}}) where {B} = B
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
basis1(::Type{<:Tensor{2, <:Any, <:Tuple{B, Any}}}) where {B} = B
basis2(::Type{<:Tensor{2, <:Any, <:Tuple{Any, B}}}) where {B} = B
basis1(::Type{T}) where {T <: AbstractTensor{2}} = Geometry.tensor_bases(T)[1]
basis1(::Type{T}) where {T <: AbstractTensor{2}} = Geometry.tensor_bases(T)[2]

Some of the other Tensor types below could also be replaced with AbstractTensors.

Comment thread src/Geometry/tensors.jl
Comment on lines +658 to +660
@inline outer(x::AbstractVector, y::AbstractVector) = x * y'
@inline outer(x::AbstractVector, y::Number) = x * y
@inline outer(x::AbstractVector, y) = nested_broadcast(y -> x ⊗ y, y)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@inline outer(x::AbstractVector, y::AbstractVector) = x * y'
@inline outer(x::AbstractVector, y::Number) = x * y
@inline outer(x::AbstractVector, y) = nested_broadcast(y -> x y, y)
@inline outer(x, y) = x * y'

Both * and ' already call nested_broadcast, and you can call ' on a Number (that's just a no-op).

Comment on lines +235 to +236
- nested type (Tuple or NamedTuple), scalar type (Number, SMatrix, or
Tensor{2}/adjoint thereof), nested type (Tuple or NamedTuple)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also looks like a relic from before the RecursiveApply refactor.

Comment on lines +48 to +52
@inline dss_transform(
arg::Geometry.Covariant3Vector,
local_geometry::Geometry.LocalGeometry,
weight,
) = arg * weight
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This case is already covered by the method below.

(Tensor{2,FT,Tuple{Covariant3Axis,WAxis},SMatrix{1,1,FT,1}},LG_123_XYZ{FT}, 103),
]
end

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like testing the padded FLOP counts here doesn't serve much of a purpose. You can leave it in if you want, but we only need tests like this for sparse tensors, which won't be introduced until a future PR.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Geometry module refactor

2 participants