Skip to content

Commit

Permalink
Fix for prompt table restore error (NVIDIA#5393) (NVIDIA#5408)
Browse files Browse the repository at this point in the history
* Fix for prompt table restore error

Signed-off-by: Virginia Adams <[email protected]>

* Added more saftey checks

Signed-off-by: Virginia Adams <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added more condition checks

Signed-off-by: Virginia Adams <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Virginia Adams <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>

Signed-off-by: Virginia Adams <[email protected]>
Co-authored-by: Virginia Adams <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: andrusenkoau <[email protected]>
  • Loading branch information
3 people authored and andrusenkoau committed Jan 5, 2023
1 parent d67f3e7 commit 252c58c
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,12 @@ def save_checkpoint_as_nemo_file(self):
self.virtual_prompt_style = current_virtual_prompt_style
self.virtual_prompt_source = current_virtual_prompt_source

# Revert prompt table back to previous state
if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.first_stage_of_pipeline():
for taskname in current_new_tasks:
if taskname in self.prompt_table.prompt_table:
del self.prompt_table.prompt_table[taskname]

with open_dict(self.cfg):
self.cfg.existing_tasks = current_existing_tasks
self.cfg.new_tasks = current_new_tasks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,12 @@ def save_checkpoint_as_nemo_file(self):
self.virtual_prompt_style = current_virtual_prompt_style
self.virtual_prompt_source = current_virtual_prompt_source

# Revert prompt table back to previous state
if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.frozen_model.model.pre_process:
for taskname in current_new_tasks:
if taskname in self.prompt_table.prompt_table:
del self.prompt_table.prompt_table[taskname]

with open_dict(self.cfg):
self.cfg.existing_tasks = current_existing_tasks
self.cfg.new_tasks = current_new_tasks
Expand Down

0 comments on commit 252c58c

Please sign in to comment.