Skip to content

[float8] Allow specifying arbitrary dtype for each tensor#1326

Draft
lw wants to merge 7 commits into
gh/lw/2/basefrom
gh/lw/2/head
Draft

[float8] Allow specifying arbitrary dtype for each tensor#1326
lw wants to merge 7 commits into
gh/lw/2/basefrom
gh/lw/2/head

Conversation

@lw

@lw lw commented Nov 22, 2024

Copy link
Copy Markdown
Contributor

[ghstack-poisoned]
lw added a commit that referenced this pull request Nov 22, 2024
@pytorch-bot

pytorch-bot Bot commented Nov 22, 2024

Copy link
Copy Markdown

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1326

Note: Links to docs will display an error until the docs builds have been completed.

❌ 1 New Failure

As of commit 97c9983 with merge base 1a0dbf1 (image):

NEW FAILURE - The following job has failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Nov 22, 2024
[ghstack-poisoned]
lw added a commit that referenced this pull request Nov 22, 2024
@lw lw added the topic: new feature Use this tag if this PR adds a new feature label Nov 22, 2024
[ghstack-poisoned]
lw added a commit that referenced this pull request Nov 22, 2024
[ghstack-poisoned]
lw added a commit that referenced this pull request Nov 22, 2024
[ghstack-poisoned]
lw added a commit that referenced this pull request Nov 22, 2024
[ghstack-poisoned]
lw added a commit that referenced this pull request Nov 22, 2024
Comment thread torchao/float8/config.py
scaling_type: ScalingType = ScalingType.DYNAMIC
scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE
static_scale: Optional[torch.Tensor] = None
dtype: Optional[torch.dtype] = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

  1. can we add a comment on what this is used for, and that None means the default e4m3|e5m2 value will be used?
  2. optional - thoughts about naming this in a more specific way such as target_dtype, lowp_dtype, etc? dtype is a bit ambiguous across torchao unfortunately :(

Comment thread torchao/float8/config.py
# grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise
cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE)
cc_go = CastConfig(
scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can also add some context in the comments on L353:L363 that it also uses e4m3 for grads?

NoopFwToFloat8E5M2BwDelayed,
NoopFwToFloat8E5M2BwDynamic,
NoopFwToFloat8E5M2BwStatic,
NoopFwToFloat8BwDelayed,

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for updating these!

Comment thread torchao/float8/float8_linear_utils.py Outdated
# Calculate the new scales from the updated history stacks
new_input_scales = amax_history_to_scale_stack(
fp8_input_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe
fp8_input_amax_history_stack, input_dtype, x_dtype, scale_fn_recipe

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will likely have to rebase on top of #1329 which changed this line

Comment thread torchao/float8/config.py
static_scale: Optional[torch.Tensor] = None
dtype: Optional[torch.dtype] = None

def short_str(self):

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we also add the dtype here, so it appears when we print an instance of Float8Linear? Float8Linear.__extra_repr__ calls this method.

@vkuzo

vkuzo commented Nov 26, 2024

Copy link
Copy Markdown
Contributor

This is great! LGTM, had some comments but all are pretty nitty. CI is green - ship it!

[ghstack-poisoned]
lw added a commit that referenced this pull request Dec 4, 2024
@lw

lw commented Dec 4, 2024

Copy link
Copy Markdown
Contributor Author

Superseded by #1378

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. topic: new feature Use this tag if this PR adds a new feature

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants