Skip to content

Commit

Permalink
wokring the stats
Browse files Browse the repository at this point in the history
  • Loading branch information
ggsavin committed Feb 3, 2024
1 parent 73819fb commit 16717c3
Showing 1 changed file with 251 additions and 22 deletions.
273 changes: 251 additions & 22 deletions notebooks/oracle.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -53,7 +53,7 @@
"<All keys matched successfully>"
]
},
"execution_count": 3,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
Expand All @@ -76,28 +76,117 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"@torch.no_grad()\n",
"def play_recurrent_game_with_oracle(env, wolf_policy, villager_agent, num_times=10, hidden_state_size=None, voting_type=None, static_villager_policy=\"oracle\", vtc_other_villagers=0, p=0.5):\n",
"\n",
" wins = 0\n",
" game_replays = []\n",
" oracle_ids = []\n",
"\n",
" for _ in range(num_times):\n",
" next_observations, _, _, _, _ = env.reset()\n",
" # init recurrent stuff for actor and critic to 0 as well\n",
" static_villager = set([random.choice(list(set(env.agents) & set(env.world_state[\"villagers\"])))])\n",
" oracle_ids.append(list(static_villager)[-1])\n",
"\n",
" if static_villager_policy==\"oracle\":\n",
" env.set_oracle(int(list(static_villager)[0].split(\"_\")[-1]))\n",
"\n",
" magent_obs = {agent: {'obs': [], \n",
" # obs size, and 1,1,64 as we pass batch first\n",
" 'hcxs': [(torch.zeros((1,1,hidden_state_size), dtype=torch.float32), \n",
" torch.zeros((1,1,hidden_state_size), dtype=torch.float32))],\n",
" } for agent in env.agents if not env.agent_roles[agent]}\n",
"\n",
" wolf_action = None\n",
"\n",
" while env.agents:\n",
" observations = copy.deepcopy(next_observations)\n",
" actions = {}\n",
"\n",
" villagers = set(env.agents) & set(env.world_state[\"villagers\"]) - static_villager\n",
" wolves = set(env.agents) & set(env.world_state[\"werewolves\"])\n",
"\n",
" ## VILLAGER LOGIC ##\n",
" v_obs = torch.cat([torch.unsqueeze(torch.tensor(env.convert_obs(observations[villager]['observation']), dtype=torch.float), 0) for villager in villagers])\n",
"\n",
" # TODO: maybe this can be sped up? \n",
" hxs, cxs = zip(*[(hxs, cxs) for hxs, cxs in [magent_obs[villager][\"hcxs\"][-1] for villager in villagers]])\n",
" hxs = torch.swapaxes(torch.cat(hxs),0,1)\n",
" cxs = torch.swapaxes(torch.cat(cxs),0,1)\n",
"\n",
" policies, _ , cells = villager_agent(v_obs, (hxs, cxs))\n",
" v_actions = torch.stack([p.sample() for p in policies], dim=1)\n",
"\n",
" hxs_new, cxs_new = cells\n",
" hxs_new = torch.swapaxes(hxs_new,1,0)\n",
" cxs_new = torch.swapaxes(cxs_new,1,0)\n",
"\n",
" for i, villager in enumerate(villagers):\n",
" if voting_type == \"plurality\":\n",
" actions[villager] = v_actions[i].item()\n",
" elif voting_type == \"approval\":\n",
" actions[villager] = (v_actions[i] - 1).tolist()\n",
" magent_obs[villager]['hcxs'].append((torch.unsqueeze(hxs_new[i], 0), torch.unsqueeze(cxs_new[i], 0)))\n",
"\n",
" # if oracle is still alive\n",
" if set(env.agents) & static_villager:\n",
" static_villager_id = list(static_villager)[0]\n",
"\n",
" # get mode of the villager votes\n",
" # max(lst, key=lst.count)\n",
"\n",
" if static_villager_policy == \"oracle\":\n",
"\n",
" # vtc_other_villagers : what should the oracle do towards other villagers.\n",
" actions[static_villager_id] = [(-1 if random.random() < p else 0) if player in env.world_state['werewolves'] else vtc_other_villagers for player in env.possible_agents]\n",
" #actions[static_villager_id] = [-1 if player in env.world_state['werewolves'] else 0 for player in env.possible_agents]\n",
" else: \n",
" actions[static_villager_id] = random_agent(env, static_villager_id, action=None)\n",
"\n",
" ## WOLF LOGIC ## \n",
" phase = env.world_state['phase']\n",
" for wolf in wolves:\n",
" wolf_action = wolf_policy(env, wolf, action=wolf_action)\n",
" actions[wolf] = wolf_action\n",
"\n",
" next_observations, _, _, _, _ = env.step(actions)\n",
"\n",
" ## UPDATED WOLF VARIABLE FOR WOLVES THAT COORDINATE ## \n",
" if env.world_state['phase'] == Phase.NIGHT:\n",
" wolf_action = None\n",
" \n",
" if env.world_state['phase'] == Phase.ACCUSATION and phase == Phase.NIGHT:\n",
" wolf_action = None\n",
" \n",
" ## Fill bigger buffer, keeping in mind sequence\n",
" winner = env.world_state['winners']\n",
" if winner == Roles.VILLAGER:\n",
" wins += 1\n",
"\n",
" game_replays.append(copy.deepcopy(env.history))\n",
" \n",
" return wins, game_replays, oracle_ids"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Without Oracle wins : 0.58\n",
"Oracle wins with 1 for villagers and a chance of voting for werewolves at 0.0: 0.45\n",
"Oracle wins with 1 for villagers and a chance of voting for werewolves at 0.25: 0.47\n",
"Oracle wins with 1 for villagers and a chance of voting for werewolves at 0.5: 0.41\n",
"Oracle wins with 1 for villagers and a chance of voting for werewolves at 0.75: 0.46\n",
"Oracle wins with 1 for villagers and a chance of voting for werewolves at 1.0: 0.43\n",
"\n",
"\n",
"Oracle wins with 0 for villagers and a chance of voting for werewolves at 0.0: 0.55\n",
"Oracle wins with 0 for villagers and a chance of voting for werewolves at 0.25: 0.54\n",
"Oracle wins with 0 for villagers and a chance of voting for werewolves at 0.5: 0.56\n",
"Oracle wins with 0 for villagers and a chance of voting for werewolves at 0.75: 0.52\n",
"Oracle wins with 0 for villagers and a chance of voting for werewolves at 1.0: 0.52\n",
"\n",
"\n"
"ename": "NameError",
"evalue": "name 'torch' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[1], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;129m@torch\u001b[39m\u001b[38;5;241m.\u001b[39mno_grad()\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mplay_recurrent_game_with_oracle\u001b[39m(env, wolf_policy, villager_agent, num_times\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m10\u001b[39m, hidden_state_size\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, voting_type\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mNone\u001b[39;00m, static_villager_policy\u001b[38;5;241m=\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124moracle\u001b[39m\u001b[38;5;124m\"\u001b[39m, vtc_other_villagers\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0\u001b[39m, p\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m0.5\u001b[39m):\n\u001b[1;32m 4\u001b[0m wins \u001b[38;5;241m=\u001b[39m \u001b[38;5;241m0\u001b[39m\n\u001b[1;32m 5\u001b[0m game_replays \u001b[38;5;241m=\u001b[39m []\n",
"\u001b[0;31mNameError\u001b[0m: name 'torch' is not defined"
]
}
],
Expand All @@ -107,11 +196,14 @@
"\n",
" wins = 0\n",
" game_replays = []\n",
" oracle_ids = []\n",
"\n",
" for _ in range(num_times):\n",
" next_observations, _, _, _, _ = env.reset()\n",
" # init recurrent stuff for actor and critic to 0 as well\n",
" static_villager = set([random.choice(list(set(env.agents) & set(env.world_state[\"villagers\"])))])\n",
" oracle_ids.append(list(static_villager)[-1])\n",
"\n",
" if static_villager_policy==\"oracle\":\n",
" env.set_oracle(int(list(static_villager)[0].split(\"_\")[-1]))\n",
"\n",
Expand Down Expand Up @@ -189,7 +281,7 @@
"\n",
" game_replays.append(copy.deepcopy(env.history))\n",
" \n",
" return wins, game_replays\n",
" return wins, game_replays, oracle_ids\n",
"\n",
"num_times = 1000\n",
"wins, replays = play_recurrent_game(env, random_approval_wolf, trained_approval_agent, num_times=num_times, hidden_state_size=256, voting_type=\"approval\")\n",
Expand Down Expand Up @@ -223,6 +315,143 @@
"# wins, replays = play_recurrent_game_with_oracle(env, random_approval_wolf, trained_approval_agent, num_times=num_times, hidden_state_size=256, voting_type=\"approval\", static_villager_policy=\"random\")\n",
"# print(f'With other random villager wins : {wins/float(num_times):.2f}')\n"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"# likelihood that oracle gets killed round 1, 2, 3, 4\n",
"# villager gets killed round 1, 2, 3, 4\n",
"\n",
"wins, replays, o_ids = play_recurrent_game_with_oracle(env, \n",
" random_approval_wolf, \n",
" trained_approval_agent, \n",
" num_times=100, \n",
" hidden_state_size=256, \n",
" voting_type=\"approval\", \n",
" static_villager_policy=\"oracle\",\n",
" vtc_other_villagers=1,\n",
" p=0.0)\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"100"
]
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"len(replays)"
]
},
{
"cell_type": "code",
"execution_count": 45,
"metadata": {},
"outputs": [],
"source": [
"def _when_did_oracle_get_killed(game, id):\n",
" '''\n",
" What day and was he killed by votes or by werewolves\n",
" '''\n",
" just_votes = []\n",
" just_kills = []\n",
" for step in game:\n",
" if step['phase'] == Phase.VOTING:\n",
" if len(step[\"executed\"]) == 1:\n",
" if step['executed'][0] == id:\n",
" return step['day'], 0\n",
" else:\n",
" who_was_killed = list(set(step['executed']) - set(just_votes[-1]['executed']))[0]\n",
" if who_was_killed == id:\n",
" return step['day'], 0\n",
" \n",
" just_votes.append(step)\n",
" \n",
" if step['phase'] == Phase.NIGHT:\n",
"\n",
" if len(step[\"killed\"]) == 1:\n",
" if step['killed'][0] == id:\n",
" return step['day'], 1\n",
" else:\n",
" who_was_killed = list(set(step['killed']) - set(just_kills[-1]['killed']))[0]\n",
" if who_was_killed == id:\n",
" return step['day'], 1\n",
"\n",
" just_kills.append(step) \n",
"\n",
" return -1, -1\n",
"\n",
"# day, ty = [(_when_did_oracle_get_killed(replay, o_id)) for replay, o_id in zip(replays,o_ids)]\n",
"day, ty = list(zip(*[(_when_did_oracle_get_killed(replay, o_id)) for replay, o_id in zip(replays,o_ids)]))\n",
"game_winner = [replay[-1]['winners'] for replay in replays]\n",
"\n",
"#print(f'Day - {day}, Type - {ty}, Winner - {replays[0][-1][\"winners\"]}')"
]
},
{
"cell_type": "code",
"execution_count": 49,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-1 35\n",
" 3 21\n",
" 1 19\n",
" 2 17\n",
" 4 8\n",
"Name: day_oracle_killed, dtype: int64"
]
},
"execution_count": 49,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import pandas as pd\n",
"\n",
"df = pd.DataFrame({\"day_oracle_killed\": list(day), \"way_oracle_killed\": list(ty), \"game_result\": game_winner})\n",
"df.head()\n",
"df['day_oracle_killed'].value_counts()"
]
},
{
"cell_type": "code",
"execution_count": 50,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"-1 35\n",
" 0 33\n",
" 1 32\n",
"Name: way_oracle_killed, dtype: int64"
]
},
"execution_count": 50,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df['way_oracle_killed'].value_counts()"
]
}
],
"metadata": {
Expand Down

0 comments on commit 16717c3

Please sign in to comment.