Skip to content

Commit

Permalink
Merge pull request #6 from FlagAI-Open/master
Browse files Browse the repository at this point in the history
merge
  • Loading branch information
shunxing1234 authored Jan 6, 2023
2 parents 5b9157e + f0ee4a4 commit 6c7850a
Show file tree
Hide file tree
Showing 7 changed files with 46 additions and 29 deletions.
2 changes: 1 addition & 1 deletion examples/AltCLIP/altclip_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,4 @@ def inference():
print(text_probs.cpu().numpy()[0].tolist())

if __name__=="__main__":
inference()
inference()
1 change: 0 additions & 1 deletion examples/glm_blank_filling/glm_generate_samples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
# Licensed under the Apache License, Version 2.0 (the "License")

import torch

from flagai.model.glm_model import GLMModel
from flagai.data.tokenizer import Tokenizer
from flagai.model.predictor.predictor import Predictor
Expand Down
17 changes: 9 additions & 8 deletions examples/glm_blank_filling/glm_generate_samples_en.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,25 @@
# Licensed under the Apache License, Version 2.0 (the "License")

import torch
from flagai.model.predictor.predictor import Predictor
from flagai.model.glm_model import GLMModel
from flagai.data.tokenizer import Tokenizer
from flagai.data.tokenizer.glm_large_en.glm_large_en_tokenizer import GLMLargeEnWordPieceTokenizer
from flagai.auto_model.auto_loader import AutoLoader
from flagai.model.predictor.predictor import Predictor
if __name__ == "__main__":
"""Main training program."""
print('Generate Samples')
# Random seeds for reproducibility.
# Model,

loader = AutoLoader(task_name='lm',
model_name='GLM-large-en',
only_download_config=False)
model_name='GLM-large-en-generation',
only_download_config=False)
model = loader.get_model()
tokenizer = loader.get_tokenizer()
model.cuda(torch.cuda.current_device())

model.cuda(torch.cuda.current_device())
predictor = Predictor(model, tokenizer)
# generate samples
text = [
'Question: Is drinking beer bad for your health? Answer: [gMASK]',
'Is drinking beer bad for your health?',
]
for t in text:
output = predictor.predict_generate_randomsample(
Expand Down
2 changes: 1 addition & 1 deletion examples/gpt2_text_writting/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
if __name__ == '__main__':
loader = AutoLoader("seq2seq",
"GPT2-base-ch",
model_dir="./state_dict/")
model_dir="./checkpoints/")
model = loader.get_model()
tokenizer = loader.get_tokenizer()
predictor = Predictor(model, tokenizer)
Expand Down
1 change: 1 addition & 0 deletions flagai/auto_model/auto_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def __getattr__(self, name):
"glm-large-ch": ["flagai.model.glm_model", "GLMModel", "glm", "nlp"],
"alm-1.0": ["flagai.model.alm_model", "ALMModel", "alm", "nlp"],
"glm-large-en": ["flagai.model.glm_model", "GLMModel", "glm", "nlp"],
"glm-large-en-generation": ["flagai.model.glm_model", "GLMModel", "glm", "nlp"],
"gpt2-base-ch": ["flagai.model.gpt2_model", "GPT2Model", "gpt2", "nlp"],
"cpm-large-ch": ["flagai.model.gpt2_model", "GPT2Model", "cpm", "nlp"],
"opt-125m-en": ["flagai.model.opt_model", "OPTModel", "opt", "nlp"],
Expand Down
24 changes: 10 additions & 14 deletions flagai/data/tokenizer/uni_tokenizer/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,15 +148,15 @@ def __init__(self,
if add_block_symbols:
self.add_command_token('sop', '<|startofpiece|>')
self.add_command_token('eop', '<|endofpiece|>',)
if add_task_mask:
self.add_command_token('gMASK', '[gMASK]')
self.add_command_token('sMASK', '[sMASK]')
if add_decoder_mask:
self.add_command_token('dBLOCK', '[dBLOCK]')
if add_sentinel_token > 0:
for i in range(1, add_sentinel_token):
self.add_command_token(f'MASK{i}', f'[MASK{i}]')
self.add_command_token(f'sop{i}', f'<|startofpiece{i}|>')
# if add_task_mask:
# self.add_command_token('gMASK', '[gMASK]')
# self.add_command_token('sMASK', '[sMASK]')
# if add_decoder_mask:
# self.add_command_token('dBLOCK', '[dBLOCK]')
# if add_sentinel_token > 0:
# for i in range(1, add_sentinel_token):
# self.add_command_token(f'MASK{i}', f'[MASK{i}]')
# self.add_command_token(f'sop{i}', f'<|startofpiece{i}|>')
elif self.tokenizer_class == "bpe":
if self.tokenizer_model_name.lower().startswith('roberta'):
self.num_command_tokens = 6
Expand Down Expand Up @@ -298,13 +298,9 @@ def __init__(self,
self.num_command_tokens += 6
self.token_end_id = self.text_tokenizer.convert_token_to_id(
'</s>')


if add_block_symbols:
sop_id = self.text_tokenizer.convert_token_to_id('<|startofpiece|>')
eop_id = self.text_tokenizer.convert_token_to_id('<|endofpiece|>')


self._command_tokens.extend([
CommandToken('sop', '<|startofpiece|>',
self.num_tokens + 1),
Expand Down Expand Up @@ -352,7 +348,7 @@ def __init__(self,
}
self.command_id_map = {tok.Id: tok for tok in self._command_tokens}
self._command_token_tokens = list(self.command_token_map.keys())

print([(k,v.Id) for k,v in self.command_name_map.items()])

def get_vocab(self):
return self.text_tokenizer.get_vocab()
Expand Down
28 changes: 24 additions & 4 deletions flagai/model/predictor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,38 @@
from itertools import islice
from transformers import AutoFeatureExtractor
import math

from flagai.model.file_utils import _get_model_id, _get_checkpoint_path, _get_vocab_path, _get_model_files
join = os.path.join

def download(model_name, download_path):
try:
model_id = _get_model_id(model_name)
except:
print("Model hub is not reachable!")
# prepare the download path
# downloading the files
if model_id and model_id != "null":
model_files = eval(_get_model_files(model_name))
print("model files:" + str(model_files))
for file_name in model_files:
if not file_name.endswith("bin"):
_get_vocab_path(os.path.join(download_path, model_name), file_name, model_id)
else :
_get_checkpoint_path(os.path.join(download_path, model_name), file_name, model_id)
return


def get_safety_checker():
# load safety model
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
safety_model_id = "CompVis/stable-diffusion-safety-checker"
path = os.getcwd() + "/checkpoints/"
if not os.path.exists(path+"SafetyChecker"):
download("SafetyChecker", path)
# safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = AutoFeatureExtractor.from_pretrained(
safety_model_id)
path+"SafetyChecker")
safety_checker = StableDiffusionSafetyChecker.from_pretrained(
safety_model_id)
path+"SafetyChecker")
return safety_checker, safety_feature_extractor


Expand Down

0 comments on commit 6c7850a

Please sign in to comment.