Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

config torch to avoid graph breaks caused by logger #6999

Merged
merged 1 commit into from
Feb 24, 2025

Conversation

ShellyNR
Copy link
Contributor

@ShellyNR ShellyNR commented Feb 4, 2025

Following changes in Pytorch trace rules , my previous PR to avoid graph breaks caused by logger is no longer relevant. So instead I've added this functionality to torch dynamo - pytorch/pytorch@16ea0dd
This commit allows the user to config torch to ignore logger methods and avoid associated graph breaks.

To enable ignore logger methods - os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
To ignore logger methods except for a specific method / methods (for example, info and isEnabledFor) - os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
and os.environ["LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE"] = "info, isEnabledFor"

@ShellyNR ShellyNR requested a review from tjruwase as a code owner February 4, 2025 17:43
@tohtana
Copy link
Contributor

tohtana commented Feb 4, 2025

Thank you @ShellyNR, this is a very clean approach.
I would like to approve this, but DCO prevents tests from running because of the migration to deepspeedai. Can you rebase based on the explanation?

@loadams loadams requested a review from tohtana as a code owner February 10, 2025 22:51
@loadams
Copy link
Collaborator

loadams commented Feb 13, 2025

Thank you @ShellyNR, this is a very clean approach. I would like to approve this, but DCO prevents tests from running because of the migration to deepspeedai. Can you rebase based on the explanation?

I believe I've fixed this for now.

@SNahir SNahir force-pushed the disable_logger_for_PT2.6 branch from 4124fa1 to fa2fb8a Compare February 17, 2025 16:41
@ShellyNR
Copy link
Contributor Author

Hi @tohtana sorry for the delay, can you please review this? Thanks

@tjruwase tjruwase added this pull request to the merge queue Feb 24, 2025
Merged via the queue into deepspeedai:master with commit e1903f0 Feb 24, 2025
10 checks passed
deepcharm pushed a commit to deepcharm/DeepSpeed that referenced this pull request Feb 26, 2025
Following changes in Pytorch trace rules , my previous PR to avoid graph
breaks caused by logger is no longer relevant. So instead I've added
this functionality to torch dynamo -
pytorch/pytorch@16ea0dd
This commit allows the user to config torch to ignore logger methods and
avoid associated graph breaks.

To enable ignore logger methods -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
To ignore logger methods except for a specific method / methods (for
example, info and isEnabledFor) -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
and os.environ["LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE"] = "info,
isEnabledFor"

Signed-off-by: ShellyNR <[email protected]>
Co-authored-by: snahir <[email protected]>
Signed-off-by: Max Kovalenko <[email protected]>
deepcharm pushed a commit to deepcharm/DeepSpeed that referenced this pull request Feb 27, 2025
Following changes in Pytorch trace rules , my previous PR to avoid graph
breaks caused by logger is no longer relevant. So instead I've added
this functionality to torch dynamo -
pytorch/pytorch@16ea0dd
This commit allows the user to config torch to ignore logger methods and
avoid associated graph breaks.

To enable ignore logger methods -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
To ignore logger methods except for a specific method / methods (for
example, info and isEnabledFor) -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
and os.environ["LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE"] = "info,
isEnabledFor"

Signed-off-by: ShellyNR <[email protected]>
Co-authored-by: snahir <[email protected]>
Signed-off-by: Max Kovalenko <[email protected]>
gyou2021 pushed a commit to gyou2021/DeepSpeed that referenced this pull request Feb 28, 2025
Following changes in Pytorch trace rules , my previous PR to avoid graph
breaks caused by logger is no longer relevant. So instead I've added
this functionality to torch dynamo -
pytorch/pytorch@16ea0dd
This commit allows the user to config torch to ignore logger methods and
avoid associated graph breaks.

To enable ignore logger methods -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
To ignore logger methods except for a specific method / methods (for
example, info and isEnabledFor) -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
and os.environ["LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE"] = "info,
isEnabledFor"

Signed-off-by: ShellyNR <[email protected]>
Co-authored-by: snahir <[email protected]>
Signed-off-by: gyou2021 <[email protected]>
tohtana pushed a commit that referenced this pull request Feb 28, 2025
Following changes in Pytorch trace rules , my previous PR to avoid graph
breaks caused by logger is no longer relevant. So instead I've added
this functionality to torch dynamo -
pytorch/pytorch@16ea0dd
This commit allows the user to config torch to ignore logger methods and
avoid associated graph breaks.

To enable ignore logger methods -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
To ignore logger methods except for a specific method / methods (for
example, info and isEnabledFor) -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
and os.environ["LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE"] = "info,
isEnabledFor"

Signed-off-by: ShellyNR <[email protected]>
Co-authored-by: snahir <[email protected]>
Signed-off-by: Masahiro Tanaka <[email protected]>
shenzheyu pushed a commit to shenzheyu/DeepSpeed that referenced this pull request Mar 5, 2025
Following changes in Pytorch trace rules , my previous PR to avoid graph
breaks caused by logger is no longer relevant. So instead I've added
this functionality to torch dynamo -
pytorch/pytorch@16ea0dd
This commit allows the user to config torch to ignore logger methods and
avoid associated graph breaks.

To enable ignore logger methods -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
To ignore logger methods except for a specific method / methods (for
example, info and isEnabledFor) -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
and os.environ["LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE"] = "info,
isEnabledFor"

Signed-off-by: ShellyNR <[email protected]>
Co-authored-by: snahir <[email protected]>
Signed-off-by: Zheyu SHEN <[email protected]>
ys950902 pushed a commit to ys950902/DeepSpeed that referenced this pull request Mar 6, 2025
Following changes in Pytorch trace rules , my previous PR to avoid graph
breaks caused by logger is no longer relevant. So instead I've added
this functionality to torch dynamo -
pytorch/pytorch@16ea0dd
This commit allows the user to config torch to ignore logger methods and
avoid associated graph breaks.

To enable ignore logger methods -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
To ignore logger methods except for a specific method / methods (for
example, info and isEnabledFor) -
os.environ["DISABLE_LOGS_WHILE_COMPILING"] = "1"
and os.environ["LOGGER_METHODS_TO_EXCLUDE_FROM_DISABLE"] = "info,
isEnabledFor"

Signed-off-by: ShellyNR <[email protected]>
Co-authored-by: snahir <[email protected]>
Signed-off-by: yisheng <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants