🐛 Bug
When scanning over modules that contain flash attention pallas kernels, the output contains NaN.
To Reproduce
Run the test in 1. This test will fail and the output of fake_fa_wrapper contains NaN when use_scan=True
Expected behavior
The output of fake_fa_wrapper should be the same regardless if use_scan=True or False.
🐛 Bug
When scanning over modules that contain flash attention pallas kernels, the output contains NaN.
To Reproduce
Run the test in 1. This test will fail and the output of
fake_fa_wrappercontains NaN whenuse_scan=TrueExpected behavior
The output of
fake_fa_wrappershould be the same regardless ifuse_scan=TrueorFalse.