Skip to content

Add complex support for torch.nn.L1Loss#49912

Closed
soulitzer wants to merge 20 commits intopytorch:masterfrom
soulitzer:l1-complex-support
Closed

Add complex support for torch.nn.L1Loss#49912
soulitzer wants to merge 20 commits intopytorch:masterfrom
soulitzer:l1-complex-support

Conversation

@soulitzer
Copy link
Copy Markdown
Contributor

@soulitzer soulitzer commented Dec 29, 2020

Building on top of the work of @anjali411 (#46640)

Things added in this PR:

  1. Modify backward and double-backward formulas
  2. Add complex support for new module tests and criterion tests (and add complex tests for L1)
  3. Modify some existing tests to support complex

@soulitzer soulitzer added module: nn Related to torch.nn module: complex Related to complex number support in PyTorch labels Dec 29, 2020
@facebook-github-bot
Copy link
Copy Markdown
Contributor

facebook-github-bot commented Dec 29, 2020

💊 CI failures summary and remediations

As of commit 8aee998 (more details on the Dr. CI page):


  • 1/1 failures possibly* introduced in this PR
    • 1/1 non-CircleCI failure(s)

ci.pytorch.org: 1 failed


This comment was automatically generated by Dr. CI (expand for details).Follow this link to opt-out of these comments for your Pull Requests.

Please report bugs/suggestions to the (internal) Dr. CI Users group.

@codecov
Copy link
Copy Markdown

codecov Bot commented Dec 29, 2020

Codecov Report

Merging #49912 (8aee998) into master (f10e7aa) will increase coverage by 10.22%.
The diff coverage is 96.55%.

@@             Coverage Diff             @@
##           master   #49912       +/-   ##
===========================================
+ Coverage   70.48%   80.71%   +10.22%     
===========================================
  Files        1904     1904               
  Lines      206632   206633        +1     
===========================================
+ Hits       145653   166789    +21136     
+ Misses      60979    39844    -21135     

Comment thread test/test_nn.py Outdated
Comment thread torch/testing/_internal/common_nn.py
Copy link
Copy Markdown
Contributor

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should also update the documentation to indicate that L1Loss now supports complex numbers

@soulitzer soulitzer requested a review from anjali411 December 30, 2020 04:00
type_map = {}
if isinstance(obj, torch.Tensor):
assert obj.is_leaf
t = type_map.get(obj.type(), get_gpu_type(obj.type()))
Copy link
Copy Markdown
Contributor

@anjali411 anjali411 Jan 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_gpu_type is only used at one other place, so it would be awesome if you could update that too and get rid of get_gpu_type method.
https://github.com/pytorch/pytorch/blob/master/test/test_cuda.py#L770

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also nit - let's just change this to t = type_map.get(obj.type(), obj.type()) and change the line below to res = obj.clone().type(t).cuda()

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might as well get rid of this test in that case
Since

    def test_is_tensor(self):
        for t in types:
            tensor = get_gpu_type(t)()
            self.assertTrue(torch.is_tensor(tensor))
        self.assertTrue(torch.is_tensor(torch.cuda.HalfTensor()))

becomes something like

for t in types:
  tensor = torch.tensor(data, dtype=t).cuda()

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah right makes sense! let's do that

Comment thread torch/testing/_internal/common_nn.py Outdated
Comment thread torch/testing/_internal/common_nn.py Outdated
Comment thread aten/src/ATen/native/Loss.cpp Outdated
Comment thread aten/src/ATen/native/Loss.cpp Outdated
@soulitzer soulitzer requested a review from albanD as a code owner January 5, 2021 06:09
Comment thread test/test_nn.py Outdated
if gradOutput is None:
gradOutput = torch.ones(())
criterion(*args).backward(gradOutput.to(input_tuple[0]))
criterion(*args).backward(gradOutput.to(output))
Copy link
Copy Markdown
Contributor Author

@soulitzer soulitzer Jan 5, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For C to R functions, input's dtype is not equal to output's dtype. In general, we'd like gradoutput to be the same dtype as output anyway

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm what if the output is a tuple? I think we should add a similar check for output as input:
output_tuple = output if isinstance(output, tuple) else (output,)

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can output be a tuple though? Input might be a tuple only because when we backward, we might want to populate the grads of multiple inputs. I'm curious which functions return tuples.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure about torch.nn module functions, but some torch functions that come to mind are triangular_solve, qr

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soulitzer did you look into this?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are functions like nn.AdaptiveMaxPool2d that do return a tuple, so I ended up adding the check for the tuple case.

auto norm = reduction == Reduction::Mean ? grad_output / input.numel() : grad_output;
at::sub_out(grad_input, input, target).sign_().mul_(norm);
return grad_input;
return at::sub_out(grad_input, input, target).sgn_().mul_(norm);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good!

Copy link
Copy Markdown
Contributor

@anjali411 anjali411 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @soulitzer the PR changes look good to me overall. let's rebase on the master and check the CI tests.

could you also remove get_gpu_type and test_is_tensor?

@soulitzer soulitzer requested a review from gchanan January 7, 2021 22:05
target_fn=lambda: torch.randn((2, 3, 4), requires_grad=True),
reference_fn=lambda i, t, _: 1. / i.numel() *
sum((a - b).abs().sum() for a, b in zip(i, t)),
check_complex=True,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is the target in this case supposed to be complex or real? The math makes it look like it should be complex, but the target created is real?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

target_fn is only used by test_jit, which basically just tries to see if scripted module behaves the same as the python module. I don't see it handling check_bfloat16 or check_half either.

Comment thread aten/src/ATen/native/Loss.cpp Outdated
} else {
at::sub_out(result, input, target).abs_();
Tensor& l1_loss_out(Tensor& result, const Tensor& input, const Tensor& target, int64_t reduction) {
auto diff = at::sub_out(result, input, target);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does this cause warnings? Because we usually warn when the result is resized (not from something 0-sized).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be fixed in the latest update. When the shape of result matches the post-reduce shape a warning should no longer appear.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@gchanan do you want to take another look at this

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soulitzer has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@soulitzer
Copy link
Copy Markdown
Contributor Author

soulitzer commented Jan 11, 2021

Fixes the l1_loss case for #50382.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soulitzer has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@soulitzer
Copy link
Copy Markdown
Contributor Author

@anjali411 @albanD Made a non-trivial change to the code. In the latest commit, everything is now routed through the out variant instead of having two separate code paths.

Comment thread aten/src/ATen/native/Loss.cpp Outdated
Tensor result = at::empty({0}, input.options().dtype(float_type));
return at::l1_loss_out(result, input, target, reduction);
}
Tensor result = at::empty({0}, input.options());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should go in an else branch:

Tensor result;
if (input.is_complex()) {
...
} else {
    result = at::empty({0}, input.options());
}

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh wait just saw you have a return statement in the if condition. I still think it might be cleaner to change it to an if else statement with a common return

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, then you'd have to declare result before the if else. Otherwise it would go out of scope by the time you try to return it.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is that necessarily cleaner? :P I feel like it could be good either way.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait if you 'declare' it before, you are technically doing an extra default initialization then copy assigning instead of simply copy initializing

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A nice trick for this that @ezyang showed me is to use a lambda:

const auto float_type = [&]() {
  if (input.is_complex()) {
    return c10::toValueType(input.scalar_type());
  } else {
    return input.scalar_type();
  }
}();
Tensor result = at::empty({0}, input.options().dtype(float_type));
return at::l1_loss_out(result, input, target, reduction);

But beyond that, what happens if you call c10::toValueType on a non complex dtype? Is it just returned as-is? If so, you don't need branching in this function at all!

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh you're right c10::toValueType does handle the non-complex dtype by just returning as-is.

Copy link
Copy Markdown
Collaborator

@albanD albanD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm
just small potential simplification of the composite l1_loss.

Copy link
Copy Markdown
Contributor

@facebook-github-bot facebook-github-bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@soulitzer has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.

@anjali411
Copy link
Copy Markdown
Contributor

@soulitzer thanks again! this PR looks good to me with the current changes. is there anything that's blocking this PR?

@soulitzer
Copy link
Copy Markdown
Contributor Author

@anjali411 I actually almost landed this yesterday, but held off due to the CI issues. One thing I wanted to check again was to see if there were any updates to gen_variable_type which caused conflicts the last time I pulled, but it doesn't seem like there are - so its on its way now!

@facebook-github-bot
Copy link
Copy Markdown
Contributor

@soulitzer merged this pull request in 6e3e570.

@soulitzer soulitzer deleted the l1-complex-support branch April 14, 2021 20:40
@AhmedBoin
Copy link
Copy Markdown

for the secound time you can implement your own

def complex_mse_loss(output, target):
    return (0.5*(output - target)**2).mean(dtype=torch.complex64)

you can also implement layers or any custom utils needed

class CLinear(nn.Module):
    def __init__(self, size_in, size_out):
        super().__init__()
        self.weights = nn.Parameter(torch.randn(size_in, size_out, dtype=torch.complex64) 
        self.bias = nn.Parameter(torch.zeros(size_out, dtype=torch.complex64))

    def forward(self, x):
        if not x.dtype == torch.complex64: x = x.type(torch.complex64)
        return x@self.weights + self.bias

laurentdupin pushed a commit to laurentdupin/pytorch that referenced this pull request Apr 24, 2026
Summary:
Building on top of the work of anjali411 (pytorch#46640)

Things added in this PR:
1. Modify backward and double-backward formulas
2. Add complex support for `new module tests` and criterion tests (and add complex tests for L1)
3. Modify some existing tests to support complex

Pull Request resolved: pytorch#49912

Reviewed By: zhangguanheng66

Differential Revision: D25853036

Pulled By: soulitzer

fbshipit-source-id: df619f1b71c450ab2818eb17804e0c55990aa8ad
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed complex_autograd Merged module: complex Related to complex number support in PyTorch module: nn Related to torch.nn

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants