Skip to content
27 changes: 27 additions & 0 deletions source/source_io/module_output/print_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -412,6 +412,33 @@ void print_kpar(const int &nks, const int &kpar_lcao)
"%%%%%%%%%%%%\n";
}
}

// 16) recommend the optimal kpar for k-point parallelism.
// kpar splits the processes into pools, each diagonalizing a subset of the nks
// k-points independently -> near-linear speedup of the k-point loop. The ceiling is
// the number of k-points, so the optimal kpar is the largest divisor of NPROC that
// does not exceed nks (more pools than k-points leaves pools idle). Perfect balance
// additionally wants nks % kpar == 0 (see the warning above); very large systems may
// prefer fewer pools to keep enough ranks per pool for the per-pool diagonalization.
// This is advisory only -- kpar fixes the MPI pool layout and cannot be changed here.
int kpar_opt = 1;
for (int d = GlobalV::NPROC; d >= 1; --d)
{
if (GlobalV::NPROC % d == 0 && d <= nks)
{
kpar_opt = d;
break;
}
}
Comment thread
A-006 marked this conversation as resolved.
Outdated
if (kpar_opt != kpar_lcao)
{
GlobalV::ofs_running << " kpar advisory: current kpar = " << kpar_lcao << " (NPROC = "
<< GlobalV::NPROC << ", nks = " << nks << "). Recommended kpar = "
<< kpar_opt << " (largest divisor of NPROC <= nks) to parallelize"
<< " the k-point loop.\n";
ModuleBase::WARNING("ModuleIO::print_kpar",
"kpar is not optimal; see running log for the recommended value.");
}
Comment thread
A-006 marked this conversation as resolved.
Outdated
}

} // namespace ModuleIO
83 changes: 83 additions & 0 deletions source/source_lcao/LCAO_init_basis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "source_io/module_parameter/parameter.h"
#include "source_base/parallel_comm.h"

#include <cmath>

namespace LCAO_domain
{

Expand Down Expand Up @@ -78,6 +80,87 @@ void init_basis_lcao(Parallel_Orbitals& pv,
try_nb = pv.set_nloc_wfc_Eij(PARAM.inp.nbands, GlobalV::ofs_running, GlobalV::ofs_warning);
}

// ---- nb2d (ScaLAPACK 2D block-cyclic block size) load-balance check ----
// ScaLAPACK diagonalizes the N x N matrix (N = nlocal) on a p x q process grid
// (p <= q) with square block size nb2d. With kpar k-point pools the diagonalization
// runs *per pool* on NPROC/kpar processes (Parallel_K2D::P2D_pool, whose block size
// is ParaV->get_block_size() == pv.nb), so the effective grid is the near-square
// factorization of NPROC/kpar -- exactly how Parallel_2D builds it. The long edge q
// governs load balance:
// B = N / (nb2d * q) blocks owned per process along q.
// B <= 1 : one process owns a whole panel while others idle -> a catastrophic
// load-imbalance "cliff" (nb2d too large; energy unaffected, time blows up).
// But nb2d also must not be too SMALL: blocks below ~16 lose BLAS/GEMM efficiency and
// explode block-cyclic communication (nb=1 = the slow "over-scatter" end). So the time-
// vs-nb2d curve is U-shaped and the healthy window is two-sided:
// nb_lo = min(16, N/(2q)) <= nb2d <= nb_hi = floor(N/(2q))
// recommended nb2d = min(64, nb_hi) (largest balanced block, capped for BLAS).
// This block-size U-curve is a property of the ScaLAPACK 2D block-cyclic *dense*
// diagonalization, so the check is restricted to ks_solver=scalapack_gvx. (genelpa/
// elpa do their own internal block tuning; lapack/cusolver/pexsi do not diagonalize
// on this distributed 2D grid, so the nb2d cliff does not apply to them.)
if (PARAM.inp.ks_solver == "scalapack_gvx")
{
const int kpar = (PARAM.globalv.kpar_lcao > 0) ? PARAM.globalv.kpar_lcao : 1;
const int np_total = pv.dim0 * pv.dim1; // diagonalization world size
const int np_pool = (kpar > 0) ? np_total / kpar : np_total; // processes per pool
if (np_pool > 1 && nlocal > 0)
Comment thread
A-006 marked this conversation as resolved.
{
// near-square factorization np_pool = p * q, p <= q (matches Parallel_2D)
int p_row = static_cast<int>(std::sqrt(np_pool + 0.5));
while (p_row > 1 && np_pool % p_row != 0) { --p_row; }
const int p_col = np_pool / p_row; // long edge q (>= p_row)

// Healthy block-size window is two-sided (the nb2d-vs-time curve is U-shaped):
// nb_hi = floor(N / (2*q)) upper bound -- keep >= 2 blocks per process
// (B >= 2); larger nb -> load imbalance "cliff".
// nb_lo = min(16, nb_hi) lower bound -- blocks below ~16 lose BLAS/GEMM
// efficiency (1x1 ops) and explode block-cyclic
// communication (panel count ~ N/nb); nb=1 is the
// slow "over-scatter" end. Capped by nb_hi because a
// tiny system on many processes cannot afford large
// blocks (then balance wins and nb_lo == nb_hi).
// recommended = the largest balanced block, capped at 64 for BLAS efficiency.
const int nb_hi = (nlocal >= 2 * p_col) ? nlocal / (2 * p_col) : 1;
const int nb_lo = (16 < nb_hi) ? 16 : nb_hi;
const int nb_opt = (nb_hi < 64) ? nb_hi : 64;
const int nb_cur = pv.nb;

const char* issue = nullptr;
if (nb_cur > nb_hi)
{
issue = "too large -> ScaLAPACK load imbalance (one process owns a whole panel)";
}
else if (nb_cur < nb_lo)
{
issue = "too small -> over-scatter (poor BLAS efficiency and heavy communication)";
}

// Two cases, both reported via ofs_warning:
// (1) user explicitly set nb2d (PARAM.inp.nb2d != 0): respect it -- do NOT
// change the value -- but warn so the issue is visible.
// (2) auto nb2d (PARAM.inp.nb2d == 0): correct it to nb_opt. pv.nb feeds both
// the kpar==1 path and Parallel_K2D (ParaV->get_block_size()), so
// resetting it here also fixes the per-pool diagonalization.
if (issue != nullptr)
{
if (PARAM.inp.nb2d != 0)
{
GlobalV::ofs_warning << "init_basis_lcao: user-set nb2d=" << nb_cur << " is " << issue
<< " for N=" << nlocal << ", kpar=" << kpar << " (per-pool grid " << p_row << "x"
<< p_col << "); recommended nb2d=" << nb_opt << " (user value kept, not changed).\n";
}
else
{
GlobalV::ofs_warning << "init_basis_lcao: auto nb2d=" << nb_cur << " is " << issue
<< " for N=" << nlocal << ", kpar=" << kpar << "; auto-adjusted to nb2d=" << nb_opt << ".\n";
pv.set(nlocal, nlocal, nb_opt, pv.blacs_ctxt);
pv.set_nloc_wfc_Eij(PARAM.inp.nbands, GlobalV::ofs_running, GlobalV::ofs_warning);
Comment thread
A-006 marked this conversation as resolved.
Outdated
}
}
}
}

// init blacs context for genelpa
pv.set_desc_wfc_Eij(nlocal, PARAM.inp.nbands, pv.nrow);

Expand Down
Loading