TABLE OF CONTENTS


ABINIT/cheb_oracle [ Functions ]

[ Top ] [ Functions ]

NAME

 cheb_oracle

FUNCTION

 Returns the number of necessary iterations to decrease residual by at least tol
 Here as in the rest of the code, the convention is that residuals are squared (||Ax-lx||^2)

INPUTS

 x= input variable
 a= left bound of the interval
 b= right bound of the interval
 tol= needed precision
 nmax= max number of iterations

OUTPUT

 n= number of iterations needed to decrease residual by tol

NOTES

SOURCE

647 function cheb_oracle(x, a, b, tol, nmax) result(n)
648 
649  real(dp) :: tol
650 
651  integer :: nmax
652  integer :: n, i
653  real(dp), intent(in) :: x, a, b
654  real(dp) :: y, xred, temp
655  real(dp) :: yim1
656 
657 ! *************************************************************************
658 
659  xred = (x-(a+b)/2)/(b-a)*2
660  y = xred
661  yim1 = one
662 
663  n = nmax
664  if(1/(y**2) < tol) then
665    n = 1
666  else
667    do i=2, nmax-1
668      temp = y
669      y = 2*xred*y - yim1
670      yim1 = temp
671      if(1/(y**2) < tol) then
672        n = i
673        exit
674      end if
675    end do
676  end if
677 
678 end function cheb_oracle

ABINIT/cheb_poly [ Functions ]

[ Top ] [ Functions ]

NAME

 cheb_poly

FUNCTION

 Computes the value of the Chebyshev polynomial of degree n on the interval [a,b] at x

INPUTS

 x= input variable
 n= degree
 a= left bound of the interval
 b= right bound of the interval

OUTPUT

 y= Tn(x)

NOTES

SOURCE

603 function cheb_poly(x, n, a, b) result(y)
604 
605  integer, intent(in) :: n
606  integer :: i
607  real(dp), intent(in) :: x, a, b
608  real(dp) :: y, xred, temp
609  real(dp) :: yim1
610 
611 ! *************************************************************************
612 
613  xred = (x-(a+b)/2)/(b-a)*2
614  y = xred
615  yim1 = one
616  do i=2, n
617    temp = y
618    y = 2*xred*y - yim1
619    yim1 = temp
620  end do
621 
622 end function cheb_poly

ABINIT/chebfi [ Functions ]

[ Top ] [ Functions ]

NAME

 chebfi

FUNCTION

 this routine updates the wave functions at a given k-point,
 using the ChebFi method (see paper by A. Levitt and M. Torrent)

INPUTS

  dtset <type(dataset_type)>=all input variales for this dataset
  gs_hamk <type(gs_hamiltonian_type)>=all data for the hamiltonian at k
  kinpw(npw)=(modified) kinetic energy for each plane wave (hartree)
  mpi_enreg=information about MPI parallelization
  nband=number of bands at this k point for that spin polarization
  npw=number of plane waves at this k point
  nspinor=number of plane waves at this k point
  prtvol=control print volume and debugging output

OUTPUT

  eig(nband)=array for holding eigenvalues (hartree)
  resid(nband)=residuals for each states
  If gs_hamk%usepaw==1:
    gsc(2,*)=<g|s|c> matrix elements (s=overlap)
  If gs_hamk%usepaw==0
    enlx(nband)=contribution from each band to nonlocal psp + potential Fock ACE part of total energy, at this k-point

SIDE EFFECTS

  cg(2,*)=updated wavefunctions

NOTES

  -- TODO --
  Normev?
  Ecutsm
  nspinor 2
  spinors parallelisation
  fock
  -- Performance --
  Improve load balancing
  Don't diagonalize converged eigenvectors, just orthogonalize
  Maybe don't diagonalize so often (once every two outer iterations?)
  Benchmark diagonalizations, choose np_slk
  How to chose npfft?
  Implement MINRES for invovl
  -- LOBPCG --
  Improve stability (see paper by Lehoucq Sorensen, maybe use bunch-kaufman factorizations?)

SOURCE

102 subroutine chebfi(cg,dtset,eig,enlx,gs_hamk,gsc,kinpw,mpi_enreg,nband,npw,nspinor,prtvol,resid)
103 
104 !Arguments ------------------------------------
105  type(gs_hamiltonian_type),intent(inout) :: gs_hamk
106  type(dataset_type),intent(in) :: dtset
107  type(mpi_type),intent(inout) :: mpi_enreg
108  integer,intent(in) :: nband,npw,prtvol,nspinor
109  real(dp),intent(inout), target :: cg(2,npw*nspinor*nband),gsc(2,npw*nspinor*nband)
110  real(dp),intent(in) :: kinpw(npw)
111  real(dp),intent(out) :: resid(nband)
112  real(dp),intent(out) :: enlx(nband)
113  real(dp),intent(out) :: eig(nband)
114 
115 !Local variables-------------------------------
116  real(dp) :: pcon(npw)
117  real(dp) :: filter_low
118  real(dp) :: filter_center, filter_radius
119  real(dp), dimension(2, npw*nspinor*nband), target :: ghc, gvnlxc
120  real(dp), allocatable, dimension(:,:) :: cg_filter_next, cg_filter_prev, gsm1hc_filter, gsc_filter_prev, gsc_filter_next
121  real(dp), allocatable, dimension(:,:), target :: cg_alltoall1,gsc_alltoall1,ghc_alltoall1,gvnlxc_alltoall1
122  real(dp), allocatable, dimension(:,:), target :: cg_alltoall2,gsc_alltoall2,ghc_alltoall2,gvnlxc_alltoall2
123  real(dp), pointer, dimension(:,:) :: cg_filter, gsc_filter, ghc_filter, gvnlxc_filter
124  real(dp) :: resid_vec(2, npw*nspinor)
125  logical :: has_fock,paw
126  integer :: shift, shift_cg_loadbalanced
127  integer :: iband, iline, ispinor
128  integer :: sij_opt, cpopt
129  real(dp) :: eval, tsec(2)
130  integer :: tim_getghc = 5, ierr
131  integer :: i
132  integer, allocatable :: index_wavef_band(:)
133  real(dp) :: maxeig, mineig
134  real(dp), allocatable :: resids_filter(:), residvec_filter(:,:)
135  integer, allocatable :: nline_bands(:)
136  integer :: iactive, nactive
137  real(dp) :: ampfactor
138  integer :: nline_max, nline_decrease, nline_tolwfr
139  ! real(dp) :: load_imbalance
140  integer :: mcg
141  real(dp) :: dprod_r, dprod_i
142  character(len=500) :: message
143  integer :: rdisplsloc(mpi_enreg%nproc_band), recvcountsloc(mpi_enreg%nproc_band)
144  integer :: sdisplsloc(mpi_enreg%nproc_band), sendcountsloc(mpi_enreg%nproc_band)
145  integer :: ikpt_this_proc, npw_filter, nband_filter
146  type(pawcprj_type), allocatable :: cwaveprj(:,:), cwaveprj_next(:,:), cwaveprj_prev(:,:)
147  ! integer :: nline_total
148 
149  ! timers
150  integer, parameter :: timer_chebfi = 1600, timer_alltoall = 1601, timer_apply_inv_ovl = 1602, timer_rotation = 1603
151  integer, parameter :: timer_subdiago = 1604, timer_subham = 1605, timer_ortho = 1606, timer_getghc = 1607
152  integer, parameter :: timer_residuals = 1608, timer_update_eigen = 1609, timer_sync = 1610
153 
154 ! *************************************************************************
155 
156  !======================================================================================================
157  ! Initialize, transpose input cg if paral_kgb
158  !======================================================================================================
159  call timab(timer_chebfi,1,tsec)
160 
161  !Initializations
162  paw = gs_hamk%usepaw == 1
163  has_fock=(associated(gs_hamk%fockcommon))
164  mcg = npw*nspinor*nband
165 
166  ! Init pcon
167  pcon = (27+kinpw*(18+kinpw*(12+8*kinpw))) / (27+kinpw*(18+kinpw*(12+8*kinpw)) + 16*kinpw**4)
168 
169  ghc=zero; gvnlxc=zero
170 
171  ! Initialize the _filter pointers. Depending on paral_kgb, they might point to the actual arrays or to _alltoall variables
172  if (dtset%paral_kgb == 1) then
173    ikpt_this_proc = bandfft_kpt_get_ikpt()
174    npw_filter     = bandfft_kpt(ikpt_this_proc)%ndatarecv
175    nband_filter   = mpi_enreg%bandpp
176 
177    ABI_MALLOC(cg_alltoall1,     (2, npw_filter*nspinor*nband_filter))
178    ABI_MALLOC(gsc_alltoall1,    (2, npw_filter*nspinor*nband_filter))
179    ABI_MALLOC(ghc_alltoall1,    (2, npw_filter*nspinor*nband_filter))
180    ABI_MALLOC(gvnlxc_alltoall1, (2, npw_filter*nspinor*nband_filter))
181    ABI_MALLOC(cg_alltoall2,     (2, npw_filter*nspinor*nband_filter))
182    ABI_MALLOC(gsc_alltoall2,    (2, npw_filter*nspinor*nband_filter))
183    ABI_MALLOC(ghc_alltoall2,    (2, npw_filter*nspinor*nband_filter))
184    ABI_MALLOC(gvnlxc_alltoall2, (2, npw_filter*nspinor*nband_filter))
185 
186    ! Init tranpose variables
187    recvcountsloc = bandfft_kpt(ikpt_this_proc)%recvcounts * 2 * nspinor * mpi_enreg%bandpp
188    rdisplsloc    = bandfft_kpt(ikpt_this_proc)%rdispls    * 2 * nspinor * mpi_enreg%bandpp
189    sendcountsloc = bandfft_kpt(ikpt_this_proc)%sendcounts * 2 * nspinor
190    sdisplsloc    = bandfft_kpt(ikpt_this_proc)%sdispls    * 2 * nspinor
191 
192    ! Load balancing, so that each processor has approximately the same number of converged and non-converged bands
193    ! for two procs, rearrange 1 2 3 4 5 6 as 1 4 2 5 3 6
194    !
195    ! trick to save memory: ghc has the necessary size, and will be overwritten afterwards anyway
196 #define cg_loadbalanced ghc
197    shift = 0
198    do i=1, mpi_enreg%nproc_band
199      do iband=1, mpi_enreg%bandpp
200        shift_cg_loadbalanced = (i-1 + (iband-1)*mpi_enreg%nproc_band)*npw*nspinor
201        cg_loadbalanced(:, shift+1:shift+npw*nspinor) = cg(:, shift_cg_loadbalanced+1:shift_cg_loadbalanced+npw*nspinor)
202        shift = shift + npw*nspinor
203      end do
204    end do
205 
206    ! Transpose input cg into cg_alloall1. cg_alltoall1 is now (npw_filter, nband_filter)
207    call timab(timer_alltoall, 1, tsec)
208    call xmpi_alltoallv(cg_loadbalanced,sendcountsloc,sdisplsloc,cg_alltoall1,&
209 &   recvcountsloc,rdisplsloc,mpi_enreg%comm_band,ierr)
210    call timab(timer_alltoall, 2, tsec)
211 #undef cg_loadbalanced
212 
213    ! sort according to bandpp (from lobpcg, I don't fully understand what's going on but it works and it's fast)
214    call prep_index_wavef_bandpp(mpi_enreg%nproc_band,mpi_enreg%bandpp,&
215         &   nspinor,bandfft_kpt(ikpt_this_proc)%ndatarecv,&
216         &   bandfft_kpt(ikpt_this_proc)%recvcounts,bandfft_kpt(ikpt_this_proc)%rdispls,&
217         &   index_wavef_band)
218 
219    cg_alltoall2(:,:) = cg_alltoall1(:,index_wavef_band)
220 
221    cg_filter => cg_alltoall2
222    gsc_filter => gsc_alltoall2
223    ghc_filter => ghc_alltoall2
224    gvnlxc_filter => gvnlxc_alltoall2
225  else
226    npw_filter = npw
227    nband_filter = nband
228 
229    cg_filter => cg
230    gsc_filter => gsc
231    ghc_filter => ghc
232    gvnlxc_filter => gvnlxc
233  end if
234  ! from here to the next alltoall, all computation is done on _filter variables, agnostic
235  ! to whether it's nband x npw (paral_kgb == 0) or ndatarecv*bandpp (paral_kgb = 1)
236 
237  ! Allocate filter variables for the application of the Chebyshev polynomial
238  ABI_MALLOC(cg_filter_next, (2, npw_filter*nspinor*nband_filter))
239  ABI_MALLOC(cg_filter_prev, (2, npw_filter*nspinor*nband_filter))
240  ABI_MALLOC(gsc_filter_prev, (2, npw_filter*nspinor*nband_filter))
241  ABI_MALLOC(gsc_filter_next, (2, npw_filter*nspinor*nband_filter))
242  ABI_MALLOC(gsm1hc_filter, (2, npw_filter*nspinor*nband_filter))
243 
244  ! PAW init
245  if(paw) then
246    ABI_MALLOC(cwaveprj, (gs_hamk%natom,nspinor*nband_filter))
247    ABI_MALLOC(cwaveprj_next, (gs_hamk%natom,nspinor*nband_filter))
248    ABI_MALLOC(cwaveprj_prev, (gs_hamk%natom,nspinor*nband_filter))
249    call pawcprj_alloc(cwaveprj,0,gs_hamk%dimcprj)
250    call pawcprj_alloc(cwaveprj_next,0,gs_hamk%dimcprj)
251    call pawcprj_alloc(cwaveprj_prev,0,gs_hamk%dimcprj)
252 
253    sij_opt = 1 ! recompute S
254    cpopt = 0 ! save cprojs
255  else
256    sij_opt = 0
257    cpopt = -1
258  end if
259 
260 
261 
262  !======================================================================================================
263  ! Data in npfft x npband distribution. First getghc, update eigenvalues and residuals
264  !======================================================================================================
265  write(message, *) 'First getghc'
266  call wrtout(std_out,message,'COLL')
267 
268  ! get_ghc on cg
269  call timab(timer_getghc, 1, tsec)
270  if (dtset%paral_kgb == 0) then
271    call getghc(cpopt,cg_filter,cwaveprj,ghc_filter,gsc_filter,gs_hamk,gvnlxc_filter,&
272 &   eval,mpi_enreg,nband,prtvol,sij_opt,tim_getghc,0)
273  else
274    call prep_getghc(cg_filter,gs_hamk,gvnlxc_filter,ghc_filter,gsc_filter,eval,nband,mpi_enreg,&
275 &   prtvol,sij_opt,cpopt,cwaveprj,already_transposed=.true.)
276  end if
277  call timab(timer_getghc, 2, tsec)
278 
279  ! Debug barrier: should be invisible
280  call timab(timer_sync, 1, tsec)
281  call xmpi_barrier(mpi_enreg%comm_band)
282  call timab(timer_sync, 2, tsec)
283 
284  write(message, *) 'Computing residuals'
285  call wrtout(std_out,message,'COLL')
286  ! update eigenvalues and residuals
287  call timab(timer_update_eigen, 1, tsec)
288  ABI_MALLOC(resids_filter, (nband_filter))
289  ABI_MALLOC(residvec_filter, (2, npw_filter*nspinor))
290  ABI_MALLOC(nline_bands, (nband_filter))
291  do iband=1, nband_filter
292    shift = npw_filter*nspinor*(iband-1)
293    call dotprod_g(eig(iband),dprod_i,gs_hamk%istwf_k,npw_filter*nspinor,1,ghc_filter(:, shift+1:shift+npw_filter*nspinor),&
294 &   cg_filter(:, shift+1:shift+npw_filter*nspinor),mpi_enreg%me_g0,mpi_enreg%comm_spinorfft)
295    if(paw) then
296      call dotprod_g(dprod_r,dprod_i,gs_hamk%istwf_k,npw_filter*nspinor,1,gsc_filter(:, shift+1:shift+npw_filter*nspinor),&
297 &     cg_filter(:, shift+1:shift+npw_filter*nspinor),mpi_enreg%me_g0,mpi_enreg%comm_spinorfft)
298      eig(iband) = eig(iband)/dprod_r
299    end if
300 
301    if(paw) then
302      residvec_filter = ghc_filter(:, shift+1 : shift+npw_filter*nspinor) &
303 &     - eig(iband)*gsc_filter(:, shift+1 : shift+npw_filter*nspinor)
304    else
305      residvec_filter = ghc_filter(:, shift+1 : shift+npw_filter*nspinor) &
306 &     - eig(iband)*cg_filter(:, shift+1 : shift+npw_filter*nspinor)
307    end if
308    resids_filter(iband) = SUM(residvec_filter**2)
309  end do
310  call xmpi_sum(resids_filter,mpi_enreg%comm_fft,ierr)
311  call xmpi_max(MAXVAL(eig(1:nband_filter)),maxeig,mpi_enreg%comm_band,ierr)
312  call xmpi_min(MINVAL(eig(1:nband_filter)),mineig,mpi_enreg%comm_band,ierr)
313  filter_low = maxeig
314  call timab(timer_update_eigen, 2, tsec)
315 
316  ! Decide how many iterations per band are needed
317  ! don't go above this, or face bad conditioning of the Gram matrix.
318  nline_max = cheb_oracle(mineig, filter_low, dtset%ecut, 1e-16_dp, 40)
319  ! if(mpi_enreg%me == 0) write(0, *) nline_max
320  do iband=1, nband_filter
321    ! nline necessary to converge to tolwfr
322    nline_tolwfr = cheb_oracle(eig(iband), filter_low, dtset%ecut, dtset%tolwfr_diago / resids_filter(iband), dtset%nline)
323    ! nline necessary to decrease residual by a constant factor
324    nline_decrease = cheb_oracle(eig(iband), filter_low, dtset%ecut, 0.1_dp, dtset%nline)
325 
326    nline_bands(iband) = MAX(MIN(nline_tolwfr, nline_decrease, nline_max, dtset%nline), 1)
327    nline_bands(iband) = dtset%nline ! fiddle with this to use locking
328  end do
329 
330 
331  !!!!! Uncomment for diagnostics
332  ! nline_total = SUM(nline_bands)
333  ! call xmpi_sum(nline_total, mpi_enreg%comm_band, ierr)
334  ! load_imbalance = (SUM(nline_bands) - REAL(nline_total)/REAL(mpi_enreg%nproc_band)) / &
335  ! &                (REAL(nline_total)/REAL(mpi_enreg%nproc_band))
336  ! call xmax_mpi(load_imbalance, mpi_enreg%comm_band, ierr)
337 
338  ! write(message, *) 'Mean nline', REAL(nline_total)/REAL(nband), 'max imbalance (%)', load_imbalance*100
339  ! call wrtout(std_out,message,'COLL')
340 
341  ABI_FREE(resids_filter)
342  ABI_FREE(residvec_filter)
343 
344  !======================================================================================================
345  ! Chebyshev polynomial application
346  !======================================================================================================
347  ! Filter by a chebyshev polynomial of degree nline
348  do iline=1,dtset%nline
349    ! Filter only on [iactive, iactive+nactive-1]
350    iactive = nband_filter
351    do iband = 1, nband_filter
352      ! does iband need an iteration?
353      if (nline_bands(iband) >= iline) then
354        iactive = iband
355        exit
356      end if
357    end do
358    nactive = nband_filter - iactive + 1
359    shift = npw_filter*nspinor*(iactive-1) + 1
360    ! trick the legacy prep_getghc
361    mpi_enreg%bandpp = nactive
362 
363    ! Define the filter position
364    filter_center = (dtset%ecut+filter_low)/2
365    filter_radius = (dtset%ecut-filter_low)/2
366 
367    ! write(message, *) 'Applying invovl, iteration', iline
368    ! call wrtout(std_out,message,'COLL')
369 
370    ! If paw, have to apply S^-1
371    if(paw) then
372      call timab(timer_apply_inv_ovl, 1, tsec)
373      call apply_invovl(gs_hamk, ghc_filter(:,shift:), gsm1hc_filter(:,shift:), cwaveprj_next(:,iactive:), &
374 &     npw_filter, nactive, mpi_enreg, nspinor, dtset%invovl_blksliced)
375      call timab(timer_apply_inv_ovl, 2, tsec)
376    else
377      gsm1hc_filter(:,shift:) = ghc_filter(:,shift:)
378    end if
379 
380    ! Chebyshev iteration: UPDATE cg
381    if(iline == 1) then
382      cg_filter_next(:,shift:) = one/filter_radius * (gsm1hc_filter(:,shift:) - filter_center*cg_filter(:,shift:))
383    else
384      cg_filter_next(:,shift:) = two/filter_radius * (gsm1hc_filter(:,shift:) - filter_center*cg_filter(:,shift:)) &
385 &     - cg_filter_prev(:,shift:)
386    end if
387    ! Update gsc and cwaveprj
388    if(paw) then
389      if(iline == 1) then
390        gsc_filter_next(:,shift:) = one/filter_radius * (ghc_filter(:,shift:) - filter_center*gsc_filter(:,shift:))
391        !cwaveprj_next = one/filter_radius * (cwaveprj_next - filter_center*cwaveprj)
392        call pawcprj_axpby(-filter_center/filter_radius, one/filter_radius,cwaveprj(:,iactive:),cwaveprj_next(:,iactive:))
393      else
394        gsc_filter_next(:,shift:) = two/filter_radius * (ghc_filter(:,shift:) - filter_center*gsc_filter(:,shift:))&
395 &       - gsc_filter_prev(:,shift:)
396        !cwaveprj_next = two/filter_radius * (cwaveprj_next - filter_center*cwaveprj) - cwaveprj_prev
397        call pawcprj_axpby(-two*filter_center/filter_radius, two/filter_radius,cwaveprj(:,iactive:),cwaveprj_next(:,iactive:))
398        call pawcprj_axpby(-one, one,cwaveprj_prev(:,iactive:),cwaveprj_next(:,iactive:))
399      end if
400    end if
401 
402    ! Bookkeeping of the _prev variables
403    cg_filter_prev(:,shift:) = cg_filter(:,shift:)
404    cg_filter(:,shift:) = cg_filter_next(:,shift:)
405    if(paw) then
406      gsc_filter_prev(:,shift:) = gsc_filter(:,shift:)
407      gsc_filter(:,shift:) = gsc_filter_next(:,shift:)
408 
409      call pawcprj_copy(cwaveprj(:,iactive:),cwaveprj_prev(:,iactive:))
410      call pawcprj_copy(cwaveprj_next(:,iactive:),cwaveprj(:,iactive:))
411    end if
412 
413    ! Update ghc
414    if(paw) then
415      !! DEBUG use this to remove the optimization and recompute gsc/cprojs
416      ! sij_opt = 1
417      ! cpopt = 0
418 
419      sij_opt = 0 ! gsc is already computed
420      cpopt = 2 ! reuse cprojs
421    else
422      sij_opt = 0
423      cpopt = -1
424    end if
425 
426    write(message, *) 'Getghc, iteration', iline
427    call wrtout(std_out,message,'COLL')
428 
429    call timab(timer_getghc, 1, tsec)
430    if (dtset%paral_kgb == 0) then
431      call getghc(cpopt,cg_filter(:,shift:),cwaveprj(:,iactive:),ghc_filter(:,shift:),&
432 &     gsc_filter(:,shift:),gs_hamk,gvnlxc_filter(:,shift:),eval,mpi_enreg,&
433 &     nband,prtvol,sij_opt,tim_getghc,0)
434    else
435      call prep_getghc(cg_filter(:,shift:),gs_hamk,gvnlxc_filter(:,shift:),ghc_filter(:,shift:),&
436 &     gsc_filter(:,shift:),eval,nband,mpi_enreg,prtvol,sij_opt,cpopt,&
437 &     cwaveprj(:,iactive:),already_transposed=.true.)
438    end if
439 
440    ! end of the trick
441    mpi_enreg%bandpp = nband_filter
442 
443    call timab(timer_getghc, 2, tsec)
444  end do ! end loop on nline
445 
446  ! normalize according to the previously computed rayleigh quotients (inaccurate, but cheap)
447  do iband = 1, nband_filter
448    ampfactor = cheb_poly(eig(iband), nline_bands(iband), filter_low, dtset%ecut)
449    if(abs(ampfactor) < 1e-3) ampfactor = 1e-3 ! just in case, avoid amplifying too much
450    shift = npw_filter*nspinor*(iband-1)
451    cg_filter(:, shift+1:shift+npw_filter*nspinor) = cg_filter(:, shift+1:shift+npw_filter*nspinor) / ampfactor
452    ghc_filter(:, shift+1:shift+npw_filter*nspinor) = ghc_filter(:, shift+1:shift+npw_filter*nspinor) / ampfactor
453    if(paw) then
454      gsc_filter(:, shift+1:shift+npw_filter*nspinor) = gsc_filter(:, shift+1:shift+npw_filter*nspinor) / ampfactor
455    endif
456    if(.not.paw .or. has_fock)then
457      gvnlxc_filter(:, shift+1:shift+npw_filter*nspinor) = gvnlxc_filter(:, shift+1:shift+npw_filter*nspinor) / ampfactor
458    end if
459  end do
460 
461  ! Cleanup
462  if(paw) then
463    call pawcprj_free(cwaveprj)
464    call pawcprj_free(cwaveprj_next)
465    call pawcprj_free(cwaveprj_prev)
466    ABI_FREE(cwaveprj)
467    ABI_FREE(cwaveprj_next)
468    ABI_FREE(cwaveprj_prev)
469  end if
470  ABI_FREE(nline_bands)
471  ABI_FREE(cg_filter_next)
472  ABI_FREE(cg_filter_prev)
473  ABI_FREE(gsc_filter_prev)
474  ABI_FREE(gsc_filter_next)
475  ABI_FREE(gsm1hc_filter)
476 
477  !======================================================================================================
478  ! Filtering done, tranpose back
479  !======================================================================================================
480 
481  write(message, *) 'Filtering done, transposing back'
482  call wrtout(std_out,message,'COLL')
483 
484  ! transpose back
485  if(dtset%paral_kgb == 1) then
486    cg_alltoall1(:,index_wavef_band) = cg_alltoall2(:,:)
487    ghc_alltoall1(:,index_wavef_band) = ghc_alltoall2(:,:)
488    if(paw) then
489      gsc_alltoall1(:,index_wavef_band) = gsc_alltoall2(:,:)
490    else
491      gvnlxc_alltoall1(:,index_wavef_band) = gvnlxc_alltoall2(:,:)
492    end if
493 
494    ABI_FREE(index_wavef_band)
495 
496    call timab(timer_sync, 1, tsec)
497    call xmpi_barrier(mpi_enreg%comm_band)
498    call timab(timer_sync, 2, tsec)
499 
500    call timab(timer_alltoall, 1, tsec)
501 
502   ! Do we pack the arrays in the alltoall, saving latency, or do we do it separately, saving memory and copies?
503    call xmpi_alltoallv(cg_alltoall1,recvcountsloc,rdisplsloc,cg,&
504 &   sendcountsloc,sdisplsloc,mpi_enreg%comm_band,ierr)
505    call xmpi_alltoallv(ghc_alltoall1,recvcountsloc,rdisplsloc,ghc,&
506 &   sendcountsloc,sdisplsloc,mpi_enreg%comm_band,ierr)
507    if(paw) then
508      call xmpi_alltoallv(gsc_alltoall1,recvcountsloc,rdisplsloc,gsc,&
509 &     sendcountsloc,sdisplsloc,mpi_enreg%comm_band,ierr)
510    else
511      call xmpi_alltoallv(gvnlxc_alltoall1,recvcountsloc,rdisplsloc,gvnlxc,&
512 &     sendcountsloc,sdisplsloc,mpi_enreg%comm_band,ierr)
513    end if
514    call timab(timer_alltoall, 2, tsec)
515 
516    if(mpi_enreg%paral_kgb == 1) then
517      ABI_FREE(cg_alltoall1)
518      ABI_FREE(gsc_alltoall1)
519      ABI_FREE(ghc_alltoall1)
520      ABI_FREE(gvnlxc_alltoall1)
521      ABI_FREE(cg_alltoall2)
522      ABI_FREE(gsc_alltoall2)
523      ABI_FREE(ghc_alltoall2)
524      ABI_FREE(gvnlxc_alltoall2)
525    end if
526  else
527    ! nothing to do, the _filter variables already point to the right ones
528  end if
529 
530 
531 
532  !======================================================================================================
533  ! Data in (npfft*npband) x 1 distribution. Rayleigh-Ritz step
534  !======================================================================================================
535 
536  ! _subdiago might use less memory when using only one proc, should maybe call it, or just remove it
537  ! and always call _distributed
538 #if defined HAVE_LINALG_SCALAPACK
539  call rayleigh_ritz_distributed(cg,ghc,gsc,gvnlxc,eig,has_fock,gs_hamk%istwf_k,mpi_enreg,nband,npw,nspinor,gs_hamk%usepaw)
540 #else
541  call rayleigh_ritz_subdiago(cg,ghc,gsc,gvnlxc,eig,has_fock,gs_hamk%istwf_k,mpi_enreg,nband,npw,nspinor,gs_hamk%usepaw)
542 #endif
543 
544  ! Build residuals
545  call timab(timer_residuals, 1, tsec)
546  do iband=1,nband
547    shift = npw*nspinor*(iband-1)
548    if(paw) then
549      resid_vec = ghc(:, shift+1 : shift+npw*nspinor) - eig(iband)*gsc(:, shift+1 : shift+npw*nspinor)
550    else
551      resid_vec = ghc(:, shift+1 : shift+npw*nspinor) - eig(iband)*cg (:, shift+1 : shift+npw*nspinor)
552    end if
553 
554    ! precondition resid_vec
555    do ispinor = 1,nspinor
556      resid_vec(1, npw*(ispinor-1)+1:npw*ispinor) = resid_vec(1, npw*(ispinor-1)+1:npw*ispinor) * pcon
557      resid_vec(2, npw*(ispinor-1)+1:npw*ispinor) = resid_vec(2, npw*(ispinor-1)+1:npw*ispinor) * pcon
558    end do
559 
560    call dotprod_g(resid(iband),dprod_i,gs_hamk%istwf_k,npw*nspinor,1,resid_vec,&
561 &   resid_vec,mpi_enreg%me_g0,mpi_enreg%comm_bandspinorfft)
562 
563    if(.not. paw .or. has_fock) then
564      call dotprod_g(enlx(iband),dprod_i,gs_hamk%istwf_k,npw*nspinor,1,cg(:, shift+1:shift+npw*nspinor),&
565 &     gvnlxc(:, shift+1:shift+npw_filter*nspinor),mpi_enreg%me_g0,mpi_enreg%comm_bandspinorfft)
566    end if
567  end do
568  call timab(timer_residuals, 2, tsec)
569 
570  ! write(message, '(a,4e10.2)') 'Resids (1, N, min, max) ', resid(1), resid(nband), MINVAL(resid), MAXVAL(resid)
571  ! call wrtout(std_out,message,'COLL')
572 
573  ! write(message,*)'Eigens(1,nocc,nband) ',eig(1), eig(ilastocc),eig(nband)
574  ! call wrtout(std_out,message,'COLL')
575  ! write(message,*)'Resids(1,nocc,nband) ',resid(1), resid(ilastocc),resid(nband)
576  ! call wrtout(std_out,message,'COLL')
577 
578  call timab(timer_chebfi,2,tsec)
579 
580 end subroutine chebfi

ABINIT/m_chebfi [ Modules ]

[ Top ] [ Modules ]

NAME

  m_chebfi

FUNCTION

COPYRIGHT

  Copyright (C) 2014-2024 ABINIT group (AL)
  This file is distributed under the terms of the
  GNU General Public License, see ~abinit/COPYING
  or http://www.gnu.org/copyleft/gpl.txt .

SOURCE

16 #if defined HAVE_CONFIG_H
17 #include "config.h"
18 #endif
19 
20 #include "abi_common.h"
21 
22 module m_chebfi
23 
24  use defs_basis
25  use m_errors
26  use m_xmpi
27  use m_abicore
28  use m_abi_linalg
29  use m_rayleigh_ritz
30  use m_invovl
31  use m_dtset
32 
33  use defs_abitypes, only : mpi_type
34  use m_time,          only : timab
35  use m_cgtools,       only : dotprod_g
36  use m_bandfft_kpt,   only : bandfft_kpt, bandfft_kpt_get_ikpt
37  use m_pawcprj,       only : pawcprj_type, pawcprj_alloc, pawcprj_free, pawcprj_axpby, pawcprj_copy
38  use m_hamiltonian,   only : gs_hamiltonian_type
39  use m_getghc,        only : getghc
40  use m_prep_kgb,      only : prep_getghc, prep_index_wavef_bandpp
41 
42  implicit none
43 
44  private