generated from BattlesnakeOfficial/starter-snake-javascript
-
Notifications
You must be signed in to change notification settings - Fork 0
/
monteCarloTreeSearch.js
322 lines (276 loc) · 9.52 KB
/
monteCarloTreeSearch.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
import { generateNewState, isTerminal } from "./generateState.js";
import { copyState, getLegalMoves, purgeSnakes } from "./helpers.js";
import { evaluation } from "./evaluation.js";
const EXPLORATION_CONSTANT = Math.sqrt(2);
const MAX_SIMULATIONS_DEPTH = 24; // 12 turns for each player
const MAX_TIME = 360; // ms
let maxDepthReached = 0;
let getLegalMovesCounter = 0;
let getLegalMovesTime = 0;
let generateStateCounter = 0;
let generateStateTime = 0;
class Node {
constructor(state, turn, parent = null) {
this.state = state;
this.parent = parent;
this.children = [];
this.visits = 0;
this.score = 0;
this.turn = turn;
this.isTerminal = isTerminal(state);
}
addChild(child) {
this.children.push(child);
}
}
// Algorithm contains 4 major steps: https://www.youtube.com/watch?v=UXW2yZndl7U
// 1. Selection: Start from the root and succesively chose children until a leaf node L is reached
// Nodes are chosen based on their UCB1 value
// 2. Expansion: If the node L has not been visited we jump directly to the simulation.
// If it has been visited before we first expand it (if it has not been expanded)
// We then find the child which maximizes UCB1 (go back to 1 with L as root)
// 3. Simulation (roll-out/playout): From the selected node C,
// Perform a rollout where random moves are chosen until an exit criteria is met or the game ends
// Evaluate this final state
// 4. Backpropogation: Use the result of the playout to update infromation in the nodes on the path from C to R
// Each node has a score (cumulative) and a number of visits
export const monteCarloTreeSearch = (state) => {
maxDepthReached = 0;
generateStateCounter = 0;
generateStateTime = 0;
getLegalMovesCounter = 0;
getLegalMovesTime = 0;
const start = Date.now();
const root = new Node(state, 0, null);
let numberOfSimulations = 0;
while (Date.now() - start < MAX_TIME) {
let node = select(root); // 1. Selection (leaf node which maximizes UCB1)
// 2. Exapnd node if we have visited it before
if (node.children.length === 0 && node.visits > 0 && !node.isTerminal) {
expand(node);
// Select one of the children
node = select(node);
}
const result = simulate(node, 0); // 3. Simulate from the selected node
backpropagate(node, result); // 4. Backpropagation
// Benchmark search depth
if (node.turn > maxDepthReached) {
maxDepthReached = node.turn;
}
numberOfSimulations++; // Increment number of simulations
}
console.log(
"Average simulation time: ",
(Date.now() - start) / numberOfSimulations
);
console.log("Performed simulations: ", numberOfSimulations);
console.log("Max depth reached: ", maxDepthReached / 2);
console.log("Time spent on MCTS: ", Date.now() - start);
console.log(
"Time spent on generating states: ",
generateStateTime,
" Calls: ",
generateStateCounter,
" Average time: ",
generateStateTime / generateStateCounter
);
console.log(
"Time spent on finding legal moves: ",
getLegalMovesTime,
" Calls: ",
getLegalMovesCounter,
" Average time: ",
getLegalMovesTime / getLegalMovesCounter
);
return bestChild(root);
};
// %%%%%%%%%% Util functions %%%%%%%%%%
// UCB1 formula
const UCB1 = (node) => {
// Nodes that have not been visited are preferred over others
const score =
node.visits === 0
? Infinity
: node.score +
EXPLORATION_CONSTANT *
Math.sqrt(Math.log(node.parent.visits) / node.visits);
return score;
};
// Returns the child of a node which maximizes the UCB1 formula
const bestUCT = (node) => {
let bestChild = null;
let bestUCT = -Infinity;
for (const child of node.children) {
const uct = UCB1(child);
// Kan man göra såhär?
if (uct === Infinity) {
bestUCT = uct;
bestChild = child;
break;
} else if (UCB1(child) > bestUCT) {
bestUCT = uct;
bestChild = child;
}
}
return bestChild;
};
// Returns the child with the best score
const bestChild = (node) => {
let bestChild = null;
let bestScore = -Infinity;
for (const child of node.children) {
const score = child.score / child.visits;
console.log("Child score: ", score);
if (score > bestScore) {
bestScore = score;
bestChild = child;
}
}
const ourPos = node.state.ourSnakes[0].head;
if (bestChild.state.ourSnakes.length === 0) {
return "up";
}
const ourNextPos = bestChild.state.ourSnakes[0].head;
const move =
ourPos.x === ourNextPos.x
? ourPos.y < ourNextPos.y
? "up"
: "down"
: ourPos.x < ourNextPos.x
? "right"
: "left";
return move;
};
// 1. Selection:
// Returns the child which maximizes UCB1
const select = (node) => {
//Base case: we have reached a leaf node
if (node.children.length === 0) {
return node;
}
// If current node has children we find the best child
const bestChild = bestUCT(node);
// Recursive call until we reach a leaf node
return select(bestChild);
};
// 2. Expands a node with the elligble moves
const expand = (node) => {
// Select the right team
const ourSnakes =
node.turn % 2 === 0 ? node.state.ourSnakes : node.state.enemySnakes;
// 1. Generate legal moves for snake 1.1
let legalMoveStart = Date.now();
const movesObj = getLegalMoves(node.state, ourSnakes[0].id, node.turn);
getLegalMovesCounter++;
getLegalMovesTime += Date.now() - legalMoveStart;
const moves = Object.keys(movesObj).filter((key) => movesObj[key]);
// 2. Generate the next states s2.1 based on the moves
let stateStart = Date.now();
const states = moves.map((move) =>
generateNewState(node.state, ourSnakes[0].id, move, node.turn)
);
generateStateCounter += states.length;
generateStateTime += Date.now() - stateStart;
if (ourSnakes.length === 1) {
if (node.turn % 2 === 1) states.forEach(purgeSnakes);
const children = states.map(
(state) => new Node(state, node.turn + 1, node)
);
node.children = children;
return;
}
// 4. Generate the next states s2.2 based on the moves of snake 2
const states2 = [];
for (const state of states) {
// 3. Get the legal moves for snake 1.2 on states s2.1
legalMoveStart = Date.now();
const moves2Obj = getLegalMoves(node.state, ourSnakes[1].id, node.turn);
getLegalMovesCounter++;
getLegalMovesTime += Date.now() - legalMoveStart;
const moves2 = Object.keys(moves2Obj).filter((key) => moves2Obj[key]);
stateStart = Date.now();
moves2.forEach((move) => {
states2.push(generateNewState(state, ourSnakes[1].id, move, node.turn));
});
generateStateCounter += moves2.length;
generateStateTime += Date.now() - stateStart;
}
// 4.5 Purge colliding snakes in the states
if (node.turn % 2 === 1) states2.map((state) => purgeSnakes(state));
// 5. Create the nodes for the states s2.2
const children = states2.map((state) => new Node(state, node.turn + 1, node));
// 6. Add the children to the parent node
node.children = children;
};
//3. Simulation until we reach an end node or exit criteria is met
const simulate = (node, depth, startTime) => {
// Simulate until stop criterion is reached
const thisSimDepth = node.turn + MAX_SIMULATIONS_DEPTH + (node.turn % 2);
let tempNode = new Node(copyState(node.state), node.turn);
while (tempNode.turn < thisSimDepth && !tempNode.isTerminal) {
// Get possible moves
const ourSnakes =
tempNode.turn % 2 === 0
? tempNode.state.ourSnakes
: tempNode.state.enemySnakes;
// 1. Generate legal moves for snake 1.1
let legalMovesStart = Date.now();
const movesObj = getLegalMoves(
tempNode.state,
ourSnakes[0].id,
tempNode.turn
);
getLegalMovesCounter++;
getLegalMovesTime += Date.now() - legalMovesStart;
const moves = Object.keys(movesObj).filter((key) => movesObj[key]);
const move = moves[Math.floor(Math.random() * moves.length)];
// 2. Generate the next states s2.1 based on the moves
let generateStateStart = Date.now();
const state = generateNewState(
tempNode.state,
ourSnakes[0].id,
move,
tempNode.turn
);
generateStateCounter++;
generateStateTime += Date.now() - generateStateStart;
if (ourSnakes.length === 1) {
if (node.depth % 2 === 1) purgeSnakes(state);
tempNode = new Node(state, tempNode.turn + 1);
continue;
}
// 3. Get the legal moves for snake 1.2 on states s2.1
legalMovesStart = Date.now();
const moves2Obj = getLegalMoves(state, ourSnakes[1].id, tempNode.turn);
getLegalMovesCounter++;
getLegalMovesTime += Date.now() - legalMovesStart;
const moves2 = Object.keys(moves2Obj).filter((key) => moves2Obj[key]);
const move2 = moves2[Math.floor(Math.random() * moves2.length)];
// 4. Generate the next states s2.2 based on the moves of snake 2
generateStateStart = Date.now();
const state2 = generateNewState(
state,
ourSnakes[1].id,
move2,
tempNode.turn
);
generateStateCounter++;
generateStateTime += Date.now() - generateStateStart;
// 4.5 Purge colliding snakes
if (node.turn % 2 === 1) purgeSnakes(state2);
// 5. Create the nodes for the states s2.2
tempNode = new Node(state2, tempNode.turn + 1);
}
return evaluation(
tempNode.state.ourSnakes ?? [],
tempNode.state.enemySnakes ?? []
);
};
// 4. Backpropagation
const backpropagate = (node, result) => {
while (node !== null) {
node.visits += 1;
node.score += node.turn % 2 == 0 ? -result : result; // We want to maximize the score of the first player
node = node.parent;
}
};