Skip to content

SDPA accuracy issue for CUDA (Disable temporarily) #117957

@Valentine233

Description

@Valentine233

🐛 Describe the bug

The fuse_attention pattern 16 makes Bert CUDA go into efficient attention instead of SDPA math, causing accuracy issue. Found in #113004.

Error logs

Error: :[2024-01-18 02:26:46,462] torch._dynamo.utils: [ERROR] Accuracy failed: allclose not within tol=0.0001

1159Error: :[2024-01-18 02:26:46,463] torch._dynamo.utils: [ERROR] Accuracy failed: allclose not within tol=0.0001

1160Error: :[2024-01-18 02:26:46,463] torch._dynamo.utils: [ERROR] Accuracy failed: allclose not within tol=0.0001

1161Error: :[2024-01-18 02:26:46,463] torch._dynamo.utils: [ERROR] Accuracy failed: allclose not within tol=0.0001

1162Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] Caught exception:

1163Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] Traceback (most recent call last):

1164Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 650, in run_test

1165Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     getattr(self, test_name)()

1166Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 536, in wrapper

1167Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     fn()

1168Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2683, in wrapper

1169Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     method(*args, **kwargs)

1170Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 175, in wrapper

1171Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     return func(*args, **kwargs)

1172Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 161, in wrapper

1173Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     return func(*args, **kwargs)

1174Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/mock.py", line 1379, in patched

1175Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     return func(*newargs, **newkeywargs)

1176Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/var/lib/jenkins/workspace/test/distributed/test_dynamo_distributed.py", line 326, in test_hf_bert_ddp_inductor

1177Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     self._test_hf_bert_ddp_inductor(static_graph=False)

1178Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/var/lib/jenkins/workspace/test/distributed/test_dynamo_distributed.py", line 318, in _test_hf_bert_ddp_inductor

1179Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     run_hf_bert_ddp(self, model, inputs, "inductor")

1180Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/var/lib/jenkins/workspace/test/distributed/test_dynamo_distributed.py", line 242, in run_hf_bert_ddp

1181Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     self.assertTrue(same(correct_results, opt_results))

1182Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 687, in assertTrue

1183Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     raise self.failureException(msg)

1184Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] AssertionError: False is not true

1185Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]

1186Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] To execute this test, run the following from the base repo dir:

1187Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]      python test/distributed/test_dynamo_distributed.py -k test_hf_bert_ddp_inductor

1188Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]

1189Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

1190Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]  exiting process 3 with exit code: 10

1191Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] Caught exception:

1192Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] Traceback (most recent call last):

1193Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 650, in run_test

1194Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     getattr(self, test_name)()

1195Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 536, in wrapper

1196Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     fn()

1197Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2683, in wrapper

1198Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     method(*args, **kwargs)

1199Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 175, in wrapper

1200Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     return func(*args, **kwargs)

1201Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 161, in wrapper

1202Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     return func(*args, **kwargs)

1203Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/mock.py", line 1379, in patched

1204Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     return func(*newargs, **newkeywargs)

1205Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/var/lib/jenkins/workspace/test/distributed/test_dynamo_distributed.py", line 326, in test_hf_bert_ddp_inductor

1206Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     self._test_hf_bert_ddp_inductor(static_graph=False)

1207Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/var/lib/jenkins/workspace/test/distributed/test_dynamo_distributed.py", line 318, in _test_hf_bert_ddp_inductor

1208Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     run_hf_bert_ddp(self, model, inputs, "inductor")

1209Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/var/lib/jenkins/workspace/test/distributed/test_dynamo_distributed.py", line 242, in run_hf_bert_ddp

1210Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     self.assertTrue(same(correct_results, opt_results))

1211Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 687, in assertTrue

1212Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     raise self.failureException(msg)

1213Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] AssertionError: False is not true

1214Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]

1215Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] To execute this test, run the following from the base repo dir:

1216Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]      python test/distributed/test_dynamo_distributed.py -k test_hf_bert_ddp_inductor

1217Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]

1218Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

1219Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]  exiting process 2 with exit code: 10

1220Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] Caught exception:

1221Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] Traceback (most recent call last):

1222Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 650, in run_test

1223Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     getattr(self, test_name)()

1224Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 536, in wrapper

1225Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     fn()

1226Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2683, in wrapper

1227Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     method(*args, **kwargs)

1228Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 175, in wrapper

1229Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     return func(*args, **kwargs)

1230Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 161, in wrapper

1231Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     return func(*args, **kwargs)

1232Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/mock.py", line 1379, in patched

1233Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     return func(*newargs, **newkeywargs)

1234Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/var/lib/jenkins/workspace/test/distributed/test_dynamo_distributed.py", line 326, in test_hf_bert_ddp_inductor

1235Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     self._test_hf_bert_ddp_inductor(static_graph=False)

1236Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/var/lib/jenkins/workspace/test/distributed/test_dynamo_distributed.py", line 318, in _test_hf_bert_ddp_inductor

1237Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     run_hf_bert_ddp(self, model, inputs, "inductor")

1238Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/var/lib/jenkins/workspace/test/distributed/test_dynamo_distributed.py", line 242, in run_hf_bert_ddp

1239Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     self.assertTrue(same(correct_results, opt_results))

1240Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 687, in assertTrue

1241Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     raise self.failureException(msg)

1242Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] AssertionError: False is not true

1243Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]

1244Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] To execute this test, run the following from the base repo dir:

1245Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]      python test/distributed/test_dynamo_distributed.py -k test_hf_bert_ddp_inductor

1246Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]

1247Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

1248Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]  exiting process 1 with exit code: 10

1249Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] Caught exception:

1250Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] Traceback (most recent call last):

1251Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 650, in run_test

1252Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     getattr(self, test_name)()

1253Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 536, in wrapper

1254Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     fn()

1255Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_utils.py", line 2683, in wrapper

1256Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     method(*args, **kwargs)

1257Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 175, in wrapper

1258Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     return func(*args, **kwargs)

1259Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/site-packages/torch/testing/_internal/common_distributed.py", line 161, in wrapper

1260Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     return func(*args, **kwargs)

1261Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/mock.py", line 1379, in patched

1262Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     return func(*newargs, **newkeywargs)

1263Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/var/lib/jenkins/workspace/test/distributed/test_dynamo_distributed.py", line 326, in test_hf_bert_ddp_inductor

1264Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     self._test_hf_bert_ddp_inductor(static_graph=False)

1265Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/var/lib/jenkins/workspace/test/distributed/test_dynamo_distributed.py", line 318, in _test_hf_bert_ddp_inductor

1266Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     run_hf_bert_ddp(self, model, inputs, "inductor")

1267Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/var/lib/jenkins/workspace/test/distributed/test_dynamo_distributed.py", line 242, in run_hf_bert_ddp

1268Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     self.assertTrue(same(correct_results, opt_results))

1269Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]   File "/opt/conda/envs/py_3.10/lib/python3.10/unittest/case.py", line 687, in assertTrue

1270Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]     raise self.failureException(msg)

1271Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] AssertionError: False is not true

1272Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]

1273Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] To execute this test, run the following from the base repo dir:

1274Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]      python test/distributed/test_dynamo_distributed.py -k test_hf_bert_ddp_inductor

1275Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]

1276Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR] This message can be suppressed by setting PYTORCH_PRINT_REPRO_ON_FAILURE=0

1277Error: 1-18 02:26:46,464] torch.testing._internal.common_distributed: [ERROR]  exiting process 0 with exit code: 10

Minified repro

  1. Enable CUDA for fuse_attention pattern 16. In https://github.com/pytorch/pytorch/blob/main/torch/_inductor/fx_passes/fuse_attention.py#L670, remove disable_cuda=True.
  2. Reproduce with python test/distributed/test_dynamo_distributed.py -k test_hf_bert_ddp_inductor.

Versions

PyTorch version: branch cpu-device.

cc @ezyang @gchanan @zou3519 @kadeng @msaroufim @bdhirsh @anijain2305 @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @peterbell10 @ipiszy @yf225 @chenyang78 @muchulee8 @aakhundov @ColinPeppler

Metadata

Metadata

Assignees

Labels

module: inductoroncall: pt2triagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions