Skip to content

Support explicit aliasing / buffer donation #8711

@rpsilva-aws

Description

@rpsilva-aws

🚀 Feature

We propose introducing a new API that allows users to explicitly control buffer donation in XLA computations, particularly useful for memory optimization in large models. This feature provides fine-grained control over tensor memory management without depending on functionalization's aliasing behavior.

When working with IR device data nodes (e.g., x) and derived tensors (e.g., y = x + 1), users can now explicitly annotate buffer donations through torch_xla._XLAC.donate_buffer(...) . This provides similar memory benefits to functionalization's x += 1 but with more explicit control, annotated in the underlying lazy graph executor.

A) X donation annotation

In this case, we'd have the signature requiring a tensor we wish to donate. The invariant requires this to be a materialized device data IR node. Hence, this parameter input is included in the final XLA computation's buffer donors, whether or not it was previously included by default (propagated aliasing). Changes to this tensor, with or without functionalization, will retain the same donation intent at the end.

B) X-Y donation annotation

In this case, we'd have the signature requiring a source and destination tensor. This would leverage the functionalization aliasing for XLATensors, by mutating the destination's aliasing to the source's tensor ID (simulating an in-place op). We expect the same invariant as A), in addition to requiring the provided tensors to have the same shape and type. Accessing the source donated tensor will throw an error.

C) Computation buffer donation propagation

At the moment, if we create a user computation, even if setting the buffer donation entry, it will not be honored. Only once we explicitly mark step, that the final computation and respective HLO proto is generated that annotates the buffer donation indices - very simply, by mapping the parameter input IDs to all the live tensor alias IDs. This would require us to maintain and write heuristics for propagating local user computation buffer donation indices to the global context.


We're leaning towards A, since B would need an exhaustive sets of tests when worked interchangeably with functionalization, in addition to A being more evidently honored when building the buffer donation indices. B on the other hand, would still depend on any remaining aliasing propagation, so it wouldn't necessarily guarantee the parameter input donation - which is main motivation behind this issue. C would require significantly more effort to support this.

Motivation

We currently do not support function argument in-place mutations on the scan and gradient accumulation APIs. It is a constraint that function arguments (e.g. gradients, weights) should not be in-place mutated at the application level. Hence, we have:

def body_fn(x, y):
  updated_x = x + 1
  updated_y = y.clone()
  return updated_x, updated_y * 2

If functionalization is enabled, this will disassociate and break the alias_id relation between the tensors, whereas without functionalization, both the tensor ID and aliasing will be different.

If we assume that we're executing the resulting HLO from a large model, and a mark step following the optimizer step - not aliasing gradients and weights can have a large impact on the device memory, requiring a dedicated runtime tensor allocation for the output tensors. For Llama3 8B TP32, this can be up to 6GB / device.

This can be simultaneously used for torch.compile or torch.trace, since dynamo executes the graphs without syncing all live tensors. Hence, without the live tensors, we are unable to leverage LTC to infer the respective needed parameter aliasings.

Additional context

Note: Buffer donation is irreversible - tensors cannot be accessed after donation. Users should carefully consider their computation graph when applying donations.

Metadata

Metadata

Assignees

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions