Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allennlp Predictor class that mimics gec_model #6

Merged
merged 25 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
9ba0aeb
Minor changes to make these tests pass if a cuda device is available.
ksteimel Sep 27, 2022
34386bd
Adding registered names for use by predictor
ksteimel Sep 27, 2022
36ab359
Adding expected test output.
ksteimel Sep 27, 2022
9bd6a7b
Added WIP docstring to GecBERTModel
ksteimel Sep 27, 2022
6f3f7ae
WIP Gec Predictor.
ksteimel Sep 27, 2022
9dd6bf1
words metadata is getting filled if unspecified when text_to_instance…
ksteimel Sep 28, 2022
64cbc8d
Using JustSpacesWordSplitter so that tokenization matches that used b…
ksteimel Sep 28, 2022
10e8069
Decode now adds the corrected sentence to the output dict.
ksteimel Sep 28, 2022
f6a9185
Updating gitignore to prevent adding .th files
ksteimel Sep 28, 2022
657f72b
Adding directory fixture as analogue to model archive
ksteimel Sep 28, 2022
e3cca59
Fixing errors in modeling code now that model.decode adds the origina…
ksteimel Sep 28, 2022
1c0025c
Adding conditional so that no correction is performed in decode if no…
ksteimel Sep 28, 2022
0f204ea
Appending start token when creating instances from json or string.
ksteimel Sep 28, 2022
67ef511
Start token is expected in ouptut.
ksteimel Sep 28, 2022
53a94dd
Drop START_TOKEN from output_dict["words"]. This interferes with the …
ksteimel Sep 28, 2022
86212b7
The outputs now no longer have $START_TOKEN in the corrected sentence…
ksteimel Sep 28, 2022
20a0692
Handling multiple iterations of correction in predictor now.
ksteimel Sep 28, 2022
0b88028
Changed location of weights file so it can be used by gec_predictor a…
ksteimel Sep 28, 2022
0b6b508
setup is now downloading weights file if it does not already exist.
ksteimel Sep 28, 2022
6f8f23c
Apply suggestions from code review
ksteimel Oct 12, 2022
7427f93
Removing unused imports, adding docstrings.
Oct 12, 2022
73aef1b
Removing unused predictions to labeled_instances method.
Oct 12, 2022
944993c
Updated docstring for decode()
Oct 12, 2022
709ba28
Removed unused imports.
Oct 12, 2022
b04376c
Adding back import of gec_predictor that shouldn't have been removed
ksteimel Oct 22, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,7 @@ dmypy.json
# PyCharm
.idea

*.sh
*.sh

# pytorch weights files
*.th
2 changes: 1 addition & 1 deletion gector/bert_token_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def forward(
return util.uncombine_initial_dims(selected_embeddings, offsets.size())


# @TokenEmbedder.register("bert-pretrained")
@TokenEmbedder.register("gec-bert-pretrained")
class PretrainedBertEmbedder(BertEmbedder):

"""
Expand Down
3 changes: 3 additions & 0 deletions gector/datareader.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def text_to_instance(self, tokens: List[Token], tags: List[str] = None,
fields: Dict[str, Field] = {}
sequence = TextField(tokens, self._token_indexers)
fields["tokens"] = sequence
# If words has not been explicitly passed in, create it from tokens.
if words is None:
words = [token.text for token in tokens]
fields["metadata"] = MetadataField({"words": words})
if tags is not None:
labels, detect_tags, complex_flag_dict = self.extract_tags(tags)
Expand Down
22 changes: 22 additions & 0 deletions gector/gec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,28 @@ def __init__(self, vocab_path=None, model_paths=None,
del_confidence=0,
resolve_cycles=False,
):
"""
Class used to enable prediction from GECToR model.

Parameters
----------
vocab_path
model_paths: List[Path]
weigths
max_len
min_len
lowercase_tokens
log
iterations
model_name
special_tokens_fix
is_ensemble
min_error_probability
confidence
del_confidence
resolve_cycles: bool
This parameter is unused.
"""
self.model_weights = list(map(float, weigths)) if weigths else [1] * len(model_paths)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.max_len = max_len
Expand Down
182 changes: 182 additions & 0 deletions gector/gec_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from typing import Dict, List

import numpy
from allennlp.predictors import Predictor
from allennlp.models import Model
from allennlp.common.util import sanitize
from overrides import overrides
from allennlp.common.util import JsonDict
from allennlp.data import DatasetReader, Instance, Token
from allennlp.data.tokenizers.word_splitter import JustSpacesWordSplitter
from allennlp.models import Model
from utils.helpers import START_TOKEN


@Predictor.register("gec-predictor")
class GecPredictor(Predictor):
"""
A Predictor for generating predictions from GECToR.

Note that currently, this is unable to handle ensemble predictions.
ksteimel marked this conversation as resolved.
Show resolved Hide resolved
"""

def __init__(self,
model: Model,
dataset_reader: DatasetReader,
iterations: int = 3) -> None:
"""
Parameters
---------
model: Model
An instantiated `Seq2Labels` model for performing grammatical error correction.
dataset_reader: DatasetReader
An instantiated dataset reader, typically `Seq2LabelsDatasetReader`.
iterations: int
This represents the number of times grammatical error correction is applied to the input.
"""
super().__init__(model, dataset_reader)
self._tokenizer = JustSpacesWordSplitter()
self._iterations = iterations

def predict(self, sentence: str) -> JsonDict:
"""
Generate error correction predictions for a single input string.

Parameters
----------
sentence: str
The input text to perform error correction on.

Returns
-------
JsonDict
A dictionary containing the following keys:
- logits_labels
- logits_d_tags
- class_probabilities_labels
- class_probabilities_d_tags
- max_error_probability
- words
- labels
- d_tags
- corrected_words
For an explanation of each of these see `Seq2Labels.decode()`.
"""
return self.predict_json({"sentence": sentence})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be a good idea to have proper numpy style docstrings for all methods that we add at least.


def predict_batch(self, sentences: List[str]) -> List[JsonDict]:
"""
Generate predictions for a sequence of input strings.

Parameters
----------
sentences: List[str]
A list of strings to correct.

Returns
-------
List[JsonDict]
A list of dictionaries, each containing the following keys:
- logits_labels
- logits_d_tags
- class_probabilities_labels
- class_probabilities_d_tags
- max_error_probability
- words
- labels
- d_tags
- corrected_words
For an explanation of each of these see `Seq2Labels.decode()`.
"""
return self.predict_batch_json([{"sentence": sentence} for sentence in sentences])

@overrides
def predict_instance(self, instance: Instance) -> JsonDict:
"""
This special predict_instance method allows for applying the correction model multiple times.

Parameters
---------

Returns
-------
JsonDict
A dictionary containing the following keys:
- logits_labels
- logits_d_tags
- class_probabilities_labels
- class_probabilities_d_tags
- max_error_probability
- words
- labels
- d_tags
- corrected_words
For an explanation of each of these see `Seq2Labels.decode()`.
"""
for i in range(self._iterations):
output = self._model.forward_on_instance(instance)
# integrate predictions back into instance for next iteration
tokens = [Token(word) for word in output["corrected_words"]]
instance = self._dataset_reader.text_to_instance(tokens)
return sanitize(output)

@overrides
def predict_batch_instance(self, instances: List[Instance]) -> List[JsonDict]:
"""
This special predict_batch_instance method allows for applying the correction model multiple times.

Parameters
----------

Returns
-------
List[JsonDict]
A list of dictionaries, each containing the following keys:
- logits_labels
- logits_d_tags
- class_probabilities_labels
- class_probabilities_d_tags
- max_error_probability
- words
- labels
- d_tags
- corrected_words
For an explanation of each of these see `Seq2Labels.decode()`.
"""
for i in range(self._iterations):
outputs = self._model.forward_on_instances(instances)
corrected_instances = []
for output in outputs:
tokens = [Token(word) for word in output["corrected_words"]]
instance = self._dataset_reader.text_to_instance(tokens)
corrected_instances.append(instance)
instances = corrected_instances
return sanitize(outputs)
ksteimel marked this conversation as resolved.
Show resolved Hide resolved

@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
"""
Convert a JsonDict into an Instance.

This is used internally by `self.predict_json()`.

Parameters
----------
json_dict: JsonDict
Expects a dict with a single key "sentence" with a value representing the string to correct.
i.e. ``{"sentence": "..."}``.

Returns
------
Instance
An instance with the following fields:
- tokens
- metadata
- labels
- d_tags
"""
sentence = json_dict["sentence"]
tokens = self._tokenizer.split_words(sentence)
# Add start token to front
tokens = [Token(START_TOKEN)] + tokens
return self._dataset_reader.text_to_instance(tokens)
97 changes: 96 additions & 1 deletion gector/seq2labels_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from allennlp.training.metrics import CategoricalAccuracy
from overrides import overrides
from torch.nn.modules.linear import Linear
from utils.helpers import PAD, UNK, get_target_sent_by_edits, START_TOKEN


@Model.register("seq2labels")
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(self, vocab: Vocabulary,
label_smoothing: float = 0.0,
confidence: float = 0.0,
del_confidence: float = 0.0,
min_error_probability: float = 0.0,
initializer: InitializerApplicator = InitializerApplicator(),
regularizer: Optional[RegularizerApplicator] = None) -> None:
super(Seq2Labels, self).__init__(vocab, regularizer)
Expand All @@ -72,6 +74,7 @@ def __init__(self, vocab: Vocabulary,
self.label_smoothing = label_smoothing
self.confidence = confidence
self.del_conf = del_confidence
self.min_error_probability = min_error_probability
self.incorr_index = self.vocab.get_token_index("INCORRECT",
namespace=detect_namespace)

Expand Down Expand Up @@ -161,14 +164,56 @@ def forward(self, # type: ignore
output_dict["loss"] = loss_labels + loss_d

if metadata is not None:
output_dict["words"] = [x["words"] for x in metadata]
output_dict["words"] = []
for instance in metadata:
output_dict["words"].append([word for word in instance["words"] if word != START_TOKEN])
return output_dict

@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Does a simple position-wise argmax over each token, converts indices to string labels, and
adds a ``"tags"`` key to the dictionary with the result.

Parameters
----------
output_dict: Dict[str, torch.Tensor]
This is expected to have the following keys:
- logits_labels
- logits_d_tags
- class_probabilities_labels
- class_probabilities_d_tags
- max_error_probability
- words

Returns
------
Dict
A dictionary containing the following keys:
- logits_labels
Logits for labels indicating the types of corrections
to perform.
- logits_d_tags
Logits for labels indicating the presence or absence
of grammatical errors.
- class_probabilities_labels
Class probabilities for labels indicating the types
of corrections to perform.
- class_probabilities_d_tags
Class probabilities for labels indicating the presence
or absence of grammatical errors.
- max_error_probability
A threshold probability that has to be exceeded for an
error to be corrected.
- words
The original tokens used to create the instance.
- labels
Labels indicating the types of corrections to perform.
- d_tags
Labels indicating the presence or absence of grammatical errors.
- corrected_words
`words` after applying the correction operations
specified in `labels`
"""
for label_namespace in self.label_namespaces:
all_predictions = output_dict[f'class_probabilities_{label_namespace}']
Expand All @@ -185,8 +230,58 @@ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
for x in argmax_indices]
all_tags.append(tags)
output_dict[f'{label_namespace}'] = all_tags
batch_size = len(output_dict['labels'])
output_dict['corrected_words'] = []
for i in range(batch_size):
words_in_instance = output_dict['words'][i]
batch_len = len(words_in_instance)
probs = output_dict['class_probabilities_labels'][i]
max_probs = torch.max(probs, dim=0)
probs = max_probs[0].tolist()
indices = max_probs[1].tolist()
if max(indices) == 0: # No corrections should be performed
output_dict["corrected_words"].append(output_dict["words"][i])
else:
actions_per_token = []
for j in range(batch_len):
if j == 0:
token = START_TOKEN
else:
token = words_in_instance[j]
if indices[j] == 0:
continue
suggested_token_operation = output_dict['labels'][i][j]
action = self.get_token_action(index=j, prob=probs[j],
sugg_token=suggested_token_operation)
if not action:
continue
actions_per_token.append(action)
corrected_sent = get_target_sent_by_edits(output_dict['words'][i], actions_per_token)
output_dict['corrected_words'].append(corrected_sent)
return output_dict

def get_token_action(self, index, prob, sugg_token):
"""Get list of suggested actions for token."""
# cases when we don't need to do anything
if prob < self.min_error_probability or sugg_token in [UNK, PAD, '$KEEP']:
return None

if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE':
start_pos = index
end_pos = index + 1
elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"):
start_pos = index + 1
end_pos = index + 1

if sugg_token == "$DELETE":
sugg_token_clear = ""
elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"):
sugg_token_clear = sugg_token[:]
else:
sugg_token_clear = sugg_token[sugg_token.index('_') + 1:]

return start_pos - 1, end_pos - 1, sugg_token_clear, prob

@overrides
ksteimel marked this conversation as resolved.
Show resolved Hide resolved
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
metrics_to_return = {metric_name: metric.get_metric(reset) for
Expand Down
1 change: 1 addition & 0 deletions gector/tokenizer_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def get_keys(self, index_name: str) -> List[str]:
return [index_name, f"{index_name}-offsets", f"{index_name}-type-ids", "mask"]


@TokenIndexer.register("gec-pretrained-bert-indexer")
class PretrainedBertIndexer(TokenizerIndexer):
# pylint: disable=line-too-long
"""
Expand Down
Loading