forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add AWS SageMaker ASR Examples (NVIDIA#5638)
* Base code for AWS SageMaker example Signed-off-by: SeanNaren <[email protected]> * Remove format Signed-off-by: SeanNaren <[email protected]> * wrap Signed-off-by: SeanNaren <[email protected]> * Add a notebook with the code Signed-off-by: SeanNaren <[email protected]> * Setup Signed-off-by: SeanNaren <[email protected]> * Update notebook Signed-off-by: SeanNaren <[email protected]> * Remove space Signed-off-by: SeanNaren <[email protected]> * Fix spelling mistake Signed-off-by: SeanNaren <[email protected]> * Add message to explain usage Signed-off-by: SeanNaren <[email protected]> * Add CommonVoice esperanto example Signed-off-by: SeanNaren <[email protected]> * Fix path Signed-off-by: SeanNaren <[email protected]> * Fixes Signed-off-by: SeanNaren <[email protected]> * Import sox locally, add documentation Signed-off-by: SeanNaren <[email protected]> * Address reviews Signed-off-by: SeanNaren <[email protected]> * Address reviews Signed-off-by: SeanNaren <[email protected]> * Address reviews Signed-off-by: SeanNaren <[email protected]> * Add cell to download the SSL model Signed-off-by: SeanNaren <[email protected]> * Set max epochs to 300 Signed-off-by: SeanNaren <[email protected]> * Fixes, introduce HF dataset instructions Signed-off-by: SeanNaren <[email protected]> * Upstream updates from other branch Signed-off-by: SeanNaren <[email protected]> * Fix warning Signed-off-by: SeanNaren <[email protected]> * Add README, add image Signed-off-by: SeanNaren <[email protected]> * Fix warning Signed-off-by: SeanNaren <[email protected]> * Address feedback Signed-off-by: SeanNaren <[email protected]> * Feedback Signed-off-by: SeanNaren <[email protected]> --------- Signed-off-by: SeanNaren <[email protected]> Signed-off-by: Jason <[email protected]>
- Loading branch information
Showing
6 changed files
with
1,352 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import glob | ||
import json | ||
import os | ||
import os.path | ||
import subprocess | ||
import tarfile | ||
from typing import Optional | ||
|
||
import wget | ||
|
||
|
||
# Function to build a manifest | ||
def build_manifest(transcripts_path, manifest_path, data_dir, mount_dir, wav_path): | ||
# create manifest with reference to this directory. This is useful when mounting the dataset. | ||
mount_dir = mount_dir if mount_dir else data_dir | ||
with open(transcripts_path, 'r') as fin: | ||
with open(manifest_path, 'w') as fout: | ||
for line in fin: | ||
# Lines look like this: | ||
# <s> transcript </s> (fileID) | ||
transcript = line[: line.find('(') - 1].lower() | ||
transcript = transcript.replace('<s>', '').replace('</s>', '') | ||
transcript = transcript.strip() | ||
|
||
file_id = line[line.find('(') + 1 : -2] # e.g. "cen4-fash-b" | ||
audio_path = os.path.join( | ||
data_dir, wav_path, file_id[file_id.find('-') + 1 : file_id.rfind('-')], file_id + '.wav' | ||
) | ||
|
||
mounted_audio_path = os.path.join( | ||
mount_dir, wav_path, file_id[file_id.find('-') + 1 : file_id.rfind('-')], file_id + '.wav' | ||
) | ||
# import sox here to not require sox to be available for importing all utils. | ||
import sox | ||
|
||
duration = sox.file_info.duration(audio_path) | ||
|
||
# Write the metadata to the manifest | ||
metadata = {"audio_filepath": mounted_audio_path, "duration": duration, "text": transcript} | ||
json.dump(metadata, fout) | ||
fout.write('\n') | ||
|
||
|
||
def download_an4(data_dir: str = "./", train_mount_dir: Optional[str] = None, test_mount_dir: Optional[str] = None): | ||
""" | ||
Function to download the AN4 dataset. This hides pre-processing boilerplate for notebook ASR examples. | ||
Args: | ||
data_dir: Path to store the data. | ||
train_mount_dir: If you plan to mount the dataset, use this to prepend the mount directory to the | ||
audio filepath in the train manifest. | ||
test_mount_dir: If you plan to mount the dataset, use this to prepend the mount directory to the | ||
audio filepath in the test manifest. | ||
""" | ||
print("******") | ||
os.makedirs(data_dir, exist_ok=True) | ||
if not os.path.exists(data_dir + '/an4_sphere.tar.gz'): | ||
an4_url = 'https://dldata-public.s3.us-east-2.amazonaws.com/an4_sphere.tar.gz' | ||
an4_path = wget.download(an4_url, data_dir) | ||
print(f"Dataset downloaded at: {an4_path}") | ||
else: | ||
print("Tarfile already exists.") | ||
an4_path = data_dir + '/an4_sphere.tar.gz' | ||
|
||
if not os.path.exists(data_dir + '/an4/'): | ||
tar = tarfile.open(an4_path) | ||
tar.extractall(path=data_dir) | ||
|
||
print("Converting .sph to .wav...") | ||
sph_list = glob.glob(data_dir + '/an4/**/*.sph', recursive=True) | ||
for sph_path in sph_list: | ||
wav_path = sph_path[:-4] + '.wav' | ||
cmd = ["sox", sph_path, wav_path] | ||
subprocess.run(cmd) | ||
print("Finished conversion.\n******") | ||
|
||
# Building Manifests | ||
print("******") | ||
train_transcripts = data_dir + '/an4/etc/an4_train.transcription' | ||
train_manifest = data_dir + '/an4/train_manifest.json' | ||
|
||
if not os.path.isfile(train_manifest): | ||
build_manifest(train_transcripts, train_manifest, data_dir, train_mount_dir, 'an4/wav/an4_clstk') | ||
print("Training manifest created.") | ||
|
||
test_transcripts = data_dir + '/an4/etc/an4_test.transcription' | ||
test_manifest = data_dir + '/an4/test_manifest.json' | ||
if not os.path.isfile(test_manifest): | ||
build_manifest(test_transcripts, test_manifest, data_dir, test_mount_dir, 'an4/wav/an4test_clstk') | ||
print("Test manifest created.") | ||
print("***Done***") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# NeMo Cloud Tutorials | ||
|
||
The best way to get started with NeMo in the cloud. | ||
|
||
## ASR | ||
* [Quickstart: Training an ASR model on AWS SageMaker](aws/SageMaker_ASR_Training.ipynb) | ||
* [Multi-Node Scaling: ASR Fine-Tuning on AWS SageMaker](aws/ASR_Finetuning_at_Scale_with_AWS_SageMaker.ipynb) |
Oops, something went wrong.