Skip to content

[Need Discussion] Implement twice backward of ConvNd#1569

Closed
caogang wants to merge 7 commits intopytorch:masterfrom
caogang:feature_conv-backward
Closed

[Need Discussion] Implement twice backward of ConvNd#1569
caogang wants to merge 7 commits intopytorch:masterfrom
caogang:feature_conv-backward

Conversation

@caogang
Copy link
Contributor

@caogang caogang commented May 16, 2017

Hi, all.

This PR is following the PR #1555 with cleaner rebase.

@apaszke I have followed your suggestion to modify the ConvForward directly. And change the implementation methods of ConvTransposeNd to use ConvBackward directly instead of ConvForward(transposed). Here is the commit. And I have used gradcheck to check the two gradient. It passed the following test.

from torch.autograd import gradcheck
input = (autograd.Variable(torch.randn(1, 20, 22), requires_grad=True),)
test = gradcheck(nn.Conv1d(20, 12, 4), input, eps=1e-3, atol=1e-4)
print(test)
input = (autograd.Variable(torch.randn(1, 20, 22), requires_grad=True),)
test = gradcheck(nn.ConvTranspose1d(20, 12, 4), input, eps=1e-3, atol=1e-4)
print(test)

However when I use it to perform grad of grad of ConvForward, it raise error. So for now, do not merge this PR.

Assume the forward formula to calculate, is ConvForward(forward_mode) -> ConvForward(forward_mode) -> ConvBackward(grad_mode) ->ConvBackward(grad_mode). And the backward process is ConvForward(grad_mode) -> ConvForward(grad_mode) -> ConvBackward(grad_mode) -> ConvBackward (grad_mode). It raised a Segment Fault Error at the third backward process unit,ConvBackward(grad_mode)

I have found the problem is because of the std::move(convolution) in ConvBackward::apply. During the forward process of ConvBackward(grad_mode), it will use std::move(convolution) to pass arguments, and convolution of itself will be zero. When the second time processing this unit in backward process, some parameters like convolution will cause this Segment error.

  • So how can I pass this convolution parameters with itself retained not be cleared when using std::move in the following place?
file : torch/csrc/autograd/functions/convolution.cpp
568+            std::move(columns), std::move(ones), std::move(convolution));

I have a problem. Assume the forward formula to calculate, is ConvForward -> ConvForward -> ConvBackward ->ConvBackward. And the backward process is ConvForward -> ConvForward -> ConvBackward -> ConvBackward. Should I accumulate grad of weight and bias in every backward process unit? Or just do that in the ConvBackward unit.

@apaszke
Copy link
Contributor

apaszke commented May 16, 2017

I don't understand what these ConvForward and ConvBackward sequences mean. I also don't really know where do you want to accumulate the grads -- these functions only compute and return them, there's no accumulation going on.

@apaszke
Copy link
Contributor

apaszke commented May 16, 2017

Also, you're not doing gradcheck(ConvNd(20, 12, 4), ...) right? It's not valid you should never reuse a Function object. This is correct: gradcheck(lambda input: ConvNd(20, 12, 4)(input), ...).

@caogang
Copy link
Contributor Author

caogang commented May 16, 2017

Ok, I change the test code as follows. it can still pass the test

inputs = (autograd.Variable(torch.randn(1, 12, 12).double(), requires_grad=True),
               autograd.Variable(torch.randn(10, 12, 5).double(), requires_grad=True,),
               autograd.Variable(torch.randn(10).double(), requires_grad=True))
test = gradcheck(lambda i, w, b : F.conv1d(i, w, b), inputs, eps=1e-6, atol=1e-4)
print(test)
inputs = (autograd.Variable(torch.randn(1, 12, 12).double(), requires_grad=True),
               autograd.Variable(torch.randn(12, 10, 5).double(), requires_grad=True,),
               autograd.Variable(torch.randn(10).double(), requires_grad=True))
test = gradcheck(lambda i, w, b : F.conv1d(i, w, b), inputs, eps=1e-6, atol=1e-4)
print(test)

@caogang
Copy link
Contributor Author

caogang commented May 16, 2017

By the way. Do you have any idea of the first problem that I want to use another function when passing arguments convolution instead of std::move(convolution)? Because the move will reset the convolution parameter, and I want to retain convolution value. I have tried pass the convolution without std::move, it raise error. :(

@apaszke
Copy link
Contributor

apaszke commented May 16, 2017

You need to remove std::move, but can't because it's a unique_ptr (i.e. there can be only one owner of this pointer). It needs to be changed to a shared_ptr.

@caogang
Copy link
Contributor Author

caogang commented May 17, 2017

Ok, the first problem is solved. Now I found there are something wrong with accumulating the grad_weight into the SavedVariable weight_. So how can I accumulate the grad to weight_? Maybe using the Variable::get_grad_accumulator()? But I don't know how to use it. @apaszke

@apaszke
Copy link
Contributor

apaszke commented May 17, 2017

You don't need to accumulate any grads! Just return them from apply and autograd will take care of the rest

Copy link
Contributor

@apaszke apaszke left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, so @albanD pointed out that grad of grad of ConvNd is not a plain forward convolution, but an expression that has 2x forward convolution and a few additions (I haven't verified that yet though). If that's the case, then you shouldn't create a backward and forward mode in the ConvForward function, but create ConvBackwardBackward which doesn't unpack the Variables in its apply, but uses ConvForward and Add functions to compute it's output. This will be enough to make it infinitely many times differentiable

outputs[g] = compute_output(
input_g.get(), weight_g.get(), bias_g.get(),
columns[g].get(), ones[g].get(), kernel_size, *this);
std::cout << "ConvForward in grad mode" << std::endl;

This comment was marked as off-topic.

This comment was marked as off-topic.

dilation: the spacing between kernel elements. Default: 1
"""
f = ConvNd(_triple(stride), _triple(padding), _triple(dilation), True,
f = ConvNdBackward(_triple(stride), _triple(padding), _triple(dilation), True,

This comment was marked as off-topic.

@caogang
Copy link
Contributor Author

caogang commented May 25, 2017

albanD has implemented this feature. Please refer to #1643 .

@caogang caogang closed this May 25, 2017
@caogang caogang deleted the feature_conv-backward branch June 8, 2017 08:08
houseroad added a commit to houseroad/pytorch that referenced this pull request Jan 15, 2019
…827566

Summary:
Previous import was 7abd834091f1024c11749dcfd25126802db9fdd5

Included changes:
- **[84a0441](onnx/onnx@84a0441)**: Clarify namescopes in the presence of nested subgraphs (pytorch#1665) <G. Ramalingam>
- **[118fec5](onnx/onnx@118fec5)**: Add Where op. (pytorch#1569) <Sergii Dymchenko>
- **[beefa15](onnx/onnx@beefa15)**: Use strings directly for casing as np.object w/o redundant StringHolder. (pytorch#1736) <Dmitri Smirnov>
- **[4023bae](onnx/onnx@4023bae)**: Add a capability to input/output unicode strings (pytorch#1734) <Dmitri Smirnov>
- **[1a8a7fc](onnx/onnx@1a8a7fc)**: typos fixed: iutput -> input (pytorch#1726) <Beomsoo Kim>
- **[0128478](onnx/onnx@0128478)**: Scan test update (pytorch#1732) <G. Ramalingam>
- **[c6a24fd](onnx/onnx@c6a24fd)**: turn rtol to 0.002 on densenet121, since AMD and Nvidia GPU's precion difference (pytorch#1733) <Lu Fang>
- **[5b7ac72](onnx/onnx@5b7ac72)**: Add Shrink operator (pytorch#1622) <Rui Zhu>

Differential Revision: D13676711

fbshipit-source-id: 0b7b8a398afa4a3b54752fb792f19e7efca80f65
facebook-github-bot pushed a commit that referenced this pull request Jan 16, 2019
…827566 (#16046)

Summary:
Pull Request resolved: #16046

Previous import was 7abd834091f1024c11749dcfd25126802db9fdd5

Included changes:
- **[84a0441](onnx/onnx@84a0441)**: Clarify namescopes in the presence of nested subgraphs (#1665) <G. Ramalingam>
- **[118fec5](onnx/onnx@118fec5)**: Add Where op. (#1569) <Sergii Dymchenko>
- **[beefa15](onnx/onnx@beefa15)**: Use strings directly for casing as np.object w/o redundant StringHolder. (#1736) <Dmitri Smirnov>
- **[4023bae](onnx/onnx@4023bae)**: Add a capability to input/output unicode strings (#1734) <Dmitri Smirnov>
- **[1a8a7fc](onnx/onnx@1a8a7fc)**: typos fixed: iutput -> input (#1726) <Beomsoo Kim>
- **[0128478](onnx/onnx@0128478)**: Scan test update (#1732) <G. Ramalingam>
- **[c6a24fd](onnx/onnx@c6a24fd)**: turn rtol to 0.002 on densenet121, since AMD and Nvidia GPU's precion difference (#1733) <Lu Fang>
- **[5b7ac72](onnx/onnx@5b7ac72)**: Add Shrink operator (#1622) <Rui Zhu>

Reviewed By: yinghai

Differential Revision: D13676711

fbshipit-source-id: 513cc137223469b47af48919432aaecf58006012
jjsjann123 pushed a commit to jjsjann123/pytorch that referenced this pull request Apr 18, 2022
 * cache_after to cacheAfter
 * cache_before to cacheBefore
 * cache_fork to cacheFork
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants