This page contains instructions for how to set up Ray on GKE with TPUs.
Please follow the official Google Cloud documentation for an introduction to TPUs. In particular, please ensure that your GCP project has sufficient quotas to provision the cluster, see this link for details.
For addition useful information about TPUs on GKE (such as topology configurations and availability), see this page.
In addition, please ensure the following are installed on your local development environment:
- Helm (v3.9.3)
- Kubectl
Here's which versions of this webhook are compatible with which versions of KubeRay. Reading from the bottom, the webhook version stays the same in all subsequent KubeRay versions until the next row's KubeRay version.
| KubeRay version | Webhook version | TPU Generation |
|---|---|---|
| 1.4.0 | 1.3.1 | Added Ironwood (TPU v7x) support for all configurations (multi-slice, multi-host, etc). |
| 1.4.0 | 1.2.5 | Supports TPU versions v4 to v6e. |
| 1.1.1 | 1.2.4 | Supports TPU versions v4 to v6e. |
Pre-built container images are hosted at
us-docker.pkg.dev/ai-on-gke/kuberay-tpu-webhook/tpu-webhook and have a -gke.X suffix.
The KubeRay TPU Webhook automatically bootstraps the TPU environment for TPU clusters. The webhook needs to be installed once per GKE cluster and requires a KubeRay Operator running v1.1+ and GKE cluster version of 1.28+. The webhook requires cert-manager to be installed in-cluster to handle TLS certificate injection. cert-manager can be installed in both GKE standard and autopilot clusters using the following helm commands:
helm repo add jetstack https://charts.jetstack.io
helm repo update
helm install --create-namespace --namespace cert-manager --set installCRDs=true --set global.leaderElection.namespace=cert-manager cert-manager jetstack/cert-managerAfter installing cert-manager, it may take up to two minutes for the certificate to become ready.
Ensure you are authenticated to use artifact registry:
gcloud auth login
gcloud auth configure-docker us-docker.pkg.devInstalling the webhook:
helm install kuberay-tpu-webhook oci://us-docker.pkg.dev/ai-on-gke/kuberay-tpu-webhook-helm/kuberay-tpu-webhookThe above command can be edited with -f or --set flags to pass in a custom values file or key-value pair respectively for the chart (i.e. --set tpuWebhook.image.tag=v1.3.1-gke.2).
For common errors encountered when deploying the webhook, see the Troubleshooting guide.
When you submit a RayCluster resource requesting TPUs, this mutating webhook intercepts the Pod creation and automatically injects the required configurations so that libtpu and JAX can initialize correctly. You do not need to manually configure these in your manifests.
- Network Initialization:
- TPU v4 - v6e: Automatically generates and injects the
TPU_WORKER_HOSTNAMESlist for multi-host networking. The webhook also sets thesubdomainandhostnamefields in the Pod spec. - TPU v7x (Ironwood): In addition to the vars and fields injected in previous versions, also automatically generates and injects the new
TPU_PROCESS_ADDRESSESandTPU_PROCESS_PORTrequired for v7x architecture.TPU_PROCESS_ADDRESSESis identical toTPU_WORKER_HOSTNAMES, but with the container port appended for each address.
- TPU v4 - v6e: Automatically generates and injects the
- Worker Identification: Calculates and injects
TPU_WORKER_IDandTPU_NAME(a unique identifier for the replica group) for multi-host and multi-container coordination. - Multi-Container (NUMA) Support: Natively supports v7x Pods that run multiple NUMA-aligned containers, assigning unique ports and IDs to each ML process. It's important to note that multi-node support per Pod with KubeRay is experimental.
- Megascale (Multi-Slice) Support: If
MEGASCALE_NUM_SLICESis set explicitly in the Pod spec of your Ray container, the webhook automatically calculates and injectsMEGASCALE_SLICE_ID,MEGASCALE_COORDINATOR_ADDRESS, andMEGASCALE_PORT. If utilizing the JaxTrainer in Ray Train,MEGASCALE_NUM_SLICESand related env vars are calculated for you based on the value ofnum_workers,accelerator_type, andtopologyand set automatically at runtime. - Device Plugin Routing: Injects
TPU_DEVICE_PLUGIN_HOST_IPandTPU_DEVICE_PLUGIN_ADDRto ensure the container communicates with the correct node-level hardware plugin. These environment variables are utilized in Ray to scrape per-node metrics like Tensor Core utilization, HBM utilization, TPU duty cycle, and memory usage which are then viewable on the Ray Dashboard. See View TPU metrics on the Ray Dashboard.
In addition to automatically injecting environment variables, the webhook also acts as a validating admission controller. It analyzes your RayCluster custom resource upon submission and will reject the creation of the cluster if the configurations of your TPU worker groups are invalid.
The webhook evaluates each workerGroupSpec against the following rules:
- Non-TPU Workloads are Ignored: If a worker group's containers do not request
google.com/tpuresources, the webhook immediately admits them without further checks. - Missing NumOfHosts: If
numOfHostsis set to0or omitted for a TPU multi-host worker group (determined from the topology and accelerator type), the cluster is rejected.numOfHostsdefaults to1in KubeRay. - Missing Node Selectors: If a TPU worker group is missing the
cloud.google.com/gke-tpu-topologynode selector the cluster is rejected. - Strict Topology Validation: The webhook strictly enforces that the number of physical TPU hosts requested matches your requested physical topology. It calculates this using the following formula:
- Expected Hosts:
max(Total Chips / Chips Per Host, 1) - If the calculated
Expected Hostsdoes not exactly match thenumOfHostsdefined in yourworkerGroupSpec, the cluster is rejected with the error:"Number of workers in worker group not equal to specified topology". - Example: If your node selector specifies a
2x2x2topology (8 total chips) and your container requests4TPUs (google.com/tpu: "4"), yournumOfHostsmust be set to2.
- Expected Hosts:
To install the KubeRay TPU webhook from source:
git clone https://github.com/ai-on-gke/kuberay-tpu-webhookcd kuberay-tpu-webhookmake deploy- this will create the webhook deployment, configs, and service in the "ray-system" namespace
- to change the namespace, edit the "namespace" value in each .yaml in deployments/ and certs/
make deploy-cert
You can find sample TPU cluster manifests for single-host and multi-host here.
For a quick-start guide to using TPUs with KubeRay, see Use TPUs with KubeRay.
- Save the following to a local file (e.g.
test_tpu.py):
import ray
ray.init(
runtime_env={
"pip": [
"jax[tpu]",
"-f https://storage.googleapis.com/jax-releases/libtpu_releases.html",
]
}
)
@ray.remote(resources={"TPU": 4})
def tpu_cores():
import jax
return "TPU cores:" + str(jax.device_count())
num_workers = 4
result = [tpu_cores.remote() for _ in range(num_workers)]
print(ray.get(result))kubectl port-forward svc/RAYCLUSTER-NAME-head-svc dashboard &whereRAYCLUSTER-NAMEis the.metadata.nameof the RayCluster in the cluster manifest you used.RAY_ADDRESS=http://localhost:8265 ray job submit --runtime-env-json='{"working_dir": "."}' -- python test_tpu.py
For a more advanced workload running Stable Diffusion on TPUs, see here. For an example of serving a LLM with TPUs, RayServe, and KubeRay, see here.