diff --git a/deepmd/pt/utils/dataset.py b/deepmd/pt/utils/dataset.py index 24daa6e37e..3920499d3a 100644 --- a/deepmd/pt/utils/dataset.py +++ b/deepmd/pt/utils/dataset.py @@ -879,7 +879,7 @@ def __len__(self): def __getitem__(self, index=None): """Get a batch of frames from the selected system.""" if index is None: - index = dp_random.choice(np.arange(self.nsystems), self.probs) + index = dp_random.choice(np.arange(self.nsystems), p=self.probs) b_data = self._data_systems[index].get_batch(self._batch_size) b_data["natoms"] = torch.tensor( self._natoms_vec[index], device=env.PREPROCESS_DEVICE @@ -892,7 +892,7 @@ def __getitem__(self, index=None): def get_training_batch(self, index=None): """Get a batch of frames from the selected system.""" if index is None: - index = dp_random.choice(np.arange(self.nsystems), self.probs) + index = dp_random.choice(np.arange(self.nsystems), p=self.probs) b_data = self._data_systems[index].get_batch_for_train(self._batch_size) b_data["natoms"] = torch.tensor( self._natoms_vec[index], device=env.PREPROCESS_DEVICE diff --git a/deepmd/utils/random.py b/deepmd/utils/random.py index 8944419412..44ea6a1dac 100644 --- a/deepmd/utils/random.py +++ b/deepmd/utils/random.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: LGPL-3.0-or-later from typing import ( Optional, + Tuple, + Union, ) import numpy as np @@ -8,22 +10,35 @@ _RANDOM_GENERATOR = np.random.RandomState() -def choice(a: np.ndarray, p: Optional[np.ndarray] = None): +def choice( + a: Union[np.ndarray, int], + size: Optional[Union[int, Tuple[int, ...]]] = None, + replace: bool = True, + p: Optional[np.ndarray] = None, +): """Generates a random sample from a given 1-D array. Parameters ---------- - a : np.ndarray - A random sample is generated from its elements. - p : np.ndarray - The probabilities associated with each entry in a. + a : 1-D array-like or int + If an ndarray, a random sample is generated from its elements. If an int, + the random sample is generated as if it were np.arange(a) + size : int or tuple of ints, optional + Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples + are drawn. Default is None, in which case a single value is returned. + replace : boolean, optional + Whether the sample is with or without replacement. Default is True, meaning + that a value of a can be selected multiple times. + p : 1-D array-like, optional + The probabilities associated with each entry in a. If not given, the sample + assumes a uniform distribution over all entries in a. Returns ------- np.ndarray arrays with results and their shapes """ - return _RANDOM_GENERATOR.choice(a, p=p) + return _RANDOM_GENERATOR.choice(a, size=size, replace=replace, p=p) def random(size=None):