Skip to content

[Bug] Outputs of torch.abs abnormally mismatch on GPU and CPU when applying commutative law of multiplication #382

@Azyka

Description

@Azyka

Describe the bug
The commutative law of multiplication should always hold for torch.mul. And a*b*c == a*c*b.
However, this law fails when applying commutative law of multiplication to the input of torch.abs, producing mismatch on the outputs of torch.abs.
This mismatch is seen both on cuda and cpu.

To Reproduce
Repro script:

import numpy as np
from numpy import testing
import torch

DEVICE='cuda'

p0 = torch.tensor(4, device=DEVICE, dtype=torch.int8)
p1 = torch.tensor(6, device=DEVICE, dtype=torch.int8)

class Model0(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.v3_0 = p0
        self.v5_0 = p1

    def forward(self, *args):
        v3_0 = self.v3_0
        v5_0 = self.v5_0
        mul = torch.mul(v5_0, v3_0)
        mul_1 = torch.mul(args[0], mul)
        abs_1 = torch.abs(mul_1)
        mul_2 = torch.mul(mul_1, mul)
        return (abs_1, mul_2)

model_0 = Model0()
output_names_0 = ['v5_0', 'v4_0']

class Model1(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.v3_0 = p0
        self.v5_0 = p1

    def forward(self, *args):
        v3_0 = self.v3_0
        v5_0 = self.v5_0
        mul = torch.mul(v3_0, args[0])
        mul_1 = torch.mul(v5_0, mul)
        abs_1 = torch.abs(mul_1)
        return (abs_1, )

model_1 = Model1()
output_names_1 = ['v5_0',]

data = np.random.normal(10, 0.1, size=(53, 33)).astype(np.int8)
input_data_0 = [data]

optmodel_0 = torch.compile(model_0, fullgraph=True, backend='hidet', mode=None)
model_out_0 = optmodel_0(*[torch.from_numpy(v).to(DEVICE) for v in input_data_0])
model_out_0 = [v.to(DEVICE).detach() for v in model_out_0] if isinstance(model_out_0, tuple) else [model_out_0.to(DEVICE).detach()]
model_out_0 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_0]
output_0 = dict(zip(output_names_0, model_out_0))

input_data_1 = [data]

optmodel_1 = torch.compile(model_1, fullgraph=True, backend='hidet', mode=None)
model_out_1 = optmodel_1(*[torch.from_numpy(v).to(DEVICE) for v in input_data_1])
model_out_1 = [v.to(DEVICE).detach() for v in model_out_1] if isinstance(model_out_1, tuple) else [model_out_1.to(DEVICE).detach()]
model_out_1 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_1]
output_1 = dict(zip(output_names_1, model_out_1))
output_name_dict = {'v5_0': 'v5_0'}

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1], rtol=1, err_msg=f'at {tensor_name_0}, {tensor_name_1}')
    print("hidet does not trigger assertion")
except AssertionError as e:
    print("hidet triggers assertion")
    print(e)
print('=========================')

model_out_0 = model_0(*[torch.from_numpy(v).to(DEVICE) for v in input_data_0])
model_out_0 = [v.to(DEVICE).detach() for v in model_out_0] if isinstance(model_out_0, tuple) else [model_out_0.to(DEVICE).detach()]
model_out_0 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_0]
output_0 = dict(zip(output_names_0, model_out_0))

model_out_1 = model_1(*[torch.from_numpy(v).to(DEVICE) for v in input_data_1])
model_out_1 = [v.to(DEVICE).detach() for v in model_out_1] if isinstance(model_out_1, tuple) else [model_out_1.to(DEVICE).detach()]
model_out_1 = [v.cpu().resolve_conj().numpy() if v.is_conj() else v.cpu().numpy() for v in model_out_1]
output_1 = dict(zip(output_names_1, model_out_1))

print('=========================')
try:
    for tensor_name_0, tensor_name_1 in output_name_dict.items():
        testing.assert_allclose(output_0[tensor_name_0], output_1[tensor_name_1], rtol=1, err_msg=f'at {tensor_name_0}, {tensor_name_1}')
    print("torch_eager does not trigger assertion")
except AssertionError as e:
    print("torch_eager triggers assertion")
    print(e)
print('=========================')

Output:

=========================
hidet triggers assertion

Not equal to tolerance rtol=1, atol=0
at v5_0, v5_0
Mismatched elements: 1749 / 1749 (100%)
Max absolute difference: 80
Max relative difference: 2.
 x: array([[16, 40, 40, ..., 16, 16, 16],
       [16, 40, 16, ..., 40, 40, 40],
       [40, 16, 40, ..., 40, 16, 40],...
 y: array([[-16, -40, -40, ..., -16, -16, -16],
       [-16, -40, -16, ..., -40, -40, -40],
       [-40, -16, -40, ..., -40, -16, -40],...
=========================
=========================
torch_eager does not trigger assertion
=========================

Expected behavior
The output of torch.abs is expected to be the same for the same input.

Enviroment

  • OS: Ubuntu 22.04.3 LTS (x86_64)
  • GPU: RTX 1660
  • NVIDIA GPU Driver: 525.147.05
  • Hidet Version: 0.3.0
  • PyTorch Version: 2.1.0+cu118

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions