[SPMD] Hybrid Device mesh creation#5147
Conversation
|
cc @alanwaketan |
jonb377
left a comment
There was a problem hiding this comment.
Looking great Mohit! Could we also add some basic unit tests in https://github.com/pytorch/xla/blob/master/test/spmd/test_xla_sharding.py?
| out[coords[0], coords[1], coords[2]] = d | ||
| return out | ||
|
|
||
| def _create_device_mesh_for_nd_torus( |
There was a problem hiding this comment.
Can you explain how this function optimize the performance according to the TPU physical topology? What's the algorithm? Is it the inner ring has the highest performance, so we should assign the back of the mesh_shape to it?
There was a problem hiding this comment.
Speaking with Mohit offline. The rule is that the TPU topology is always 3D. And the inner 2D tensors have a faster ICI than the ones connect across them. Therefore, we should group the most speed demanding rank, i.e., highest rank of the mesh, to the inner 2D tensors.
There was a problem hiding this comment.
Now that I read more into the code. This algorithm seems quite restrict:
- It only works with mapping a 2D or 3D logical mesh into the 3D physical mesh.
- Then for 3D mesh, I think the logical mesh needs to be a transpose of the physical mesh.
- Then for 2D mesh, it's just trying to map a combination of the axes into each of the dimension of the logical mesh.
After these simple rules, it then makes sure that devices that are physically close to each other are assigned close to each other in the logical mesh as well. For example, assuming the logical mesh is 2D, the devices that are in mesh[0] are always be a 2D slice of the 3D physical mesh.
If my understanding is correct, @khatwanimohit can you polish my comments and make it into the comment of this helper?
There was a problem hiding this comment.
You can add:
This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L64.
ce4f052 to
9c6d8ab
Compare
| hybrid_mesh = xs.HybridMesh( | ||
| ici_mesh_shape=(1, 4), dcn_mesh_shape=(num_slices, 1)) | ||
| print(hybrid_mesh.get_logical_mesh()) | ||
| self.assertEqual(hybrid_mesh.get_logical_mesh().tolist(), |
There was a problem hiding this comment.
Does this result respect the _create_device_mesh_for_nd_torus algorithm?
There was a problem hiding this comment.
Yes, I have confirmed this with the jax's mesh
There was a problem hiding this comment.
Can you make the ici_mesh_shap=(2, 2)? I think that can better show how the algorithm works?
There was a problem hiding this comment.
Changed ici_mesh_shape
alanwaketan
left a comment
There was a problem hiding this comment.
I just noticed that most of the helpers @khatwanimohit you introduced are inspired by https://github.com/google/jax/blob/bfe8acb31e04a540daad3f568239ec0e5c3f0d0f/jax/experimental/mesh_utils.py. And in fact, all those helpers have a very nice docstring to explain what the helpers are doing.
I recommend next time if you are going to import some JAX utils to PyTorch/XLA, you'd better:
- List the source on each utils you imported.
- Import their docstring as well. Those are really critical for the readability of the code.
Also, have you checked the licenses to make sure that you can copy code from JAX into PyTorch/XLA? If not, I can do the research for you.
79336d3 to
572548b
Compare
alanwaketan
left a comment
There was a problem hiding this comment.
Mostly looking good to me. Thanks, @khatwanimohit.
Please address the comments on readability.
| return self.device_ids.reshape(self.mesh_shape) | ||
|
|
||
|
|
||
| # HybridDevice class has been inspired from jax's mesh_utils: https://github.com/google/jax/blob/fc5960f2b8b7a0ef74dbae4e27c5c08ff1564cff/jax/experimental/mesh_utils.py#L4 |
There was a problem hiding this comment.
Can you make it per helper that you imported?
| out[coords[0], coords[1], coords[2]] = d | ||
| return out | ||
|
|
||
| def _create_device_mesh_for_nd_torus( |
There was a problem hiding this comment.
You can add:
This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L64.
| super().__init__(device_ids, mesh_shape, axis_names) | ||
|
|
||
| def _get_physical_tpu_mesh(self, devices: Sequence[Any]) -> np.ndarray: | ||
| r"""Rearrange TPU devices in a slice into a physical mesh.""" |
There was a problem hiding this comment.
Can you add:
1.
This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L172
- The following description of the function:
r"""Rearrange TPU devices in a slice into a physical mesh.
Args:
devices: A list of device logical ordinals in a TPU slice.
Returns:
A np.ndarray of device logical ordinals with shape [global_x, global_y, global_z]. On
v2 and v3, global_z is instead cores_per_chip (i.e., 2).
"""
| physical_mesh, mesh_shape) | ||
| return device_mesh | ||
|
|
||
| def _create_hybrid_device_mesh(self, ici_mesh_shape: Sequence[int], |
There was a problem hiding this comment.
Can you add:
1.
This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L288.
- And the follow function description:
"""Creates a device mesh for hybrid (e.g., ICI and DCN) parallelism.
Args:
ici_mesh_shape: shape of the logical mesh for the faster/inner network, ordered
by increasing network intensity, e.g. [replica, data, mdl] where mdl has
the most network communication requirements.
dcn_mesh_shape: shape of the logical mesh for the slower/outer network,
in the same order as mesh_shape.
Returns:
A np.ndarray of device logical ordinal with ici_mesh_shape * dcn_mesh_shape as its shape
that can be fed into HybridMesh for hybrid parallelism.
"""
| return physical_mesh.transpose(transpose).reshape(mesh_shape), assignment | ||
|
|
||
| # This is imported from JAX: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L231 | ||
| def _create_device_mesh(self, |
There was a problem hiding this comment.
I didn't mention this one given your logic is quite different. I suggest you can undo it.
There was a problem hiding this comment.
Fixed the comment
alanwaketan
left a comment
There was a problem hiding this comment.
LGTM. Thanks, Mohit.
|
The TPU CI broke after this PR merged. Is this related? |
|
Let's have a follow up to disable the test for TPU. You can do that by following: https://github.com/pytorch/xla/blob/master/test/test_zero1.py#L13 |
No description provided.