Skip to content

Commit

Permalink
add size and replace arguments to deepmd.utils.random.choice (#3195)
Browse files Browse the repository at this point in the history
Fix
https://github.com/deepmodeling/deepmd-kit/security/code-scanning/2096

---------

Signed-off-by: Jinzhe Zeng <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
njzjz and pre-commit-ci[bot] authored Jan 29, 2024
1 parent a8168b5 commit 5b64d5c
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 8 deletions.
4 changes: 2 additions & 2 deletions deepmd/pt/utils/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -879,7 +879,7 @@ def __len__(self):
def __getitem__(self, index=None):
"""Get a batch of frames from the selected system."""
if index is None:
index = dp_random.choice(np.arange(self.nsystems), self.probs)
index = dp_random.choice(np.arange(self.nsystems), p=self.probs)
b_data = self._data_systems[index].get_batch(self._batch_size)
b_data["natoms"] = torch.tensor(
self._natoms_vec[index], device=env.PREPROCESS_DEVICE
Expand All @@ -892,7 +892,7 @@ def __getitem__(self, index=None):
def get_training_batch(self, index=None):
"""Get a batch of frames from the selected system."""
if index is None:
index = dp_random.choice(np.arange(self.nsystems), self.probs)
index = dp_random.choice(np.arange(self.nsystems), p=self.probs)
b_data = self._data_systems[index].get_batch_for_train(self._batch_size)
b_data["natoms"] = torch.tensor(
self._natoms_vec[index], device=env.PREPROCESS_DEVICE
Expand Down
27 changes: 21 additions & 6 deletions deepmd/utils/random.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,44 @@
# SPDX-License-Identifier: LGPL-3.0-or-later
from typing import (
Optional,
Tuple,
Union,
)

import numpy as np

_RANDOM_GENERATOR = np.random.RandomState()


def choice(a: np.ndarray, p: Optional[np.ndarray] = None):
def choice(
a: Union[np.ndarray, int],
size: Optional[Union[int, Tuple[int, ...]]] = None,
replace: bool = True,
p: Optional[np.ndarray] = None,
):
"""Generates a random sample from a given 1-D array.
Parameters
----------
a : np.ndarray
A random sample is generated from its elements.
p : np.ndarray
The probabilities associated with each entry in a.
a : 1-D array-like or int
If an ndarray, a random sample is generated from its elements. If an int,
the random sample is generated as if it were np.arange(a)
size : int or tuple of ints, optional
Output shape. If the given shape is, e.g., (m, n, k), then m * n * k samples
are drawn. Default is None, in which case a single value is returned.
replace : boolean, optional
Whether the sample is with or without replacement. Default is True, meaning
that a value of a can be selected multiple times.
p : 1-D array-like, optional
The probabilities associated with each entry in a. If not given, the sample
assumes a uniform distribution over all entries in a.
Returns
-------
np.ndarray
arrays with results and their shapes
"""
return _RANDOM_GENERATOR.choice(a, p=p)
return _RANDOM_GENERATOR.choice(a, size=size, replace=replace, p=p)


def random(size=None):
Expand Down

0 comments on commit 5b64d5c

Please sign in to comment.