Enable default buffer donation for gradient accumulation#8758
Enable default buffer donation for gradient accumulation#8758bhavya01 merged 3 commits intopytorch:masterfrom
Conversation
| # Ensure that the input or pre-initialized gradient tensors can be donated | ||
| # after reassigned to the respective model parameters. If the buffer donor | ||
| # is not enabled, then this is a no-op. | ||
| torch_xla._XLAC._set_buffer_donation(param.grad, True) |
There was a problem hiding this comment.
Is the intended behavior that acc_grads shares the same buffer as prev_grad in the body function line 379? Does it help with saving memory?
There was a problem hiding this comment.
Exactly. The resulting gradient will not retain the same aliasing (with or without functionalization), so the output of our XLA graph will be requesting runtime to allocate memory for all the new gradients, as opposed to re-using the donated input gradient tensors.
Unfortunately, similarly to scan, we need to either clone() or return new tensors when wanting to semantically mutated/update the input tensors to the body.
There was a problem hiding this comment.
Can you elaborate more on the retain the same aliasing part? Is it just determined by the id of the tensors? If so, can you point to the code of the data structure. Thanks!
Sorry, I am just trying to ramp up on the buffer donation system.
There was a problem hiding this comment.
No worries. I'll include my understanding if it helps.
Focusing on mark step, where we want to sync XLA tensors, we have two types of donation:
- LTC aliasing:
The buffer donation is based on the alias ID of all the live tensors (https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L1256). For the LTC aliasing, we look at the alias ID of all the parameter data (deduced from https://github.com/pytorch/pytorch/blob/ce805a5ba5a9fdda793060a0fe1514b5fa1ea163/torch/csrc/lazy/core/lazy_graph_executor.cpp#L703 (upstream inherited class of our XLA graph executor), and if that alias matches the respective live tensors, then we donate that buffer to be reused after the mark step (to avoid having to allocate for the outputs), since they do match the same live tensor.
With functionalization, if you do an in-place mutation of a tensor, its tensor ID will always be updated, but the alias ID would be retained (iirc, it was originally implemented to handle views, but there's more context to it):
>>> device = torch_xla.device()
>>> t = torch.randn(10,10).to(device)
>>> print(torch_xla._XLAC._get_xla_tensor_debug_info(t))
XLATensor {
TensorID: 1
AliasID: 1
Device: CPU:0
XLA Shape: f32[10,10]
ShardingSpec: None
IR: [UNKNOWN_SCALAR[]] xla::device_data, xla_shape=f32[10,10]{1,0}, dynamic_dims: (), device=CPU:0
XLAData: None
Tensor on host: None
}
>>> t *= 2
>>> print(torch_xla._XLAC._get_xla_tensor_debug_info(t))
XLATensor {
TensorID: 3
AliasID: 1
Device: CPU:0
XLA Shape: f32[10,10]
ShardingSpec: None
IR: [] aten::mul, xla_shape=f32[10,10]{1,0}, dynamic_dims: ()
XLAData: None
Tensor on host: None
}
>>> t1 = t + 4
>>> print(torch_xla._XLAC._get_xla_tensor_debug_info(t1))
XLATensor {
TensorID: 5
AliasID: 5
Device: CPU:0
XLA Shape: f32[10,10]
ShardingSpec: None
IR: [] aten::add, xla_shape=f32[10,10]{1,0}, dynamic_dims: ()
XLAData: None
Tensor on host: None
}
>>> xm.mark_step() # This will donate the input device data tensor (t) to its output tensor.
>>> print(met.metric_data("InputOutputAliasCount"))
(1, 1.0, ((1740686077.919644, 1.0),))
You will have the following HLO graph:
HloModule IrToHlo.11, entry_computation_layout={(f32[], f32[10,10]{1,0}, f32[])->(f32[10,10]{1,0}, f32[10,10]{1,0})}
ENTRY %IrToHlo.11 (p0.1: f32[], p1.2: f32[10,10], p2.6: f32[]) -> (f32[10,10], f32[10,10]) {
%p1.2 = f32[10,10]{1,0} parameter(1)
%p0.1 = f32[] parameter(0)
%broadcast.3 = f32[10,10]{1,0} broadcast(f32[] %p0.1), dimensions={}
%multiply.4 = f32[10,10]{1,0} multiply(f32[10,10]{1,0} %p1.2, f32[10,10]{1,0} %broadcast.3)
%p2.6 = f32[] parameter(2)
%constant.5 = f32[] constant(1)
%multiply.7 = f32[] multiply(f32[] %p2.6, f32[] %constant.5)
%broadcast.8 = f32[10,10]{1,0} broadcast(f32[] %multiply.7), dimensions={}
%add.9 = f32[10,10]{1,0} add(f32[10,10]{1,0} %multiply.4, f32[10,10]{1,0} %broadcast.8)
ROOT %tuple.10 = (f32[10,10]{1,0}, f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %multiply.4, f32[10,10]{1,0} %add.9)
}
We will have two outputs, but since the input tensor can be donated, we will allow f32[10,10 to be donated, so either of the outputs can reuse that runtime buffer. Since we have an extra live tensor to keep alive and no more tensors to donate, that'll still need its own buffer.
Usually, in training (e.g. Llama3), you should see donation for all the gradients (if mark_step precedes zero_grad), parameter and optimizer states.
With functionalization disabled, both the tensor ID and alias ID will remain unchanged:
>>> device = torch_xla.device()
>>> t = torch.randn(10,10).to(device)
>>> print(torch_xla._XLAC._get_xla_tensor_debug_info(t))
XLATensor {
TensorID: 1
AliasID: 1
Device: CPU:0
XLA Shape: f32[10,10]
ShardingSpec: None
IR: [UNKNOWN_SCALAR[]] xla::device_data, xla_shape=f32[10,10]{1,0}, dynamic_dims: (), device=CPU:0
XLAData: None
Tensor on host: None
}
>>> t *= 2
>>> print(torch_xla._XLAC._get_xla_tensor_debug_info(t))
XLATensor {
TensorID: 1
AliasID: 1
Device: CPU:0
XLA Shape: f32[10,10]
ShardingSpec: None
IR: [] aten::mul, xla_shape=f32[10,10]{1,0}, dynamic_dims: ()
XLAData: None
Tensor on host: None
}
>>> t1 = t + 4
>>> print(torch_xla._XLAC._get_xla_tensor_debug_info(t1))
XLATensor {
TensorID: 2
AliasID: 2
Device: CPU:0
XLA Shape: f32[10,10]
ShardingSpec: None
IR: [] aten::add, xla_shape=f32[10,10]{1,0}, dynamic_dims: ()
XLAData: None
Tensor on host: None
}
>>> xm.mark_step() # This will donate the input device data tensor (t) to its output tensor.
>>> print(met.metric_data("InputOutputAliasCount"))
(1, 1.0, ((1740686077.919644, 1.0),))
Since the both the tensor ID and alias ID are kept the same for in-place mutations, it follows the same result as with functionalization.
- Buffer donation aliasing:
You can find some context on #8711, but it was originally brought in by Jack for dynamo execution. This allows a user to explicitly mark a tensor to be donated (https://github.com/pytorch/xla/blob/master/torch_xla/csrc/xla_graph_executor.cpp#L1334). This was needed for torch.compile / torch.trace, and my understanding is that dynamo executes without syncing all live tensors, so the first part (LTC aliasing) doesn't apply. Hence, it needed those APIs to have control over what should be donated (similar to our case - if we see an in-place mutation), e.g.:
>>> device = torch_xla.device()
>>> t = torch.randn(10,10).to(device)
>>> print(torch_xla._XLAC._get_xla_tensor_debug_info(t))
XLATensor {
TensorID: 1
AliasID: 1
Device: CPU:0
XLA Shape: f32[10,10]
ShardingSpec: None
IR: [UNKNOWN_SCALAR[]] xla::device_data, xla_shape=f32[10,10]{1,0}, dynamic_dims: (), device=CPU:0
XLAData: None
Tensor on host: None
}
>>> t_new = t + 1
>>> print(torch_xla._XLAC._get_xla_tensor_debug_info(t_new))
XLATensor {
TensorID: 3
AliasID: 3
Device: CPU:0
XLA Shape: f32[10,10]
ShardingSpec: None
IR: [] aten::add, xla_shape=f32[10,10]{1,0}, dynamic_dims: ()
XLAData: None
Tensor on host: None
}
>>> Let's say we don't need the live `t` tensor anymore (after mark_step), and we can force it to be donated.
>>> torch_xla._XLAC._xla_set_enable_alias_with_buffer_donor_config(True)
>>> torch_xla._XLAC._set_buffer_donation(t, True)
>>> xm.mark_step() # This will donate the input device data tensor (t) to its output tensor.
>>> print(met.metric_data("InputOutputAliasCount"))
(1, 1.0, ((1740686447.0998409, 1.0),))
>>> print(t) # This will fail as expected (Check failed: handle->HasValue())
In this PR, we need both simultaneously, since we still have LTC aliasing for any other computation around this experimental gradient accumulation API. At the same time, we have control over the backward, and since we explicitly do a copy of the parameter input to the body (for orthogonal, limiting reasons), we need to donate the former tensor - since, as we saw before, t1 will reflect a different alias ID ("not retaining the same aliasing"). The XLA while op will ensure that it reflects the same tensor (input to the loop (init) -> output to the last loop's last iteration).
There was a problem hiding this comment.
Thanks a lot for such a nice explanation! Just one small question, when you say disable functionalization, do you mean that setting XLA_DISABLE_FUNCTIONALIZATION=1
There was a problem hiding this comment.
No worries! Yes, correct.
Follow-up to #8721 (#8711), we extend the intended buffer donation behavior to the experimental gradient accumulation API.
The current implementation relies on torch.autograd.grad + lowering context, which does not support in-place mutation of the parameter model gradients (which are inputs). Hence, since we're re-assigning the parameter gradients to a different aliasing ID (computed with autograd), we will be inevitably unaliasing the gradient input to the graph and the resulting one.
Unfortunately, the buffer donation (user) config is disabled by default, so we currently warn users about this behavior. If we can alleviate the constraint of retaining aliasing and supporting in-place mutation of the inputs, we can remove the built-in donation from the API.
Note: An additional commit to decouple/separate the context manager for the buffer donation config, so it can be easily reused on the application code.