Skip to content

Commit

Permalink
Housekeeping: remove unused code and fix tests (#1795)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt authored Oct 21, 2024
1 parent e14e39f commit fe394c4
Show file tree
Hide file tree
Showing 7 changed files with 7 additions and 13 deletions.
1 change: 0 additions & 1 deletion litgpt/finetune/full.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from pathlib import Path
from pprint import pprint
from typing import Dict, List, Literal, Optional, Tuple, Union
import warnings

import lightning as L
import torch
Expand Down
3 changes: 1 addition & 2 deletions litgpt/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,11 +541,10 @@ def batched_index_copy_(t, dim, idx, val):
return t.index_copy_(dim, idx, val)

assert idx.dim() == 2, f"multiple batch dims not yet {idx.shape=}"
assert dim != 0, f"cannot index batch dim"
assert dim != 0, f"cannot index batch dim {dim=}"
batch_size, idx_size = idx.shape
assert batch_size == t.size(0)
assert batch_size == val.size(0)
t_indexed_dim = t.size(dim)

# if we can view the batch and indexed dimensions together, we could
# do index trickery. This is, sadly, not the case for kvcache so we
Expand Down
1 change: 0 additions & 1 deletion litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from functools import partial
from pathlib import Path
from typing import Optional, Tuple, Union, Dict
import warnings

import lightning as L
import torch
Expand Down
8 changes: 4 additions & 4 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def test_llm_load_hub_init(tmp_path):

text_2 = llm.generate("text", max_new_tokens=10, top_k=1, stream=True)
text_2 = "".join(list(text_2))
assert text_1 == text_2, (text1, text_2)
assert text_1 == text_2, (text_1, text_2)


def test_model_not_initialized(tmp_path):
Expand Down Expand Up @@ -174,14 +174,14 @@ def test_more_than_1_device_for_sequential_gpu(tmp_path):
model=model_name,
)

with pytest.raises(NotImplementedError, match=f"Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'."):
with pytest.raises(NotImplementedError, match="Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'."):
llm.distribute(devices=2)

llm.distribute(devices=2, generate_strategy="sequential")
assert isinstance(llm.generate("What do llamas eat?"), str)
assert str(llm.model.transformer.h[0].mlp.fc.weight.device) == "cuda:0"
last_layer_idx = len(llm.model.transformer.h) - 1
assert str(llm.model.transformer.h[last_layer_idx].mlp.fc.weight.device) == f"cuda:1"
assert str(llm.model.transformer.h[last_layer_idx].mlp.fc.weight.device) == "cuda:1"

# Also check with default (devices="auto") setting
llm.distribute(generate_strategy="sequential")
Expand Down Expand Up @@ -270,7 +270,7 @@ def test_fixed_kv_cache(tmp_path):

# Request too many tokens
with pytest.raises(NotImplementedError, match="max_seq_length 512 needs to be >= 9223372036854775809"):
output_text = llm.generate("hello world", max_new_tokens=2**63)
_ = llm.generate("hello world", max_new_tokens=2**63)


def test_invalid_accelerator(tmp_path):
Expand Down
5 changes: 2 additions & 3 deletions tests/test_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,13 @@
from io import StringIO
from itertools import repeat
from pathlib import Path
from unittest.mock import ANY, MagicMock, Mock, call, patch
from unittest.mock import MagicMock, Mock, patch
import sys
from typing import Iterable
from typing import Iterable, Iterator

import pytest
import torch
import yaml
from lightning.fabric import Fabric

import litgpt.chat.base as chat
import litgpt.generate.base as generate
Expand Down
1 change: 0 additions & 1 deletion tests/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
import os
from pathlib import Path
from unittest import mock
from unittest.mock import ANY, Mock, call

Expand Down
1 change: 0 additions & 1 deletion tests/test_generate_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from contextlib import redirect_stderr, redirect_stdout
from io import StringIO
import os
from pathlib import Path
from unittest.mock import ANY, Mock, call

import pytest
Expand Down

0 comments on commit fe394c4

Please sign in to comment.