Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
de8c3b8
Add plane wave assignment notes
mystic-qaq May 7, 2026
e7fd8db
Add plane wave optimization plan
mystic-qaq May 7, 2026
0b7b146
Add plane wave project planning docs
mystic-qaq May 7, 2026
65412a8
docs: update plane wave project plan
mystic-qaq May 8, 2026
42ceda2
docs: update plane wave project plan
MiniYuanBot May 8, 2026
cedfb98
docs: update Gamma_only plan, and rewrite chapter 6 and 8
Aunixt May 8, 2026
c74147a
docs: update plane wave project plan
MiniYuanBot May 8, 2026
50e56db
modify part7
mystic-qaq May 8, 2026
3d488d2
fix: adjust spacing between Chinese and English text
MiniYuanBot May 8, 2026
fbf52cf
doc: slightly modified 5.5
Aunixt May 8, 2026
f8f9c92
Delete the former plan
mystic-qaq May 13, 2026
1f35e9e
doc:add file ModulePW.md and WorkflowC.md
Aunixt May 13, 2026
3afb208
docs: move all the homework .md files to the directory homework_docs
Aunixt May 13, 2026
4c68e59
doc: Add more docs. The original algorithm analysis docs are in homew…
Aunixt May 15, 2026
0abe25e
Merge branch 'deepmodeling:develop' into WorkflowC
mystic-qaq May 15, 2026
9e24805
Merge branch 'deepmodeling:develop' into WorkflowC
Aunixt May 20, 2026
65acba3
Merge branch 'deepmodeling:develop' into WorkflowC
Aunixt May 21, 2026
fae9fe9
add WorkflowC_baseline_report.md
Aunixt May 22, 2026
543bc96
add release test
Aunixt May 22, 2026
c4ed0be
revise WorkflowC_baseline_report.md and fix some bugs
Aunixt May 22, 2026
8ac7271
refine the doc
Aunixt May 22, 2026
eeccb95
Merge branch 'deepmodeling:develop' into WorkflowC
mystic-qaq May 23, 2026
1f9d6d2
Merge branch 'deepmodeling:develop' into feat/WorkflowC
mystic-qaq May 23, 2026
dc1d7e6
rerun the examples and find error
Aunixt May 23, 2026
9889c24
remove gaas_tiny(20Ry)
Aunixt May 23, 2026
01e56aa
add more test
Aunixt May 23, 2026
7e6896b
modify the format
Aunixt May 23, 2026
0448e16
add Experience of Using AI
Aunixt May 23, 2026
d3f04e9
Merge branch 'deepmodeling:develop' into WorkflowC
mystic-qaq May 24, 2026
dedfbfc
Merge branch 'deepmodeling:develop' into feat/WorkflowC
mystic-qaq May 24, 2026
2d0dad7
remove time test in file
Aunixt May 25, 2026
f95785d
Merge branch 'deepmodeling:develop' into feat/WorkflowC
Aunixt May 25, 2026
f214d37
Merge branch 'deepmodeling:develop' into WorkflowC
Aunixt May 25, 2026
783b9ef
Merge pull request #1 from mystic-qaq/feat/WorkflowC
Aunixt May 25, 2026
906db29
add cache-reuse
Aunixt May 27, 2026
5e4a6ff
Merge branch 'deepmodeling:develop' into feat/cache-reuse
Aunixt May 27, 2026
63ded64
Merge branch 'deepmodeling:develop' into feat/cache-reuse
Aunixt May 28, 2026
8274962
Refine cache reuse, and refine test script. Later I will try to test …
Aunixt May 28, 2026
561720a
Merge branch 'deepmodeling:develop' into feat/cache-reuse
mystic-qaq May 29, 2026
398bc17
Return the timer to its original settings
Aunixt May 29, 2026
434b12c
bench: add bench scripts
Aunixt May 30, 2026
c62939a
Merge branch 'deepmodeling:develop' into feat/cache-reuse
mystic-qaq May 30, 2026
87e1304
bench: complete cache-reuse bench
Aunixt May 30, 2026
7c821b0
merge baseline bench
Aunixt May 30, 2026
533613c
doc: complete task8 docs
Aunixt May 30, 2026
8868a1e
doc: refine task8 report
Aunixt May 30, 2026
f0fa5f3
delete homework_docs
Aunixt May 30, 2026
b9fb1d2
remove vec_*.txt
Aunixt May 30, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
237 changes: 195 additions & 42 deletions source/source_basis/module_pw/pw_basis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "source_base/timer.h"
#include "source_base/global_function.h"

#include <vector>

namespace ModulePW
{
Expand All @@ -13,6 +14,62 @@ PW_Basis::PW_Basis()
classname="PW_Basis";
}

PW_Basis::PW_Basis(const PW_Basis& other)
{
this->classname = other.classname;
Comment on lines +17 to +19
#ifdef __MPI
this->pool_world = other.pool_world;
#endif
this->nst = other.nst;
this->nstnz = other.nstnz;
this->nstot = other.nstot;
this->npw = other.npw;
this->npwtot = other.npwtot;
this->nrxx = other.nrxx;
this->startz_current = other.startz_current;
this->nplane = other.nplane;
this->ig_gge0 = other.ig_gge0;
this->gamma_only = other.gamma_only;
this->full_pw = other.full_pw;
this->ggecut = other.ggecut;
this->gridecut_lat = other.gridecut_lat;
this->lat0 = other.lat0;
this->tpiba = other.tpiba;
this->tpiba2 = other.tpiba2;
this->latvec = other.latvec;
this->G = other.G;
this->GT = other.GT;
this->GGT = other.GGT;
this->omega = other.omega;
this->distribution_type = other.distribution_type;
this->full_pw_dim = other.full_pw_dim;
this->poolnproc = other.poolnproc;
this->poolrank = other.poolrank;
this->ngg = other.ngg;
this->fftnx = other.fftnx;
this->fftny = other.fftny;
this->fftnz = other.fftnz;
this->fftnxyz = other.fftnxyz;
this->fftnxy = other.fftnxy;
this->nx = other.nx;
this->ny = other.ny;
this->nz = other.nz;
this->nxyz = other.nxyz;
this->nxy = other.nxy;
this->liy = other.liy;
this->riy = other.riy;
this->lix = other.lix;
this->rix = other.rix;
this->xprime = other.xprime;
this->ng_xeq0 = other.ng_xeq0;
this->nmaxgr = other.nmaxgr;
this->device = other.device;
this->precision = other.precision;
this->double_data_ = other.double_data_;
this->float_data_ = other.float_data_;
this->fft_bundle.setfft(this->device, this->precision);
}

PW_Basis::PW_Basis(std::string device_, std::string precision_) : device(std::move(device_)), precision(std::move(precision_)) {
classname="PW_Basis";
this->fft_bundle.setfft("cpu",this->precision);
Expand All @@ -28,17 +85,13 @@ PW_Basis:: ~PW_Basis()
delete[] fftixy2ip;
delete[] nst_per;
delete[] npw_per;
delete[] gdirect;
delete[] gcar;
delete[] gg;
delete[] startz;
delete[] numz;
delete[] numg;
delete[] numr;
delete[] startg;
delete[] startr;
delete[] ig2igg;
delete[] gg_uniq;
this->clear_owned_cache();
#if defined(__CUDA) || defined(__ROCM)
if (this->device == "gpu")
{
Expand All @@ -48,6 +101,53 @@ PW_Basis:: ~PW_Basis()
#endif
}

void PW_Basis::clear_owned_cache()
{
this->invalidate_cache();
this->gg_cache_storage.reset();
this->gdirect_cache_storage.reset();
this->gcar_cache_storage.reset();
this->ig2igg_cache_storage.reset();
this->gg_uniq_cache_storage.reset();
this->gg = nullptr;
this->gdirect = nullptr;
this->gcar = nullptr;
this->ig2igg = nullptr;
this->gg_uniq = nullptr;
this->ngg = 0;
this->ig_gge0 = -1;
}

PW_Basis::CacheStats PW_Basis::get_cache_stats() const
{
CacheStats stats;
stats.local_pw_hits = this->local_pw_cache_hits.load();
stats.local_pw_misses = this->local_pw_cache_misses.load();
stats.uniqgg_hits = this->uniqgg_cache_hits.load();
stats.uniqgg_misses = this->uniqgg_cache_misses.load();
const bool has_local_pw_cache = this->local_pw_cache_valid.load() && this->npw > 0;
const bool has_uniqgg_cache = this->uniqgg_cache_valid.load() && this->ngg > 0;
if (has_local_pw_cache)
{
stats.cache_bytes += sizeof(double) * this->npw;
stats.cache_bytes += sizeof(ModuleBase::Vector3<double>) * this->npw * 2;
}
if (has_uniqgg_cache)
{
stats.cache_bytes += sizeof(int) * this->npw;
stats.cache_bytes += sizeof(double) * this->ngg;
}
return stats;
}

void PW_Basis::reset_cache_stats()
{
this->local_pw_cache_hits.store(0);
this->local_pw_cache_misses.store(0);
this->uniqgg_cache_hits.store(0);
this->uniqgg_cache_misses.store(0);
}

///
/// distribute plane wave basis and real-space grids to different processors
/// set up maps for fft and create arrays for MPI_Alltoall
Expand Down Expand Up @@ -138,10 +238,29 @@ void PW_Basis::collect_local_pw()
{
return;
}
ModuleBase::timer::start(this->classname, "collect_local_pw");
if (this->local_pw_cache_valid.load())
{
this->local_pw_cache_hits.fetch_add(1);
ModuleBase::timer::end(this->classname, "collect_local_pw");
return;
}
std::lock_guard<std::mutex> guard(this->cache_mutex);
if (this->local_pw_cache_valid.load())
{
this->local_pw_cache_hits.fetch_add(1);
ModuleBase::timer::end(this->classname, "collect_local_pw");
return;
}
this->local_pw_cache_misses.fetch_add(1);
this->ig_gge0 = -1;
delete[] this->gg; this->gg = new double[this->npw];
delete[] this->gdirect; this->gdirect = new ModuleBase::Vector3<double>[this->npw];
delete[] this->gcar; this->gcar = new ModuleBase::Vector3<double>[this->npw];
this->gg_cache_storage.reset(new double[this->npw]);
this->gdirect_cache_storage.reset(new ModuleBase::Vector3<double>[this->npw]);
this->gcar_cache_storage.reset(new ModuleBase::Vector3<double>[this->npw]);
this->gg = this->gg_cache_storage.get();
this->gdirect = this->gdirect_cache_storage.get();
this->gcar = this->gcar_cache_storage.get();
this->uniqgg_cache_valid.store(false);

ModuleBase::Vector3<double> f;
int gamma_num = 0;
Expand Down Expand Up @@ -182,6 +301,8 @@ void PW_Basis::collect_local_pw()
}
}
}
this->local_pw_cache_valid.store(true);
ModuleBase::timer::end(this->classname, "collect_local_pw");
return;
}

Expand All @@ -196,45 +317,75 @@ void PW_Basis::collect_uniqgg()
{
return;
}
ModuleBase::timer::start(this->classname, "collect_uniqgg");
if (this->uniqgg_cache_valid.load())
{
this->uniqgg_cache_hits.fetch_add(1);
ModuleBase::timer::end(this->classname, "collect_uniqgg");
return;
}
std::lock_guard<std::mutex> guard(this->cache_mutex);
if (this->uniqgg_cache_valid.load())
{
this->uniqgg_cache_hits.fetch_add(1);
ModuleBase::timer::end(this->classname, "collect_uniqgg");
return;
}
this->uniqgg_cache_misses.fetch_add(1);
this->ig_gge0 = -1;
delete[] this->ig2igg; this->ig2igg = new int [this->npw];
this->ig2igg_cache_storage.reset(new int[this->npw]);
this->ig2igg = this->ig2igg_cache_storage.get();

int *sortindex = new int [this->npw];//Reconstruct the mapping of the plane wave index ig according to the energy size of the plane waves
double *tmpgg = new double [this->npw];//Ranking the plane waves by energy size while ensuring that the same energy is preserved for each wave to correspond
double *tmpgg2 = new double [this->npw];//ranking the plane waves by energy size and removing the duplicates
ModuleBase::Vector3<double> f;
for(int ig = 0 ; ig < this-> npw ; ++ig)
std::vector<int> sortindex(this->npw); // Reconstruct the plane-wave index mapping after sorting by energy.
std::vector<double> tmpgg(this->npw);
std::vector<double> tmpgg2(this->npw);
if (this->local_pw_cache_valid.load() && this->gg != nullptr)
{
int isz = this->ig2isz[ig];
int iz = isz % this->nz;
int is = isz / this->nz;
int ixy = this->is2fftixy[is];
int ix = ixy / this->fftny;
int iy = ixy % this->fftny;
if (ix >= int(this->nx/2) + 1)
{
ix -= this->nx;
}
if (iy >= int(this->ny/2) + 1)
{
iy -= this->ny;
}
if (iz >= int(this->nz/2) + 1)
for(int ig = 0 ; ig < this-> npw ; ++ig)
{
iz -= this->nz;
tmpgg[ig] = this->gg[ig];
if(tmpgg[ig] < 1e-8)
{
this->ig_gge0 = ig;
}
}
f.x = ix;
f.y = iy;
f.z = iz;
tmpgg[ig] = f * (this->GGT * f);
if(tmpgg[ig] < 1e-8)
}
else
{
ModuleBase::Vector3<double> f;
for(int ig = 0 ; ig < this-> npw ; ++ig)
{
this->ig_gge0 = ig;
int isz = this->ig2isz[ig];
int iz = isz % this->nz;
int is = isz / this->nz;
int ixy = this->is2fftixy[is];
int ix = ixy / this->fftny;
int iy = ixy % this->fftny;
if (ix >= int(this->nx/2) + 1)
{
ix -= this->nx;
}
if (iy >= int(this->ny/2) + 1)
{
iy -= this->ny;
}
if (iz >= int(this->nz/2) + 1)
{
iz -= this->nz;
}
f.x = ix;
f.y = iy;
f.z = iz;
tmpgg[ig] = f * (this->GGT * f);
if(tmpgg[ig] < 1e-8)
{
this->ig_gge0 = ig;
}
}
}

ModuleBase::GlobalFunc::ZEROS(sortindex, this->npw);
ModuleBase::heapsort(this->npw, tmpgg, sortindex);
ModuleBase::GlobalFunc::ZEROS(sortindex.data(), this->npw);
ModuleBase::heapsort(this->npw, tmpgg.data(), sortindex.data());


int igg = 0;
Expand All @@ -261,14 +412,14 @@ void PW_Basis::collect_uniqgg()
}
tmpgg2[igg] = avg_gg / double(avg_n);
this->ngg = igg + 1;
delete[] this->gg_uniq; this->gg_uniq = new double [this->ngg];
this->gg_uniq_cache_storage.reset(new double[this->ngg]);
this->gg_uniq = this->gg_uniq_cache_storage.get();
for(int igg = 0 ; igg < this->ngg ; ++igg)
{
gg_uniq[igg] = tmpgg2[igg];
}
delete[] sortindex;
delete[] tmpgg;
delete[] tmpgg2;
this->uniqgg_cache_valid.store(true);
ModuleBase::timer::end(this->classname, "collect_uniqgg");
}

void PW_Basis::getfftixy2is(int * fftixy2is) const
Expand All @@ -295,10 +446,12 @@ void PW_Basis::getfftixy2is(int * fftixy2is) const

void PW_Basis::set_device(std::string device_) {
this->device = std::move(device_);
this->invalidate_cache();
}

void PW_Basis::set_precision(std::string precision_) {
this->precision = std::move(precision_);
this->invalidate_cache();
}

}
39 changes: 39 additions & 0 deletions source/source_basis/module_pw/pw_basis.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@
#include <complex>
#include "source_base/module_fft/fft_bundle.h"
#include <cstring>
#include <atomic>
#ifdef __MPI
#include "mpi.h"
#endif
#include <cstddef>
#include <cstdint>
#include <memory>
#include <mutex>

namespace ModulePW
{
Expand Down Expand Up @@ -56,8 +61,18 @@ class PW_Basis
{

public:
struct CacheStats
{
std::uint64_t local_pw_hits = 0;
std::uint64_t local_pw_misses = 0;
std::uint64_t uniqgg_hits = 0;
std::uint64_t uniqgg_misses = 0;
std::size_t cache_bytes = 0;
};

std::string classname;
PW_Basis();
PW_Basis(const PW_Basis& other);
PW_Basis(std::string device_, std::string precision_);
virtual ~PW_Basis();
//Init mpi parameters
Expand Down Expand Up @@ -137,9 +152,33 @@ class PW_Basis
//distribute plane waves and grids and set up fft
void setuptransform();

CacheStats get_cache_stats() const;
void reset_cache_stats();

protected:
int *startnsz_per=nullptr;//useless intermediate variable// startnsz_per[ip]: starting is * nz stick in the ip^th proc.

virtual void invalidate_cache()
{
this->local_pw_cache_valid.store(false);
this->uniqgg_cache_valid.store(false);
}

void clear_owned_cache();

std::atomic<bool> local_pw_cache_valid{false};
std::atomic<bool> uniqgg_cache_valid{false};
mutable std::mutex cache_mutex;
std::unique_ptr<double[]> gg_cache_storage;
std::unique_ptr<ModuleBase::Vector3<double>[]> gdirect_cache_storage;
std::unique_ptr<ModuleBase::Vector3<double>[]> gcar_cache_storage;
std::unique_ptr<int[]> ig2igg_cache_storage;
std::unique_ptr<double[]> gg_uniq_cache_storage;
std::atomic<std::uint64_t> local_pw_cache_hits{0};
std::atomic<std::uint64_t> local_pw_cache_misses{0};
std::atomic<std::uint64_t> uniqgg_cache_hits{0};
std::atomic<std::uint64_t> uniqgg_cache_misses{0};

//distribute plane waves to different processors
void distribute_g();

Expand Down
Loading
Loading