Updating BatchNorm to add training_mode#3333
Conversation
Signed-off-by: neginraoof <neginmr@utexas.edu>
ff21361 to
9641524
Compare
Signed-off-by: neginraoof <neginmr@utexas.edu>
6540d14 to
2c77431
Compare
|
@gramalingam @askhade @jcwchen @hariharans29 |
There was a problem hiding this comment.
Thank you for redoing this change! Backend test data looks good to me.
I don't have the history context here -- last time this change has been reverted due to a lack of verification. Have we verified it now? It would be great if it can be reviewed by training team again before check-in.
| <dl> | ||
| <dt><tt>Y</tt> (differentiable) : T</dt> | ||
| <dd>The output tensor of the same shape as X</dd> | ||
| <dt><tt>mean</tt> (optional, non-differentiable) : T</dt> |
There was a problem hiding this comment.
Is it an error if optional outputs are present and training mode input flows through as false (defaulted value) ?
There was a problem hiding this comment.
Agree we should explicitly indicate whether these extraneous outputs are allowed or not (in the inference mode). If they are allowed, we need to specify what their values will be (or whether they have an undefined value)
There was a problem hiding this comment.
@hariharans29 @gramalingam Do you have a suggestion here? Does it make sense to allow these, but fill with garbage?
There was a problem hiding this comment.
Sure - how about "if the training_mode is false and the optional training related outputs are present, the contents of it are undefined...." ?
| ('b', TensorProto.FLOAT, (4,)), | ||
| ('mean', TensorProto.FLOAT, (4,)), | ||
| ('var', TensorProto.FLOAT, (4,)), | ||
| ('training_mode', TensorProto.BOOL, ())], |
There was a problem hiding this comment.
How about adding one test without the optional training_mode input ?
There was a problem hiding this comment.
Such a test is there in line 1239 for batch_norm.
| auto& scale_input_shape = getInputShape(ctx, 1); | ||
| if (static_cast<int>(scale_input_shape.dim_size()) != 1 || | ||
| !scale_input_shape.dim(0).has_dim_value() || | ||
| static_cast<int>(scale_input_shape.dim(0).dim_value()) != |
There was a problem hiding this comment.
Why throw if it doesn't have a dim_value ? What if it has a dim_param ? Should we check to see if it matches the dim_param of the channel index in the input shape (if that is a dim_param too) ?
| ('var', TensorProto.FLOAT, (4,)), | ||
| ('training_mode', TensorProto.BOOL, ())], | ||
| [make_node('BatchNormalization', ['x', 'scale', 'b', 'mean', 'var', 'training_mode'], | ||
| ['out', 'output_mean', 'output_var', 'saved_mean', 'saved_var'])], |
There was a problem hiding this comment.
How about adding more tests to test robustness of the shape inference method :
- Shapes with dim_params (along with dim_values)
- Shapes with no dim_values and no dim_params ?
- Missing shapes for some inputs
There was a problem hiding this comment.
What is the difference between 2 and 3?
There was a problem hiding this comment.
2 -> shapes are present for inputs but they contain no dim_value or dim_param.
3 -> Shapes are missing altogether for certain inputs.
Does this make sense ?
| Y = (X - mean) / sqrt(var + epsilon) * scale + B | ||
| ``` | ||
|
|
||
| For previous (depreciated) non-spatial cases, implementors are suggested |
There was a problem hiding this comment.
I don't understand this either. If the inputs are required to have some shape, can we explicitly document it here? Is it required to be 2D?
There was a problem hiding this comment.
I think this is a note about how to get the old (deprecated) behavior for the op. There was an attribute in older versions:
'Spatial': If true, compute the mean and variance across per activation. If false, compute the mean and variance across per feature over each mini-batch.
And I think this note specifies how to achieve this behavior.
Input X can have N dimensions.
I can clarify the lines above.
There was a problem hiding this comment.
I think input X is expected to have at least rank 2. In the most common case input will be NCHW, but as noted in the comment to replicate the older non-spatial behavior you can flatten it to N x CHW.
2208eac to
75e2c7a
Compare
Signed-off-by: neginraoof <neginmr@utexas.edu>
75e2c7a to
bc91f64
Compare
Signed-off-by: neginraoof <neginmr@utexas.edu>
|
|
||
| if (ctx.getNumInputs() > 5 && hasInputShape(ctx, 5)) { | ||
| auto& mode_input_shape = getInputShape(ctx, 5); | ||
| // if mode is not scalar or tensor of rank 1, fail shape inference |
There was a problem hiding this comment.
I think that for new op extensions, it is preferable to restrict scalars to be of rank 0.
@postrational what do you think? This is one of the "unified op interface" items. Traditionally, there has been some divergence with some ops using or allowing rank-1 tensors of size 1 for a scalar, with some requiring rank-0 tensors.
| current_mean = ReducedMean(X, axis=all_except_channel_index) | ||
| current_var = ReducedVar(X, axis=all_except_channel_index) | ||
|
|
||
| running_mean = mean * momentum + current_mean * (1 - momentum) |
There was a problem hiding this comment.
Sorry - still not sure if this should be input_mean and input_var instead of mean and var ?
| Output case #1: Y, running_mean, running_var, current_mean, current_var (training_mode=True) | ||
| Output case #2: Y (training_mode=False) | ||
|
|
||
| When training_mode=False, extra outputs are undefined and the user should not depend on those. |
There was a problem hiding this comment.
A note about this as it relates to ORT implementation: if the extra outputs are specified and ignored by the ORT kernel, it is possible that even the memory for these output tensors are not allocated. This can cause subsequent errors with unallocated memory usage (depending on how the memory-planner reuses these buffers). So, it may be preferable to flag the use of extra outputs as error when training_mode=False. It would be helpful to get some input from ORT side on this. (@askhade any thoughts?)
There was a problem hiding this comment.
Here is the ORT implementation: https://github.com/microsoft/onnxruntime/blob/8892ee4b6d343109699ab292e66c2c7a5e41925a/onnxruntime/core/providers/cpu/nn/batch_norm.h#L78
I don't think the backend needs to access memory location of such outputs in inference.
There was a problem hiding this comment.
There was a problem hiding this comment.
@neginraoof While the batch_norm kernel doesn't access the optional tensors in inference mode (at least as I've implemented it in the linked PR), I believe the concern was over the allocation planner, which might try to re-use a buffer that hasn't been allocated properly. I'm in favor of making it invalid to use the extra outputs when in inference mode, since it reduces edge-cases and is generally confusing anyway (why introduce the notion of "undefined behavior" when you can just prevent i t entirely)?
There was a problem hiding this comment.
Btw the discussion in #1042 might be relevant here. I think the consensus (at least for n=2, @gramalingam and I) was that there is not really any use-case for the optional outputs during inference. While this means that we could technically change the spec so that training vs. inference mode is determined purely based on presence of these optional outputs, I don't think changing the behavior of an op based on # of outputs is as clean as just having an attribute for it.
There was a problem hiding this comment.
One challenge is that the "mode" is an input (as opposed to an attribute). I guess this was chosen to allow a training-engine to use the same graph/node, but change the mode dynamically? If we want this ability, then we will have to allow the multiple outputs to be specified. This also means that the issue mentioned above for allocation-planner buffer-reuse is a problem and ORT will need to solve it someway. A trivial fix would be for the kernel to ensure that the output is allocated by calling "Output" even if it's value is not initialized, but this would mean some unnecessary memory allocation in some situations.
There was a problem hiding this comment.
Oh I missed that mode was an input instead of attribute. IIRC in the older (reverted) PR it was an attribute instead. @neginraoof can you clarify the motivation for it being an input?
allow a training-engine to use the same graph/node, but change the mode dynamically
I don't think this is an issue because ORT already has a pre-training graph transform pass that can modify attributes or add outputs. Being able to change mode by setting a graph input doesn't seem like it provides any concrete advantage over just changing the attribute.
There was a problem hiding this comment.
Yes, an attribute would simplify things. It would be great to clarify the motivation to have it as an input, thanks!
There was a problem hiding this comment.
This is the revert PR: https://github.com/onnx/onnx/pull/2750/files
This one shows that training_mode has been an input for the op before, and I kept it based on the old implementations. Also, training_mode for dropout op is an input. So do we want to keep it consistent?
@gramalingam @pranav-prakash
There was a problem hiding this comment.
Ah I see. If it were solely up to me, I think that having it as an attribute is cleaner – and for consistency it's dropout that should be updated to move the "training_mode" to an attribute (there have been various other ops in the past that have had inputs moved as attributes, so this is not really anything unprecedented – and seems to indicate to me that having things as attributes is what schemas generally evolve towards anyway). But I'm just a user so I'll defer to the others on whatever they think is better.
Btw dropout states that
If it [training_mode] is false... and if mask is requested as output it will contain all ones.
So if we do indeed choose to keep it as an input, explicitly specifying the value of the optional output might be better than leaving it as "undefined behavior"
Signed-off-by: neginraoof <neginmr@utexas.edu>
Signed-off-by: neginraoof <neginmr@utexas.edu>
|
@pranav-prakash |
| When training_mode=False, extra outputs are undefined and the user should not depend on those. | ||
| The outputs are updated as follows when training_mode=True: | ||
| ``` | ||
| current_mean = ReducedMean(X, axis=all_except_channel_index) |
There was a problem hiding this comment.
Do you mean ReduceMean instead of ReducedMean? (same for var)
Otherwise the formulas look good to me (I especially like the use of running_mean which is less confusing than the saved_mean that the older spec had).
There was a problem hiding this comment.
Thanks. I can update this.
I updated the part for optional outputs, mentioning this is invalid. What do you think?
There was a problem hiding this comment.
Nice! I think in addition the shape inference could be updated to enforce that no optional outputs are present if training_mode=true
There was a problem hiding this comment.
Sure. Will do.
@pranav-prakash Do you think you can help us test this spec with you ORT implementation? Just to check spec works fine.
There was a problem hiding this comment.
Sure, I can do so (next week though, I'm a bit busy this week). Even with the change in spec I don't there would be be any changes in my PR for ORT BatchNorm training though (aside from checking the newly-added training_mode attribute).
(There's also a separate issue where the pre-existing CUDA BatchNorm training implementation in ORT deviates from the spec in outputting inverse-standard-deviation instead of variance – whose behavior I copied so the tests could remain the same – but that's a discussion for that PR).
Signed-off-by: neginraoof <neginmr@utexas.edu>
Signed-off-by: neginraoof <neginmr@utexas.edu>
Signed-off-by: neginraoof <neginmr@utexas.edu>
|
@postrational @linkerzhang |
|
This PR is approved. Please resolve merge conflicts. |
…xBatchNorm # Conflicts: # docs/Changelog.md Signed-off-by: neginraoof <neginmr@utexas.edu>
|
@postrational Thanks! Conflict is resolved. |
|
@postrational CI looks fine. Please help us merge the PR. Thanks again. |
Updating BatchNorm to add training_mode as input.