Skip to content

Workaround for performance bug in PyTorch with subclassed tensors#3683

Merged
jph00 merged 3 commits into
fastai:masterfrom
warner-benjamin:subclass_speed_fix
Jun 10, 2022
Merged

Workaround for performance bug in PyTorch with subclassed tensors#3683
jph00 merged 3 commits into
fastai:masterfrom
warner-benjamin:subclass_speed_fix

Conversation

@warner-benjamin

@warner-benjamin warner-benjamin commented Jun 10, 2022

Copy link
Copy Markdown
Collaborator

Resolves #3682 by adding CastToTensor callback as a workaround for performance bug in PyTorch with subclassed tensors.

Fixes 4cff258 by testing if learn.xb & learn.yb are tuples and applying the cast to Tensor if they are. (Unless I made a mistake in testing, b[:i] and b[i:] are tuples).

Unlike 4cff258, this PR adds the workaround as a callback so callbacks which use the input tensor type still can before it is casted to Tensor for the training performance increase.

It also allows turning off the Tensor casting by removing the callback should it ruin a workflow. Although anyone who does this is encouraged to reimplement their own custom callback which casts to Tensor to get the free training performance increase.

Currently CastToTensor.order is right before MixedPrecision.

@warner-benjamin warner-benjamin requested a review from jph00 as a code owner June 10, 2022 07:22
@review-notebook-app

Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@warner-benjamin

Copy link
Copy Markdown
Collaborator Author

I should have fixed the sync error, but if a nbdev _all_ isn't on one line, it won't be added to the module's __all__.

@jph00

jph00 commented Jun 10, 2022

Copy link
Copy Markdown
Member

Much better - thanks!

@jph00 jph00 merged commit 94edfc5 into fastai:master Jun 10, 2022
@jph00 jph00 added the bug label Jun 10, 2022
@warner-benjamin warner-benjamin deleted the subclass_speed_fix branch October 3, 2022 05:02
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

workaround pytorch subclass performance bug

2 participants