⚠️ [CLAP] Fix dtype of logit scales in init#25682
⚠️ [CLAP] Fix dtype of logit scales in init#25682sanchit-gandhi merged 1 commit intohuggingface:mainfrom
Conversation
| text_config = config.text_config | ||
| audio_config = config.audio_config | ||
|
|
||
| self.logit_scale_a = nn.Parameter(torch.tensor(np.log(config.logit_scale_init_value))) |
There was a problem hiding this comment.
The aforementioned behaviour is a result of the np.log operation defaulting to float64
There was a problem hiding this comment.
Given the original code we might need to init in float64 then cast to float if it makes a difference. No idea if the actual value save is in float64!
There was a problem hiding this comment.
The parameters are initialised in float64 but are stored in float32 in the state dict
ArthurZucker
left a comment
There was a problem hiding this comment.
As mentioned offline, never used in the original repo. Is a bit breaking but it is a bug fix. Let's just add one
|
The documentation is not available anymore as the PR was closed or merged. |
|
Note that in the original repo, the model is always cast to float16 for all training / inference. Thus, they likely never used the model in it's default dtype, and always relied on explicitly casting to float16 |
[CLAP] Fix dtype of logit scales
What does this PR do?
The dtype of the CLAP logit scale parameters was always float64 by default (even if the rest of the model was initialised in float32). This PR fixes the logit scales, such that they respect the default dtype of the model.