Skip to content

Introduce int32 index_fill and index_copy indices#142160

Open
rpsilva-aws wants to merge 1 commit intopytorch:mainfrom
rpsilva-aws:rpsilva_pt_int32_v2
Open

Introduce int32 index_fill and index_copy indices#142160
rpsilva-aws wants to merge 1 commit intopytorch:mainfrom
rpsilva-aws:rpsilva_pt_int32_v2

Conversation

@rpsilva-aws
Copy link
Copy Markdown
Contributor

Fixes #142090
This PR extends index_fill and index_copy operations to support int32 indices in addition to the existing int64 support:

  1. Memory Efficiency and potential performance for handling a large number of indices, particularly when needing to transfer to accelerator backends. In some cases, the compiler may not support 64-bit integers, and in some cases, may cause conflicting casts when used with TorchXLA. I do not see the immediate need to propagate these to the dim input, since it is tied to the APIs and is relatively negligible.
  2. Framework interoperability: As mentioned above, this gives more flexibility when working with TorchXLA, since some operations require the same type of physical raw representations for the tensors as the XLA tensors. In some cases, for Neuron, XLA generates S32 types which are not compatible with some operations (e.g. casts) when needing to convert across tensors.
  3. Consistency with other APIs, such as index_add and index_select.

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot Bot commented Dec 5, 2024

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/142160

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 763b7e4 with merge base 0318589 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@pytorch-bot pytorch-bot Bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label Dec 5, 2024
@rpsilva-aws
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "release notes: cuda"

@pytorch-bot pytorch-bot Bot added the release notes: cuda release notes category label Dec 5, 2024
@rpsilva-aws
Copy link
Copy Markdown
Contributor Author

FYI: @albanD, reopened #142090

@rpsilva-aws
Copy link
Copy Markdown
Contributor Author

cc: @miladm for visibility, since this is somewhat an extension/related to pytorch/xla#8450.

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Dec 5, 2024

@eqy can you take a look at this?

@rpsilva-aws
Copy link
Copy Markdown
Contributor Author

rpsilva-aws commented Dec 6, 2024

@albanD, @eqy Any way we can consider this in for 2.6? We need it for our Torch XLA counterpart PRs.

@netlify
Copy link
Copy Markdown

netlify Bot commented Dec 6, 2024

Deploy Preview for chimerical-cranachan-793287 ready!

Name Link
🔨 Latest commit 763b7e4
🔍 Latest deploy log https://app.netlify.com/sites/chimerical-cranachan-793287/deploys/67539246757645000865b5a8
😎 Deploy Preview https://deploy-preview-142160--chimerical-cranachan-793287.netlify.app
📱 Preview on mobile
Toggle QR Code...

QR Code

Use your smartphone camera to open QR code link.

To edit notification comments on pull requests, go to your Netlify site configuration.

@rpsilva-aws
Copy link
Copy Markdown
Contributor Author

rpsilva-aws commented Dec 6, 2024

Fixed a typo in the CUDA file, PTAL.

@rpsilva-aws
Copy link
Copy Markdown
Contributor Author

@eqy Can you help restart the WF? Thank you.

Copy link
Copy Markdown
Collaborator

@eqy eqy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to add some trivial asserts that guarantee that the number of elements of the tensor doesn't exceed the max possible value of the provided index type?

@rpsilva-aws
Copy link
Copy Markdown
Contributor Author

rpsilva-aws commented Dec 7, 2024

Would it make sense to add some trivial asserts that guarantee that the number of elements of the tensor doesn't exceed the max possible value of the provided index type?

The indices are for the elements along a specific dimension, and these are implicitly capped by the index type. The total number of elements for the indices can have repeated values if I am not mistaken, so the assertion wouldn't necessarily hold. I also thought of aligning the first dimension type with the indices, but doesn't necessarily need to hold, but we do have an assert for that:

          TORCH_CHECK_INDEX(idx >= -self_dim_size && idx < self_dim_size,
                            "index ", idx, " is out of bounds for dimension ",
                            dim, " with size ", self_dim_size);

@rpsilva-aws
Copy link
Copy Markdown
Contributor Author

@eqy
Copy link
Copy Markdown
Collaborator

eqy commented Dec 7, 2024

Would it make sense to add some trivial asserts that guarantee that the number of elements of the tensor doesn't exceed the max possible value of the provided index type?

The indices are for the elements along a specific dimension, and these are implicitly capped by the index type. The total number of elements for the indices can have repeated values if I am not mistaken, so the assertion wouldn't necessarily hold. I also thought of aligning the first dimension type with the indices, but doesn't necessarily need to hold, but we do have an assert for that:

          TORCH_CHECK_INDEX(idx >= -self_dim_size && idx < self_dim_size,
                            "index ", idx, " is out of bounds for dimension ",
                            dim, " with size ", self_dim_size);

What I mean is, does it make sense to add a check against a dim with size > INT_MAX when 32-bit indices are used?

@rpsilva-aws
Copy link
Copy Markdown
Contributor Author

rpsilva-aws commented Dec 8, 2024

Would it make sense to add some trivial asserts that guarantee that the number of elements of the tensor doesn't exceed the max possible value of the provided index type?

The indices are for the elements along a specific dimension, and these are implicitly capped by the index type. The total number of elements for the indices can have repeated values if I am not mistaken, so the assertion wouldn't necessarily hold. I also thought of aligning the first dimension type with the indices, but doesn't necessarily need to hold, but we do have an assert for that:

          TORCH_CHECK_INDEX(idx >= -self_dim_size && idx < self_dim_size,
                            "index ", idx, " is out of bounds for dimension ",
                            dim, " with size ", self_dim_size);

What I mean is, does it make sense to add a check against a dim with size > INT_MAX when 32-bit indices are used?

I see, thanks for your suggestion - and if I am not mistaken, this already is in place. In the current kernel implementation, when we iterate through the indices, the existing TORCH_CHECK_INDEX macro should already do this bounds checking for each index against dim (size > idx <= INT_MAX, for int32). This should already make it a safe operation for handling 32-bit indices: https://github.com/pytorch/pytorch/pull/142160/files#diff-8aa1a200ec63d23db422aa31b6dca1e6cb372887c43b064ef435210b1b0dec0aR3446. If you instead mean the dim input, then these represent the dim tensor that we operate within, it wouldn't be directly related. Let me know if this helps answer it.

@janeyx99 janeyx99 added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Dec 9, 2024
@rpsilva-aws
Copy link
Copy Markdown
Contributor Author

@eqy do you have other concerns here? Let me know if any, in case this could still go in for 2.6.

@eqy
Copy link
Copy Markdown
Collaborator

eqy commented Dec 10, 2024

My concern was not memory safety, but rather warning/alerting the user in cases where the index width could not be used to fully address a given dimension, e.g., dim size > 2**32 but index type int. In other words, does the "implict capping" lead to potentially nonsensical use-cases that should raise a warning or exception?

@eqy
Copy link
Copy Markdown
Collaborator

eqy commented Dec 10, 2024

A trivial code example:

>>> import torch
>>> a = torch.empty(2**32, device='cuda', dtype=torch.uint8)
>>> a.index_fill_(0, torch.tensor([2**32 - 1], dtype=torch.int64, device='cuda'), -1)
tensor([  0,   0,   0,  ...,   0,   0, 255], device='cuda:0',
       dtype=torch.uint8)
>>> a.index_fill_(0, torch.tensor([2**32 - 1], dtype=torch.int64, device='cuda').to(torch.int32), -1)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
IndexError: index_fill_(): Expected dtype int64 for index.

What would the second case do here? IMO relying on potential UB (e.g., assuming the conversion rolls negative) seems flaky

@github-actions
Copy link
Copy Markdown
Contributor

github-actions Bot commented Feb 8, 2025

Looks like this PR hasn't been updated in a while so we're going to go ahead and mark this as Stale.
Feel free to remove the Stale label if you feel this was a mistake.
If you are unable to remove the Stale label please contact a maintainer in order to do so.
If you want the bot to never mark this PR stale again, add the no-stale label.
Stale pull requests will automatically be closed after 30 days of inactivity.

@github-actions github-actions Bot added the Stale label Feb 8, 2025
@rpsilva-aws
Copy link
Copy Markdown
Contributor Author

@pytorchbot label "no-stale"

@pytorch-bot pytorch-bot Bot added the no-stale label Feb 8, 2025
@jeffhataws
Copy link
Copy Markdown

Keep this alive.

@jeffhataws
Copy link
Copy Markdown

See #141994 for a related RFC.

@jeffhataws
Copy link
Copy Markdown

My concern was not memory safety, but rather warning/alerting the user in cases where the index width could not be used to fully address a given dimension, e.g., dim size > 2**32 but index type int. In other words, does the "implict capping" lead to potentially nonsensical use-cases that should raise a warning or exception?

#142160 (comment)

@rpsilva-aws @eqy can the following be used to check for index bound?

canUse32BitIndexMath

bool canUse32BitIndexMath(const TensorBase& t, int64_t max_elem) {

can_use_32bit_indexing

bool TensorIteratorBase::can_use_32bit_indexing() const {

@jeffhataws
Copy link
Copy Markdown

@rpsilva-aws will you rebase?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

module: cpu CPU specific problem (e.g., perf, algorithm) no-stale open source release notes: cuda release notes category Stale triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants