🚀 Feature
Support basic linear algebra for complex numbers.
Motivation
I talked with @sw005320 about https://github.com/nttcslab-sp/dnn_wpe and it turns out, that the matrix inversion implemented with real numbers is unstable. In a beamforming example @Emrys365 observed a performance difference of 5 dB in a signal to distortion ratio (SDR) where he replaced the inversion with numpy code (torch: 5dB, numpy 10dB).
I tried torch.inverse and torch.solve and interestingly they are working in 1.6.0.dev20200623+cpu (Not mentioned in pytorch/pytorch#33152).
Is it possible, to support torch.matmul and some other linear algebra functions?
I also tried to use backward after torch.solve and the code fails with the exception msg, that matmul is not implemented.
Does someone know, how the gradient is defined in torch for complex numbers?
Is it grad_real + j grad_imag or grad_real - j grad_imag?
And how can I add/fix the gradient, when I find a broken implementation?
Pitch
Alternatives
Additional context
Currently, I am considering to jump between pytorch_complex and torch.autograd.Function:
def hermite(a):
return a.transpose(-2, -1).conj()
def matmul(t1, t2):
real1, imag1 = t1.real, t1.imag
real2, imag2 = t2.real, t2.imag
o_real = torch.matmul(real1, real2) - torch.matmul(imag1, imag2)
o_imag = torch.matmul(real1, imag2) + torch.matmul(imag1, real2)
return o_real + 1j * o_imag
class Solve(torch.autograd.Function):
@staticmethod
def forward(ctx, A, b):
x, _ = torch.solve(b, A)
ctx.save_for_backward(A, x)
return x
@staticmethod
def backward(ctx, grad_output):
A, x = ctx.saved_tensors
gb, _ = torch.solve(grad_output, hermite(A))
gA = - matmul(gb, hermite(x))
return gA, gb
🚀 Feature
Support basic linear algebra for complex numbers.
Motivation
I talked with @sw005320 about https://github.com/nttcslab-sp/dnn_wpe and it turns out, that the matrix inversion implemented with real numbers is unstable. In a beamforming example @Emrys365 observed a performance difference of 5 dB in a signal to distortion ratio (SDR) where he replaced the inversion with numpy code (torch: 5dB, numpy 10dB).
I tried
torch.inverseandtorch.solveand interestingly they are working in1.6.0.dev20200623+cpu(Not mentioned in pytorch/pytorch#33152).Is it possible, to support
torch.matmuland some other linear algebra functions?I also tried to use backward after
torch.solveand the code fails with the exception msg, thatmatmulis not implemented.Does someone know, how the gradient is defined in torch for complex numbers?
Is it
grad_real + j grad_imagorgrad_real - j grad_imag?And how can I add/fix the gradient, when I find a broken implementation?
Pitch
torch.solveandtorch.inversemiss a backward functiontorch.matmuldoes not workAlternatives
Additional context
Currently, I am considering to jump between
pytorch_complexandtorch.autograd.Function: