Skip to content

Remove optional for veiw_fn during View Tracking#50067

Closed
ejguan wants to merge 7 commits intogh/ejguan/18/basefrom
gh/ejguan/18/head
Closed

Remove optional for veiw_fn during View Tracking#50067
ejguan wants to merge 7 commits intogh/ejguan/18/basefrom
gh/ejguan/18/head

Conversation

@ejguan
Copy link
Copy Markdown
Contributor

@ejguan ejguan commented Jan 4, 2021

Fixes #49257

Stack from ghstack:

Using the Callgrind to test the performance.

import torch
import timeit
from torch.utils.benchmark import Timer

timer = Timer("x.view({100, 5, 20});", setup="torch::Tensor x = torch::ones({10, 10, 100});", language="c++", timer=timeit.default_timer)
res = timer.collect_callgrind(number=10)

Nightly

<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7fdd49d61c70>
x.view({100, 5, 20});
setup: torch::Tensor x = torch::ones({10, 10, 100});
                           All          Noisy symbols removed
    Instructions:        41850                      41850
    Baseline:                0                          0
10 runs per measurement, 1 thread
Warning: PyTorch was not built with debug symbols.
         Source information may be limited. Rebuild with
         REL_WITH_DEB_INFO=1 for more detailed results.

Current

<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7fdcf322d1c0>
x.view({100, 5, 20});
setup: torch::Tensor x = torch::ones({10, 10, 100});
                           All          Noisy symbols removed
    Instructions:        42020                      42020
    Baseline:                0                          0
10 runs per measurement, 1 thread
Warning: PyTorch was not built with debug symbols.
         Source information may be limited. Rebuild with
         REL_WITH_DEB_INFO=1 for more detailed results.

Compare

There are 170 instructions added

<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7fdd427df7f0>
    970  ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, std::function<at::Tensor (at::Tensor const&)>, torch::autograd::CreationMeta, bool)
    240  ???:torch::autograd::ViewInfo::~ViewInfo()
    180  ???:torch::autograd::ViewInfo::ViewInfo(at::Tensor, std::function<at::Tensor (at::Tensor const&)>)
    130  ???:torch::autograd::make_variable_differentiable_view(at::Tensor const&, c10::optional<torch::autograd::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, torch::autograd::CreationMeta, bool)
    105  /tmp/benchmark_utils_jit_build_66d13e5035484028ab02edcb72545e37/timer_cpp_2208139565547260839/timer_src.cpp:main
    100  ???:std::function<at::Tensor (at::Tensor const&)>::function(std::function<at::Tensor (at::Tensor const&)> const&)
     70  ???:torch::autograd::DifferentiableViewMeta::~DifferentiableViewMeta()
     70  ???:torch::autograd::DifferentiableViewMeta::DifferentiableViewMeta(c10::TensorImpl*, c10::optional<torch::autograd::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, torch::autograd::CreationMeta)
   -100  ???:c10::optional_base<torch::autograd::ViewInfo>::optional_base(c10::optional_base<torch::autograd::ViewInfo>&&)
   -105  /tmp/benchmark_utils_jit_build_a62a84fbc4cd4f9e91376c1bb3dcb697/timer_cpp_7842929001219208348/timer_src.cpp:main
   -120  ???:torch::autograd::ViewInfo::ViewInfo(at::Tensor, c10::optional<std::function<at::Tensor (at::Tensor const&)> >)
   -210  ???:c10::optional_base<std::function<at::Tensor (at::Tensor const&)> >::~optional_base()
   -240  ???:c10::optional_base<torch::autograd::ViewInfo>::~optional_base()
   -920  ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, c10::optional<std::function<at::Tensor (at::Tensor const&)> >, torch::autograd::CreationMeta, bool)

Total: 170

Differential Revision: D25900495

@ejguan ejguan requested a review from albanD as a code owner January 4, 2021 21:18
ejguan added a commit that referenced this pull request Jan 4, 2021
ghstack-source-id: 88cc152
Pull Request resolved: #50067
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Jan 4, 2021

💊 CI failures summary and remediations

As of commit 31b2a5f (more details on the Dr. CI page):


💚 💚 Looks good so far! There are no failures yet. 💚 💚


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@ejguan ejguan linked an issue Jan 4, 2021 that may be closed by this pull request
@ejguan ejguan marked this pull request as draft January 5, 2021 14:29
Fixes #49257


Using the following script to test the performance.
```python
import torch

x = torch.randn(100, 100, requires_grad=True)
with torch.autograd.profiler.profile() as prof:
    for _ in range(100):
        y = torch.sum(torch.diagonal(x))
        y.backward()

print(prof.key_averages().table(sort_by="self_cpu_time_total"))
```

### Nightly
![image](https://user-images.githubusercontent.com/68879799/103581107-53f5f680-4ea9-11eb-800a-d6570d1c8b6b.png)

### Current
![image](https://user-images.githubusercontent.com/68879799/103581114-5bb59b00-4ea9-11eb-9aa8-038741bc8407.png)

I don't find any significant regression over backward function (using the `view_fn`)

Differential Revision: [D25768847](https://our.internmc.facebook.com/intern/diff/D25768847)

[ghstack-poisoned]
ejguan added a commit that referenced this pull request Jan 5, 2021
ghstack-source-id: 72c9679
Pull Request resolved: #50067
@ejguan ejguan marked this pull request as ready for review January 5, 2021 16:59
@codecov
Copy link
Copy Markdown

codecov Bot commented Jan 5, 2021

Codecov Report

Merging #50067 (31b2a5f) into gh/ejguan/18/base (5834438) will decrease coverage by 0.00%.
The diff coverage is 86.66%.

@@                  Coverage Diff                  @@
##           gh/ejguan/18/base   #50067      +/-   ##
=====================================================
- Coverage              80.70%   80.70%   -0.01%     
=====================================================
  Files                   1905     1905              
  Lines                 206813   206811       -2     
=====================================================
- Hits                  166916   166912       -4     
- Misses                 39897    39899       +2     

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This change looks good.

Can you check the perf implication with the new Timer class (https://pytorch.org/docs/master/benchmark_utils.html?highlight=timer#torch.utils.benchmark.Timer) with the instruction count mode that will be much more precise!
You can do somthing like this to save a trace (here for a basic view op, but maybe something else if it is more relevant to your PR):

import torch
from torch.utils.benchmark import Timer

timer = Timer("x.view(-1);", setup="torch::Tensor x = torch::ones({1,2,3});", language="c++")
res = timer.collect_callgrind(number=10)

torch.save(res, "save.pth")

And once you do that for master and your branch, you can compare with:

import torch
torch.set_printoptions(linewidth=260)

master_instr = torch.load("master.pth")
fast_view_instr = torch.load("fast_view.pth")

master_instr = master_instr.as_standardized()
fast_view_instr = fast_view_instr.as_standardized()


print(master_instr)
print(fast_view_instr)

delta = fast_view_instr.delta(master_instr).denoise()

print(delta)

Comment thread torch/csrc/autograd/functions/tensor.cpp Outdated
Comment thread torch/csrc/autograd/variable.cpp Outdated
Comment thread torch/csrc/autograd/variable.h
Fixes #49257


Using the following script to test the performance.
```python
import torch

x = torch.randn(100, 100, requires_grad=True)
with torch.autograd.profiler.profile() as prof:
    for _ in range(100):
        y = torch.sum(torch.diagonal(x))
        y.backward()

print(prof.key_averages().table(sort_by="self_cpu_time_total"))
```

### Nightly
![image](https://user-images.githubusercontent.com/68879799/103666773-0d54da80-4f43-11eb-9f04-d7b94276f82f.png)

### Current
![image](https://user-images.githubusercontent.com/68879799/103665874-eba72380-4f41-11eb-9843-61fcd801464f.png)

I don't find any significant regression over backward function (using the `view_fn`)

Differential Revision: [D25768847](https://our.internmc.facebook.com/intern/diff/D25768847)

[ghstack-poisoned]
ejguan added a commit that referenced this pull request Jan 5, 2021
ghstack-source-id: 63e21c1
Pull Request resolved: #50067
Fixes #49257


Using the following script to test the performance.
```python
import torch

x = torch.randn(100, 100, requires_grad=True)
with torch.autograd.profiler.profile() as prof:
    for _ in range(100):
        y = torch.sum(torch.diagonal(x))
        y.backward()

print(prof.key_averages().table(sort_by="self_cpu_time_total"))
```

### Nightly
![image](https://user-images.githubusercontent.com/68879799/103666773-0d54da80-4f43-11eb-9f04-d7b94276f82f.png)

### Current
![image](https://user-images.githubusercontent.com/68879799/103665874-eba72380-4f41-11eb-9843-61fcd801464f.png)

I don't find any significant regression over backward function (using the `view_fn`)

Differential Revision: [D25768847](https://our.internmc.facebook.com/intern/diff/D25768847)

[ghstack-poisoned]
ejguan added a commit that referenced this pull request Jan 6, 2021
ghstack-source-id: 9a65e14
Pull Request resolved: #50067
Fixes #49257


Using the `Callgrind` to test the performance.
```python
import torch
import timeit
from torch.utils.benchmark import Timer

timer = Timer("x.view({100, 5, 20});", setup="torch::Tensor x = torch::ones({10, 10, 100});", language="c++", timer=timeit.default_timer)
res = timer.collect_callgrind(number=10)
```
### Nightly
```python
torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f7949138c40>
x.view({100, 5, 20});
setup: torch::Tensor x = torch::ones({10, 10, 100});
                           All          Noisy symbols removed
    Instructions:        42310                      42310
    Baseline:                0                          0
10 runs per measurement, 1 thread
Warning: PyTorch was not built with debug symbols.
         Source information may be limited. Rebuild with
         REL_WITH_DEB_INFO=1 for more detailed results.
```
### Current
```python
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f78f271a580>
x.view({100, 5, 20});
setup: torch::Tensor x = torch::ones({10, 10, 100});
                           All          Noisy symbols removed
    Instructions:        42480                      42480
    Baseline:                0                          0
10 runs per measurement, 1 thread
Warning: PyTorch was not built with debug symbols.
         Source information may be limited. Rebuild with
         REL_WITH_DEB_INFO=1 for more detailed results.
```
### Compare
There are 170 instructions reduced
```python
torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f7941b7a7c0>
    970  ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, std::function<at::Tensor (at::Tensor const&)>, torch::autograd::CreationMeta, bool)
    240  ???:torch::autograd::ViewInfo::~ViewInfo()
    180  ???:torch::autograd::ViewInfo::ViewInfo(at::Tensor, std::function<at::Tensor (at::Tensor const&)>)
    130  ???:torch::autograd::make_variable_differentiable_view(at::Tensor const&, c10::optional<torch::autograd::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, torch::autograd::CreationMeta, bool)
    105  /tmp/benchmark_utils_jit_build_69e2f1710544485588feeca0719a3a57/timer_cpp_4435526292782672407/timer_src.cpp:main
    100  ???:std::function<at::Tensor (at::Tensor const&)>::function(std::function<at::Tensor (at::Tensor const&)> const&)
     70  ???:torch::autograd::DifferentiableViewMeta::~DifferentiableViewMeta()
     70  ???:torch::autograd::DifferentiableViewMeta::DifferentiableViewMeta(c10::TensorImpl*, c10::optional<torch::autograd::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, torch::autograd::CreationMeta)
   -100  ???:c10::optional_base<torch::autograd::ViewInfo>::optional_base(c10::optional_base<torch::autograd::ViewInfo>&&)
   -105  /tmp/benchmark_utils_jit_build_2e75f38b553e42eba00523a86ad9aa05/timer_cpp_3360771523810516633/timer_src.cpp:main
   -120  ???:torch::autograd::ViewInfo::ViewInfo(at::Tensor, c10::optional<std::function<at::Tensor (at::Tensor const&)> >)
   -210  ???:c10::optional_base<std::function<at::Tensor (at::Tensor const&)> >::~optional_base()
   -240  ???:c10::optional_base<torch::autograd::ViewInfo>::~optional_base()
   -920  ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, c10::optional<std::function<at::Tensor (at::Tensor const&)> >, torch::autograd::CreationMeta, bool)
```
Differential Revision: [D25768847](https://our.internmc.facebook.com/intern/diff/D25768847)

[ghstack-poisoned]
ejguan added a commit that referenced this pull request Jan 8, 2021
ghstack-source-id: cf2c646
Pull Request resolved: #50067
Fixes #49257


Using the `Callgrind` to test the performance.
```python
import torch
import timeit
from torch.utils.benchmark import Timer

timer = Timer("x.view({100, 5, 20});", setup="torch::Tensor x = torch::ones({10, 10, 100});", language="c++", timer=timeit.default_timer)
res = timer.collect_callgrind(number=10)
```
### Nightly
```python
torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f7949138c40>
x.view({100, 5, 20});
setup: torch::Tensor x = torch::ones({10, 10, 100});
                           All          Noisy symbols removed
    Instructions:        42310                      42310
    Baseline:                0                          0
10 runs per measurement, 1 thread
Warning: PyTorch was not built with debug symbols.
         Source information may be limited. Rebuild with
         REL_WITH_DEB_INFO=1 for more detailed results.
```
### Current
```python
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f78f271a580>
x.view({100, 5, 20});
setup: torch::Tensor x = torch::ones({10, 10, 100});
                           All          Noisy symbols removed
    Instructions:        42480                      42480
    Baseline:                0                          0
10 runs per measurement, 1 thread
Warning: PyTorch was not built with debug symbols.
         Source information may be limited. Rebuild with
         REL_WITH_DEB_INFO=1 for more detailed results.
```
### Compare
There are 170 instructions reduced
```python
torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f7941b7a7c0>
    970  ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, std::function<at::Tensor (at::Tensor const&)>, torch::autograd::CreationMeta, bool)
    240  ???:torch::autograd::ViewInfo::~ViewInfo()
    180  ???:torch::autograd::ViewInfo::ViewInfo(at::Tensor, std::function<at::Tensor (at::Tensor const&)>)
    130  ???:torch::autograd::make_variable_differentiable_view(at::Tensor const&, c10::optional<torch::autograd::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, torch::autograd::CreationMeta, bool)
    105  /tmp/benchmark_utils_jit_build_69e2f1710544485588feeca0719a3a57/timer_cpp_4435526292782672407/timer_src.cpp:main
    100  ???:std::function<at::Tensor (at::Tensor const&)>::function(std::function<at::Tensor (at::Tensor const&)> const&)
     70  ???:torch::autograd::DifferentiableViewMeta::~DifferentiableViewMeta()
     70  ???:torch::autograd::DifferentiableViewMeta::DifferentiableViewMeta(c10::TensorImpl*, c10::optional<torch::autograd::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, torch::autograd::CreationMeta)
   -100  ???:c10::optional_base<torch::autograd::ViewInfo>::optional_base(c10::optional_base<torch::autograd::ViewInfo>&&)
   -105  /tmp/benchmark_utils_jit_build_2e75f38b553e42eba00523a86ad9aa05/timer_cpp_3360771523810516633/timer_src.cpp:main
   -120  ???:torch::autograd::ViewInfo::ViewInfo(at::Tensor, c10::optional<std::function<at::Tensor (at::Tensor const&)> >)
   -210  ???:c10::optional_base<std::function<at::Tensor (at::Tensor const&)> >::~optional_base()
   -240  ???:c10::optional_base<torch::autograd::ViewInfo>::~optional_base()
   -920  ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, c10::optional<std::function<at::Tensor (at::Tensor const&)> >, torch::autograd::CreationMeta, bool)
```

[ghstack-poisoned]
ejguan added a commit that referenced this pull request Jan 11, 2021
ghstack-source-id: 302d8da
Pull Request resolved: #50067
Fixes #49257


Using the `Callgrind` to test the performance.
```python
import torch
import timeit
from torch.utils.benchmark import Timer

timer = Timer("x.view({100, 5, 20});", setup="torch::Tensor x = torch::ones({10, 10, 100});", language="c++", timer=timeit.default_timer)
res = timer.collect_callgrind(number=10)
```
### Nightly
```python
torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f7949138c40>
x.view({100, 5, 20});
setup: torch::Tensor x = torch::ones({10, 10, 100});
                           All          Noisy symbols removed
    Instructions:        42310                      42310
    Baseline:                0                          0
10 runs per measurement, 1 thread
Warning: PyTorch was not built with debug symbols.
         Source information may be limited. Rebuild with
         REL_WITH_DEB_INFO=1 for more detailed results.
```
### Current
```python
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f78f271a580>
x.view({100, 5, 20});
setup: torch::Tensor x = torch::ones({10, 10, 100});
                           All          Noisy symbols removed
    Instructions:        42480                      42480
    Baseline:                0                          0
10 runs per measurement, 1 thread
Warning: PyTorch was not built with debug symbols.
         Source information may be limited. Rebuild with
         REL_WITH_DEB_INFO=1 for more detailed results.
```
### Compare
There are 170 instructions reduced
```python
torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f7941b7a7c0>
    970  ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, std::function<at::Tensor (at::Tensor const&)>, torch::autograd::CreationMeta, bool)
    240  ???:torch::autograd::ViewInfo::~ViewInfo()
    180  ???:torch::autograd::ViewInfo::ViewInfo(at::Tensor, std::function<at::Tensor (at::Tensor const&)>)
    130  ???:torch::autograd::make_variable_differentiable_view(at::Tensor const&, c10::optional<torch::autograd::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, torch::autograd::CreationMeta, bool)
    105  /tmp/benchmark_utils_jit_build_69e2f1710544485588feeca0719a3a57/timer_cpp_4435526292782672407/timer_src.cpp:main
    100  ???:std::function<at::Tensor (at::Tensor const&)>::function(std::function<at::Tensor (at::Tensor const&)> const&)
     70  ???:torch::autograd::DifferentiableViewMeta::~DifferentiableViewMeta()
     70  ???:torch::autograd::DifferentiableViewMeta::DifferentiableViewMeta(c10::TensorImpl*, c10::optional<torch::autograd::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, torch::autograd::CreationMeta)
   -100  ???:c10::optional_base<torch::autograd::ViewInfo>::optional_base(c10::optional_base<torch::autograd::ViewInfo>&&)
   -105  /tmp/benchmark_utils_jit_build_2e75f38b553e42eba00523a86ad9aa05/timer_cpp_3360771523810516633/timer_src.cpp:main
   -120  ???:torch::autograd::ViewInfo::ViewInfo(at::Tensor, c10::optional<std::function<at::Tensor (at::Tensor const&)> >)
   -210  ???:c10::optional_base<std::function<at::Tensor (at::Tensor const&)> >::~optional_base()
   -240  ???:c10::optional_base<torch::autograd::ViewInfo>::~optional_base()
   -920  ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, c10::optional<std::function<at::Tensor (at::Tensor const&)> >, torch::autograd::CreationMeta, bool)
```


[ghstack-poisoned]
ejguan added a commit that referenced this pull request Jan 13, 2021
ghstack-source-id: c002b77
Pull Request resolved: #50067
@ejguan ejguan requested a review from albanD January 13, 2021 17:57
@ejguan
Copy link
Copy Markdown
Contributor Author

ejguan commented Jan 13, 2021

@albanD
It's quite interesting that It doesn't seem there is benefit in terms of instruction counts by removing optional for function, as shown in the above comment.

@albanD
Copy link
Copy Markdown
Collaborator

albanD commented Jan 14, 2021

Hi,

Thanks for the benchmark, after discussing with other people, it is expected that the runtime is the same (the compiler is too smart for us): https://godbolt.org/z/W16xos (thanks @swolchok for the godbolt link).
One benefit that is not measured here though is the smaller memory footprint which is always nice!

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The updated code looks good to me, thanks!

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@ejguan merged this pull request in 00d432a.

@facebook-github-bot facebook-github-bot deleted the gh/ejguan/18/head branch January 19, 2021 15:16
laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Pull Request resolved: pytorch#50067

Fixes pytorch#49257

Using the `Callgrind` to test the performance.
```python
import torch
import timeit
from torch.utils.benchmark import Timer

timer = Timer("x.view({100, 5, 20});", setup="torch::Tensor x = torch::ones({10, 10, 100});", language="c++", timer=timeit.default_timer)
res = timer.collect_callgrind(number=10)
```
### Nightly
```python
torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f7949138c40>
x.view({100, 5, 20});
setup: torch::Tensor x = torch::ones({10, 10, 100});
                           All          Noisy symbols removed
    Instructions:        42310                      42310
    Baseline:                0                          0
10 runs per measurement, 1 thread
Warning: PyTorch was not built with debug symbols.
         Source information may be limited. Rebuild with
         REL_WITH_DEB_INFO=1 for more detailed results.
```
### Current
```python
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f78f271a580>
x.view({100, 5, 20});
setup: torch::Tensor x = torch::ones({10, 10, 100});
                           All          Noisy symbols removed
    Instructions:        42480                      42480
    Baseline:                0                          0
10 runs per measurement, 1 thread
Warning: PyTorch was not built with debug symbols.
         Source information may be limited. Rebuild with
         REL_WITH_DEB_INFO=1 for more detailed results.
```
### Compare
There are 170 instructions reduced
```python
torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f7941b7a7c0>
    970  ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, std::function<at::Tensor (at::Tensor const&)>, torch::autograd::CreationMeta, bool)
    240  ???:torch::autograd::ViewInfo::~ViewInfo()
    180  ???:torch::autograd::ViewInfo::ViewInfo(at::Tensor, std::function<at::Tensor (at::Tensor const&)>)
    130  ???:torch::autograd::make_variable_differentiable_view(at::Tensor const&, c10::optional<torch::autograd::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, torch::autograd::CreationMeta, bool)
    105  /tmp/benchmark_utils_jit_build_69e2f1710544485588feeca0719a3a57/timer_cpp_4435526292782672407/timer_src.cpp:main
    100  ???:std::function<at::Tensor (at::Tensor const&)>::function(std::function<at::Tensor (at::Tensor const&)> const&)
     70  ???:torch::autograd::DifferentiableViewMeta::~DifferentiableViewMeta()
     70  ???:torch::autograd::DifferentiableViewMeta::DifferentiableViewMeta(c10::TensorImpl*, c10::optional<torch::autograd::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, torch::autograd::CreationMeta)
   -100  ???:c10::optional_base<torch::autograd::ViewInfo>::optional_base(c10::optional_base<torch::autograd::ViewInfo>&&)
   -105  /tmp/benchmark_utils_jit_build_2e75f38b553e42eba00523a86ad9aa05/timer_cpp_3360771523810516633/timer_src.cpp:main
   -120  ???:torch::autograd::ViewInfo::ViewInfo(at::Tensor, c10::optional<std::function<at::Tensor (at::Tensor const&)> >)
   -210  ???:c10::optional_base<std::function<at::Tensor (at::Tensor const&)> >::~optional_base()
   -240  ???:c10::optional_base<torch::autograd::ViewInfo>::~optional_base()
   -920  ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, c10::optional<std::function<at::Tensor (at::Tensor const&)> >, torch::autograd::CreationMeta, bool)
```

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D25900495

Pulled By: ejguan

fbshipit-source-id: dedd30e69db6b48601a18ae98d6b28faeae30d90
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

View tracking for autograd should not save optional std::function

3 participants