Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

In which part do you implement policy decoupling #12

Closed
donutQQ opened this issue Jan 10, 2022 · 3 comments
Closed

In which part do you implement policy decoupling #12

donutQQ opened this issue Jan 10, 2022 · 3 comments

Comments

@donutQQ
Copy link

donutQQ commented Jan 10, 2022

Hello, I am very interested in your work! I have learned the code, especially the class "TransformerAggregationAgent". But I have not found where you implement the policy decoupling. The only thing I find is
q_agg = torch.mean(outputs, 1)
q = self.q_linear(q_agg)

I am confused that you calculatte the mean along the action dimension and then map the result back to the actions. Can you please explain the motivation of this part. Really look forward to your reply.

Thanks!

@hhhusiyi-monash
Copy link
Collaborator

Hi there,

Thanks for your interest.
TransformerAggregationAgent is a transformer-based agent without policy decoupling strategy.
In figure 4(a) of our paper, you could find that without policy decoupling, the transformer-based agent performs even worse than classical GRU/LSTM, which demonstrates the effectiveness of this strategy.

Any further concern is welcome.

@ouyangshixiong
Copy link

ouyangshixiong commented Feb 6, 2022

I am not the author of this algorithm, but I carefully read through all code, I think the code of "policy decoupling" is below:

    def forward(self, inputs, hidden_state, task_enemy_num, task_ally_num):
        outputs, _ = self.transformer.forward(inputs, hidden_state, None)
        # first output for 6 action (no_op stop up down left right)
        q_basic_actions = self.q_basic(outputs[:, 0, :])

        # last dim for hidden state
        h = outputs[:, -1:, :]

        q_enemies_list = []

        # each enemy has an output Q
        for i in range(task_enemy_num):
            q_enemy = self.q_basic(outputs[:, 1 + i, :])
            q_enemy_mean = torch.mean(q_enemy, 1, True)
            q_enemies_list.append(q_enemy_mean)

        # concat enemy Q over all enemies
        q_enemies = torch.stack(q_enemies_list, dim=1).squeeze()

        # concat basic action Q with enemy attack Q
        q = torch.cat((q_basic_actions, q_enemies), 1)

        return q, h

As paper said, it used Transformer to process input(obs), so "inputs" should be obs.
outputs should be all agents' Raw Value(In-short "R"), include enemies'(Figure 7 in the paper).

  • The author decouple Main agent's "R" into q_basic_actions,size:[ ally_num, 6 ]
  • The author abandon all allies' "R"(notice that outputs vector size is [ally_num, 1 + enemy_num + (ally_num-1) +1, embedding_dim ]
    but only main agent(first one),enemies' "R" and the hidden layer(last one) were processed. Another evidence is that input parameter "task_ally_num" was not used).
  • The author decouple enemies' R and simply caculate the R-mean into q_enemy_mean,size[ally_num, enemy_num]
    The final q were concatenated by each enemy's R-mean and main agent's q_basic_actions, size[ally_num, 6+ enemy_num]

The author use HEATMAP to explain the relationship between self-attention matrix and final stragegy(Figure 6 in the paper)

@hhhusiyi-monash
Copy link
Collaborator

Thanks for your detailed explanation. And I pinned this issue for people who have the same confusion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants