Prevent silent overflow in lapack worker size calculations.#19288
Prevent silent overflow in lapack worker size calculations.#19288copybara-service[bot] merged 1 commit intojax-ml:mainfrom
Conversation
hawkinsp
left a comment
There was a problem hiding this comment.
This is a good improvement, although it would not catch the original issue, in which I think there was an integer overflow inside LAPACK from multiplying m * n when m and n were close to 2**16. To catch that we'd need some additional defensive checking limiting maximum sizes.
I would probably add that checking in Python here, if desired:
https://github.com/google/jax/blob/main/jaxlib/lapack.py
One other change is needed also. Please add:
copts = ["-fexceptions"],
features = ["-use_header_modules"],
to the lapack_kernels build target.
https://github.com/google/jax/blob/adf05d520a9ec95b2b05b6ba28b627ae34df4a02/jaxlib/cpu/BUILD#L32
By default we disable exceptions in our C++ builds, and they have to be enabled file by file. (Mostly we use absl::Status to communicate failure, but that's probably overkill here, where these workspace size functions are called by pybind11 wrappers and pybind11 wants failures as C++ exceptions.)
|
Actually you're probably right: the workspace size check would probably catch the original problem by itself. So only the BUILD file change is needed. |
Done. Thanks for the hint, @hawkinsp! |
48bb619 to
4496a25
Compare
Add -fexceptions to building lapack_kernels
4496a25 to
3fa1033
Compare
As in the title.
Addresses the issues of wrong results and crashes reported in #10420, #10411, etc.