-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9d918ab
commit dabfe2a
Showing
8 changed files
with
472 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,81 @@ | ||
#%% | ||
import os | ||
import json | ||
import torch | ||
import gymnasium as gym | ||
from dcrl_env_harl_partialobs import DCRL | ||
|
||
# Import HAPPO from here: /lustre/guillant/HARL/harl/algorithms/actors/happo.py | ||
# But I am working on /lustre/guillant/dc-rl | ||
|
||
import sys | ||
sys.path.append('/lustre/guillant/HARL/harl') | ||
|
||
import harl | ||
from algorithms.actors.happo import HAPPO | ||
|
||
|
||
#%% | ||
# Checkpoint and config path: | ||
checkpoint_path = os.path.join('/lustre/guillant/HARL/results/dcrl', 'CA/happo/ls_dc_bat/seed-00001-2024-05-28-23-23-34') | ||
config_path = os.path.join(checkpoint_path, 'config.json') | ||
|
||
# Read config_path | ||
# Read the config file | ||
with open(config_path, 'r') as f: | ||
config = json.load(f) | ||
|
||
env_config = config['env_args'] | ||
# Create the dcrl environment | ||
env = DCRL(env_config) | ||
|
||
# Obtain from env_config how many agents is active: ''agents': ['agent_ls', 'agent_dc', 'agent_bat']' | ||
# Obtain the number of active agents | ||
num_agents = len(env_config['agents']) | ||
agents = env_config['agents'] | ||
actors = {} | ||
for agent_id, agent in enumerate(env_config['agents']): | ||
checkpoint = torch.load(checkpoint_path + "/models/actor_agent" + str(agent_id) + ".pt") | ||
|
||
# load_state_dict from checkpoint | ||
model_args = config['algo_args']['model'] | ||
algo_args = config['algo_args']['algo'] | ||
agent = HAPPO({**algo_args["model"], **algo_args["algo"]}, | ||
self.envs.observation_space[agent_id], | ||
self.envs.action_space[agent_id], | ||
device=self.device, | ||
) | ||
actors[agent].load_state_dict(checkpoint['model_state_dict']) | ||
actors[agent].eval() | ||
|
||
#%% | ||
# Reset the environment | ||
obs = env.reset() | ||
done = False | ||
total_reward = 0 | ||
|
||
while not done: | ||
# Get the actions for each actor | ||
actions = {} | ||
for agent in agents: | ||
actor = actors[agent] | ||
action = actor(torch.tensor(obs[agent]).float().unsqueeze(0)).detach().numpy() | ||
actions[env_config['agents'][agent_id]] = action | ||
|
||
obs[agent_id], reward, done, _ = env.step(action[0]) | ||
action = model.predict(obs) | ||
|
||
# Take a step in the environment | ||
obs, reward, done, _ = env.step(action) | ||
|
||
# Accumulate the reward | ||
total_reward += reward | ||
|
||
# Print the total reward | ||
print("Total reward:", total_reward) | ||
|
||
# Path to the trained model | ||
model_path = '/path/to/trained/model.h5' | ||
|
||
# Evaluate the model | ||
evaluate_model(model_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
#%% | ||
import os | ||
import pandas as pd | ||
|
||
import matplotlib.pyplot as plt | ||
|
||
#%% | ||
folder_path = '/lustre/guillant/dc-rl/data/CarbonIntensity' | ||
|
||
# Get all CSV files in the folder | ||
csv_files = [file for file in os.listdir(folder_path) if file.endswith('.csv')] | ||
|
||
# Read each CSV file and obtain the avg_CI column and save it along with the location name in a dictionary | ||
values = {} | ||
|
||
for file in csv_files: | ||
file_path = os.path.join(folder_path, file) | ||
df = pd.read_csv(file_path) | ||
values[file[:2]] = df['avg_CI'] | ||
|
||
#%% | ||
# Now plot the values and the legend should be the key of the values dictionary | ||
# I want to plot only the month 7. | ||
# Knowing that the values on the csv are 1 hour apart, I can get the index of the first day of the month and the last day of the month | ||
# and then plot only the values between those indexes | ||
# I can use the index to get the values between those indexes | ||
import numpy as np | ||
selected_month = 7 | ||
init_index = selected_month * 30 * 24 | ||
end_index = (selected_month + 1) * 30 * 24 | ||
|
||
x_range = np.arange(init_index, end_index)/24 | ||
plt.figure(figsize=(10, 5)) | ||
for key, value in values.items(): | ||
if key in ['IL', 'TX', 'NY', 'VA', 'GA', 'WA', 'AZ', 'CA']: | ||
# plt.plot(value[init_index:end_index]**3/200000, label=key, linestyle='-', linewidth=2, alpha=0.9) | ||
plt.plot(x_range, value[init_index:end_index], label=key, linestyle='-', linewidth=2, alpha=0.9) | ||
|
||
plt.ylabel('Carbon Intensity (gCO2/kWh)', fontsize=16) | ||
plt.xlabel('Day', fontsize=16) | ||
# plt.xlim(init_index/24, end_index/24) | ||
plt.title('Average Daily Carbon Intensity in Different Locations in July', fontsize=18) | ||
plt.grid('on', linestyle='--', alpha=0.5) | ||
|
||
plt.tick_params(axis='x', labelsize=12, rotation=45) # Set the font size of xticks | ||
plt.tick_params(axis='y', labelsize=12) # Set the font size of yticks | ||
plt.legend(fontsize=11.5, ncols=8) | ||
plt.xlim(210, 240) | ||
plt.ylim(-1) | ||
|
||
plt.savefig('plots/GreenDCC_ci_all_locations.pdf', bbox_inches='tight') | ||
plt.show() | ||
|
||
#%% |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.