Skip to content

Commit

Permalink
Fix random state generator
Browse files Browse the repository at this point in the history
  • Loading branch information
viclafargue committed Apr 6, 2021
1 parent 9feecfb commit 6fea88a
Showing 1 changed file with 5 additions and 12 deletions.
17 changes: 5 additions & 12 deletions python/cuml/datasets/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2020, NVIDIA CORPORATION.
# Copyright (c) 2020-2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -25,16 +25,9 @@ def _create_rs_generator(random_state):
The random_state from which the CuPy random state is generated
"""

if hasattr(random_state, '__module__'):
rs_type = random_state.__module__ + '.' + type(random_state).__name__
else:
rs_type = type(random_state).__name__

rs = None
if rs_type == "NoneType" or rs_type == "int":
rs = cp.random.RandomState(seed=random_state)
elif rs_type == "cupy.random.generator.RandomState":
rs = rs_type
if isinstance(random_state, (type(None), int)):
return cp.random.RandomState(seed=random_state)
elif isinstance(random_state, cp.random.RandomState):
return random_state
else:
raise ValueError('random_state type must be int or CuPy RandomState')
return rs

0 comments on commit 6fea88a

Please sign in to comment.