Skip to content

Support torch.nn.functional.one_hot#9523

Merged
qihqi merged 2 commits intomasterfrom
xiowei/support_one_hot
Aug 1, 2025
Merged

Support torch.nn.functional.one_hot#9523
qihqi merged 2 commits intomasterfrom
xiowei/support_one_hot

Conversation

@vanbasten23
Copy link
Copy Markdown
Collaborator

@vanbasten23 vanbasten23 commented Jul 31, 2025

Without this change, running the below commands

import torch
import torchax
torchax.enable_globally()

T=1024  # Total number of tokens
D=2048  # hidden_size
L=128  # lora_rank
N=16  # num_loras

loras=torch.randn((N, L, D), dtype=torch.float32, device='jax')
idxs=torch.randint(0, N, (T,), dtype=torch.long, device='jax')
torch.nn.functional.one_hot(idxs, loras.shape[0])

would fail with

>>> torch.nn.functional.one_hot(idxs, loras.shape[0])
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/mnt/disks/persist/github/xla/torchax/torchax/tensor.py", line 247, in __torch_function__
    return func(*args, **(kwargs or {}))
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: one_hot is only applicable to index tensor of type LongTensor.

With the fix, the torch.nn.functional.one_hot succeeds:

>>> torch.nn.functional.one_hot(idxs, loras.shape[0])
Tensor(<class 'jaxlib._jax.ArrayImpl'> [[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 1. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 1.]])

Test plan: pytest -n 0 -s -vv test/test_ops.py -k test_reference_eager_nn_functional_one_hot_cpu_int64

@vanbasten23
Copy link
Copy Markdown
Collaborator Author

vanbasten23 commented Jul 31, 2025

Regarding test, this op seems to be different from other ops where we can remove the op name from the skip_list. Any suggestion on how to test this change?

[Edit]: I'm currently testing via pytest -n 0 -s -vv test/test_ops.py -k test_reference_eager_nn_functional_one_hot_cpu_int64

@vanbasten23 vanbasten23 requested a review from qihqi July 31, 2025 06:03
@vanbasten23 vanbasten23 marked this pull request as ready for review July 31, 2025 06:04
@qihqi qihqi merged commit 7a48185 into master Aug 1, 2025
23 of 24 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants