Support gmm and tgmm trace_pallas caching#7921
Conversation
|
still need to add a test for the cache miss case. |
| global trace_pallas_arg_to_payload | ||
| # implcit assumption here that everything in kwargs is hashable and not a tensor, | ||
| # which is true for the gmm and tgmm. | ||
| hash_key = (kernel, static_argnums, tuple(static_argnames), tuple(jax_args), |
There was a problem hiding this comment.
How does this work with different objects but with the same size, dtype and device?
There was a problem hiding this comment.
jax_args are just meta tensors, I verified that same size will always map to the same hash. we are not hashing the id(static_argnames) so as long as the value is the same it will generate the same hash.
There was a problem hiding this comment.
That's interesting. I guess if it works it works. Then why don't just use @cache?
There was a problem hiding this comment.
my understanding is that @cache cache the input, inputs of this functions are xla tensor, I felt like cache will try to access the value of those tensors. in here I only cache the JAX meta tensor.
Also let me reverify this with the real moe models.
|
verified in the profile that |
was able to reduce the tracing time of gmm from 6ms to 2.4 ms

