Skip to content

Relax int4wo device mismatch error#2254

Merged
andrewor14 merged 1 commit into
mainfrom
relax-int4wo-device-error
May 23, 2025
Merged

Relax int4wo device mismatch error#2254
andrewor14 merged 1 commit into
mainfrom
relax-int4wo-device-error

Conversation

@andrewor14

Copy link
Copy Markdown
Contributor

Summary: We have an guard preventing users from using a cuda quantized on cpu and vice versa. However, this also prevents users who load their checkpoints on cpu first and then move them to cuda later, which is what torchtune does:

quantize_(model.cuda(), Int4WeightOnlyConfig())
# save checkpoint in cuda
torch.save(model.state_dict(), "my_checkpoint.pt")
# load checkpoint on cpu
# This is what torchtune does: https://github.com/pytorch/torchtune/blob/v0.6.1/torchtune/training/checkpointing/_utils.py#L253
sd = torch.load("my_checkpoint.pt", weights_only=False, map_location="cpu")
# move checkpoint to cuda
for k, v in sd.items():
    sd[k] = v.to("cuda")
# load state_dict in cuda
model.load_state_dict(sd, assign=True)

This use case is safe in that the model was quantized in cuda and ultimately used on cuda. This commit relaxes the error to allow the above use case. More details here: #1117.

Test Plan:
python test/quantization/test_quant_api.py -k test_int4wo_cuda_serialization

@pytorch-bot

pytorch-bot Bot commented May 23, 2025

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2254

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit c31eb1d with merge base 4c6188f (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label May 23, 2025
**Summary:** We have an guard preventing users from using a
cuda quantized on cpu and vice versa. However, this also
prevents users who load their checkpoints on cpu first and
then move them to cuda later, which is what torchtune does:

```
quantize_(model.cuda(), Int4WeightOnlyConfig())
# save checkpoint in cuda
torch.save(model.state_dict(), "my_checkpoint.pt")
# load checkpoint on cpu
# This is what torchtune does: https://github.com/pytorch/torchtune/blob/v0.6.1/torchtune/training/checkpointing/_utils.py#L253
sd = torch.load("my_checkpoint.pt", weights_only=False, map_location="cpu")
# move checkpoint to cuda
for k, v in sd.items():
    sd[k] = v.to("cuda")
# load state_dict in cuda
model.load_state_dict(sd, assign=True)
```

This use case is safe in that the model was quantized in
cuda and ultimately used on cuda. This commit relaxes the
error to allow the above use case. More details here:
#1117.

**Test Plan:**
python test/quantization/test_quant_api.py -k test_int4wo_cuda_serialization
@andrewor14 andrewor14 force-pushed the relax-int4wo-device-error branch from 99df5c1 to c31eb1d Compare May 23, 2025 19:32
@andrewor14 andrewor14 added the topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories) label May 23, 2025

@jerryzh168 jerryzh168 left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

I think we also want to support quantizing in CPU directly, but this is a good first step

@andrewor14 andrewor14 merged commit a776b1f into main May 23, 2025
20 checks passed
# cpu and cuda device: https://github.com/pytorch/ao/issues/1117
if not is_device(torch.device(self.device).type, device):
raise ValueError(
logging.warning(

@jerryzh168 jerryzh168 May 27, 2025

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@andrewor14 currently it will print a lot of warning

maybe we can just remove this? since it's supported now.

or we can change it to something like warn_once

liangel-02 pushed a commit that referenced this pull request Aug 25, 2025
**Summary:** We have an guard preventing users from using a
cuda quantized on cpu and vice versa. However, this also
prevents users who load their checkpoints on cpu first and
then move them to cuda later, which is what torchtune does:

```
quantize_(model.cuda(), Int4WeightOnlyConfig())
# save checkpoint in cuda
torch.save(model.state_dict(), "my_checkpoint.pt")
# load checkpoint on cpu
# This is what torchtune does: https://github.com/pytorch/torchtune/blob/v0.6.1/torchtune/training/checkpointing/_utils.py#L253
sd = torch.load("my_checkpoint.pt", weights_only=False, map_location="cpu")
# move checkpoint to cuda
for k, v in sd.items():
    sd[k] = v.to("cuda")
# load state_dict in cuda
model.load_state_dict(sd, assign=True)
```

This use case is safe in that the model was quantized in
cuda and ultimately used on cuda. This commit relaxes the
error to allow the above use case. More details here:
#1117.

**Test Plan:**
python test/quantization/test_quant_api.py -k test_int4wo_cuda_serialization
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: improvement Use this tag if this PR is an improvement (doesn't fit into any of the other categories)

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants