Skip to content

Kernel IR: Splitting CUDA codegen from IrPrinter#379

Merged
tlemo merged 35 commits into20_8_18_develfrom
kernel_ir_part7c
Sep 15, 2020
Merged

Kernel IR: Splitting CUDA codegen from IrPrinter#379
tlemo merged 35 commits into20_8_18_develfrom
kernel_ir_part7c

Conversation

@tlemo
Copy link
Copy Markdown
Collaborator

@tlemo tlemo commented Sep 15, 2020

One of the main goals of having a dedicated kernel IR was separation of concerns: simpler and smaller components which do one thing instead of monolithic implementations.

This PR is a significant step in that direction: the CUDA code generation is now separate from the IrPrinter.

@tlemo tlemo requested review from naoyam and rdspring1 September 15, 2020 17:40
Comment thread test/cpp/jit/test_gpu.cpp
Comment on lines +1121 to +1126
if ((((((blockIdx.x * 1) + (1 - 1)) * 128) + threadIdx.x) < T0.size[0])) {
for(size_t i6 = 0; i6 < 1; ++i6) {
T2[i6]
= T0[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)] * T1[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)];
T3[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)]
= T2[i6] * T0[((((blockIdx.x * 1) + i6) * 128) + threadIdx.x)];
Copy link
Copy Markdown
Collaborator

@rdspring1 rdspring1 Sep 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I like collapsing the parenthesis, but I'd prefer to have the operators on separate lines for readability.

T2[] = T0[]
     * T1[]

For reads and writes, I'm fine with this: T2[] = T0[]

Copy link
Copy Markdown
Collaborator Author

@tlemo tlemo Sep 15, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. The formatting is not final, and I was planning to revisit it in a follow up PR to keep the changes a bit smaller (and also since we have opportunities to improve the formatting while also simplify the codegen code itself)

But if this is something we don't want to wait for, I'd be happy to update this PR.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm just adding my 2 cents on the kernel formatting. 😄

I also noticed that the for-loop is redundant, since it is only run once.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also noticed that the for-loop is redundant, since it is only run once.

Yep. That's a completely different beast altogether though. We're not doing any low-level optimizations today (but we could, and probably should - another reason to have a standalone kernel IR)


// Predicate map
// TODO(kir): consider a simpler, kernel IR based version
ThreadPredicateMap predicate_map_;
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIRC, the only reason we need to keep this mapping is for code generation of broadcastOp. A device function, blockBroacast, must be used when broadcasting thread-parallelized dimensions. Whether we should call that function is currently only determined at the code-gen time, but really I think this should be captured when lowering to KIR. One idea may be to have BlockBroadcast KIR node and generate that KIR node instance when FIR is lowered to KIR.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One idea may be to have BlockBroadcast KIR node and generate that KIR node instance when FIR is lowered to KIR.

I really like this idea. In general, I think this is the right pattern for conditional code generation: generate the intended operations during lowering rather than deciding what to print at the last minute.

Copy link
Copy Markdown
Collaborator

@naoyam naoyam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good. Left a comment on ThreadPredicateMap.

@tlemo tlemo merged commit 385fb96 into 20_8_18_devel Sep 15, 2020
@tlemo tlemo deleted the kernel_ir_part7c branch September 15, 2020 21:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants