Get a GEMM example with all bells and whistles#368
Conversation
| tv6->reorder({ | ||
| {2, -2}, | ||
| {3, -1}, | ||
| {4, 2}, | ||
| {5, 3}, | ||
| {6, 4}, | ||
| }); |
There was a problem hiding this comment.
Maybe comment that rFactor moves the reduction axes to the inner-most dimension and this reorder returns the tv6 to its previous mapping.
| // Sum the K-dim | ||
| TensorView* tv5 = sum(tv4, {1}); |
There was a problem hiding this comment.
comment that the K-dim becomes a reduction axis in tv5[M, rK, N]
| {t0, t1, BSX}, | ||
| torch::jit::fuser::cuda::LaunchParams(-1, -1, -1, BSX, -1, -1)); | ||
| {t0, t1, 3, 4, 5}, | ||
| torch::jit::fuser::cuda::LaunchParams(-1, -1, -1, -1, -1, -1)); |
There was a problem hiding this comment.
Could we remove LaunchParams since it is completely inferred?
| {t0, t1, 3, 4, 5}, | ||
| torch::jit::fuser::cuda::LaunchParams(-1, -1, -1, -1, -1, -1)); | ||
|
|
||
| at::Tensor aten_output = mul(t0.unsqueeze(2), t1.unsqueeze(0)).sum(1); |
There was a problem hiding this comment.
Maybe replace with at::Tensor aten_output = matmul(t0, t1);
There was a problem hiding this comment.
I went back and forth on this, for the presentation I'm including this in I think it's interesting to structure it similar to how we structured the kernel.
naoyam
left a comment
There was a problem hiding this comment.
The change regarding gridReduce looks good to me. Note that there is a bug in the use of shared memory inside the loop of the new added test. __syncthreads() is needed at the end of the loop body.
… reading before writing.
| fl->body().push_back(new kir::Sync()); | ||
| } | ||
|
|
||
| bool needs_sync = prev_needs_sync; |
There was a problem hiding this comment.
Is this supposed to declare a new variable? Or is it supposed to assign prev_needs_sync to the existing needs_sync?
| void SyncInserter::handle(kir::ForLoop* fl) { | ||
| bool prev_needs_sync = needs_sync; | ||
| active_scope = fl; | ||
|
|
There was a problem hiding this comment.
Don't we need to reset needs_sync here?
| for (auto inp : expr->inputs()) { | ||
| if (ir_utils::isTV(inp)) { | ||
| if (inp->as<TensorView>()->getMemoryType() == MemoryType::Shared) { | ||
| needs_sync = true; |
There was a problem hiding this comment.
Reading shared memory does not necessarily mean syncthreads is required. This synchthreads is needed when there is read-write dependency across threads.
This is a safe approach, but doesn't seem very efficient.
There was a problem hiding this comment.
I was working on a PR to detect this Write-After-Read race condition. It could be used here to be more efficient.
There was a problem hiding this comment.
@rdspring1 I'm cleaning up this PR. I'm thinking about dropping the syncthreads changes in this PR in favor of your #374. Does it sound good to you?
65ae8ea to
913fc90
Compare
* Basic Write-After-Read (WAR) check to add __syncthreads to end of for-loop * Enable Tiled GEMM example * Check that IterDomain iterates from zero to some positive integer Co-authored-by: Ryan Spring <rspring@nvidia.com>
4a09dcf to
3b63be6
Compare
|
@csarofeen The change in this PR is only about inserting thread predicates for expressions writing into shared memory. I tried to do a little cleanup at 3b63be6. The analysis didn't change but is done only once for each expression, whereas it is done on-demand in the original implementation, potentially being done multiple times redundantly Please let me know if this looks good to you. |
|
This remaining change avoids redundant writes to broadcast tensors on shared memory. For example, in the SmemDynamicTiledGemm test, the inner-most loop looks like below: Notice that writes to |
|
I'll put an upstream PR once this is merged. |
Summary: X-link: meta-pytorch/data#368 This is PR aims to expose the right data-relate API. There are two more changes made in this PR to convert public api to private api `check_lambda_fn` -> `_check_lambda_fn` `deprecation_warning` -> `_deprecation_warning` Pull Request resolved: pytorch#76143 Reviewed By: albanD, NivekT Differential Revision: D35798311 Pulled By: ejguan fbshipit-source-id: b13fded5c88a533c706702fb2070c918c839dca4 (cherry picked from commit 0b534b8)
Changed one of the tests to be a GEMM example with a combination of compiler time and runtime tiling parameters, including symbolic values for both inter-cta and intra-cta reductions. A few experimental fixes went in that we need to more thoroughly vet.
Related issues:
#363
#364
#365
#366
#367