Skip to content

Commit

Permalink
Fix reformer CI (#21254)
Browse files Browse the repository at this point in the history
* fix ReformerForSequenceClassification doc example

* fix ReformerForMaskedLM doc example

Co-authored-by: ydshieh <[email protected]>
  • Loading branch information
ydshieh and ydshieh authored Jan 23, 2023
1 parent eaace0c commit cb6b568
Showing 1 changed file with 6 additions and 8 deletions.
14 changes: 6 additions & 8 deletions src/transformers/models/reformer/modeling_reformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2377,15 +2377,17 @@ def forward(
>>> tokenizer.add_special_tokens({"mask_token": "[MASK]"}) # doctest: +IGNORE_RESULT
>>> inputs = tokenizer("The capital of France is [MASK].", return_tensors="pt")
>>> # resize model's embedding matrix
>>> model.resize_token_embeddings(new_num_tokens=model.config.vocab_size + 1) # doctest: +IGNORE_RESULT
>>> with torch.no_grad():
... logits = model(**inputs).logits
>>> # retrieve index of [MASK]
>>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
>>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
>>> tokenizer.decode(predicted_token_id)
'it'
>>> predicted_token = tokenizer.decode(predicted_token_id)
```
```python
Expand All @@ -2396,8 +2398,7 @@ def forward(
... )
>>> outputs = model(**inputs, labels=labels)
>>> round(outputs.loss.item(), 2)
7.09
>>> loss = round(outputs.loss.item(), 2)
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Expand Down Expand Up @@ -2494,8 +2495,7 @@ def forward(
... logits = model(**inputs).logits
>>> predicted_class_id = logits.argmax().item()
>>> model.config.id2label[predicted_class_id]
'LABEL_0'
>>> label = model.config.id2label[predicted_class_id]
```
```python
Expand All @@ -2507,8 +2507,6 @@ def forward(
>>> labels = torch.tensor(1)
>>> loss = model(**inputs, labels=labels).loss
>>> round(loss.item(), 2)
0.68
```
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
Expand Down

0 comments on commit cb6b568

Please sign in to comment.