Skip to content

FLOPS Roofline Analysis Feature for PyTorch Profiler.#46506

Closed
xuzhao9 wants to merge 1 commit intopytorch:masterfrom
xuzhao9:xzhao9/roofline-analysis
Closed

FLOPS Roofline Analysis Feature for PyTorch Profiler.#46506
xuzhao9 wants to merge 1 commit intopytorch:masterfrom
xuzhao9:xzhao9/roofline-analysis

Conversation

@xuzhao9
Copy link
Copy Markdown
Contributor

@xuzhao9 xuzhao9 commented Oct 17, 2020

FLOPs Roofline Analysis Feature for PyTorch Profiler.

Summary:

Currently, PyTorch Profiler lacks the ability to measure the FLOPs of operators, such as mm and conv.
FLOPs are helpful to estimate the computation complexity of the operators.
For now, we use input shapes to estimate the number of floating pointer operations.
In the future, we may compute this information by tracking hardware counters.

Test Plan:
Run python test/test_profiler_flops.py -k test_flops. The test will print a profiler table with "FLOPS" column, like the following:


                    Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                   Input Shapes        MFLOPS

            aten::matmul         0.06%      57.653us        82.97%      79.310ms      79.310ms             1                 [[40, 33, 1, 243], [243, 243]]            --
                aten::mm        82.84%      79.186ms        82.86%      79.204ms      79.204ms             1                      [[1320, 243], [243, 243]]       984.323
            aten::conv2d         0.04%      36.345us        16.06%      15.347ms      15.347ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [  44065010.318
       aten::convolution         0.02%      16.016us        16.02%      15.310ms      15.310ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [            --
      aten::_convolution         0.07%      63.855us        16.00%      15.294ms      15.294ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [            --
aten::mkldnn_convolution        15.89%      15.188ms        15.93%      15.225ms      15.225ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [            --
              aten::relu         0.10%      98.223us         0.64%     612.157us     306.079us             2                             [[40, 33, 1, 243]]            --
         aten::threshold         0.49%     465.416us         0.54%     513.934us     256.967us             2                     [[40, 33, 1, 243], [], []]            --
              aten::add_         0.29%     279.301us         0.29%     279.301us     279.301us             1                  [[40, 33, 1, 243], [243], []]            --
             aten::empty         0.10%      99.113us         0.10%      99.113us      24.778us             4                       [[], [], [], [], [], []]            --

Self CPU time total: 95.584ms

.

Ran 1 test in 0.176s

For now, we only provide FLOPs calculation for aten::conv2d and aten::mm operators.

@xuzhao9 xuzhao9 requested review from ilia-cher and malfet October 17, 2020 02:24
@xuzhao9 xuzhao9 self-assigned this Oct 17, 2020
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Oct 17, 2020

💊 CI failures summary and remediations

As of commit 1e0fcba1a1 (more details on the Dr. CI page):


  • 5/5 failures possibly* introduced in this PR
    • 2/5 non-CircleCI failure(s)---

3 failures not recognized by patterns:

Job Step Action
CircleCI pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc5_4_build (Optional) Merge target branch 🔁 rerun
CircleCI pytorch_linux_xenial_py3_6_gcc5_4_build (Optional) Merge target branch 🔁 rerun
CircleCI pytorch_xla_linux_bionic_py3_6_clang9_build (Optional) Merge target branch 🔁 rerun

Extra GitHub checks: 2 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 5 times.

@xuzhao9 xuzhao9 removed request for albanD and apaszke October 17, 2020 02:24
@xuzhao9 xuzhao9 marked this pull request as draft October 17, 2020 02:24
@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented Oct 17, 2020

💊 CI failures summary and remediations

As of commit 6fa25f5 (more details on the Dr. CI page):


  • 1/2 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)
  • 1/2 broken upstream at merge base ea4ccc7 on Dec 17 from 2:42pm to 4:02pm

🚧 1 fixed upstream failure:

These were probably caused by upstream breakages that were already fixed.

Please rebase on the viable/strict branch (expand for instructions)

If your commit is older than viable/strict, run these commands:

git fetch https://github.com/pytorch/pytorch viable/strict
git rebase FETCH_HEAD

Check out the recency history of this "viable master" tracking branch.


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

This comment has been revised 167 times.

@xuzhao9 xuzhao9 force-pushed the xzhao9/roofline-analysis branch 2 times, most recently from 6535ac7 to 9443759 Compare October 20, 2020 17:20
Copy link
Copy Markdown
Contributor

@ilia-cher ilia-cher left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LG, a few comments inline and a high-level comment:
how about passing extra args back into python? then we don't need flops code in C++ and can just implement flops compute as a python module;
we could also find these args useful for other purposes

Comment thread test/test_profiler_flops.py Outdated
Comment thread torch/csrc/autograd/profiler.h Outdated
@ilia-cher
Copy link
Copy Markdown
Contributor

btw we could also expand the list of ops we set extra args for, e.g. some element-wise ops (add, mult) - not blocking but could be easy to include too

@ilia-cher ilia-cher self-requested a review October 21, 2020 21:15
@ilia-cher
Copy link
Copy Markdown
Contributor

discussed offline: we can export extra_args to python in the follow ups, and keep the logic in C++, let's move the flops logic and extracting extra args logic into a separate .h/.cpp

@ppwwyyxx
Copy link
Copy Markdown
Collaborator

ppwwyyxx commented Nov 6, 2020

FYI: we've implemented flop counting for a few key aten operations at https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/jit_handles.py

One note: we deliberately decided to print warnings for every operator that's not counted, unless the operator is explicitly ignored - because silently ignored operations give users a silent false sense that the model has low flops.

@ilia-cher
Copy link
Copy Markdown
Contributor

ilia-cher commented Nov 6, 2020

FYI: we've implemented flop counting for a few key aten operations at https://github.com/facebookresearch/fvcore/blob/master/fvcore/nn/jit_handles.py

One note: we deliberately decided to print warnings for every operator that's not counted, unless the operator is explicitly ignored - because silently ignored operations give users a silent false sense that the model has low flops.

thanks! might be useful for us, atm we use formulas code in C++, but python is also an option. Also, having formulas in python, would make it easy to impl. them in C++ too.

@ppwwyyxx
Copy link
Copy Markdown
Collaborator

ppwwyyxx commented Nov 6, 2020

There is a few reasons why in fvcore we prefer to count flops in python:

  • flop is sometimes ambiguously defined and we allow users to provide their own formulas
  • allow users to provide formulas for custom operators
  • potentially allow more detailed per-module flop-counting, when combined together with the module's forward hook (not yet available in fvcore, but available internally). To researchers, this is much more informative than an overall number for the whole model

@xuzhao9 xuzhao9 force-pushed the xzhao9/roofline-analysis branch from 9443759 to 074c0f2 Compare November 9, 2020 21:43
@xuzhao9 xuzhao9 changed the title [WIP] FLOPs Roofline Analysis Feature for PyTorch Profiler. FLOPs Roofline Analysis Feature for PyTorch Profiler. Nov 9, 2020
@xuzhao9 xuzhao9 marked this pull request as ready for review November 9, 2020 21:46
@xuzhao9 xuzhao9 force-pushed the xzhao9/roofline-analysis branch 2 times, most recently from 5ee9a3a to 79412b2 Compare November 9, 2020 22:41
Comment thread test/test_profiler.py Outdated
Comment thread test/test_profiler.py Outdated
Comment thread tools/build_variables.bzl Outdated
Comment thread torch/autograd/profiler.py Outdated
Comment thread torch/autograd/profiler.py Outdated
Comment thread torch/csrc/autograd/profiler_extra.h Outdated
Comment thread torch/csrc/autograd/profiler.h Outdated
Comment thread torch/csrc/autograd/profiler.cpp Outdated
Comment thread torch/autograd/profiler.py Outdated
Comment thread torch/autograd/profiler.py Outdated
@ilia-cher
Copy link
Copy Markdown
Contributor

also please update the test section with new test output

@ilia-cher
Copy link
Copy Markdown
Contributor

ilia-cher commented Nov 11, 2020

we should check the type of the inputs to make sure that we compute actual FLOPs (probably just adding a check would be enough)

@xuzhao9 xuzhao9 force-pushed the xzhao9/roofline-analysis branch from 8bdfdd4 to d41f150 Compare December 17, 2020 02:16
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ilia-cher has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@codecov
Copy link
Copy Markdown

codecov Bot commented Dec 17, 2020

Codecov Report

Merging #46506 (e00871b) into master (001ff3a) will increase coverage by 0.01%.
The diff coverage is 84.93%.

@@            Coverage Diff             @@
##           master   #46506      +/-   ##
==========================================
+ Coverage   80.60%   80.62%   +0.01%     
==========================================
  Files        1879     1880       +1     
  Lines      202892   203412     +520     
==========================================
+ Hits       163543   163993     +450     
- Misses      39349    39419      +70     

@xuzhao9 xuzhao9 force-pushed the xzhao9/roofline-analysis branch from d41f150 to e00871b Compare December 17, 2020 16:13
Copy link
Copy Markdown
Contributor

@malfet malfet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, although please consider getting rid of all string literals in the code, for example instead of typing "mat1_size", "mat2_size", define a `constexpr auto kMat1Size = "mat1_size" in profiler_utils.h and then reference this constant in the code.
Otherwise, you code is can be subject to typos between difference literals that would be detected only during the runtime.

Comment thread torch/csrc/autograd/profiler_utils.cpp Outdated
Comment thread torch/csrc/autograd/profiler_utils.cpp Outdated
Comment thread torch/csrc/autograd/profiler_utils.cpp Outdated
Comment thread torch/csrc/autograd/profiler_utils.cpp Outdated
Comment thread torch/csrc/autograd/profiler_utils.cpp Outdated
@xuzhao9 xuzhao9 force-pushed the xzhao9/roofline-analysis branch from e00871b to 9909319 Compare December 17, 2020 21:59
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuzhao9 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@xuzhao9 xuzhao9 force-pushed the xzhao9/roofline-analysis branch from 9909319 to bd40860 Compare December 17, 2020 22:12
@xuzhao9 xuzhao9 changed the title FLOPs Roofline Analysis Feature for PyTorch Profiler. FLOPS Roofline Analysis Feature for PyTorch Profiler. Dec 17, 2020
@xuzhao9 xuzhao9 force-pushed the xzhao9/roofline-analysis branch 3 times, most recently from 15b859d to 4f8ed9c Compare December 17, 2020 23:03
Summary:

Currently, PyTorch Profiler lacks the ability to measure the FLOPs of operators, such as mm and conv.
FLOPs are helpful to estimate the computation complexity of the operators.
For now, we use input shapes to estimate the number of floating pointer operations.
In the future, we may compute this information by tracking hardware counters.

Test Plan:
Run `python test/test_profiler_flops.py -k test_flops`. The test will print a profiler table with "FLOPS" column, like the following:
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------  ------------
                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                   Input Shapes        MFLOPS
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------  ------------
                aten::matmul         0.06%      57.653us        82.97%      79.310ms      79.310ms             1                 [[40, 33, 1, 243], [243, 243]]            --
                    aten::mm        82.84%      79.186ms        82.86%      79.204ms      79.204ms             1                      [[1320, 243], [243, 243]]       984.323
                aten::conv2d         0.04%      36.345us        16.06%      15.347ms      15.347ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [  44065010.318
           aten::convolution         0.02%      16.016us        16.02%      15.310ms      15.310ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [            --
          aten::_convolution         0.07%      63.855us        16.00%      15.294ms      15.294ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [            --
    aten::mkldnn_convolution        15.89%      15.188ms        15.93%      15.225ms      15.225ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [            --
                  aten::relu         0.10%      98.223us         0.64%     612.157us     306.079us             2                             [[40, 33, 1, 243]]            --
             aten::threshold         0.49%     465.416us         0.54%     513.934us     256.967us             2                     [[40, 33, 1, 243], [], []]            --
                  aten::add_         0.29%     279.301us         0.29%     279.301us     279.301us             1                  [[40, 33, 1, 243], [243], []]            --
                 aten::empty         0.10%      99.113us         0.10%      99.113us      24.778us             4                       [[], [], [], [], [], []]            --
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------  ------------
Self CPU time total: 95.585ms

.
----------------------------------------------------------------------
Ran 1 test in 0.176s

For now, we only provide FLOPs calculation for aten::conv2d and aten::mm operators.
@xuzhao9 xuzhao9 force-pushed the xzhao9/roofline-analysis branch from 4f8ed9c to 6fa25f5 Compare December 17, 2020 23:14
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xuzhao9 has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@xuzhao9 merged this pull request in 573f4aa.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@xuzhao9 merged this pull request in 573f4aa.

@xuzhao9 xuzhao9 deleted the xzhao9/roofline-analysis branch December 18, 2020 15:40
@ngimel
Copy link
Copy Markdown
Collaborator

ngimel commented Dec 24, 2020

Profiler now warns when profiling with (record_shapes=False, with_flops=False). This is default mode, and should not warn. Likely because of these warnings, performance regressed (on fastrnns-benchmark 12s instead of 9s pre-PR).

@ilia-cher
Copy link
Copy Markdown
Contributor

sending a fix #49896

kernel_sizes[2], kernel_sizes[3]);

// grouping is NOT properly handled yet
return conv2d_multiply_factor * minibatch * input_h * input_w * kernel_h * kernel_w * in_channels * out_channels;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To properly handle stride, pad, distillation, and so on, we should use output_h * output_w instead of input_h * input_w .

for(int64_t dim : mat2_size) {
flops *= dim;
}
return flops;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we have an extra factor of 2 to account for multiplication and addition?

std::tie(out_channels, std::ignore, kernel_h, kernel_w) = std::make_tuple(kernel_sizes[0], kernel_sizes[1],
kernel_sizes[2], kernel_sizes[3]);

// grouping is NOT properly handled yet
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Handling group is as easy as just dividing the flop by extra_args.at(kGroups) or is there any corner case I didn't think about?

hwangdeyu pushed a commit to hwangdeyu/pytorch that referenced this pull request Jan 6, 2021
Summary:
FLOPs Roofline Analysis Feature for PyTorch Profiler.

Currently, PyTorch Profiler lacks the ability to measure the FLOPs of operators, such as mm and conv.
FLOPs are helpful to estimate the computation complexity of the operators.
For now, we use input shapes to estimate the number of floating pointer operations.
In the future, we may compute this information by tracking hardware counters.

Pull Request resolved: pytorch#46506

Test Plan:
Run `python test/test_profiler_flops.py -k test_flops`. The test will print a profiler table with "FLOPS" column, like the following:
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------  ------------
                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                   Input Shapes        MFLOPS
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------  ------------
                aten::matmul         0.06%      57.653us        82.97%      79.310ms      79.310ms             1                 [[40, 33, 1, 243], [243, 243]]            --
                    aten::mm        82.84%      79.186ms        82.86%      79.204ms      79.204ms             1                      [[1320, 243], [243, 243]]       984.323
                aten::conv2d         0.04%      36.345us        16.06%      15.347ms      15.347ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [  44065010.318
           aten::convolution         0.02%      16.016us        16.02%      15.310ms      15.310ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [            --
          aten::_convolution         0.07%      63.855us        16.00%      15.294ms      15.294ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [            --
    aten::mkldnn_convolution        15.89%      15.188ms        15.93%      15.225ms      15.225ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [            --
                  aten::relu         0.10%      98.223us         0.64%     612.157us     306.079us             2                             [[40, 33, 1, 243]]            --
             aten::threshold         0.49%     465.416us         0.54%     513.934us     256.967us             2                     [[40, 33, 1, 243], [], []]            --
                  aten::add_         0.29%     279.301us         0.29%     279.301us     279.301us             1                  [[40, 33, 1, 243], [243], []]            --
                 aten::empty         0.10%      99.113us         0.10%      99.113us      24.778us             4                       [[], [], [], [], [], []]            --
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------  ------------
Self CPU time total: 95.584ms

.
----------------------------------------------------------------------
Ran 1 test in 0.176s

For now, we only provide FLOPs calculation for aten::conv2d and aten::mm operators.

Reviewed By: ezyang

Differential Revision: D25214452

Pulled By: xuzhao9

fbshipit-source-id: 0ae841bd8dbdeb032346dc3d9d38e19875aa1da3
@ilia-cher
Copy link
Copy Markdown
Contributor

@jspark1105 thanks for the feedback! we'll send an update to the formulas, cc. @xuzhao9

@xuzhao9
Copy link
Copy Markdown
Contributor Author

xuzhao9 commented Jan 29, 2021

Thanks for the comments, @jspark1105 ! I have created #51377 to address your comments.

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
FLOPs Roofline Analysis Feature for PyTorch Profiler.

Currently, PyTorch Profiler lacks the ability to measure the FLOPs of operators, such as mm and conv.
FLOPs are helpful to estimate the computation complexity of the operators.
For now, we use input shapes to estimate the number of floating pointer operations.
In the future, we may compute this information by tracking hardware counters.

Pull Request resolved: pytorch#46506

Test Plan:
Run `python test/test_profiler_flops.py -k test_flops`. The test will print a profiler table with "FLOPS" column, like the following:
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------  ------------
                        Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls                                   Input Shapes        MFLOPS
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------  ------------
                aten::matmul         0.06%      57.653us        82.97%      79.310ms      79.310ms             1                 [[40, 33, 1, 243], [243, 243]]            --
                    aten::mm        82.84%      79.186ms        82.86%      79.204ms      79.204ms             1                      [[1320, 243], [243, 243]]       984.323
                aten::conv2d         0.04%      36.345us        16.06%      15.347ms      15.347ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [  44065010.318
           aten::convolution         0.02%      16.016us        16.02%      15.310ms      15.310ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [            --
          aten::_convolution         0.07%      63.855us        16.00%      15.294ms      15.294ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [            --
    aten::mkldnn_convolution        15.89%      15.188ms        15.93%      15.225ms      15.225ms             1  [[40, 16, 18, 260], [33, 16, 18, 18], [33], [            --
                  aten::relu         0.10%      98.223us         0.64%     612.157us     306.079us             2                             [[40, 33, 1, 243]]            --
             aten::threshold         0.49%     465.416us         0.54%     513.934us     256.967us             2                     [[40, 33, 1, 243], [], []]            --
                  aten::add_         0.29%     279.301us         0.29%     279.301us     279.301us             1                  [[40, 33, 1, 243], [243], []]            --
                 aten::empty         0.10%      99.113us         0.10%      99.113us      24.778us             4                       [[], [], [], [], [], []]            --
----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------  ------------
Self CPU time total: 95.584ms

.
----------------------------------------------------------------------
Ran 1 test in 0.176s

For now, we only provide FLOPs calculation for aten::conv2d and aten::mm operators.

Reviewed By: ezyang

Differential Revision: D25214452

Pulled By: xuzhao9

fbshipit-source-id: 0ae841bd8dbdeb032346dc3d9d38e19875aa1da3
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants