-
Notifications
You must be signed in to change notification settings - Fork 6
/
cartpole_param_inference_rnn.py
39 lines (34 loc) · 1.51 KB
/
cartpole_param_inference_rnn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#%% Imports
from src.data.cartpole_data_generator import CartPoleDataGenerator
from src.models.rnn import RNN_feat
from src.utils.param_inference import *
import os
cur_root_dir = os.getcwd()
print("Current Directory: {}".format(cur_root_dir))
#%% Load Policy
policy_file = os.path.join(cur_root_dir, "src", "models", "controllers", "PPO", "CartPole-v1.pkl")
g = CartPoleDataGenerator(policy_file=policy_file)
#%% Load data shape
params, stats = g.gen(1)
shapes = {"params": params.shape[1], "data": stats.shape[1]}
print("Total data size: {}".format(params.shape[0]))
#%% Train model
rnn = RNN_feat(simulator=g)
# log_mdn, inf_mdn = train(epochs=1000, batch_size=100, params_dim=2, stats_dim=12, num_sampled_points=1000,
# generator=g, model="MDRFF", n_components=10)
#
# #%% Plot Results for mass and length specific params
# true_obs = np.array([[1.0, 1.0]])
#
# get_results_from_true_obs(env_params=["length", "masspole"], true_obs=true_obs, generator=g, inf=inf_mdn, shapes=shapes,
# p_lower=[0.1, 0.1], p_upper=[2.0, 2.0])
#
# true_obs = np.array([[0.7, 1.3]])
#
# get_results_from_true_obs(env_params=["length", "masspole"], true_obs=true_obs, generator=g, inf=inf_mdn, shapes=shapes,
# p_lower=[0.1, 0.1], p_upper=[2.0, 2.0])
#
# true_obs = np.array([[1.75, 1.0]])
#
# get_results_from_true_obs(env_params=["length", "masspole"], true_obs=true_obs, generator=g, inf=inf_mdn, shapes=shapes,
# p_lower=[0.1, 0.1], p_upper=[2.0, 2.0])