This repository has been archived by the owner on Jul 18, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 20
/
text_perplexity_score.py
115 lines (98 loc) · 4.72 KB
/
text_perplexity_score.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""
Copyright 2024 Intel Corporation
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from .base import BaseLLMOperation, statistics_decorator
from ray.data import Dataset
from pyspark.sql import DataFrame
from pyrecdp.core.model_utils import get_model, prepare_model
from pyrecdp.primitives.operations.base import LLMOPERATORS
from pyrecdp.primitives.operations.utils import get_words_from_document
def text_bytesize(s):
return len(s.encode('utf-8'))
class TextPerplexityScore(BaseLLMOperation):
def __init__(self, text_key: str = 'text', language: str = 'en'):
"""
Generate perplexity score
:param language: Sample in which language. Default: en.(en, zh)
"""
settings = {'language': language, 'text_key': text_key}
requirements = []
super().__init__(settings, requirements)
self.language = language
self.text_key = text_key
self.inplace = False
self.sp_model_key = prepare_model(lang=language,
model_type='sentencepiece')
self.kl_model_key = prepare_model(lang=language, model_type='kenlm')
self.tokenizer = get_model(self.sp_model_key, self.language, 'sentencepiece')
self.kenlm_model = get_model(self.kl_model_key, self.language, 'kenlm')
@statistics_decorator
def process_rayds(self, ds: Dataset) -> Dataset:
if self.inplace:
raise NotImplementedError("We don't inplace modify text with normalization")
else:
new_name = 'perplexity'
compute_func = self.get_compute_func()
ret = ds.map(lambda x: self.process_row(x, self.text_key, new_name, compute_func))
if self.statistics_flag:
self.statistics.max = ret.max(new_name)
self.statistics.min = ret.min(new_name)
self.statistics.mean = ret.mean(new_name)
self.statistics.std = ret.std(new_name)
else:
self.statistics.max, self.statistics.min, self.statistics.mean, self.statistics.std = 0, 0, 0, 0
return ret
@statistics_decorator
def process_spark(self, spark, spark_df: DataFrame) -> DataFrame:
import pyspark.sql.functions as F
from pyspark.sql import types as T
bytesize_udf = F.udf(self.get_compute_func(), T.FloatType())
ret = spark_df.withColumn("perplexity", bytesize_udf(F.col(self.text_key)))
if self.statistics_flag:
self.statistics.max = ret.select(F.max("perplexity")).collect()[0][0]
self.statistics.min = ret.select(F.min("perplexity")).collect()[0][0]
self.statistics.mean = ret.select(F.mean("perplexity")).collect()[0][0]
self.statistics.std = ret.select(F.std("perplexity")).collect()[0][0]
else:
self.statistics.max, self.statistics.min, self.statistics.mean, self.statistics.std = 0, 0, 0, 0
return ret
def get_compute_func(self, *args, **kwargs):
tokenizer = self.tokenizer
kenlm_model = self.kenlm_model
def compute(text):
words = get_words_from_document(
text,
token_func=tokenizer.encode_as_pieces if tokenizer else None)
join_text = ' '.join(words)
# compute perplexity
logits, length = 0, 0
for line in join_text.splitlines():
logits += kenlm_model.score(line)
length += (len(line.split()) + 1)
ppl = (10.0 ** (-logits / length)) if length != 0 else 0.0
perplexity = round(ppl, 1)
return perplexity
return compute
def summarize(self) -> str:
statistics_save = {
"min": self.statistics.min,
"max": self.statistics.max,
"mean": self.statistics.mean,
"std": self.statistics.std,
}
return (statistics_save,
f"A total of {self.statistics.total_in} rows of data were processed, using {self.statistics.used_time} seconds, "
f"Get max perplexity {self.statistics.max}, "
f"Get min perplexity {self.statistics.min}, "
f"Get average perplexity {self.statistics.mean},"
f"Get the std of perplexity {self.statistics.std}")
LLMOPERATORS.register(TextPerplexityScore)