Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Question about the output of the decision transformer #27916

Closed
Pulsar110 opened this issue Dec 8, 2023 · 4 comments
Closed

Question about the output of the decision transformer #27916

Pulsar110 opened this issue Dec 8, 2023 · 4 comments

Comments

@Pulsar110
Copy link

From the code in here: https://github.com/huggingface/transformers/blob/v4.35.2/src/transformers/models/decision_transformer/modeling_decision_transformer.py#L920-L927

        # reshape x so that the second dimension corresponds to the original
        # returns (0), states (1), or actions (2); i.e. x[:,1,t] is the token for s_t
        x = x.reshape(batch_size, seq_length, 3, self.hidden_size).permute(0, 2, 1, 3)

        # get predictions
        return_preds = self.predict_return(x[:, 2])  # predict next return given state and action
        state_preds = self.predict_state(x[:, 2])  # predict next state given state and action
        action_preds = self.predict_action(x[:, 1])  # predict next action given state

I'm not sure I understand why self.predict_return(x[:, 2]) or self.predict_state(x[:, 2]) is predicting the return/next state given the state and action. From the comment on the top, x[:, 2] is only the action? Am I missing something?

And if this code is correct, what is the use of x[:, 0]?

@ArthurZucker
Copy link
Collaborator

Hey 🤗 thanks for opening an issue! We try to keep the github issues for bugs/feature requests.
Could you ask your question on the forum instead? I'm sure the community will be of help!

Thanks!

@Pulsar110
Copy link
Author

Thank you. I have created a post here: https://discuss.huggingface.co/t/question-about-the-output-of-the-decision-transformer/65384
So far no one has commented there. I'm not sure if there is a bug in the code, or maybe I do not understand it correctly, that's also why I wanted to post here.

@ArthurZucker
Copy link
Collaborator

I don't know this model at all so pinging @edbeeching the author of the PR!

@edbeeching
Copy link
Contributor

Hi @Pulsar110 , thanks for your question. It would probably be best to reach out to the authors with this question as our implementation aims to match the author's codebase: https://github.com/kzl/decision-transformer/blob/e2d82e68f330c00f763507b3b01d774740bee53f/gym/decision_transformer/models/decision_transformer.py#L97

If I were to hazard a guess I would think that there is a mistake in their implementation and we should be indexing entry 0 at some point.

Let us know what they say and perhaps we can update our implementation with any changes they suggest. I will close the issue for now but feel free to reopen it with more questions or if you hear back from them.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants