[float8] Allow specifying arbitrary dtype for each tensor#1326
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1326
Note: Links to docs will display an error until the docs builds have been completed. ❌ 1 New FailureAs of commit 97c9983 with merge base 1a0dbf1 ( NEW FAILURE - The following job has failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
| scaling_type: ScalingType = ScalingType.DYNAMIC | ||
| scaling_granularity: ScalingGranularity = ScalingGranularity.TENSORWISE | ||
| static_scale: Optional[torch.Tensor] = None | ||
| dtype: Optional[torch.dtype] = None |
There was a problem hiding this comment.
nit:
- can we add a comment on what this is used for, and that
Nonemeans the default e4m3|e5m2 value will be used? - optional - thoughts about naming this in a more specific way such as
target_dtype,lowp_dtype, etc?dtypeis a bit ambiguous across torchao unfortunately :(
| # grad_input_hp = grad_output_fp8_axiswise_dim0 @ weight_fp8_tensorwise | ||
| cc_go = CastConfig(scaling_granularity=ScalingGranularity.AXISWISE) | ||
| cc_go = CastConfig( | ||
| scaling_granularity=ScalingGranularity.AXISWISE, dtype=e4m3_dtype |
There was a problem hiding this comment.
nit: maybe we can also add some context in the comments on L353:L363 that it also uses e4m3 for grads?
| NoopFwToFloat8E5M2BwDelayed, | ||
| NoopFwToFloat8E5M2BwDynamic, | ||
| NoopFwToFloat8E5M2BwStatic, | ||
| NoopFwToFloat8BwDelayed, |
| # Calculate the new scales from the updated history stacks | ||
| new_input_scales = amax_history_to_scale_stack( | ||
| fp8_input_amax_history_stack, e4m3_dtype, x_dtype, scale_fn_recipe | ||
| fp8_input_amax_history_stack, input_dtype, x_dtype, scale_fn_recipe |
There was a problem hiding this comment.
will likely have to rebase on top of #1329 which changed this line
| static_scale: Optional[torch.Tensor] = None | ||
| dtype: Optional[torch.dtype] = None | ||
|
|
||
| def short_str(self): |
There was a problem hiding this comment.
can we also add the dtype here, so it appears when we print an instance of Float8Linear? Float8Linear.__extra_repr__ calls this method.
|
This is great! LGTM, had some comments but all are pretty nitty. CI is green - ship it! |
|
Superseded by #1378 |
Stack from ghstack (oldest at bottom):