[BugFix] Call synchronization when using the td.to("cpu") operation on third-party devices to avoid potential precision issues#1425
[BugFix] Call synchronization when using the td.to("cpu") operation on third-party devices to avoid potential precision issues#1425vmoens merged 2 commits intopytorch:mainfrom ji-huazhong:fix
Conversation
vmoens
left a comment
There was a problem hiding this comment.
Thanks for this PR!
Can you further explain why this is needed and what problem is solved?
Thanks!
Hi @vmoens We have recently been using VERL for post-training our model on the Ascend NPUs. We observed that as the training progresses, the gradient norm (grad norm) consistently diverges (runs out of control). In contrast, all metrics converge well when training on NV GPUs. Coincidentally, we found that adding synchronization (torch.npu.synchronize) at the connection points between different components in the RL workflow resolved this issue, and it no longer recurs. In VERL, a single controller distributes data(organized in the form of TensorDict) to various workers (each associated with a device) via Ray for computation. After the computation is completed, Through further investigation, we discovered that the Lines 14130 to 14136 in 5e78151 Lines 14260 to 14268 in 5e78151 However, this guarantee only covers CUDA and MPS devices. PyTorch also supports third-party hardware such as Intel XPU and Ascend NPU, and this PR aims to extend this guarantee to cover these third-party hardware platforms as well. |
|
Great to see this PR merged! When will the next version containing this feature be released? @vmoens |
|
Hoping to release this by end of week but I have a bunch of CI issues... |
Description
Following up comment, this PR adresses the preicison issue when move tensordict from 3rd-party device to cpu.
cc @vmoens
Motivation and Context
Why is this change required? What problem does it solve?
If it fixes an open issue, please link to the issue here.
You can use the syntax
close #15213if this solves the issue #15213Types of changes
What types of changes does your code introduce? Remove all that do not apply:
Checklist
Go over all the following points, and put an
xin all the boxes that apply.If you are unsure about any of these, don't hesitate to ask. We are here to help!