Skip to content

Get a GEMM example with all bells and whistles#368

Merged
naoyam merged 15 commits into20_8_18_develfrom
crazy_example
Sep 22, 2020
Merged

Get a GEMM example with all bells and whistles#368
naoyam merged 15 commits into20_8_18_develfrom
crazy_example

Conversation

@csarofeen
Copy link
Copy Markdown
Owner

Changed one of the tests to be a GEMM example with a combination of compiler time and runtime tiling parameters, including symbolic values for both inter-cta and intra-cta reductions. A few experimental fixes went in that we need to more thoroughly vet.

Related issues:
#363
#364
#365
#366
#367

Comment thread test/cpp/jit/test_gpu.cpp
Comment on lines +5932 to +5938
tv6->reorder({
{2, -2},
{3, -1},
{4, 2},
{5, 3},
{6, 4},
});
Copy link
Copy Markdown
Collaborator

@rdspring1 rdspring1 Sep 8, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe comment that rFactor moves the reduction axes to the inner-most dimension and this reorder returns the tv6 to its previous mapping.

Comment thread test/cpp/jit/test_gpu.cpp Outdated
Comment on lines +5902 to +5903
// Sum the K-dim
TensorView* tv5 = sum(tv4, {1});
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment that the K-dim becomes a reduction axis in tv5[M, rK, N]

Comment thread test/cpp/jit/test_gpu.cpp Outdated
{t0, t1, BSX},
torch::jit::fuser::cuda::LaunchParams(-1, -1, -1, BSX, -1, -1));
{t0, t1, 3, 4, 5},
torch::jit::fuser::cuda::LaunchParams(-1, -1, -1, -1, -1, -1));
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we remove LaunchParams since it is completely inferred?

Comment thread test/cpp/jit/test_gpu.cpp Outdated
{t0, t1, 3, 4, 5},
torch::jit::fuser::cuda::LaunchParams(-1, -1, -1, -1, -1, -1));

at::Tensor aten_output = mul(t0.unsqueeze(2), t1.unsqueeze(0)).sum(1);
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe replace with at::Tensor aten_output = matmul(t0, t1);

Copy link
Copy Markdown
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I went back and forth on this, for the presentation I'm including this in I think it's interesting to structure it similar to how we structured the kernel.

Copy link
Copy Markdown
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The change regarding gridReduce looks good to me. Note that there is a bug in the use of shared memory inside the loop of the new added test. __syncthreads() is needed at the end of the loop body.

fl->body().push_back(new kir::Sync());
}

bool needs_sync = prev_needs_sync;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this supposed to declare a new variable? Or is it supposed to assign prev_needs_sync to the existing needs_sync?

void SyncInserter::handle(kir::ForLoop* fl) {
bool prev_needs_sync = needs_sync;
active_scope = fl;

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we need to reset needs_sync here?

for (auto inp : expr->inputs()) {
if (ir_utils::isTV(inp)) {
if (inp->as<TensorView>()->getMemoryType() == MemoryType::Shared) {
needs_sync = true;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reading shared memory does not necessarily mean syncthreads is required. This synchthreads is needed when there is read-write dependency across threads.

This is a safe approach, but doesn't seem very efficient.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was working on a PR to detect this Write-After-Read race condition. It could be used here to be more efficient.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rdspring1 I'm cleaning up this PR. I'm thinking about dropping the syncthreads changes in this PR in favor of your #374. Does it sound good to you?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes.

This was referenced Sep 14, 2020
Naoya Maruyama and others added 5 commits September 15, 2020 15:37
…e finish reading before writing."

This reverts commit dffaa76.

Revert this in favor of #383
* Basic Write-After-Read (WAR) check to add __syncthreads to end of for-loop

* Enable Tiled GEMM example

* Check that IterDomain iterates from zero to some positive integer

Co-authored-by: Ryan Spring <rspring@nvidia.com>
@naoyam
Copy link
Copy Markdown
Collaborator

naoyam commented Sep 22, 2020

@csarofeen The change in this PR is only about inserting thread predicates for expressions writing into shared memory. I tried to do a little cleanup at 3b63be6. The analysis didn't change but is done only once for each expression, whereas it is done on-demand in the original implementation, potentially being done multiple times redundantly

Please let me know if this looks good to you.

@naoyam
Copy link
Copy Markdown
Collaborator

naoyam commented Sep 22, 2020

This remaining change avoids redundant writes to broadcast tensors on shared memory. For example, in the SmemDynamicTiledGemm test, the inner-most loop looks like below:

 for(size_t i11 = 0; i11 < (ceilDiv((ceilDiv(T0.size[1], i2)), i1)); ++i11) {
    if ((((((blockIdx.z * blockDim.z) + threadIdx.z) < T0.size[0]) && (((((i11 * blockDim.x) + threadIdx.x) * gridDim.x) + blockIdx.x) < T0.size[1])) && (threadIdx.y == 0))) {
      T2[(threadIdx.z * blockDim.x) + threadIdx.x]
         = T0[(((blockIdx.z * blockDim.z) + threadIdx.z) * T0.stride[0]) + (((((i11 * blockDim.x) + threadIdx.x) * gridDim.x) + blockIdx.x) * T0.stride[1])];
    }
    if ((((((((i11 * blockDim.x) + threadIdx.x) * gridDim.x) + blockIdx.x) < T1.size[0]) && (((blockIdx.y * 8) + threadIdx.y) < T1.size[1])) && (threadIdx.z == 0))) {
      T3[(threadIdx.x * 8) + threadIdx.y]
         = T1[(((((i11 * blockDim.x) + threadIdx.x) * gridDim.x) + blockIdx.x) * T1.stride[0]) + (((blockIdx.y * 8) + threadIdx.y) * T1.stride[1])];
    }
    __syncthreads();
    float T4[1];
    if ((((((blockIdx.z * blockDim.z) + threadIdx.z) < T0.size[0]) && (((((i11 * blockDim.x) + threadIdx.x) * gridDim.x) + blockIdx.x) < T0.size[1])) && (((blockIdx.y * 8) + threadIdx.y) < T1.size[1]))) {
      T4[0]
        = T2[(threadIdx.z * blockDim.x) + threadIdx.x]
        * T3[(threadIdx.x * 8) + threadIdx.y];
    }
    if ((((((blockIdx.z * blockDim.z) + threadIdx.z) < T0.size[0]) && (((((i11 * blockDim.x) + threadIdx.x) * gridDim.x) + blockIdx.x) < T0.size[1])) && (((blockIdx.y * 8) + threadIdx.y) < T1.size[1]))) {
      T6[0]
        = T6[0]
        + T4[0];
    }
    __syncthreads();
  }

Notice that writes to T2 and T3 are predicated with threadIdx.y == 0 and threadIdx.z == 0, respectively, in addition to the other bound-check predicates.

@jjsjann123
Copy link
Copy Markdown
Collaborator

I'll put an upstream PR once this is merged.

Copy link
Copy Markdown
Owner Author

@csarofeen csarofeen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@naoyam naoyam merged commit 1c67154 into 20_8_18_devel Sep 22, 2020
@csarofeen csarofeen deleted the crazy_example branch June 9, 2021 13:38
jjsjann123 pushed a commit that referenced this pull request May 3, 2022
Summary:
X-link: meta-pytorch/data#368

This is PR aims to expose the right data-relate API.

There are two more changes made in this PR to convert public api to private api
`check_lambda_fn` -> `_check_lambda_fn`
`deprecation_warning` -> `_deprecation_warning`

Pull Request resolved: pytorch#76143

Reviewed By: albanD, NivekT

Differential Revision: D35798311

Pulled By: ejguan

fbshipit-source-id: b13fded5c88a533c706702fb2070c918c839dca4
(cherry picked from commit 0b534b8)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants