Skip to content

Support splitting physical axis in HybridMesh #8381

@tengyifei

Description

@tengyifei

🚀 Feature

HybridMesh is a utility to generate a mapping from accelerator device IDs to logical mesh coordinates. Today it doesn't support splitting a physical axis. So you can't use e.g. ICI mesh shape (64, 4) with TPU topology 16 x 16 (4 < 16).

Motivation

The motivation is to scale a model to multiple pods of Trillium TPUs. For example, we may want to use:

  • ICI mesh shape: (64, 4)
  • DCN mesh shape: (2, 1)

over two pods of Trillium 16x16 TPUs. That requires splitting one of the 16 physical axis to 4x4 in order to map to an ICI axis size of 4.

Pitch

To do this we can probably reference what JAX does these days, since parts of HybridMesh was copied from JAX.

Metadata

Metadata

Assignees

Labels

distributedSPMD and other distributed things.

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions