Skip to content

Commit

Permalink
Modification of output path
Browse files Browse the repository at this point in the history
  • Loading branch information
Abel Riboulot committed Aug 1, 2020
1 parent 706197d commit ce26218
Showing 1 changed file with 6 additions and 7 deletions.
13 changes: 6 additions & 7 deletions onnxt5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@ def get_encoder_decoder_tokenizer():

# Checks if encoder is already expanded
if not os.path.exists(path_t5_encoder):
download_generation_model(os.path.join(package_path, 'model_data', 't5-encoder.tar.gz'), 't5-encoder.tar.gz',
package_path)
download_generation_model(os.path.join(package_path, 'model_data'), 't5-encoder.tar.gz')

# Checks if decoder is already expanded
if not os.path.exists(path_t5_decoder):
download_generation_model(os.path.join(package_path, 'model_data', 't5-decoder-with-lm-head.tar.gz'),
't5-decoder-with-lm-head.tar.gz', package_path)
download_generation_model(os.path.join(package_path, 'model_data'), 't5-decoder-with-lm-head.tar.gz')

# Loading the model_data
decoder_sess = InferenceSession(path_t5_decoder)
Expand All @@ -44,10 +42,11 @@ def run_embeddings_text(encoder, decoder, tokenizer, prompt):

return encoder_output, decoder_output

def download_generation_model(path, object, output_dir):
def download_generation_model(output_dir, object):
url = f'https://t5-onnx-models.s3.amazonaws.com/{object}'
r = requests.get(url, allow_redirects=True)
open(path, 'wb').write(r.content)
tar = tarfile.open(path, "r:gz")
downloaded_file_path = os.path.join(output_dir, object)
open(downloaded_file_path, 'wb').write(r.content)
tar = tarfile.open(downloaded_file_path, "r:gz")
tar.extractall(path=output_dir)
tar.close()

0 comments on commit ce26218

Please sign in to comment.