Skip to content

Commit f11c4f9

Browse files
csarofeenfacebook-github-bot
authored andcommitted
New CUDA Fuser: Unrolling support, interface refactor (#36435)
Summary: Unrolling support has been added in a way that we get good performing code on GPUs. Not sure how long this link will last but an example of a generated unrolled kernel is: https://godbolt.org/z/i0uAv3 What can be seen from there is multiple calls of "ld.global.f32" without "ld.store.f32" in between them (and vice versa). This means that we are launching multiple loads that can be run in parallel, as well as multiple stores that can be run in parallel. This can be a crucial optimization for memory bound kernels. This was generally a point of concern in TVM as an attempt of a similar kernel from TVM produces: https://godbolt.org/z/Vu97vG which surrounds load - store pairs in conditional branches preventing the benefits of unrolling. Pull Request resolved: #36435 Reviewed By: ZolotukhinM Differential Revision: D21024011 Pulled By: soumith fbshipit-source-id: e852e282fa7a304aba962e1926f756098c011fe0
1 parent d7fabfd commit f11c4f9

44 files changed

Lines changed: 2841 additions & 1208 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

caffe2/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -587,8 +587,11 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
587587
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/ir_iostream.cpp
588588
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/iter_visitor.cpp
589589
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel.cpp
590+
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/kernel_cache.cpp
590591
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/manager.cpp
591592
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/mutator.cpp
593+
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_loops.cpp
594+
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower_utils.cpp
592595
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/lower2device.cpp
593596
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/parser.cpp
594597
${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/partition.cpp

test/cpp/jit/test_gpu.cpp

Lines changed: 193 additions & 121 deletions
Large diffs are not rendered by default.

test/cpp/jit/tests.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ namespace jit {
117117
_(GPU_FusionCodeGen2) \
118118
_(GPU_FusionSimplePWise) \
119119
_(GPU_FusionExecKernel) \
120-
_(GPU_FusionForLoop)
120+
_(GPU_FusionForLoop) \
121+
_(GPU_FusionLoopUnroll)
121122
#else
122123
#define TH_FORALL_TESTS_CUDA(_) \
123124
_(ArgumentSpec) \

test/test_jit_cuda_fuser.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -86,20 +86,57 @@ def t(x, y, z, q):
8686
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
8787
@skipIfRocm
8888
def test_scalar_input(self):
89-
def t(x, y, z):
90-
# type: (Tensor, Tensor, float) -> Tensor
89+
def t(x : torch.Tensor, y : torch.Tensor, z : float):
9190
o = x + y
9291
o = o + z
9392
return o
9493
t_jit = torch.jit.script(t)
95-
x = torch.randn(4, 8, dtype=torch.float, device="cuda")
96-
y = torch.randn(4, 8, dtype=torch.float, device="cuda")
94+
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
95+
y = torch.randn(4, 8, 1, 32, dtype=torch.float, device="cuda")
96+
y = y.expand(4, 8, 32, 32)
97+
jit_o = t_jit(x, y, 2.0)
98+
jit_o = t_jit(x, y, 2.0)
99+
o = t(x, y, 2.0)
100+
self.assertEqual(o, jit_o)
101+
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0)))
102+
103+
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
104+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
105+
@skipIfRocm
106+
def test_broadcasting(self):
107+
def t(x : torch.Tensor, y : torch.Tensor, z : float):
108+
o = x + y
109+
o = o + z
110+
return o
111+
t_jit = torch.jit.script(t)
112+
x = torch.randn(4, 8, 32, 32, dtype=torch.float, device="cuda")
113+
y = torch.randn(32, 32, dtype=torch.float, device="cuda")
97114
jit_o = t_jit(x, y, 2.0)
98115
jit_o = t_jit(x, y, 2.0)
99116
o = t(x, y, 2.0)
100117
self.assertEqual(o, jit_o)
101118
self.assertTrue(self._has_cuda_fusion_group(t_jit.graph_for(x, y, 2.0)))
102119

120+
@unittest.skipIf(not RUN_CUDA, "requires CUDA")
121+
@unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "Requires profiling node to run cuda fuser")
122+
@skipIfRocm
123+
def test_broadcasting_multiple_output_shape(self):
124+
def t(x : torch.Tensor, y : torch.Tensor, z : torch.Tensor):
125+
o = x + 12
126+
o1 = o + y
127+
o2 = o + z
128+
oo = o1.sum() + o2.sum()
129+
return oo
130+
t_jit = torch.jit.script(t)
131+
x = torch.randn(32, 32, dtype=torch.float, device="cuda")
132+
y = torch.randn(2, 32, 32, dtype=torch.float, device="cuda")
133+
z = torch.randn(4, 32, 32, dtype=torch.float, device="cuda")
134+
jit_o = t_jit(x, y, z)
135+
jit_o = t_jit(x, y, z)
136+
o = t(x, y, z)
137+
self.assertEqual(o, jit_o)
138+
# Currently cannot fuse this
139+
self.assertFalse(self._has_cuda_fusion_group(t_jit.graph_for(x, y, z)))
103140

104141
if __name__ == '__main__':
105142
run_tests()

tools/build_variables.bzl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,6 +246,9 @@ libtorch_cuda_sources = [
246246
"torch/csrc/jit/codegen/cuda/ir_iostream.cpp",
247247
"torch/csrc/jit/codegen/cuda/iter_visitor.cpp",
248248
"torch/csrc/jit/codegen/cuda/kernel.cpp",
249+
"torch/csrc/jit/codegen/cuda/kernel_cache.cpp",
250+
"torch/csrc/jit/codegen/cuda/lower_loops.cpp",
251+
"torch/csrc/jit/codegen/cuda/lower_utils.cpp",
249252
"torch/csrc/jit/codegen/cuda/lower2device.cpp",
250253
"torch/csrc/jit/codegen/cuda/manager.cpp",
251254
"torch/csrc/jit/codegen/cuda/mutator.cpp",

torch/csrc/jit/codegen/cuda/data_struct_str.h

Lines changed: 0 additions & 11 deletions
This file was deleted.

torch/csrc/jit/codegen/cuda/dispatch.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
#include <torch/csrc/jit/codegen/cuda/fusion.h>
22
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
3-
#include <torch/csrc/jit/codegen/cuda/tensor.h>
43
#include <torch/csrc/jit/codegen/cuda/type.h>
54

65
#include <torch/csrc/jit/codegen/cuda/dispatch.h>

torch/csrc/jit/codegen/cuda/fusion.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <torch/csrc/jit/codegen/cuda/fusion.h>
2+
#include <torch/csrc/jit/codegen/cuda/ir_all_nodes.h>
23
#include <torch/csrc/jit/codegen/cuda/ir_printer.h>
34

45
namespace torch {
@@ -33,6 +34,22 @@ std::vector<Expr*> ExprSort::getExprs(
3334
return es.exprs;
3435
}
3536

37+
void InputsOf::handle(TensorView* tv) {
38+
if (FusionGuard::getCurFusion()->hasInput(tv))
39+
inputs.push_back(tv);
40+
}
41+
42+
std::vector<TensorView*> InputsOf::output(Fusion* fusion, Val* output_) {
43+
TORCH_CHECK(
44+
fusion->hasOutput(output_),
45+
"Asked for the inputs of ",
46+
output_,
47+
" however, it is not an output of the provided fusion.");
48+
InputsOf io;
49+
io.traverseFrom(FusionGuard::getCurFusion(), {output_});
50+
return io.inputs;
51+
}
52+
3653
Fusion::~Fusion() {
3754
{
3855
auto it = val_set_.begin();
@@ -140,6 +157,10 @@ std::vector<Expr*> Fusion::exprs(bool from_outputs_only, bool breadth_first) {
140157
return ExprSort::getExprs(this, from_outputs_only, breadth_first);
141158
}
142159

160+
std::vector<TensorView*> Fusion::inputsOf(Val* val) {
161+
return InputsOf::output(this, val);
162+
}
163+
143164
void Fusion::print() {
144165
FusionGuard fg(this);
145166
std::cout << "%kernel {\n";

torch/csrc/jit/codegen/cuda/fusion.h

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ struct TypeHash {
4949
*/
5050

5151
struct Fusion;
52+
struct TensorView;
5253

5354
// Fusion Guard is our "context manager". It holds the actrive fusion and allows
5455
// it to be accessed anywhere through FusionGuard::getCurFusion().
@@ -79,6 +80,20 @@ struct ExprSort : public IterVisitor {
7980
bool breadth_first);
8081
};
8182

83+
// Expr sort will take a fusion and return a topologically sorted list of
84+
// expressions.
85+
struct InputsOf : public IterVisitor {
86+
using IterVisitor::handle;
87+
88+
private:
89+
std::vector<TensorView*> inputs;
90+
91+
void handle(TensorView* tv) override;
92+
93+
public:
94+
static std::vector<TensorView*> output(Fusion* fusion, Val* output_);
95+
};
96+
8297
/*
8398
* Fusion is mutable but unique. Nodes cannot be copied in any way from one
8499
* Fusion to another. If anything like that is desired, it would require
@@ -139,6 +154,8 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput {
139154
bool from_outputs_only = false,
140155
bool breadth_first = false);
141156

157+
std::vector<TensorView*> inputsOf(Val* val);
158+
142159
// Print this fusion to cout.
143160
void print();
144161

@@ -174,8 +191,6 @@ struct TORCH_CUDA_API Fusion : public IRInputOutput {
174191
// Return the Expr that produces val (const version)
175192
const Expr* origin(const Val* val) const;
176193

177-
bool lowered = false;
178-
179194
private:
180195
// Sets of all Vals/Exprs registered with this fusion
181196
std::set<Val*> val_set_;

torch/csrc/jit/codegen/cuda/index_compute.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ void IndexCompute::replayBackward(Merge* expr) {
2121
ax >= 0 && ax < indices.size(),
2222
"Hit an invalid MERGE transformation during IndexCompute, axis is not within bounds.");
2323

24-
Val* I = expr->in()->axis(ax + 1)->size();
24+
Val* I = expr->in()->axis(ax + 1)->extent();
2525
Val* ind = indices[ax];
2626
indices[ax] = div(ind, I);
2727
indices.insert(indices.begin() + ax + 1, mod(ind, I));
@@ -62,18 +62,18 @@ IndexCompute::IndexCompute(const TensorView* tv, std::vector<Val*> _indices) {
6262

6363
TensorDomain* td = tv->domain();
6464

65-
bool exclude_reduction = td->size() > indices.size();
65+
bool exclude_reduction = td->nDims() > indices.size();
6666

6767
TORCH_CHECK(
68-
exclude_reduction || td->size() == indices.size(),
68+
exclude_reduction || td->nDims() == indices.size(),
6969
"For IndexCompute the number of axis should match the number of dimensions"
7070
" in the TensorView.");
7171

7272
// If we need to ignore the reduction dimensions because a tensor is
7373
// being consumed, not produced, then insert dummy dimensions in the
7474
// indices for bookkeeping while replaying split/merge/reorder operations.
7575
if (exclude_reduction)
76-
for (decltype(td->size()) i{0}; i < td->size(); i++)
76+
for (decltype(td->nDims()) i{0}; i < td->nDims(); i++)
7777
if (td->axis(i)->isReduction())
7878
indices.insert(indices.begin() + i, new Int(-1));
7979

@@ -83,15 +83,15 @@ IndexCompute::IndexCompute(const TensorView* tv, std::vector<Val*> _indices) {
8383
TensorDomain* root = TransformIter::runBackward(td, true);
8484

8585
TORCH_INTERNAL_ASSERT(
86-
root->size() == indices.size(),
86+
root->nDims() == indices.size(),
8787
"Error during IndexCompute. The number of indices generated"
8888
" after running the transformations backwards should match"
8989
" the number of dimensions of the root TensorView.");
9090

9191
// Remove indices associated with reduction axes, we had them just for
9292
// bookkeeping.
9393
if (exclude_reduction) {
94-
for (auto i = root->size() - 1; i >= 0; i--)
94+
for (auto i = root->nDims() - 1; i >= 0; i--)
9595
if (root->axis(i)->isReduction())
9696
indices.erase(indices.begin() + i);
9797
}

0 commit comments

Comments
 (0)