-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtoy.py
85 lines (73 loc) · 2.27 KB
/
toy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
'''
Toy dataset and tokenizer for quick prototype.
ds = ToyDataset()
tn = ToyTokenizer()
print(ds[30])
print(tn.tokenize(ds[30]))
print(tn.detokenize(tn.tokenize(ds[30])))
'''
import torch
class ToyTokenizer:
def __init__(self):
self.eos = 0
self.token_add = 1
self.token_equal = 2
self.zero = 3
def tokenize(self, text):
chars = text.split(' ')
result = []
for char in chars:
if len(char)==0:
continue
if char == '+':
token = self.token_add
elif char == '=':
token = self.token_equal
else:
num = int(char)
token = self.zero + num
result.append(token)
return result
def token2char(self, token):
if token == self.token_add:
char = '+'
elif token == self.token_equal:
char = '='
elif token == self.eos:
char = ''
else:
char = str(token.item() - self.zero)
return char
def detokenize(self, tokens):
result = []
for token in tokens:
char = self.token2char(token)
result.append(char)
return ' '.join(result)
class ToyDataset(torch.utils.data.Dataset):
def __init__(self, transform=None, n_epochs=1):
self.transform = transform
self.repeat = n_epochs # for better training performance to avoid setup new pipeline
self.num_samples = 100
pass
def __len__(self):
return self.num_samples * self.repeat
def __getitem__(self, idx):
idx = idx % self.num_samples
x = idx // 10
y = idx % 10
z = x + y
result = f'{x} + {y} = {z}'
if self.transform:
result = self.transform(result)
return result
class TokenizerTransform:
def __init__(self, max_seq=10):
self.tokenizer = ToyTokenizer()
self.max_seq = max_seq
def __call__(self, text):
tokens = self.tokenizer.tokenize(text)
padding_size = self.max_seq - len(tokens)
tokens = tokens + [self.tokenizer.eos for _ in range(padding_size)]
# OK to use int in MPS, but CUDA requires long.
return torch.tensor(tokens, dtype=torch.long)