Skip to content

Commit

Permalink
JetMoeIntegrationTest
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Aug 14, 2024
1 parent 9485289 commit d37904c
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions tests/models/jetmoe/test_modeling_jetmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ class JetMoeIntegrationTest(unittest.TestCase):
@slow
def test_model_8b_logits(self):
input_ids = [1, 306, 4658, 278, 6593, 310, 2834, 338]
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
input_ids = torch.tensor([input_ids]).to(model.model.embed_tokens.weight.device)
with torch.no_grad():
out = model(input_ids).logits.cpu()
Expand All @@ -498,7 +498,7 @@ def test_model_8b_generation(self):
EXPECTED_TEXT_COMPLETION = """My favourite condiment is ....\nI love ketchup. I love"""
prompt = "My favourite condiment is "
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.model.embed_tokens.weight.device)

# greedy generation outputs
Expand All @@ -521,7 +521,7 @@ def test_model_8b_batched_generation(self):
"My favourite ",
]
tokenizer = AutoTokenizer.from_pretrained("jetmoe/jetmoe-8b", use_fast=False)
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b", device_map="auto")
model = JetMoeForCausalLM.from_pretrained("jetmoe/jetmoe-8b")
input_ids = tokenizer(prompt, return_tensors="pt", padding=True).to(model.model.embed_tokens.weight.device)
print(input_ids)

Expand Down

0 comments on commit d37904c

Please sign in to comment.