Skip to content

Commit

Permalink
Merge pull request #85 from tonnylou44853:main
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707084783
Change-Id: I81a60df8e1e5d3035ac96fa39f4f354f865c051a
  • Loading branch information
jsspencer committed Jan 9, 2025
2 parents 9a1deec + b689f3f commit 7f7a0c8
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions ferminet/pbc/hamiltonian.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,7 @@ def local_energy(
nspins: Sequence[int],
use_scan: bool = False,
complex_output: bool = False,
laplacian_method: str = 'default',
states: int = 0,
lattice: Optional[jnp.ndarray] = None,
heg: bool = True,
Expand All @@ -170,6 +171,9 @@ def local_energy(
nspins: Number of particles of each spin.
use_scan: Whether to use a `lax.scan` for computing the laplacian.
complex_output: If true, the output of f is complex-valued.
laplacian_method: Laplacian calculation method. One of:
'default': take jvp(grad), looping over inputs
'folx': use Microsoft's implementation of forward laplacian
states: Number of excited states to compute. Not implemented, only present
for consistency of calling convention.
lattice: Shape (ndim, ndim). Matrix of lattice vectors. Default: identity
Expand All @@ -188,8 +192,10 @@ def local_energy(
if lattice is None:
lattice = jnp.eye(3)

ke = hamiltonian.local_kinetic_energy(f, use_scan=use_scan,
complex_output=complex_output)
ke = hamiltonian.local_kinetic_energy(f,
use_scan=use_scan,
complex_output=complex_output,
laplacian_method=laplacian_method)

def _e_l(
params: networks.ParamTree, key: chex.PRNGKey, data: networks.FermiNetData
Expand Down

0 comments on commit 7f7a0c8

Please sign in to comment.