diff --git a/src/CLUBB_core/advance_xp2_xpyp_module.F90 b/src/CLUBB_core/advance_xp2_xpyp_module.F90 index b3068d7661..476735f120 100644 --- a/src/CLUBB_core/advance_xp2_xpyp_module.F90 +++ b/src/CLUBB_core/advance_xp2_xpyp_module.F90 @@ -520,9 +520,11 @@ subroutine advance_xp2_xpyp( nzm, nzt, ngrdcol, sclr_dim, sclr_tol, gr, sclr_idx type(torch_tensor), dimension(1) :: c14_tensor_in type(torch_tensor), dimension(1) :: c14_tensor_out - real( kind = core_rknd ), dimension(5), target :: & + integer, parameter :: c14_input_size = 5 + integer, parameter :: c14_output_size = 1 + real( kind = core_rknd ), dimension(:,:), allocatable, save, target :: & c14_ml_input - real( kind = core_rknd ), dimension(1), target :: & + real( kind = core_rknd ), dimension(:,:), allocatable, save, target :: & c14_ml_output !------------------------------ Begin Code ------------------------------ @@ -630,6 +632,38 @@ subroutine advance_xp2_xpyp( nzm, nzt, ngrdcol, sclr_dim, sclr_tol, gr, sclr_idx ! Once this is working we can move towards batching a column at a time or multiple ! columns at once. if ( l_c14_ml ) then + + ! Allocate and check the input buffer + if ( .not. allocated(c14_ml_input) ) then + allocate( c14_ml_input(ngrdcol*nzm, c14_input_size) ) + else if ( size(c14_ml_input, 1) /= ngrdcol * nzm ) then + ! The size of the problem shoudn't change, but in case it does... + write(fstderr, *) err_info%err_header_global + write(fstderr,*) "The c14_ml_input array is not the correct shape." + write(fstderr,*) "This may indicate the number of columns or vertical levels has changed." + write(fstderr,*) "This should not happen..." + write(fstderr,*) "Original number of points: ", size(c14_ml_input, 1) + write(fstderr,*) "Current number of points: ", ngrdcol * nzm + err_info%err_code = clubb_fatal_error + return + end if + + ! Allocate and check the output buffer + if ( .not. allocated(c14_ml_output) ) then + allocate( c14_ml_output(ngrdcol*nzm, c14_output_size) ) + else if ( size(c14_ml_output) /= ngrdcol * nzm ) then + ! The size of the problem shoudn't change, but in case it does... + write(fstderr, *) err_info%err_header_global + write(fstderr,*) "The c14_ml_output array is not the correct shape." + write(fstderr,*) "This may indicate the number of columns or vertical levels has changed." + write(fstderr,*) "This should not happen..." + write(fstderr,*) "Original number of points: ", size(c14_ml_output) + write(fstderr,*) "Current number of points: ", ngrdcol * nzm + err_info%err_code = clubb_fatal_error + return + end if + + call timer_start(C14_timer_total) ! Interpolate Lscales from thermal to momentum grid Lscale_up_zm(:,:) = zt2zm_api( nzm, nzt, ngrdcol, gr, Lscale_up(:,:), zero_threshold ) @@ -637,25 +671,33 @@ subroutine advance_xp2_xpyp( nzm, nzt, ngrdcol, sclr_dim, sclr_tol, gr, sclr_idx do k = 1, nzm do i = 1, ngrdcol - c14_ml_input(1) = up2(i,k) / em(i,k) - c14_ml_input(2) = vp2(i,k) / em(i,k) - c14_ml_input(3) = wp2(i,k) / em(i,k) - c14_ml_input(4) = Lscale_up_zm(i,k) / 1000.0_core_rknd ! Normalised by 1km per training - c14_ml_input(5) = Lscale_down_zm(i,k) / 1000.0_core_rknd ! Normalised by 1km per training - - call torch_tensor_from_array(c14_tensor_in(1), c14_ml_input, torch_kCPU) - call torch_tensor_from_array(c14_tensor_out(1), c14_ml_output, torch_kCPU) - call torch_model_forward(C14_neural_net, c14_tensor_in, c14_tensor_out) - - ! This does not "delete" the Fortran `torch_tensor`s, but rather cleans up - ! any pointers in C++ and Fortran now we are done with them before creating new - ! ones with subsequent calls to `torch_tensor_from_array` - call torch_delete(c14_tensor_in) - call torch_delete(c14_tensor_out) - - C14_1d(i,k) = one_third * c14_ml_output(1) + c14_ml_input((k-1) * ngrdcol + i, 1) = up2(i,k) / em(i,k) + c14_ml_input((k-1) * ngrdcol + i, 2) = vp2(i,k) / em(i,k) + c14_ml_input((k-1) * ngrdcol + i, 3) = wp2(i,k) / em(i,k) + c14_ml_input((k-1) * ngrdcol + i, 4) = Lscale_up_zm(i,k) / 1000.0_core_rknd ! Normalised by 1km per training + c14_ml_input((k-1) * ngrdcol + i, 5) = Lscale_down_zm(i,k) / 1000.0_core_rknd ! Normalised by 1km per training end do end do + + ! I establish a tensor with all points on the leftmost index + call torch_tensor_from_array(c14_tensor_in(1), c14_ml_input, torch_kCPU) + + ! I need to establish a corresponging output buffer + call torch_tensor_from_array(c14_tensor_out(1), c14_ml_output, torch_kCPU) + + ! Run inference + call torch_model_forward(C14_neural_net, c14_tensor_in, c14_tensor_out) + + call torch_delete(c14_tensor_in) + call torch_delete(c14_tensor_out) + + ! Copy the output back to C14_1d + do k = 1, nzm + do i = 1, ngrdcol + C14_1d(i,k) = one_third * c14_ml_output((k-1) * ngrdcol + i, 1) + end do + end do + call timer_stop(C14_timer_total) endif ! l_c14_ml