From 266c0259b70e80a536836f083ae5efb6f03bd57d Mon Sep 17 00:00:00 2001 From: James Spencer Date: Thu, 9 Jan 2025 18:41:07 +0000 Subject: [PATCH] Fallback to initialising electrons about the origin if a per-atom initialisation is not found PiperOrigin-RevId: 713722201 Change-Id: Ib6b2e491218ca4d5e9c65ce7518377f669607031 --- ferminet/train.py | 39 ++++++++++++++++++++++++++++++--------- 1 file changed, 30 insertions(+), 9 deletions(-) diff --git a/ferminet/train.py b/ferminet/train.py index 7abd138..9a60664 100644 --- a/ferminet/train.py +++ b/ferminet/train.py @@ -62,6 +62,7 @@ def init_electrons( # pylint: disable=dangerous-default-value batch_size: int, init_width: float, core_electrons: Mapping[str, int] = {}, + max_iter: int = 10_000, ) -> Tuple[jnp.ndarray, jnp.ndarray]: """Initializes electron positions around each atom. @@ -75,6 +76,9 @@ def init_electrons( # pylint: disable=dangerous-default-value electron configurations. core_electrons: mapping of element symbol to number of core electrons included in the pseudopotential. + max_iter: maximum number of iterations to try to find a valid initial + electron configuration for each atom. If reached, all electrons are + initialised from a Gaussian distribution centred on the origin. Returns: array of (batch_size, (nalpha+nbeta)*ndim) of initial (random) electron @@ -83,6 +87,7 @@ def init_electrons( # pylint: disable=dangerous-default-value of spin configurations, where 1 and -1 indicate alpha and beta electrons respectively. """ + niter = 0 total_electrons = sum(atom.charge - core_electrons.get(atom.symbol, 0) for atom in molecule) if total_electrons != sum(electrons): @@ -98,19 +103,35 @@ def init_electrons( # pylint: disable=dangerous-default-value for atom in molecule ] assert sum(sum(x) for x in atomic_spin_configs) == sum(electrons) - while tuple(sum(x) for x in zip(*atomic_spin_configs)) != electrons: + while ( + tuple(sum(x) for x in zip(*atomic_spin_configs)) != electrons + and niter < max_iter + ): i = np.random.randint(len(atomic_spin_configs)) nalpha, nbeta = atomic_spin_configs[i] atomic_spin_configs[i] = nbeta, nalpha + niter += 1 + + if tuple(sum(x) for x in zip(*atomic_spin_configs)) == electrons: + # Assign each electron to an atom initially. + electron_positions = [] + for i in range(2): + for j in range(len(molecule)): + atom_position = jnp.asarray(molecule[j].coords) + electron_positions.append( + jnp.tile(atom_position, atomic_spin_configs[j][i])) + electron_positions = jnp.concatenate(electron_positions) + else: + logging.warning( + 'Failed to find a valid initial electron configuration after %i' + ' iterations. Initializing all electrons from a Gaussian distribution' + ' centred on the origin. This might require increasing the number of' + ' iterations used for pretraining and MCMC burn-in. Consider' + ' implementing a custom initialisation.', + niter, + ) + electron_positions = jnp.zeros(shape=(3*sum(electrons),)) - # Assign each electron to an atom initially. - electron_positions = [] - for i in range(2): - for j in range(len(molecule)): - atom_position = jnp.asarray(molecule[j].coords) - electron_positions.append( - jnp.tile(atom_position, atomic_spin_configs[j][i])) - electron_positions = jnp.concatenate(electron_positions) # Create a batch of configurations with a Gaussian distribution about each # atom. key, subkey = jax.random.split(key)