Skip to content

Commit

Permalink
Update TEQ train dataloader (#1554)
Browse files Browse the repository at this point in the history
  • Loading branch information
changwangss authored Jan 19, 2024
1 parent 941fed3 commit d1e994b
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion neural_compressor/adaptor/torch_utils/teq.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ def train(

while global_steps <= train_steps:
for inputs in dataloader:
if isinstance(inputs, dict):
if isinstance(inputs, torch.Tensor):
input_id = inputs
elif isinstance(inputs, dict):
input_id = inputs["input_ids"]
else:
input_id = inputs[0]
Expand Down

0 comments on commit d1e994b

Please sign in to comment.