Skip to content

Commit

Permalink
[WIP] Moving the mask mechanism to a new Mode class
Browse files Browse the repository at this point in the history
PonteIneptique committed Apr 11, 2022

Verified

This commit was signed with the committer’s verified signature.
huangzhen1997 Joe Huang
1 parent 149c9a1 commit 240158c
Showing 5 changed files with 222 additions and 152 deletions.
94 changes: 46 additions & 48 deletions boudams/data_generation/plaintext.py
Original file line number Diff line number Diff line change
@@ -52,57 +52,55 @@ def convert(
os.makedirs(os.path.dirname(output_fp), exist_ok=True)
key = "form" # For dict reader

with open(input_fp) as input_fio:
with open(output_fp, "w") as output_fio:
with open(input_fp) as input_fio, open(output_fp, "w") as output_fio:
sequence = []
next_sequence = random.randint(min_words, max_words)

content = _apos.sub(" ", input_fio.read())

for word in _splitter.split(content):
word = _space.sub("", word)
if not word:
continue
sequence.append(word)

char_length = len("".join(sequence))

# If the char length is greater than our maximum
# we create a sentence now by saying next sequence is now.
if char_length > max_char_length * 0.9:
next_sequence = len(sequence)

# If we reached the random length for the word count
if len(sequence) == next_sequence:
# If however we have a string that is too small (like less then 7 chars), we'll pack it
# up next time
if char_length < min_char_length:
next_sequence += 1
continue

sequence = []
next_sequence = random.randint(min_words, max_words)
# If the sentence length is smaller than MAX_CHAR_LENGTH, we randomly add noise
if random.random() < noise_char_random:
index = random.randint(1, len(sequence))
sequence = sequence[:index] + \
[noise_char] * random.randint(1, max_noise_char) + \
sequence[index:]

content = _apos.sub(" ", input_fio.read())
write_sentence(output_fio, sequence)

for word in _splitter.split(content):
word = _space.sub("", word)
if not word:
continue
sequence.append(word)

char_length = len("".join(sequence))

# If the char length is greater than our maximum
# we create a sentence now by saying next sequence is now.
if char_length > max_char_length * 0.9:
next_sequence = len(sequence)

# If we reached the random length for the word count
if len(sequence) == next_sequence:
# If however we have a string that is too small (like less then 7 chars), we'll pack it
# up next time
if char_length < min_char_length:
next_sequence += 1
continue

# If the sentence length is smaller than MAX_CHAR_LENGTH, we randomly add noise
if random.random() < noise_char_random:
index = random.randint(1, len(sequence))
sequence = sequence[:index] + \
[noise_char] * random.randint(1, max_noise_char) + \
sequence[index:]

write_sentence(output_fio, sequence)

# We randomly keep the last word for next sentence
if random.random() < random_keep:
kept = random.randint(1, max_kept)
sequence = sequence[-kept:] + []
else:
sequence = []

# We set-up the next sequence length
next_sequence = random.randint(min_words, max_words) + len(sequence)

# At the end of the loop, if sequence is not empty
if sequence and len("".join(sequence)) > min_char_length:
write_sentence(output_fio, sequence, max_chars=max_char_length)
# We randomly keep the last word for next sentence
if random.random() < random_keep:
kept = random.randint(1, max_kept)
sequence = sequence[-kept:] + []
else:
sequence = []

# We set-up the next sequence length
next_sequence = random.randint(min_words, max_words) + len(sequence)

# At the end of the loop, if sequence is not empty
if sequence and len("".join(sequence)) > min_char_length:
write_sentence(output_fio, sequence, max_chars=max_char_length)


if __name__ == "__main__":
4 changes: 2 additions & 2 deletions boudams/dataset.py
Original file line number Diff line number Diff line change
@@ -50,8 +50,8 @@ def _setup(self):
x, y = self._l_e.readunit(line)
self.encoded.append(
GT_PAIR(
*self._l_e.inp_to_numerical(x),
*self._l_e.gt_to_numerical(y),
*self._l_e.sent_to_numerical(x),
*self._l_e.mask_to_numerical(y),
f"File:{file}#Line:{line_index}"
)
)
272 changes: 172 additions & 100 deletions boudams/encoder.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import re

import tabulate
import torch
import torch.cuda
@@ -13,64 +15,163 @@

DEFAULT_INIT_TOKEN = "<SOS>"
DEFAULT_EOS_TOKEN = "<EOS>"
DEFAULT_PAD_TOKEN = "<PAD>"
DEFAULT_PAD_TOKEN = ""
DEFAULT_UNK_TOKEN = "<UNK>"
DEFAULT_MASK_TOKEN = "x"
DEFAULT_WB_TOKEN = "|"


class SimpleSpaceMode:
class MaskValueException(Exception):
""" Exception raised when a token is longer than a character """

class MaskGenerationError(Exception):
""" Exception raised when a mask is not of the same size as the input transformed string """

def __init__(self, masks: Dict[str, int] = None):
self.name = "Default"
self.masks_to_index: Dict[str, int] = masks or {
DEFAULT_PAD_TOKEN: 0,
DEFAULT_MASK_TOKEN: 1,
DEFAULT_WB_TOKEN: 2
}
self.index_to_mask: Dict[str, int] = masks or {
0: DEFAULT_PAD_TOKEN,
1: DEFAULT_MASK_TOKEN,
2: DEFAULT_WB_TOKEN
}
self.index_to_masks_name: Dict[int, str] = {
0: "PAD",
1: "W",
2: "WB"
}
self.masks_name_to_index: Dict[str, int] = {
"PAD": 0,
"W": 1,
"WB": 2
}
self.pad_token = DEFAULT_PAD_TOKEN
self._pad_token_index = self.masks_to_index[self.pad_token]
self._space = re.compile(r"\s")

self._check()

def _check(self):
for char in self.masks_to_index:
if char != self.pad_token: # We do not limit <PAD> to a single char because it's not dumped in a string
if len(char) != 1:
raise SimpleSpaceMode.MaskValueException(f"Mask characters cannot be longer than one char "
f"(Found: `{char}` "
f"for {self.index_to_masks_name[self.masks_to_index[char]]})")

@property
def pad_token_index(self) -> int:
return self._pad_token_index

@property
def classes_count(self):
return len(self.masks_to_index)

def generate_mask(
self,
string: str,
tokens_ratio: Optional[Dict[str, float]] = None
) -> Tuple[str, str]:
""" From a well-formed ground truth input, generates a fake error-containing string
:param string: Input string
:param tokens_ratio: Dict of TokenName
:return:
>>> (SimpleSpaceMode()).generate_mask("j'ai un cheval")
('xxx|x|xxxxx|', "j'aiuncheval")
"""
split = self._space.split(string)
masks = DEFAULT_WB_TOKEN.join([DEFAULT_MASK_TOKEN * (len(tok)-1) for tok in split]) + DEFAULT_WB_TOKEN
model_input = "".join(split)
assert len(masks) == len(model_input), f"Length of input and mask should be equal {masks+model_input}"
return masks, model_input

def encode_mask(self, masked_string: Sequence[str]) -> List[int]:
""" Encodes into a list of index a string
:param masked_string: String masked by the current Mode
:return: Pre-tensor input
>>> (SimpleSpaceMode()).encode_mask("xxx|x|xxxxx|")
[1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2]
"""
return [self.masks_to_index[char] for char in masked_string]

def apply_mask_to_string(self, input_string: str, masked_string: List[int]) -> str:
def apply():
for char, mask in zip(input_string, masked_string):
if mask == self.pad_token_index:
break
if self.index_to_masks_name[mask] == "WB":
yield char+ " "
else:
yield char
return "".join(apply())

def prepare_input(self, string: str) -> str:
return self._space.sub("", string)


class LabelEncoder:
Modes = {
"SimpleSpace": SimpleSpaceMode
}

# For test purposes
EXAMPLE_LINE = "\t".join(['a b c D', 'x x x x'])

def __init__(
self,
pad_token=DEFAULT_PAD_TOKEN,
unk_token=DEFAULT_UNK_TOKEN,
mask_token=DEFAULT_MASK_TOKEN,
mode: str = "SimpleSpace",
maximum_length: int = None,
lower: bool = True,
remove_diacriticals: bool = True
remove_diacriticals: bool = True,
unk_token: str = DEFAULT_UNK_TOKEN
):
self._mode: SimpleSpaceMode = self.Modes[mode]()

self.pad_token: str = pad_token
self.unk_token: str = unk_token
self.mask_token: str = mask_token
self.space_token: str = " "
self.pad_token: str = self._mode.pad_token
self.pad_token_index: int = 0 # Only for CHARS

self.pad_token_index: int = 2
self.space_token_index: int = 1
self.mask_token_index: int = 0
self.unk_token_index: int = 0 # Put here because it isn't used in masked
self.unk_token: str = unk_token
self.unk_token_index: int = 1

self.max_len: Optional[int] = maximum_length
self.lower = lower
self.remove_diacriticals = remove_diacriticals

self.itos: Dict[int, str] = {
self.pad_token_index: self.pad_token,
self.unk_token_index: self.unk_token,
self.space_token_index: self.space_token
} # Id to string for reversal
self.pad_token: 0,
self.unk_token: self.unk_token_index
} # String to ID

self.stoi: Dict[str, int] = {
self.pad_token: self.pad_token_index,
self.unk_token: self.unk_token_index,
self.space_token: self.space_token_index
self.unk_token: self.unk_token_index
} # String to ID

# Mask dictionaries
self.itom: Dict[int, str] = {
self.pad_token_index: self.pad_token,
self.mask_token_index: self.mask_token,
self.space_token_index: self.space_token
}
self.mtoi: Dict[str, int] = {
self.pad_token: self.pad_token_index,
self.mask_token: self.mask_token_index,
self.space_token: self.space_token_index
}
self.itom: Dict[int, str] = dict([
(tok_id, mask) for (mask, tok_id) in self._mode.masks_to_index.items()
])
self.mtoi: Dict[int, str] = dict([
(tok_id, mask) for (mask, tok_id) in self._mode.masks_to_index.items()
])

@property
def mask_count(self):
return len(self.mtoi)

@property
def mode(self):
return self._mode

def __len__(self):
return len(self.stoi)

@@ -107,14 +208,17 @@ def build(self, *paths, debug=False):
logging.debug(str(counter))
logging.debug(self.stoi)

def readunit(self, line) -> Tuple[Tuple[str, ...], Tuple[str, ...]]:
def readunit(self, line) -> Tuple[Tuple[str, ...], str]:
""" Read a single line
:param line:
:return:
>>> (LabelEncoder(lower=True)).readunit(LabelEncoder.EXAMPLE_LINE)
(('a', ' ', 'b', ' ', 'c', ' ', 'd'), 'x x x x')
"""
inp, out = line.strip().split("\t")
return tuple(self.prepare(inp)), tuple(self.prepare(out))
return tuple(self.prepare(inp)), out

def prepare(self, inp: str):
if self.remove_diacriticals:
@@ -165,21 +269,15 @@ def pad_and_tensorize(

return torch.tensor(tensor), torch.tensor(lengths), order

def gt_to_numerical(self, sentence: Sequence[str]) -> Tuple[List[int], int]:
def mask_to_numerical(self, sentence: Sequence[str]) -> Tuple[List[int], int]:
""" Transform GT to numerical
:param sentence: Sequence of characters (can be a straight string) with spaces
:return: List of mask indexes
"""
numericals = [
self.mask_token_index if ngram[1] != " " else self.space_token_index
for ngram in zip(*[sentence[i:] for i in range(2)])
if ngram[0] != " "
] + [self.space_token_index]
return self.mode.encode_mask(sentence), len(sentence)

return numericals, len(sentence) - sentence.count(" ")

def inp_to_numerical(self, sentence: Sequence[str]) -> Tuple[List[int], int]:
def sent_to_numerical(self, sentence: Sequence[str]) -> Tuple[List[int], int]:
""" Transform input sentence to numerical
:param sentence: Sequence of characters (can be a straight string) without spaces
@@ -190,66 +288,44 @@ def inp_to_numerical(self, sentence: Sequence[str]) -> Tuple[List[int], int]:
len(sentence)
)

def numerical_to_sent(self, encoded_sentence: List[int]) -> str:
""" Transform a list of integers to a string
:param encoded_sentence: List of index
:return: Characters
"""
return "".join([
self.itos[char_idx]
for char_idx in encoded_sentence
if char_idx != self.pad_token
])

def reverse_batch(
self,
batch: Union[list, torch.Tensor],
ignore: Optional[Tuple[str, ...]] = None,
masked: Optional[Union[list, torch.Tensor]] = None
self,
batch: Union[list, torch.Tensor],
mask_batch: Optional[Union[list, torch.Tensor]] = None
):
ignore = ignore or ()
""" Produce result strings for a batch
:param batch: Input batch
:param mask_batch: Output batch
:return: List of strings with applied masks
"""
# If dimension is [sentence_len, batch_size]
if not isinstance(batch, list):

with torch.cuda.device_of(batch):
batch = batch.tolist()

if masked is not None:
if not isinstance(masked, list):
with torch.cuda.device_of(masked):
masked = masked.tolist()

if not isinstance(masked[0][0], str):
masked = [
[
self.itos[masked_token]
for masked_token in sentence
]
for sentence in masked
]
else:
masked = [
list(sentence)
for sentence in masked
]

return [
[
tok
for masked_token, mask_token in zip(masked_sentence, space_mask)
if space_mask not in ignore and masked_token not in ignore
for tok in [masked_token] + ([" "] if mask_token == self.space_token_index else [])
]
for masked_sentence, space_mask in zip(masked, batch)
]

if ignore is True:
batch = [
[
self.itos[ind]
for ind in ex
if ind not in ignore
]
for ex in batch
]
else:
batch = [
[
self.itos[ind]
for ind in ex
]
for ex in batch
] # denumericalize
return batch
if not isinstance(mask_batch, list):
with torch.cuda.device_of(mask_batch):
mask_batch = mask_batch.tolist()

return [
self.mode.apply_mask_to_string(
input_string=self.numerical_to_sent(batch_seq),
masked_string=masked_seq
)
for batch_seq, masked_seq in zip(batch, mask_batch)
]

def transcribe_batch(self, batch: List[List[str]]):
for sentence in batch:
@@ -289,14 +365,10 @@ def dump(self) -> str:
})

def format_confusion_matrix(self, confusion: List[List[int]]):
beautiful = {
self.mask_token: "Char",
self.space_token: "Space char"
}
header = [
"",
*[
beautiful.get(self.itom[index], self.itom[index])
self.mode.index_to_masks_name.get(index, index)
for index in sorted(list(self.itom.keys()))
]
]
@@ -368,7 +440,7 @@ def format_confusion_matrix(self, confusion: List[List[int]]):

# Somehow, although stuff IS padded and each sequence should have the same size, this is not the case...
# I definitely need to spleep on it
reversed_data = list(label_encoder.reverse_batch(y, masked=x))
reversed_data = list(label_encoder.reverse_batch(y, mask_batch=x))

assert [
l
2 changes: 1 addition & 1 deletion boudams/model/linear.py
Original file line number Diff line number Diff line change
@@ -197,6 +197,6 @@ def predict(self, src, src_len, label_encoder: "LabelEncoder",
logits = torch.argmax(out, -1)
return label_encoder.reverse_batch(
logits,
masked=override_src or src,
mask_batch=override_src or src,
ignore=(self.pad_idx, )
)
2 changes: 1 addition & 1 deletion boudams/tagger.py
Original file line number Diff line number Diff line change
@@ -399,7 +399,7 @@ def annotate(self, texts: List[str], batch_size=32, device: str = "cpu"):
for n in range(0, len(texts), batch_size):
batch = texts[n:n+batch_size]
xs = [
self.vocabulary.inp_to_numerical(self.vocabulary.prepare(s))
self.vocabulary.sent_to_numerical(self.vocabulary.prepare(s))
for s in batch
]
logging.info("Dealing with batch %s " % (int(n/batch_size)+1))

0 comments on commit 240158c

Please sign in to comment.