diff --git a/synthtiger/__init__.py b/synthtiger/__init__.py index c71fc86..2a382e4 100644 --- a/synthtiger/__init__.py +++ b/synthtiger/__init__.py @@ -6,7 +6,14 @@ from synthtiger import components, layers, templates, utils from synthtiger._version import __version__ -from synthtiger.gen import generator, read_config, read_template +from synthtiger.gen import ( + generator, + get_global_random_states, + read_config, + read_template, + set_global_random_seed, + set_global_random_states, +) __all__ = [ "components", @@ -14,6 +21,9 @@ "templates", "utils", "generator", + "get_global_random_states", "read_config", "read_template", + "set_global_random_seed", + "set_global_random_states", ] diff --git a/synthtiger/gen.py b/synthtiger/gen.py index 65baa37..aa849e8 100644 --- a/synthtiger/gen.py +++ b/synthtiger/gen.py @@ -62,6 +62,27 @@ def generator(path, name, config=None, count=None, worker=0, seed=None, verbose= yield task_idx, data +def get_global_random_states(): + states = { + "random": random.getstate(), + "numpy": np.random.get_state(), + "imgaug": imgaug.random.get_global_rng().state, + } + return states + + +def set_global_random_states(states): + random.setstate(states["random"]) + np.random.set_state(states["numpy"]) + imgaug.random.get_global_rng().state = states["imgaug"] + + +def set_global_random_seed(seed): + random.seed(seed) + np.random.set_state(np.random.RandomState(np.random.MT19937(seed)).get_state()) + imgaug.seed(seed) + + def _run(func, args): proc = Process(target=func, args=args) proc.daemon = True @@ -89,13 +110,8 @@ def _worker(path, name, config, task_queue, data_queue, verbose): def _generate(template, seed, verbose): - temp_state = random.getstate() - temp_np_state = np.random.get_state() - temp_imgaug_state = imgaug.random.get_global_rng().state - - random.seed(seed) - np.random.set_state(np.random.RandomState(np.random.MT19937(seed)).get_state()) - imgaug.seed(seed) + states = get_global_random_states() + set_global_random_seed(seed) while True: try: @@ -106,8 +122,5 @@ def _generate(template, seed, verbose): continue break - random.setstate(temp_state) - np.random.set_state(temp_np_state) - imgaug.random.get_global_rng().state = temp_imgaug_state - + set_global_random_states(states) return data diff --git a/synthtiger/main.py b/synthtiger/main.py index d752d72..d1620ce 100644 --- a/synthtiger/main.py +++ b/synthtiger/main.py @@ -17,6 +17,7 @@ def run(args): pprint.pprint(config) + synthtiger.set_global_random_seed(args.seed) template = synthtiger.read_template(args.script, args.name, config) generator = synthtiger.generator( args.script,