You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
We now have support for expert parallelism with #842. Currently this is implemented by running ragged_dot on the subset of local experts. In order to implement this efficiently with JAX JIT (i.e. static shapes), we need the group_offset parameter of ragged_dot so we can target only the tokens that got routed to the local experts. However, currently group_offset is not implemented in jax.lax.ragged_dot (see also the discussion in jax-ml/jax#34168 which is worth reading). Therefore we implemented a workaround for the time being in #860 that makes it possible to run the code and is actually surprisingly efficient already. However, it can be optimized by not running expert computations on the extra tokens.
There are several ways to do this:
Implement group_offset for jax.lax.ragged_dot, we are most interested in the GPU case for now, but later it will also be good to do it for TPUs (and for now we can use the fallback code there). There are some pointers on how to do this in Expert parallelism and "Unimplemented group_offset support" for jax.lax.ragged_dot jax-ml/jax#34168 (comment) -- it is conceptually pretty simple but will need some understanding of and wrangling with XLA. It has the advantage that we would likely not need to do extra auto-tuning on top of what XLA already does. Would need some discussions / sync with the JAX/XLA teams, but I think a prototype could be written with the code that is already open-source.
(Likely the simplest option) Implement ragged_dot that supports group_offset via pallas. There are several implementation we could adapt, like
Ideally the implementation would be simple yet performant and supports as many platforms (e.g. hopper, blackwell, older gpu architectures, TPUs) as possible, but it might be hard to satisfy all the requirements, so partial progress in any of these dimensions (and using the current fallback elsewhere) is very welcome.
Improving the performance here (and possibly also improving the performance over vanilla jax.lax.ragged_dot without group_offset will have a huge impact because that kernel is used so much, both for MultiLoRA support as well as expert handling.
We now have support for expert parallelism with #842. Currently this is implemented by running
ragged_doton the subset of local experts. In order to implement this efficiently with JAX JIT (i.e. static shapes), we need thegroup_offsetparameter ofragged_dotso we can target only the tokens that got routed to the local experts. However, currentlygroup_offsetis not implemented injax.lax.ragged_dot(see also the discussion in jax-ml/jax#34168 which is worth reading). Therefore we implemented a workaround for the time being in #860 that makes it possible to run the code and is actually surprisingly efficient already. However, it can be optimized by not running expert computations on the extra tokens.There are several ways to do this:
group_offsetforjax.lax.ragged_dot, we are most interested in the GPU case for now, but later it will also be good to do it for TPUs (and for now we can use the fallback code there). There are some pointers on how to do this in Expert parallelism and "Unimplemented group_offset support" for jax.lax.ragged_dot jax-ml/jax#34168 (comment) -- it is conceptually pretty simple but will need some understanding of and wrangling with XLA. It has the advantage that we would likely not need to do extra auto-tuning on top of what XLA already does. Would need some discussions / sync with the JAX/XLA teams, but I think a prototype could be written with the code that is already open-source.ragged_dotthat supportsgroup_offsetvia pallas. There are several implementation we could adapt, likeIdeally the implementation would be simple yet performant and supports as many platforms (e.g. hopper, blackwell, older gpu architectures, TPUs) as possible, but it might be hard to satisfy all the requirements, so partial progress in any of these dimensions (and using the current fallback elsewhere) is very welcome.
Improving the performance here (and possibly also improving the performance over vanilla
jax.lax.ragged_dotwithoutgroup_offsetwill have a huge impact because that kernel is used so much, both for MultiLoRA support as well as expert handling.