Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 61 additions & 19 deletions src/CLUBB_core/advance_xp2_xpyp_module.F90
Original file line number Diff line number Diff line change
Expand Up @@ -520,9 +520,11 @@ subroutine advance_xp2_xpyp( nzm, nzt, ngrdcol, sclr_dim, sclr_tol, gr, sclr_idx

type(torch_tensor), dimension(1) :: c14_tensor_in
type(torch_tensor), dimension(1) :: c14_tensor_out
real( kind = core_rknd ), dimension(5), target :: &
integer, parameter :: c14_input_size = 5
integer, parameter :: c14_output_size = 1
real( kind = core_rknd ), dimension(:,:), allocatable, save, target :: &
c14_ml_input
real( kind = core_rknd ), dimension(1), target :: &
real( kind = core_rknd ), dimension(:,:), allocatable, save, target :: &
c14_ml_output

!------------------------------ Begin Code ------------------------------
Expand Down Expand Up @@ -630,32 +632,72 @@ subroutine advance_xp2_xpyp( nzm, nzt, ngrdcol, sclr_dim, sclr_tol, gr, sclr_idx
! Once this is working we can move towards batching a column at a time or multiple
! columns at once.
if ( l_c14_ml ) then

! Allocate and check the input buffer
if ( .not. allocated(c14_ml_input) ) then
allocate( c14_ml_input(ngrdcol*nzm, c14_input_size) )
else if ( size(c14_ml_input, 1) /= ngrdcol * nzm ) then
! The size of the problem shoudn't change, but in case it does...
write(fstderr, *) err_info%err_header_global
write(fstderr,*) "The c14_ml_input array is not the correct shape."
write(fstderr,*) "This may indicate the number of columns or vertical levels has changed."
write(fstderr,*) "This should not happen..."
write(fstderr,*) "Original number of points: ", size(c14_ml_input, 1)
write(fstderr,*) "Current number of points: ", ngrdcol * nzm
err_info%err_code = clubb_fatal_error
return
end if

! Allocate and check the output buffer
if ( .not. allocated(c14_ml_output) ) then
allocate( c14_ml_output(ngrdcol*nzm, c14_output_size) )
else if ( size(c14_ml_output) /= ngrdcol * nzm ) then
! The size of the problem shoudn't change, but in case it does...
write(fstderr, *) err_info%err_header_global
write(fstderr,*) "The c14_ml_output array is not the correct shape."
write(fstderr,*) "This may indicate the number of columns or vertical levels has changed."
write(fstderr,*) "This should not happen..."
write(fstderr,*) "Original number of points: ", size(c14_ml_output)
write(fstderr,*) "Current number of points: ", ngrdcol * nzm
err_info%err_code = clubb_fatal_error
return
end if


call timer_start(C14_timer_total)
! Interpolate Lscales from thermal to momentum grid
Lscale_up_zm(:,:) = zt2zm_api( nzm, nzt, ngrdcol, gr, Lscale_up(:,:), zero_threshold )
Lscale_down_zm(:,:) = zt2zm_api( nzm, nzt, ngrdcol, gr, Lscale_down(:,:), zero_threshold )

do k = 1, nzm
do i = 1, ngrdcol
c14_ml_input(1) = up2(i,k) / em(i,k)
c14_ml_input(2) = vp2(i,k) / em(i,k)
c14_ml_input(3) = wp2(i,k) / em(i,k)
c14_ml_input(4) = Lscale_up_zm(i,k) / 1000.0_core_rknd ! Normalised by 1km per training
c14_ml_input(5) = Lscale_down_zm(i,k) / 1000.0_core_rknd ! Normalised by 1km per training

call torch_tensor_from_array(c14_tensor_in(1), c14_ml_input, torch_kCPU)
call torch_tensor_from_array(c14_tensor_out(1), c14_ml_output, torch_kCPU)
call torch_model_forward(C14_neural_net, c14_tensor_in, c14_tensor_out)

! This does not "delete" the Fortran `torch_tensor`s, but rather cleans up
! any pointers in C++ and Fortran now we are done with them before creating new
! ones with subsequent calls to `torch_tensor_from_array`
call torch_delete(c14_tensor_in)
call torch_delete(c14_tensor_out)

C14_1d(i,k) = one_third * c14_ml_output(1)
c14_ml_input((k-1) * ngrdcol + i, 1) = up2(i,k) / em(i,k)
c14_ml_input((k-1) * ngrdcol + i, 2) = vp2(i,k) / em(i,k)
c14_ml_input((k-1) * ngrdcol + i, 3) = wp2(i,k) / em(i,k)
c14_ml_input((k-1) * ngrdcol + i, 4) = Lscale_up_zm(i,k) / 1000.0_core_rknd ! Normalised by 1km per training
c14_ml_input((k-1) * ngrdcol + i, 5) = Lscale_down_zm(i,k) / 1000.0_core_rknd ! Normalised by 1km per training
end do
end do

! I establish a tensor with all points on the leftmost index
call torch_tensor_from_array(c14_tensor_in(1), c14_ml_input, torch_kCPU)

! I need to establish a corresponging output buffer
call torch_tensor_from_array(c14_tensor_out(1), c14_ml_output, torch_kCPU)

! Run inference
call torch_model_forward(C14_neural_net, c14_tensor_in, c14_tensor_out)

call torch_delete(c14_tensor_in)
call torch_delete(c14_tensor_out)

! Copy the output back to C14_1d
do k = 1, nzm
do i = 1, ngrdcol
C14_1d(i,k) = one_third * c14_ml_output((k-1) * ngrdcol + i, 1)
end do
end do

call timer_stop(C14_timer_total)
endif ! l_c14_ml

Expand Down