Commit f456748
Iurii Zdebskyi
Update on "[wip] Replace optimizers in torch.optim with the ones from torch.optim._multi_tensor"
Differential Revision: [D25406490](https://our.internmc.facebook.com/intern/diff/D25406490)
------
### Benchmark results
SGD (lr=1e-3, momentum=1, dampening=0, weight_decay=1, nesterov=True)
Current: 201.63 ms
Foreach: 56.99 ms
Adam (weight_decay=1., amsgrad=True)
Current: 233.27 ms
Foreach: 46.89 ms
AdamW (weight_decay=1., amsgrad=True)
Current: 371.18 ms
Foreach: 121.04 ms
RMSprop (weight_decay=1, momentum=1, centered=True)
Current: 364.88 ms
Foreach: 47.52 ms
Rprop (lr=1e-2, etas=(0.5, 1.2), step_sizes=(1e-6, 50))
Current: 1.43 s
Foreach: 1.26 s
ASGD (weight_decay=1)
Current: 165.39 ms
Foreach: 40.61 ms
Adamax (weight_decay=1)
Current: 374.42 ms
Foreach: 291.06 ms
Adadelta (weight_decay=1)
Current: 252.64 ms
Foreach: 29.62 ms
### Benchmark script
```
import torch
import torch.optim as optim
import torch.nn as nn
import torchvision
import torch.utils.benchmark as benchmark_utils
model = torchvision.models.resnet.resnet101(pretrained=True).to("cuda")
targets = torch.randint(0, 1000, (100, 100), device="cuda")
criterion = nn.CrossEntropyLoss()
# optimizers
params = dict(weight_decay=1)
optimizer = optim.Adadelta(model.parameters(), **params)
optimizer_mta = optim._multi_tensor.Adadelta(model.parameters(), **params)
running_loss = 0.0
target = torch.empty(128, dtype=torch.long, device="cuda").random_(5)
optimizer.zero_grad()
inputs = torch.rand(128, 3, 100, 100, device="cuda" , requires_grad=True)
outputs = model(inputs)
loss = criterion(outputs, target)
loss.backward()
optimizer.step()
running_loss += loss.item()
def main():
timer = benchmark_utils.Timer(
stmt="torch.cuda.synchronize(); optimizer.step()",
globals=globals(),
label="str(optimizer)",
)
print(f"autorange:\n{timer.blocked_autorange()}\n\n")
timer_mta = benchmark_utils.Timer(
stmt="torch.cuda.synchronize(); optimizer_mta.step()",
globals=globals(),
label="str(optimizer_mta)",
)
print(f"autorange:\n{timer_mta.blocked_autorange()}\n\n")
if __name__ == "__main__":
main()
```
[ghstack-poisoned]954 files changed
Lines changed: 31435 additions & 12244 deletions
File tree
- .azure_pipelines
- job_templates
- .circleci
- cimodel/data
- simple
- docker
- common
- scripts
- verbatim-sources
- job-specs
- workflows
- windows-jni/include
- .github
- scripts
- workflows
- .jenkins/pytorch
- perf_test
- win-test-helpers
- android
- pytorch_android_torchvision/src/main
- cpp
- java/org/pytorch/torchvision
- test_app/app/src/main
- java/org/pytorch/testapp
- res/layout
- aten/src
- ATen
- core
- boxing
- dispatch
- op_registration
- cpu/vec256
- vsx
- cuda
- cudnn
- native
- cpu
- cuda
- cudnn
- metal
- mpscnn/tests
- ops
- miopen
- mkldnn
- quantized
- cpu
- kernels
- qnnpack
- src
- sparse
- cuda
- xnnpack
- templates
- test
- THCUNN/generic
- THC
- generated
- generic
- TH
- generic
- benchmarks
- cpp/tensorexpr
- fastrnns
- instruction_counts
- core
- definitions
- worker
- operator_benchmark/pt
- static_runtime
- c10
- benchmark
- core
- impl
- cuda
- mobile
- test/core/impl
- util
- caffe2
- contrib
- aten
- fakelowp
- tensorrt
- core
- nomnigraph/tests
- onnx
- operators
- opt
- perfkernels
- proto
- python
- examples
- fakelowp
- onnx
- tests
- operator_test
- rnn
- serialized_test
- trt
- data
- quantization/server
- queue
- sgd
- transforms
- utils
- cmake
- Modules_CUDA_fix/upstream/FindCUDA
- Modules
- public
- docker
- caffe2
- jenkins
- centos-cuda
- centos-rocm
- centos
- common
- ubuntu-cuda
- ubuntu-rocm
- ubuntu
- ubuntu-14.04-cpu-all-options
- ubuntu-14.04-cpu-minimal
- ubuntu-16.04-cpu-all-options
- ubuntu-16.04-cpu-minimal
- ubuntu-16.04-gpu-tutorial
- cpu-blis
- pytorch
- ubuntu_cpu_gpu
- docs
- caffe2
- cpp
- source/notes
- source
- community
- notes
- ios/TestApp
- TestApp/Assets.xcassets
- AppIcon.appiconset
- fastlane
- modules
- module_test
- observers
- scripts
- release_notes
- test
- backward_compatibility
- benchmark_utils
- cpp
- api
- jit
- lite_interpreter_runtime
- rpc
- tensorexpr
- distributed
- elastic/agent/server/test
- rpc
- fx
- jit
- onnx
- expect
- optim
- package
- package_a
- quantization
- type_hint_tests
- typing/reveal
- third_party
- tools
- autograd
- templates
- clang_format_hash
- linux64
- mac
- codegen
- api
- dest
- selective_build
- gdb
- lite_interpreter
- setup_helpers
- stats_utils
- test
- torch
- _C
- autograd
- csrc
- api
- include/torch
- nn
- functional
- modules
- options
- utils
- src
- nn/modules
- autograd
- functions
- deploy
- interpreter
- distributed
- autograd
- engine
- functions
- rpc_messages
- c10d
- rpc
- jit
- codegen/fuser
- cuda
- frontend
- ir
- mobile
- passes
- onnx
- python
- runtime
- static
- serialization
- tensorexpr
- utils
- cuda
- distributed
- algorithms/ddp_comm_hooks
- elastic/agent/server
- nn/api
- distributions
- futures
- fx
- experimental
- passes
- jit
- lib/c10d
- test
- linalg
- multiprocessing
- nn
- intrinsic/qat/modules
- modules
- parallel
- utils
- onnx
- optim/_multi_tensor
- package
- quantization
- fx
- ns
- special
- testing
- _internal
- distributed
- rpc
- utils
- benchmark/utils
- valgrind_wrapper
Some content is hidden
Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
Lines changed: 134 additions & 0 deletions
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
| 1 | + | |
| 2 | + | |
| 3 | + | |
| 4 | + | |
| 5 | + | |
| 6 | + | |
| 7 | + | |
| 8 | + | |
| 9 | + | |
| 10 | + | |
| 11 | + | |
| 12 | + | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
| 16 | + | |
| 17 | + | |
| 18 | + | |
| 19 | + | |
| 20 | + | |
| 21 | + | |
| 22 | + | |
| 23 | + | |
| 24 | + | |
| 25 | + | |
| 26 | + | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
| 54 | + | |
| 55 | + | |
| 56 | + | |
| 57 | + | |
| 58 | + | |
| 59 | + | |
| 60 | + | |
| 61 | + | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| 106 | + | |
| 107 | + | |
| 108 | + | |
| 109 | + | |
| 110 | + | |
| 111 | + | |
| 112 | + | |
| 113 | + | |
| 114 | + | |
| 115 | + | |
| 116 | + | |
| 117 | + | |
| 118 | + | |
| 119 | + | |
| 120 | + | |
| 121 | + | |
| 122 | + | |
| 123 | + | |
| 124 | + | |
| 125 | + | |
| 126 | + | |
| 127 | + | |
| 128 | + | |
| 129 | + | |
| 130 | + | |
| 131 | + | |
| 132 | + | |
| 133 | + | |
| 134 | + | |
0 commit comments