From 6fea88aa02895981b1efcf04c78ba6abf805e307 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 6 Apr 2021 14:46:57 +0000 Subject: [PATCH] Fix random state generator --- python/cuml/datasets/utils.py | 17 +++++------------ 1 file changed, 5 insertions(+), 12 deletions(-) diff --git a/python/cuml/datasets/utils.py b/python/cuml/datasets/utils.py index 0a0b3c7b5e..d6e0c3c16d 100644 --- a/python/cuml/datasets/utils.py +++ b/python/cuml/datasets/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020, NVIDIA CORPORATION. +# Copyright (c) 2020-2021, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -25,16 +25,9 @@ def _create_rs_generator(random_state): The random_state from which the CuPy random state is generated """ - if hasattr(random_state, '__module__'): - rs_type = random_state.__module__ + '.' + type(random_state).__name__ - else: - rs_type = type(random_state).__name__ - - rs = None - if rs_type == "NoneType" or rs_type == "int": - rs = cp.random.RandomState(seed=random_state) - elif rs_type == "cupy.random.generator.RandomState": - rs = rs_type + if isinstance(random_state, (type(None), int)): + return cp.random.RandomState(seed=random_state) + elif isinstance(random_state, cp.random.RandomState): + return random_state else: raise ValueError('random_state type must be int or CuPy RandomState') - return rs