Skip to content

Commit

Permalink
Draft: Fix restoring from checkpoint for case when `model.common_data…
Browse files Browse the repository at this point in the history
…set_parameters.label_vocab_dir` is provided (#4136)

* Fix restoring from checkpoint with label vocab dir

Signed-off-by: PeganovAnton <[email protected]>

* Add tests for various ways to pass label ids to model

Signed-off-by: PeganovAnton <[email protected]>

* Fix typo

Signed-off-by: PeganovAnton <[email protected]>

* Fix typo

Signed-off-by: PeganovAnton <[email protected]>

* Do not create tmp directory

Signed-off-by: PeganovAnton <[email protected]>

* Fix parameter name

Signed-off-by: PeganovAnton <[email protected]>

* finish cherry-pick op

Signed-off-by: PeganovAnton <[email protected]>

* Fix labels errors

Signed-off-by: PeganovAnton <[email protected]>

* Remove duplicate stage

Signed-off-by: PeganovAnton <[email protected]>

* Change target branch

Signed-off-by: PeganovAnton <[email protected]>
  • Loading branch information
PeganovAnton authored May 9, 2022
1 parent 6fd6254 commit 54f6bbf
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 9 deletions.
99 changes: 99 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -1555,6 +1555,105 @@ pipeline {
}
}
}
stage('Punctuation & Capitalization, Different ways of passing labels to model') {
when {
anyOf {
branch 'r1.9.0'
changeRequest target: 'r1.9.0'
}
}
failFast true
stages {
stage('Punctuation & Capitalization, Using model.common_datasest_parameters.label_vocab_dir') {
steps {
sh 'cd examples/nlp/token_classification && \
label_vocab_dir=label_vocab_dir && \
mkdir -p ${label_vocab_dir} && \
punct_label_vocab="${label_vocab_dir}/punct_label_vocab.csv" && \
capit_label_vocab="${label_vocab_dir}/capit_label_vocab.csv" && \
printf "O\n,\n.\n?\n" > "${punct_label_vocab}" && \
printf "O\nU\n" > "${capit_label_vocab}" && \
CUDA_LAUNCH_BLOCKING=1 python punctuation_capitalization_train_evaluate.py \
model.train_ds.use_tarred_dataset=false \
model.train_ds.ds_item=/home/TestData/nlp/token_classification_punctuation \
model.validation_ds.ds_item=/home/TestData/nlp/token_classification_punctuation \
model.test_ds.ds_item=/home/TestData/nlp/token_classification_punctuation \
model.language_model.pretrained_model_name=distilbert-base-uncased \
model.common_dataset_parameters.label_vocab_dir="${label_vocab_dir}" \
model.class_labels.punct_labels_file="$(basename "${punct_label_vocab}")" \
model.class_labels.capit_labels_file="$(basename "${capit_label_vocab}")" \
+model.train_ds.use_cache=false \
+model.validation_ds.use_cache=false \
+model.test_ds.use_cache=false \
trainer.devices=[0,1] \
trainer.strategy=ddp \
trainer.max_epochs=1 \
+exp_manager.explicit_log_dir=/home/TestData/nlp/token_classification_punctuation/output \
+do_testing=false && \
CUDA_LAUNCH_BLOCKING=1 python punctuation_capitalization_train_evaluate.py \
+do_training=false \
+do_testing=true \
~model.train_ds \
~model.validation_ds \
model.test_ds.ds_item=/home/TestData/nlp/token_classification_punctuation \
pretrained_model=/home/TestData/nlp/token_classification_punctuation/output/checkpoints/Punctuation_and_Capitalization.nemo \
+model.train_ds.use_cache=false \
+model.validation_ds.use_cache=false \
+model.test_ds.use_cache=false \
trainer.devices=[0,1] \
trainer.strategy=ddp \
trainer.max_epochs=1 \
exp_manager=null && \
rm -r "${label_vocab_dir}" && \
rm -rf /home/TestData/nlp/token_classification_punctuation/output/*'
}
}
stage('Punctuation & Capitalization, Using model.common_datasest_parameters.{punct,capit}_label_ids') {
steps {
sh 'cd examples/nlp/token_classification && \
conf_path=/home/TestData/nlp/token_classification_punctuation && \
conf_name=punctuation_capitalization_config_with_ids && \
cp conf/punctuation_capitalization_config.yaml "${conf_path}/${conf_name}.yaml" && \
sed -i $\'s/punct_label_ids: null/punct_label_ids: {O: 0, \\\',\\\': 1, .: 2, \\\'?\\\': 3}/\' \
"${conf_path}/${conf_name}.yaml" && \
sed -i $\'s/capit_label_ids: null/capit_label_ids: {O: 0, U: 1}/\' \
"${conf_path}/${conf_name}.yaml" && \
CUDA_LAUNCH_BLOCKING=1 python punctuation_capitalization_train_evaluate.py \
--config-path "${conf_path}" \
--config-name "${conf_name}" \
model.train_ds.use_tarred_dataset=false \
model.train_ds.ds_item=/home/TestData/nlp/token_classification_punctuation \
model.validation_ds.ds_item=/home/TestData/nlp/token_classification_punctuation \
model.test_ds.ds_item=/home/TestData/nlp/token_classification_punctuation \
model.language_model.pretrained_model_name=distilbert-base-uncased \
+model.train_ds.use_cache=false \
+model.validation_ds.use_cache=false \
+model.test_ds.use_cache=false \
trainer.devices=[0,1] \
trainer.strategy=ddp \
trainer.max_epochs=1 \
+exp_manager.explicit_log_dir=/home/TestData/nlp/token_classification_punctuation/output \
+do_testing=false && \
CUDA_LAUNCH_BLOCKING=1 python punctuation_capitalization_train_evaluate.py \
+do_training=false \
+do_testing=true \
~model.train_ds \
~model.validation_ds \
model.test_ds.ds_item=/home/TestData/nlp/token_classification_punctuation \
pretrained_model=/home/TestData/nlp/token_classification_punctuation/output/checkpoints/Punctuation_and_Capitalization.nemo \
+model.train_ds.use_cache=false \
+model.validation_ds.use_cache=false \
+model.test_ds.use_cache=false \
trainer.devices=[0,1] \
trainer.strategy=ddp \
trainer.max_epochs=1 \
exp_manager=null && \
rm -rf /home/TestData/nlp/token_classification_punctuation/output/* && \
rm "${conf_path}/${conf_name}.yaml"'
}
}
}
}
stage('Punctuation & Capitalization inference') {
when {
anyOf {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -638,16 +638,16 @@ def _check_label_config_parameters(self) -> None:
)

def _extract_label_vocab_files_from_config(self) -> Tuple[Optional[Path], Optional[Path]]:
if self._cfg.common_dataset_parameters.label_vocab_dir is None:
if self._is_model_being_restored():
punct_label_vocab_file = self._cfg.class_labels.punct_labels_file
capit_label_vocab_file = self._cfg.class_labels.capit_labels_file
else:
punct_label_vocab_file, capit_label_vocab_file = None, None
if self._is_model_being_restored():
punct_label_vocab_file = self._cfg.class_labels.punct_labels_file
capit_label_vocab_file = self._cfg.class_labels.capit_labels_file
else:
label_vocab_dir = Path(self._cfg.common_dataset_parameters.label_vocab_dir).expanduser()
punct_label_vocab_file = label_vocab_dir / self._cfg.class_labels.punct_labels_file
capit_label_vocab_file = label_vocab_dir / self._cfg.class_labels.capit_labels_file
if self._cfg.common_dataset_parameters.label_vocab_dir is None:
punct_label_vocab_file, capit_label_vocab_file = None, None
else:
label_vocab_dir = Path(self._cfg.common_dataset_parameters.label_vocab_dir).expanduser()
punct_label_vocab_file = label_vocab_dir / self._cfg.class_labels.punct_labels_file
capit_label_vocab_file = label_vocab_dir / self._cfg.class_labels.capit_labels_file
return punct_label_vocab_file, capit_label_vocab_file

def _set_label_ids(self) -> None:
Expand Down

0 comments on commit 54f6bbf

Please sign in to comment.