#!/usr/bin/env bash

# These files will be patched and must exist
files=("./scripts/configure.sh"
       "./src/SOURCES"
       "./src/mdmain_utils.mod.F90"
       "./src/control_utils.mod.F90"
       "./src/md_driver.mod.F90"
       "./src/system.mod.F90")

# Check if files exist
for file in ${files[@]}
do
    if [ ! -w "$file" ]
    then
        echo "ERROR: $file does not exist or is not writable. Are we in the CPMD root directory?"
        exit
    fi
done

# These files will be created and must not exist
new_files=("./src/plumed_wrapper.mod.F90")

# Check that files do not exist
for file in ${new_files[@]}
do
    if [ -e "$file" ]
    then
        echo "ERROR: $file already exists. Has CPMD already been patched with PLUMED?"
        exit
    fi
done

# Create patches
cat << EOF > patch.0
diff ./src/SOURCES ./src/SOURCES
272c
          dftd3_driver.mod.F90 mimic_wrapper.mod.F90 plumed_wrapper.mod.F90
.
EOF
cat << EOF > patch.1
diff ./src/mdmain_utils.mod.F90 ./src/mdmain_utils.mod.F90
1285a
#ifdef __PLUMED
    IF (cntl%plumed .and. paral%io_parent) call plumed_finalize
#endif
.
805a
#ifdef __PLUMED
       IF (cntl%plumed .AND. paral%io_parent) call plumed_calc(infi, tau0, fion, ener_com%etot)
#endif
.
478c
#ifdef __PLUMED
    IF (cntl%plumed .and. paral%io_parent) call plumed_init
#endif
.
128a
#ifdef __PLUMED
  USE plumed_wrapper,                  ONLY: plumed_init, &
                                             plumed_calc, &
                                             plumed_finalize
#endif
.
EOF
cat << EOF > patch.2
diff ./src/control_utils.mod.F90 ./src/control_utils.mod.F90
4051a
             ELSEIF ( keyword_contains(line,'PLUMED') ) THEN
#ifdef __PLUMED
                cntl%plumed = .TRUE.
#else
                something_went_wrong = .true.
                error_message        = 'PLUMED IS REQUESTED BUT CPMD IS NOT COMPILED WITH IT!'
#endif
.
EOF
cat << EOF > patch.3
diff ./src/md_driver.mod.F90 ./src/md_driver.mod.F90
1590a
#ifdef __PLUMED
    if (cntl%plumed .and. paral%io_parent) call plumed_finalize
#endif
.
1125a
#ifdef __PLUMED
       if (cntl%plumed .AND. paral%io_parent) call plumed_calc(loopnfi, tau0, fion, ener_com%etot)
#endif
.
801a
#ifdef __PLUMED
    if (cntl%plumed .and. paral%io_parent) call plumed_init
#endif
.
147a
#ifdef __PLUMED
  USE plumed_wrapper,                  ONLY: plumed_init, &
                                             plumed_calc, &
                                             plumed_finalize
#endif
.
EOF
cat << EOF > patch.4
diff ./src/system.mod.F90 ./src/system.mod.F90
566a
     LOGICAL :: plumed = .FALSE.
.
EOF
cat << EOF > patch.5
diff ./src/plumed_wrapper.mod.F90 ./src/plumed_wrapper.mod.F90
0a
implicit none
integer, parameter :: min_api = 2
integer, parameter :: supported_api = 6
integer, parameter :: real_size = 8

real(real_8), parameter :: bohr2nm = 0.1_real_8 / fbohr
real(real_8), parameter :: hartree2kjmol = au_kjm
real(real_8), parameter :: au2ps = au_fs / 1000_real_8
real(real_8), parameter :: charge_unit = 1.0_real_8
real(real_8), parameter :: mass_unit = 1.0_real_8


character(len=*), parameter :: code_name = "CPMD" // char(0)

character(len=*), parameter :: cmd_version = "getApiVersion" // char(0)

character(len=*), parameter :: cmd_precision = "setRealPrecision" // char(0)
character(len=*), parameter :: cmd_ener_unit = "setMDEnergyUnits" // char(0)
character(len=*), parameter :: cmd_len_unit = "setMDLengthUnits" // char(0)
character(len=*), parameter :: cmd_time_unit = "setMDTimeUnits" // char(0)
character(len=*), parameter :: cmd_charge_unit = "setMDChargeUnits" // char(0)
character(len=*), parameter :: cmd_mass_unit = "setMDMassUnits" // char(0)

character(len=*), parameter :: cmd_dat_file = "setPlumedDat" // char(0)
character(len=*), parameter :: cmd_set_natoms = "setNatoms" // char(0)
character(len=*), parameter :: cmd_set_md = "setMDEngine" // char(0)
character(len=*), parameter :: cmd_set_log = "setLog" // char(0)
character(len=*), parameter :: cmd_set_logfile = "setLogFile" // char(0)
character(len=*), parameter :: cmd_set_timestep = "setTimestep" // char(0)
character(len=*), parameter :: cmd_set_kbt = "setKbT" // char(0)
character(len=*), parameter :: cmd_set_box = "setBox" // char(0)
character(len=*), parameter :: cmd_set_restart = "setRestart" // char(0)

character(len=*), parameter :: cmd_set_step = "setStep" // char(0)
character(len=*), parameter :: cmd_set_masses = "setMasses" // char(0)
character(len=*), parameter :: cmd_set_coords = "setPositions" // char(0)
character(len=*), parameter :: cmd_set_forces = "setForces" // char(0)
character(len=*), parameter :: cmd_set_virial = "setVirial" // char(0)
character(len=*), parameter :: cmd_set_energy = "setEnergy" // char(0)

character(len=*), parameter :: cmd_init = "init" // char(0)
character(len=*), parameter :: cmd_compute = "calc" // char(0)

character(len=32), save :: plumed
character(len=*), parameter :: plumed_input = "plumed.inp" // char(0)
character(len=*), parameter :: plumed_output = "plumed.log" // char(0)

real(real_8), allocatable, save :: temp_coords(:,:)
real(real_8), allocatable, save :: temp_masses(:)
real(real_8), allocatable, save :: temp_force(:,:)
real(real_8), allocatable, save :: force_vir(:,:)
real(real_8), save :: temp_energy
real(real_8), save :: kbT

integer, save :: num_atoms
real(real_8), save :: timestep_int
real(real_8), save :: box_int(3,3)
integer, save :: dummy
integer, save :: restart
! real(real_8), allocatable :: charges(:)

contains

subroutine plumed_init
    implicit none

    integer :: ierr
    integer :: api_version
    integer :: is, ia, offset

    if (restart1%restart) then
        restart = 1
    else
        restart = 0
    endif

    num_atoms = ions1%nat
    timestep_int = dt_ions

    box_int(:,:) = 0.0_real_8
#ifdef __MIMIC
    if (cntl%mimic) then
        box_int = mimic_control%box
    else
        box_int(1,1) = cell_com%celldm(1)
        box_int(2,2) = cell_com%celldm(1) * cell_com%celldm(2)
        box_int(3,3) = cell_com%celldm(1) * cell_com%celldm(3)
    endif
#else
    box_int(1,1) = cell_com%celldm(1)
    box_int(2,2) = cell_com%celldm(1) * cell_com%celldm(2)
    box_int(3,3) = cell_com%celldm(1) * cell_com%celldm(3)
#endif
    kbT = kboltz * cntr%tempw

    allocate(temp_coords(3,num_atoms),stat=ierr)
    if(ierr/=0) call stopgm("plumed_init",'allocation problem: temp_coords',&
          __LINE__,__FILE__)
    allocate(temp_force(3,num_atoms),stat=ierr)
    if(ierr/=0) call stopgm("plumed_init",'allocation problem: temp_force',&
          __LINE__,__FILE__)
    allocate(force_vir(3,num_atoms),stat=ierr)
    if(ierr/=0) call stopgm("plumed_init",'allocation problem: force_vir',&
          __LINE__,__FILE__)
    allocate(temp_masses(num_atoms),stat=ierr)
    if(ierr/=0) call stopgm("plumed_init",'allocation problem: temp_masses',&
          __LINE__,__FILE__)

#ifdef __PLUMED
    call plumed_f_create(plumed)
    call plumed_f_cmd(plumed, cmd_version, api_version)
    write(6,'(/," ",12("PLUMED"))')
    write(6, *) "USING PLUMED API VERSION:", api_version
    write(6, *) "USING PLUMED INPUT FILE:", plumed_input
    write(6, *) "LOG WILL BE WRITTEN TO:", plumed_output
    write(6, *) "kbT is:", kbT
    if (restart1%restart) write(6, *) "RESTARTING PREVIOUS SIMULATIONS"
    if (api_version < min_api) then
        call stopgm("plumed_init",'Unsupported version of Plumed API',&
          __LINE__,__FILE__)
    else if (api_version > supported_api) then
        write(6, *) "WARNING: PLUMED API VERSION IS NEWER THEN THE SUPPORTED ONE"
        write(6, *) "WARNING: THIS SHOULD NOT BE A PROBLEM BUT JUST IN CASE..."
    endif
    if (api_version > 3) then
        call plumed_f_cmd(plumed, cmd_charge_unit, charge_unit)
        call plumed_f_cmd(plumed, cmd_mass_unit, mass_unit)
    endif
    call plumed_f_cmd(plumed, cmd_precision, real_size)
    call plumed_f_cmd(plumed, cmd_ener_unit, hartree2kjmol)
    call plumed_f_cmd(plumed, cmd_len_unit, bohr2nm)
    call plumed_f_cmd(plumed, cmd_time_unit, au2ps)
    call plumed_f_cmd(plumed, cmd_set_md, code_name)
    call plumed_f_cmd(plumed, cmd_dat_file, plumed_input)
    call plumed_f_cmd(plumed, cmd_set_logfile, plumed_output)
    call plumed_f_cmd(plumed, cmd_set_natoms, num_atoms)
    call plumed_f_cmd(plumed, cmd_set_timestep, timestep_int)
    call plumed_f_cmd(plumed, cmd_set_kbt, kbT)
    call plumed_f_cmd(plumed, cmd_set_restart, restart)
    call plumed_f_cmd(plumed, cmd_init, dummy)
#endif

    offset = 1
    do is = 1, ions1%nsp
        do ia = 1, ions0%na(is)
            temp_masses(offset) = rmass%pma0(is)
            offset = offset + 1
        end do
    end do

    write(6,'(/," ",12("PLUMED"))')
end subroutine plumed_init

subroutine plumed_calc(timestep, tau, fion, energy)
    implicit none
    integer :: timestep
    real(real_8), dimension(:,:,:), intent(in) :: tau
    real(real_8), dimension(:,:,:), intent(inout) :: fion
    real(real_8), intent(inout) :: energy

    integer :: is, ia, offset

    character(*), parameter :: procedureN = 'plumed_calc'
    integer :: isub

    call tiset(procedureN,isub)

    force_vir(:,:) = 0.0_real_8
    temp_force(:,:) = 0.0_real_8
    temp_energy = energy
    offset = 1
    do is = 1, ions1%nsp
        do ia = 1, ions0%na(is)
            temp_coords(:, offset) = tau(:, ia, is)
            offset = offset + 1
        end do
    end do

#ifdef __PLUMED
    call plumed_f_cmd(plumed, cmd_set_step, timestep)
    call plumed_f_cmd(plumed, cmd_set_box, box_int)
    call plumed_f_cmd(plumed, cmd_set_masses, temp_masses)
    call plumed_f_cmd(plumed, cmd_set_coords, temp_coords)
    call plumed_f_cmd(plumed, cmd_set_forces, temp_force)
    call plumed_f_cmd(plumed, cmd_set_virial, force_vir)
    call plumed_f_cmd(plumed, cmd_set_energy, temp_energy)
    call plumed_f_cmd(plumed, cmd_compute, dummy)
#endif

    offset = 1
    do is = 1, ions1%nsp
        do ia = 1, ions0%na(is)
            fion(:, ia, is) = fion(:, ia, is) + temp_force(:, offset)
            offset = offset + 1
        end do
    end do

    energy = temp_energy

    call tihalt(procedureN,isub)
end subroutine plumed_calc

subroutine plumed_finalize
#ifdef __PLUMED
    call plumed_f_finalize
#endif
end subroutine plumed_finalize

end module plumed_wrapper
.
0a
module plumed_wrapper

use cell
use ions, only : ions0, ions1
USE kinds, ONLY: int_1,&
                 int_4,&
                 int_8,&
                 real_4,&
                 real_8
use readsr_utils, only: readsr, readsi, input_string_len, keyword_contains
use inscan_utils, only: inscan
use error_handling, only: stopgm
use mp_interface
use rmas
USE tpar, only: dt_ions
use system, only: cntr, cntl
use cnst
#ifdef __MIMIC
use mimic_wrapper, only: mimic_control
#endif
use timer, only: tihalt, tiset
use store_types, only: restart1
.
EOF
cat << EOF > patch.6
diff ./scripts/configure.sh ./scripts/configure.sh
744a
    fi
.
738a
    if [ \$plumed ]; then
cat << END >&3
\\\$(TARGET): \\\$(CPMD_LIB) timetag.o cpmd.o
	\\\$(LD) \\\$(FFLAGS) \\\$(PLUMED_LOAD) -o \\\$(TARGET) timetag.o cpmd.o \\\$(CPMD_LIB) \\\$(LFLAGS)
	@ls -l \\\$(TARGET)
	@echo "Compilation done."
END
    else
.
713a
    fi
.
707a
    if [ \$plumed ]; then
cat << END >&3
\\\$(TARGET): \\\$(CPMD_LIB) \\\$(GROMOS_LIB) \\\$(INTERFACE_LIB) timetag.o cpmd.o
	\\\$(LD) \\\$(FFLAGS) \\\$(PLUMED_LOAD) -o \\\$(TARGET) timetag.o cpmd.o \\\$(CPMD_LIB) \\\$(GROMOS_LIB) \\\$(INTERFACE_LIB) \\\$(LFLAGS)
	@ls -l \\\$(TARGET)
	@echo "Compilation done."
END
    else
.
351a
if [ \$plumed ]; then
cat << END >&3
include \\\$(CPMDROOT)/Plumed.inc
END
fi

.
267a
if [ \$plumed ]; then
    CPPFLAGS=\${CPPFLAGS}' -D__PLUMED'
fi
.
137a
    -plumed|-p)  #New flag if you want to patch cpmd with plumed
      plumed=1
      echo "** Enabling coupling with PLUMED. Make sure PLUMED is installed on your system!" >&2
      ;;
.
60a
   -plumed          Enable support for coupling to PLUMED
.
EOF

# Progress bar from https://www.baeldung.com/linux/command-line-progress-bar
bar_size=40
bar_char_done="#"
bar_char_todo="-"
bar_percentage_scale=2

function show_progress {
    current="$1"
    total="$2"

    # calculate the progress in percentage
    percent=$(bc <<< "scale=$bar_percentage_scale; 100 * $current / $total" )
    # The number of done and todo characters
    done=$(bc <<< "scale=0; $bar_size * $percent / 100" )
    todo=$(bc <<< "scale=0; $bar_size - $done" )

    # build the done and todo sub-bars
    done_sub_bar=$(printf "%${done}s" | tr " " "${bar_char_done}")
    todo_sub_bar=$(printf "%${todo}s" | tr " " "${bar_char_todo}")

    # output the bar
    echo -ne "\rProgress : [${done_sub_bar}${todo_sub_bar}] ${percent}%"

    if [ $total -eq $current ]; then
        echo -e "\nDONE"
    fi
}

echo "Patching CPMD"
echo "Adding PLUMED interface to CPMD"
for i in `seq 0 6`
do
    filename=`head -n 1 patch.$i | awk '{print $2}'`
    patch -e -p0 $filename < patch.$i
    rm patch.$i
    show_progress $i 6
done
echo "PLUMED patches has been applied, you can proceed with 'plumed patch'"
