Skip to content

Commit

Permalink
Use 10k samples by default
Browse files Browse the repository at this point in the history
  • Loading branch information
KevinMusgrave committed Jan 30, 2024
1 parent a8669f2 commit f81cbf1
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions blog/llm-finetuning/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ def add_length_column(dataset):
return df


def filter_by_total_length(df, difficulty):
def filter_by_total_length(df, difficulty, number_of_samples):
if difficulty == "easy":
return df[df["total_length"].between(10, 100)]
return df[df["total_length"].between(10, 100)].iloc[:number_of_samples]
elif difficulty == "medium":
return df[df["total_length"].between(101, 200)]
return df[df["total_length"].between(101, 200)].iloc[:number_of_samples]
elif difficulty == "hard":
return df[df["total_length"].between(200, 800)]
return df[df["total_length"].between(201, 800)].iloc[:number_of_samples]


def get_dataset_subset_name(difficulty):
Expand Down Expand Up @@ -56,13 +56,13 @@ def load_dataset(difficulty):
return datasets.load_from_disk(get_dataset_subset_name(difficulty))


def load_or_create_dataset(difficulty):
def load_or_create_dataset(difficulty, num_samples=10000):
try:
return load_dataset(difficulty)
except FileNotFoundError:
dataset = datasets.load_dataset("Clinton/Text-to-sql-v1")
dataset = dataset["train"]
dataset = dataset.remove_columns(["text", "source"])
df = add_length_column(dataset)
df = filter_by_total_length(df, difficulty)
df = filter_by_total_length(df, difficulty, num_samples)
return create_and_save_datasets(df, difficulty)

0 comments on commit f81cbf1

Please sign in to comment.