penalisation_braginskii_m.f90 Source File


Contents


Source Code

module penalisation_braginskii_m
    ! Module for calculating 'penalisation values' -- i.e. values which enforce a desired boundary condition -- for the 
    ! braginskii model. 
    use NETCDF
    use comm_handler_m, only : comm_handler_t
    use perf_m, only : perf_start, perf_stop
    use error_handling_grillix_m, only: handle_error, error_info_t, handle_error_netcdf
    use status_codes_grillix_m, only : GRILLIX_ERR_OTHER
    use descriptors_braginskii_m, only : BND_TYPE_DIRICHLET_ZERO, &
                                         BND_TYPE_NEUMANN, &
                                         BND_TYPE_NONE, &
                                         BND_BRAGTYPE_FLOATING_POTENTIAL_LOCAL, &
                                         BND_BRAGTYPE_PARALLEL_HEAT_TRANSMISSION, &
                                         BND_BRAGTYPE_SONIC_LOCAL, &
                                         BND_BRAGTYPE_BOHM_LOCAL, &
                                         BND_BRAGTYPE_SONIC_DRIFT_LOCAL, &
                                         BND_BRAGTYPE_BOHM_DRIFT_LOCAL, &
                                         BND_BRAGTYPE_JPAR_SMOOTHPENZERO
    use descriptors_m, only : DISTRICT_CORE, DISTRICT_WALL, DISTRICT_DOME, DISTRICT_OUT
    use equilibrium_m, only : equilibrium_t
    use mesh_cart_m, only : mesh_cart_t
    use helmholtz_solver_m, only : helmholtz_solver_t
    use parallel_map_m, only : parallel_map_t
    use penalisation_m, only : penalisation_t
    use variable_m, only: variable_t
    use precision_grillix_m, only : GP, GP_NAN, GP_EPS
    use screen_io_m, only :  get_stdout
    use inplane_operators_m, only : inplane_operators_t
    use helmholtz_solver_m, only : helmholtz_solver_t
    use helmholtz_solver_factory_m, only : parameters_helmholtz_solver_factory, helmholtz_solver_factory   
    use boundaries_perp_m, only : set_perpbnds, extrapolate_ghost_points
    use elementary_functions_m, only : step_hermite
    use equilibrium_storage_m, only : equilibrium_storage_t
    use helmholtz_netcdfio_m, only : write_netcdf_helmholtz
    ! Parameters
    use params_brag_model_m, only : &
        rhos, beta, tratio, mass_ratio_ei
    use params_brag_pardiss_model_m, only : &       
        chipar0_e, chipar0_i
    use params_brag_boundaries_perp_m, only : &    
        bnddescr_apar_core, bnddescr_apar_wall, bnddescr_apar_dome, bnddescr_apar_out
    use params_brag_boundaries_parpen_m, only : &         
        bnddescr_ne_pen, bndval_ne_pen, &
        bnddescr_te_pen, bndval_te_pen, sheath_heattransfac_e, &
        bnddescr_ti_pen, bndval_ti_pen, sheath_heattransfac_i, &        
        bnddescr_vort_pen, bndval_vort_pen, &
        bnddescr_upar_pen, bohm_frac_ti, &
        bnddescr_pot_pen, sheath_potential_lambda_sh, bnddescr_ohm_pen
    implicit none  
    private

    public :: logne_penalisation
    public :: logte_penalisation
    public :: logti_penalisation
    public :: upar_penalisation
    public :: vort_penalisation
    public :: pot_penalisation
    public :: psipar_penalisation

    private :: apply_dirichlet
    private :: apply_neumann
    private :: apply_heat_transmission_bc
    private :: add_bohm_drift_correction
    private :: compute_local_sound_speed
    
    contains

    subroutine logne_penalisation(equi, mesh_cano, map, penalisation_cano, &
                                  logne, logne_pen_vals)
        !! Sets the logarithmic-density penalisation values.
        class(equilibrium_t), intent(in) :: equi
        !! Equilibrium
        type(mesh_cart_t), intent(in) :: mesh_cano
        !! Mesh (canonical)
        type(parallel_map_t), intent(in) :: map
        !! Parallel map
        type(penalisation_t), intent(in) :: penalisation_cano
        !! Penalisation (canonical)
        type(variable_t), intent(in) :: logne
        !! Logarithmic density
        real(GP), intent(out), dimension(mesh_cano%get_n_points_inner()) :: logne_pen_vals
        !! Penalisation values for the logarithmic density
       
        select case(bnddescr_ne_pen)
            case (BND_TYPE_DIRICHLET_ZERO)
                ! Need to take the logarithm of the dirichlet value, 
                ! since it is assumed that the value given is in real units
                call apply_dirichlet(mesh_cano, penalisation_cano, &
                                     log(bndval_ne_pen), logne_pen_vals)

        case (BND_TYPE_NEUMANN)

            call apply_neumann(.false., mesh_cano, map, penalisation_cano, logne, logne_pen_vals)

        case default
            call handle_error('Pen_bndtpy not valid', &
                              GRILLIX_ERR_OTHER, __LINE__, __FILE__, &
                              error_info_t('Boundary_condition_type: ', &
                                           [bnddescr_ne_pen]))
        end select        

    end subroutine

    subroutine logte_penalisation(equi, mesh_cano, map, penalisation_cano, &
                                  ne, te, logte, full_upar, logte_pen_vals)
        !! Sets the logarithmic electron temperature penalisation values.
        class(equilibrium_t), intent(in) :: equi
        !! Equilibrium
        type(mesh_cart_t), intent(in) :: mesh_cano
        !! Mesh (canonical)
        type(parallel_map_t), intent(in) :: map
        !! Parallel map
        type(penalisation_t), intent(in) :: penalisation_cano
        !! Penalisation (canonical)
        type(variable_t), intent(in) :: ne, te, logte, full_upar
        !! Variables to compute penalisation values from, in order of left to right
        !!     non-logarithmic density, electron temperature
        !!     logarithmic electron temperature
        !!     parallel ion velocity mapped onto the canonical grid
        real(GP), intent(out), dimension(mesh_cano%get_n_points_inner()) :: logte_pen_vals
        !! Penalisation values for the logarithmic electron temperature

        select case(bnddescr_te_pen)
            case (BND_TYPE_DIRICHLET_ZERO)
                ! Need to take the logarithm of the dirichlet value, since it is assumed that the value given is in real units
                call apply_dirichlet(mesh_cano, penalisation_cano, &
                                     log(bndval_te_pen), logte_pen_vals)
                                    
        case (BND_TYPE_NEUMANN)

            call apply_neumann(.false., mesh_cano, map, penalisation_cano, logte, logte_pen_vals)
        
        case (BND_BRAGTYPE_PARALLEL_HEAT_TRANSMISSION)

            call apply_heat_transmission_bc(mesh_cano, map, penalisation_cano, &
                                            te, logte, ne, full_upar, &
                                            chipar0_e, sheath_heattransfac_e, &
                                            logte_pen_vals)

        case default
            call handle_error('Pen_bndtpy not valid', &
                              GRILLIX_ERR_OTHER, __LINE__, __FILE__, &
                              error_info_t('Boundary_condition_type: ', &
                                           [bnddescr_te_pen]))
        end select        

    end subroutine

    subroutine logti_penalisation(equi, mesh_cano, map, penalisation_cano, &
                                  ne, ti, logti, full_upar, logti_pen_vals)
        !! Sets the logarithmic ion temperature penalisation values.
        class(equilibrium_t), intent(in) :: equi
        !! Equilibrium
        type(mesh_cart_t), intent(in) :: mesh_cano
        !! Mesh (canonical)
        type(parallel_map_t), intent(in) :: map
        !! Parallel map
        type(penalisation_t), intent(in) :: penalisation_cano
        !! Penalisation (canonical)
        type(variable_t), intent(in) :: ne, ti, logti, full_upar
        !! Variables to compute penalisation values from, in order of left to right
        !!     non-logarithmic density, ion temperature
        !!     logarithmic ion temperature
        !!     parallel ion velocity mapped to the canonical grid
        real(GP), intent(out), dimension(mesh_cano%get_n_points_inner()) :: logti_pen_vals
        !! Penalisation values for the logarithmic ion temperature

        select case(bnddescr_ti_pen)
            case (BND_TYPE_DIRICHLET_ZERO)
                ! Need to take the logarithm of the dirichlet value, since it is assumed that the value given is in real units
                call apply_dirichlet(mesh_cano, penalisation_cano, &
                                     log(bndval_ti_pen), logti_pen_vals)

        case (BND_TYPE_NEUMANN)

            call apply_neumann(.false., mesh_cano, map, penalisation_cano, logti, logti_pen_vals)
        
        case (BND_BRAGTYPE_PARALLEL_HEAT_TRANSMISSION)

            call apply_heat_transmission_bc(mesh_cano, map, penalisation_cano, ti, logti, ne, full_upar, &
                                            chipar0_i, sheath_heattransfac_i, &
                                            logti_pen_vals)

        case default
            call handle_error('Pen_bndtpy not valid', &
                              GRILLIX_ERR_OTHER, __LINE__, __FILE__, &
                              error_info_t('Boundary_condition_type: ', &
                                           [bnddescr_ti_pen]))
        end select        

    end subroutine

    subroutine upar_penalisation(equi, equi_on_stag, mesh_stag, map, penalisation_stag, opsinplane_stag, &
                                 upar, stag_te, stag_ti, stag_pot, upar_pen_vals)
        !! Sets the parallel velocity penalisation values.
        class(equilibrium_t), intent(in) :: equi
        !! Equilibrium
        type(equilibrium_storage_t), intent(in) :: equi_on_stag
        !! Equilibrium quantites on staggered mesh
        type(mesh_cart_t), intent(in) :: mesh_stag
        !! Mesh (staggered)
        type(parallel_map_t), intent(in) :: map
        !! Parallel map
        type(penalisation_t), intent(in) :: penalisation_stag
        !! Penalisation (staggered)
        type(inplane_operators_t), intent(in) :: opsinplane_stag
        !! In-plane operators (staggered)
        type(variable_t), intent(in) :: upar
        !! Parallel velocity 
        type(variable_t), intent(in) :: stag_te
        !! Electron temperature on staggered grid
        type(variable_t), intent(in) :: stag_ti
        !! Ion temperature on staggered grid
        type(variable_t), intent(in) :: stag_pot
        !! Electrostatic potential on staggered grid
        real(GP), intent(out), dimension(mesh_stag%get_n_points_inner()) :: upar_pen_vals
        !! Penalisation values for the parallel velocity

        logical :: set_sound_speed
        logical :: apply_drift_correction
        logical :: set_largerless
        
        real(GP) :: pen_dirind
        integer :: ki, l
        
        real(GP), dimension(mesh_stag%get_n_points_inner()) :: upstream_upar
        

        select case(bnddescr_upar_pen)
        
            case (BND_TYPE_DIRICHLET_ZERO)
                ! Artificial case, penalising parallel velocity to zero
                call apply_dirichlet(mesh_stag, penalisation_stag, &
                                    0.0_GP, upar_pen_vals)
                                                   
                set_sound_speed         = .false.
                apply_drift_correction  = .false.
                set_largerless          = .false.

            case (BND_TYPE_NEUMANN)

                call apply_neumann(.true., mesh_stag, map, penalisation_stag, upar, upar_pen_vals)
                
                set_sound_speed         = .false.
                apply_drift_correction  = .false.
                set_largerless          = .false.

            case (BND_BRAGTYPE_SONIC_LOCAL)
            
                set_sound_speed         = .true.
                apply_drift_correction  = .false.
                set_largerless          = .false.                

            case (BND_BRAGTYPE_BOHM_LOCAL)
            
                set_sound_speed         = .true.
                apply_drift_correction  = .false.
                set_largerless          = .true.
                
            case (BND_BRAGTYPE_SONIC_DRIFT_LOCAL) 
            
                set_sound_speed         = .true.
                apply_drift_correction  = .true.
                set_largerless          = .false.
            
            case (BND_BRAGTYPE_BOHM_DRIFT_LOCAL)
            
                set_sound_speed         = .true.
                apply_drift_correction  = .true.
                set_largerless          = .true.
                
            case default
                call handle_error('Pen_bndtpy not valid', &
                                  GRILLIX_ERR_OTHER, __LINE__, __FILE__, &
                                  error_info_t('Boundary_condition_type: ', &
                                               [bnddescr_upar_pen]))
        end select
        
        
        if (set_sound_speed) then
        
            !$omp parallel default(none) private(ki, l, pen_dirind) &
            !$omp           shared(mesh_stag, penalisation_stag, stag_te, stag_ti, upar_pen_vals)
            !$omp do
            do ki = 1, mesh_stag%get_n_points_inner()
                l = mesh_stag%inner_indices(ki)
                pen_dirind = penalisation_stag%get_dirindfun_val(ki)

                upar_pen_vals(ki) = pen_dirind * compute_local_sound_speed(stag_te, stag_ti, l)
            
            end do
            !$omp end do
            !$omp end parallel
        
        endif
        
        if (apply_drift_correction) then
            call add_bohm_drift_correction(equi, equi_on_stag, mesh_stag, penalisation_stag, opsinplane_stag, &
                                           stag_pot, upar_pen_vals) 
        endif
        
        if (set_largerless) then
        
            call apply_neumann(.true., mesh_stag, map, penalisation_stag, upar, upstream_upar)
            !$omp parallel default(none) private(ki, l, pen_dirind) &
            !$omp          shared(mesh_stag, penalisation_stag, upstream_upar, upar_pen_vals)
            !$OMP DO
            do ki = 1, mesh_stag%get_n_points_inner()
                l = mesh_stag%inner_indices(ki)
                pen_dirind = penalisation_stag%get_dirindfun_val(ki)
                if (pen_dirind >= 0.0_GP) then
                    upar_pen_vals(ki) = max(upar_pen_vals(ki), upstream_upar(ki)*abs(pen_dirind)) 
                else
                    upar_pen_vals(ki) = min(upar_pen_vals(ki), upstream_upar(ki)*abs(pen_dirind))
                endif
            enddo
            !$omp end do
            !$omp end parallel
            
        endif

    end subroutine
    
    subroutine vort_penalisation(equi, mesh_cano, map, penalisation_cano, vort, vort_pen_vals)
        !! Sets the vorticity penalisation values.
        class(equilibrium_t), intent(in) :: equi
        !! Equilibrium
        type(mesh_cart_t), intent(in) :: mesh_cano
        !! Mesh (canonical)
        type(parallel_map_t), intent(in) :: map
        !! Parallel map
        type(penalisation_t), intent(in) :: penalisation_cano
        !! Penalisation (canonical)
        type(variable_t), intent(in) :: vort
        !! Vorticity
        real(GP), intent(out), dimension(mesh_cano%get_n_points_inner()) :: vort_pen_vals
        !! Penalisation values for the vorticity

        select case(bnddescr_vort_pen)
            case (BND_TYPE_DIRICHLET_ZERO)
            
                call apply_dirichlet(mesh_cano, penalisation_cano, &
                                     bndval_vort_pen, vort_pen_vals)            

        case (BND_TYPE_NEUMANN)

            call apply_neumann(.false., mesh_cano, map, penalisation_cano, vort, vort_pen_vals)

        case default
            call handle_error('Pen_bndtpy not valid', &
                              GRILLIX_ERR_OTHER, __LINE__, __FILE__, &
                              error_info_t('Boundary_condition_type: ', &
                                           [bnddescr_vort_pen]))
        end select        

    end subroutine 

    subroutine pot_penalisation(equi, mesh_cano, map, penalisation_cano, &
                                te, pot_pen_vals)
        !! Sets the electrostatic potential penalisation values.
        !! The boundary condition BND_BRAGTYPE_FLOATING_POTENTIAL_LOCAL should always be combined with
        !! a zero-current boundary condition
        class(equilibrium_t), intent(in) :: equi
        !! Equilibrium
        type(mesh_cart_t), intent(in) :: mesh_cano
        !! Mesh (canonical)
        type(parallel_map_t), intent(in) :: map
        !! Parallel map
        type(penalisation_t), intent(in) :: penalisation_cano
        !! Penalisation (canonical)
        type(variable_t), intent(in) :: te
        !! Non-logarithmic electron temperature
        real(GP), intent(out), dimension(mesh_cano%get_n_points_inner()) :: pot_pen_vals
        !! Penalisation values for the potential

        integer :: ki, l
        
        select case(bnddescr_pot_pen)
            case (BND_TYPE_DIRICHLET_ZERO, BND_TYPE_NONE)
                
                call apply_dirichlet(mesh_cano, penalisation_cano, &
                                        0.0_GP, pot_pen_vals)
            
            case (BND_BRAGTYPE_FLOATING_POTENTIAL_LOCAL)

                !$omp parallel default(none) private(ki, l)  &
                !$omp          shared(mesh_cano, sheath_potential_lambda_sh, te, pot_pen_vals)
                !$omp do
                do ki = 1, mesh_cano%get_n_points_inner()
                    l = mesh_cano%inner_indices(ki)
                    ! Set the potential equal to the local electron temperature times the sheath potential
                    pot_pen_vals(ki) = sheath_potential_lambda_sh * te%vals(l)
                enddo
                !$omp end do
                !$omp end parallel

            case default
                call handle_error('Pen_bndtpy not valid', &
                                  GRILLIX_ERR_OTHER, __LINE__, __FILE__, &
                                  error_info_t('Boundary_condition_type: ', &
                                               [bnddescr_pot_pen]))
        end select        

    end subroutine

    subroutine psipar_penalisation(comm_handler, equi, mesh_stag, hsolver_stag, map, penalisation_stag, &
                                   jpar_t_extrapolate, nevar_adv_stag, apar_pen, psipar_pen_vals, sinfo, res)
        !! Sets the penalisation values for the modified electromagnetic potential.
        !! Currently, the only option is to set a dirichelet-zero value, which penalised
        !! jpar and apar to zero at the boundary
        type(comm_handler_t), intent(in) :: comm_handler
        !! Communicators
        class(equilibrium_t), intent(in) :: equi
        !! Equilibrium
        type(mesh_cart_t), intent(in) :: mesh_stag
        !! Mesh (staggered)
        class(helmholtz_solver_t), intent(inout) :: hsolver_stag
        !! Elliptic (2D) solver on staggered mesh
        type(parallel_map_t), intent(in) :: map
        !! Parallel map
        type(penalisation_t), intent(in) :: penalisation_stag
        !! Penalisation (staggered)
        type(variable_t), intent(in) :: jpar_t_extrapolate
        !! Guess for parallel current at timestep t+1 (e.g. extrapolated)
        real(GP), intent(in), dimension(mesh_stag%get_n_points()) :: nevar_adv_stag
        !! Density at timestep t+1 on staggered grid
        real(GP), intent(inout), dimension(mesh_stag%get_n_points()) :: apar_pen
        !! Penalisation values for the parallel electromagnetic potential
        !! Attention: dimension is n_points here for solver
        !! On input: initial guess
        real(GP), intent(out), dimension(mesh_stag%get_n_points_inner()) :: psipar_pen_vals
        !! Penalisation values for the modified electromagnetic potential
        integer, intent(out) :: sinfo
        !! Info from solver
        real(GP), intent(out) :: res
        !! Residual of penalisation solve  
    
        integer :: ki, kb, kg, l, nf90_stat, nf90_id
        real(GP) :: fac, phi_stag, x, y, pchar

        real(GP), dimension(mesh_stag%get_n_points()) :: jpar_pen, co
        real(GP), dimension(mesh_stag%get_n_points_inner()) :: wrk_inner
        real(GP), dimension(mesh_stag%get_n_points_inner()) :: xi, lambda

        character(len=3) :: plane_c
        
        phi_stag = mesh_stag%get_phi()

        ! Compute desired values for penalisationj of parallel current

        select case(bnddescr_ohm_pen)
                       
            case (BND_BRAGTYPE_JPAR_SMOOTHPENZERO)
            
                !$omp parallel default(none) private(ki) shared(mesh_stag, wrk_inner)
                !$omp do
                do ki = 1, mesh_stag%get_n_points_inner()
                    wrk_inner(ki) = 0.0_GP
                enddo
                !$omp end do
                !$omp end parallel
            
            case (BND_TYPE_NEUMANN)
            
                call apply_neumann(.true., mesh_stag, map, penalisation_stag, jpar_t_extrapolate, wrk_inner)
                
            case (BND_TYPE_DIRICHLET_ZERO)
                ! Direct penalisation of psi_par to zero
            
                call apply_dirichlet(mesh_stag, penalisation_stag, &
                                        0.0_GP, psipar_pen_vals)
                                        
                return ! No additional solve needed
                
            case default
                
                call handle_error('Pen_bndtpy not valid', &
                                  GRILLIX_ERR_OTHER, __LINE__, __FILE__, &
                                  error_info_t('Boundary_condition_type: ', &
                                              [bnddescr_ohm_pen]))    
        end select
        
        ! Smooth jpar_pen and perform solve for psipar_pen
       
        call perf_start('../../psiparpen_setup')
        fac = rhos**2            
        !$omp parallel default(none) private(ki, kb, kg, l, pchar, x, y) &
        !$omp          shared(equi, mesh_stag, phi_stag, penalisation_stag, fac, &
        !$omp                 jpar_pen, co, xi, lambda, jpar_t_extrapolate, wrk_inner)
        !$omp do
        do ki = 1, mesh_stag%get_n_points_inner()
            l = mesh_stag%inner_indices(ki)
            x = mesh_stag%get_x(l)
            y = mesh_stag%get_y(l)
            pchar =  penalisation_stag%get_charfun_val(ki)
            jpar_pen(l) = wrk_inner(ki) * pchar + (1.0_GP-pchar)* jpar_t_extrapolate%vals(l)
            co(l)       = equi%jacobian(x, y, phi_stag)
            xi(ki)      = fac / equi%jacobian(x, y, phi_stag)
            lambda(ki)  = 0.0_GP
        enddo
        !$omp end do
        !$omp do
        do kb = 1, mesh_stag%get_n_points_boundary()
            l = mesh_stag%boundary_indices(kb)
            x = mesh_stag%get_x(l)
            y = mesh_stag%get_y(l)
            jpar_pen(l) = 0.0_GP ! --> Is actually boundary condition for apar
            co(l) = equi%jacobian(x, y, phi_stag)
        enddo
        !$omp end do
        !$omp do
        do kg = 1, mesh_stag%get_n_points_ghost()
            l = mesh_stag%ghost_indices(kg)
            x = mesh_stag%get_x(l)
            y = mesh_stag%get_y(l)
            jpar_pen(l) = 0.0_GP ! --> Is actually boundary condition for apar
            co(l) = equi%jacobian(x, y, phi_stag)
        enddo
        !$omp end do
        !$omp end parallel
        call perf_stop('../../psiparpen_setup')

        call perf_start('../../psiparpen_solve_init')
            
        call hsolver_stag%update(co, lambda, xi, &
                                 bnddescr_apar_core, &
                                 bnddescr_apar_wall, &
                                 bnddescr_apar_dome, &
                                 bnddescr_apar_out)
        call perf_stop('../../psiparpen_solve_init')
        
        call perf_start('../../psiparpen_solve_gmres')
        call hsolver_stag%solve(jpar_pen, apar_pen, res, sinfo)
        call perf_stop('../../psiparpen_solve_gmres')
        
        if (sinfo < 0) then
            ! Write out state if solver failed
            write(plane_c,'(I3.3)')comm_handler%get_plane()
            
            ! Overwrites existing file
            nf90_stat = nf90_create('psiparpen_solve_failstate_plane'//plane_c//'.nc', &
                                    NF90_NETCDF4+NF90_CLOBBER, nf90_id)
                                    
            call handle_error_netcdf(nf90_stat, __LINE__, __FILE__)
            call write_netcdf_helmholtz(nf90_id, mesh_stag, &
                                        bnddescr_apar_core, &
                                        bnddescr_apar_wall, &
                                        bnddescr_apar_dome, &
                                        bnddescr_apar_out, &
                                        co, lambda, xi, &
                                        jpar_pen, &
                                        guess=apar_pen, &
                                        sol=apar_pen, &
                                        hcsr_write_on=.true.)
            nf90_stat = nf90_close(nf90_id)
        endif

        !$OMP PARALLEL PRIVATE(ki, l)
        !$OMP DO
        do ki = 1, mesh_stag%get_n_points_inner()
            l = mesh_stag%inner_indices(ki)
            psipar_pen_vals(ki) = beta * apar_pen(l) + mass_ratio_ei / nevar_adv_stag(l) * jpar_pen(l)
        enddo
        !$OMP END DO
        !$OMP END PARALLEL           

    end subroutine

    subroutine apply_dirichlet(mesh, penalisation, pen_bndval, penalised_values)
        !! Calculates values for a zeroth order dirichlet boundary condition (i.e. direct value set)
        type(mesh_cart_t), intent(in) :: mesh
        !! Mesh (any)
        type(penalisation_t), intent(in) :: penalisation
        !! Penalisation (any)
        real(GP), intent(in) :: pen_bndval
        !! Value or scale factor for penalisation boundary
        real(GP), intent(out), dimension(mesh%get_n_points_inner()) :: penalised_values
        !! Array with dirichlet-zero values

        integer :: ki

        !$omp parallel default(none) private(ki) &
        !$omp          shared(mesh, penalisation, pen_bndval, penalised_values)
        !$omp do
        do ki = 1, mesh%get_n_points_inner()
            penalised_values(ki) = pen_bndval
        enddo
        !$omp end do
        !$omp end parallel

    end subroutine

    subroutine apply_neumann(staggered, mesh, map, penalisation, values, penalised_values, gradients)
        !! Calculate penalisation values for a local neumann-zero condition, by taking the upstream point value
        logical, intent(in) :: staggered
        !! if true applied for variable defined on staggered mesh, otherwise on canonical
        type(mesh_cart_t), intent(in) :: mesh
        !! Mesh (any, consistent with staggering)
        type(parallel_map_t), intent(in) :: map
        !! Parallel map
        type(penalisation_t), intent(in) :: penalisation
        !! Penalisation (any, consistent with staggering)
        type(variable_t), intent(in) :: values
        !! Variable with filled halos
        real(GP), optional, dimension(mesh%get_n_points_inner()) :: gradients
        !! Point-wise gradients, defined only on the grid (indexed by ka)
        real(GP), intent(out), dimension(mesh%get_n_points_inner()) :: penalised_values
        !! Array with neumann-zero values, defined only on the grid (indexed by ki)

        integer :: ki, l
        real(GP) :: pen_dirind
        real(GP), dimension(mesh%get_n_points()) :: vals_fwd, vals_bwd
        real(GP) :: grad, fwd_bwd_average
        real(GP) :: map_dpar_fwd, map_dpar_bwd

        !$omp parallel default(none) &
        !$omp          private(ki, l, grad, pen_dirind, fwd_bwd_average, map_dpar_bwd, map_dpar_fwd) &
        !$omp          shared(mesh, map, penalisation, values, vals_fwd, vals_bwd, gradients, &
        !$omp                 penalised_values, staggered)
        
        ! Compute map values from forward and backward plane
        if (staggered) then
            call map%upstream_stag_from_stag_fwd(values%hfwd, vals_fwd)
            call map%upstream_stag_from_stag_bwd(values%hbwd, vals_bwd)
        else
            call map%upstream_cano_from_cano_fwd(values%hfwd, vals_fwd)
            call map%upstream_cano_from_cano_bwd(values%hbwd, vals_bwd)
        endif

        !$omp do
        do ki = 1, mesh%get_n_points_inner()
            l = mesh%inner_indices(ki)

            if (present(gradients)) then
                grad = gradients(ki)
            else
                grad = 0.0_GP
            endif

            if (staggered) then
                map_dpar_bwd = map%dpar_stag_stag_bwd(l)
                map_dpar_fwd = map%dpar_stag_stag_fwd(l)
            else
                map_dpar_bwd = map%dpar_cano_cano_bwd(l)
                map_dpar_fwd = map%dpar_cano_cano_fwd(l)
            endif

            pen_dirind = penalisation%get_dirindfun_val(ki)
            fwd_bwd_average = (1.0_GP - abs(pen_dirind)) * 0.5_GP * (vals_bwd(l) + vals_fwd(l))

            if (pen_dirind > 0.0_GP) then
                ! if magnetic field directed towards target use back map values
                penalised_values(ki) = + pen_dirind*(vals_bwd(l) + grad * map_dpar_bwd) &
                                       + fwd_bwd_average !Correction for the case where abs(pen_dirind) isn't 1.0
            else
                ! if magnetic field directed away from target use forward map values
                penalised_values(ki) = - pen_dirind*(vals_fwd(l) - grad * map_dpar_fwd) &
                                       + fwd_bwd_average !Correction for the case where abs(pen_dirind) isn't 1.0
            endif
        enddo
        !$omp end do
        !$omp end parallel
    
    end subroutine

    subroutine apply_heat_transmission_bc(mesh_cano, map, penalisation_cano, &
                                          temperature, log_temperature, ne, full_upar, &
                                          chi_par_0, sheath_transmission_factor, &
                                          penalised_values)
        !! Sets a boundary condition of the form \nabla_parallel xi_s = -gamma_sh*ne*upar / chipar(T_s),
        !! where s = ions, electrons
        type(mesh_cart_t), intent(in) :: mesh_cano
        !! Mesh (canonical)
        type(parallel_map_t), intent(in) :: map
        !! Parallel map
        type(penalisation_t), intent(in) :: penalisation_cano
        !! Penalisation (canonical
        type(variable_t), intent(in) :: temperature, log_temperature
        !! Temperature, non-logarithmic and logarithmic, either electron or ion, which you want to penalise
        type(variable_t), intent(in) :: ne, full_upar
        !! Density (non-logarithmic) and parallel velocity on the canonical grid
        real(GP), intent(in) :: chi_par_0
        !! Dimensionless chipar heat conduction factor
        real(GP), intent(in) :: sheath_transmission_factor
        !! Anomalous heat transmission factor for the temperature which you want to penalise
        real(GP), intent(out), dimension(mesh_cano%get_n_points_inner()) :: penalised_values
        !! Array with penalisation values for the heat-transmission boundary condition
        
        real(GP), dimension(mesh_cano%get_n_points_inner()) :: par_gradient
        real(GP) :: chipar
        integer :: ki, l

        !$omp parallel default(none) private(ki, l, chipar) &
        !$omp          shared(mesh_cano, chi_par_0, sheath_transmission_factor, &
        !$omp                 temperature, par_gradient, ne, full_upar)
        !$omp do
        do ki = 1, mesh_cano%get_n_points_inner()
            l = mesh_cano%inner_indices(ki)
            
            chipar = chi_par_0 * sqrt(temperature%vals(l) ** 5)
            par_gradient(ki) = -sheath_transmission_factor * ne%vals(l) * full_upar%vals(l) / chipar

        enddo
        !$omp end do
        !$omp end parallel

        call apply_neumann(.false., mesh_cano, map, penalisation_cano, &
                           log_temperature, penalised_values, par_gradient)

    end subroutine

    real(GP) function compute_local_sound_speed(stag_te, stag_ti, l) result(sound_speed)
        !! Computes the magnitude of the local sound speed
        !! N.b. sw_boundary_flow_ti is a switch whether ion temperature should be included in the calculation of the sound speed
        type(variable_t), intent(in) :: stag_te, stag_ti
        !! Electron and ion temperatures (non-logarithmic), mapped to the staggered grid
        integer, intent(in) :: l
        !! Mesh (staggered) index

        sound_speed = sqrt(stag_te%vals(l) + bohm_frac_ti * tratio * stag_ti%vals(l))

    end function
    
    
    subroutine add_bohm_drift_correction(equi, equi_on_stag, mesh_stag, penalisation_stag, opsinplane_stag, &
                                         stag_pot, upar_pen) 
        !! Adds ExB drift correction to the Bohm Chodura boundary conditions, i.e.
        !! upar_pen = upar_pen + v_ExB * n / (n /cdot b)
        class(equilibrium_t), intent(in) :: equi
        !! Equilibrium
        type(equilibrium_storage_t), intent(in) :: equi_on_stag
        !! Equilibrium quantities on staggered mesh
        type(mesh_cart_t), intent(in) :: mesh_stag
        !! Mesh (staggered)
        type(penalisation_t), intent(in) :: penalisation_stag
        !! Penalisation (staggered)
        type(inplane_operators_t), intent(in) :: opsinplane_stag
        !! In-plane operators (staggered)
        type(variable_t), intent(in) :: stag_pot
        !! Electrostatic potential on staggered grid
        real(GP), intent(inout), dimension(mesh_stag%get_n_points_inner()) :: upar_pen
        !! Drift correction term
    
        integer :: ki, kb, l
        real(GP) :: x, y, bx, by, epolx, epoly, ndotb, ddx_penchi, ddy_penchi, ddx_pot, ddy_pot
        real(GP) :: advection_penchi, advection_epol, envelope, drift_corr 
        real(GP), dimension(mesh_stag%get_n_points()) :: penchi
        integer, dimension(mesh_stag%get_n_points_boundary()) :: bnd_descrs
                
        !$omp parallel default(none) &
        !$omp          private(ki, kb, l, x, y, bx, by, epolx, epoly, ndotb, &
        !$omp                  ddx_penchi, ddy_penchi, ddx_pot, ddy_pot, &
        !$omp                  advection_penchi, advection_epol, envelope, drift_corr) &
        !$omp          shared(equi_on_stag, mesh_stag, penalisation_stag, opsinplane_stag, &
        !$omp                 bnd_descrs, penchi, stag_pot, upar_pen)
        
        ! Penalisation function extrapolated to whole mesh
        !$omp do
        do ki = 1, mesh_stag%get_n_points_inner()
            l = mesh_stag%inner_indices(ki)
            penchi(l) = penalisation_stag%get_charfun_val(ki)
        enddo
        !$omp end do
        !$omp do
        do kb = 1, mesh_stag%get_n_points_boundary()
            bnd_descrs(kb) = BND_TYPE_NEUMANN
        enddo
        !$omp end do
        
        call set_perpbnds(mesh_stag, bnd_descrs, penchi)
        call extrapolate_ghost_points(mesh_stag, penchi)
        
        !$omp do
        do ki = 1, mesh_stag%get_n_points_inner()
            l = mesh_stag%inner_indices(ki)
            x = mesh_stag%get_x(l)
            y = mesh_stag%get_y(l)
                        
            bx = equi_on_stag%bx(l) / equi_on_stag%absb(l)
            by = equi_on_stag%by(l) / equi_on_stag%absb(l)
            
            ! Compute v_E*\nabla pnechi / (\nabla penchi * bhat)
            advection_penchi = -1.0_GP / equi_on_stag%btor(l) * opsinplane_stag%arakawa(stag_pot%vals, penchi, l) 
            
            ddx_penchi = opsinplane_stag%ddx(penchi, l)
            ddy_penchi = opsinplane_stag%ddy(penchi, l)
            ndotb = ddx_penchi * bx + ddy_penchi * by
            
            if ( (penchi(l)>1.0E-3_GP) .and. (penchi(l)<0.999_GP) ) then
                advection_penchi = advection_penchi / ndotb
            else
                advection_penchi = 0.0_GP
            endif
            
            ! Compute v_E* epol / (epol * bhat)            
            ddx_pot = opsinplane_stag%ddx(stag_pot%vals, l)
            ddy_pot = opsinplane_stag%ddy(stag_pot%vals, l)
            
            epolx = equi_on_stag%epol(l,1)
            epoly = equi_on_stag%epol(l,2)
            ndotb = epolx*bx + epoly*by
            advection_epol = -1.0_GP / equi_on_stag%btor(l) * (ddx_pot*epoly - ddy_pot*epolx) /  ndotb
            
            ! Put envelope on it, i.e. in flat chi region use advection_epol, in gradient-chi region use advection_penchi
            envelope = step_hermite(0.11_GP, penchi(l), 0.2_GP) - step_hermite(0.89_GP, penchi(l), 0.2_GP)
            drift_corr = envelope * advection_penchi + (1.0_GP-envelope)*advection_epol 
            upar_pen(ki) = upar_pen(ki) - drift_corr

        enddo
        !$omp end do
        !$omp end parallel
           
    end subroutine

end module