Skip to content

[release/2.3] [ROCm] Correct numerical issues in layer norm backwards kernel (#140259)#1766

Merged
pruthvistony merged 1 commit intorelease/2.3from
rel23-picks-jack
Dec 6, 2024
Merged

[release/2.3] [ROCm] Correct numerical issues in layer norm backwards kernel (#140259)#1766
pruthvistony merged 1 commit intorelease/2.3from
rel23-picks-jack

Conversation

@jataylo
Copy link
Copy Markdown
Collaborator

@jataylo jataylo commented Dec 4, 2024

It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel cuLoadWriteStridedInputs which processes strided input and accumulates the partial gradients into shared memory.

In this kernel (pytorch#87635) we truncated mean and rstd from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)

…ch#140259)

It was raised that the backwards layer norm on AMD was slightly off the accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which processes strided input and accumulates the partial gradients into shared memory.

In this kernel (pytorch#87635) we truncated `mean` and `rstd` from T_ACC type to T which causes numerical issues in the warp buffers created in this kernel. This PR will use the correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)
@rocm-repo-management-api
Copy link
Copy Markdown

Jenkins build for 376941ea81baea6726add3a82a7e78144ed2bebe commit finished as FAILURE
Links: Blue Ocean view / Build artifacts

@pruthvistony pruthvistony merged commit a7b07f9 into release/2.3 Dec 6, 2024
@pruthvistony pruthvistony deleted the rel23-picks-jack branch December 6, 2024 05:57
@ROCm ROCm deleted a comment from rocm-mici Dec 13, 2024
@ROCm ROCm deleted a comment from rocm-mici Dec 13, 2024
@ROCm ROCm deleted a comment from rocm-mici Dec 13, 2024
rocm-mici pushed a commit that referenced this pull request Dec 13, 2024
… kernel (pytorch#140259) (#1766)

It was raised that the backwards layer norm on AMD was slightly off the
accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which
processes strided input and accumulates the partial gradients into
shared memory.

In this kernel (pytorch#87635) we
truncated `mean` and `rstd` from T_ACC type to T which causes numerical
issues in the warp buffers created in this kernel. This PR will use the
correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so
this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)
@ROCm ROCm deleted a comment from rocm-mici Dec 13, 2024
rocm-mici pushed a commit that referenced this pull request Dec 13, 2024
… kernel (pytorch#140259) (#1766)

It was raised that the backwards layer norm on AMD was slightly off the
accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which
processes strided input and accumulates the partial gradients into
shared memory.

In this kernel (pytorch#87635) we
truncated `mean` and `rstd` from T_ACC type to T which causes numerical
issues in the warp buffers created in this kernel. This PR will use the
correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so
this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)
@ROCm ROCm deleted a comment from rocm-mici Dec 13, 2024
rocm-mici pushed a commit that referenced this pull request Dec 13, 2024
… kernel (pytorch#140259) (#1766)

It was raised that the backwards layer norm on AMD was slightly off the
accuracy of the equivalent NVIDIA implementation.

On AMD we call into a helper kernel `cuLoadWriteStridedInputs` which
processes strided input and accumulates the partial gradients into
shared memory.

In this kernel (pytorch#87635) we
truncated `mean` and `rstd` from T_ACC type to T which causes numerical
issues in the warp buffers created in this kernel. This PR will use the
correct accumulator type for mean and rstd.

Note: Only AMD call into this call stack for backwards layer norm, so
this was not an issue for NV.

Pull Request resolved: pytorch#140259
Approved by: https://github.com/jianyuh

(cherry picked from commit 001f736)
@ROCm ROCm deleted a comment from rocm-mici Dec 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants