-
Notifications
You must be signed in to change notification settings - Fork 48
/
Copy pathga_repertoire.py
186 lines (142 loc) · 5.67 KB
/
ga_repertoire.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
"""Defines a repertoire for simple genetic algorithms."""
from __future__ import annotations
from functools import partial
from typing import Callable, Tuple
import jax
import jax.numpy as jnp
from jax.flatten_util import ravel_pytree
from qdax.core.containers.repertoire import Repertoire
from qdax.types import Fitness, Genotype, RNGKey
class GARepertoire(Repertoire):
"""Class for a simple repertoire for a simple genetic
algorithm.
Args:
genotypes: a PyTree containing the genotypes of the
individuals in the population. Each leaf has the
shape (population_size, num_features).
fitnesses: an array containing the fitness of the individuals
in the population. With shape (population_size, fitness_dim).
The implementation of GARepertoire was thought for the case
where fitness_dim equals 1 but the class can be herited and
rules adapted for cases where fitness_dim is greater than 1.
"""
genotypes: Genotype
fitnesses: Fitness
@property
def size(self) -> int:
"""Gives the size of the population."""
first_leaf = jax.tree_util.tree_leaves(self.genotypes)[0]
return int(first_leaf.shape[0])
def save(self, path: str = "./") -> None:
"""Saves the repertoire.
Args:
path: place to store the files. Defaults to "./".
"""
def flatten_genotype(genotype: Genotype) -> jnp.ndarray:
flatten_genotype, _ = ravel_pytree(genotype)
return flatten_genotype
# flatten all the genotypes
flat_genotypes = jax.vmap(flatten_genotype)(self.genotypes)
jnp.save(path + "genotypes.npy", flat_genotypes)
jnp.save(path + "scores.npy", self.fitnesses)
@classmethod
def load(cls, reconstruction_fn: Callable, path: str = "./") -> GARepertoire:
"""Loads a GA Repertoire.
Args:
reconstruction_fn: Function to reconstruct a PyTree
from a flat array.
path: Path where the data is saved. Defaults to "./".
Returns:
A GA Repertoire.
"""
flat_genotypes = jnp.load(path + "genotypes.npy")
genotypes = jax.vmap(reconstruction_fn)(flat_genotypes)
fitnesses = jnp.load(path + "fitnesses.npy")
return cls(
genotypes=genotypes,
fitnesses=fitnesses,
)
@partial(jax.jit, static_argnames=("num_samples",))
def sample(self, random_key: RNGKey, num_samples: int) -> Tuple[Genotype, RNGKey]:
"""Sample genotypes from the repertoire.
Args:
random_key: a random key to handle stochasticity.
num_samples: the number of genotypes to sample.
Returns:
The sample of genotypes.
"""
# prepare sampling probability
mask = self.fitnesses != -jnp.inf
p = jnp.any(mask, axis=-1) / jnp.sum(jnp.any(mask, axis=-1))
# sample
random_key, subkey = jax.random.split(random_key)
samples = jax.tree_util.tree_map(
lambda x: jax.random.choice(
subkey, x, shape=(num_samples,), p=p, replace=False
),
self.genotypes,
)
return samples, random_key
@jax.jit
def add(
self, batch_of_genotypes: Genotype, batch_of_fitnesses: Fitness
) -> GARepertoire:
"""Implements the repertoire addition rules.
Parents and offsprings are gathered and only the population_size
bests are kept. The others are killed.
Args:
batch_of_genotypes: new genotypes that we try to add.
batch_of_fitnesses: fitness of those new genotypes.
Returns:
The updated repertoire.
"""
# gather individuals and fitnesses
candidates = jax.tree_util.tree_map(
lambda x, y: jnp.concatenate((x, y), axis=0),
self.genotypes,
batch_of_genotypes,
)
candidates_fitnesses = jnp.concatenate(
(self.fitnesses, batch_of_fitnesses), axis=0
)
# sort by fitnesses
indices = jnp.argsort(jnp.sum(candidates_fitnesses, axis=1))[::-1]
# keep only the best ones
survivor_indices = indices[: self.size]
# keep only the best ones
new_candidates = jax.tree_util.tree_map(
lambda x: x[survivor_indices], candidates
)
new_repertoire = self.replace(
genotypes=new_candidates, fitnesses=candidates_fitnesses[survivor_indices]
)
return new_repertoire # type: ignore
@classmethod
def init( # type: ignore
cls,
genotypes: Genotype,
fitnesses: Fitness,
population_size: int,
) -> GARepertoire:
"""Initializes the repertoire.
Start with default values and adds a first batch of genotypes
to the repertoire.
Args:
genotypes: first batch of genotypes
fitnesses: corresponding fitnesses
population_size: size of the population we want to evolve
Returns:
An initial repertoire.
"""
# create default fitnesses
default_fitnesses = -jnp.inf * jnp.ones(
shape=(population_size, fitnesses.shape[-1])
)
# create default genotypes
default_genotypes = jax.tree_util.tree_map(
lambda x: jnp.zeros(shape=(population_size,) + x.shape[1:]), genotypes
)
# create an initial repertoire with those default values
repertoire = cls(genotypes=default_genotypes, fitnesses=default_fitnesses)
new_repertoire = repertoire.add(genotypes, fitnesses)
return new_repertoire # type: ignore