From 50ded3b0a6f5f48dfd674add50dc570e48f3ba75 Mon Sep 17 00:00:00 2001 From: Virginia Adams Date: Thu, 10 Nov 2022 23:14:40 +0000 Subject: [PATCH 1/5] Fix for prompt table restore error Signed-off-by: Virginia Adams --- .../language_modeling/megatron_gpt_prompt_learning_model.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py index add7c898c80c..bc6439517cae 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py @@ -706,6 +706,10 @@ def save_checkpoint_as_nemo_file(self): # Set values back to their training state to continue training self.virtual_prompt_style = current_virtual_prompt_style self.virtual_prompt_source = current_virtual_prompt_source + + # Set revert prompt table back to previous state + for taskname in current_new_tasks: + del self.prompt_table.prompt_table[taskname] with open_dict(self.cfg): self.cfg.existing_tasks = current_existing_tasks From 8334a4bd131deeedf59f5dfa5b3c8ea35d0e94d2 Mon Sep 17 00:00:00 2001 From: Virginia Adams Date: Thu, 10 Nov 2022 23:22:16 +0000 Subject: [PATCH 2/5] Added more saftey checks Signed-off-by: Virginia Adams --- .../language_modeling/megatron_gpt_prompt_learning_model.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py index bc6439517cae..68b13355e09b 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py @@ -708,8 +708,10 @@ def save_checkpoint_as_nemo_file(self): self.virtual_prompt_source = current_virtual_prompt_source # Set revert prompt table back to previous state - for taskname in current_new_tasks: - del self.prompt_table.prompt_table[taskname] + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING: + for taskname in current_new_tasks: + if taskname in self.prompt_table.prompt_table: + del self.prompt_table.prompt_table[taskname] with open_dict(self.cfg): self.cfg.existing_tasks = current_existing_tasks From b5951efd86c3533bea2c3300ad4309063f2aa9eb Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Nov 2022 23:28:19 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../language_modeling/megatron_gpt_prompt_learning_model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py index 68b13355e09b..8a39cd25b64b 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py @@ -706,9 +706,9 @@ def save_checkpoint_as_nemo_file(self): # Set values back to their training state to continue training self.virtual_prompt_style = current_virtual_prompt_style self.virtual_prompt_source = current_virtual_prompt_source - + # Set revert prompt table back to previous state - if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING: + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING: for taskname in current_new_tasks: if taskname in self.prompt_table.prompt_table: del self.prompt_table.prompt_table[taskname] From c6a498104d939c520b2614affbae180ff5b96df6 Mon Sep 17 00:00:00 2001 From: Virginia Adams Date: Thu, 10 Nov 2022 23:55:02 +0000 Subject: [PATCH 4/5] Added more condition checks Signed-off-by: Virginia Adams --- .../megatron_base_prompt_learning_model.py | 6 ++++++ .../language_modeling/megatron_gpt_prompt_learning_model.py | 4 ++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py index ffae75ed5a34..4a409e633fc6 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py @@ -435,6 +435,12 @@ def save_checkpoint_as_nemo_file(self): # Set values back to their training state to continue training self.virtual_prompt_style = current_virtual_prompt_style self.virtual_prompt_source = current_virtual_prompt_source + + # Revert prompt table back to previous state + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.first_stage_of_pipeline(): + for taskname in current_new_tasks: + if taskname in self.prompt_table.prompt_table: + del self.prompt_table.prompt_table[taskname] with open_dict(self.cfg): self.cfg.existing_tasks = current_existing_tasks diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py index 68b13355e09b..daa1a5143afe 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py @@ -707,8 +707,8 @@ def save_checkpoint_as_nemo_file(self): self.virtual_prompt_style = current_virtual_prompt_style self.virtual_prompt_source = current_virtual_prompt_source - # Set revert prompt table back to previous state - if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING: + # Revert prompt table back to previous state + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.frozen_model.model.pre_process: for taskname in current_new_tasks: if taskname in self.prompt_table.prompt_table: del self.prompt_table.prompt_table[taskname] From 952b4652c729205a4afb0a5e5073dff71abf295d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 10 Nov 2022 23:58:12 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../language_modeling/megatron_base_prompt_learning_model.py | 4 ++-- .../language_modeling/megatron_gpt_prompt_learning_model.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py index 4a409e633fc6..67448badb43a 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_base_prompt_learning_model.py @@ -435,9 +435,9 @@ def save_checkpoint_as_nemo_file(self): # Set values back to their training state to continue training self.virtual_prompt_style = current_virtual_prompt_style self.virtual_prompt_source = current_virtual_prompt_source - + # Revert prompt table back to previous state - if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.first_stage_of_pipeline(): + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.first_stage_of_pipeline(): for taskname in current_new_tasks: if taskname in self.prompt_table.prompt_table: del self.prompt_table.prompt_table[taskname] diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py index daa1a5143afe..5b083ed86b93 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_prompt_learning_model.py @@ -706,9 +706,9 @@ def save_checkpoint_as_nemo_file(self): # Set values back to their training state to continue training self.virtual_prompt_style = current_virtual_prompt_style self.virtual_prompt_source = current_virtual_prompt_source - + # Revert prompt table back to previous state - if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.frozen_model.model.pre_process: + if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.frozen_model.model.pre_process: for taskname in current_new_tasks: if taskname in self.prompt_table.prompt_table: del self.prompt_table.prompt_table[taskname]