diff --git a/litgpt/config.py b/litgpt/config.py index ac6a84fcef..ebc74b4c3b 100644 --- a/litgpt/config.py +++ b/litgpt/config.py @@ -8,8 +8,6 @@ import torch import yaml from typing_extensions import Self - -import litgpt.model from litgpt.utils import find_multiple @@ -144,6 +142,7 @@ def from_checkpoint(cls, path: Path, **kwargs: Any) -> Self: @property def mlp_class(self) -> Type: # `self.mlp_class_name` cannot be the type to keep the config serializable + import litgpt.model return getattr(litgpt.model, self.mlp_class_name) @property diff --git a/litgpt/data/base.py b/litgpt/data/base.py index 668a4a4a07..f4ef68a818 100644 --- a/litgpt/data/base.py +++ b/litgpt/data/base.py @@ -8,7 +8,7 @@ from torch import Tensor from torch.utils.data import Dataset -from litgpt import Tokenizer +from litgpt.tokenizer import Tokenizer from litgpt.prompts import PromptStyle diff --git a/litgpt/data/deita.py b/litgpt/data/deita.py index 56bb54d8f8..bc93750014 100644 --- a/litgpt/data/deita.py +++ b/litgpt/data/deita.py @@ -7,7 +7,7 @@ import torch from torch.utils.data import DataLoader -from litgpt import PromptStyle +from litgpt.prompts import PromptStyle from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn from litgpt.tokenizer import Tokenizer diff --git a/litgpt/data/dolly.py b/litgpt/data/dolly.py index 0939157033..1e0789fae2 100644 --- a/litgpt/data/dolly.py +++ b/litgpt/data/dolly.py @@ -8,7 +8,7 @@ import torch from torch.utils.data import random_split -from litgpt import PromptStyle +from litgpt.prompts import PromptStyle from litgpt.data import Alpaca, SFTDataset _URL: str = "https://huggingface.co/datasets/databricks/databricks-dolly-15k/resolve/main/databricks-dolly-15k.jsonl" diff --git a/litgpt/data/flan.py b/litgpt/data/flan.py index df0a3e2cca..e06bfe86a7 100644 --- a/litgpt/data/flan.py +++ b/litgpt/data/flan.py @@ -8,7 +8,7 @@ import torch from torch.utils.data import DataLoader -from litgpt import PromptStyle +from litgpt.prompts import PromptStyle from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn from litgpt.data.alpaca import download_if_missing from litgpt.tokenizer import Tokenizer diff --git a/litgpt/data/json_data.py b/litgpt/data/json_data.py index fbcb42d0b8..3e9a51d409 100644 --- a/litgpt/data/json_data.py +++ b/litgpt/data/json_data.py @@ -8,7 +8,7 @@ import torch from torch.utils.data import DataLoader, random_split -from litgpt import PromptStyle +from litgpt.prompts import PromptStyle from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn from litgpt.tokenizer import Tokenizer diff --git a/litgpt/data/lima.py b/litgpt/data/lima.py index 581e957207..6eb4ef7aa4 100644 --- a/litgpt/data/lima.py +++ b/litgpt/data/lima.py @@ -7,7 +7,7 @@ import torch from torch.utils.data import DataLoader, random_split -from litgpt import PromptStyle +from litgpt.prompts import PromptStyle from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn from litgpt.tokenizer import Tokenizer diff --git a/litgpt/data/lit_data.py b/litgpt/data/lit_data.py index ddac413f28..15b800d6ba 100644 --- a/litgpt/data/lit_data.py +++ b/litgpt/data/lit_data.py @@ -6,7 +6,7 @@ from torch.utils.data import DataLoader -from litgpt import Tokenizer +from litgpt.tokenizer import Tokenizer from litgpt.data import DataModule diff --git a/litgpt/data/longform.py b/litgpt/data/longform.py index fb264496e9..27ee63ee21 100644 --- a/litgpt/data/longform.py +++ b/litgpt/data/longform.py @@ -8,7 +8,7 @@ import torch from torch.utils.data import DataLoader -from litgpt import PromptStyle +from litgpt.prompts import PromptStyle from litgpt.data import DataModule, SFTDataset, get_sft_collate_fn from litgpt.data.alpaca import download_if_missing from litgpt.tokenizer import Tokenizer diff --git a/litgpt/data/openwebtext.py b/litgpt/data/openwebtext.py index 74fe38b331..4bf0a64adc 100644 --- a/litgpt/data/openwebtext.py +++ b/litgpt/data/openwebtext.py @@ -7,7 +7,7 @@ from torch.utils.data import DataLoader -from litgpt import Tokenizer +from litgpt.tokenizer import Tokenizer from litgpt.data import DataModule diff --git a/litgpt/data/prepare_slimpajama.py b/litgpt/data/prepare_slimpajama.py index c1cad7af12..5eb6aaad33 100644 --- a/litgpt/data/prepare_slimpajama.py +++ b/litgpt/data/prepare_slimpajama.py @@ -5,7 +5,7 @@ import time from pathlib import Path -from litgpt import Tokenizer +from litgpt.tokenizer import Tokenizer from litgpt.data.prepare_starcoder import DataChunkRecipe from litgpt.utils import CLI, extend_checkpoint_dir diff --git a/litgpt/data/prepare_starcoder.py b/litgpt/data/prepare_starcoder.py index 2e2741d31b..4deb2581f3 100644 --- a/litgpt/data/prepare_starcoder.py +++ b/litgpt/data/prepare_starcoder.py @@ -7,7 +7,7 @@ from lightning_utilities.core.imports import RequirementCache -from litgpt import Tokenizer +from litgpt.tokenizer import Tokenizer from litgpt.utils import CLI, extend_checkpoint_dir _LITDATA_AVAILABLE = RequirementCache("litdata") diff --git a/litgpt/data/text_files.py b/litgpt/data/text_files.py index 4776e66afd..e584d2a006 100644 --- a/litgpt/data/text_files.py +++ b/litgpt/data/text_files.py @@ -8,7 +8,7 @@ from torch.utils.data import DataLoader -from litgpt import Tokenizer +from litgpt.tokenizer import Tokenizer from litgpt.data import DataModule diff --git a/litgpt/data/tinyllama.py b/litgpt/data/tinyllama.py index 44214333c9..73d204e710 100644 --- a/litgpt/data/tinyllama.py +++ b/litgpt/data/tinyllama.py @@ -5,7 +5,7 @@ from torch.utils.data import DataLoader -from litgpt import Tokenizer +from litgpt.tokenizer import Tokenizer from litgpt.data import DataModule diff --git a/litgpt/data/tinystories.py b/litgpt/data/tinystories.py index c6c54c9b38..54a1c83ae0 100644 --- a/litgpt/data/tinystories.py +++ b/litgpt/data/tinystories.py @@ -10,7 +10,7 @@ from torch.utils.data import DataLoader from tqdm import tqdm -from litgpt import Tokenizer +from litgpt.tokenizer import Tokenizer from litgpt.data import DataModule from litgpt.data.alpaca import download_if_missing from litgpt.data.text_files import validate_tokenizer