From fe394c43d8d037fa387baf5f6df0ccbea633d503 Mon Sep 17 00:00:00 2001 From: Sebastian Raschka Date: Mon, 21 Oct 2024 18:06:11 -0500 Subject: [PATCH] Housekeeping: remove unused code and fix tests (#1795) --- litgpt/finetune/full.py | 1 - litgpt/model.py | 3 +-- litgpt/pretrain.py | 1 - tests/test_api.py | 8 ++++---- tests/test_chat.py | 5 ++--- tests/test_generate.py | 1 - tests/test_generate_adapter.py | 1 - 7 files changed, 7 insertions(+), 13 deletions(-) diff --git a/litgpt/finetune/full.py b/litgpt/finetune/full.py index 388675fe57..fe32814f0a 100644 --- a/litgpt/finetune/full.py +++ b/litgpt/finetune/full.py @@ -6,7 +6,6 @@ from pathlib import Path from pprint import pprint from typing import Dict, List, Literal, Optional, Tuple, Union -import warnings import lightning as L import torch diff --git a/litgpt/model.py b/litgpt/model.py index c694f63dfd..b60b0506b6 100644 --- a/litgpt/model.py +++ b/litgpt/model.py @@ -541,11 +541,10 @@ def batched_index_copy_(t, dim, idx, val): return t.index_copy_(dim, idx, val) assert idx.dim() == 2, f"multiple batch dims not yet {idx.shape=}" - assert dim != 0, f"cannot index batch dim" + assert dim != 0, f"cannot index batch dim {dim=}" batch_size, idx_size = idx.shape assert batch_size == t.size(0) assert batch_size == val.size(0) - t_indexed_dim = t.size(dim) # if we can view the batch and indexed dimensions together, we could # do index trickery. This is, sadly, not the case for kvcache so we diff --git a/litgpt/pretrain.py b/litgpt/pretrain.py index 084693ba36..e10df56a5b 100644 --- a/litgpt/pretrain.py +++ b/litgpt/pretrain.py @@ -7,7 +7,6 @@ from functools import partial from pathlib import Path from typing import Optional, Tuple, Union, Dict -import warnings import lightning as L import torch diff --git a/tests/test_api.py b/tests/test_api.py index 9c484432f1..0064ae5400 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -130,7 +130,7 @@ def test_llm_load_hub_init(tmp_path): text_2 = llm.generate("text", max_new_tokens=10, top_k=1, stream=True) text_2 = "".join(list(text_2)) - assert text_1 == text_2, (text1, text_2) + assert text_1 == text_2, (text_1, text_2) def test_model_not_initialized(tmp_path): @@ -174,14 +174,14 @@ def test_more_than_1_device_for_sequential_gpu(tmp_path): model=model_name, ) - with pytest.raises(NotImplementedError, match=f"Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'."): + with pytest.raises(NotImplementedError, match="Support for multiple devices is currently only implemented for generate_strategy='sequential'|'tensor_parallel'."): llm.distribute(devices=2) llm.distribute(devices=2, generate_strategy="sequential") assert isinstance(llm.generate("What do llamas eat?"), str) assert str(llm.model.transformer.h[0].mlp.fc.weight.device) == "cuda:0" last_layer_idx = len(llm.model.transformer.h) - 1 - assert str(llm.model.transformer.h[last_layer_idx].mlp.fc.weight.device) == f"cuda:1" + assert str(llm.model.transformer.h[last_layer_idx].mlp.fc.weight.device) == "cuda:1" # Also check with default (devices="auto") setting llm.distribute(generate_strategy="sequential") @@ -270,7 +270,7 @@ def test_fixed_kv_cache(tmp_path): # Request too many tokens with pytest.raises(NotImplementedError, match="max_seq_length 512 needs to be >= 9223372036854775809"): - output_text = llm.generate("hello world", max_new_tokens=2**63) + _ = llm.generate("hello world", max_new_tokens=2**63) def test_invalid_accelerator(tmp_path): diff --git a/tests/test_chat.py b/tests/test_chat.py index d76cbe5a62..5b7a68b556 100644 --- a/tests/test_chat.py +++ b/tests/test_chat.py @@ -6,14 +6,13 @@ from io import StringIO from itertools import repeat from pathlib import Path -from unittest.mock import ANY, MagicMock, Mock, call, patch +from unittest.mock import MagicMock, Mock, patch import sys -from typing import Iterable +from typing import Iterable, Iterator import pytest import torch import yaml -from lightning.fabric import Fabric import litgpt.chat.base as chat import litgpt.generate.base as generate diff --git a/tests/test_generate.py b/tests/test_generate.py index 5a497b5059..6fc561b945 100644 --- a/tests/test_generate.py +++ b/tests/test_generate.py @@ -6,7 +6,6 @@ from contextlib import redirect_stderr, redirect_stdout from io import StringIO import os -from pathlib import Path from unittest import mock from unittest.mock import ANY, Mock, call diff --git a/tests/test_generate_adapter.py b/tests/test_generate_adapter.py index fbce300120..6e57ff0c5e 100644 --- a/tests/test_generate_adapter.py +++ b/tests/test_generate_adapter.py @@ -6,7 +6,6 @@ from contextlib import redirect_stderr, redirect_stdout from io import StringIO import os -from pathlib import Path from unittest.mock import ANY, Mock, call import pytest