Skip to content
89 changes: 53 additions & 36 deletions source/source_basis/module_pw/pw_gatherscatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,41 @@

namespace ModulePW
{
namespace detail
{
// Copy complex buffers through the interleaved scalar stream so compilers can
// vectorize the contiguous real/imaginary data movement.
template <typename T>
inline void copy_complex_buffer(const std::complex<T>* in, std::complex<T>* out, const int count)
Comment thread
Aunixt marked this conversation as resolved.
{
const T* __restrict__ in_r = reinterpret_cast<const T*>(in);
T* __restrict__ out_r = reinterpret_cast<T*>(out);
#ifdef __GNUC__
#pragma GCC ivdep
#endif
for (int i = 0; i < 2 * count; ++i)
{
out_r[i] = in_r[i];
}
Comment thread
Aunixt marked this conversation as resolved.
Outdated
}

// Top-level transform copies own the OpenMP parallel region; gather/scatter
// loops call the non-parallel helper inside their existing parallel regions.
template <typename T>
inline void copy_complex_buffer_parallel(const std::complex<T>* in, std::complex<T>* out, const int count)
{
const T* __restrict__ in_r = reinterpret_cast<const T*>(in);
T* __restrict__ out_r = reinterpret_cast<T*>(out);
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int i = 0; i < 2 * count; ++i)
Comment thread
Aunixt marked this conversation as resolved.
Outdated
{
out_r[i] = in_r[i];
}
}
} // namespace detail

/**
* @brief gather planes and scatter sticks
* @param in: (nplane,fftny,fftnx)
Expand All @@ -27,12 +62,9 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
for(int is = 0 ; is < nst_ ; ++is)
{
int ixy = istot2ixy_[is];
std::complex<T> *outp = &out[is*nz_];
std::complex<T> *inp = &in[ixy*nz_];
for(int iz = 0 ; iz < nz_ ; ++iz)
{
outp[iz] = inp[iz];
}
std::complex<T>* outp = &out[is*nz_];
const std::complex<T>* inp = &in[ixy*nz_];
detail::copy_complex_buffer(inp, outp, nz_);
}
return;
}
Expand All @@ -50,12 +82,9 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
for (int istot = 0; istot < nstot_gps; ++istot)
{
int ixy = istot2ixy_gps[istot];
std::complex<T> *outp = &out[istot * nplane_gps];
std::complex<T> *inp = &in[ixy * nplane_gps];
for (int iz = 0; iz < nplane_gps; ++iz)
{
outp[iz] = inp[iz];
}
std::complex<T>* outp = &out[istot * nplane_gps];
const std::complex<T>* inp = &in[ixy * nplane_gps];
detail::copy_complex_buffer(inp, outp, nplane_gps);
}

//exchange data
Expand Down Expand Up @@ -90,12 +119,9 @@ void PW_Basis::gatherp_scatters(std::complex<T>* in, std::complex<T>* out) const
int nzip = numz_gps[ip];
std::complex<T> *outp0 = &out[startz_gps[ip]];
std::complex<T> *inp0 = &in[startg_gps[ip]];
std::complex<T> *outp = &outp0[is * nz_gps];
std::complex<T> *inp = &inp0[is * nzip ];
for (int izip = 0; izip < nzip; ++izip)
{
outp[izip] = inp[izip];
}
std::complex<T>* outp = &outp0[is * nz_gps];
const std::complex<T>* inp = &inp0[is * nzip ];
detail::copy_complex_buffer(inp, outp, nzip);
}
}
#endif
Expand Down Expand Up @@ -132,12 +158,9 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
for(int is = 0 ; is < nst_ ; ++is)
{
int ixy = istot2ixy_[is];
std::complex<T> *outp = &out[ixy*nz_];
std::complex<T> *inp = &in[is*nz_];
for(int iz = 0 ; iz < nz_ ; ++iz)
{
outp[iz] = inp[iz];
}
std::complex<T>* outp = &out[ixy*nz_];
const std::complex<T>* inp = &in[is*nz_];
detail::copy_complex_buffer(inp, outp, nz_);
}
return;
}
Expand All @@ -162,12 +185,9 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
int nzip = numz_[ip];
std::complex<T> *outp0 = &out[startg_[ip]];
std::complex<T> *inp0 = &in[startz_[ip]];
std::complex<T> *outp = &outp0[is * nzip];
std::complex<T> *inp = &inp0[is * nz_ ];
for (int izip = 0; izip < nzip; ++izip)
{
outp[izip] = inp[izip];
}
std::complex<T>* outp = &outp0[is * nzip];
const std::complex<T>* inp = &inp0[is * nz_ ];
detail::copy_complex_buffer(inp, outp, nzip);
}
}

Expand Down Expand Up @@ -205,12 +225,9 @@ void PW_Basis::gathers_scatterp(std::complex<T>* in, std::complex<T>* out) const
{
int ixy = istot2ixy[istot];
//int ixy = (ixy / fftny)*ny + ixy % fftny;
std::complex<T> *outp = &out[ixy * nplane];
std::complex<T> *inp = &in[istot * nplane];
for (int iz = 0; iz < nplane; ++iz)
{
outp[iz] = inp[iz];
}
std::complex<T>* outp = &out[ixy * nplane];
const std::complex<T>* inp = &in[istot * nplane];
detail::copy_complex_buffer(inp, outp, nplane);
}
#endif
return;
Expand Down
18 changes: 3 additions & 15 deletions source/source_basis/module_pw/pw_transform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,7 @@ void PW_Basis::real2recip(const std::complex<FPTYPE>* in,
const int npw_ = this->npw;
const int nxyz_ = this->nxyz;
const int* ig2isz_ = this->ig2isz;
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int ir = 0; ir < nrxx_; ++ir)
{
this->fft_bundle.get_auxr_data<FPTYPE>()[ir] = in[ir];
}
detail::copy_complex_buffer_parallel(in, this->fft_bundle.get_auxr_data<FPTYPE>(), nrxx_);
this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());

this->gatherp_scatters(this->fft_bundle.get_auxr_data<FPTYPE>(), this->fft_bundle.get_auxg_data<FPTYPE>());
Expand Down Expand Up @@ -199,13 +193,7 @@ void PW_Basis::recip2real(const std::complex<FPTYPE>* in,
}
else
{
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int ir = 0; ir < nrxx_; ++ir)
{
out[ir] = this->fft_bundle.get_auxr_data<FPTYPE>()[ir];
}
detail::copy_complex_buffer_parallel(this->fft_bundle.get_auxr_data<FPTYPE>(), out, nrxx_);
}
ModuleBase::timer::end(this->classname, "recip2real");
}
Expand Down Expand Up @@ -340,4 +328,4 @@ template void PW_Basis::recip2real<double>(const std::complex<double>* in,
std::complex<double>* out,
const bool add,
const double factor) const;
} // namespace ModulePW
} // namespace ModulePW
16 changes: 2 additions & 14 deletions source/source_basis/module_pw/pw_transform_k.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,7 @@ void PW_Basis_K::real2recip(const std::complex<FPTYPE>* in,

assert(this->gamma_only == false);
auto* auxr = this->fft_bundle.get_auxr_data<FPTYPE>();
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int ir = 0; ir < this->nrxx; ++ir)
{
auxr[ir] = in[ir];
}
detail::copy_complex_buffer_parallel(in, auxr, this->nrxx);
this->fft_bundle.fftxyfor(fft_bundle.get_auxr_data<FPTYPE>(), fft_bundle.get_auxr_data<FPTYPE>());

this->gatherp_scatters(this->fft_bundle.get_auxr_data<FPTYPE>(), this->fft_bundle.get_auxg_data<FPTYPE>());
Expand Down Expand Up @@ -200,13 +194,7 @@ void PW_Basis_K::recip2real(const std::complex<FPTYPE>* in,
}
else
{
#ifdef _OPENMP
#pragma omp parallel for schedule(static)
#endif
for (int ir = 0; ir < this->nrxx; ++ir)
{
out[ir] = auxr[ir];
}
detail::copy_complex_buffer_parallel(auxr, out, this->nrxx);
}
ModuleBase::timer::end(this->classname, "recip2real");
}
Expand Down
40 changes: 40 additions & 0 deletions source/source_basis/module_pw/test_serial/pw_basis_k_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "source_base/global_function.h"
#include "source_base/constants.h"
#include "source_base/matrix3.h"
#include <vector>

/************************************************
* serial unit test of functions in pw_basis.cpp
Expand Down Expand Up @@ -188,4 +189,43 @@ TEST_F(PWBasisKTEST, CollectLocalPW)
EXPECT_EQ(basis_k.npwk_max,2721);
}

TEST_F(PWBasisKTEST, ComplexTransformRoundTrip)
{
ModulePW::PW_Basis_K basis_k(device_flag, precision_double);
double lat0 = 2.0;
ModuleBase::Matrix3 latvec(1.0,0.0,1.0,
0.0,2.0,0.0,
0.0,0.0,2.0);
double gridecut = 30.0;
const bool gamma_only_in = false;
const double gk_ecut_in = 20.0;
const int nks_in = 1;
const ModuleBase::Vector3<double> kvec_d_in[1] = { {0.0, 0.0, 0.0} };
const int distribution_type_in = 2;
const bool xprime_in = false;

basis_k.initgrids(lat0, latvec, gridecut);
basis_k.initparameters(gamma_only_in, gk_ecut_in, nks_in, kvec_d_in, distribution_type_in, xprime_in);
ASSERT_NO_THROW(basis_k.setuptransform());

// Use reciprocal-space input because arbitrary real-space data is projected
// by the plane-wave cutoff and is not exactly recoverable.
std::vector<std::complex<double>> recip_in(basis_k.npwk[0]);
std::vector<std::complex<double>> real_space(basis_k.nrxx);
std::vector<std::complex<double>> recip_out(basis_k.npwk[0]);
Comment thread
Aunixt marked this conversation as resolved.
for (int ig = 0; ig < basis_k.npwk[0]; ++ig)
{
const double real_part = (ig % 17 - 8) / 11.0;
const double imag_part = (ig % 19 - 9) / 13.0;
recip_in[ig] = std::complex<double>(real_part, imag_part);
}

basis_k.recip2real(recip_in.data(), real_space.data(), 0);
basis_k.real2recip(real_space.data(), recip_out.data(), 0);

for (int ig = 0; ig < basis_k.npwk[0]; ++ig)
{
EXPECT_NEAR(recip_in[ig].real(), recip_out[ig].real(), 1e-10);
EXPECT_NEAR(recip_in[ig].imag(), recip_out[ig].imag(), 1e-10);
}
}
39 changes: 39 additions & 0 deletions source/source_basis/module_pw/test_serial/pw_basis_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "source_base/global_function.h"
#include "source_base/constants.h"
#include "source_base/matrix3.h"
#include <vector>

/************************************************
* serial unit test of functions in pw_basis.cpp
Expand Down Expand Up @@ -362,3 +363,41 @@ TEST_F(PWBasisTEST,CollectUniqgg)
pwb.collect_uniqgg();
EXPECT_EQ(pwb.ngg,78);
}

TEST_F(PWBasisTEST,ComplexTransformRoundTrip)
{
double lat0 = 2.0;
ModuleBase::Matrix3 latvec(1.0,0.0,1.0,
0.0,2.0,0.0,
0.0,0.0,2.0);
double gridecut = 30.0;
bool gamma_only_in = false;
double pwecut_in = 20.0;
int distribution_type_in = 2;
bool xprime_in = false;

pwb.initgrids(lat0, latvec, gridecut);
pwb.initparameters(gamma_only_in, pwecut_in, distribution_type_in, xprime_in);
ASSERT_NO_THROW(pwb.setuptransform());

// Use reciprocal-space input because arbitrary real-space data is projected
// by the plane-wave cutoff and is not exactly recoverable.
std::vector<std::complex<double>> recip_in(pwb.npw);
std::vector<std::complex<double>> real_space(pwb.nrxx);
std::vector<std::complex<double>> recip_out(pwb.npw);
for (int ig = 0; ig < pwb.npw; ++ig)
{
const double real_part = (ig % 11 - 5) / 7.0;
const double imag_part = (ig % 13 - 6) / 9.0;
recip_in[ig] = std::complex<double>(real_part, imag_part);
}

pwb.recip2real(recip_in.data(), real_space.data());
pwb.real2recip(real_space.data(), recip_out.data());

for (int ig = 0; ig < pwb.npw; ++ig)
{
EXPECT_NEAR(recip_in[ig].real(), recip_out[ig].real(), 1e-10);
EXPECT_NEAR(recip_in[ig].imag(), recip_out[ig].imag(), 1e-10);
}
Comment thread
Aunixt marked this conversation as resolved.
}
Loading