Cast bfloat16 to float32 for Numpy conversions#29755
Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
amyeroberts
left a comment
There was a problem hiding this comment.
Thanks! Could you add a test for this?
Is there any way we can keep track of the converted weights? This fixes the entering-into-numpy-hinterland issue, but if I've understood correctly will end up having far larger TF models, which won't be the same on equivalence tests
|
@amyeroberts thankfully, the NumPy values are never assigned directly as TF weights! The way our weight loading works, a TF model is first created with random weights, and then we loop over As a result, the exact dtype we use to load the weights from PyTorch doesn't matter, as long as it doesn't lose any precision (which is why I upcast to float32 for safety here). If the TF model is That said, our support for full-bfloat16 TF models is still a little shaky, but fixing that is probably a separate PR! |
|
@Rocketknight1 Great! All we need is a test then we're good to go 🚀 |
b67d54c to
c830f36
Compare
|
@amyeroberts test added! |
There is no direct conversion from Torch <-> TF, so passing tensors between the two requires us to go through Numpy. Unfortunately, Numpy doesn't support
bfloat16- this patch fixes an issue when TF tries to load a PyTorch checkpoint where weights have been stores inbfloat16, by upcasting the weights tofloat32before the Numpy conversion - they can be downcast later when they're assigned as the TF weights.