Skip to content
90 changes: 90 additions & 0 deletions source/source_lcao/LCAO_init_basis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "source_io/module_parameter/parameter.h"
#include "source_base/parallel_comm.h"

#include <cmath>

namespace LCAO_domain
{

Expand Down Expand Up @@ -78,6 +80,94 @@ void init_basis_lcao(Parallel_Orbitals& pv,
try_nb = pv.set_nloc_wfc_Eij(PARAM.inp.nbands, GlobalV::ofs_running, GlobalV::ofs_warning);
}

// ---- nb2d (ScaLAPACK 2D block-cyclic block size) load-balance check ----
// ScaLAPACK diagonalizes the N x N matrix (N = nlocal) on a p x q process grid
// (p <= q) with square block nb2d. The time-vs-nb2d curve is U-shaped: too large
// -> load-imbalance cliff (one process owns a whole panel); too small -> poor BLAS
// and heavy block-cyclic communication. Healthy window [nb_lo, nb_hi] below.
// Only scalapack_gvx diagonalizes on this 2D grid (genelpa/elpa tune internally;
// lapack/cusolver/pexsi do not), so the check is restricted to it.
if (PARAM.inp.ks_solver == "scalapack_gvx")
{
const int kpar = (PARAM.globalv.kpar_lcao > 0) ? PARAM.globalv.kpar_lcao : 1;
// Processes running one (per-pool) diagonalization:
// kpar == 1 : the grid is this ParaV grid (built on DIAG_WORLD) -> pv.dim0*pv.dim1.
// kpar > 1 : hsolver re-splits MPI_COMM_WORLD into kpar pools of NPROC/kpar ranks
// (not the DIAG_WORLD grid). Uneven pools (NPROC % kpar != 0) are skipped.
int np_pool = 0; // processes per pool (0 => skip the check)
if (kpar <= 1)
{
np_pool = pv.dim0 * pv.dim1;
}
else if (GlobalV::NPROC % kpar == 0)
{
np_pool = GlobalV::NPROC / kpar;
}
if (np_pool > 1 && nlocal > 0)
Comment thread
A-006 marked this conversation as resolved.
{
// near-square factorization np_pool = p * q, p <= q (matches Parallel_2D)
int p_row = static_cast<int>(std::sqrt(np_pool + 0.5));
while (p_row > 1 && np_pool % p_row != 0) { --p_row; }
const int p_col = np_pool / p_row; // long edge q (>= p_row)

// Two-sided window: nb_hi = floor(N/2q) keeps >= 2 blocks per process;
// nb_lo = min(16, nb_hi) avoids tiny blocks; recommended = min(64, nb_hi).
// nspin==4 carries 2-component spinors that must stay paired in one block
// (hence autoset/fallback use nb2d=2, not 1), so snap the window to a multiple
// of 2 -- an odd nb2d would break the spinor blocking and segfault.
const int nb_unit = (PARAM.inp.nspin == 4) ? 2 : 1;
auto snap = [nb_unit](int v) { v = v / nb_unit * nb_unit; return v < nb_unit ? nb_unit : v; };
const int nb_hi = snap((nlocal >= 2 * p_col) ? nlocal / (2 * p_col) : 1);
const int nb_lo = snap((16 < nb_hi) ? 16 : nb_hi);
const int nb_opt = snap((nb_hi < 64) ? nb_hi : 64);
const int nb_cur = pv.nb;

const char* issue = nullptr;
if (nb_cur > nb_hi)
{
issue = "too large -> ScaLAPACK load imbalance (one process owns a whole panel)";
}
else if (nb_cur < nb_lo)
{
issue = "too small -> over-scatter (poor BLAS efficiency and heavy communication)";
}

// user-set nb2d (!=0): keep the value, only warn. auto nb2d (==0): correct it
// to nb_opt (pv.nb feeds both the kpar==1 path and the per-pool Parallel_K2D).
if (issue != nullptr)
{
if (PARAM.inp.nb2d != 0)
{
GlobalV::ofs_warning << "init_basis_lcao: user-set nb2d=" << nb_cur << " is " << issue
<< " for N=" << nlocal << ", kpar=" << kpar << " (per-pool grid " << p_row << "x"
<< p_col << "); recommended nb2d=" << nb_opt << " (user value kept, not changed).\n";
}
else
{
// Validate nb_opt like the initial distribution: set_nloc_wfc_Eij
// returns non-zero if it is incompatible with the band/grid layout
// (ceil(nbands/nb_opt) < grid width). If so, revert to the validated
// nb_cur -- a half-updated pv would crash the later wavefunction setup.
int retry = pv.set(nlocal, nlocal, nb_opt, pv.blacs_ctxt);
retry += pv.set_nloc_wfc_Eij(PARAM.inp.nbands, GlobalV::ofs_running, GlobalV::ofs_warning);
if (retry != 0)
{
pv.set(nlocal, nlocal, nb_cur, pv.blacs_ctxt);
pv.set_nloc_wfc_Eij(PARAM.inp.nbands, GlobalV::ofs_running, GlobalV::ofs_warning);
GlobalV::ofs_warning << "init_basis_lcao: auto nb2d=" << nb_cur << " is " << issue
<< " for N=" << nlocal << ", kpar=" << kpar << "; recommended nb2d=" << nb_opt
<< " is incompatible with the band/grid layout, so nb2d=" << nb_cur << " is kept.\n";
}
else
{
GlobalV::ofs_warning << "init_basis_lcao: auto nb2d=" << nb_cur << " is " << issue
<< " for N=" << nlocal << ", kpar=" << kpar << "; auto-adjusted to nb2d=" << nb_opt << ".\n";
}
}
}
}
}

// init blacs context for genelpa
pv.set_desc_wfc_Eij(nlocal, PARAM.inp.nbands, pv.nrow);

Expand Down
Loading