Skip to content

Prototype benchmarking util#38338

Closed
robieta wants to merge 27 commits intomasterfrom
gh/taylorrobie/timeit_benchmark
Closed

Prototype benchmarking util#38338
robieta wants to merge 27 commits intomasterfrom
gh/taylorrobie/timeit_benchmark

Conversation

@robieta
Copy link
Copy Markdown
Contributor

@robieta robieta commented May 12, 2020

This is the prototype for the modular utils that we've been discussing. It is admittedly a large PR, but a good fraction of that is documentation and examples. I've trimmed a bit on the edges since we last discussed this design (for instance Timer is no longer Fuzzer aware), but it's mostly the same.

In addition to the library and hermetic examples, I've included examples.end_to_end which tests #38061 over a variety of shapes, dtypes, degrees of broadcasting, and layouts. (CC @crcrpar) I only did CPU as I'm not set up on a GPU machine yet. Results from my devserver

Key takeaways:

  1. For contiguous Tensors, larger dtypes (fp32 and fp64) and lots of reuse of the mask due to broadcasting, improvements are significant. (Presumably due to better vectorization?)
  2. There is an extra ~1.5 us overhead, which dominates small kernels.
  3. Cases with lower write intensity (int8, lower mask fraction, etc) or non-contiguous seem to suffer.

Hopefully this demonstrates the proof-of-concept for how this tooling can be used to tune kernels and assess PRs. Looking forward to thoughts and feedback.

@dr-ci
Copy link
Copy Markdown

dr-ci Bot commented May 12, 2020

💊 CI failures summary and remediations

As of commit d068b7a (more details on the Dr. CI page):


None of the CI failures appear to be your fault 💚



❄️ 2 failures tentatively classified as flaky

but reruns have not yet been triggered to confirm:

See CircleCI build caffe2_onnx_main_py3_6_clang7_ubuntu16_04_build (1/2)

Step: "Build" (full log | diagnosis details | 🔁 rerun) ❄️

Jun 29 23:01:44 Failed to recurse into submodule path 'third_party/ideep'
sys	0m0.060s 
Jun 29 23:01:19 ++ export BUILD_ENVIRONMENT=caffe2-onnx-main-py3.6-clang7-ubuntu16.04-build 
Jun 29 23:01:19 ++ BUILD_ENVIRONMENT=caffe2-onnx-main-py3.6-clang7-ubuntu16.04-build 
Jun 29 23:01:19 ++ git submodule sync 
Jun 29 23:01:19 ++ git submodule update -q --init --recursive 
Jun 29 23:01:44 error: RPC failed; curl 56 GnuTLS recv error (-54): Error in the pull function. 
Jun 29 23:01:44 fatal: The remote end hung up unexpectedly 
Jun 29 23:01:44 fatal: early EOF 
Jun 29 23:01:44 fatal: index-pack failed 
Jun 29 23:01:44 fatal: clone of 'https://github.com/intel/mkl-dnn.git' into submodule path 'mkl-dnn' failed 
Jun 29 23:01:44 Failed to recurse into submodule path 'third_party/ideep' 

See CircleCI build binary_windows_libtorch_3_7_cpu_debug_build (2/2)

Step: "Build" (full log | diagnosis details | 🔁 rerun) ❄️

CondaHTTPError: HTTP 000 CONNECTION FAILED for url
The system cannot find the file specified. 
Could Not Find C:\w\b\windows\miniconda.exe 
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current 
                                 Dload  Upload   Total   Spent    Left  Speed 
   0     0    0     0    0     0      0      0 --:--:-- --:--:-- --:--:--     0 100 54.6M  100 54.6M    0     0  54.6M      0  0:00:01 --:--:--  0:00:01  205M 
Collecting package metadata (current_repodata.json): ...working... done 
Solving environment: ...working... failed with repodata from current_repodata.json, will retry with next repodata source. 
Collecting package metadata (repodata.json): ...working... done 
Solving environment: ...working... done 
 
CondaHTTPError: HTTP 000 CONNECTION FAILED for url <https://repo.anaconda.com/pkgs/main/win-64/mkl-2020.1-216.conda> 
Elapsed: - 
 
An HTTP error occurred when trying to retrieve this URL. 
HTTP errors are often intermittent, and a simple retry will get you on your way. 
 
 
 
## Package Plan ## 
 
  environment location: C:\w\b\windows\conda\envs\py37 

This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions on the GitHub issue tracker or post in the (internal) Dr. CI Users group.

See how this bot performed.

This comment has been revised 84 times.

@robieta
Copy link
Copy Markdown
Contributor Author

robieta commented May 12, 2020

CI is failing because of merge conflicts when it tries to fast forward the branch. Looking into it now.

@vadimkantorov
Copy link
Copy Markdown
Contributor

vadimkantorov commented May 12, 2020

One thing to consider for CPU benchmarks (but may be hard to do properly): controlling for CPU throttling https://lemire.me/blog/2018/01/16/microbenchmarking-calls-for-idealized-conditions/ and maybe for CPU thread affinity

@robieta
Copy link
Copy Markdown
Contributor Author

robieta commented May 12, 2020

One thing to consider for CPU benchmarks (but may be hard to do properly): controlling for CPU throttling and maybe for CPU thread affinity

Yeah, it's a tough problem to be sure. I'm using the standard set of tricks to mitigate this:

  1. Conduct a number of trials rather than one long run. (Generally tens to hundreds)
  2. Use Median rather than mean for robustness to outliers / systematic throttling / etc.
  3. Discard trials where too much variation is observed.
  4. Trim visualization based on estimated significant figures.

I've found that measurements are fairly stable (at least on my machine), but I agree that it's definitely something to watch out for and this is not a panacea. The hope is that later versions will have proper runtime integration so we can get counts, allocations, etc.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robieta has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Collaborator

@dzhulgakov dzhulgakov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is pretty cool, I like nice APIs!

For CPU control it's hard to do so from within process. But we could provide a standard wrapper script that would turn off turbo (not sure whether there's a standard way to do it), pin to a single cpu and set thread scheduler priority to performance.

Comment thread benchmarks/experimental_components/examples/end_to_end.py Outdated
Comment thread benchmarks/experimental_components/utils/timer.py Outdated
@robieta
Copy link
Copy Markdown
Contributor Author

robieta commented May 20, 2020

This is pretty cool, I like nice APIs!

For CPU control it's hard to do so from within process. But we could provide a standard wrapper script that would turn off turbo (not sure whether there's a standard way to do it), pin to a single cpu and set thread scheduler priority to performance.

Thanks!

I agree, there doesn't seem to be any way to control the environment without spawning a subprocess. One concern would be overhead; given that most measurements tend to be short (< 1s), I worry that the vast majority of time would be spent creating and destroying these controlled envs. One could imagine keeping a pool of subprocesses and having Timers "submit" work to it, but you start to get into non-trivial engineering complexity vs just spacing out replicates.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robieta has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

Copy link
Copy Markdown
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Overall this looks great!

broadcastable to the shape of `x`:

```
fuzzer = Fuzzer(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As a follow-up, it probably makes sense to create fuzzer helpers that would fuzz tensors for common cases - e.g. UnaryOpFuzzer, BinaryOpFuzzer (with/without broadcast) with number of dimensions as an argument and sensible defaults.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've added two which are pretty comprehensive for unary and binary ops, and I'll add more in the future. Let me know what you think.

Comment thread benchmarks/experimental_components/utils/common.py Outdated
Comment thread benchmarks/experimental_components/utils/common.py Outdated
Comment thread benchmarks/experimental_components/utils/common.py Outdated
return output

@staticmethod
def color_segment(segment, value, group_values):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

color-coding is very nice!

Comment thread benchmarks/experimental_components/utils/fuzzer.py Outdated
Comment thread benchmarks/experimental_components/examples/end_to_end.py Outdated
Comment thread benchmarks/experimental_components/utils/fuzzer.py Outdated
Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robieta has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@robieta robieta force-pushed the gh/taylorrobie/timeit_benchmark branch from 615a36b to a3db551 Compare June 4, 2020 16:40
@robieta
Copy link
Copy Markdown
Contributor Author

robieta commented Jun 24, 2020

I've updated the end-to-end example to test 3 PRs, as well as including a script to build the expected environments. The following were run on a 56 core, 8 P100 machine:

#39850 (cc @xwang233)

#39967 (cc @ShawnZhong)

#39744 (cc @nikitaved)

@robieta
Copy link
Copy Markdown
Contributor Author

robieta commented Jun 24, 2020

@nikitaved Thanks for pointing out that the GPU variation for your PR is quite high. (I can't seem to find the comment)
I re-ran it using the same environment for both before and after and saw similarly high levels of variation. After some experimentation, I found two things:

  1. Running CPU and GPU at the same time significantly increases the variance (~ +/- 70% with vs. ~ +/- 30% without). I had limited the number of workers in the CPU pool, but htop indicates that I was still saturating the CPU at times so it was probably contending on the CPU portion. (e.g. stream management and op launch.)

  2. Replicates are essential. I had this notion that GPU measurements were more stable than CPU, but empirically I had to go to 5 replicates to get down to < 5% variability. (The Timer does internal replicates, but those are clearly correlated either by device or temporally.)

Even then, you have to throw out almost 50% of the test cases because the variation is too high. These are basically all <75 us cases; there's just too much overhead jitter. Which is a shame. I think I'm going to experiment with having a worker do more measurements if it detects that a result is noisy to see if that lets us save some of those results.

@robieta
Copy link
Copy Markdown
Contributor Author

robieta commented Jun 25, 2020

It looks like the underlying issue is that the GPUs are not identical, so the variance was due to master and branch being scheduled on different devices. Running them on the same GPU significantly reduces GPU run-to-run variation. And for CPU, 5 replicates is sufficient to clamp down on the variance. (Though a non-trivial number still get culled due to excessive variation.) In order to make this easier to validate, I added a --test_variance option that runs both the before and after on the same environment, so any difference is noise.

#39744 Variance test
#39744 summary (new methodology)

The variance test shows that we're now within ~5%, and rerunning the actual PR shows no meaningful change for GPU (Yay!) and CPU change which is beyond noise, but much more muted than before indicating that the +/- 50% swings were noise rather than real changes, and when we clamp down on that noise the real signal is ~5-10% relative difference.

I'll rerun all of the PR benchmarks tomorrow.

@robieta
Copy link
Copy Markdown
Contributor Author

robieta commented Jun 25, 2020

Updated runs: (This is on an 8xV100 machine since I lost the 8xP100 machine and assignment is random.)
#39850

#39967

#39744

@nikitaved
Copy link
Copy Markdown
Collaborator

@robieta , thanks for the update! When it comes to sorting, I observed that the TensorIterator might give a significant boost for large multidimensional contiguous tensors with the last dimensions sorted. Maybe it could be of any use as it might affect the performance of any dim-apply type of algorithm.

Copy link
Copy Markdown
Collaborator

@ngimel ngimel left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good and you should merge it. end2end example still contains things that are pr-specific and spread around, so maybe some restructuring will be required to make it easier to use for other prs, but that can be done later, the general utilities are in good shape and it definitely makes sense to merge them

def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--pr", type=str, default=_PR_LIST[0], choices=_PR_LIST)
parser.add_argument("--num_gpus", type=int, default=8)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

default probably should be "use all" instead of hardcoded 8 - number of available gpus can be queried later


layout:
Indicates that `x` is not contiguous due to permutation. Invoking
`x.permute(steps)` (e.g. x.permute((2, 0, 1)) if steps = [2, 0, 1])
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

did you mean steps here or order?

layout:
Indicates that `x` is not contiguous due to permutation. Invoking
`x.permute(steps)` (e.g. x.permute((2, 0, 1)) if steps = [2, 0, 1])
would produce a Tensor whose shape matches memory order. (Though still
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a Tensor with physical memory layout matching logical memory layout? shape matches memory order is not very clear

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robieta has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@robieta merged this pull request in f394979.

@fmassa fmassa deleted the gh/taylorrobie/timeit_benchmark branch July 6, 2020 18:05
facebook-github-bot referenced this pull request Jan 22, 2021
Summary:
This is a benchmarking tooling to work with sparse tensors. To implement this, we extended PR `benchmarking util` [https://github.com/pytorch/pytorch/issues/38338](https://github.com/pytorch/pytorch/pull/38338) for sparse tensors.   In order to extend the proposed utility library the **FuzzedTensor** class was extended  by creating the new **FuzzedSparseTensor** class.  In addition two new operator classes were added, the `UnaryOpSparseFuzzer` and `BinaryOpSparseFuzzer`.

The class `FuzzedSparseTensor` adds new input parameters to the constructor:
1. `sparse_dim`: The number of sparse dimensions in a sparse tensor.
2. `nnz`:   Number of non-zero elements in the sparse tensor.
3. `density`: The density of the sparse tensor.
4. `coalesced`: As we know the sparse tensor format permits coalesced/uncoalesced sparse tensors.

and removes `probability_contiguous`, `max_allocation_bytes`, `roll_parameter`, `tensor_constructor` as they are dense-tensors related parameters.

In addition, I've extended the `torch.utils.benchmark.examples` to work with the new classes `FuzzedSparseTensor`, `UnaryOpSparseFuzzer` and `BinaryOpSparseFuzzer`.

Hopefully, this tooling and these examples will help to make other benchmarks in other PRs. Looking forward to your thoughts and feedback. cc robieta, mruberry,  ngimel

Pull Request resolved: #48397

Reviewed By: ejguan

Differential Revision: D26008137

Pulled By: mruberry

fbshipit-source-id: 2f37811c7c3eaa3494a0f2500e519267f2186dfb
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
This is the prototype for the modular utils that we've been discussing. It is admittedly a large PR, but a good fraction of that is documentation and examples. I've trimmed a bit on the edges since we last discussed this design (for instance Timer is no longer Fuzzer aware), but it's mostly the same.

In addition to the library and hermetic examples, I've included `examples.end_to_end` which tests pytorch#38061 over a variety of shapes, dtypes, degrees of broadcasting, and layouts. (CC crcrpar)  I only did CPU as I'm not set up on a GPU machine yet. [Results from my devserver](https://gist.github.com/robieta/d1a8e1980556dc3f4f021c9f7c3738e2)

Key takeaways:
  1) For contiguous Tensors, larger dtypes (fp32 and fp64) and lots of reuse of the mask due to broadcasting, improvements are significant. (Presumably due to better vectorization?)
  2) There is an extra ~1.5 us overhead, which dominates small kernels.
  3) Cases with lower write intensity (int8, lower mask fraction, etc) or non-contiguous seem to suffer.

Hopefully this demonstrates the proof-of-concept for how this tooling can be used to tune kernels and assess PRs. Looking forward to thoughts and feedback.
Pull Request resolved: pytorch#38338

Differential Revision: D21551048

Pulled By: robieta

fbshipit-source-id: 6c50e5439a04eac98b8a2355ef731852ba0500db
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

7 participants