Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Graph2Tree and GTS can't handle test_batch_size > 1 #19

Open
liamjxu opened this issue May 6, 2022 · 2 comments
Open

Graph2Tree and GTS can't handle test_batch_size > 1 #19

liamjxu opened this issue May 6, 2022 · 2 comments

Comments

@liamjxu
Copy link

liamjxu commented May 6, 2022

When test_batch_size is set through command line, e.g.,

python run_mwptoolkit.py --model=GTS --dataset=mawps --task_type=multi_equation --gpu_id=0 --equation_fix=prefix --test_batch_size=32

Both Graph2Tree and GTS fail to forward propagate,

Traceback (most recent call last):
  File "run_mwptoolkit.py", line 63, in <module>
    run_toolkit(config)
  File "/data/MWPToolkit/mwptoolkit/quick_start.py", line 220, in run_toolkit
    train_with_train_valid_test_split(config)
  File "/data/MWPToolkit/mwptoolkit/quick_start.py", line 109, in train_with_train_valid_test_split
    trainer.fit()
  File "/data/MWPToolkit/mwptoolkit/trainer/supervised_trainer.py", line 583, in fit
    valid_equ_ac, valid_val_ac, valid_total, valid_time_cost = self.evaluate(DatasetType.Valid)
  File "/data/MWPToolkit/mwptoolkit/trainer/supervised_trainer.py", line 645, in evaluate
    batch_val_ac, batch_equ_ac = self._eval_batch(batch)
  File "/data/MWPToolkit/mwptoolkit/trainer/supervised_trainer.py", line 506, in _eval_batch
    test_out, target = self.model.model_test(batch)
  File "/data/MWPToolkit/mwptoolkit/model/Seq2Tree/gts.py", line 179, in model_test
    _, outputs, _ = self.forward(seq, seq_length, nums_stack, num_size, num_pos)
  File "/data/MWPToolkit/mwptoolkit/model/Seq2Tree/gts.py", line 128, in forward
    output_all_layers)
  File "/data/MWPToolkit/mwptoolkit/model/Seq2Tree/gts.py", line 350, in decoder_forward
    out_token = int(ti)
ValueError: only one element tensors can be converted to Python scalars

This issue makes the training + testing time exceptionally long because the model_test logic is utilized in both validating and testing.
Would you consider adding support for test_batch_size > 1?

@LYH-YF
Copy link
Owner

LYH-YF commented May 8, 2022

Yes, i'm working on this, I will update if I test its correctness and it significantly improves speed.

@liamjxu
Copy link
Author

liamjxu commented May 8, 2022

Thanks for the reply! Looking forward to it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants