Skip to content
This repository was archived by the owner on Mar 3, 2026. It is now read-only.
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 31 additions & 11 deletions torchprime/experimental/torchax_models/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import custom_mesh
import jax
from jax import numpy as jnp
import numpy as np
import splash_attn
import torch
Expand Down Expand Up @@ -142,17 +143,36 @@ def make_weight_shard(weight_meta, slice_index):


def create_sharded_weights(model, mesh, sharding_map):
res = {}
for name, weight_meta in model.state_dict().items():
sharding_spec = sharding_map.get(_process_sharding_name(name))
if sharding_spec is None:
print("Skipping weight:", name)
continue
sharding = NamedSharding(mesh, P(*sharding_spec))
res[name] = jax.make_array_from_callback(
weight_meta.shape, sharding,
functools.partial(make_weight_shard, weight_meta))
return res
name_to_sharding = {
name: NamedSharding(mesh, P(*sharding_map.get(_process_sharding_name(name))))
for name in model.state_dict().keys()
if _process_sharding_name(name) in sharding_map
}

kaiming = jax.nn.initializers.he_uniform(dtype=jnp.bfloat16)

key = jax.random.PRNGKey(0)
key = jax.device_put(key, NamedSharding(mesh, P())) # replicate

@functools.partial(
jax.jit,
out_shardings=name_to_sharding,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Did you test this on 405B? When I did something similar locally, the jit worked for 8B but the graph used >1TiB of HBM on 405B on two slices. The problem I encountered was that each device was still trying to generate the entire weight.

)
def create_weights(rng):
res = {}
for name, weight_meta in model.state_dict().items():
if _process_sharding_name(name) not in sharding_map:
continue
rng, subkey = jax.random.split(rng)
if len(weight_meta.shape) < 2:
res[name] = jax.random.normal(subkey, weight_meta.shape,
interop.jax_view(weight_meta.dtype))
else:
res[name] = kaiming(subkey, weight_meta.shape, interop.jax_view(weight_meta.dtype))
return res

weights = create_weights(key)
return interop.torch_view(weights)


def sharded_device_put(tensor, sharding):
Expand Down