diff --git a/makemore.py b/makemore.py index db0fe4c..446b116 100644 --- a/makemore.py +++ b/makemore.py @@ -542,7 +542,7 @@ def __getitem__(self, idx): y[len(ix)+1:] = -1 # index -1 will mask the loss at the inactive locations return x, y -def create_datasets(input_file): +def create_datasets(input_file, max_word_length: int | None = None): # preprocessing of the input text file with open(input_file, 'r') as f: @@ -551,7 +551,8 @@ def create_datasets(input_file): words = [w.strip() for w in words] # get rid of any leading or trailing white space words = [w for w in words if w] # get rid of any empty strings chars = sorted(list(set(''.join(words)))) # all the possible characters - max_word_length = max(len(w) for w in words) + if not max_word_length: + max_word_length = max(len(w) for w in words) print(f"number of examples in the dataset: {len(words)}") print(f"max word length: {max_word_length}") print(f"number of unique characters in the vocabulary: {len(chars)}") @@ -604,6 +605,7 @@ def next(self): parser.add_argument('--max-steps', type=int, default=-1, help="max number of optimization steps to run for, or -1 for infinite.") parser.add_argument('--device', type=str, default='cpu', help="device to use for compute, examples: cpu|cuda|cuda:2|mps") parser.add_argument('--seed', type=int, default=3407, help="seed") + parser.add_argument('--max-word-length', type=int, default=None, help="The max word length to use for the dataset") # sampling parser.add_argument('--top-k', type=int, default=-1, help="top-k for sampling, -1 means no top-k") # model @@ -626,7 +628,7 @@ def next(self): writer = SummaryWriter(log_dir=args.work_dir) # init datasets - train_dataset, test_dataset = create_datasets(args.input_file) + train_dataset, test_dataset = create_datasets(args.input_file, args.max_word_length) vocab_size = train_dataset.get_vocab_size() block_size = train_dataset.get_output_length() print(f"dataset determined that: {vocab_size=}, {block_size=}")