Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CLAP] Fix logit scales dtype for fp16 #25754

Merged
merged 1 commit into from
Aug 25, 2023

Conversation

sanchit-gandhi
Copy link
Contributor

@sanchit-gandhi sanchit-gandhi commented Aug 25, 2023

What does this PR do?

On some hardware, taking torch.log of a tensor in float16 on the CPU fails:

in __init__(self, config)
   1956         audio_config = config.audio_config
   1957 
-> 1958         self.logit_scale_a = nn.Parameter(torch.log(torch.tensor(config.logit_scale_init_value)))
   1959         self.logit_scale_t = nn.Parameter(torch.log(torch.tensor(config.logit_scale_init_value)))
   1960 

RuntimeError: "log_vml_cpu" not implemented for 'Half'

Note that this only failed for me on a Colab T4, but not on a Titan RTX (used to test #25682).

Let's take math.log then convert it to a tensor - this will respect the dtype of the model but not take torch.log of a float16 CPU param.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 25, 2023

The documentation is not available anymore as the PR was closed or merged.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On CPU you can't use half anyway no?

@sanchit-gandhi
Copy link
Contributor Author

sanchit-gandhi commented Aug 25, 2023

Yep try with this:

import torch

torch.tensor(0).half().dtype()

Gives:

tensor(0., dtype=torch.float16)

Is used when we load diffusers pipelines in fp16 (load state dict in fp16 on cpu then move to cuda)

@sanchit-gandhi sanchit-gandhi merged commit 0770ce6 into huggingface:main Aug 25, 2023
@sanchit-gandhi sanchit-gandhi deleted the clap-dtype branch August 25, 2023 12:30
parambharat pushed a commit to parambharat/transformers that referenced this pull request Sep 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants