Skip to content

Add scalar conversion using avx instructions for half#102140

Closed
CaoE wants to merge 2 commits intopytorch:mainfrom
CaoE:ecao/optimize_half
Closed

Add scalar conversion using avx instructions for half#102140
CaoE wants to merge 2 commits intopytorch:mainfrom
CaoE:ecao/optimize_half

Conversation

@CaoE
Copy link
Collaborator

@CaoE CaoE commented May 24, 2023

Motivation

Scalar conversion between Half and Float on CPU is more time consuming compared to BFloat16 <-> Float. There is no direct data type conversion instruction for single Half value on CPU, so we add scalar conversion with avx instructions for Half to speed up.

Testing

Test maxpool, and compared with the results of #98819.
Single socket (28 cores):

shape fp16 forward / ms bf16 forward / ms fp16 backward / ms bf16 backward / ms speedup ratio (fp16 forward) speedup ratio (fp16 backward)
size: (1, 56, 264, 264), kernel: 3, stride: 1, mem_format: contig 5.07165 5.418 0.5798 0.5123 1.373694951 3.430786
size: (1, 56, 264, 264), kernel: 3, stride: 1, mem_format: CL 1.37455 1.2505 8.8336 9.7684 1.373635008 4.132924
size: (32, 16, 200, 200), kernel: 3, stride: 1, mem_format: contig 28.72 30.7069 3.813 3.75 1.31977124 2.783006
size: (32, 16, 200, 200), kernel: 3, stride: 1, mem_format: CL 4.5783 4.703 4.703 5.1 1.028980189 3.1293
size: (32, 32, 100, 100), kernel: 3, stride: 1, mem_format: contig 13.896 14.8138 1.6635 1.6274 1.298704663 2.982699
size: (32, 32, 100, 100), kernel: 3, stride: 1, mem_format: CL 2.11291 2.1158 2.26778 2.272 0.951105348 3.179012
size: (4, 19, 10, 16, 16), kernel: 3, stride: 1, mem_format: contig 0.4204 0.3843 0.0649 0.0633 2.102711703 1.779492
size: (4, 19, 10, 16, 16), kernel: 3, stride: 1, mem_format: CL3d 0.1134 0.11 0.1476 0.143 2.23042328 3.612398

Single core:

shape fp16 forward / ms bf16 forward / ms fp16 backward / ms bf16 backward / ms speedup ratio (fp16 forward) speedup ratio (fp16 backward)
size: (1, 56, 264, 264), kernel: 3, stride: 1, mem_format: contig 124.413 114.44 10.553 11.2486 1.31395433 3.923844
size: (1, 56, 264, 264), kernel: 3, stride: 1, mem_format: CL 28.99 28.0781 9.5092 10.9258 1.324296999 3.888377
size: (32, 16, 200, 200), kernel: 3, stride: 1, mem_format: contig 640.8276 591.964 59.18776 60.854 1.334956391 3.704458
size: (32, 16, 200, 200), kernel: 3, stride: 1, mem_format: CL 88.57 90.214 54.358 59.205 1.031258214 3.75285
size: (32, 32, 100, 100), kernel: 3, stride: 1, mem_format: contig 318.6197 285.155 28.4999 29.4387 1.315298144 3.759747
size: (32, 32, 100, 100), kernel: 3, stride: 1, mem_format: CL 31.3981 34.0544 25.6557 28.7811 1.068505738 3.841587
size: (4, 19, 10, 16, 16), kernel: 3, stride: 1, mem_format: contig 8.87882 8.207 0.386056 0.3939 1.567866 3.50387
size: (4, 19, 10, 16, 16), kernel: 3, stride: 1, mem_format: CL3d 2.4167 2.38295 0.3769 0.4066 1.39402491 3.30061

cc @jgong5 @mingfeima @XiaobingSuper @sanchitintel @ashokei @jingxu10 @voznesenskym @penguinwu @EikanWang @Guobing-Chen @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @ngimel @yf225 @chenyang78 @kadeng @muchulee8 @aakhundov

@pytorch-bot
Copy link

pytorch-bot bot commented May 24, 2023

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/102140

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit b7079ee with merge base 3a79621 (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@github-actions github-actions bot added the module: cpu CPU specific problem (e.g., perf, algorithm) label May 24, 2023
@CaoE CaoE requested a review from mingfeima May 24, 2023 02:40
@CaoE CaoE added module: half Related to float16 half-precision floats topic: not user facing topic category ciflow/trunk Trigger trunk jobs on your pull request ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR module: inductor labels May 24, 2023
@CaoE
Copy link
Collaborator Author

CaoE commented May 24, 2023

Move #101378 from stack to here.

@CaoE CaoE requested a review from jgong5 May 30, 2023 05:29
Copy link
Collaborator

@mingfeima mingfeima left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now

@CaoE CaoE marked this pull request as ready for review May 31, 2023 05:11
@CaoE CaoE requested a review from ngimel June 14, 2023 08:38
@CaoE
Copy link
Collaborator Author

CaoE commented Jul 7, 2023

@ngimel Could please review this PR ? Thank you.

@CaoE CaoE requested a review from cpuhrsch July 7, 2023 03:18
@CaoE
Copy link
Collaborator Author

CaoE commented Jul 7, 2023

@cpuhrsch Could you please review this PR ? Thanks.

inline namespace CPU_CAPABILITY {

#if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
inline uint16_t float2half_scalar(float val) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why introduce these new functions that are only defined under both of these options?

It looks like they're only used once.

Copy link
Collaborator Author

@CaoE CaoE Jul 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no direct data type conversion instruction for single Half value on CPU, so I add scalar conversion with AVX instructions for Half to speed up when AVX2 or AVX512 is supported on the platform.
If AVX2 and AVX512 are not supported, Half <-> float conversion will fallback to the original implementations.

return float(c10::bit_cast<sycl::half>(x));
#else
#if defined(CPU_CAPABILITY_AVX2) || defined(CPU_CAPABILITY_AVX512)
return at::vec::half2float_scalar(x);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just inline the code here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively we can unify code that works for all CPU types in the half2float, float2half conversion functions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we just inline the code here?

I can not inline at::vec::half2float_scalar or at::vec::float2half_scalar in c10/util/Half-inl.h since PyTorch build systems does not add compilation flags related to AVX or AVX512 for c10/util/Half-inl.h during compilation. It is unable to compile AVX related instructions in c10/util/Half-inl.h. That's why I chose to add these two functions under aten/src/Aten/cpu.

Copy link
Contributor

@cpuhrsch cpuhrsch left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure that CI covers these changes? Can you raise an exception within the conversion and see which tests and which environment fails? I want to make sure we have a CPU with the required capability and a test that exercises this.

@CaoE CaoE force-pushed the ecao/optimize_half branch from 995760a to 1381809 Compare July 12, 2023 03:06
@soulitzer soulitzer added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Jul 14, 2023
@CaoE CaoE force-pushed the ecao/optimize_half branch from 1381809 to b71763f Compare July 27, 2023 08:09
@CaoE
Copy link
Collaborator Author

CaoE commented Jul 27, 2023

@cpuhrsch I added a c++ test case for half conversion.

Can you raise an exception within the conversion and see which tests and which environment fails?

Half conversion would not raise an exception since if the cpu does not support avx2 or avx512 it will fallback to fp16_ieee_from_fp32_value and detail::fp16_ieee_to_fp32_value.

@ZainRizvi
Copy link
Contributor

@pytorchbot revert -m "Sorry, this is still breaking internal builds. Specifically, the dynamo test test_repros.py::DynamicShapesReproTests::test_odict_get_item_index_name" -c ghfirst

@cpuhrsch can help debug this

def test_odict_get_item_index_name(self):
        d = {float: torch.float32, np.float16: torch.float16}
    
        [@torch](https://www.internalfb.com/intern/profile/torch).compile
        def f(x, y1, y2):
            return torch.zeros(5, dtype=d[y1]), torch.zeros(5, dtype=d[y2])
    
>       f(torch.zeros(4), float, np.float16)

@pytorchmergebot
Copy link
Collaborator

@pytorchbot successfully started a revert job. Check the current status here.
Questions? Feedback? Please reach out to the PyTorch DevX Team

@pytorchmergebot
Copy link
Collaborator

@CaoE your PR has been successfully reverted.

pytorchmergebot added a commit that referenced this pull request Aug 21, 2023
This reverts commit 1d6a446.

Reverted #102140 on behalf of https://github.com/ZainRizvi due to Sorry, this is still breaking internal builds. Specifically, the dynamo test test_repros.py::DynamicShapesReproTests::test_odict_get_item_index_name ([comment](#102140 (comment)))
@cpuhrsch cpuhrsch reopened this Aug 21, 2023
@cpuhrsch
Copy link
Contributor

This looks like an unrelated failure. @CaoE could you please rebase and we'll try again?

@CaoE CaoE force-pushed the ecao/optimize_half branch from eee3577 to f9530b4 Compare August 22, 2023 02:10
@CaoE
Copy link
Collaborator Author

CaoE commented Aug 22, 2023

@cpuhrsch Thanks for your help !!! I rebased and also removed apple check by using _cvtss_sh and _cvtsh_ss instead, not sure if it can pass internal building on mac and tests.

@CaoE CaoE force-pushed the ecao/optimize_half branch 3 times, most recently from 6c5f279 to 31c2cc4 Compare August 29, 2023 06:01
@CaoE CaoE added ciflow/mps Run MPS tests (subset of trunk) ciflow/android Trigger android build and test (run_android_test.yml) labels Aug 29, 2023
@pytorch-bot
Copy link

pytorch-bot bot commented Aug 29, 2023

Warning: Unknown label ciflow/android.
Currently recognized labels are

  • ciflow/binaries
  • ciflow/binaries_conda
  • ciflow/binaries_libtorch
  • ciflow/binaries_wheel
  • ciflow/inductor
  • ciflow/inductor-perf-compare
  • ciflow/mps
  • ciflow/nightly
  • ciflow/periodic
  • ciflow/slow
  • ciflow/trunk
  • ciflow/unstable

Please add the new label to .github/pytorch-probot.yml

@CaoE CaoE added the ciflow/binaries Trigger all binary build and upload jobs on the PR label Aug 29, 2023
@CaoE
Copy link
Collaborator Author

CaoE commented Aug 30, 2023

_cvtss_sh and _cvtsh_ss will still meet same issue with xcode_14.3.1 on mac. Add apple check back.

@CaoE CaoE force-pushed the ecao/optimize_half branch from 31c2cc4 to b7079ee Compare August 30, 2023 08:40
@CaoE
Copy link
Collaborator Author

CaoE commented Aug 30, 2023

@pytorchbot merge

@pytorchmergebot
Copy link
Collaborator

Merge started

Your change will be merged once all checks pass (ETA 0-4 Hours).

Learn more about merging in the wiki.

Questions? Feedback? Please reach out to the PyTorch DevX Team

Advanced Debugging
Check the merge workflow status
here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/android Trigger android build and test (run_android_test.yml) ciflow/binaries Trigger all binary build and upload jobs on the PR ciflow/inductor ciflow/mps Run MPS tests (subset of trunk) ciflow/periodic Trigger jobs ran periodically on master (periodic.yml) on the PR ciflow/trunk Trigger trunk jobs on your pull request Merged module: cpu CPU specific problem (e.g., perf, algorithm) module: half Related to float16 half-precision floats open source topic: performance topic category triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Projects

None yet

Development

Successfully merging this pull request may close these issues.