diff --git a/source/source_basis/module_pw/pw_basis.cpp b/source/source_basis/module_pw/pw_basis.cpp index 549fec8e5a4..6eac8852764 100644 --- a/source/source_basis/module_pw/pw_basis.cpp +++ b/source/source_basis/module_pw/pw_basis.cpp @@ -5,6 +5,7 @@ #include "source_base/timer.h" #include "source_base/global_function.h" +#include namespace ModulePW { @@ -13,6 +14,62 @@ PW_Basis::PW_Basis() classname="PW_Basis"; } +PW_Basis::PW_Basis(const PW_Basis& other) +{ + this->classname = other.classname; +#ifdef __MPI + this->pool_world = other.pool_world; +#endif + this->nst = other.nst; + this->nstnz = other.nstnz; + this->nstot = other.nstot; + this->npw = other.npw; + this->npwtot = other.npwtot; + this->nrxx = other.nrxx; + this->startz_current = other.startz_current; + this->nplane = other.nplane; + this->ig_gge0 = other.ig_gge0; + this->gamma_only = other.gamma_only; + this->full_pw = other.full_pw; + this->ggecut = other.ggecut; + this->gridecut_lat = other.gridecut_lat; + this->lat0 = other.lat0; + this->tpiba = other.tpiba; + this->tpiba2 = other.tpiba2; + this->latvec = other.latvec; + this->G = other.G; + this->GT = other.GT; + this->GGT = other.GGT; + this->omega = other.omega; + this->distribution_type = other.distribution_type; + this->full_pw_dim = other.full_pw_dim; + this->poolnproc = other.poolnproc; + this->poolrank = other.poolrank; + this->ngg = other.ngg; + this->fftnx = other.fftnx; + this->fftny = other.fftny; + this->fftnz = other.fftnz; + this->fftnxyz = other.fftnxyz; + this->fftnxy = other.fftnxy; + this->nx = other.nx; + this->ny = other.ny; + this->nz = other.nz; + this->nxyz = other.nxyz; + this->nxy = other.nxy; + this->liy = other.liy; + this->riy = other.riy; + this->lix = other.lix; + this->rix = other.rix; + this->xprime = other.xprime; + this->ng_xeq0 = other.ng_xeq0; + this->nmaxgr = other.nmaxgr; + this->device = other.device; + this->precision = other.precision; + this->double_data_ = other.double_data_; + this->float_data_ = other.float_data_; + this->fft_bundle.setfft(this->device, this->precision); +} + PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::move(device_)), precision(std::move(precision_)) { classname="PW_Basis"; this->fft_bundle.setfft("cpu",this->precision); @@ -28,17 +85,13 @@ PW_Basis:: ~PW_Basis() delete[] fftixy2ip; delete[] nst_per; delete[] npw_per; - delete[] gdirect; - delete[] gcar; - delete[] gg; delete[] startz; delete[] numz; delete[] numg; delete[] numr; delete[] startg; delete[] startr; - delete[] ig2igg; - delete[] gg_uniq; + this->clear_owned_cache(); #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { @@ -48,6 +101,53 @@ PW_Basis:: ~PW_Basis() #endif } +void PW_Basis::clear_owned_cache() +{ + this->invalidate_cache(); + this->gg_cache_storage.reset(); + this->gdirect_cache_storage.reset(); + this->gcar_cache_storage.reset(); + this->ig2igg_cache_storage.reset(); + this->gg_uniq_cache_storage.reset(); + this->gg = nullptr; + this->gdirect = nullptr; + this->gcar = nullptr; + this->ig2igg = nullptr; + this->gg_uniq = nullptr; + this->ngg = 0; + this->ig_gge0 = -1; +} + +PW_Basis::CacheStats PW_Basis::get_cache_stats() const +{ + CacheStats stats; + stats.local_pw_hits = this->local_pw_cache_hits.load(); + stats.local_pw_misses = this->local_pw_cache_misses.load(); + stats.uniqgg_hits = this->uniqgg_cache_hits.load(); + stats.uniqgg_misses = this->uniqgg_cache_misses.load(); + const bool has_local_pw_cache = this->local_pw_cache_valid.load() && this->npw > 0; + const bool has_uniqgg_cache = this->uniqgg_cache_valid.load() && this->ngg > 0; + if (has_local_pw_cache) + { + stats.cache_bytes += sizeof(double) * this->npw; + stats.cache_bytes += sizeof(ModuleBase::Vector3) * this->npw * 2; + } + if (has_uniqgg_cache) + { + stats.cache_bytes += sizeof(int) * this->npw; + stats.cache_bytes += sizeof(double) * this->ngg; + } + return stats; +} + +void PW_Basis::reset_cache_stats() +{ + this->local_pw_cache_hits.store(0); + this->local_pw_cache_misses.store(0); + this->uniqgg_cache_hits.store(0); + this->uniqgg_cache_misses.store(0); +} + /// /// distribute plane wave basis and real-space grids to different processors /// set up maps for fft and create arrays for MPI_Alltoall @@ -138,10 +238,29 @@ void PW_Basis::collect_local_pw() { return; } + ModuleBase::timer::start(this->classname, "collect_local_pw"); + if (this->local_pw_cache_valid.load()) + { + this->local_pw_cache_hits.fetch_add(1); + ModuleBase::timer::end(this->classname, "collect_local_pw"); + return; + } + std::lock_guard guard(this->cache_mutex); + if (this->local_pw_cache_valid.load()) + { + this->local_pw_cache_hits.fetch_add(1); + ModuleBase::timer::end(this->classname, "collect_local_pw"); + return; + } + this->local_pw_cache_misses.fetch_add(1); this->ig_gge0 = -1; - delete[] this->gg; this->gg = new double[this->npw]; - delete[] this->gdirect; this->gdirect = new ModuleBase::Vector3[this->npw]; - delete[] this->gcar; this->gcar = new ModuleBase::Vector3[this->npw]; + this->gg_cache_storage.reset(new double[this->npw]); + this->gdirect_cache_storage.reset(new ModuleBase::Vector3[this->npw]); + this->gcar_cache_storage.reset(new ModuleBase::Vector3[this->npw]); + this->gg = this->gg_cache_storage.get(); + this->gdirect = this->gdirect_cache_storage.get(); + this->gcar = this->gcar_cache_storage.get(); + this->uniqgg_cache_valid.store(false); ModuleBase::Vector3 f; int gamma_num = 0; @@ -182,6 +301,8 @@ void PW_Basis::collect_local_pw() } } } + this->local_pw_cache_valid.store(true); + ModuleBase::timer::end(this->classname, "collect_local_pw"); return; } @@ -196,45 +317,75 @@ void PW_Basis::collect_uniqgg() { return; } + ModuleBase::timer::start(this->classname, "collect_uniqgg"); + if (this->uniqgg_cache_valid.load()) + { + this->uniqgg_cache_hits.fetch_add(1); + ModuleBase::timer::end(this->classname, "collect_uniqgg"); + return; + } + std::lock_guard guard(this->cache_mutex); + if (this->uniqgg_cache_valid.load()) + { + this->uniqgg_cache_hits.fetch_add(1); + ModuleBase::timer::end(this->classname, "collect_uniqgg"); + return; + } + this->uniqgg_cache_misses.fetch_add(1); this->ig_gge0 = -1; - delete[] this->ig2igg; this->ig2igg = new int [this->npw]; + this->ig2igg_cache_storage.reset(new int[this->npw]); + this->ig2igg = this->ig2igg_cache_storage.get(); - int *sortindex = new int [this->npw];//Reconstruct the mapping of the plane wave index ig according to the energy size of the plane waves - double *tmpgg = new double [this->npw];//Ranking the plane waves by energy size while ensuring that the same energy is preserved for each wave to correspond - double *tmpgg2 = new double [this->npw];//ranking the plane waves by energy size and removing the duplicates - ModuleBase::Vector3 f; - for(int ig = 0 ; ig < this-> npw ; ++ig) + std::vector sortindex(this->npw); // Reconstruct the plane-wave index mapping after sorting by energy. + std::vector tmpgg(this->npw); + std::vector tmpgg2(this->npw); + if (this->local_pw_cache_valid.load() && this->gg != nullptr) { - int isz = this->ig2isz[ig]; - int iz = isz % this->nz; - int is = isz / this->nz; - int ixy = this->is2fftixy[is]; - int ix = ixy / this->fftny; - int iy = ixy % this->fftny; - if (ix >= int(this->nx/2) + 1) - { - ix -= this->nx; - } - if (iy >= int(this->ny/2) + 1) - { - iy -= this->ny; - } - if (iz >= int(this->nz/2) + 1) + for(int ig = 0 ; ig < this-> npw ; ++ig) { - iz -= this->nz; + tmpgg[ig] = this->gg[ig]; + if(tmpgg[ig] < 1e-8) + { + this->ig_gge0 = ig; + } } - f.x = ix; - f.y = iy; - f.z = iz; - tmpgg[ig] = f * (this->GGT * f); - if(tmpgg[ig] < 1e-8) + } + else + { + ModuleBase::Vector3 f; + for(int ig = 0 ; ig < this-> npw ; ++ig) { - this->ig_gge0 = ig; + int isz = this->ig2isz[ig]; + int iz = isz % this->nz; + int is = isz / this->nz; + int ixy = this->is2fftixy[is]; + int ix = ixy / this->fftny; + int iy = ixy % this->fftny; + if (ix >= int(this->nx/2) + 1) + { + ix -= this->nx; + } + if (iy >= int(this->ny/2) + 1) + { + iy -= this->ny; + } + if (iz >= int(this->nz/2) + 1) + { + iz -= this->nz; + } + f.x = ix; + f.y = iy; + f.z = iz; + tmpgg[ig] = f * (this->GGT * f); + if(tmpgg[ig] < 1e-8) + { + this->ig_gge0 = ig; + } } } - ModuleBase::GlobalFunc::ZEROS(sortindex, this->npw); - ModuleBase::heapsort(this->npw, tmpgg, sortindex); + ModuleBase::GlobalFunc::ZEROS(sortindex.data(), this->npw); + ModuleBase::heapsort(this->npw, tmpgg.data(), sortindex.data()); int igg = 0; @@ -261,14 +412,14 @@ void PW_Basis::collect_uniqgg() } tmpgg2[igg] = avg_gg / double(avg_n); this->ngg = igg + 1; - delete[] this->gg_uniq; this->gg_uniq = new double [this->ngg]; + this->gg_uniq_cache_storage.reset(new double[this->ngg]); + this->gg_uniq = this->gg_uniq_cache_storage.get(); for(int igg = 0 ; igg < this->ngg ; ++igg) { gg_uniq[igg] = tmpgg2[igg]; } - delete[] sortindex; - delete[] tmpgg; - delete[] tmpgg2; + this->uniqgg_cache_valid.store(true); + ModuleBase::timer::end(this->classname, "collect_uniqgg"); } void PW_Basis::getfftixy2is(int * fftixy2is) const @@ -295,10 +446,12 @@ void PW_Basis::getfftixy2is(int * fftixy2is) const void PW_Basis::set_device(std::string device_) { this->device = std::move(device_); + this->invalidate_cache(); } void PW_Basis::set_precision(std::string precision_) { this->precision = std::move(precision_); + this->invalidate_cache(); } } diff --git a/source/source_basis/module_pw/pw_basis.h b/source/source_basis/module_pw/pw_basis.h index b834cb0e0f4..2c5fe00fca3 100644 --- a/source/source_basis/module_pw/pw_basis.h +++ b/source/source_basis/module_pw/pw_basis.h @@ -9,9 +9,14 @@ #include #include "source_base/module_fft/fft_bundle.h" #include +#include #ifdef __MPI #include "mpi.h" #endif +#include +#include +#include +#include namespace ModulePW { @@ -56,8 +61,18 @@ class PW_Basis { public: + struct CacheStats + { + std::uint64_t local_pw_hits = 0; + std::uint64_t local_pw_misses = 0; + std::uint64_t uniqgg_hits = 0; + std::uint64_t uniqgg_misses = 0; + std::size_t cache_bytes = 0; + }; + std::string classname; PW_Basis(); + PW_Basis(const PW_Basis& other); PW_Basis(std::string device_, std::string precision_); virtual ~PW_Basis(); //Init mpi parameters @@ -137,9 +152,33 @@ class PW_Basis //distribute plane waves and grids and set up fft void setuptransform(); + CacheStats get_cache_stats() const; + void reset_cache_stats(); + protected: int *startnsz_per=nullptr;//useless intermediate variable// startnsz_per[ip]: starting is * nz stick in the ip^th proc. + virtual void invalidate_cache() + { + this->local_pw_cache_valid.store(false); + this->uniqgg_cache_valid.store(false); + } + + void clear_owned_cache(); + + std::atomic local_pw_cache_valid{false}; + std::atomic uniqgg_cache_valid{false}; + mutable std::mutex cache_mutex; + std::unique_ptr gg_cache_storage; + std::unique_ptr[]> gdirect_cache_storage; + std::unique_ptr[]> gcar_cache_storage; + std::unique_ptr ig2igg_cache_storage; + std::unique_ptr gg_uniq_cache_storage; + std::atomic local_pw_cache_hits{0}; + std::atomic local_pw_cache_misses{0}; + std::atomic uniqgg_cache_hits{0}; + std::atomic uniqgg_cache_misses{0}; + //distribute plane waves to different processors void distribute_g(); diff --git a/source/source_basis/module_pw/pw_basis_k.cpp b/source/source_basis/module_pw/pw_basis_k.cpp index 2c2d02bf927..5907b69ed2d 100644 --- a/source/source_basis/module_pw/pw_basis_k.cpp +++ b/source/source_basis/module_pw/pw_basis_k.cpp @@ -6,6 +6,7 @@ #include "source_base/timer.h" #include +#include namespace ModulePW { @@ -21,7 +22,7 @@ PW_Basis_K::~PW_Basis_K() delete[] npwk; delete[] igl2isz_k; delete[] igl2ig_k; - delete[] gk2; + this->clear_k_cache_storage(); #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { @@ -46,6 +47,44 @@ PW_Basis_K::~PW_Basis_K() #endif } +void PW_Basis_K::clear_k_cache_storage() +{ + this->invalidate_cache(); + this->k_gcar_cache_storage.reset(); + this->k_gk2_cache_storage.reset(); + this->gcar = nullptr; + this->gk2 = nullptr; +} + +PW_Basis_K::KCacheStats PW_Basis_K::get_k_cache_stats() const +{ + KCacheStats stats; + const auto base_stats = PW_Basis::get_cache_stats(); + static_cast(stats) = base_stats; + stats.gcar_hits = this->gcar_cache_hits.load(); + stats.gcar_misses = this->gcar_cache_misses.load(); + stats.gk2_hits = this->gk2_cache_hits.load(); + stats.gk2_misses = this->gk2_cache_misses.load(); + if (this->gcar_cache_valid.load() && this->npwk_max > 0 && this->nks > 0) + { + stats.cache_bytes += sizeof(ModuleBase::Vector3) * this->npwk_max * this->nks; + } + if (this->gk_cache_valid.load() && this->npwk_max > 0 && this->nks > 0) + { + stats.cache_bytes += sizeof(double) * this->npwk_max * this->nks; + } + return stats; +} + +void PW_Basis_K::reset_k_cache_stats() +{ + PW_Basis::reset_cache_stats(); + this->gcar_cache_hits.store(0); + this->gcar_cache_misses.store(0); + this->gk2_cache_hits.store(0); + this->gk2_cache_misses.store(0); +} + void PW_Basis_K::initparameters(const bool gamma_only_in, const double gk_ecut_in, const int nks_in, // number of k points in this pool @@ -101,6 +140,7 @@ void PW_Basis_K::initparameters(const bool gamma_only_in, this->fftnxy = this->fftnx * this->fftny; this->fftnxyz = this->fftnxy * this->fftnz; this->distribution_type = distribution_type_in; + this->invalidate_cache(); #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { @@ -129,18 +169,23 @@ void PW_Basis_K::initparameters(const bool gamma_only_in, void PW_Basis_K::setupIndGk() { + ModuleBase::timer::start(this->classname, "setupIndGk"); + this->invalidate_cache(); // count npwk this->npwk_max = 0; delete[] this->npwk; this->npwk = new int[this->nks]; + std::vector> selected_ig(this->nks); for (int ik = 0; ik < this->nks; ik++) { int ng = 0; + selected_ig[ik].reserve(this->npw); for (int ig = 0; ig < this->npw; ig++) { const double gk2 = this->cal_GplusK_cartesian(ik, ig).norm2(); if (gk2 <= this->gk_ecut) { + selected_ig[ik].push_back(ig); ++ng; } } @@ -166,6 +211,7 @@ void PW_Basis_K::setupIndGk() // get igl2isz_k and igl2ig_k if (this->npwk_max <= 0) { + ModuleBase::timer::end(this->classname, "setupIndGk"); return; } @@ -176,15 +222,11 @@ void PW_Basis_K::setupIndGk() for (int ik = 0; ik < this->nks; ik++) { int igl = 0; - for (int ig = 0; ig < this->npw; ig++) + for (const int ig : selected_ig[ik]) { - const double gk2 = this->cal_GplusK_cartesian(ik, ig).norm2(); - if (gk2 <= this->gk_ecut) - { - this->igl2isz_k[ik * npwk_max + igl] = this->ig2isz[ig]; - this->igl2ig_k[ik * npwk_max + igl] = ig; - ++igl; - } + this->igl2isz_k[ik * npwk_max + igl] = this->ig2isz[ig]; + this->igl2ig_k[ik * npwk_max + igl] = ig; + ++igl; } } #if defined(__CUDA) || defined(__ROCM) @@ -195,6 +237,7 @@ void PW_Basis_K::setupIndGk() } #endif this->get_ig2ixyz_k(); + ModuleBase::timer::end(this->classname, "setupIndGk"); return; } @@ -249,19 +292,62 @@ void PW_Basis_K::setuptransform() void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_height_in, const double& erf_sigma_in) { - this->erf_ecut = erf_ecut_in; - this->erf_height = erf_height_in; - this->erf_sigma = erf_sigma_in; + ModuleBase::timer::start(this->classname, "collect_local_pw"); + const bool gcar_hit = this->gcar_cache_valid.load(); + const bool gk2_hit = this->gk_cache_valid.load() + && this->erf_ecut == erf_ecut_in + && this->erf_height == erf_height_in + && this->erf_sigma == erf_sigma_in; if (this->npwk_max <= 0) { + ModuleBase::timer::end(this->classname, "collect_local_pw"); + return; + } + if (gcar_hit && gk2_hit) + { + this->gcar_cache_hits.fetch_add(1); + this->gk2_cache_hits.fetch_add(1); + ModuleBase::timer::end(this->classname, "collect_local_pw"); + return; + } + std::lock_guard guard(this->cache_mutex); + const bool locked_gcar_hit = this->gcar_cache_valid.load(); + const bool locked_gk2_hit = this->gk_cache_valid.load() + && this->erf_ecut == erf_ecut_in + && this->erf_height == erf_height_in + && this->erf_sigma == erf_sigma_in; + if (locked_gcar_hit && locked_gk2_hit) + { + this->gcar_cache_hits.fetch_add(1); + this->gk2_cache_hits.fetch_add(1); + ModuleBase::timer::end(this->classname, "collect_local_pw"); return; } - delete[] gk2; - delete[] gcar; - this->gk2 = new double[this->npwk_max * this->nks]; - this->gcar = new ModuleBase::Vector3[this->npwk_max * this->nks]; - ModuleBase::Memory::record("PW_B_K::gk2", sizeof(double) * this->npwk_max * this->nks); - ModuleBase::Memory::record("PW_B_K::gcar", sizeof(ModuleBase::Vector3) * this->npwk_max * this->nks); + if (locked_gcar_hit) + { + this->gcar_cache_hits.fetch_add(1); + } + else + { + this->gcar_cache_misses.fetch_add(1); + this->k_gcar_cache_storage.reset(new ModuleBase::Vector3[this->npwk_max * this->nks]); + this->gcar = this->k_gcar_cache_storage.get(); + ModuleBase::Memory::record("PW_B_K::gcar", sizeof(ModuleBase::Vector3) * this->npwk_max * this->nks); + } + if (locked_gk2_hit) + { + this->gk2_cache_hits.fetch_add(1); + } + else + { + this->gk2_cache_misses.fetch_add(1); + this->k_gk2_cache_storage.reset(new double[this->npwk_max * this->nks]); + this->gk2 = this->k_gk2_cache_storage.get(); + ModuleBase::Memory::record("PW_B_K::gk2", sizeof(double) * this->npwk_max * this->nks); + } + this->erf_ecut = erf_ecut_in; + this->erf_height = erf_height_in; + this->erf_sigma = erf_sigma_in; ModuleBase::Vector3 f; for (int ik = 0; ik < this->nks; ++ik) @@ -291,36 +377,50 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h f.y = iy; f.z = iz; - this->gcar[ik * npwk_max + igl] = f * this->G; + if (!locked_gcar_hit) + { + this->gcar[ik * npwk_max + igl] = f * this->G; + } double temp_gk2 = (f + kv) * (this->GGT * (f + kv)); - if (erf_height > 0) + if (!locked_gk2_hit && erf_height > 0) { this->gk2[ik * npwk_max + igl] = temp_gk2 + erf_height / tpiba2 * (1.0 + std::erf((temp_gk2 * tpiba2 - erf_ecut) / erf_sigma)); } - else + else if (!locked_gk2_hit) { this->gk2[ik * npwk_max + igl] = temp_gk2; } } } + if (!locked_gcar_hit) + { + this->sync_gcar_device_cache(); + this->gcar_cache_valid.store(true); + } + if (!locked_gk2_hit) + { + this->sync_gk2_device_cache(); + this->gk_cache_valid.store(true); + } + ModuleBase::timer::end(this->classname, "collect_local_pw"); +} + +void PW_Basis_K::sync_gcar_device_cache() +{ #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { if (this->float_data_) { - resmem_sd_op()(this->s_gk2, this->npwk_max * this->nks); resmem_sd_op()(this->s_gcar, this->npwk_max * this->nks * 3); - castmem_d2s_h2d_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); castmem_d2s_h2d_op()(this->s_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); } if (this->double_data_) { - resmem_dd_op()(this->d_gk2, this->npwk_max * this->nks); resmem_dd_op()(this->d_gcar, this->npwk_max * this->nks * 3); - syncmem_d2d_h2d_op()(this->d_gk2, this->gk2, this->npwk_max * this->nks); syncmem_d2d_h2d_op()(this->d_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); @@ -331,9 +431,7 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h #endif if (this->float_data_) { - resmem_sh_op()(this->s_gk2, this->npwk_max * this->nks, "PW_B_K::s_gk2"); resmem_sh_op()(this->s_gcar, this->npwk_max * this->nks * 3, "PW_B_K::s_gcar"); - castmem_d2s_h2h_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); castmem_d2s_h2h_op()(this->s_gcar, reinterpret_cast(&this->gcar[0][0]), this->npwk_max * this->nks * 3); @@ -349,6 +447,39 @@ void PW_Basis_K::collect_local_pw(const double& erf_ecut_in, const double& erf_h #endif } +void PW_Basis_K::sync_gk2_device_cache() +{ +#if defined(__CUDA) || defined(__ROCM) + if (this->device == "gpu") + { + if (this->float_data_) + { + resmem_sd_op()(this->s_gk2, this->npwk_max * this->nks); + castmem_d2s_h2d_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); + } + if (this->double_data_) + { + resmem_dd_op()(this->d_gk2, this->npwk_max * this->nks); + syncmem_d2d_h2d_op()(this->d_gk2, this->gk2, this->npwk_max * this->nks); + } + } + else + { +#endif + if (this->float_data_) + { + resmem_sh_op()(this->s_gk2, this->npwk_max * this->nks, "PW_B_K::s_gk2"); + castmem_d2s_h2h_op()(this->s_gk2, this->gk2, this->npwk_max * this->nks); + } + if (this->double_data_) + { + this->d_gk2 = this->gk2; + } +#if defined(__CUDA) || defined(__ROCM) + } +#endif +} + ModuleBase::Vector3 PW_Basis_K::cal_GplusK_cartesian(const int ik, const int ig) const { int isz = this->ig2isz[ig]; @@ -529,4 +660,4 @@ double* PW_Basis_K::get_gk2_data() const return this->d_gk2; } -} // namespace ModulePW \ No newline at end of file +} // namespace ModulePW diff --git a/source/source_basis/module_pw/pw_basis_k.h b/source/source_basis/module_pw/pw_basis_k.h index f633a30769d..0eb4b785b9b 100644 --- a/source/source_basis/module_pw/pw_basis_k.h +++ b/source/source_basis/module_pw/pw_basis_k.h @@ -56,6 +56,14 @@ class PW_Basis_K : public PW_Basis { public: + struct KCacheStats : public PW_Basis::CacheStats + { + std::uint64_t gcar_hits = 0; + std::uint64_t gcar_misses = 0; + std::uint64_t gk2_hits = 0; + std::uint64_t gk2_misses = 0; + }; + PW_Basis_K(); PW_Basis_K(std::string device_, std::string precision_) : PW_Basis(device_, precision_) {classname="PW_Basis_K";} ~PW_Basis_K(); @@ -99,7 +107,29 @@ class PW_Basis_K : public PW_Basis const double& erf_height_in = 0.0, const double& erf_sigma_in = 0.1); + KCacheStats get_k_cache_stats() const; + void reset_k_cache_stats(); + private: + void invalidate_cache() override + { + PW_Basis::invalidate_cache(); + this->gcar_cache_valid.store(false); + this->gk_cache_valid.store(false); + } + + void clear_k_cache_storage(); + void sync_gcar_device_cache(); + void sync_gk2_device_cache(); + + std::atomic gcar_cache_valid{false}; + std::atomic gk_cache_valid{false}; + std::unique_ptr[]> k_gcar_cache_storage; + std::unique_ptr k_gk2_cache_storage; + std::atomic gcar_cache_hits{0}; + std::atomic gcar_cache_misses{0}; + std::atomic gk2_cache_hits{0}; + std::atomic gk2_cache_misses{0}; float * s_gk2 = nullptr; double * d_gk2 = nullptr; // modulus (G+K)^2 of G vectors [npwk_max*nks] //create igl2isz_k map array for fft @@ -280,4 +310,3 @@ class PW_Basis_K : public PW_Basis #endif //PlaneWave_K class #include "./pw_basis_k_big.h" //temporary it will be removed - diff --git a/source/source_basis/module_pw/pw_distributeg.cpp b/source/source_basis/module_pw/pw_distributeg.cpp index a13fc57f9b9..93d9b07d294 100644 --- a/source/source_basis/module_pw/pw_distributeg.cpp +++ b/source/source_basis/module_pw/pw_distributeg.cpp @@ -161,6 +161,7 @@ void PW_Basis::get_ig2isz_is2fftixy( { delete[] this->ig2isz; this->ig2isz = nullptr; // map ig to the z coordinate of this planewave. delete[] this->is2fftixy; this->is2fftixy = nullptr; // map is (index of sticks) to ixy (iy + ix * fftny). + this->invalidate_cache(); #if defined(__CUDA) || defined(__ROCM) if (this->device == "gpu") { delmem_int_op()(this->d_is2fftixy); @@ -226,6 +227,7 @@ void PW_Basis::get_ig2isz_is2fftixy( syncmem_int_h2d_op()(ig2ixyz_gpu, ig2ixyz.data(), this->npw); } #endif + this->invalidate_cache(); return; } } // namespace ModulePW \ No newline at end of file diff --git a/source/source_basis/module_pw/pw_gatherscatter.h b/source/source_basis/module_pw/pw_gatherscatter.h index 207320f4268..60d2f9bb0d7 100644 --- a/source/source_basis/module_pw/pw_gatherscatter.h +++ b/source/source_basis/module_pw/pw_gatherscatter.h @@ -1,6 +1,5 @@ #include "pw_basis.h" #include "source_base/global_function.h" -#include "source_base/timer.h" #include namespace ModulePW @@ -15,8 +14,8 @@ namespace ModulePW template void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const { - - if(this->poolnproc == 1) //In this case nst=nstot, nz = nplane, + + if(this->poolnproc == 1) //In this case nst=nstot, nz = nplane, { const int nst_ = this->nst; const int nz_ = this->nz; @@ -112,7 +111,10 @@ void PW_Basis::gatherp_scatters(std::complex* in, std::complex* out) const template void PW_Basis::gathers_scatterp(std::complex* in, std::complex* out) const { - if(this->poolnproc == 1) //In this case nrxx=fftnx*fftny*nz, nst = nstot, + + + + if(this->poolnproc == 1) //In this case nrxx=fftnx*fftny*nz, nst = nstot, { const int nrxx_ = this->nrxx; const int nst_ = this->nst; diff --git a/source/source_basis/module_pw/pw_init.cpp b/source/source_basis/module_pw/pw_init.cpp index 08c676d39f3..628be0d5ae8 100644 --- a/source/source_basis/module_pw/pw_init.cpp +++ b/source/source_basis/module_pw/pw_init.cpp @@ -13,6 +13,7 @@ void PW_Basis:: initmpi( this->poolnproc = poolnproc_in; this->poolrank = poolrank_in; this->pool_world = pool_world_in; + this->invalidate_cache(); } #endif /// @@ -142,6 +143,7 @@ void PW_Basis:: initgrids( this->nz = ibox[2]; this->nxy =this->nx * this->ny; this->nxyz = this->nxy * this->nz; + this->invalidate_cache(); delete[] ibox; return; @@ -203,6 +205,7 @@ void PW_Basis:: initgrids( MPI_Allreduce(MPI_IN_PLACE, &this->gridecut_lat, 1, MPI_DOUBLE, MPI_MIN , this->pool_world); #endif this->gridecut_lat -= 1e-6; + this->invalidate_cache(); delete[] ibox; return; @@ -240,6 +243,7 @@ void PW_Basis:: initparameters( this->ggecut = this->gridecut_lat; } this->distribution_type = distribution_type_in; + this->invalidate_cache(); } // Set parameters about full planewave, used only in OFDFT for now. sunliang added 2022-08-30 @@ -251,5 +255,6 @@ void PW_Basis::setfullpw( this->full_pw = inpt_full_pw; this->full_pw_dim = inpt_full_pw_dim; if (!this->full_pw) this->full_pw_dim = 0; + this->invalidate_cache(); +} } -} \ No newline at end of file diff --git a/source/source_basis/module_pw/test/test1-1-1.cpp b/source/source_basis/module_pw/test/test1-1-1.cpp index 3eb9d8fd5e4..2e4ebe2f7aa 100644 --- a/source/source_basis/module_pw/test/test1-1-1.cpp +++ b/source/source_basis/module_pw/test/test1-1-1.cpp @@ -29,8 +29,28 @@ TEST_F(PWTEST,test1_1_1) pwtest.initgrids(lat0, latvec, wfcecut); pwtest.initparameters(gamma_only, wfcecut, distribution_type,xprime); pwtest.setuptransform(); + pwtest.reset_cache_stats(); pwtest.collect_local_pw(); pwtest.collect_uniqgg(); + auto stats_after_build = pwtest.get_cache_stats(); + EXPECT_EQ(stats_after_build.local_pw_misses, 1); + EXPECT_EQ(stats_after_build.uniqgg_misses, 1); + double* gg_ptr = pwtest.gg; + int* ig2igg_ptr = pwtest.ig2igg; + double* gguniq_ptr = pwtest.gg_uniq; + const int ngg_before = pwtest.ngg; + const double gg_sample = pwtest.gg[0]; + pwtest.collect_local_pw(); + pwtest.collect_uniqgg(); + EXPECT_EQ(pwtest.gg, gg_ptr); + EXPECT_EQ(pwtest.ig2igg, ig2igg_ptr); + EXPECT_EQ(pwtest.gg_uniq, gguniq_ptr); + EXPECT_EQ(pwtest.ngg, ngg_before); + EXPECT_DOUBLE_EQ(pwtest.gg[0], gg_sample); + auto stats_after_hit = pwtest.get_cache_stats(); + EXPECT_EQ(stats_after_hit.local_pw_hits, 1); + EXPECT_EQ(stats_after_hit.uniqgg_hits, 1); + EXPECT_GT(stats_after_hit.cache_bytes, 0); ModuleBase::Matrix3 GT,G,GGT; GT = latvec.Inverse(); G = GT.Transpose(); @@ -229,4 +249,4 @@ TEST_F(PWTEST,test1_1_1) delete[] irindex; -} \ No newline at end of file +} diff --git a/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp b/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp index 84932bae2ff..42ae1b3a73d 100644 --- a/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp +++ b/source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp @@ -183,9 +183,28 @@ TEST_F(PWBasisKTEST, CollectLocalPW) const bool xprime_in = true; basis_k.initparameters(gamma_only_in, gk_ecut_in, nks_in,kvec_d_in, distribution_type_in, xprime_in); EXPECT_NO_THROW(basis_k.setuptransform()); + basis_k.reset_k_cache_stats(); EXPECT_NO_THROW(basis_k.collect_local_pw()); + ASSERT_GT(basis_k.npwk[0], 0); + auto* gk2_ptr = basis_k.get_gk2_data(); + auto* gcar_ptr = basis_k.get_gcar_data(); + const double gk2_sample = basis_k.getgk2(0,0); + const auto stats_after_build = basis_k.get_k_cache_stats(); + EXPECT_EQ(stats_after_build.gcar_misses, 1); + EXPECT_EQ(stats_after_build.gk2_misses, 1); + EXPECT_NO_THROW(basis_k.collect_local_pw()); + EXPECT_EQ(basis_k.get_gk2_data(), gk2_ptr); + EXPECT_EQ(basis_k.get_gcar_data(), gcar_ptr); + EXPECT_DOUBLE_EQ(basis_k.getgk2(0,0), gk2_sample); + EXPECT_NO_THROW(basis_k.collect_local_pw(1.0, 0.5, 0.2)); + EXPECT_EQ(basis_k.get_gcar_data(), gcar_ptr); + const auto stats_after_hits = basis_k.get_k_cache_stats(); + EXPECT_EQ(stats_after_hits.gcar_hits, 2); + EXPECT_EQ(stats_after_hits.gcar_misses, 1); + EXPECT_EQ(stats_after_hits.gk2_hits, 1); + EXPECT_EQ(stats_after_hits.gk2_misses, 2); + EXPECT_GT(stats_after_hits.cache_bytes, 0); EXPECT_EQ(basis_k.npw,3695); EXPECT_EQ(basis_k.npwk_max,2721); } -