Skip to content

Commit

Permalink
Update seed feature (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
moonbings authored Nov 10, 2022
1 parent 95ad39f commit 9a1d6c0
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 12 deletions.
12 changes: 11 additions & 1 deletion synthtiger/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,24 @@

from synthtiger import components, layers, templates, utils
from synthtiger._version import __version__
from synthtiger.gen import generator, read_config, read_template
from synthtiger.gen import (
generator,
get_global_random_states,
read_config,
read_template,
set_global_random_seed,
set_global_random_states,
)

__all__ = [
"components",
"layers",
"templates",
"utils",
"generator",
"get_global_random_states",
"read_config",
"read_template",
"set_global_random_seed",
"set_global_random_states",
]
35 changes: 24 additions & 11 deletions synthtiger/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,27 @@ def generator(path, name, config=None, count=None, worker=0, seed=None, verbose=
yield task_idx, data


def get_global_random_states():
states = {
"random": random.getstate(),
"numpy": np.random.get_state(),
"imgaug": imgaug.random.get_global_rng().state,
}
return states


def set_global_random_states(states):
random.setstate(states["random"])
np.random.set_state(states["numpy"])
imgaug.random.get_global_rng().state = states["imgaug"]


def set_global_random_seed(seed):
random.seed(seed)
np.random.set_state(np.random.RandomState(np.random.MT19937(seed)).get_state())
imgaug.seed(seed)


def _run(func, args):
proc = Process(target=func, args=args)
proc.daemon = True
Expand Down Expand Up @@ -89,13 +110,8 @@ def _worker(path, name, config, task_queue, data_queue, verbose):


def _generate(template, seed, verbose):
temp_state = random.getstate()
temp_np_state = np.random.get_state()
temp_imgaug_state = imgaug.random.get_global_rng().state

random.seed(seed)
np.random.set_state(np.random.RandomState(np.random.MT19937(seed)).get_state())
imgaug.seed(seed)
states = get_global_random_states()
set_global_random_seed(seed)

while True:
try:
Expand All @@ -106,8 +122,5 @@ def _generate(template, seed, verbose):
continue
break

random.setstate(temp_state)
np.random.set_state(temp_np_state)
imgaug.random.get_global_rng().state = temp_imgaug_state

set_global_random_states(states)
return data
1 change: 1 addition & 0 deletions synthtiger/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def run(args):

pprint.pprint(config)

synthtiger.set_global_random_seed(args.seed)
template = synthtiger.read_template(args.script, args.name, config)
generator = synthtiger.generator(
args.script,
Expand Down

0 comments on commit 9a1d6c0

Please sign in to comment.