Skip to content

Dropout with prob == 0 doesn't validate consistently #1799

@csarofeen

Description

@csarofeen

🐛 Describe the bug

The following script doesn't validate consistently on TOT. It seems we may still be dropping out some values even though probability == 0. I think this may be because of: https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/ops/composite.cpp#L31 which maybe should be le not lt?

import functools
import random
from typing import List

import torch
import torch.nn.functional as F

def composite_definition(
    input1: torch.Tensor,
    input2: torch.Tensor,
    weight: torch.Tensor,
    bias1: torch.Tensor,
    bias2: torch.Tensor,
    normalization_axis: int,
    dropout_prob: float,
) -> torch.Tensor:
    bias1_out = input1 + bias1
    dropout_out = F.dropout(bias1_out, 0.0, True)
    norm_input = dropout_out + input2
    norm_output = F.layer_norm(norm_input, (input1.size(normalization_axis),), weight, bias2)
    return norm_output

# Setup initial tensors and parameters
input_size = [64, 128, 1024]
device = "cuda"
dtype = torch.float32

# Create sample inputs
input1 = torch.randn(*input_size, device=device, dtype=dtype, requires_grad=True)
input2 = torch.rand_like(input1).requires_grad_()
 
# Precompute a grad output tensor, for this example it's the same size as the inputs
grad_output = torch.rand_like(input1)
 
# Randomly initialize the model parameters
weight = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))
bias1 = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))
bias2 = torch.nn.Parameter(torch.randn(input_size[2], dtype=dtype, device=device))

parameters = [input1, input2, weight, bias1, bias2]
ref_composite = composite_definition(input1, input2, weight, bias1, bias2, normalization_axis=2, dropout_prob=0.0)

scripted_composite_definition = torch.jit.script(composite_definition)

for i in range(20):
  scripted = scripted_composite_definition(input1, input2, weight, bias1, bias2, normalization_axis=2, dropout_prob=0.0)
  print("output abs max {}".format((ref_composite - scripted).abs().max()))

Versions

TOT

Metadata

Metadata

Labels

No labels
No labels

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions