Skip to content

Numerical inconsistencies on GPU when computing A.T@B vs (B.T@A).T #67185

@wjablonski-work

Description

@wjablonski-work

🐛 Bug

I'm getting inconsistent results when computing a.T@b vs (b.T@a).T on GPU - while they should be the same.

To Reproduce

The example only works with specific a and b matrices, to be downloaded here:
https://ufile.io/f/odd2s

import torch
import numpy as np
a = torch.from_numpy(np.load("a.npy"))
b = torch.from_numpy(np.load("b.npy"))
b = b@b.T

def norm(m):
  return torch.mean(m**2)

def check(m1, m2):
  print(f"m1 norm: {norm(m1)} m2 norm: {norm(m2)}, difference norm: {norm(m1-m2)}")
  # print(m1-m2)

print(a.shape, b.shape)

# on cpu everything is more or less OK (error is on the order of float precision)
check(a.T@b, (b.T@a).T)

# subset of b columns - same
check(a.T@b[:,:10], (b[:,:10].T@a).T)

# on cuda however...
a = a.cuda()
b = b.cuda()
check(a.T@b, (b.T@a).T)

# also on cuda, this time it's OK
check(a.T@b[:,:10], (b[:,:10].T@a).T)

output:

torch.Size([7000, 6]) torch.Size([7000, 7000])
m1 norm: 1.5420925003586944e+23 m2 norm: 1.5420961032383963e+23, difference norm: 3.177730552941773e+16
m1 norm: 8.02644225740919e+22 m2 norm: 8.026431448770084e+22, difference norm: 627637878784.0
m1 norm: 1.5420941216545603e+23 m2 norm: 1.686173280933397e+23, difference norm: 8.479770993020602e+22
m1 norm: 1.1493808646518008e+23 m2 norm: 1.1493808646518008e+23, difference norm: 0.0

Expected behavior

When running on a CPU, the difference between a.T@b vs (b.T@a).T is small enough to be accounted for by precision error. However, on a gpu, the error is huge. The error on a gpu disappears when I slice off only small fraction of the b matrix.

Also, casting everything to double fixes the issue.

Environment

Collecting environment information...
PyTorch version: 1.9.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: CentOS Linux release 7.7.1908 (Core) (x86_64)
GCC version: (GCC) 9.2.0
Clang version: Could not collect
CMake version: version 2.8.12.2
Libc version: glibc-2.17

Python version: 3.8.7 (default, Oct 18 2021, 14:45:54)  [GCC 8.2.1 20180905 (Red Hat 8.2.1-3)] (64-bit runtime)
Python platform: Linux-3.10.0-1062.1.1.el7.x86_64-x86_64-with-glibc2.2.5
Is CUDA available: True
CUDA runtime version: Could not collect
GPU models and configuration: 
GPU 0: A100-PCIE-40GB
GPU 1: A100-PCIE-40GB
GPU 2: A100-PCIE-40GB
GPU 3: A100-PCIE-40GB

Nvidia driver version: 460.32.03
cuDNN version: Could not collect
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] mypy==0.780
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.20.1
[pip3] torch==1.9.0+cu111
[pip3] torchaudio==0.9.0
[pip3] torchvision==0.10.0+cu111
[conda] Could not collect

Additional context

cc @zasdfgbnm @ptrblck

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: numerical-stabilityProblems related to numerical stability of operationsmodule: tf32Related to tf32 data formattriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions