forked from ikostrikov/pytorch-a2c-ppo-acktr-gail
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model_tor.py
59 lines (48 loc) · 2.41 KB
/
model_tor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
from distributions_tor import GaussianDistributionNetwork
from utils_tor import init_param_openaibaselines
class ActorCriticNetwork(nn.Module):
def __init__(self, input_dim, hidden_dim, actor_output_dim, critic_output_dim):
super(ActorCriticNetwork, self).__init__()
self.hidden_net = ActorCriticHiddenNetwork(input_dim, hidden_dim)
self.actor_output_net = GaussianDistributionNetwork(hidden_dim, actor_output_dim)
self.critic_output_net = init_param_openaibaselines(nn.Linear(hidden_dim, critic_output_dim))
def act(self, observ):
state_value, meta_action = self._forward(observ)
action_distrib = self.actor_output_net(meta_action)
action = action_distrib.sample()
action_log_prob = action_distrib.log_prob(action).sum(dim=-1, keepdim=True)
return action, action_log_prob, state_value
def evaluate_actions(self, observ, action):
state_value, meta_action = self._forward(observ)
action_distrib = self.actor_output_net(meta_action)
action_log_prob = action_distrib.log_prob(action).sum(dim=-1, keepdim=True)
action_distrib_entropy = action_distrib.entropy().sum(dim=-1, keepdim=False).mean()
return action_log_prob, action_distrib_entropy, state_value
def predict_state_value(self, observ):
hidden_critic = self.hidden_net.critic_hidden_net(observ)
return self.critic_output_net(hidden_critic)
def forward(self, inputs, states, masks):
raise NotImplementedError
def _forward(self, observ):
meta_action = self.hidden_net.actor_hidden_net(observ)
state_value = self.predict_state_value(observ)
return (state_value, meta_action)
class ActorCriticHiddenNetwork(nn.Module):
def __init__(self, input_dim, hidden_dim):
super(ActorCriticHiddenNetwork, self).__init__()
self.actor_hidden_net = nn.Sequential(
init_param_openaibaselines(nn.Linear(input_dim, hidden_dim)),
nn.Tanh(),
init_param_openaibaselines(nn.Linear(hidden_dim, hidden_dim)),
nn.Tanh()
)
self.critic_hidden_net = nn.Sequential(
init_param_openaibaselines(nn.Linear(input_dim, hidden_dim)),
nn.Tanh(),
init_param_openaibaselines(nn.Linear(hidden_dim, hidden_dim)),
nn.Tanh()
)
def forward(self, observ):
raise NotImplementedError