Conversation
|
cc @tgale96 for review. |
| GroupMetadata = Any # TODO(enriqueps): Clean this up and use a namedtuple | ||
|
|
||
|
|
||
| def _make_group_metadata( |
There was a problem hiding this comment.
I wonder if you could trace the metadata function we have in the library with the GMM to avoid duplicating this tricky bit of code? If not this is fine, just curious.
There was a problem hiding this comment.
Tracing doesn't seem to be an option AFAIK - though, it would be great if I we found a way to call the jax implementation of this method and make this whole implementation leaner. I suggest we do it as a follow up PR.
cc @alanwaketan
There was a problem hiding this comment.
Why can't we just use the JAX version?
There was a problem hiding this comment.
When we use the jax version to do this compute. Let's make sure we pass jax CPU tensors such that this part of compute can be done in cpu instead. As far as I can tell only group_sizes is used and it's 1D, so should be pretty lightweight to compute. We should also benchmark this against the reference_gmm in case this part drastically increase the tracing time. On that matter, we should cache the result.
There was a problem hiding this comment.
Actually we can't do this given the group_sizes is data produced in the middle of the graph. And it means we need to do a graph break.
|
Picking this up, now rebased from master to fix the conflicts. This PR should be ready to be reviewed/merged, I'll run the TPU CI to verify one more time. |
|
@JackCaoG thanks for the comments, this should be ready for another round of review. |
JackCaoG
left a comment
There was a problem hiding this comment.
I added the TPUCI tag and rerun the CI. Feel free to merge once v4 test passed.
|
Hopefully, I can take a look tomorrow after going over all the reading materials w.r.t MoE and megablocks. If I couldn't get to it, feel free to land it as it is. We can always follow up. |
|
Given that the CI (+ TPU CI) is green, I'll go ahead and merge this. I'll follow-up with any fixes if needed. |
alanwaketan
left a comment
There was a problem hiding this comment.
@wonjoolee95 Do you think we can make a follow up PR to simplify this before moving to tgmm?
| lhs: torch.Tensor, | ||
| rhs: torch.Tensor, | ||
| group_sizes: torch.Tensor, | ||
| preferred_element_type: torch.dtype = torch.float32, |
There was a problem hiding this comment.
I think we should omit preferred_element_type, tiling, group_offset, existing_out, transpose_rhs and interpret parameters unless we know the users for sure need that.
| return (group_offsets, group_ids, m_tile_ids), num_tiles | ||
|
|
||
|
|
||
| def _zero_uninitialized_memory( |
There was a problem hiding this comment.
Why can't we just use the JAX version?
| GroupMetadata = Any # TODO(enriqueps): Clean this up and use a namedtuple | ||
|
|
||
|
|
||
| def _make_group_metadata( |
There was a problem hiding this comment.
Why can't we just use the JAX version?
| import numpy as np | ||
|
|
||
|
|
||
| def _validate_args( |
There was a problem hiding this comment.
Why can't we just use the JAX version?
| @@ -0,0 +1,22 @@ | |||
| """Common utilities for Pallas kernels.""" | |||
There was a problem hiding this comment.
This file can be deleted if we directly use the helper from JAX.
| @@ -0,0 +1 @@ | |||
| from .gmm import gmm | |||
There was a problem hiding this comment.
Once we remove all the duplicated code. We can move this method back to custom_kernel.py.
| from jax.experimental import pallas as pl | ||
|
|
||
|
|
||
| class MegabloxTest(unittest.TestCase): |
There was a problem hiding this comment.
Why can't we merge this to test_pallas.py?
| group_offset_torch = torch.from_numpy(np.array(group_offset)).to("xla") | ||
| output_shape = torch.Size([m, n]) | ||
| out = torch_xla._XLAC._xla_tpu_custom_call([ | ||
| num_active_tiles, group_metadata0, group_metadata1, group_metadata2, |
There was a problem hiding this comment.
We should only duplicate the logic to get us these parameters. Anything else can be removed.
| group_offset_torch, lhs, rhs | ||
| ], payload, [output_shape], [preferred_element_type]) | ||
|
|
||
| if existing_out is None and num_current_groups < num_total_groups: |
There was a problem hiding this comment.
As far as I can tell, this is only needed after we have expert parallelism. I still cannot tell if we can get there so far.
|
|
||
| class MegabloxTest(unittest.TestCase): | ||
|
|
||
| def _reference_gmm( |
There was a problem hiding this comment.
Can just do it in torch instead of np?
| start += group_sizes[i] | ||
| return np.array(np.concatenate(out, axis=0)) | ||
|
|
||
| def _group_sizes_strategy(self, m: int, num_groups: int) -> torch.Tensor: |
There was a problem hiding this comment.
As far as I can tell, for us, we just need to make sure our piping is correct and we don't need to ensure gmm itself is correct. That's JAX's job. So, let's remove this and pick one or two cases that are tuned to our wrapper.
| starts = np.concatenate([np.zeros(1, dtype=np.int32), ends_no_final]) | ||
| return torch.from_numpy(ends - starts).to(torch.int32) | ||
|
|
||
| def _tolerances(self, lhs_dtype: torch.dtype, rhs_dtype: torch.dtype, |
There was a problem hiding this comment.
If we use torch, and we don't need this. We can just torch.allclose.
| return 1e-3, 1e-2 # atol, rtol | ||
| return 1e-4, 1e-2 # atol, rtol | ||
|
|
||
| LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]] |
|
|
||
| LutFn = Callable[[int, int, int], Optional[tuple[int, int, int]]] | ||
|
|
||
| def _init_test_cases(self): |
There was a problem hiding this comment.
We might not need all of these.
|
|
||
| lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla') | ||
| rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla') | ||
| group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) |
There was a problem hiding this comment.
This is a CPU tensor!!!!!!!!!!!!!!!!!!!!!!!!
| lhs = torch.rand(m, k, dtype=lhs_dtype).to('xla') | ||
| rhs = torch.rand(num_groups, k, n, dtype=rhs_dtype).to('xla') | ||
| group_sizes = self._group_sizes_strategy(m=m, num_groups=num_groups) | ||
| out = megablox.gmm(lhs, rhs, group_sizes) |
There was a problem hiding this comment.
We always output fp32 in this test case regardless of the input dtypes....
Summary: This is an effort to refactor the code from #6940 and aims to remove useless code in that part. It reduces the amount of code from ~400 lines to ~50 lines. However, a bummer is the original gmm kernel doesn't work at all... It assumes groups_sizes is a cpu tensor. That means we need to materialize this input in order to use this gmm kernel, and that will introduce graph breaks in the computation. I will need yet another follow up to make this code actually functional... Good news is the test cases seem functional, yay... Test Plan: python test/test_megablox.py
Summary: This is an effort to refactor the code from #6940 and aims to remove useless code in that part. It reduces the amount of code from ~400 lines to ~50 lines. However, a bummer is the original gmm kernel doesn't work at all... It assumes groups_sizes is a cpu tensor. That means we need to materialize this input in order to use this gmm kernel, and that will introduce graph breaks in the computation. I will need yet another follow up to make this code actually functional... Good news is the test cases seem functional, yay... Test Plan: python test/test_megablox.py
In this PR, we add megablox kernel. The current implementation adds a new file
megablox_gmm. I plan to merge it tocustom_kernelwith the rest of kernels.