Skip to content

Add _syncthreads for Write-After-Read Race#383

Merged
rdspring1 merged 8 commits into20_8_18_develfrom
rds_smem_war
Sep 18, 2020
Merged

Add _syncthreads for Write-After-Read Race#383
rdspring1 merged 8 commits into20_8_18_develfrom
rds_smem_war

Conversation

@rdspring1
Copy link
Copy Markdown
Collaborator

@rdspring1 rdspring1 commented Sep 16, 2020

Fixes #380

  • Basic Write-After-Read check to add __syncthreads to end of for-loop
  • Enable result validation of the GEMM test.

Goal: Insert sync at end of for-loops to prevent write-after-read (WAR) race condition. WAR race condition occurs when the next iteration of the loop overwrites shared memory value before a previous operation has finished reading it.

WAR Race Check:

  • Track all output shared memory TVs before first sync.
  • Track all input shared memory TVs after last sync.
  • If the intersection is non-empty, then there is a WAR race condition.
  • Recursively check each nested for-loop.

@naoyam
Copy link
Copy Markdown
Collaborator

naoyam commented Sep 16, 2020

In the relevant tests, can you also check whether a required syncthreads is actually inserted at its proper location? Result validation may not be able to expose potential race conditions.

@rdspring1
Copy link
Copy Markdown
Collaborator Author

rdspring1 commented Sep 16, 2020

I thought about that but didn't come up with a good solution. e.g. check the entire kernel string.
Do you have any ideas?

@naoyam
Copy link
Copy Markdown
Collaborator

naoyam commented Sep 16, 2020

How about just traversing KIR to find a relevant ForLoop node and check whether it ends with a sync node? Finding relevant ForLoops may not be trivial for complex fusions, though.

@rdspring1
Copy link
Copy Markdown
Collaborator Author

Here are two other options:

  1. Instead of traversing KIR, flatten KIR and check if sync exists at position x
  2. Check if substring at position x is __syncthreads

Do we have access to the KIR from the test? Option 2 is the easiest to implement.

@naoyam
Copy link
Copy Markdown
Collaborator

naoyam commented Sep 16, 2020

I think (read-only) accesses to KIR should be allowed for verification.

@naoyam
Copy link
Copy Markdown
Collaborator

naoyam commented Sep 17, 2020

last_op_sync_ seems to be used to suppress inserting syncthreads when the last operation is also syncthreads. It seems that the "last operation" here can mean the last operation in a nested loop body. Something like this:

for i in X
  Write to SMEM
  __syncthreads()
  Read from SMEM
  for j in Y
    do something
     __syncthreads()
  end for
  // Insert __syncthreads() here?
end for

If I read the code correctly, in the above case, the last commented-out syncthreads is NOT inserted. However, since the trip count of the inner loop can be 0, the nested syncthreads may not be executed, so we need to have the last syncthreads.

Am I missing anything?

Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp Outdated
@rdspring1
Copy link
Copy Markdown
Collaborator Author

rdspring1 commented Sep 17, 2020

Don't all for-loops iterate from 0 to some positive integer? If we enforce this constraint, then every for-loop is entered.
Since we are performing this pass at compile time, we have to assume either all for-loops are entered or not.

Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp Outdated
@naoyam
Copy link
Copy Markdown
Collaborator

naoyam commented Sep 17, 2020

Don't all for-loops iterate from 0 to some positive integer? If we enforce this constraint, then every for-loop is entered.
Since we are performing this pass at compile time, we have to assume either all for-loops are entered or not.

Are we sure all for loops must have a trip count greater than 0? That may be the case, actually, but not 100% sure.

@rdspring1
Copy link
Copy Markdown
Collaborator Author

I ran all the cpp unit tests with this assertion inside the IterDomain constructor and they passed.

  TORCH_INTERNAL_ASSERT(
      _start->isZeroInt(),
      "Cannot create an iter domain with a start that is not zero but received ",
      _extent,
      " .");

  TORCH_INTERNAL_ASSERT(
      !_extent->isZeroInt(),
      "Cannot create an iter domain with a extent that is zero but received ",
      _extent,
      " .");

Comment thread torch/csrc/jit/codegen/cuda/ir_nodes.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower2device.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.h Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp Outdated
@naoyam
Copy link
Copy Markdown
Collaborator

naoyam commented Sep 18, 2020

Thanks for adding the check for inserted syncthreads. Looks very good!

Comment thread torch/csrc/jit/codegen/cuda/executor.h Outdated
Comment thread test/cpp/jit/test_gpu.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/kernel.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/kernel.cpp
Comment thread torch/csrc/jit/codegen/cuda/kernel.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/kernel_ir.cpp
Comment thread torch/csrc/jit/codegen/cuda/lower2device.cpp
Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp
Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp Outdated
Comment thread torch/csrc/jit/codegen/cuda/lower_insert_syncs.cpp
Comment thread torch/csrc/jit/codegen/cuda/kernel.cpp
Comment thread torch/csrc/jit/codegen/cuda/kernel.h Outdated
Comment thread torch/csrc/jit/codegen/cuda/kernel.cpp
@rdspring1 rdspring1 merged commit 944dad5 into 20_8_18_devel Sep 18, 2020
@rdspring1 rdspring1 deleted the rds_smem_war branch September 18, 2020 20:50
naoyam pushed a commit that referenced this pull request Sep 21, 2020
…e finish reading before writing."

This reverts commit dffaa76.

Revert this in favor of #383
naoyam pushed a commit that referenced this pull request Sep 22, 2020
* Basic Write-After-Read (WAR) check to add __syncthreads to end of for-loop

* Enable Tiled GEMM example

* Check that IterDomain iterates from zero to some positive integer

Co-authored-by: Ryan Spring <rspring@nvidia.com>
naoyam pushed a commit that referenced this pull request Sep 22, 2020
* Get a crazy test example working.

* Change problem size and tile size, still an issue with N > 32.

* Add sync threads in loops that read from smem, to make sure we finish reading before writing.

* Predicate off threads bound to a broadcast dim of an output when its in shared memory.

* Predicate smem tiling writing based on broadcasted dims in consumer.

* Cleanup example a bit.

* Revert "Add sync threads in loops that read from smem, to make sure we finish reading before writing."

This reverts commit dffaa76.

Revert this in favor of #383

* Add _syncthreads for Write-After-Read Race (#383)

* Basic Write-After-Read (WAR) check to add __syncthreads to end of for-loop

* Enable Tiled GEMM example

* Check that IterDomain iterates from zero to some positive integer

Co-authored-by: Ryan Spring <rspring@nvidia.com>

* Refactor thread predication for writes to smem

Co-authored-by: Naoya Maruyama <nmaruyama@nvidia.com>
Co-authored-by: Ryan Spring <rdspring1@gmail.com>
Co-authored-by: Ryan Spring <rspring@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Missing _syncthreads

3 participants