Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Validation #898

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
8cb6522
add validation script
xiaohanzhan-db Dec 23, 2023
c59c11f
update
xiaohanzhan-db Jan 3, 2024
66f34eb
change token count function
Jan 3, 2024
2cd387b
reorganize cells
Jan 5, 2024
3eac3bf
Add unit tests
xiaohanzhan-db Jan 5, 2024
d2d9767
Add a printout for CPT
xiaohanzhan-db Jan 6, 2024
be25591
update question
xiaohanzhan-db Jan 6, 2024
4651be7
Add questions
Jan 8, 2024
5cd6a94
Fix lints
xiaohanzhan-db Jan 8, 2024
8e2c1f4
Merge branch 'main' into validation
XiaohanZhangCMU Jan 8, 2024
e6e4a81
update format
xiaohanzhan-db Jan 8, 2024
34c5690
Merge branch 'validation' of github.com:XiaohanZhangCMU/llm-foundryX …
xiaohanzhan-db Jan 8, 2024
1668b9a
update
xiaohanzhan-db Jan 8, 2024
2219135
nb source
xiaohanzhan-db Jan 8, 2024
86c6e87
add validation script
xiaohanzhan-db Dec 23, 2023
678b376
update
xiaohanzhan-db Jan 3, 2024
297e057
change token count function
Jan 3, 2024
09d0ebb
reorganize cells
Jan 5, 2024
460df65
Add unit tests
xiaohanzhan-db Jan 5, 2024
3ffd200
Add a printout for CPT
xiaohanzhan-db Jan 6, 2024
9362886
update question
xiaohanzhan-db Jan 6, 2024
898e5ac
Add questions
Jan 8, 2024
a4bef71
Fix lints
xiaohanzhan-db Jan 8, 2024
4ca9cc6
update format
xiaohanzhan-db Jan 8, 2024
d636a0f
update
xiaohanzhan-db Jan 8, 2024
827d155
nb source
xiaohanzhan-db Jan 8, 2024
6bbf3fc
Remove license insert for validation notebook
xiaohanzhan-db Jan 8, 2024
4f6a4fb
Merge branch 'validation' of github.com:XiaohanZhangCMU/llm-foundryX …
xiaohanzhan-db Jan 8, 2024
5966b68
Add validation utils
xiaohanzhan-db Jan 11, 2024
da17813
Merge branch 'main' into validation
xiaohanzhan-db Jan 11, 2024
a7c36bc
Minor cleanups (#858)
mvpatel2000 Jan 11, 2024
55e4626
update utils/__init__.py to include extra validation functions
xiaohanzhan-db Jan 11, 2024
45544a1
update notebook
Jan 11, 2024
d2797b3
update
xiaohanzhan-db Jan 11, 2024
019da77
Merge branch 'validation' of github.com:XiaohanZhangCMU/llm-foundryX …
xiaohanzhan-db Jan 11, 2024
756fdae
update
xiaohanzhan-db Jan 11, 2024
6de8c37
Read UC delta table (#773)
XiaohanZhangCMU Jan 11, 2024
93b5a9f
Add download remote function to util
xiaohanzhan-db Jan 11, 2024
b47c878
update
xiaohanzhan-db Jan 11, 2024
fa8f3d9
remove fused layernorm (#859)
mvpatel2000 Jan 11, 2024
13fd34c
update
xiaohanzhan-db Jan 11, 2024
610f669
update
xiaohanzhan-db Jan 11, 2024
9f2e51b
update
xiaohanzhan-db Jan 11, 2024
ec68f10
update
xiaohanzhan-db Jan 11, 2024
1e76068
update
xiaohanzhan-db Jan 11, 2024
7a5c164
update
xiaohanzhan-db Jan 11, 2024
e76038f
Merge branch 'main' into validation
xiaohanzhan-db Jan 11, 2024
5b413f5
update
xiaohanzhan-db Jan 11, 2024
a1aa31f
update
xiaohanzhan-db Jan 11, 2024
d24fd5c
update
xiaohanzhan-db Jan 11, 2024
da3bea1
Remove hardcoded combined.jsonl with a flag (#861)
XiaohanZhangCMU Jan 12, 2024
936e3a1
bump (#828)
mvpatel2000 Jan 12, 2024
55fce37
Add dask and dataframe_to_mds
xiaohanzhan-db Jan 12, 2024
86e2412
update
xiaohanzhan-db Jan 12, 2024
bbfec65
update
xiaohanzhan-db Jan 12, 2024
b2e880d
update
xiaohanzhan-db Jan 12, 2024
596443a
update
xiaohanzhan-db Jan 12, 2024
ea65187
Add notebook
xiaohanzhan-db Jan 12, 2024
378a4e0
update
xiaohanzhan-db Jan 12, 2024
af6e9aa
update
Jan 12, 2024
4e286ec
remove script and tests, keep notebook
xiaohanzhan-db Jan 12, 2024
09c4892
update
xiaohanzhan-db Jan 12, 2024
c82da6c
update
xiaohanzhan-db Jan 12, 2024
e5f83cc
update
xiaohanzhan-db Jan 12, 2024
17d2b9f
update
xiaohanzhan-db Jan 12, 2024
6579d55
Merge branch 'main' into validation
xiaohanzhan-db Jan 12, 2024
56308ff
Merge branch 'byod/data_validation' into validation
XiaohanZhangCMU Jan 12, 2024
6517a30
Always initialize dist (#864)
mvpatel2000 Jan 12, 2024
4daa324
updated notebook
Jan 12, 2024
b809691
Merge branch 'main' into validation
xiaohanzhan-db Jan 12, 2024
8b75f94
remove scripts keep notebook
xiaohanzhan-db Jan 12, 2024
99bf2cd
merge with byod/data_validation
xiaohanzhan-db Jan 12, 2024
22014d6
update notebook. rephrase.
Jan 12, 2024
d9f28aa
merged
xiaohanzhan-db Jan 12, 2024
43c8ac9
update
xiaohanzhan-db Jan 12, 2024
b8ac771
Add response tokens
xiaohanzhan-db Jan 16, 2024
1b9681c
update
xiaohanzhan-db Jan 16, 2024
16883c2
merge
xiaohanzhan-db Jan 16, 2024
c7567f1
update
xiaohanzhan-db Jan 20, 2024
1764b72
Disable MDSWrite, return token counts
xiaohanzhan-db Jan 22, 2024
808ced5
Change plot settings
xiaohanzhan-db Jan 23, 2024
26ae516
Fix conflict
xiaohanzhan-db Jan 23, 2024
a212ee8
update notebook
Jan 23, 2024
d279817
update
xiaohanzhan-db Jan 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 45 additions & 18 deletions llmfoundry/utils/validation_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,9 @@ def get_num_samples_in_batch(batch: dict) -> int:

response_tokens = len(batch['labels']) if 'labels' in batch else 0

return {'ntokens': input_ids_tokens + decoder_input_ids_tokens + response_tokens}
return {
'ntokens': input_ids_tokens + decoder_input_ids_tokens + response_tokens
}


def token_counts(FT_API_args):
Expand Down Expand Up @@ -270,7 +272,7 @@ def count_shards(mds_root: str):
merge_shard_groups)

log = logging.getLogger(__name__)
DONE_FILENAME = '.text_to_mds_conversion_done'
DONE_FILENAME = '/Volumes/main/mosaic_hackathon/managed-volume/text_to_mds_conversion_done'


def parse_args(tokenizer,
Expand Down Expand Up @@ -499,6 +501,8 @@ def download_and_convert(
bos_text (str): Text to prepend to each example to separate concatenated samples
no_wrap: (bool): Whether to let text examples wrap across multiple training examples
compression (str): The compression algorithm to use for MDS writing
Returns:
(int): token count of the current group
"""
object_store = maybe_create_object_store_from_uri(input_folder)

Expand All @@ -521,14 +525,18 @@ def download_and_convert(
no_wrap=no_wrap,
)

columns = {'tokens': 'bytes'}
token_count = sum([ 1 for _ in dataset])

# columns = {'tokens': 'bytes'}

log.info('Converting to MDS format...')
with MDSWriter(out=output_folder,
columns=columns,
compression=compression) as out:
for sample in tqdm(dataset):
out.write(sample)
# log.info('Converting to MDS format...')
# with MDSWriter(out=output_folder,
# columns=columns,
# compression=compression) as out:
# for sample in tqdm(dataset):
# out.write(sample)

return token_count


def is_remote_path(path: str) -> bool:
Expand Down Expand Up @@ -616,7 +624,7 @@ def convert_text_to_mds(
processes: int,
args_str: str,
reprocess: bool,
):
)->int:
"""Convert a folder of text files to MDS format.

Args:
Expand All @@ -631,6 +639,8 @@ def convert_text_to_mds(
processes (int): The number of processes to use.
args_str (str): String representation of the arguments
reprocess (bool): Whether to always reprocess the given folder of text files
Returns:
(int): total tokens of the dataset
"""
is_remote_output = is_remote_path(output_folder)

Expand Down Expand Up @@ -658,12 +668,13 @@ def convert_text_to_mds(
processes, tokenizer_name, concat_tokens, eos_text,
bos_text, no_wrap, compression)
with ProcessPoolExecutor(max_workers=processes) as executor:
list(executor.map(download_and_convert_starargs, args))
pool = list(executor.map(download_and_convert_starargs, args))

# Merge the mds shards from each of the processes into a single folder
merge_shard_groups(local_output_folder)
# merge_shard_groups(local_output_folder)
total_tokens = sum(pool)
else:
download_and_convert(object_names, local_output_folder, input_folder,
total_tokens = download_and_convert(object_names, local_output_folder, input_folder,
tokenizer_name, concat_tokens, eos_text, bos_text,
no_wrap, compression)

Expand All @@ -683,6 +694,8 @@ def convert_text_to_mds(
output_object_store.upload_object(
remote_path, os.path.join(local_output_folder, file))

return total_tokens


def _args_str(original_args: Namespace) -> str:
"""Create a string from the args to determine whether to reprocess.
Expand Down Expand Up @@ -801,8 +814,8 @@ def plot_hist(data, save_plot_path=None):

# Aesthetics
plt.title('Histogram of Token Counts')
plt.xlabel('Token Count')
plt.ylabel('Frequency')
plt.xlabel('Number of Tokens per Sample')
plt.ylabel('Count of Frequency')

# Grid and Layout
plt.grid(axis='y', alpha=0.75)
Expand Down Expand Up @@ -855,6 +868,19 @@ def pandas_processing_fn(df: pd.DataFrame,
hf_dataset = hf_datasets.Dataset.from_pandas(df=df)
tokenizer = AutoTokenizer.from_pretrained(args['tokenizer'])
tokenizer.model_max_length = 5000000000 # Hack to prevent warnings from HuggingFace

if bos_text + eos_text == '':
test_tokens = tokenizer('test')
if test_tokens['input_ids'][
0] != tokenizer.bos_token_id and test_tokens['input_ids'][
-1] != tokenizer.eos_token_id:
tok_error_msg = 'This tokenizer does not insert an EOS nor BOS token. '
tok_error_msg += 'Concatenating with this tokenizer will result in sequences being '
tok_error_msg += 'attached without a separating token. Please use another tokenizer, '
tok_error_msg += 'such as facebook/opt-125m, or specify EOS/BOS text with e.g. '
tok_error_msg += '--bos_text=<|endoftext|>.'
raise ValueError(tok_error_msg)

dataset = ConcatTokensDataset(
hf_dataset=hf_dataset,
max_length=args.get('concat_tokens', None),
Expand Down Expand Up @@ -893,15 +919,16 @@ def pandas_processing_fn(df: pd.DataFrame,
except ImportError as e:
e.msg = get_import_exception_message(e.name,
extra_deps='spark') # pyright: ignore
raise e
#raise e

try:
from dask.dataframe import DataFrame as DaskDataFrame
from dask.distributed import Client, LocalCluster
except ImportError as e:
e.msg = get_import_exception_message(e.name,
extra_deps='dask') # pyright: ignore
raise e
#raise e
DaskDataFrame = None

try:
from streaming import MDSWriter
Expand All @@ -912,7 +939,7 @@ def pandas_processing_fn(df: pd.DataFrame,
except ImportError as e:
e.msg = get_import_exception_message(
e.name, extra_deps='streaming') # pyright: ignore
raise e
#raise e

logger = logging.getLogger(__name__)

Expand Down
Loading