Skip to content

Commit

Permalink
fix #175
Browse files Browse the repository at this point in the history
  • Loading branch information
makcedward committed Nov 11, 2020
1 parent bb2ec0f commit 289f752
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
12 changes: 5 additions & 7 deletions nlpaug/augmenter/word/context_word_embs.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,9 +264,6 @@ def insert(self, data):
elif len(output) > 1:
candidate = self.sample(output, 1)[0]

# if self.model_type in ['xlnet', 'roberta']:
# candidate = self.model.SUBWORD_PREFIX + candidate # Adding prefix for space

# In XLNet, it can be the first word of sentence which does not come with sapce. E.g. Zombine (ID:29110)
if self.model_type in ['xlnet', 'roberta']:
if candidate != '' and not candidate.startswith(self.model.SUBWORD_PREFIX):
Expand Down Expand Up @@ -394,7 +391,7 @@ def substitute(self, data):
outputs = self.model.predict(masked_texts, target_words=original_tokens, n=2)

# Update doc
for aug_input_pos, output, masked_text in zip(aug_input_poses, outputs, masked_texts):
for original_token, aug_input_pos, output, masked_text in zip(original_tokens, aug_input_poses, outputs, masked_texts):
split_result = split_results[aug_input_pos]
head_doc = split_result[5]
aug_idx = split_result[6][i] # augment position in text
Expand All @@ -408,15 +405,16 @@ def substitute(self, data):
candidate = output[0]
elif len(output) > 1:
candidate = self.sample(output, 1)[0]

# if self.model_type in ['xlnet', 'roberta']:
# candidate = self.model.SUBWORD_PREFIX + candidate # Adding prefix for space

# In XLNet, it can be the first word of sentence which does not come with sapce. E.g. Zombine (ID:29110)
if self.model_type in ['xlnet', 'roberta']:
if candidate != '' and not candidate.startswith(self.model.SUBWORD_PREFIX):
candidate = self.model.SUBWORD_PREFIX + candidate

# Fallback to original token if no candidate is appropriate
if candidate == '':
candidate = original_token

head_doc.update_change_log(aug_idx, token=candidate, action=Action.SUBSTITUTE,
change_seq=self.parent_change_seq+change_seq)

Expand Down
11 changes: 10 additions & 1 deletion test/augmenter/word/test_context_word_embs.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def test_multilingual(self):
# print('[{}]: {}'.format(input_param['lang'], augmented_text))

def test_contextual_word_embs(self):
# # self.execute_by_device('cuda')
# self.execute_by_device('cuda')
self.execute_by_device('cpu')

def execute_by_device(self, device):
Expand All @@ -106,9 +106,18 @@ def execute_by_device(self, device):
self.max_length([insert_aug, substitute_aug])
self.empty_replacement(substitute_aug)
self.skip_short_token(substitute_aug)
self.no_candidiate(substitute_aug)

self.assertLess(0, len(self.model_paths))

def no_candidiate(self, aug):
text = 'This python library helps you with augmenting nlp for your machine learning projects. Visit this introduction to understand about it.'
original_top_p = aug.model.top_p
aug.model.top_p = 0.1
augmented_text = aug.augment(text)
aug.model.top_p = original_top_p
self.assertTrue(aug.model.MASK_TOKEN not in augmented_text)

def skip_short_token(self, aug):
text = 'I am a boy'
self.assertNotEqual(text.lower(), aug.augment(text).lower())
Expand Down

0 comments on commit 289f752

Please sign in to comment.