Skip to content

Commit

Permalink
<s> remove html tags </s>
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Dec 27, 2023
1 parent 250a449 commit 176e25e
Show file tree
Hide file tree
Showing 13 changed files with 121 additions and 79 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ external
notebooks
final_outputs
.cache*
data_subset/**
data_subset/**
*.pth
9 changes: 8 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,35 @@ def test_split_ort():
splits = wtp.split("This is a test sentence This is another test sentence.", threshold=0.005)
assert splits == ["This is a test sentence ", "This is another test sentence."]


def test_split_torch():
wtp = WtP("benjamin/wtp-bert-mini", hub_prefix=None)

splits = wtp.split("This is a test sentence This is another test sentence.", threshold=0.005)
assert splits == ["This is a test sentence ", "This is another test sentence."]


def test_split_torch_canine():
wtp = WtP("benjamin/wtp-canine-s-1l", hub_prefix=None)

splits = wtp.split("This is a test sentence. This is another test sentence.", lang_code="en")
assert splits == ["This is a test sentence. ", "This is another test sentence."]


def test_move_device():
wtp = WtP("benjamin/wtp-bert-mini", hub_prefix=None)
wtp.half().to("cpu")


def test_strip_whitespace():
wtp = WtP("benjamin/wtp-bert-mini", hub_prefix=None)

splits = wtp.split("This is a test sentence This is another test sentence. ", strip_whitespace=True, threshold=0.005)
splits = wtp.split(
"This is a test sentence This is another test sentence. ", strip_whitespace=True, threshold=0.005
)
assert splits == ["This is a test sentence", "This is another test sentence."]


def test_split_long():
prefix = "x" * 2000

Expand Down
2 changes: 1 addition & 1 deletion tpu_starter.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
for var in "$@"
do
until gcloud compute tpus tpu-vm create $var --zone=europe-west4-a --accelerator-type=v3-8 --version=tpu-vm-base; do sleep 5; done
until gcloud compute tpus tpu-vm create $var --zone=europe-west4-a --accelerator-type=v3-8 --version=tpu-vm-pt-1.13; do sleep 3; done
done
20 changes: 12 additions & 8 deletions utils/remove_unks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tokenizers import AddedToken
from wtpsplit.utils import Constants, LabelArgs


def get_subword_label_dict(label_args, tokenizer):
label_dict = {}

Expand Down Expand Up @@ -36,30 +37,33 @@ def get_subword_label_dict(label_args, tokenizer):
label_dict = get_subword_label_dict(LabelArgs(), tokenizer)
print(len(label_dict))


def write_punctuation_file():
with open(os.path.join(Constants.ROOT_DIR, "punctuation_xlmr.txt"), 'w', encoding='utf-8') as file:
with open(os.path.join(Constants.ROOT_DIR, "punctuation_xlmr.txt"), "w", encoding="utf-8") as file:
for char in Constants.PUNCTUATION_CHARS:
token_id = tokenizer.convert_tokens_to_ids(char)
if token_id != tokenizer.unk_token_id:
file.write(char + '\n')

file.write(char + "\n")


def write_punctuation_file_unk():
added_unk = False
with open(os.path.join(Constants.ROOT_DIR, "punctuation_xlmr_unk.txt"), 'w', encoding='utf-8') as file:
with open(os.path.join(Constants.ROOT_DIR, "punctuation_xlmr_unk.txt"), "w", encoding="utf-8") as file:
for char in Constants.PUNCTUATION_CHARS:
token_id = tokenizer.convert_tokens_to_ids(char)
if token_id != tokenizer.unk_token_id:
file.write(char + '\n')
file.write(char + "\n")
elif not added_unk:
print("added unk")
file.write('<unk>\n')
file.write("<unk>\n")
added_unk = True


write_punctuation_file()
write_punctuation_file_unk()

label_args_default = LabelArgs()
print(Constants.PUNCTUATION_CHARS, len(Constants.PUNCTUATION_CHARS))

label_args_custom = LabelArgs(custom_punctuation_file='punctuation_xlmr.txt')
print(Constants.PUNCTUATION_CHARS, len(Constants.PUNCTUATION_CHARS))
label_args_custom = LabelArgs(custom_punctuation_file="punctuation_xlmr.txt")
print(Constants.PUNCTUATION_CHARS, len(Constants.PUNCTUATION_CHARS))
4 changes: 3 additions & 1 deletion wtpsplit/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ def __init__(

self.num_hash_buckets = num_hash_buckets
self.num_hash_functions = num_hash_functions



class SubwordXLMConfig(XLMRobertaConfig):
"""Config for XLM-R and XLM-V models. Used for token-level training.
Args:
XLMRobertaConfig: Base class.
"""

model_type = "xlm-token"
mixture_name = "xlm-token"

Expand Down
4 changes: 2 additions & 2 deletions wtpsplit/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,14 +230,14 @@ def ersatz_sentencize(
):
if lang_code not in ERSATZ_LANGUAGES:
raise LanguageError(f"ersatz does not support {lang_code}")

# check if infile parent dir exists, if not, create it
if not os.path.exists(os.path.dirname(infile)):
os.makedirs(os.path.dirname(infile))
# check if outfile parent dir exists, if not, create it
if not os.path.exists(os.path.dirname(outfile)):
os.makedirs(os.path.dirname(outfile))

open(infile, "w").write(text)

subprocess.check_output(
Expand Down
8 changes: 4 additions & 4 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from tqdm.auto import tqdm
from transformers import AutoModelForTokenClassification, HfArgumentParser

import wtpsplit.models # noqa: F401
import wtpsplit.models # noqa: F401
from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture
from wtpsplit.extract import PyTorchWrapper, extract
from wtpsplit.utils import Constants
Expand All @@ -27,13 +27,13 @@ class Args:
# "meta": {
# "train_data": ["train sentence 1", "train sentence 2"]
# },
# "data": ["test sentence 1", "test sentence 2"]
# "data": ["test sentence 1", "test sentence 2"]
# }
# }
# }
# }
eval_data_path: str = "data/eval.pth"
valid_text_path: str = None#"data/sentence/valid.parquet"
valid_text_path: str = None # "data/sentence/valid.parquet"
device: str = "cpu"
block_size: int = 512
stride: int = 64
Expand Down Expand Up @@ -131,7 +131,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
valid_data = load_dataset("parquet", data_files=args.valid_text_path, split="train")
else:
valid_data = None

model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device))

# first, logits for everything.
Expand Down
26 changes: 13 additions & 13 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

logger = logging.getLogger(__name__)


class ORTWrapper:
def __init__(self, config, ort_session):
self.config = config
Expand All @@ -24,7 +25,7 @@ def __call__(self, hashed_ids, attention_mask):
logits = self.ort_session.run(
["logits"],
{
"attention_mask": attention_mask.astype(np.float16), # ORT expects fp16 mask
"attention_mask": attention_mask.astype(np.float16), # ORT expects fp16 mask
"hashed_ids": hashed_ids,
},
)[0]
Expand Down Expand Up @@ -63,6 +64,7 @@ def __call__(self, input_ids, hashed_ids, attention_mask, language_ids=None):

return {"logits": logits}


def extract(
batch_of_texts,
model,
Expand Down Expand Up @@ -104,11 +106,11 @@ def extract(
# make sure block_size is a multiple of downsampling rate
downsampling_rate = getattr(model.config, "downsampling_rate", 1)
block_size = math.ceil(block_size / downsampling_rate) * downsampling_rate
actual_block_size = block_size - 2 if use_subwords else block_size # account for CLS and SEP tokens
actual_block_size = block_size - 2 if use_subwords else block_size # account for CLS and SEP tokens

# total number of forward passes
num_chunks = sum(math.ceil(max(length - actual_block_size, 0) / stride) + 1 for length in text_lengths)

# preallocate a buffer for all input hashes & attention masks
if not use_subwords:
input_hashes = np.zeros((num_chunks, block_size, model.config.num_hash_functions), dtype=np.int64)
Expand All @@ -121,18 +123,17 @@ def extract(
locs = np.zeros((num_chunks, 3), dtype=np.int32)

if not use_subwords:
# this is equivalent to (but faster than) np.array([ord(c) for c in "".join(batch_of_texts)])
# this is equivalent to (but faster than) np.array([ord(c) for c in "".join(batch_of_texts)])
codec = "utf-32-le" if sys.byteorder == "little" else "utf-32-be"
ordinals = np.frombuffer(bytearray("".join(batch_of_texts), encoding=codec), dtype=np.int32)
# hash encode all ids
flat_hashed_ids = hash_encode(ordinals,
num_hashes=model.config.num_hash_functions,
num_buckets=model.config.num_hash_buckets)
flat_hashed_ids = hash_encode(
ordinals, num_hashes=model.config.num_hash_functions, num_buckets=model.config.num_hash_buckets
)
# note that ordinals and flat_hashed_ids have the same length
offset = 0
current_chunk = 0



# create chunks
for i in range(len(batch_of_texts)):
for j in range(0, text_lengths[i], stride):
Expand All @@ -150,9 +151,9 @@ def extract(
attention_mask[current_chunk, : end - start] = 1
else:
chunk = [cls_token_id] + batch_of_texts[i][start:end] + [sep_token_id]
input_ids[current_chunk, :len(chunk)] = chunk
attention_mask[current_chunk, :len(chunk)] = 1
input_ids[current_chunk, : len(chunk)] = chunk
attention_mask[current_chunk, : len(chunk)] = 1

locs[current_chunk, :] = [i, start, end]
current_chunk += 1

Expand Down Expand Up @@ -212,7 +213,6 @@ def extract(
# Pad with the specific pad_token_id for the tokenizer
batch_input_ids = np.pad(batch_input_ids, ((0, n_missing), (0, 0)), constant_values=pad_token_id)
batch_attention_mask = np.pad(batch_attention_mask, ((0, n_missing), (0, 0)))


kwargs = {"language_ids": language_ids[: len(batch_attention_mask)]} if uses_lang_adapters else {}

Expand Down
21 changes: 11 additions & 10 deletions wtpsplit/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -958,16 +958,17 @@ def forward(
return_dict,
)


class SubwordXLMForTokenClassification(XLMRobertaForTokenClassification):
config_class = SubwordXLMConfig

_keys_to_ignore_on_load_unexpected = [r"pooler"]
_keys_to_ignore_on_load_missing = [r"position_ids"]

def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels

self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
classifier_dropout = (
config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
Expand All @@ -977,7 +978,7 @@ def __init__(self, config):

# Initialize weights and apply final processing
self.post_init()

def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
Expand Down Expand Up @@ -1005,9 +1006,8 @@ def forward(
output_hidden_states,
return_dict,
)





AutoModel.register(LACanineConfig, LACanineModel)
AutoModelForTokenClassification.register(LACanineConfig, LACanineForTokenClassification)

Expand All @@ -1020,23 +1020,24 @@ def forward(
if __name__ == "__main__":
# test XLM
from transformers import AutoConfig, AutoTokenizer

model_str = "xlm-roberta-base"
config = AutoConfig.from_pretrained(model_str)
config.num_labels = 4
config.num_hidden_layers = 9
backbone = SubwordXLMForTokenClassification.from_pretrained(model_str, config=config)
print(summary(backbone, depth=4))

# some sample input
text = "This is a test\n sentence \n\n"
tokenizer = AutoTokenizer.from_pretrained(model_str)

tokens = tokenizer(text, return_tensors="pt", add_special_tokens=False)
from tokenizers import AddedToken

tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
print(tokenizer.tokenize(text))
print(tokenizer.encode(text))
print(tokens)
# forward pass
print(backbone(**tokens))

10 changes: 6 additions & 4 deletions wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def get_metrics(labels, preds):

return metrics, info


def get_token_spans(tokenizer: object, offsets_mapping: list, tokens: list):
token_spans = []
for idx, token in enumerate(tokens):
Expand All @@ -62,17 +63,19 @@ def get_token_spans(tokenizer: object, offsets_mapping: list, tokens: list):

return token_spans


def token_to_char_probs(text: str, tokens: list, token_probs: np.ndarray, tokenizer, offsets_mapping):
# some very low number since at non-ending position, predicting a newline is impossible
char_probs = np.zeros(len(text)) - 10000
token_spans = get_token_spans(tokenizer, offsets_mapping, tokens)

for i, ((start, end), prob, token) in enumerate(zip(token_spans, token_probs, tokens)):
# assign the token's prob to the last char of the token
char_probs[end - 1] = prob
char_probs[end - 1] = prob

return char_probs


def evaluate_sentence(
lang_code,
sentences,
Expand Down Expand Up @@ -104,11 +107,11 @@ def evaluate_sentence(
logits = logits[0]
if offsets_mapping is not None:
offsets_mapping = offsets_mapping[0]

true_end_indices = np.cumsum(np.array([len(s) for s in sentences])) + np.arange(len(sentences)) * len(separator)
newline_labels = np.zeros(len(text))
newline_labels[true_end_indices - 1] = 1

if "xlm" in model.config.model_type:
tokens = tokenizer.tokenize(text, verbose=False)
char_probs = token_to_char_probs(text, tokens, logits[:, positive_index], tokenizer, offsets_mapping)
Expand All @@ -127,4 +130,3 @@ def evaluate_sentence(
info["newline_probs_pysbd"] = newline_probs_pysbd

return metrics["pr_auc"], info

Loading

0 comments on commit 176e25e

Please sign in to comment.