Skip to content

Commit

Permalink
Mode AdvancedSpace is working, need to see at training time now
Browse files Browse the repository at this point in the history
PonteIneptique committed Apr 12, 2022
1 parent ec9904c commit 72e8cd9
Showing 6 changed files with 321 additions and 139 deletions.
21 changes: 16 additions & 5 deletions boudams/cli.py
Original file line number Diff line number Diff line change
@@ -13,9 +13,11 @@

from boudams.tagger import BoudamsTagger, OptimizerParams
from boudams.trainer import Trainer, logger, ACCEPTABLE_MONITOR_METRICS
from boudams.encoder import LabelEncoder, SimpleSpaceMode
from boudams.encoder import LabelEncoder
from boudams.modes import SimpleSpaceMode, AdvancedSpaceMode
from boudams.dataset import BoudamsDataset
from boudams.data_generation import base as dataset_base, plaintext, splitter as all_splitters
from boudams.utils import parse_params


@click.group()
@@ -28,26 +30,33 @@ def dataset():
""" Dataset related functions """



def _get_mode(mode: str, mode_kwargs: str = "") -> SimpleSpaceMode:
if mode == "simple-space":
return SimpleSpaceMode()
elif mode == "advanced-space":
return AdvancedSpaceMode()


@dataset.command("convert")
@click.argument("splitter", type=click.Choice(['words', 'sentence']))
@click.argument("input_path", nargs=-1, type=click.Path(file_okay=True, dir_okay=False))
@click.argument("output_path", type=click.Path(file_okay=False))
@click.option("--mode", type=click.Choice(['simple-space']),
@click.option("--mode", type=click.Choice(['simple-space', 'advanced-space']),
default="simple-space", show_default=True,
help="Type of encoder you want to set-up")
@click.option("--splitter-regex", type=str, default=None, show_default=True,
help="Regular expression for some splitter")
@click.option("--min-chars", type=int, default=2, show_default=True,
help="Discard samples smaller than min-chars")
@click.option("--min_words", type=int, default=2, show_default=True,
help="Minimum of words to build a line [Word splitter only]")
@click.option("--max_words", type=int, default=10, show_default=True,
help="Maximum number of words to build a line [Word splitter only]")
def convert(output_path, input_path, mode, splitter, splitter_regex, min_words, max_words):
@click.option("--mode-ratios", type=str, default="", show_default=True,
help="Token ratios for modes at mask generation. Eg. `keep-space=.3&fake-space=.01`"
"will have a 30% chance of keeping a space and a 1% one to generate fake space after each char")
def convert(output_path, input_path, mode, splitter, splitter_regex, min_words, max_words, min_chars,
mode_ratios):
""" Build sequence training data using files with [METHOD] format in [INPUT_PATH] and saving the
converted format into [OUTPUT_PATH]
@@ -65,7 +74,9 @@ def convert(output_path, input_path, mode, splitter, splitter_regex, min_words,
)
plaintext.convert(
input_path, output_path,
splitter=splitter, mode=_get_mode(mode=mode)
splitter=splitter, mode=_get_mode(mode=mode),
min_chars=min_chars,
token_ratio=parse_params(mode_ratios)
)


8 changes: 5 additions & 3 deletions boudams/data_generation/plaintext.py
Original file line number Diff line number Diff line change
@@ -6,8 +6,7 @@

from typing import Iterable, Union, Dict


from boudams.encoder import SimpleSpaceMode
from boudams.modes import SimpleSpaceMode
from boudams.data_generation.splitter import Splitter


@@ -20,6 +19,7 @@ def convert(
splitter: Splitter,
token_ratio: Dict[str, float] = None,
mode: SimpleSpaceMode = None,
min_chars: int = 5,
**kwargs
):
""" Build sequence to train data over using TSV or TAB files where either the first
@@ -55,7 +55,9 @@ def convert(
line = _SPACES.sub(" ", line)
for sequence in splitter.split(line.strip()):
if sequence.strip():
output_fio.write("\t".join(mode.generate_mask(sequence, tokens_ratio=token_ratio))+"\n")
sample, mask = mode.generate_mask(sequence, tokens_ratio=token_ratio)
if len(sample) >= min_chars:
output_fio.write("\t".join([sample, mask])+"\n")


if __name__ == "__main__":
2 changes: 1 addition & 1 deletion boudams/data_generation/splitter.py
Original file line number Diff line number Diff line change
@@ -44,7 +44,7 @@ def split(self, text: str) -> Iterable[str]:


class SentenceSplitter(Splitter):
def __init__(self, splitter: re.Pattern = r"(([\.\;!\?\"]+)"):
def __init__(self, splitter: re.Pattern = r"([\.\;!\?\"]+)"):
self.splitter: re.Regex = re.compile(splitter)

def split(self, text: str) -> Iterable[str]:
128 changes: 2 additions & 126 deletions boudams/encoder.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
import re

import tabulate
import torch
import torch.cuda
import torch.nn
@@ -13,132 +10,11 @@

from mufidecode import mufidecode

from boudams.modes import SimpleSpaceMode

DEFAULT_INIT_TOKEN = "<SOS>"
DEFAULT_EOS_TOKEN = "<EOS>"
DEFAULT_PAD_TOKEN = "垫"
DEFAULT_UNK_TOKEN = "<UNK>"
DEFAULT_MASK_TOKEN = "x"
DEFAULT_WB_TOKEN = "|"


class SimpleSpaceMode:
NormalizeSpace: bool = True

class MaskValueException(Exception):
""" Exception raised when a token is longer than a character """

class MaskGenerationError(Exception):
""" Exception raised when a mask is not of the same size as the input transformed string """

def __init__(self, masks: Dict[str, int] = None):
self.name = "Default"
self.masks_to_index: Dict[str, int] = masks or {
DEFAULT_PAD_TOKEN: 0,
DEFAULT_MASK_TOKEN: 1,
DEFAULT_WB_TOKEN: 2
}
self.index_to_mask: Dict[str, int] = masks or {
0: DEFAULT_PAD_TOKEN,
1: DEFAULT_MASK_TOKEN,
2: DEFAULT_WB_TOKEN
}
self.index_to_masks_name: Dict[int, str] = {
0: "PAD",
1: "W",
2: "WB"
}
self.masks_name_to_index: Dict[str, int] = {
"PAD": 0,
"W": 1,
"WB": 2
}
self.pad_token = DEFAULT_PAD_TOKEN
self._pad_token_index = self.masks_to_index[self.pad_token]
self._space = re.compile(r"\s")

self._check()

def _check(self):
for char in self.masks_to_index:
if char != self.pad_token: # We do not limit <PAD> to a single char because it's not dumped in a string
if len(char) != 1:
raise SimpleSpaceMode.MaskValueException(
f"Mask characters cannot be longer than one char "
f"(Found: `{char}` "
f"for {self.index_to_masks_name[self.masks_to_index[char]]})")

@property
def pad_token_index(self) -> int:
return self._pad_token_index

@property
def classes_count(self):
return len(self.masks_to_index)

def generate_mask(
self,
string: str,
tokens_ratio: Optional[Dict[str, float]] = None
) -> Tuple[str, str]:
""" From a well-formed ground truth input, generates a fake error-containing string
:param string: Input string
:param tokens_ratio: Dict of TokenName
:return:
>>> (SimpleSpaceMode()).generate_mask("j'ai un cheval")
('xxx|x|xxxxx|', "j'aiuncheval")
"""
split = self._space.split(string)
masks = DEFAULT_WB_TOKEN.join([DEFAULT_MASK_TOKEN * (len(tok)-1) for tok in split]) + DEFAULT_WB_TOKEN
model_input = "".join(split)
assert len(masks) == len(model_input), f"Length of input and mask should be equal `{masks}` + `{model_input}`"
return model_input, masks

def encode_mask(self, masked_string: Sequence[str]) -> List[int]:
""" Encodes into a list of index a string
:param masked_string: String masked by the current Mode
:return: Pre-tensor input
>>> (SimpleSpaceMode()).encode_mask("xxx|x|xxxxx|")
[1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2]
"""
return [self.masks_to_index[char] for char in masked_string]

def apply_mask_to_string(self, input_string: str, masked_string: List[int]) -> str:
def apply():
for char, mask in zip(input_string, masked_string):
if mask == self.pad_token_index:
break
if self.index_to_masks_name[mask] == "WB":
yield char + " "
else:
yield char
return "".join(apply())

def prepare_input(self, string: str) -> str:
return self._space.sub("", string)

def computer_wer(self, confusion_matrix):
indexes = torch.tensor([
i
for i in range(self.classes_count)
if i != self.pad_token_index
]).type_as(confusion_matrix)

clean_matrix = confusion_matrix[indexes][:, indexes]
space_token_index = self.masks_to_index[DEFAULT_WB_TOKEN]
if space_token_index > self.pad_token_index:
space_token_index -= 1
nb_space_gt = (
clean_matrix[space_token_index].sum() +
clean_matrix[:, space_token_index].sum() -
clean_matrix[space_token_index, space_token_index]
)

nb_missed_space = clean_matrix.sum() - torch.diagonal(clean_matrix, 0).sum()
return nb_missed_space / nb_space_gt


class LabelEncoder:
288 changes: 288 additions & 0 deletions boudams/modes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,288 @@
import random
import re
from typing import Dict, Optional, Tuple, Sequence, List

import torch
from boudams.utils import parse_params

DEFAULT_PAD_TOKEN = "垫"
DEFAULT_MASK_TOKEN = "-"
DEFAULT_WB_TOKEN = "|"
DEFAULT_REMOVE_TOKEN = "⌫"
DEFAULT_ORIGINAL_TOKEN = ""


class MaskValueException(Exception):
""" Exception raised when a token is longer than a character """


class MaskGenerationError(Exception):
""" Exception raised when a mask is not of the same size as the input transformed string """


class SimpleSpaceMode:
NormalizeSpace: bool = True

def __init__(self, masks: Dict[str, int] = None):
self.name = "Default"
self.masks_to_index: Dict[str, int] = masks or {
DEFAULT_PAD_TOKEN: 0,
DEFAULT_MASK_TOKEN: 1,
DEFAULT_WB_TOKEN: 2
}
self.index_to_mask: Dict[str, int] = masks or {
0: DEFAULT_PAD_TOKEN,
1: DEFAULT_MASK_TOKEN,
2: DEFAULT_WB_TOKEN
}
self.index_to_masks_name: Dict[int, str] = {
0: "PAD",
1: "W",
2: "WB"
}
self.masks_name_to_index: Dict[str, int] = {
"PAD": 0,
"W": 1,
"WB": 2
}
self.pad_token = DEFAULT_PAD_TOKEN
self._pad_token_index = self.masks_to_index[self.pad_token]
self._space = re.compile(r"\s")

self._check()

def _check(self):
for char in self.masks_to_index:
if char != self.pad_token: # We do not limit <PAD> to a single char because it's not dumped in a string
if len(char) != 1:
raise MaskValueException(
f"Mask characters cannot be longer than one char "
f"(Found: `{char}` "
f"for {self.index_to_masks_name[self.masks_to_index[char]]})")

@property
def pad_token_index(self) -> int:
return self._pad_token_index

@property
def classes_count(self):
return len(self.masks_to_index)

def generate_mask(
self,
string: str,
tokens_ratio: Optional[Dict[str, float]] = None
) -> Tuple[str, str]:
""" From a well-formed ground truth input, generates a fake error-containing string
:param string: Input string
:param tokens_ratio: Dict of TokenName
:return:
>>> (SimpleSpaceMode()).generate_mask("j'ai un cheval")
('xxx|x|xxxxx|', "j'aiuncheval")
"""
split = self._space.split(string)
masks = DEFAULT_WB_TOKEN.join([DEFAULT_MASK_TOKEN * (len(tok)-1) for tok in split]) + DEFAULT_WB_TOKEN
model_input = "".join(split)
assert len(masks) == len(model_input), f"Length of input and mask should be equal `{masks}` + `{model_input}`"
return model_input, masks

def encode_mask(self, masked_string: Sequence[str]) -> List[int]:
""" Encodes into a list of index a string
:param masked_string: String masked by the current Mode
:return: Pre-tensor input
>>> (SimpleSpaceMode()).encode_mask("xxx|x|xxxxx|")
[1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2]
"""
return [self.masks_to_index[char] for char in masked_string]

def apply_mask_to_string(self, input_string: str, masked_string: List[int]) -> str:
def apply():
for char, mask in zip(input_string, masked_string):
if mask == self.pad_token_index:
break
if self.index_to_masks_name[mask] == "WB":
yield char + " "
else:
yield char
return "".join(apply())

def prepare_input(self, string: str) -> str:
return self._space.sub("", string)

def _clean_matrix(self, confusion_matrix, pad_token_index):
indexes = torch.tensor([
i
for i in range(self.classes_count)
if i != pad_token_index
]).type_as(confusion_matrix)

return confusion_matrix[indexes][:, indexes]

def computer_wer(self, confusion_matrix):
clean_matrix = self._clean_matrix(confusion_matrix, self.pad_token_index)

space_token_index = self.masks_to_index[DEFAULT_WB_TOKEN]
if space_token_index > self.pad_token_index:
space_token_index -= 1
nb_space_gt = (
clean_matrix[space_token_index].sum() +
clean_matrix[:, space_token_index].sum() -
clean_matrix[space_token_index, space_token_index]
)

nb_missed_space = clean_matrix.sum() - torch.diagonal(clean_matrix, 0).sum()
return nb_missed_space / nb_space_gt


class AdvancedSpaceMode(SimpleSpaceMode):
def __init__(self, masks: Dict[str, int] = None):
self.name = "Default"
self.masks_to_index: Dict[str, int] = masks or {
DEFAULT_PAD_TOKEN: 0,
DEFAULT_MASK_TOKEN: 1,
DEFAULT_WB_TOKEN: 2,
DEFAULT_REMOVE_TOKEN: 3,
DEFAULT_ORIGINAL_TOKEN: 4
}
self.index_to_mask: Dict[str, int] = masks or {
0: DEFAULT_PAD_TOKEN,
1: DEFAULT_MASK_TOKEN,
2: DEFAULT_WB_TOKEN,
3: DEFAULT_REMOVE_TOKEN,
4: DEFAULT_ORIGINAL_TOKEN
}
self.index_to_masks_name: Dict[int, str] = {
0: "PAD",
1: "W",
2: "WB",
3: "REMOVE",
4: "ORIGINAL"
}
self.masks_name_to_index: Dict[str, int] = {
"PAD": 0,
"W": 1,
"WB": 2,
"REMOVE": 3,
"ORIGINAL": 4
}
self.pad_token = DEFAULT_PAD_TOKEN
self._pad_token_index = self.masks_to_index[self.pad_token]
self._space = re.compile(r"\s+")

self._check()

def _check(self):
for char in self.masks_to_index:
if char != self.pad_token: # We do not limit <PAD> to a single char because it's not dumped in a string
if len(char) != 1:
raise MaskValueException(
f"Mask characters cannot be longer than one char "
f"(Found: `{char}` "
f"for {self.index_to_masks_name[self.masks_to_index[char]]})")

@property
def pad_token_index(self) -> int:
return self._pad_token_index

@property
def classes_count(self):
return len(self.masks_to_index)

def generate_mask(
self,
string: str,
tokens_ratio: Optional[Dict[str, float]] = None
) -> Tuple[str, str]:
""" From a well-formed ground truth input, generates a fake error-containing string
:param string: Input string
:param tokens_ratio: Dict of TokenName
:return:
>>> (AdvancedSpaceMode()).generate_mask("j'ai un cheval", tokens_ratio={"fake-space": 1, 'keep-space': 0})
("j ' a iu nc h e v a l", '-⌫-⌫-⌫|-⌫|-⌫-⌫-⌫-⌫-⌫|')
>>> (AdvancedSpaceMode()).generate_mask("j'ai un cheval", tokens_ratio={"fake-space": 0, 'keep-space': 1})
("j'ai un cheval", '---|-|-----|')
"""

model_input: List[str] = []
masks: List[str] = []
string = string.strip()
for char, next_char in zip(string, string[1:]+" "):
if char.strip(): # It's not a space
model_input.append(char)
masks.append(DEFAULT_MASK_TOKEN)
if next_char.strip() and random.random() < tokens_ratio.get("fake-space", 0):
model_input.append(" ")
masks.append(DEFAULT_REMOVE_TOKEN)
else:
if len(masks):
masks[-1] = DEFAULT_WB_TOKEN
if random.random() < tokens_ratio.get("keep-space", 0):
model_input.append(" ") # Space are normalized
masks.append(DEFAULT_ORIGINAL_TOKEN)
masks[-1] = DEFAULT_WB_TOKEN
assert len(masks) == len(model_input), f"Length of input and mask should be equal `{masks}` + `{model_input}`"
return "".join(model_input), "".join(masks)

def encode_mask(self, masked_string: Sequence[str]) -> List[int]:
""" Encodes into a list of index a string
:param masked_string: String masked by the current Mode
:return: Pre-tensor input
>>> (AdvancedSpaceMode()).encode_mask("-⌫--|-|-|")
[1, 3, 1, 1, 2, 1, 2, 4, 1, 2]
"""
return [self.masks_to_index[char] for char in masked_string]

def apply_mask_to_string(self, input_string: str, masked_string: List[int]) -> str:
""" Apply a prediction to a string
:param input_string:
:param masked_string:
:return:
>>> (AdvancedSpaceMode()).apply_mask_to_string("J 'aiun nu", [1, 3, 1, 1, 2, 1, 2, 4, 1, 2])
"J'ai un nu"
"""
def apply():
for char, mask in zip(input_string, masked_string):
if mask == self.pad_token_index:
break
if self.index_to_masks_name[mask] == "WB":
yield char + " "
elif self.index_to_masks_name[mask] == "REMOVE":
continue
else:
yield char
return self._space.sub(" ", "".join(apply())).strip()

def prepare_input(self, string: str) -> str:
return self._space.sub(" ", string).strip()

def computer_wer(self, confusion_matrix):
clean_matrix = self._clean_matrix(confusion_matrix, self.pad_token_index)

space_tokens = [
space_index if space_index < self.pad_token_index else space_index-1
for space_index in [
self.masks_to_index[DEFAULT_WB_TOKEN],
self.masks_to_index[DEFAULT_REMOVE_TOKEN],
self.masks_to_index[DEFAULT_REMOVE_TOKEN]
]
]

nb_space_gt = (
clean_matrix[space_tokens].sum() +
clean_matrix[:, space_tokens].sum() -
clean_matrix[space_tokens, space_tokens].sum()
)

nb_missed_space = clean_matrix.sum() - torch.diagonal(clean_matrix, 0).sum()
return nb_missed_space / nb_space_gt
13 changes: 9 additions & 4 deletions boudams/utils.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,27 @@
import math
import gzip
import time
import uuid
from contextlib import contextmanager
import os
import shutil
import unidecode
import warnings
from typing import Dict, Any
from urllib.parse import parse_qs

import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import confusion_matrix
from sklearn.utils.multiclass import unique_labels

Cache = {}


def parse_params(string: str) -> Dict[str, Any]:
return {
key: eval(value[0]) # This is not safe, but this is only local ?
for key, value in parse_qs(string).items()
}


def improvement_on_min_or_max(metric: str) -> str:
if "loss" in metric or "wer" in metric:
return "min"

0 comments on commit 72e8cd9

Please sign in to comment.