Skip to content

torch.sparse.sum backward fails when reducing over dense dimensions. #99147

@nikitaved

Description

@nikitaved

🐛 Describe the bug

As per title. To reproduce:

In [1]: import torch                                                                                                   
                                                                                                                       
In [2]: def make_args(x, dim):                                                                                         
   ...:     x_g = x.clone().requires_grad_(True)
   ...:     y_g = torch.sparse.sum(x_g, dim=dim)                                                                       
   ...:     return x_g, y_g                                                                                            
   ...:                                                    
                                                                                                                                                                                                                                              
In [3]: idx = torch.tensor([[0, 0, 0], [0, 1, 2]])         
                                                           
In [4]: val = torch.rand(3, 5, 5)                          
                                                                                                                       
In [5]: x = torch.sparse_coo_tensor(idx, val, (5, 5, 5, 5))
                                                           
In [6]: x_g, y_g = make_args(x, -1) 
                                                           
In [7]: torch.autograd.grad(y_g, x_g, torch.ones(*y_g.shape).to_sparse(y_g.sparse_dim()))                              
---------------------------------------------------------------------------                                            
RuntimeError                              Traceback (most recent call last)                                            
Input In [7], in <cell line: 1>()             
----> 1 torch.autograd.grad(y_g, x_g, torch.ones(*y_g.shape).to_sparse(y_g.sparse_dim()))                              
                                                           
File ~/git/Quansight/pytorch/torch/autograd/__init__.py:319, in grad(outputs, inputs, grad_outputs, retain_graph, create_graph, only_inputs, allow_unused, is_grads_batched, materialize_grads)                                               
    317     result = _vmap_internals._vmap(vjp, 0, 0, allow_none_pass_through=True)(grad_outputs_)                     
    318 else:                                
--> 319     result = Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass    
    320         t_outputs, grad_outputs_, retain_graph, create_graph, t_inputs,                                        
    321         allow_unused, accumulate_grad=False)  # Calls into the C++ engine to run the backward pass             
    322 if materialize_grads:                
    323     result = tuple(output if output is not None else torch.zeros_like(input, requires_grad=True)               
    324                    for (output, input) in zip(result, t_inputs))                                               
                                                           
RuntimeError: The expanded size of the tensor (3) must match the existing size (25) at non-singleton dimension 0.  Target sizes: [3, 5, 5].  Tensor sizes: [25, 5, 1]       

No issues like that when reducing over sparse dimensions.

Versions

Current master.

cc @alexsamardzic @pearu @cpuhrsch @amjames @bhosmer @ezyang @albanD @zou3519 @gqchen @soulitzer @lezcano @Varal7

Metadata

Metadata

Assignees

Labels

module: autogradRelated to torch.autograd, and the autograd engine in generalmodule: sparseRelated to torch.sparsetriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions