[release/2.6] [SWDEV-535259] enable miopen channels last 3d for conv and batchnorm#2209
Conversation
|
I think we also need to change this file to update the NHWC check:
or else it won't try to use NHWC for MIOpen convs even with flag set it looks like. |
|
Not sure on the usual process, but we might need to merge this into 2.6 & 2.7 once ready. |
Not so, a few lines later the return statement does an or of this bool var and the 2d bool var which means 3d is enabled if 2d is enabled. All the other changes in this PR were the lines actually preventing channels last 3d. |
Won't the 2D check return false since the memory format is actually |
|
!cherry-pick --onto release/2.7 |
…and batchnorm (#2209) The same env vars PYTORCH_MIOPEN_SUGGEST_NHWC and PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM now also enable channels last 3d in addition to 2d.
|
Created branch autogenerated/release/2.7_cherry-pick_pr-2209 and #2232 |
Additive on top of #2209 3D batchhorm tests (NHWC3D and NCHW3D) NCHW 3D tests: ``` test_batchnorm_3D_inference_NCHW_vs_cpu_float32 (__main__.TestNN) ... ok (0.149s) test_batchnorm_3D_inference_NCHW_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.062s) test_batchnorm_3D_inference_NCHW_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.042s) test_batchnorm_3D_inference_NCHW_vs_native_float32 (__main__.TestNN) ... ok (0.091s) test_batchnorm_3D_inference_NCHW_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.008s) test_batchnorm_3D_inference_NCHW_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.007s) test_batchnorm_3D_inference_NHWC_vs_NCHW_float32 (__main__.TestNN) ... ok (0.028s) test_batchnorm_3D_inference_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... ok (0.010s) test_batchnorm_3D_inference_NHWC_vs_NCHW_mixed_float16 (__main__.TestNN) ... ok (0.010s) test_batchnorm_3D_inference_NHWC_vs_cpu_float32 (__main__.TestNN) ... ok (0.091s) test_batchnorm_3D_inference_NHWC_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.020s) test_batchnorm_3D_inference_NHWC_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.023s) test_batchnorm_3D_inference_NHWC_vs_native_float32 (__main__.TestNN) ... ok (0.010s) test_batchnorm_3D_inference_NHWC_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.015s) test_batchnorm_3D_inference_NHWC_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.007s) test_batchnorm_3D_train_NCHW_vs_cpu_float32 (__main__.TestNN) ... ok (0.011s) test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NCHW_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NCHW_vs_native_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NCHW_vs_native_mixed_bfloat16 (__main__.TestNN) ... skip: bfloat16 NCHW train failed due to native tolerance issue SWDEV-507600 (0.002s) test_batchnorm_3D_train_NCHW_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_NCHW_float32 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_float16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NHWC_vs_cpu_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_native_float32 (__main__.TestNN) ... ok (0.011s) test_batchnorm_3D_train_NHWC_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.004s) ``` Old batchnorm tests will have `2D` it their names ``` test_batchnorm_2D_inference_NCHW_vs_cpu_float32 (__main__.TestNN) ... ok (0.023s) test_batchnorm_2D_inference_NCHW_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.005s) test_batchnorm_2D_inference_NCHW_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.005s) test_batchnorm_2D_inference_NCHW_vs_native_float32 (__main__.TestNN) ... ok (0.104s) test_batchnorm_2D_inference_NCHW_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_inference_NCHW_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_NCHW_float32 (__main__.TestNN) ... ok (0.020s) test_batchnorm_2D_inference_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_inference_NHWC_vs_NCHW_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_inference_NHWC_vs_cpu_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_inference_NHWC_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_native_float32 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_train_NCHW_vs_cpu_float32 (__main__.TestNN) ... ok (0.011s) test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NCHW_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NCHW_vs_native_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NCHW_vs_native_mixed_bfloat16 (__main__.TestNN) ... skip: bfloat16 NCHW train failed due to native tolerance issue SWDEV-507600 (0.002s) test_batchnorm_2D_train_NCHW_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_NCHW_float32 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_float16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NHWC_vs_cpu_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_native_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.004s) ``` Tested in `compute-rocm-dkms-no-npi-hipclang` image build 16062: `compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-dkms-no-npi-hipclang:16062_ubuntu22.04_py3.10_pytorch_lw_release-2.7_1fee1967` Tests can be run with environment variable `MIOPEN_ENABLE_LOGGING_CMD=1` to collect MIOpenDriver commands ``` MIOPEN_ENABLE_LOGGING_CMD=1 python test_nn.py -v -k test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16 test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... MIOpen(HIP): Command [LogCmdBNorm] ./bin/MIOpenDriver bnormbfp16 -n 4 -c 8 -D 2 -H 2 -W 2 -m 1 --forw 1 -b 0 -r 1 -s 1 --layout NDHWC MIOpen(HIP): Command [LogCmdBNorm] ./bin/MIOpenDriver bnormbfp16 -n 4 -c 8 -D 2 -H 2 -W 2 -m 1 --forw 0 -b 1 -s 1 --layout NDHWC MIOpen(HIP): Command [LogCmdBNorm] ./bin/MIOpenDriver bnormbfp16 -n 4 -c 8 -D 2 -H 2 -W 2 -m 1 --forw 1 -b 0 -r 1 -s 1 --layout NCDHW MIOpen(HIP): Command [LogCmdBNorm] ./bin/MIOpenDriver bnormbfp16 -n 4 -c 8 -D 2 -H 2 -W 2 -m 1 --forw 0 -b 1 -s 1 --layout NCDHW ok ``` Co-authored-by: Jeff Daily <jeff.daily@amd.com>
|
!cherry-pick --onto rocm7.0_internal_testing |
…and batchnorm (#2209) The same env vars PYTORCH_MIOPEN_SUGGEST_NHWC and PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM now also enable channels last 3d in addition to 2d.
|
Created branch autogenerated/rocm7.0_internal_testing_cherry-pick_pr-2209 and #2242 |
Additive on top of #2209 3D batchhorm tests (NHWC3D and NCHW3D) NCHW 3D tests: ``` test_batchnorm_3D_inference_NCHW_vs_cpu_float32 (__main__.TestNN) ... ok (0.149s) test_batchnorm_3D_inference_NCHW_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.062s) test_batchnorm_3D_inference_NCHW_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.042s) test_batchnorm_3D_inference_NCHW_vs_native_float32 (__main__.TestNN) ... ok (0.091s) test_batchnorm_3D_inference_NCHW_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.008s) test_batchnorm_3D_inference_NCHW_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.007s) test_batchnorm_3D_inference_NHWC_vs_NCHW_float32 (__main__.TestNN) ... ok (0.028s) test_batchnorm_3D_inference_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... ok (0.010s) test_batchnorm_3D_inference_NHWC_vs_NCHW_mixed_float16 (__main__.TestNN) ... ok (0.010s) test_batchnorm_3D_inference_NHWC_vs_cpu_float32 (__main__.TestNN) ... ok (0.091s) test_batchnorm_3D_inference_NHWC_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.020s) test_batchnorm_3D_inference_NHWC_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.023s) test_batchnorm_3D_inference_NHWC_vs_native_float32 (__main__.TestNN) ... ok (0.010s) test_batchnorm_3D_inference_NHWC_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.015s) test_batchnorm_3D_inference_NHWC_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.007s) test_batchnorm_3D_train_NCHW_vs_cpu_float32 (__main__.TestNN) ... ok (0.011s) test_batchnorm_3D_train_NCHW_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NCHW_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NCHW_vs_native_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NCHW_vs_native_mixed_bfloat16 (__main__.TestNN) ... skip: bfloat16 NCHW train failed due to native tolerance issue SWDEV-507600 (0.002s) test_batchnorm_3D_train_NCHW_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_NCHW_float32 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_float16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_3D_train_NHWC_vs_cpu_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_native_float32 (__main__.TestNN) ... ok (0.011s) test_batchnorm_3D_train_NHWC_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_3D_train_NHWC_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.004s) ``` Old batchnorm tests will have `2D` it their names ``` test_batchnorm_2D_inference_NCHW_vs_cpu_float32 (__main__.TestNN) ... ok (0.023s) test_batchnorm_2D_inference_NCHW_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.005s) test_batchnorm_2D_inference_NCHW_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.005s) test_batchnorm_2D_inference_NCHW_vs_native_float32 (__main__.TestNN) ... ok (0.104s) test_batchnorm_2D_inference_NCHW_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_inference_NCHW_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_NCHW_float32 (__main__.TestNN) ... ok (0.020s) test_batchnorm_2D_inference_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_inference_NHWC_vs_NCHW_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_inference_NHWC_vs_cpu_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_inference_NHWC_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_native_float32 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_inference_NHWC_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.003s) test_batchnorm_2D_train_NCHW_vs_cpu_float32 (__main__.TestNN) ... ok (0.011s) test_batchnorm_2D_train_NCHW_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NCHW_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NCHW_vs_native_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NCHW_vs_native_mixed_bfloat16 (__main__.TestNN) ... skip: bfloat16 NCHW train failed due to native tolerance issue SWDEV-507600 (0.002s) test_batchnorm_2D_train_NCHW_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_NCHW_float32 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NHWC_vs_NCHW_mixed_float16 (__main__.TestNN) ... ok (0.006s) test_batchnorm_2D_train_NHWC_vs_cpu_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_cpu_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_cpu_mixed_float16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_native_float32 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_native_mixed_bfloat16 (__main__.TestNN) ... ok (0.004s) test_batchnorm_2D_train_NHWC_vs_native_mixed_float16 (__main__.TestNN) ... ok (0.004s) ``` Tested in `compute-rocm-dkms-no-npi-hipclang` image build 16062: `compute-artifactory.amd.com:5000/rocm-plus-docker/framework/compute-rocm-dkms-no-npi-hipclang:16062_ubuntu22.04_py3.10_pytorch_lw_release-2.7_1fee1967` Tests can be run with environment variable `MIOPEN_ENABLE_LOGGING_CMD=1` to collect MIOpenDriver commands ``` MIOPEN_ENABLE_LOGGING_CMD=1 python test_nn.py -v -k test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16 test_batchnorm_3D_train_NHWC_vs_NCHW_mixed_bfloat16 (__main__.TestNN) ... MIOpen(HIP): Command [LogCmdBNorm] ./bin/MIOpenDriver bnormbfp16 -n 4 -c 8 -D 2 -H 2 -W 2 -m 1 --forw 1 -b 0 -r 1 -s 1 --layout NDHWC MIOpen(HIP): Command [LogCmdBNorm] ./bin/MIOpenDriver bnormbfp16 -n 4 -c 8 -D 2 -H 2 -W 2 -m 1 --forw 0 -b 1 -s 1 --layout NDHWC MIOpen(HIP): Command [LogCmdBNorm] ./bin/MIOpenDriver bnormbfp16 -n 4 -c 8 -D 2 -H 2 -W 2 -m 1 --forw 1 -b 0 -r 1 -s 1 --layout NCDHW MIOpen(HIP): Command [LogCmdBNorm] ./bin/MIOpenDriver bnormbfp16 -n 4 -c 8 -D 2 -H 2 -W 2 -m 1 --forw 0 -b 1 -s 1 --layout NCDHW ok ``` Co-authored-by: Jeff Daily <jeff.daily@amd.com>
The same env vars PYTORCH_MIOPEN_SUGGEST_NHWC and PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM now also enable channels last 3d in addition to 2d.
Cherry-picked to release/2.7 branch via #2232
Cherry-picked to rocm7.0_internal_testing branch via #2242