-
Notifications
You must be signed in to change notification settings - Fork 0
/
deep_q_network.py
42 lines (33 loc) · 1.21 KB
/
deep_q_network.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import os
import time
import numpy as np
import tensorflow as tf
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.layers import *
from tensorflow.keras.activations import *
from tensorflow.keras.models import *
from tensorflow.keras.optimizers import *
from tensorflow.keras.initializers import *
from tensorflow.keras.callbacks import *
class DQN(Model):
def __init__(self, state_shape, num_actions, lr):
super(DQN, self).__init__()
self.state_shape = state_shape
self.num_actions = num_actions
self.lr = lr
input_state = Input(shape=state_shape)
x = Dense(20)(input_state)
x = Activation("relu")(x)
x = Dense(20)(x)
x = Activation("relu")(x)
output_pred = Dense(self.num_actions)(x)
self.model = Model(inputs=input_state, outputs=output_pred)
self.model.compile(loss="mse", optimizer=Adam(lr=self.lr))
def train(self, states, q_values):
self.model.fit(states, q_values, verbose=0)
def predict(self, state):
return self.model.predict(state)
def load_model(self, path):
self.model.load_weights(path)
def save_model(self, path):
self.model.save_weights(path)