import functools
import random
from typing import List
import torch
import torch.nn.functional as F
def composite_definition(
input1: torch.Tensor,
input2: torch.Tensor,
weight: torch.Tensor,
bias1: torch.Tensor,
bias2: torch.Tensor,
normalization_axis: int,
dropout_prob: float,
) -> torch.Tensor:
bias1_out = input1 + bias1
dropout_out = F.dropout(bias1_out, 0.0, True)
norm_input = dropout_out + input2
norm_output = F.layer_norm(norm_input, (input1.size(normalization_axis),), weight, bias2)
return norm_output
# Setup initial tensors and parameters
input_size = [64, 128, 1024]
device = "cuda"
dtype = torch.float32
# Create sample inputs
input1 = torch.randn(*input_size, device=device, dtype=dtype, requires_grad=True)
input2 = torch.rand_like(input1).requires_grad_()
# Precompute a grad output tensor, for this example it's the same size as the inputs
grad_output = torch.rand_like(input1)
# Randomly initialize the model parameters
weight = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))
bias1 = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))
bias2 = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))
parameters = [input1, input2, weight, bias1, bias2]
ref_composite = composite_definition(input1, input2, weight, bias1, bias2, normalization_axis=2, dropout_prob=0.0)
scripted_composite_definition = torch.jit.script(composite_definition)
for i in range(20):
scripted = scripted_composite_definition(input1, input2, weight, bias1, bias2, normalization_axis=2, dropout_prob=0.0)
print("output abs max {}".format((ref_composite - scripted).abs().max()))
🐛 Describe the bug
The following script doesn't validate consistently on TOT. It seems we may still be dropping out some values even though probability == 0. I think this may be because of: https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/ops/composite.cpp#L31 which maybe should be
lenotlt?Versions
TOT