Skip to content

OP lowering for _prelu_kernel_backward #5695

@Edelbert

Description

@Edelbert

🚀 Feature

Hello! I found out that there is no xla op for nn.PReLU, although _prelu_kernel is supported, it looks like its not about nn.PReLU function. For the sake of example, I added PReLU functions into default example .

class MNIST(nn.Module):

  def __init__(self):
    super(MNIST, self).__init__()
    self.conv1 = nn.Conv2d(1, 10, kernel_size=5)
    self.prelu1 = nn.PReLU(10)
    self.bn1 = nn.BatchNorm2d(10)
    self.conv2 = nn.Conv2d(10, 20, kernel_size=5)
    self.prelu2 = nn.PReLU(20)
    self.bn2 = nn.BatchNorm2d(20)
    self.fc1 = nn.Linear(320, 50)
    self.prelu3 = nn.PReLU(50)
    self.fc2 = nn.Linear(50, 10)

  def forward(self, x):
    x = self.prelu1(F.max_pool2d(self.conv1(x), 2))
    x = self.bn1(x)
    x = self.prelu2(F.max_pool2d(self.conv2(x), 2))
    x = self.bn2(x)
    x = torch.flatten(x, 1)
    x = self.prelu3(self.fc1(x))
    x = self.fc2(x)
    return F.log_softmax(x, dim=1)

Debug output

Counter: aten::_prelu_kernel_backward
  Value: 305

pt-xla-profiler: TransferFromServerTime too frequent: 311 counts during 105 steps
pt-xla-profiler: Op(s) not lowered: aten::_prelu_kernel_backward,  Please open a GitHub issue with the above op lowering requests.

Packages version:

torch                    2.1.0
torch-xla                2.1.0

Metadata

Metadata

Assignees

Labels

loweringATen Operation lowering

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions