forked from applied-ai-collective/Pacman-Deep-Q-Network
-
Notifications
You must be signed in to change notification settings - Fork 0
/
curriculumLearning.py
83 lines (79 loc) · 3.15 KB
/
curriculumLearning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from typing import List, Any, Union, Dict
from environment import Environment
class EnvStepRule(object):
"""
Class for creating rules for stepping to the next environment
Args:
performance_threshold: Threshold for performance
patience: Number of times the performance threshold must be reached
thres_direction: Direction of the threshold
"""
def __init__(self,
performance_threshold:float,
patience:int=1,
thresh_direction:str='above'
):
self.performance_threshold = performance_threshold
self.progress_counter = 0
self.patience = patience
self.thresh_direction = thresh_direction
def checkRule(self, performance:float):
if self.thresh_direction == 'above':
if performance >= self.performance_threshold:
self.progress_counter += 1
else:
self.progress_counter = 0
elif self.thresh_direction == 'below':
if performance <= self.performance_threshold:
self.progress_counter += 1
else:
self.progress_counter = 0
if self.progress_counter >= self.patience:
return True
else:
return False
class EnvScheduler(object):
"""
Class for scheduling the environment used for curriculum learning
Args:
env_list: List of environments
env_start: Starting environment
env_step_rule: Dictionary specifying the rules for stepping to the next environment
env_back_rule: Dictionary specifying the rules for stepping back to the previous environment
"""
def __init__(self,
env_list:List[Environment],
env_step_rule:Union[Dict[str, Any],EnvStepRule],
env_back_rule:Union[Dict[str, int], EnvStepRule,None] = None,
env_start:int = 0,
quiet:bool = True):
self.env_num = env_start
self.num_envs = len(env_list)
self.performace = 0
self.quiet = quiet
self.env_list = sorted(env_list, key=lambda x: x.difficulty)
if isinstance(env_step_rule, EnvStepRule):
self.env_step_rule = env_step_rule
else:
self.env_step_rule = EnvStepRule(**env_step_rule)
if env_back_rule is not None:
if isinstance(env_back_rule, EnvStepRule):
self.env_back_rule = env_back_rule
else:
self.env_back_rule = EnvStepRule(**env_back_rule)
else:
self.env_back_rule = None
def registerPerformance(self, performance):
if self.env_step_rule.checkRule(performance):
if not self.quiet:
print("Stepping to next environment")
self.env_num += 1
elif self.env_back_rule is not None:
if self.env_back_rule.checkRule(performance):
if not self.quiet:
print("Stepping back to previous environment")
self.env_num -= 1
def getEnv(self):
if self.env_num < self.num_envs:
return self.env_list[self.env_num]
return None