Skip to content

Commit

Permalink
Merge pull request #6 from EducationalTestingService/feature/predicto…
Browse files Browse the repository at this point in the history
…r_from_gec_model

Allennlp Predictor class that mimics gec_model
  • Loading branch information
ksteimel authored Oct 24, 2022
2 parents 5da5955 + b04376c commit c937981
Show file tree
Hide file tree
Showing 14 changed files with 5,444 additions and 10 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -132,4 +132,7 @@ dmypy.json
# PyCharm
.idea

*.sh
*.sh

# pytorch weights files
*.th
2 changes: 1 addition & 1 deletion gector/bert_token_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ def forward(
return util.uncombine_initial_dims(selected_embeddings, offsets.size())


# @TokenEmbedder.register("bert-pretrained")
@TokenEmbedder.register("gec-bert-pretrained")
class PretrainedBertEmbedder(BertEmbedder):

"""
Expand Down
3 changes: 3 additions & 0 deletions gector/datareader.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,9 @@ def text_to_instance(self, tokens: List[Token], tags: List[str] = None,
fields: Dict[str, Field] = {}
sequence = TextField(tokens, self._token_indexers)
fields["tokens"] = sequence
# If words has not been explicitly passed in, create it from tokens.
if words is None:
words = [token.text for token in tokens]
fields["metadata"] = MetadataField({"words": words})
if tags is not None:
labels, detect_tags, complex_flag_dict = self.extract_tags(tags)
Expand Down
22 changes: 22 additions & 0 deletions gector/gec_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,28 @@ def __init__(self, vocab_path=None, model_paths=None,
del_confidence=0,
resolve_cycles=False,
):
"""
Class used to enable prediction from GECToR model.
Parameters
----------
vocab_path
model_paths: List[Path]
weigths
max_len
min_len
lowercase_tokens
log
iterations
model_name
special_tokens_fix
is_ensemble
min_error_probability
confidence
del_confidence
resolve_cycles: bool
This parameter is unused.
"""
self.model_weights = list(map(float, weigths)) if weigths else [1] * len(model_paths)
self.device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
self.max_len = max_len
Expand Down
182 changes: 182 additions & 0 deletions gector/gec_predictor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
from typing import Dict, List

import numpy
from allennlp.predictors import Predictor
from allennlp.models import Model
from allennlp.common.util import sanitize
from overrides import overrides
from allennlp.common.util import JsonDict
from allennlp.data import DatasetReader, Instance, Token
from allennlp.data.tokenizers.word_splitter import JustSpacesWordSplitter
from allennlp.models import Model
from utils.helpers import START_TOKEN


@Predictor.register("gec-predictor")
class GecPredictor(Predictor):
"""
A Predictor for generating predictions from GECToR.
Note that currently, this is unable to handle ensemble predictions.
"""

def __init__(self,
model: Model,
dataset_reader: DatasetReader,
iterations: int = 3) -> None:
"""
Parameters
---------
model: Model
An instantiated `Seq2Labels` model for performing grammatical error correction.
dataset_reader: DatasetReader
An instantiated dataset reader, typically `Seq2LabelsDatasetReader`.
iterations: int
This represents the number of times grammatical error correction is applied to the input.
"""
super().__init__(model, dataset_reader)
self._tokenizer = JustSpacesWordSplitter()
self._iterations = iterations

def predict(self, sentence: str) -> JsonDict:
"""
Generate error correction predictions for a single input string.
Parameters
----------
sentence: str
The input text to perform error correction on.
Returns
-------
JsonDict
A dictionary containing the following keys:
- logits_labels
- logits_d_tags
- class_probabilities_labels
- class_probabilities_d_tags
- max_error_probability
- words
- labels
- d_tags
- corrected_words
For an explanation of each of these see `Seq2Labels.decode()`.
"""
return self.predict_json({"sentence": sentence})

def predict_batch(self, sentences: List[str]) -> List[JsonDict]:
"""
Generate predictions for a sequence of input strings.
Parameters
----------
sentences: List[str]
A list of strings to correct.
Returns
-------
List[JsonDict]
A list of dictionaries, each containing the following keys:
- logits_labels
- logits_d_tags
- class_probabilities_labels
- class_probabilities_d_tags
- max_error_probability
- words
- labels
- d_tags
- corrected_words
For an explanation of each of these see `Seq2Labels.decode()`.
"""
return self.predict_batch_json([{"sentence": sentence} for sentence in sentences])

@overrides
def predict_instance(self, instance: Instance) -> JsonDict:
"""
This special predict_instance method allows for applying the correction model multiple times.
Parameters
---------
Returns
-------
JsonDict
A dictionary containing the following keys:
- logits_labels
- logits_d_tags
- class_probabilities_labels
- class_probabilities_d_tags
- max_error_probability
- words
- labels
- d_tags
- corrected_words
For an explanation of each of these see `Seq2Labels.decode()`.
"""
for i in range(self._iterations):
output = self._model.forward_on_instance(instance)
# integrate predictions back into instance for next iteration
tokens = [Token(word) for word in output["corrected_words"]]
instance = self._dataset_reader.text_to_instance(tokens)
return sanitize(output)

@overrides
def predict_batch_instance(self, instances: List[Instance]) -> List[JsonDict]:
"""
This special predict_batch_instance method allows for applying the correction model multiple times.
Parameters
----------
Returns
-------
List[JsonDict]
A list of dictionaries, each containing the following keys:
- logits_labels
- logits_d_tags
- class_probabilities_labels
- class_probabilities_d_tags
- max_error_probability
- words
- labels
- d_tags
- corrected_words
For an explanation of each of these see `Seq2Labels.decode()`.
"""
for i in range(self._iterations):
outputs = self._model.forward_on_instances(instances)
corrected_instances = []
for output in outputs:
tokens = [Token(word) for word in output["corrected_words"]]
instance = self._dataset_reader.text_to_instance(tokens)
corrected_instances.append(instance)
instances = corrected_instances
return sanitize(outputs)

@overrides
def _json_to_instance(self, json_dict: JsonDict) -> Instance:
"""
Convert a JsonDict into an Instance.
This is used internally by `self.predict_json()`.
Parameters
----------
json_dict: JsonDict
Expects a dict with a single key "sentence" with a value representing the string to correct.
i.e. ``{"sentence": "..."}``.
Returns
------
Instance
An instance with the following fields:
- tokens
- metadata
- labels
- d_tags
"""
sentence = json_dict["sentence"]
tokens = self._tokenizer.split_words(sentence)
# Add start token to front
tokens = [Token(START_TOKEN)] + tokens
return self._dataset_reader.text_to_instance(tokens)
97 changes: 96 additions & 1 deletion gector/seq2labels_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from allennlp.training.metrics import CategoricalAccuracy
from overrides import overrides
from torch.nn.modules.linear import Linear
from utils.helpers import PAD, UNK, get_target_sent_by_edits, START_TOKEN


@Model.register("seq2labels")
Expand Down Expand Up @@ -60,6 +61,7 @@ def __init__(self, vocab: Vocabulary,
label_smoothing: float = 0.0,
confidence: float = 0.0,
del_confidence: float = 0.0,
min_error_probability: float = 0.0,
initializer: InitializerApplicator = InitializerApplicator(),
regularizer: Optional[RegularizerApplicator] = None) -> None:
super(Seq2Labels, self).__init__(vocab, regularizer)
Expand All @@ -72,6 +74,7 @@ def __init__(self, vocab: Vocabulary,
self.label_smoothing = label_smoothing
self.confidence = confidence
self.del_conf = del_confidence
self.min_error_probability = min_error_probability
self.incorr_index = self.vocab.get_token_index("INCORRECT",
namespace=detect_namespace)

Expand Down Expand Up @@ -161,14 +164,56 @@ def forward(self, # type: ignore
output_dict["loss"] = loss_labels + loss_d

if metadata is not None:
output_dict["words"] = [x["words"] for x in metadata]
output_dict["words"] = []
for instance in metadata:
output_dict["words"].append([word for word in instance["words"] if word != START_TOKEN])
return output_dict

@overrides
def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Does a simple position-wise argmax over each token, converts indices to string labels, and
adds a ``"tags"`` key to the dictionary with the result.
Parameters
----------
output_dict: Dict[str, torch.Tensor]
This is expected to have the following keys:
- logits_labels
- logits_d_tags
- class_probabilities_labels
- class_probabilities_d_tags
- max_error_probability
- words
Returns
------
Dict
A dictionary containing the following keys:
- logits_labels
Logits for labels indicating the types of corrections
to perform.
- logits_d_tags
Logits for labels indicating the presence or absence
of grammatical errors.
- class_probabilities_labels
Class probabilities for labels indicating the types
of corrections to perform.
- class_probabilities_d_tags
Class probabilities for labels indicating the presence
or absence of grammatical errors.
- max_error_probability
A threshold probability that has to be exceeded for an
error to be corrected.
- words
The original tokens used to create the instance.
- labels
Labels indicating the types of corrections to perform.
- d_tags
Labels indicating the presence or absence of grammatical errors.
- corrected_words
`words` after applying the correction operations
specified in `labels`
"""
for label_namespace in self.label_namespaces:
all_predictions = output_dict[f'class_probabilities_{label_namespace}']
Expand All @@ -185,8 +230,58 @@ def decode(self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor
for x in argmax_indices]
all_tags.append(tags)
output_dict[f'{label_namespace}'] = all_tags
batch_size = len(output_dict['labels'])
output_dict['corrected_words'] = []
for i in range(batch_size):
words_in_instance = output_dict['words'][i]
batch_len = len(words_in_instance)
probs = output_dict['class_probabilities_labels'][i]
max_probs = torch.max(probs, dim=0)
probs = max_probs[0].tolist()
indices = max_probs[1].tolist()
if max(indices) == 0: # No corrections should be performed
output_dict["corrected_words"].append(output_dict["words"][i])
else:
actions_per_token = []
for j in range(batch_len):
if j == 0:
token = START_TOKEN
else:
token = words_in_instance[j]
if indices[j] == 0:
continue
suggested_token_operation = output_dict['labels'][i][j]
action = self.get_token_action(index=j, prob=probs[j],
sugg_token=suggested_token_operation)
if not action:
continue
actions_per_token.append(action)
corrected_sent = get_target_sent_by_edits(output_dict['words'][i], actions_per_token)
output_dict['corrected_words'].append(corrected_sent)
return output_dict

def get_token_action(self, index, prob, sugg_token):
"""Get list of suggested actions for token."""
# cases when we don't need to do anything
if prob < self.min_error_probability or sugg_token in [UNK, PAD, '$KEEP']:
return None

if sugg_token.startswith('$REPLACE_') or sugg_token.startswith('$TRANSFORM_') or sugg_token == '$DELETE':
start_pos = index
end_pos = index + 1
elif sugg_token.startswith("$APPEND_") or sugg_token.startswith("$MERGE_"):
start_pos = index + 1
end_pos = index + 1

if sugg_token == "$DELETE":
sugg_token_clear = ""
elif sugg_token.startswith('$TRANSFORM_') or sugg_token.startswith("$MERGE_"):
sugg_token_clear = sugg_token[:]
else:
sugg_token_clear = sugg_token[sugg_token.index('_') + 1:]

return start_pos - 1, end_pos - 1, sugg_token_clear, prob

@overrides
def get_metrics(self, reset: bool = False) -> Dict[str, float]:
metrics_to_return = {metric_name: metric.get_metric(reset) for
Expand Down
1 change: 1 addition & 0 deletions gector/tokenizer_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def get_keys(self, index_name: str) -> List[str]:
return [index_name, f"{index_name}-offsets", f"{index_name}-type-ids", "mask"]


@TokenIndexer.register("gec-pretrained-bert-indexer")
class PretrainedBertIndexer(TokenizerIndexer):
# pylint: disable=line-too-long
"""
Expand Down
Loading

0 comments on commit c937981

Please sign in to comment.