Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 55 additions & 2 deletions source/source_hsolver/test/diago_cg_float_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,33 @@ void lapackEigen(int &npw, std::vector<std::complex<float>> &hm, float *e, bool
delete[] work2;
}

// LAPACK reference for generalized eigenproblem when S is diagonal (complex<float>)
void lapackGeneralEigen(int &npw, std::vector<std::complex<float>> &hm, const std::vector<std::complex<float>> &sdiag, float *e, bool outtime = false)
{
// build transformed matrix tmp = S^{-1/2} H S^{-1/2}
std::vector<std::complex<float>> tmp(npw * npw);
for (int i = 0; i < npw; ++i) {
std::complex<float> si = std::sqrt(sdiag[i]);
for (int j = 0; j < npw; ++j) {
std::complex<float> sj = std::sqrt(sdiag[j]);
tmp[i * npw + j] = hm[i * npw + j] / (si * sj);
}
}
Comment on lines +62 to +72

// call cheev_ on transformed matrix
clock_t start = clock(), end;
int lwork = 2 * npw;
std::complex<float> *work2 = new std::complex<float>[lwork];
float *rwork = new float[3 * npw - 2];
int info = 0;
char tmp_c1 = 'V', tmp_c2 = 'U';
cheev_(&tmp_c1, &tmp_c2, &npw, tmp.data(), &npw, e, work2, &lwork, rwork, &info);
end = clock();
if (outtime) std::cout << "Lapack General Run time: " << (float)(end - start) / CLOCKS_PER_SEC << " S" << std::endl;
Comment on lines +81 to +83
delete[] rwork;
delete[] work2;
}

class DiagoCGPrepare
{
public:
Expand Down Expand Up @@ -85,8 +112,14 @@ class DiagoCGPrepare
float *e_lapack = new float[npw];
auto ev = DIAGOTEST::hmatrix_f;

if(mypnum == 0) { lapackEigen(npw, ev, e_lapack, false);
}
if(mypnum == 0) {
if (DIAGOTEST::sdiag_f.empty()) {
lapackEigen(npw, ev, e_lapack);
} else {
auto hm_copy = ev; // operate on a copy
lapackGeneralEigen(npw, hm_copy, DIAGOTEST::sdiag_f, e_lapack);
}
}

// initial guess of psi by perturbing lapack psi
std::vector<std::complex<float>> psiguess(nband * npw);
Expand Down Expand Up @@ -228,12 +261,32 @@ TEST_P(DiagoCGFloatTest, RandomHamilt)
//std::cout<<"eps "<<hsolver::DiagoIterAssist<std::complex<float>>::PW_DIAG_THR<<std::endl;
HPsi<std::complex<float>> hpsi(dcp.nband, dcp.npw, dcp.sparsity);
DIAGOTEST::hmatrix_f = hpsi.hamilt();
DIAGOTEST::sdiag_f.clear(); // ensure sdiag is empty for this test

DIAGOTEST::npw = dcp.npw;
// ModuleBase::ComplexMatrix psi = hpsi.psi();
dcp.CompareEigen(hpsi.precond());
}

TEST_P(DiagoCGFloatTest, RandomHamiltAndS)
{
DiagoCGPrepare dcp = GetParam();
hsolver::DiagoIterAssist<std::complex<float>>::PW_DIAG_NMAX = dcp.maxiter;
hsolver::DiagoIterAssist<std::complex<float>>::PW_DIAG_THR = dcp.eps;
HPsi<std::complex<float>> hpsi(dcp.nband, dcp.npw, dcp.sparsity);
DIAGOTEST::hmatrix_f = hpsi.hamilt();

DIAGOTEST::npw = dcp.npw;

// 生成正定对角 S(范围 0.5..1.5)
DIAGOTEST::sdiag_f.resize(dcp.npw);
std::default_random_engine eng(123);
std::uniform_real_distribution<float> ud(0.5f, 1.5f);
for (int i = 0; i < dcp.npw; ++i) DIAGOTEST::sdiag_f[i] = std::complex<float>(ud(eng), 0.0f);
Comment on lines +281 to +285

dcp.CompareEigen(hpsi.precond());
}

INSTANTIATE_TEST_SUITE_P(VerifyCG,
DiagoCGFloatTest,
::testing::Values(
Expand Down
42 changes: 30 additions & 12 deletions source/source_hsolver/test/diago_mock.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,12 @@ namespace DIAGOTEST
std::vector<std::complex<double>> hmatrix_local;
std::vector<std::complex<float>> hmatrix_f;
std::vector<std::complex<float>> hmatrix_local_f;

// diagonal representation of overlap (S) for simple mock of generalized eigenproblem
// if empty, sPsi will treat S as identity
std::vector<double> sdiag_d; // for double / complex<double>
std::vector<std::complex<float> > sdiag_f;
std::vector<std::complex<double>> sdiag;
Comment on lines +15 to +19
int h_nr;
int h_nc;
int npw;
Expand Down Expand Up @@ -409,11 +415,15 @@ void hamilt::HamiltPW<double, base_device::DEVICE_CPU>::sPsi(const double* psi_i
const int npw,
const int nbands) const
{
for (size_t i = 0; i < static_cast<size_t>(nbands * nrow); i++)
{
spsi[i] = psi_in[i];
if (DIAGOTEST::sdiag_d.size() < static_cast<size_t>(nrow)) {
DIAGOTEST::sdiag_d.assign(nrow, 1.0); // 默认单位 S
}
for (int v = 0; v < nbands; ++v) {
for (int i = 0; i < nrow; ++i) {
size_t idx = static_cast<size_t>(v) * nrow + i;
spsi[idx] = psi_in[idx] * DIAGOTEST::sdiag_d[i];
}
}
Comment on lines +418 to 426
return;
}
template <>
void hamilt::HamiltPW<std::complex<double>, base_device::DEVICE_CPU>::sPsi(const std::complex<double>* psi_in,
Expand All @@ -422,11 +432,15 @@ void hamilt::HamiltPW<std::complex<double>, base_device::DEVICE_CPU>::sPsi(const
const int npw,
const int nbands) const
{
for (size_t i = 0; i < static_cast<size_t>(nbands * nrow); i++)
{
spsi[i] = psi_in[i];
if (DIAGOTEST::sdiag_d.size() < static_cast<size_t>(nrow)) {
DIAGOTEST::sdiag_d.assign(nrow, 1.0);
}
for (int v = 0; v < nbands; ++v) {
for (int i = 0; i < nrow; ++i) {
size_t idx = static_cast<size_t>(v) * nrow + i;
spsi[idx] = psi_in[idx] * DIAGOTEST::sdiag_d[i];
}
}
return;
}
template <>
void hamilt::HamiltPW<std::complex<float>, base_device::DEVICE_CPU>::sPsi(const std::complex<float>* psi_in,
Expand All @@ -435,11 +449,15 @@ void hamilt::HamiltPW<std::complex<float>, base_device::DEVICE_CPU>::sPsi(const
const int npw,
const int nbands) const
{
for (size_t i = 0; i < static_cast<size_t>(nbands * nrow); i++)
{
spsi[i] = psi_in[i];
if (DIAGOTEST::sdiag_f.size() < static_cast<size_t>(nrow)) {
DIAGOTEST::sdiag_f.assign(nrow, 1.0f);
}
for (int v = 0; v < nbands; ++v) {
for (int i = 0; i < nrow; ++i) {
size_t idx = static_cast<size_t>(v) * nrow + i;
spsi[idx] = psi_in[idx] * DIAGOTEST::sdiag_f[i];
}
}
return;
}

//Mock function h_psi
Expand Down
Loading