-
Notifications
You must be signed in to change notification settings - Fork 75
/
tiktoken.go
128 lines (110 loc) · 3.26 KB
/
tiktoken.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
package tiktoken
import (
"fmt"
"regexp"
"strings"
"github.com/dlclark/regexp2"
)
var bpeLoader BpeLoader = NewDefaultBpeLoader()
func SetBpeLoader(loader BpeLoader) {
bpeLoader = loader
}
func GetEncoding(encodingName string) (*Tiktoken, error) {
enc, err := getEncoding(encodingName)
if err != nil {
return nil, err
}
pbe, err := NewCoreBPE(enc.MergeableRanks, enc.SpecialTokens, enc.PatStr)
if err != nil {
return nil, err
}
specialTokensSet := map[string]any{}
for k := range enc.SpecialTokens {
specialTokensSet[k] = true
}
return NewTiktoken(pbe, enc, specialTokensSet), nil
}
func EncodingForModel(modelName string) (*Tiktoken, error) {
if encodingName, ok := MODEL_TO_ENCODING[modelName]; ok {
return GetEncoding(encodingName)
} else {
for prefix, encodingName := range MODEL_PREFIX_TO_ENCODING {
if strings.HasPrefix(modelName, prefix) {
return GetEncoding(encodingName)
}
}
}
return nil, fmt.Errorf("no encoding for model %s", modelName)
}
type Tiktoken struct {
bpe *CoreBPE
pbeEncoding *Encoding
specialTokensSet map[string]any
}
func (t *Tiktoken) Encode(text string, allowedSpecial []string, disallowedSpecial []string) []int {
var allowedSpecialSet map[string]any
if len(allowedSpecial) == 0 {
allowedSpecialSet = map[string]any{}
} else if len(allowedSpecial) == 1 && allowedSpecial[0] == "all" {
allowedSpecialSet = t.specialTokensSet
} else {
allowedSpecialSet = map[string]any{}
for _, v := range allowedSpecial {
allowedSpecialSet[v] = nil
}
}
disallowedSpecialSet := map[string]any{}
for _, v := range disallowedSpecial {
disallowedSpecialSet[v] = nil
}
if len(disallowedSpecial) == 1 && disallowedSpecial[0] == "all" {
disallowedSpecialSet = difference(t.specialTokensSet, allowedSpecialSet)
}
if len(disallowedSpecialSet) > 0 {
specialRegex := t.SpecialTokenRegex(disallowedSpecialSet)
m := findRegex2StringMatch(text, specialRegex)
if m != "" {
panic(fmt.Sprintf("text contains disallowed special token %s", m))
}
}
tokens, _ := t.bpe.encodeNative(text, allowedSpecialSet)
return tokens
}
func (t *Tiktoken) EncodeOrdinary(text string) []int {
return (t.bpe.encodeOrdinaryNative(text))
}
func (t *Tiktoken) Decode(tokens []int) string {
return string(t.bpe.decodeNative(tokens))
}
func (t *Tiktoken) SpecialTokenRegex(disallowedSpecialSet map[string]any) *regexp2.Regexp {
specialRegexStrs := make([]string, 0, len(disallowedSpecialSet))
for k := range disallowedSpecialSet {
specialRegexStrs = append(specialRegexStrs, regexp.QuoteMeta(k))
}
specialRegex := regexp2.MustCompile(strings.Join(specialRegexStrs, "|"), regexp2.None)
return specialRegex
}
func findRegex2StringMatch(text string, reg *regexp2.Regexp) string {
m, _ := reg.FindStringMatch(text)
if m == nil {
return ""
}
return m.String()
}
func difference(setA, setB map[string]any) map[string]any {
result := make(map[string]any)
for k := range setA {
if _, ok := setB[k]; !ok {
result[k] = true
}
}
return result
}
// NewTiktoken can be used to create a *Tiktoken with custom parameters.
func NewTiktoken(bpe *CoreBPE, encoding *Encoding, specialTokensSet map[string]any) *Tiktoken {
return &Tiktoken{
bpe: bpe,
pbeEncoding: encoding,
specialTokensSet: specialTokensSet,
}
}