diff --git a/boudams/cli.py b/boudams/cli.py index b46b51b..46c7811 100644 --- a/boudams/cli.py +++ b/boudams/cli.py @@ -413,14 +413,19 @@ def tag(model, filename, device="cpu", batch_size=64): print("Model loaded.") for file in tqdm.tqdm(filename): out_name = file.name.replace(".txt", ".tokenized.txt") - with open(file) as f, open(out_name, "w") as out_io: - content = f.read() # Could definitely be done a better way... - if model.vocabulary.mode.name == "simple-space": - content = re.sub(r"\s+", "", content) - elif model.vocabulary.mode.NormalizeSpace: - content = re.sub(r"\s+", " ", content) + content = file.read() # Could definitely be done a better way... + if model.vocabulary.mode.name == "simple-space": + content = re.sub(r"\s+", "", content) + elif model.vocabulary.mode.NormalizeSpace: + content = re.sub(r"\s+", " ", content) + file.close() + with open(out_name, "w") as out_io: out = '' - for tokenized_string in model.annotate_text(content, batch_size=batch_size, device=device): + for tokenized_string in model.annotate_text( + content, + batch_size=batch_size, + device=device + ): out = out + tokenized_string + "\n" out_io.write(out) print("--- File " + file.name + " has been tokenized") @@ -438,7 +443,7 @@ def tag_check(config_model, content, device="cpu", batch_size=64): boudams = BoudamsTagger.load(model, device=device) boudams.eval() click.echo(f"\t[X] Model loaded") - click.echo("\n".join(boudams.annotate_text(content, splitter="([\.!\?]+)", batch_size=batch_size, device=device))) + click.echo("\n".join(boudams.annotate_text(content, splitter=r"([\.!\?]+)", batch_size=batch_size, device=device))) @cli.command("graph") diff --git a/boudams/tagger.py b/boudams/tagger.py index 7f30845..51f2aad 100644 --- a/boudams/tagger.py +++ b/boudams/tagger.py @@ -449,27 +449,47 @@ def annotate(self, texts: List[str], batch_size=32, device: str = "cpu"): for index in range(len(translations)): yield "".join(translations[order.index(index)]) - def annotate_text(self, string, splitter=r"([⁊\W\d]+)", batch_size=32, device: str = "cpu"): - splitter = re.compile(splitter) - splits = splitter.split(string) - - tempList = splits + [""] * 2 - strings = ["".join(tempList[n:n + 2]) for n in range(0, len(splits), 2)] - strings = list(filter(lambda x: x.strip(), strings)) + @staticmethod + def _apply_max_size(tokens: str, size: int): + # Use finditer when applied to things with spaces ? + # [(m.start(0), m.end(0)) for m in re.finditer(pattern, string)] ? + current = [] + for tok in re.split(r"(\s+)", tokens): + if not tok: + continue + current.append(tok) + string_size = len("".join(current)) + if string_size > size: + yield "".join(current[:-1]) + current = current[-1:] + elif string_size == size: + yield "".join(current) + current = [] + if current: + yield "".join(current) + + def annotate_text(self, single_sentence, splitter: Optional[str] = None, batch_size=32, device: str = "cpu", rolling=True): + if splitter is None: + # ToDo: Mode specific splitter ? + splitter = r"([\.!\?]+)" + splitter = re.compile(splitter) + sentences = [tok for tok in splitter.split(single_sentence) if tok.strip()] + if self._maximum_sentence_size: + # This is currently quite limitating. + # If the end token is ending with a W and not a WB, there is no way to "correct it" + # We'd need a rolling system: cut in the middle of maximum sentence size ? treated = [] max_size = self._maximum_sentence_size - for string in strings: - if len(string) > max_size: - treated.extend([ - "".join(string[n:n + max_size]) - for n in range(0, len(string), max_size) - ]) + for single_sentence in sentences: + if len(single_sentence) > max_size: + treated.extend(self._apply_max_size(single_sentence, max_size)) else: - treated.append(string) - strings = treated - yield from self.annotate(strings, batch_size=batch_size, device=device) + treated.append(single_sentence) + sentences = treated + + yield from self.annotate(sentences, batch_size=batch_size, device=device) @classmethod def load(cls, fpath="./model.boudams_model", device=None):