Concatenate directly into shared memory when constructing batches#1323
Merged
soumith merged 1 commit intopytorch:masterfrom Apr 22, 2017
Merged
Concatenate directly into shared memory when constructing batches#1323soumith merged 1 commit intopytorch:masterfrom
soumith merged 1 commit intopytorch:masterfrom
Conversation
This saves an extra memory copy, which speeds up data loading a bit (5-10% with accimage). As part of this change: * torch.cat accepts keyword argument out * sepcifiying out=None is treated like not specifying out
apaszke
approved these changes
Apr 21, 2017
| if _use_shared_memory: | ||
| # If we're in a background process, concatenate directly into a | ||
| # shared memory tensor to avoid an extra copy | ||
| numel = sum([x.numel() for x in batch]) |
This comment was marked as off-topic.
This comment was marked as off-topic.
Sorry, something went wrong.
Contributor
|
seems to cause error TypeError Traceback (most recent call last)
<ipython-input-1-ce3d61609aec> in <module>()
3 b=t.Tensor(2,4)
4 c=t.Tensor(2,8)
----> 5 t.cat(a,b,out=c)
TypeError: cat received an invalid combination of arguments - got (torch.FloatTensor, torch.FloatTensor, out=torch.FloatTensor), but expected one of:
* (sequence[torch.FloatTensor] seq)
* (sequence[torch.FloatTensor] seq, int dim)also see https://discuss.pytorch.org/t/cat-got-an-unexpected-keyword-argument-out/2151 |
Member
Author
|
cat expects a sequence and a dimension : torch.cat([a, b], dim, out=c)
…On Sat, Apr 22, 2017 at 10:34 AM 陈云 ***@***.***> wrote:
seems to cause error
import torch as t
a=t.Tensor(2,4)
b=t.Tensor(2,4)
c=t.Tensor(2,8)t.cat(a,b,out=c)
TypeError Traceback (most recent call last)<ipython-input-1-ce3d61609aec> in <module>()
3 b=t.Tensor(2,4)
4 c=t.Tensor(2,8)----> 5 t.cat(a,b,out=c)
TypeError: cat received an invalid combination of arguments - got (torch.FloatTensor, torch.FloatTensor, out=torch.FloatTensor), but expected one of:
* (sequence[torch.FloatTensor] seq)
* (sequence[torch.FloatTensor] seq, int dim)
also see
https://discuss.pytorch.org/t/cat-got-an-unexpected-keyword-argument-out/2151
—
You are receiving this because you authored the thread.
Reply to this email directly, view it on GitHub
<#1323 (comment)>, or mute
the thread
<https://github.com/notifications/unsubscribe-auth/AAoB-slsNRtQY5CF-Ah07662EdwMALfNks5ryg_ngaJpZM4NEuC8>
.
|
bunelr
added a commit
to bunelr/pytorch
that referenced
this pull request
May 13, 2017
Reflect the changes of pytorch#1323
Jiaming-Liu
pushed a commit
to Jiaming-Liu/pytorch
that referenced
this pull request
May 18, 2017
Reflect the changes of pytorch#1323
Contributor
|
@colesbury Is there a reason why numpy arrays in pytorch/torch/utils/data/dataloader.py Line 218 in e8754ee to return default_collate([torch.from_numpy(b) for b in batch]) |
Member
Author
|
@boeddeker that seems fine. Can you send a PR with the change? |
Contributor
|
ok, I opened a PR |
eqy
pushed a commit
to eqy/pytorch
that referenced
this pull request
Jan 20, 2022
scotts
added a commit
to scotts/pytorch
that referenced
this pull request
Mar 31, 2026
8b42d4c..185fe9c includes the following commits: 185fe9c Expose occupany limiting factors (pytorch#1330) 0c8ede0 remove the rocprofiler early exit hack (pytorch#1329) 4826a43 Remove duplicate test ignore (pytorch#1328) 37fada9 Ensure that async doesn't loop while sync is active (pytorch#1327) 628e1d0 Add host_name to OSS Kineto trace metadata via gethostname() (pytorch#1323) 9d7373b Revert D97166802 (pytorch#1326) 3a61657 Fix Lingering INT32 Overflow (pytorch#1324) 50a0085 Re-enabled some hardcoded tests (pytorch#1321) e19dd92 Expose occupany limiting factors (pytorch#1322) Authored with Claude.
laurentdupin
pushed a commit
to laurentdupin/pytorch
that referenced
this pull request
Apr 24, 2026
… numpy (pytorch#14534) Summary: Since pytorch#1323 tensors are shared with shared memory, but this feature is not active for numpy. This PR fix this. Pull Request resolved: pytorch#14534 Differential Revision: D13561649 Pulled By: soumith fbshipit-source-id: b6bc9e99fb91e8b675c2ef131fba9fa11c1647c0
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This saves an extra memory copy, which speeds up data loading a bit
(5-10% with accimage).
As part of this change: