feat: support group_norm, batch_norm, and layer_norm#2330
feat: support group_norm, batch_norm, and layer_norm#2330gs-olive merged 7 commits intopytorch:mainfrom
Conversation
gs-olive
left a comment
There was a problem hiding this comment.
Updates look great - added some suggestions to better follow the Torch schemas for these functions
| if weight is None: | ||
| weight = np.array(1.0) | ||
|
|
||
| if bias is None: | ||
| bias = np.array(0.0) | ||
|
|
||
| if running_mean is None: | ||
| running_mean = np.array(0.0) | ||
|
|
||
| if running_var is None: | ||
| running_var = np.array(1.0) |
There was a problem hiding this comment.
For these, it should be okay to not cast to np.array in the converter (instead leave them as ints or floats), since to_numpy should dictate this casting behavior for ints and floats. Specifically, one small difference is that I think np.array(1.0) has shape () (0D), but to_numpy generally adds a dimension, to make it 1D.
| if weight is None: | ||
| weight = np.array(1.0) | ||
|
|
||
| if bias is None: | ||
| bias = np.array(0.0) |
| if weight is None: | ||
| weight = np.array(1.0) | ||
|
|
||
| if bias is None: | ||
| bias = np.array(0.0) |
There was a problem hiding this comment.
Since line 189 is shape = weight.shape and lines 191 and 192 call weight.reshape and bias.reshape, I think weight and bias shouldn't be scalars.
There was a problem hiding this comment.
I see - in that case, it might be preferable to use to_numpy(0.0), for instance, to get back a default-formatted numpy array for the float default. Additionally, I noticed the code below has some issues:
gamma = to_numpy(weight.reshape(*shape))
##### Above is invalid, since the reshape should apply to the numpy output. It should instead be:
gamma = to_numpy(weight).reshape(shape)The same as the above applies for beta.
Additionally, lines 194 - 196 should be using get_axes_for_reduce_op, as here:
| if weight is None: | ||
| weight = np.array(1.0) | ||
|
|
||
| if bias is None: | ||
| bias = np.array(0.0) |
There was a problem hiding this comment.
I see - in that case, it might be preferable to use to_numpy(0.0), for instance, to get back a default-formatted numpy array for the float default. Additionally, I noticed the code below has some issues:
gamma = to_numpy(weight.reshape(*shape))
##### Above is invalid, since the reshape should apply to the numpy output. It should instead be:
gamma = to_numpy(weight).reshape(shape)The same as the above applies for beta.
Additionally, lines 194 - 196 should be using get_axes_for_reduce_op, as here:
| weight: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], | ||
| bias: Optional[Union[TRTTensor, torch.Tensor, np.ndarray]], |
There was a problem hiding this comment.
TRTTensor would not be a valid input here, for the scale layer
There was a problem hiding this comment.
Do you mean the type of weight and bias in all the three functions should be Optional[Union[torch.Tensor, np.ndarray]]? I see its native function:
func: layer_norm(Tensor input, SymInt[] normalized_shape, Tensor? weight=None, Tensor? bias=None, float eps=1e-05, bool cudnn_enable=True) -> Tensor
There was a problem hiding this comment.
Yes, I think it should be Optional[Union[torch.Tensor, np.ndarray]], because if either of those is a TRTTensor, the computation below would not work (to_numpy can't be called on a TRTTensor)
|
As discussed, add |
|
@gs-olive group_norm was added! |
fd820e6 to
4aa4dce
Compare
gs-olive
left a comment
There was a problem hiding this comment.
Added a few comments. Additionally, if the dynamic shape version of this converter is not passing, that is okay since it is not required for the first pass of support
|
|
||
| scale = cast(torch.Tensor, to_numpy(weight)) / np.sqrt( | ||
| cast(torch.Tensor, to_numpy(running_var)) + cast(float, eps) | ||
| cast(torch.Tensor, to_numpy(running_var)) + eps |
There was a problem hiding this comment.
The torch.Tensor cast can be removed, because to_numpy will return an np.ndarray, so this typing would be incorrect.
| eps_field = trt.PluginField( | ||
| "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 | ||
| ) | ||
| num_groups_filed = trt.PluginField( | ||
| "num_groups", np.array(num_groups), trt.PluginFieldType.INT32 | ||
| ) | ||
|
|
||
| field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed]) | ||
|
|
||
| try: | ||
| # Here's the schema of the plugin: | ||
| # https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin_PluginConfig.yaml | ||
| plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1") | ||
| except AssertionError: | ||
| _LOGGER.error( | ||
| "Unable to find group norm plugin, fall back to TensorRT implementation." | ||
| ) | ||
|
|
||
| layer = network.add_plugin_v2([input, scale, bias], plugin) | ||
| set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir) | ||
|
|
||
| # PyTorch requires three return values: (out, mean, rstd) | ||
| dummy_tensor = torch.tensor(0) | ||
| return layer.get_output(0), dummy_tensor, dummy_tensor |
There was a problem hiding this comment.
Is it possible to avoid invoking the plugin here, and instead use the full implementation, adapting from here: https://github.com/NVIDIA-AI-IOT/torch2trt/blob/36656b614f3fbc067ac673932e2200d7afdae712/torch2trt/converters/group_norm.py#L7-L73? The plugin is not preferable for use in new converters unless it cannot be otherwise supported.
There was a problem hiding this comment.
Alternatively, the TRT layer-based implementation can be the backup for the plugin, etc.
| eps_field = trt.PluginField( | ||
| "eps", np.array(eps, dtype=np.float32), trt.PluginFieldType.FLOAT32 | ||
| ) | ||
| num_groups_filed = trt.PluginField( | ||
| "num_groups", np.array(num_groups), trt.PluginFieldType.INT32 | ||
| ) | ||
|
|
||
| field_collection = trt.PluginFieldCollection([eps_field, num_groups_filed]) | ||
|
|
||
| try: | ||
| # Here's the schema of the plugin: | ||
| # https://github.com/NVIDIA/TensorRT/blob/release/8.6/plugin/groupNormalizationPlugin/GroupNormalizationPlugin_PluginConfig.yaml | ||
| plugin = get_trt_plugin("GroupNormalizationPlugin", field_collection, "1") | ||
| except AssertionError: | ||
| _LOGGER.error( | ||
| "Unable to find group norm plugin, fall back to TensorRT implementation." | ||
| ) | ||
|
|
||
| layer = network.add_plugin_v2([input, scale, bias], plugin) | ||
| set_layer_name(layer, target, f"{name}_GroupNormalizationPlugin", source_ir) | ||
|
|
||
| # PyTorch requires three return values: (out, mean, rstd) | ||
| dummy_tensor = torch.tensor(0) | ||
| return layer.get_output(0), dummy_tensor, dummy_tensor |
There was a problem hiding this comment.
The returned values here should be correct intermediate tensors from during the computation unless we explicitly remove support for nodes which need the other two values
| ) | ||
|
|
||
|
|
||
| @dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default) # type: ignore[misc] |
There was a problem hiding this comment.
Based on the schema of native_layer_norm, it looks like it requires 3 outputs much like native_group_norm. As a comment on both of those - if you want to support it with essentially the same converter as the regular layer norm, you can do the following:
Add this validator
def validator(layer_norm: Node) -> bool:
# Validate only one user, which is a getitem node that accesses the first element in the list
return (len(layer_norm.users) == 1 and
list(node.users)[0].target == operator.getitem and
list(node.users)[0].args[1] == 0))Add this converter
@dynamo_tensorrt_converter(torch.ops.aten.native_layer_norm.default, capability_validator=validator)
def converter(...):
return (regular_layer_norm, )It is important that the above converter returns a tuple, because it will be accessed by getitem, but as you have validated, it will only access the first element. This should also work for group norm.
|
@zewenli98 - when you have the chance, please rebase this PR to the latest |
|
Yes! It's still in progress. Thanks for the reminder! |
8a41cf9 to
4f585d8
Compare
support group norm, and improve batch and layer norms
f628c0c to
84b58dd
Compare
gs-olive
left a comment
There was a problem hiding this comment.
Looks good to me - will update again pending a manual check against SD
gs-olive
left a comment
There was a problem hiding this comment.
Works on SD - looks good to me!
Description
Update
batch_normandlayer_normFixes #2225
Type of change
Checklist: