Skip to content

Fix torch.full scalar type#7010

Merged
JackCaoG merged 4 commits intomasterfrom
JackCaoG/fix_full_dtype
May 1, 2024
Merged

Fix torch.full scalar type#7010
JackCaoG merged 4 commits intomasterfrom
JackCaoG/fix_full_dtype

Conversation

@JackCaoG
Copy link
Copy Markdown
Collaborator

@JackCaoG JackCaoG commented May 1, 2024

This should fix #6991.

torch.full takes scalar and dtype is an optional parameter. When dtype is not specified, we should respect the scalar's dtype.

Without this change. torch.full((2,2), False) will return a tensor with dtype float32 instead of bool.

@JackCaoG JackCaoG requested review from lsy323 and wonjoo-wj May 1, 2024 21:03
@JackCaoG JackCaoG merged commit 0a54b2b into master May 1, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

GPT2 CasualLM Inference crashes when using transformers v4.39.0

2 participants