Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fix gec predictor #13

Merged
merged 18 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .github/workflows/python-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ on:
push:
branches: [ "master" ]
pull_request:
branches: [ "master" ]

jobs:
build:
Expand All @@ -28,6 +27,10 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -e .
- name: Test
- name: Unit Testing
run: |
pytest -v tests
- name: Regression Testing
run: |
python regression_tests/test_gector_roberta.py
python regression_tests/test_regression_data_predictor.py
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ pip install -e .
```
The project was tested using Python 3.8.

## Unit tests
After activating the conda environment, simply run the code below:
`pytest -v tests`

## Datasets
All the public GEC datasets used in the paper can be downloaded from [here](https://www.cl.cam.ac.uk/research/nl/bea2019st/#data).<br>
Synthetically created datasets can be generated/downloaded [here](https://github.com/awasthiabhijeet/PIE/tree/master/errorify).<br>
Expand Down
13 changes: 13 additions & 0 deletions environment.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
name: gector
dependencies:
- python=3.8
- pytorch=1.10.0
- python-Levenshtein
- transformers
- scikit-learn
- sentencepiece
- overrides=4.1.2
- numpy
- pip:
- allennlp==0.9.0

90 changes: 58 additions & 32 deletions gector/datareader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@

from allennlp.common.file_utils import cached_path
from allennlp.data.dataset_readers.dataset_reader import DatasetReader
from allennlp.data.fields import TextField, SequenceLabelField, MetadataField, Field
from allennlp.data.fields import (
TextField,
SequenceLabelField,
MetadataField,
Field,
)
from allennlp.data.instance import Instance
from allennlp.data.token_indexers import TokenIndexer, SingleIdTokenIndexer
from allennlp.data.tokenizers import Token
Expand Down Expand Up @@ -37,23 +42,28 @@ class Seq2LabelsDatasetReader(DatasetReader):
are pre-tokenised in the data file.
max_len: if set than will truncate long sentences
"""

# fix broken sentences mostly in Lang8
BROKEN_SENTENCES_REGEXP = re.compile(r'\.[a-zA-RT-Z]')

def __init__(self,
token_indexers: Dict[str, TokenIndexer] = None,
delimeters: dict = SEQ_DELIMETERS,
skip_correct: bool = False,
skip_complex: int = 0,
lazy: bool = False,
max_len: int = None,
test_mode: bool = False,
tag_strategy: str = "keep_one",
tn_prob: float = 0,
tp_prob: float = 0,
broken_dot_strategy: str = "keep") -> None:
BROKEN_SENTENCES_REGEXP = re.compile(r"\.[a-zA-RT-Z]")

def __init__(
self,
token_indexers: Dict[str, TokenIndexer] = None,
delimeters: dict = SEQ_DELIMETERS,
skip_correct: bool = False,
skip_complex: int = 0,
lazy: bool = False,
max_len: int = None,
test_mode: bool = False,
tag_strategy: str = "keep_one",
tn_prob: float = 0,
tp_prob: float = 0,
broken_dot_strategy: str = "keep",
) -> None:
super().__init__(lazy)
self._token_indexers = token_indexers or {'tokens': SingleIdTokenIndexer()}
self._token_indexers = token_indexers or {
"tokens": SingleIdTokenIndexer()
}
self._delimeters = delimeters
self._max_len = max_len
self._skip_correct = skip_correct
Expand All @@ -69,16 +79,23 @@ def _read(self, file_path):
# if `file_path` is a URL, redirect to the cache
file_path = cached_path(file_path)
with open(file_path, "r") as data_file:
logger.info("Reading instances from lines in file at: %s", file_path)
logger.info(
"Reading instances from lines in file at: %s", file_path
)
for line in data_file:
line = line.strip("\n")
# skip blank and broken lines
if not line or (not self._test_mode and self._broken_dot_strategy == 'skip'
and self.BROKEN_SENTENCES_REGEXP.search(line) is not None):
if not line or (
not self._test_mode
and self._broken_dot_strategy == "skip"
and self.BROKEN_SENTENCES_REGEXP.search(line) is not None
):
continue

tokens_and_tags = [pair.rsplit(self._delimeters['labels'], 1)
for pair in line.split(self._delimeters['tokens'])]
tokens_and_tags = [
pair.rsplit(self._delimeters["labels"], 1)
for pair in line.split(self._delimeters["tokens"])
]
try:
tokens = [Token(token) for token, tag in tokens_and_tags]
tags = [tag for token, tag in tokens_and_tags]
Expand All @@ -91,14 +108,14 @@ def _read(self, file_path):

words = [x.text for x in tokens]
if self._max_len is not None:
tokens = tokens[:self._max_len]
tags = None if tags is None else tags[:self._max_len]
tokens = tokens[: self._max_len]
tags = None if tags is None else tags[: self._max_len]
instance = self.text_to_instance(tokens, tags, words)
if instance:
yield instance

def extract_tags(self, tags: List[str]):
op_del = self._delimeters['operations']
op_del = self._delimeters["operations"]

labels = [x.split(op_del) for x in tags]

Expand All @@ -117,17 +134,24 @@ def extract_tags(self, tags: List[str]):
else:
raise Exception("Incorrect tag strategy")

detect_tags = ["CORRECT" if label == "$KEEP" else "INCORRECT" for label in labels]
detect_tags = [
"CORRECT" if label == "$KEEP" else "INCORRECT" for label in labels
]
return labels, detect_tags, comlex_flag_dict

def text_to_instance(self, tokens: List[Token], tags: List[str] = None,
words: List[str] = None) -> Instance: # type: ignore
def text_to_instance(
self,
tokens: List[Token],
tags: List[str] = None,
words: List[str] = None,
) -> Instance: # type: ignore
"""
We take `pre-tokenized` input here, because we don't have a tokenizer in this class.
"""
# pylint: disable=arguments-differ
fields: Dict[str, Field] = {}
sequence = TextField(tokens, self._token_indexers)
# Set size of tokens to _max_len + 1 since $START token is being added
sequence = TextField(tokens[: self._max_len + 1], self._token_indexers)
Frost45 marked this conversation as resolved.
Show resolved Hide resolved
fields["tokens"] = sequence
# If words has not been explicitly passed in, create it from tokens.
if words is None:
Expand All @@ -147,8 +171,10 @@ def text_to_instance(self, tokens: List[Token], tags: List[str] = None,
if rnd > self._tp_prob:
return None

fields["labels"] = SequenceLabelField(labels, sequence,
label_namespace="labels")
fields["d_tags"] = SequenceLabelField(detect_tags, sequence,
label_namespace="d_tags")
fields["labels"] = SequenceLabelField(
labels, sequence, label_namespace="labels"
)
fields["d_tags"] = SequenceLabelField(
detect_tags, sequence, label_namespace="d_tags"
)
return Instance(fields)
Loading