Skip to content

[export] Constant tensors should not get lifted to buffers #110484

@suo

Description

@suo

Today, if we have a constant tensor (e.g. generate by a torch.tensor(0) call), it will get lifted to be a buffer (https://github.com/pytorch/pytorch/blob/main/torch/_export/passes/lift_constant_tensor_pass.py).

This is potentially problematic—the state dict is a well-established abstraction in PyTorch, and mutating it due to export implementation details may be surprising for users. As a concrete example, we have many users that guard model compatibility based on state dict keys. If export adds a key based on some details of how tracing happened, those compatibility checks will break and users will be confused as to why.

Instead, we should handle tensor constants in the tensor specification. Nodes should be able to take constant tensors as inputs, and we should have a table mapping constant tensor handles to in-archive serialized blobs. The serialization/deserialization systems will need to be updated to properly load/save these tensors.

cc @avikchaudhuri @gmagogsfm @zhxchen17 @tugsbayasgalan @angelayi

Metadata

Metadata

Assignees

Labels

export-triagedThis tag is used to tag issues that have been looked by PT2 Export team and determined the next steptriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions