Migrate var & std to ATen #39967
Conversation
💊 CI failures summary and remediationsAs of commit 552f160 (more details on the Dr. CI page):
🕵️ 4 new failures recognized by patternsThe following CI failures do not appear to be due to upstream breakages:
|
eaacf05 to
ac79193
Compare
var & std to ATen var & std to ATen
var & std to ATen var & std to ATen
facebook-github-bot
left a comment
There was a problem hiding this comment.
@VitalyFedyunin has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
|
@ShawnZhong @VitalyFedyunin @ngimel 7a3c223 (This PR) My testing on GPU agrees that this is generally an improvement, though there are some cases with regressions. (#38338 will soon be updated with the script that I used to benchmark this PR.) Unfortunately, we may need to revert this PR since the impact on single threaded CPU speed is quite severe. |
|
@robieta sounds reasonable, let me revert it first, and after we can or (if quick) fix single thread, or at least apply GPU part first. |
Summary: Not sure why there are so many issues for std & var, but this PR should close them all: std: Fix pytorch#24771, Fix pytorch#24676, Fix pytorch#24639, Fix pytorch#24529 var: Fix pytorch#24782, Fix pytorch#24677, Fix pytorch#24652, Fix pytorch#24530 ```py import time import torch def _time(): if torch.cuda.is_available(): torch.cuda.synchronize() return time.time() for device in (torch.device("cpu"), torch.device("cuda")): for size in ( [100000000], [10000, 10000], [1000, 1000, 100], [100, 100, 100, 100], ): t = torch.randn(*size, device=device) total_time = 0 for i in range(10): t1 = _time() t.std() t2 = _time() total_time += t2 - t1 print(f"Tensor of size {size} on {device}: {total_time / 10}") ``` Before: ``` Tensor of size [100000000] on cpu: 0.36041643619537356 Tensor of size [10000, 10000] on cpu: 0.37235140800476074 Tensor of size [1000, 1000, 100] on cpu: 0.386572527885437 Tensor of size [100, 100, 100, 100] on cpu: 0.37404844760894773 Tensor of size [100000000] on cuda: 0.0021645784378051757 Tensor of size [10000, 10000] on cuda: 0.002090191841125488 Tensor of size [1000, 1000, 100] on cuda: 0.00208127498626709 Tensor of size [100, 100, 100, 100] on cuda: 0.0020844221115112306 ``` After: ``` Tensor of size [100000000] on cpu: 0.1339871883392334 Tensor of size [10000, 10000] on cpu: 0.1343991994857788 Tensor of size [1000, 1000, 100] on cpu: 0.1346735954284668 Tensor of size [100, 100, 100, 100] on cpu: 0.11906447410583496 Tensor of size [100000000] on cuda: 0.0013531208038330077 Tensor of size [10000, 10000] on cuda: 0.0012922048568725585 Tensor of size [1000, 1000, 100] on cuda: 0.001285886764526367 Tensor of size [100, 100, 100, 100] on cuda: 0.0012899160385131836 ``` cc: VitalyFedyunin Pull Request resolved: pytorch#39967 Differential Revision: D22162469 Pulled By: VitalyFedyunin fbshipit-source-id: 8d901c779767b00f81cd6231bc665e04f297b4c3
Not sure why there are so many issues for std & var, but this PR should close them all:
std: Fix #24771, Fix #24676, Fix #24639, Fix #24529
var: Fix #24782, Fix #24677, Fix #24652, Fix #24530
Before:
After:
cc: @VitalyFedyunin