Skip to content

Commit

Permalink
Fix dimension issue
Browse files Browse the repository at this point in the history
  • Loading branch information
sunqm committed Jan 27, 2025
1 parent 1cd7be4 commit 1e8540f
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions pyscf/df/df_jk.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,17 +258,7 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13):
vj = 0
vk = numpy.zeros_like(dms)

if with_j:
idx = numpy.arange(nao)
dmtril = lib.pack_tril(dms + dms.conj().transpose(0,2,1))
dmtril[:,idx*(idx+1)//2+idx] *= .5

if not with_k:
for eri1 in dfobj.loop():
# uses numpy.matmul
vj += dmtril.dot(eri1.T).dot(eri1)

elif dms.dtype != numpy.float64:
if dms.dtype != numpy.float64:
if with_j:
vj = numpy.zeros_like(dms)
max_memory = dfobj.max_memory - lib.current_memory()[0]
Expand All @@ -279,15 +269,31 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13):
naux, nao_pair = eri1.shape
eri1 = lib.unpack_tril(eri1, out=buf)
if with_j:
tmp = numpy.einsum('pij,nji->pn', eri1, dms)
vj += numpy.einsum('pn,pij->nij', tmp, eri1)
tmp = numpy.einsum('pij,nji->pn', eri1, dms.real)
vj.real += numpy.einsum('pn,pij->nij', tmp, eri1)
tmp = numpy.einsum('pij,nji->pn', eri1, dms.imag)
vj.imag += numpy.einsum('pn,pij->nij', tmp, eri1)
buf2 = numpy.ndarray((nao,naux,nao), buffer=buf1)
for k in range(nset):
buf2[:] = lib.einsum('pij,jk->ipk', eri1, dms[k].real)
vk[k].real += lib.einsum('ipk,pkj->ij', buf2, eri1)
buf2[:] = lib.einsum('pij,jk->ipk', eri1, dms[k].imag)
vk[k].imag += lib.einsum('ipk,pkj->ij', buf2, eri1)
t1 = log.timer_debug1('jk', *t1)
if with_j: vj = vj.reshape(dm_shape)
if with_k: vk = vk.reshape(dm_shape)
logger.timer(dfobj, 'df vj and vk', *t0)
return vj, vk

if with_j:
idx = numpy.arange(nao)
dmtril = lib.pack_tril(dms + dms.conj().transpose(0,2,1))
dmtril[:,idx*(idx+1)//2+idx] *= .5

if not with_k:
for eri1 in dfobj.loop():
# uses numpy.matmul
vj += dmtril.dot(eri1.T).dot(eri1)

elif getattr(dm, 'mo_coeff', None) is not None:
#TODO: test whether dm.mo_coeff matching dm
Expand Down Expand Up @@ -360,12 +366,8 @@ def get_jk(dfobj, dm, hermi=0, with_j=True, with_k=True, direct_scf_tol=1e-13):
vk[k] += lib.dot(buf1.reshape(-1,nao).T, buf2.reshape(-1,nao))
t1 = log.timer_debug1('jk', *t1)

if with_j:
if dms.dtype == numpy.float64:
vj = lib.unpack_tril(vj, 1)
vj = vj.reshape(dm_shape)
if with_k:
vk = vk.reshape(dm_shape)
if with_j: vj = lib.unpack_tril(vj, 1).reshape(dm_shape)
if with_k: vk = vk.reshape(dm_shape)
logger.timer(dfobj, 'df vj and vk', *t0)
return vj, vk

Expand All @@ -374,7 +376,6 @@ def get_j(dfobj, dm, hermi=0, direct_scf_tol=1e-13):
from pyscf.scf import jk
from pyscf.df import addons
t0 = t1 = (logger.process_clock(), logger.perf_counter())
assert dm.dtype == numpy.float64

mol = dfobj.mol
if dfobj._vjopt is None:
Expand Down Expand Up @@ -415,6 +416,7 @@ def get_j(dfobj, dm, hermi=0, direct_scf_tol=1e-13):
opt = dfobj._vjopt
fakemol = opt.fakemol
dm = numpy.asarray(dm, order='C')
assert dm.dtype == numpy.float64
dm_shape = dm.shape
nao = dm_shape[-1]
dm = dm.reshape(-1,nao,nao)
Expand Down

0 comments on commit 1e8540f

Please sign in to comment.