Today, if we have a constant tensor (e.g. generate by a torch.tensor(0) call), it will get lifted to be a buffer (https://github.com/pytorch/pytorch/blob/main/torch/_export/passes/lift_constant_tensor_pass.py).
This is potentially problematic—the state dict is a well-established abstraction in PyTorch, and mutating it due to export implementation details may be surprising for users. As a concrete example, we have many users that guard model compatibility based on state dict keys. If export adds a key based on some details of how tracing happened, those compatibility checks will break and users will be confused as to why.
Instead, we should handle tensor constants in the tensor specification. Nodes should be able to take constant tensors as inputs, and we should have a table mapping constant tensor handles to in-archive serialized blobs. The serialization/deserialization systems will need to be updated to properly load/save these tensors.
cc @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi
Today, if we have a constant tensor (e.g. generate by a
torch.tensor(0)call), it will get lifted to be a buffer (https://github.com/pytorch/pytorch/blob/main/torch/_export/passes/lift_constant_tensor_pass.py).This is potentially problematic—the state dict is a well-established abstraction in PyTorch, and mutating it due to export implementation details may be surprising for users. As a concrete example, we have many users that guard model compatibility based on state dict keys. If export adds a key based on some details of how tracing happened, those compatibility checks will break and users will be confused as to why.
Instead, we should handle tensor constants in the tensor specification. Nodes should be able to take constant tensors as inputs, and we should have a table mapping constant tensor handles to in-archive serialized blobs. The serialization/deserialization systems will need to be updated to properly load/save these tensors.
cc @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi